Allow ArrowIterator to read from json

Summary:

Test Plan:
This commit is contained in:
Pedro Rodriguez 2025-02-05 22:47:00 +00:00
parent 2f42633b07
commit 48cf4dfee1
3 changed files with 104 additions and 57 deletions

View file

@ -1,8 +1,10 @@
# Copyright (c) Meta Platforms, Inc. and affiliates. # Copyright (c) Meta Platforms, Inc. and affiliates.
import json
import logging import logging
import os import os
from typing import Any from typing import Any
import fsspec
import numpy as np import numpy as np
import yaml import yaml
from omegaconf import OmegaConf from omegaconf import OmegaConf
@ -10,11 +12,9 @@ from pydantic import BaseModel, ConfigDict
from bytelatent.checkpoint import CheckpointArgs from bytelatent.checkpoint import CheckpointArgs
from bytelatent.data.data_types import Batch from bytelatent.data.data_types import Batch
from bytelatent.data.file_util import get_fs
from bytelatent.data.iterators.abstract_iterator import StatefulIterator from bytelatent.data.iterators.abstract_iterator import StatefulIterator
from bytelatent.data.iterators.arrow_iterator import ( from bytelatent.data.iterators.arrow_iterator import ArrowFileIterator
ArrowFileIterator,
find_and_sanitize_chunks,
)
from bytelatent.data.iterators.looping_iterator import LoopingIterator from bytelatent.data.iterators.looping_iterator import LoopingIterator
from bytelatent.data.iterators.multiprocess_iterator import MultiprocessIterator from bytelatent.data.iterators.multiprocess_iterator import MultiprocessIterator
from bytelatent.data.iterators.packing_iterator import PackingArgs, PackingIterator from bytelatent.data.iterators.packing_iterator import PackingArgs, PackingIterator
@ -53,6 +53,33 @@ def parse_args(args_cls):
return pydantic_args return pydantic_args
TRAIN_DATA_FILE_PATTERN = "*.chunk.*.jsonl"
def find_and_sanitize_chunks(
dataset_path: str,
world_size: int,
file_pattern: str,
s3_profile: str | None = None,
):
fs = get_fs(dataset_path, s3_profile=s3_profile)
path_with_glob = os.path.join(dataset_path, file_pattern)
dataset_chunks = fs.glob(path_with_glob)
n_chunks = len(dataset_chunks)
if n_chunks > world_size:
n_discard = n_chunks - world_size
dataset_chunks = dataset_chunks[:world_size]
else:
assert (
world_size % n_chunks == 0
), "World size should be a multiple of number of chunks"
assert n_chunks > 0, f"No valid chunks in {dataset_path}"
return dataset_chunks
def distribute_data_to_rank( def distribute_data_to_rank(
*, *,
dataset_path: str, dataset_path: str,
@ -62,9 +89,10 @@ def distribute_data_to_rank(
rank: int, rank: int,
world_size: int, world_size: int,
s3_profile: str | None = None, s3_profile: str | None = None,
file_pattern: str = TRAIN_DATA_FILE_PATTERN,
) -> ArrowFileIterator: ) -> ArrowFileIterator:
dataset_chunks = find_and_sanitize_chunks( dataset_chunks = find_and_sanitize_chunks(
dataset_path, world_size, s3_profile=s3_profile dataset_path, world_size, file_pattern, s3_profile=s3_profile
) )
n_workers_per_chunk = world_size // len(dataset_chunks) n_workers_per_chunk = world_size // len(dataset_chunks)
rank_to_arrow_iterator_params = [] rank_to_arrow_iterator_params = []

View file

@ -16,6 +16,7 @@ from bytelatent import ByteLatentError
from bytelatent.data.data_types import BltExample from bytelatent.data.data_types import BltExample
from bytelatent.data.file_util import get_fs from bytelatent.data.file_util import get_fs
from bytelatent.data.iterators.abstract_iterator import IteratorState, StatefulIterator from bytelatent.data.iterators.abstract_iterator import IteratorState, StatefulIterator
from bytelatent.preprocess.preprocess_entropies import get_id_key, get_text
logger = getLogger(__name__) logger = getLogger(__name__)
@ -32,6 +33,7 @@ class ArrowFileIteratorState(BaseModel, IteratorState):
arrow_batch_size: int = 100 arrow_batch_size: int = 100
s3_profile: str | None s3_profile: str | None
filesystem_type: str | None = None filesystem_type: str | None = None
file_format: str
def build(self) -> "ArrowFileIterator": def build(self) -> "ArrowFileIterator":
arrow_file = ArrowFileIterator( arrow_file = ArrowFileIterator(
@ -44,6 +46,7 @@ class ArrowFileIteratorState(BaseModel, IteratorState):
dataset_files=self.dataset_files, dataset_files=self.dataset_files,
s3_profile=self.s3_profile, s3_profile=self.s3_profile,
filesystem_type=self.filesystem_type, filesystem_type=self.filesystem_type,
file_format=self.file_format,
) )
if self.row_num != 0: if self.row_num != 0:
arrow_file._set_row_num(self.row_num) arrow_file._set_row_num(self.row_num)
@ -70,6 +73,7 @@ class ArrowFileIterator(StatefulIterator):
dataset_files: list[str] | None = None, dataset_files: list[str] | None = None,
s3_profile: str | None = None, s3_profile: str | None = None,
filesystem_type: str | None = None, filesystem_type: str | None = None,
file_format: str = "arrow",
): ):
assert 0 <= worker_id < num_workers, (worker_id, num_workers) assert 0 <= worker_id < num_workers, (worker_id, num_workers)
if file_path is None and dataset_files is None: if file_path is None and dataset_files is None:
@ -87,12 +91,16 @@ class ArrowFileIterator(StatefulIterator):
self.arrow_batch_size = arrow_batch_size self.arrow_batch_size = arrow_batch_size
self.s3_profile = s3_profile self.s3_profile = s3_profile
self.filesystem_type = filesystem_type self.filesystem_type = filesystem_type
self.file_format = file_format
self.fs = None self.fs = None
if self.filesystem_type is not None: if self.filesystem_type is not None:
if self.filesystem_type == "file": if self.filesystem_type == "file":
self.fs = fsspec.filesystem("file") self.fs = fsspec.filesystem("file")
elif self.filesystem_type == "s3": elif self.filesystem_type == "s3":
self.fs = fsspec.filesystem("s3", profile=s3_profile) self.fs = fsspec.filesystem("s3", profile=s3_profile)
else:
raise ValueError("Unknown filesystem")
logger.info("Arrow iterator using fs=%s", self.fs)
if dataset_files is None: if dataset_files is None:
# Prepare arrow shards # Prepare arrow shards
@ -153,6 +161,7 @@ class ArrowFileIterator(StatefulIterator):
dataset_files=self.dataset_files, dataset_files=self.dataset_files,
s3_profile=self.s3_profile, s3_profile=self.s3_profile,
filesystem_type=self.filesystem_type, filesystem_type=self.filesystem_type,
file_format=self.file_format,
) )
def create_iter( def create_iter(
@ -164,7 +173,7 @@ class ArrowFileIterator(StatefulIterator):
else: else:
filesystem = None filesystem = None
self.dataset = pa.dataset.dataset( self.dataset = pa.dataset.dataset(
self.dataset_files, format="arrow", filesystem=filesystem self.dataset_files, format=self.file_format, filesystem=filesystem
) )
self.batch_iterator = self.dataset.to_batches( self.batch_iterator = self.dataset.to_batches(
batch_size=self.arrow_batch_size batch_size=self.arrow_batch_size
@ -173,13 +182,22 @@ class ArrowFileIterator(StatefulIterator):
if self.batch_to_consume is not None: if self.batch_to_consume is not None:
batch_columns: dict[str, list] = self.batch_to_consume batch_columns: dict[str, list] = self.batch_to_consume
self.batch_to_consume = None self.batch_to_consume = None
sample_ids = batch_columns["sample_id"] if self.file_format == "arrow":
texts = batch_columns["text"] sample_ids = batch_columns["sample_id"]
entropies = batch_columns["entropies"] texts = batch_columns["text"]
entropies = batch_columns["entropies"]
elif self.file_format == "json":
# This data hasn't been preprocessed to a uniform format,
# so we have to do it now and omit entropies
sample_ids = batch_columns[get_id_key(batch_columns)]
texts = get_text(batch_columns)
entropies = None
else:
raise ValueError(f"Unknown file format: {self.file_format}")
for i in range(len(sample_ids)): for i in range(len(sample_ids)):
out = BltExample( out = BltExample(
sample_id=sample_ids[i], sample_id=sample_ids[i],
entropies=entropies[i], entropies=entropies[i] if entropies is not None else None,
text=texts[i], text=texts[i],
tokens=None, tokens=None,
mask=None, mask=None,
@ -191,13 +209,22 @@ class ArrowFileIterator(StatefulIterator):
for batch in self.batch_iterator: for batch in self.batch_iterator:
batch_columns = batch.to_pydict() batch_columns = batch.to_pydict()
sample_ids = batch_columns["sample_id"] if self.file_format == "arrow":
texts = batch_columns["text"] sample_ids = batch_columns["sample_id"]
entropies = batch_columns["entropies"] texts = batch_columns["text"]
entropies = batch_columns["entropies"]
elif self.file_format == "json":
# This data hasn't been preprocessed to a uniform format,
# so we have to do it now and omit entropies
sample_ids = batch_columns[get_id_key(batch_columns)]
texts = get_text(batch_columns)
entropies = None
else:
raise ValueError(f"Unknown file format: {self.file_format}")
for i in range(len(sample_ids)): for i in range(len(sample_ids)):
out = BltExample( out = BltExample(
sample_id=sample_ids[i], sample_id=sample_ids[i],
entropies=entropies[i], entropies=entropies[i] if entropies is not None else None,
text=texts[i], text=texts[i],
tokens=None, tokens=None,
mask=None, mask=None,
@ -231,13 +258,24 @@ class ArrowFileIterator(StatefulIterator):
for batch in self.batch_iterator: for batch in self.batch_iterator:
if len(batch) > curr_remaining: if len(batch) > curr_remaining:
batch_columns: dict[str, list] = batch.to_pydict() batch_columns: dict[str, list] = batch.to_pydict()
batch_columns["sample_id"] = batch_columns["sample_id"][ if self.file_format == "arrow":
curr_remaining: leftover_sample_ids = batch_columns["sample_id"][
] curr_remaining:
batch_columns["entropies"] = batch_columns["entropies"][ ]
curr_remaining: leftover_entropies = batch_columns["entropies"][curr_remaining:]
] leftover_texts = batch_columns["text"][curr_remaining:]
batch_columns["text"] = batch_columns["text"][curr_remaining:] elif self.file_format == "json":
leftover_sample_ids = batch_columns[get_id_key(batch_columns)][
curr_remaining:
]
leftover_entropies = None
leftover_texts = get_text(batch_columns)[curr_remaining:]
else:
raise ValueError(f"Unknown file format: {self.file_format}")
batch_columns["sample_id"] = leftover_sample_ids
batch_columns["entropies"] = leftover_entropies
batch_columns["text"] = leftover_texts
self.batch_to_consume = batch_columns self.batch_to_consume = batch_columns
break break
elif len(batch) == curr_remaining: elif len(batch) == curr_remaining:
@ -250,30 +288,3 @@ class ArrowFileIterator(StatefulIterator):
logger.info( logger.info(
f"Finished setting arrow position to {target_row_num} for {self.dataset_files}" f"Finished setting arrow position to {target_row_num} for {self.dataset_files}"
) )
TRAIN_DATA_FILE_PATTERN = "*.chunk.*.jsonl"
def find_and_sanitize_chunks(
dataset_path: str,
world_size: int,
file_pattern: str = TRAIN_DATA_FILE_PATTERN,
s3_profile: str | None = None,
):
fs = get_fs(dataset_path, s3_profile=s3_profile)
path_with_glob = os.path.join(dataset_path, file_pattern)
dataset_chunks = fs.glob(path_with_glob)
n_chunks = len(dataset_chunks)
if n_chunks > world_size:
n_discard = n_chunks - world_size
dataset_chunks = dataset_chunks[:world_size]
else:
assert (
world_size % n_chunks == 0
), "World size should be a multiple of number of chunks"
assert n_chunks > 0, f"No valid chunks in {dataset_path}"
return dataset_chunks

View file

@ -15,29 +15,37 @@ from bytelatent.entropy_model import load_entropy_model
from bytelatent.tokenizers.build_tokenizer import TokenizerArgs from bytelatent.tokenizers.build_tokenizer import TokenizerArgs
def get_id_from_doc(doc: dict) -> int: def get_id_key(doc: dict) -> int:
""" """
We need a reliable way to ensure that samples from jsonl We need a reliable way to ensure that samples from jsonl
and arrow are the same, but there is no unique id field, and arrow are the same, but there is no unique id field,
so derive the best possible so derive the best possible
""" """
if "sample_id" in doc: if "sample_id" in doc:
sample_id = doc["sample_id"] return "sample_id"
elif "title" in doc: elif "title" in doc:
sample_id = doc["title"] return "title"
elif "qid" in doc: elif "qid" in doc:
sample_id = doc["qid"] return "qid"
elif "paper_id" in doc: elif "paper_id" in doc:
sample_id = doc["paper_id"] return "paper_id"
elif "path" in doc: elif "path" in doc:
sample_id = doc["path"] return "path"
elif "url" in doc: elif "url" in doc:
sample_id = doc["url"] return "url"
elif "id" in doc: elif "id" in doc:
sample_id = doc["id"] return "id"
else: else:
raise ValueError(f"Could not find a id key from: {doc.keys()}") raise ValueError(f"Could not find a id key from: {doc.keys()}")
return str(sample_id)
def get_id_from_doc(doc: dict) -> int:
"""
We need a reliable way to ensure that samples from jsonl
and arrow are the same, but there is no unique id field,
so derive the best possible
"""
return str(doc[get_id_key(doc)])
def get_text(doc: dict): def get_text(doc: dict):