Allow ArrowIterator to read from json (#45)
Some checks are pending
Lint with Black / lint (push) Waiting to run
Lint with isort / lint (push) Waiting to run

Summary:

Currently, arrow iterator can only read arrow files. However, the pyarrow library can read
other formats, including jsonlines. This allows the same ArrowIterator to read from jsonlines,
so we can read from the original source data, and simply omit the entropy column when doing so

Test Plan:

Run train script until dataloader starts
This commit is contained in:
Pedro Rodriguez 2025-02-06 09:57:22 -08:00 committed by GitHub
parent afedb16598
commit 936d9437be
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 105 additions and 58 deletions

View file

@ -1,8 +1,10 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
import json
import logging
import os
from typing import Any
import fsspec
import numpy as np
import yaml
from omegaconf import OmegaConf
@ -10,11 +12,9 @@ from pydantic import BaseModel, ConfigDict
from bytelatent.checkpoint import CheckpointArgs
from bytelatent.data.data_types import Batch
from bytelatent.data.file_util import get_fs
from bytelatent.data.iterators.abstract_iterator import StatefulIterator
from bytelatent.data.iterators.arrow_iterator import (
ArrowFileIterator,
find_and_sanitize_chunks,
)
from bytelatent.data.iterators.arrow_iterator import ArrowFileIterator
from bytelatent.data.iterators.looping_iterator import LoopingIterator
from bytelatent.data.iterators.multiprocess_iterator import MultiprocessIterator
from bytelatent.data.iterators.packing_iterator import PackingArgs, PackingIterator
@ -53,6 +53,33 @@ def parse_args(args_cls):
return pydantic_args
TRAIN_DATA_FILE_PATTERN = "*.chunk.*.jsonl"
def find_and_sanitize_chunks(
dataset_path: str,
world_size: int,
file_pattern: str,
s3_profile: str | None = None,
):
fs = get_fs(dataset_path, s3_profile=s3_profile)
path_with_glob = os.path.join(dataset_path, file_pattern)
dataset_chunks = fs.glob(path_with_glob)
n_chunks = len(dataset_chunks)
if n_chunks > world_size:
n_discard = n_chunks - world_size
dataset_chunks = dataset_chunks[:world_size]
else:
assert (
world_size % n_chunks == 0
), "World size should be a multiple of number of chunks"
assert n_chunks > 0, f"No valid chunks in {dataset_path}"
return dataset_chunks
def distribute_data_to_rank(
*,
dataset_path: str,
@ -62,9 +89,10 @@ def distribute_data_to_rank(
rank: int,
world_size: int,
s3_profile: str | None = None,
file_pattern: str = TRAIN_DATA_FILE_PATTERN,
) -> ArrowFileIterator:
dataset_chunks = find_and_sanitize_chunks(
dataset_path, world_size, s3_profile=s3_profile
dataset_path, world_size, file_pattern, s3_profile=s3_profile
)
n_workers_per_chunk = world_size // len(dataset_chunks)
rank_to_arrow_iterator_params = []

View file

@ -16,6 +16,7 @@ from bytelatent import ByteLatentError
from bytelatent.data.data_types import BltExample
from bytelatent.data.file_util import get_fs
from bytelatent.data.iterators.abstract_iterator import IteratorState, StatefulIterator
from bytelatent.preprocess.preprocess_entropies import get_id_key, get_text
logger = getLogger(__name__)
@ -32,6 +33,7 @@ class ArrowFileIteratorState(BaseModel, IteratorState):
arrow_batch_size: int = 100
s3_profile: str | None
filesystem_type: str | None = None
file_format: str
def build(self) -> "ArrowFileIterator":
arrow_file = ArrowFileIterator(
@ -44,6 +46,7 @@ class ArrowFileIteratorState(BaseModel, IteratorState):
dataset_files=self.dataset_files,
s3_profile=self.s3_profile,
filesystem_type=self.filesystem_type,
file_format=self.file_format,
)
if self.row_num != 0:
arrow_file._set_row_num(self.row_num)
@ -70,6 +73,7 @@ class ArrowFileIterator(StatefulIterator):
dataset_files: list[str] | None = None,
s3_profile: str | None = None,
filesystem_type: str | None = None,
file_format: str = "arrow",
):
assert 0 <= worker_id < num_workers, (worker_id, num_workers)
if file_path is None and dataset_files is None:
@ -87,12 +91,16 @@ class ArrowFileIterator(StatefulIterator):
self.arrow_batch_size = arrow_batch_size
self.s3_profile = s3_profile
self.filesystem_type = filesystem_type
self.file_format = file_format
self.fs = None
if self.filesystem_type is not None:
if self.filesystem_type == "file":
self.fs = fsspec.filesystem("file")
elif self.filesystem_type == "s3":
self.fs = fsspec.filesystem("s3", profile=s3_profile)
else:
raise ValueError("Unknown filesystem")
logger.info("Arrow iterator using fs=%s", self.fs)
if dataset_files is None:
# Prepare arrow shards
@ -153,6 +161,7 @@ class ArrowFileIterator(StatefulIterator):
dataset_files=self.dataset_files,
s3_profile=self.s3_profile,
filesystem_type=self.filesystem_type,
file_format=self.file_format,
)
def create_iter(
@ -164,7 +173,7 @@ class ArrowFileIterator(StatefulIterator):
else:
filesystem = None
self.dataset = pa.dataset.dataset(
self.dataset_files, format="arrow", filesystem=filesystem
self.dataset_files, format=self.file_format, filesystem=filesystem
)
self.batch_iterator = self.dataset.to_batches(
batch_size=self.arrow_batch_size
@ -173,13 +182,22 @@ class ArrowFileIterator(StatefulIterator):
if self.batch_to_consume is not None:
batch_columns: dict[str, list] = self.batch_to_consume
self.batch_to_consume = None
sample_ids = batch_columns["sample_id"]
texts = batch_columns["text"]
entropies = batch_columns["entropies"]
if self.file_format == "arrow":
sample_ids = batch_columns["sample_id"]
texts = batch_columns["text"]
entropies = batch_columns["entropies"]
elif self.file_format == "json":
# This data hasn't been preprocessed to a uniform format,
# so we have to do it now and omit entropies
sample_ids = batch_columns[get_id_key(batch_columns)]
texts = get_text(batch_columns)
entropies = None
else:
raise ValueError(f"Unknown file format: {self.file_format}")
for i in range(len(sample_ids)):
out = BltExample(
sample_id=sample_ids[i],
entropies=entropies[i],
entropies=entropies[i] if entropies is not None else None,
text=texts[i],
tokens=None,
mask=None,
@ -191,13 +209,22 @@ class ArrowFileIterator(StatefulIterator):
for batch in self.batch_iterator:
batch_columns = batch.to_pydict()
sample_ids = batch_columns["sample_id"]
texts = batch_columns["text"]
entropies = batch_columns["entropies"]
if self.file_format == "arrow":
sample_ids = batch_columns["sample_id"]
texts = batch_columns["text"]
entropies = batch_columns["entropies"]
elif self.file_format == "json":
# This data hasn't been preprocessed to a uniform format,
# so we have to do it now and omit entropies
sample_ids = batch_columns[get_id_key(batch_columns)]
texts = get_text(batch_columns)
entropies = None
else:
raise ValueError(f"Unknown file format: {self.file_format}")
for i in range(len(sample_ids)):
out = BltExample(
sample_id=sample_ids[i],
entropies=entropies[i],
entropies=entropies[i] if entropies is not None else None,
text=texts[i],
tokens=None,
mask=None,
@ -231,13 +258,24 @@ class ArrowFileIterator(StatefulIterator):
for batch in self.batch_iterator:
if len(batch) > curr_remaining:
batch_columns: dict[str, list] = batch.to_pydict()
batch_columns["sample_id"] = batch_columns["sample_id"][
curr_remaining:
]
batch_columns["entropies"] = batch_columns["entropies"][
curr_remaining:
]
batch_columns["text"] = batch_columns["text"][curr_remaining:]
if self.file_format == "arrow":
leftover_sample_ids = batch_columns["sample_id"][
curr_remaining:
]
leftover_entropies = batch_columns["entropies"][curr_remaining:]
leftover_texts = batch_columns["text"][curr_remaining:]
elif self.file_format == "json":
leftover_sample_ids = batch_columns[get_id_key(batch_columns)][
curr_remaining:
]
leftover_entropies = None
leftover_texts = get_text(batch_columns)[curr_remaining:]
else:
raise ValueError(f"Unknown file format: {self.file_format}")
batch_columns["sample_id"] = leftover_sample_ids
batch_columns["entropies"] = leftover_entropies
batch_columns["text"] = leftover_texts
self.batch_to_consume = batch_columns
break
elif len(batch) == curr_remaining:
@ -250,30 +288,3 @@ class ArrowFileIterator(StatefulIterator):
logger.info(
f"Finished setting arrow position to {target_row_num} for {self.dataset_files}"
)
TRAIN_DATA_FILE_PATTERN = "*.chunk.*.jsonl"
def find_and_sanitize_chunks(
dataset_path: str,
world_size: int,
file_pattern: str = TRAIN_DATA_FILE_PATTERN,
s3_profile: str | None = None,
):
fs = get_fs(dataset_path, s3_profile=s3_profile)
path_with_glob = os.path.join(dataset_path, file_pattern)
dataset_chunks = fs.glob(path_with_glob)
n_chunks = len(dataset_chunks)
if n_chunks > world_size:
n_discard = n_chunks - world_size
dataset_chunks = dataset_chunks[:world_size]
else:
assert (
world_size % n_chunks == 0
), "World size should be a multiple of number of chunks"
assert n_chunks > 0, f"No valid chunks in {dataset_path}"
return dataset_chunks

View file

@ -15,29 +15,37 @@ from bytelatent.entropy_model import load_entropy_model
from bytelatent.tokenizers.build_tokenizer import TokenizerArgs
def get_id_from_doc(doc: dict) -> int:
def get_id_key(doc: dict) -> int:
"""
We need a reliable way to ensure that samples from jsonl
and arrow are the same, but there is no unique id field,
so derive the best possible
"""
if "sample_id" in doc:
sample_id = doc["sample_id"]
return "sample_id"
elif "title" in doc:
sample_id = doc["title"]
return "title"
elif "qid" in doc:
sample_id = doc["qid"]
return "qid"
elif "paper_id" in doc:
sample_id = doc["paper_id"]
return "paper_id"
elif "path" in doc:
sample_id = doc["path"]
return "path"
elif "url" in doc:
sample_id = doc["url"]
return "url"
elif "id" in doc:
sample_id = doc["id"]
return "id"
else:
raise ValueError(f"Could not find a id key from: {doc.keys()}")
return str(sample_id)
def get_id_from_doc(doc: dict) -> int:
"""
We need a reliable way to ensure that samples from jsonl
and arrow are the same, but there is no unique id field,
so derive the best possible
"""
return str(doc[get_id_key(doc)])
def get_text(doc: dict):

View file

@ -4,10 +4,10 @@ import json
import os
import shutil
import subprocess
from pydantic import BaseModel
from typing import Any, Dict
from omegaconf import OmegaConf
from pydantic import BaseModel
class StoolArgs(BaseModel):