mirror of
https://github.com/facebookresearch/blt.git
synced 2025-02-22 21:12:15 +00:00
Allow ArrowIterator to read from json
Summary: Currently, arrow iterator can only read arrow files. However, the pyarrow library can read other formats, including jsonlines. This allows the same ArrowIterator to read from jsonlines, so we can read from the original source data, and simply omit the entropy column when doing so Test Plan: Run train script until dataloader starts
This commit is contained in:
parent
afedb16598
commit
9c3c997cae
|
@ -1,8 +1,10 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
import fsspec
|
||||
import numpy as np
|
||||
import yaml
|
||||
from omegaconf import OmegaConf
|
||||
|
@ -10,11 +12,9 @@ from pydantic import BaseModel, ConfigDict
|
|||
|
||||
from bytelatent.checkpoint import CheckpointArgs
|
||||
from bytelatent.data.data_types import Batch
|
||||
from bytelatent.data.file_util import get_fs
|
||||
from bytelatent.data.iterators.abstract_iterator import StatefulIterator
|
||||
from bytelatent.data.iterators.arrow_iterator import (
|
||||
ArrowFileIterator,
|
||||
find_and_sanitize_chunks,
|
||||
)
|
||||
from bytelatent.data.iterators.arrow_iterator import ArrowFileIterator
|
||||
from bytelatent.data.iterators.looping_iterator import LoopingIterator
|
||||
from bytelatent.data.iterators.multiprocess_iterator import MultiprocessIterator
|
||||
from bytelatent.data.iterators.packing_iterator import PackingArgs, PackingIterator
|
||||
|
@ -53,6 +53,33 @@ def parse_args(args_cls):
|
|||
return pydantic_args
|
||||
|
||||
|
||||
TRAIN_DATA_FILE_PATTERN = "*.chunk.*.jsonl"
|
||||
|
||||
|
||||
def find_and_sanitize_chunks(
|
||||
dataset_path: str,
|
||||
world_size: int,
|
||||
file_pattern: str,
|
||||
s3_profile: str | None = None,
|
||||
):
|
||||
fs = get_fs(dataset_path, s3_profile=s3_profile)
|
||||
path_with_glob = os.path.join(dataset_path, file_pattern)
|
||||
dataset_chunks = fs.glob(path_with_glob)
|
||||
n_chunks = len(dataset_chunks)
|
||||
|
||||
if n_chunks > world_size:
|
||||
n_discard = n_chunks - world_size
|
||||
dataset_chunks = dataset_chunks[:world_size]
|
||||
else:
|
||||
assert (
|
||||
world_size % n_chunks == 0
|
||||
), "World size should be a multiple of number of chunks"
|
||||
|
||||
assert n_chunks > 0, f"No valid chunks in {dataset_path}"
|
||||
|
||||
return dataset_chunks
|
||||
|
||||
|
||||
def distribute_data_to_rank(
|
||||
*,
|
||||
dataset_path: str,
|
||||
|
@ -62,9 +89,10 @@ def distribute_data_to_rank(
|
|||
rank: int,
|
||||
world_size: int,
|
||||
s3_profile: str | None = None,
|
||||
file_pattern: str = TRAIN_DATA_FILE_PATTERN,
|
||||
) -> ArrowFileIterator:
|
||||
dataset_chunks = find_and_sanitize_chunks(
|
||||
dataset_path, world_size, s3_profile=s3_profile
|
||||
dataset_path, world_size, file_pattern, s3_profile=s3_profile
|
||||
)
|
||||
n_workers_per_chunk = world_size // len(dataset_chunks)
|
||||
rank_to_arrow_iterator_params = []
|
||||
|
|
|
@ -16,6 +16,7 @@ from bytelatent import ByteLatentError
|
|||
from bytelatent.data.data_types import BltExample
|
||||
from bytelatent.data.file_util import get_fs
|
||||
from bytelatent.data.iterators.abstract_iterator import IteratorState, StatefulIterator
|
||||
from bytelatent.preprocess.preprocess_entropies import get_id_key, get_text
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
@ -32,6 +33,7 @@ class ArrowFileIteratorState(BaseModel, IteratorState):
|
|||
arrow_batch_size: int = 100
|
||||
s3_profile: str | None
|
||||
filesystem_type: str | None = None
|
||||
file_format: str
|
||||
|
||||
def build(self) -> "ArrowFileIterator":
|
||||
arrow_file = ArrowFileIterator(
|
||||
|
@ -44,6 +46,7 @@ class ArrowFileIteratorState(BaseModel, IteratorState):
|
|||
dataset_files=self.dataset_files,
|
||||
s3_profile=self.s3_profile,
|
||||
filesystem_type=self.filesystem_type,
|
||||
file_format=self.file_format,
|
||||
)
|
||||
if self.row_num != 0:
|
||||
arrow_file._set_row_num(self.row_num)
|
||||
|
@ -70,6 +73,7 @@ class ArrowFileIterator(StatefulIterator):
|
|||
dataset_files: list[str] | None = None,
|
||||
s3_profile: str | None = None,
|
||||
filesystem_type: str | None = None,
|
||||
file_format: str = "arrow",
|
||||
):
|
||||
assert 0 <= worker_id < num_workers, (worker_id, num_workers)
|
||||
if file_path is None and dataset_files is None:
|
||||
|
@ -87,12 +91,16 @@ class ArrowFileIterator(StatefulIterator):
|
|||
self.arrow_batch_size = arrow_batch_size
|
||||
self.s3_profile = s3_profile
|
||||
self.filesystem_type = filesystem_type
|
||||
self.file_format = file_format
|
||||
self.fs = None
|
||||
if self.filesystem_type is not None:
|
||||
if self.filesystem_type == "file":
|
||||
self.fs = fsspec.filesystem("file")
|
||||
elif self.filesystem_type == "s3":
|
||||
self.fs = fsspec.filesystem("s3", profile=s3_profile)
|
||||
else:
|
||||
raise ValueError("Unknown filesystem")
|
||||
logger.info("Arrow iterator using fs=%s", self.fs)
|
||||
|
||||
if dataset_files is None:
|
||||
# Prepare arrow shards
|
||||
|
@ -153,6 +161,7 @@ class ArrowFileIterator(StatefulIterator):
|
|||
dataset_files=self.dataset_files,
|
||||
s3_profile=self.s3_profile,
|
||||
filesystem_type=self.filesystem_type,
|
||||
file_format=self.file_format,
|
||||
)
|
||||
|
||||
def create_iter(
|
||||
|
@ -164,7 +173,7 @@ class ArrowFileIterator(StatefulIterator):
|
|||
else:
|
||||
filesystem = None
|
||||
self.dataset = pa.dataset.dataset(
|
||||
self.dataset_files, format="arrow", filesystem=filesystem
|
||||
self.dataset_files, format=self.file_format, filesystem=filesystem
|
||||
)
|
||||
self.batch_iterator = self.dataset.to_batches(
|
||||
batch_size=self.arrow_batch_size
|
||||
|
@ -173,13 +182,22 @@ class ArrowFileIterator(StatefulIterator):
|
|||
if self.batch_to_consume is not None:
|
||||
batch_columns: dict[str, list] = self.batch_to_consume
|
||||
self.batch_to_consume = None
|
||||
if self.file_format == "arrow":
|
||||
sample_ids = batch_columns["sample_id"]
|
||||
texts = batch_columns["text"]
|
||||
entropies = batch_columns["entropies"]
|
||||
elif self.file_format == "json":
|
||||
# This data hasn't been preprocessed to a uniform format,
|
||||
# so we have to do it now and omit entropies
|
||||
sample_ids = batch_columns[get_id_key(batch_columns)]
|
||||
texts = get_text(batch_columns)
|
||||
entropies = None
|
||||
else:
|
||||
raise ValueError(f"Unknown file format: {self.file_format}")
|
||||
for i in range(len(sample_ids)):
|
||||
out = BltExample(
|
||||
sample_id=sample_ids[i],
|
||||
entropies=entropies[i],
|
||||
entropies=entropies[i] if entropies is not None else None,
|
||||
text=texts[i],
|
||||
tokens=None,
|
||||
mask=None,
|
||||
|
@ -191,13 +209,22 @@ class ArrowFileIterator(StatefulIterator):
|
|||
|
||||
for batch in self.batch_iterator:
|
||||
batch_columns = batch.to_pydict()
|
||||
if self.file_format == "arrow":
|
||||
sample_ids = batch_columns["sample_id"]
|
||||
texts = batch_columns["text"]
|
||||
entropies = batch_columns["entropies"]
|
||||
elif self.file_format == "json":
|
||||
# This data hasn't been preprocessed to a uniform format,
|
||||
# so we have to do it now and omit entropies
|
||||
sample_ids = batch_columns[get_id_key(batch_columns)]
|
||||
texts = get_text(batch_columns)
|
||||
entropies = None
|
||||
else:
|
||||
raise ValueError(f"Unknown file format: {self.file_format}")
|
||||
for i in range(len(sample_ids)):
|
||||
out = BltExample(
|
||||
sample_id=sample_ids[i],
|
||||
entropies=entropies[i],
|
||||
entropies=entropies[i] if entropies is not None else None,
|
||||
text=texts[i],
|
||||
tokens=None,
|
||||
mask=None,
|
||||
|
@ -231,13 +258,24 @@ class ArrowFileIterator(StatefulIterator):
|
|||
for batch in self.batch_iterator:
|
||||
if len(batch) > curr_remaining:
|
||||
batch_columns: dict[str, list] = batch.to_pydict()
|
||||
batch_columns["sample_id"] = batch_columns["sample_id"][
|
||||
if self.file_format == "arrow":
|
||||
leftover_sample_ids = batch_columns["sample_id"][
|
||||
curr_remaining:
|
||||
]
|
||||
batch_columns["entropies"] = batch_columns["entropies"][
|
||||
leftover_entropies = batch_columns["entropies"][curr_remaining:]
|
||||
leftover_texts = batch_columns["text"][curr_remaining:]
|
||||
elif self.file_format == "json":
|
||||
leftover_sample_ids = batch_columns[get_id_key(batch_columns)][
|
||||
curr_remaining:
|
||||
]
|
||||
batch_columns["text"] = batch_columns["text"][curr_remaining:]
|
||||
leftover_entropies = None
|
||||
leftover_texts = get_text(batch_columns)[curr_remaining:]
|
||||
else:
|
||||
raise ValueError(f"Unknown file format: {self.file_format}")
|
||||
|
||||
batch_columns["sample_id"] = leftover_sample_ids
|
||||
batch_columns["entropies"] = leftover_entropies
|
||||
batch_columns["text"] = leftover_texts
|
||||
self.batch_to_consume = batch_columns
|
||||
break
|
||||
elif len(batch) == curr_remaining:
|
||||
|
@ -250,30 +288,3 @@ class ArrowFileIterator(StatefulIterator):
|
|||
logger.info(
|
||||
f"Finished setting arrow position to {target_row_num} for {self.dataset_files}"
|
||||
)
|
||||
|
||||
|
||||
TRAIN_DATA_FILE_PATTERN = "*.chunk.*.jsonl"
|
||||
|
||||
|
||||
def find_and_sanitize_chunks(
|
||||
dataset_path: str,
|
||||
world_size: int,
|
||||
file_pattern: str = TRAIN_DATA_FILE_PATTERN,
|
||||
s3_profile: str | None = None,
|
||||
):
|
||||
fs = get_fs(dataset_path, s3_profile=s3_profile)
|
||||
path_with_glob = os.path.join(dataset_path, file_pattern)
|
||||
dataset_chunks = fs.glob(path_with_glob)
|
||||
n_chunks = len(dataset_chunks)
|
||||
|
||||
if n_chunks > world_size:
|
||||
n_discard = n_chunks - world_size
|
||||
dataset_chunks = dataset_chunks[:world_size]
|
||||
else:
|
||||
assert (
|
||||
world_size % n_chunks == 0
|
||||
), "World size should be a multiple of number of chunks"
|
||||
|
||||
assert n_chunks > 0, f"No valid chunks in {dataset_path}"
|
||||
|
||||
return dataset_chunks
|
||||
|
|
|
@ -15,29 +15,37 @@ from bytelatent.entropy_model import load_entropy_model
|
|||
from bytelatent.tokenizers.build_tokenizer import TokenizerArgs
|
||||
|
||||
|
||||
def get_id_from_doc(doc: dict) -> int:
|
||||
def get_id_key(doc: dict) -> int:
|
||||
"""
|
||||
We need a reliable way to ensure that samples from jsonl
|
||||
and arrow are the same, but there is no unique id field,
|
||||
so derive the best possible
|
||||
"""
|
||||
if "sample_id" in doc:
|
||||
sample_id = doc["sample_id"]
|
||||
return "sample_id"
|
||||
elif "title" in doc:
|
||||
sample_id = doc["title"]
|
||||
return "title"
|
||||
elif "qid" in doc:
|
||||
sample_id = doc["qid"]
|
||||
return "qid"
|
||||
elif "paper_id" in doc:
|
||||
sample_id = doc["paper_id"]
|
||||
return "paper_id"
|
||||
elif "path" in doc:
|
||||
sample_id = doc["path"]
|
||||
return "path"
|
||||
elif "url" in doc:
|
||||
sample_id = doc["url"]
|
||||
return "url"
|
||||
elif "id" in doc:
|
||||
sample_id = doc["id"]
|
||||
return "id"
|
||||
else:
|
||||
raise ValueError(f"Could not find a id key from: {doc.keys()}")
|
||||
return str(sample_id)
|
||||
|
||||
|
||||
def get_id_from_doc(doc: dict) -> int:
|
||||
"""
|
||||
We need a reliable way to ensure that samples from jsonl
|
||||
and arrow are the same, but there is no unique id field,
|
||||
so derive the best possible
|
||||
"""
|
||||
return str(doc[get_id_key(doc)])
|
||||
|
||||
|
||||
def get_text(doc: dict):
|
||||
|
|
|
@ -4,10 +4,10 @@ import json
|
|||
import os
|
||||
import shutil
|
||||
import subprocess
|
||||
from pydantic import BaseModel
|
||||
from typing import Any, Dict
|
||||
|
||||
from omegaconf import OmegaConf
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class StoolArgs(BaseModel):
|
||||
|
|
Loading…
Reference in a new issue