mirror of
https://github.com/facebookresearch/blt.git
synced 2025-02-22 21:12:15 +00:00
Allow ArrowIterator to read from json
Summary: Test Plan:
This commit is contained in:
parent
afedb16598
commit
0e9421af07
|
@ -1,8 +1,10 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
import fsspec
|
||||
import numpy as np
|
||||
import yaml
|
||||
from omegaconf import OmegaConf
|
||||
|
@ -10,11 +12,9 @@ from pydantic import BaseModel, ConfigDict
|
|||
|
||||
from bytelatent.checkpoint import CheckpointArgs
|
||||
from bytelatent.data.data_types import Batch
|
||||
from bytelatent.data.file_util import get_fs
|
||||
from bytelatent.data.iterators.abstract_iterator import StatefulIterator
|
||||
from bytelatent.data.iterators.arrow_iterator import (
|
||||
ArrowFileIterator,
|
||||
find_and_sanitize_chunks,
|
||||
)
|
||||
from bytelatent.data.iterators.arrow_iterator import ArrowFileIterator
|
||||
from bytelatent.data.iterators.looping_iterator import LoopingIterator
|
||||
from bytelatent.data.iterators.multiprocess_iterator import MultiprocessIterator
|
||||
from bytelatent.data.iterators.packing_iterator import PackingArgs, PackingIterator
|
||||
|
@ -53,6 +53,33 @@ def parse_args(args_cls):
|
|||
return pydantic_args
|
||||
|
||||
|
||||
TRAIN_DATA_FILE_PATTERN = "*.chunk.*.jsonl"
|
||||
|
||||
|
||||
def find_and_sanitize_chunks(
|
||||
dataset_path: str,
|
||||
world_size: int,
|
||||
file_pattern: str,
|
||||
s3_profile: str | None = None,
|
||||
):
|
||||
fs = get_fs(dataset_path, s3_profile=s3_profile)
|
||||
path_with_glob = os.path.join(dataset_path, file_pattern)
|
||||
dataset_chunks = fs.glob(path_with_glob)
|
||||
n_chunks = len(dataset_chunks)
|
||||
|
||||
if n_chunks > world_size:
|
||||
n_discard = n_chunks - world_size
|
||||
dataset_chunks = dataset_chunks[:world_size]
|
||||
else:
|
||||
assert (
|
||||
world_size % n_chunks == 0
|
||||
), "World size should be a multiple of number of chunks"
|
||||
|
||||
assert n_chunks > 0, f"No valid chunks in {dataset_path}"
|
||||
|
||||
return dataset_chunks
|
||||
|
||||
|
||||
def distribute_data_to_rank(
|
||||
*,
|
||||
dataset_path: str,
|
||||
|
@ -62,9 +89,10 @@ def distribute_data_to_rank(
|
|||
rank: int,
|
||||
world_size: int,
|
||||
s3_profile: str | None = None,
|
||||
file_pattern: str = TRAIN_DATA_FILE_PATTERN,
|
||||
) -> ArrowFileIterator:
|
||||
dataset_chunks = find_and_sanitize_chunks(
|
||||
dataset_path, world_size, s3_profile=s3_profile
|
||||
dataset_path, world_size, file_pattern, s3_profile=s3_profile
|
||||
)
|
||||
n_workers_per_chunk = world_size // len(dataset_chunks)
|
||||
rank_to_arrow_iterator_params = []
|
||||
|
|
|
@ -16,6 +16,7 @@ from bytelatent import ByteLatentError
|
|||
from bytelatent.data.data_types import BltExample
|
||||
from bytelatent.data.file_util import get_fs
|
||||
from bytelatent.data.iterators.abstract_iterator import IteratorState, StatefulIterator
|
||||
from bytelatent.preprocess.preprocess_entropies import get_id_key, get_text
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
@ -32,6 +33,7 @@ class ArrowFileIteratorState(BaseModel, IteratorState):
|
|||
arrow_batch_size: int = 100
|
||||
s3_profile: str | None
|
||||
filesystem_type: str | None = None
|
||||
file_format: str
|
||||
|
||||
def build(self) -> "ArrowFileIterator":
|
||||
arrow_file = ArrowFileIterator(
|
||||
|
@ -44,6 +46,7 @@ class ArrowFileIteratorState(BaseModel, IteratorState):
|
|||
dataset_files=self.dataset_files,
|
||||
s3_profile=self.s3_profile,
|
||||
filesystem_type=self.filesystem_type,
|
||||
file_format=self.file_format,
|
||||
)
|
||||
if self.row_num != 0:
|
||||
arrow_file._set_row_num(self.row_num)
|
||||
|
@ -70,6 +73,7 @@ class ArrowFileIterator(StatefulIterator):
|
|||
dataset_files: list[str] | None = None,
|
||||
s3_profile: str | None = None,
|
||||
filesystem_type: str | None = None,
|
||||
file_format: str = "arrow",
|
||||
):
|
||||
assert 0 <= worker_id < num_workers, (worker_id, num_workers)
|
||||
if file_path is None and dataset_files is None:
|
||||
|
@ -87,12 +91,16 @@ class ArrowFileIterator(StatefulIterator):
|
|||
self.arrow_batch_size = arrow_batch_size
|
||||
self.s3_profile = s3_profile
|
||||
self.filesystem_type = filesystem_type
|
||||
self.file_format = file_format
|
||||
self.fs = None
|
||||
if self.filesystem_type is not None:
|
||||
if self.filesystem_type == "file":
|
||||
self.fs = fsspec.filesystem("file")
|
||||
elif self.filesystem_type == "s3":
|
||||
self.fs = fsspec.filesystem("s3", profile=s3_profile)
|
||||
else:
|
||||
raise ValueError("Unknown filesystem")
|
||||
logger.info("Arrow iterator using fs=%s", self.fs)
|
||||
|
||||
if dataset_files is None:
|
||||
# Prepare arrow shards
|
||||
|
@ -153,6 +161,7 @@ class ArrowFileIterator(StatefulIterator):
|
|||
dataset_files=self.dataset_files,
|
||||
s3_profile=self.s3_profile,
|
||||
filesystem_type=self.filesystem_type,
|
||||
file_format=self.file_format,
|
||||
)
|
||||
|
||||
def create_iter(
|
||||
|
@ -164,7 +173,7 @@ class ArrowFileIterator(StatefulIterator):
|
|||
else:
|
||||
filesystem = None
|
||||
self.dataset = pa.dataset.dataset(
|
||||
self.dataset_files, format="arrow", filesystem=filesystem
|
||||
self.dataset_files, format=self.file_format, filesystem=filesystem
|
||||
)
|
||||
self.batch_iterator = self.dataset.to_batches(
|
||||
batch_size=self.arrow_batch_size
|
||||
|
@ -173,13 +182,22 @@ class ArrowFileIterator(StatefulIterator):
|
|||
if self.batch_to_consume is not None:
|
||||
batch_columns: dict[str, list] = self.batch_to_consume
|
||||
self.batch_to_consume = None
|
||||
sample_ids = batch_columns["sample_id"]
|
||||
texts = batch_columns["text"]
|
||||
entropies = batch_columns["entropies"]
|
||||
if self.file_format == "arrow":
|
||||
sample_ids = batch_columns["sample_id"]
|
||||
texts = batch_columns["text"]
|
||||
entropies = batch_columns["entropies"]
|
||||
elif self.file_format == "json":
|
||||
# This data hasn't been preprocessed to a uniform format,
|
||||
# so we have to do it now and omit entropies
|
||||
sample_ids = batch_columns[get_id_key(batch_columns)]
|
||||
texts = get_text(batch_columns)
|
||||
entropies = None
|
||||
else:
|
||||
raise ValueError(f"Unknown file format: {self.file_format}")
|
||||
for i in range(len(sample_ids)):
|
||||
out = BltExample(
|
||||
sample_id=sample_ids[i],
|
||||
entropies=entropies[i],
|
||||
entropies=entropies[i] if entropies is not None else None,
|
||||
text=texts[i],
|
||||
tokens=None,
|
||||
mask=None,
|
||||
|
@ -191,13 +209,22 @@ class ArrowFileIterator(StatefulIterator):
|
|||
|
||||
for batch in self.batch_iterator:
|
||||
batch_columns = batch.to_pydict()
|
||||
sample_ids = batch_columns["sample_id"]
|
||||
texts = batch_columns["text"]
|
||||
entropies = batch_columns["entropies"]
|
||||
if self.file_format == "arrow":
|
||||
sample_ids = batch_columns["sample_id"]
|
||||
texts = batch_columns["text"]
|
||||
entropies = batch_columns["entropies"]
|
||||
elif self.file_format == "json":
|
||||
# This data hasn't been preprocessed to a uniform format,
|
||||
# so we have to do it now and omit entropies
|
||||
sample_ids = batch_columns[get_id_key(batch_columns)]
|
||||
texts = get_text(batch_columns)
|
||||
entropies = None
|
||||
else:
|
||||
raise ValueError(f"Unknown file format: {self.file_format}")
|
||||
for i in range(len(sample_ids)):
|
||||
out = BltExample(
|
||||
sample_id=sample_ids[i],
|
||||
entropies=entropies[i],
|
||||
entropies=entropies[i] if entropies is not None else None,
|
||||
text=texts[i],
|
||||
tokens=None,
|
||||
mask=None,
|
||||
|
@ -231,13 +258,24 @@ class ArrowFileIterator(StatefulIterator):
|
|||
for batch in self.batch_iterator:
|
||||
if len(batch) > curr_remaining:
|
||||
batch_columns: dict[str, list] = batch.to_pydict()
|
||||
batch_columns["sample_id"] = batch_columns["sample_id"][
|
||||
curr_remaining:
|
||||
]
|
||||
batch_columns["entropies"] = batch_columns["entropies"][
|
||||
curr_remaining:
|
||||
]
|
||||
batch_columns["text"] = batch_columns["text"][curr_remaining:]
|
||||
if self.file_format == "arrow":
|
||||
leftover_sample_ids = batch_columns["sample_id"][
|
||||
curr_remaining:
|
||||
]
|
||||
leftover_entropies = batch_columns["entropies"][curr_remaining:]
|
||||
leftover_texts = batch_columns["text"][curr_remaining:]
|
||||
elif self.file_format == "json":
|
||||
leftover_sample_ids = batch_columns[get_id_key(batch_columns)][
|
||||
curr_remaining:
|
||||
]
|
||||
leftover_entropies = None
|
||||
leftover_texts = get_text(batch_columns)[curr_remaining:]
|
||||
else:
|
||||
raise ValueError(f"Unknown file format: {self.file_format}")
|
||||
|
||||
batch_columns["sample_id"] = leftover_sample_ids
|
||||
batch_columns["entropies"] = leftover_entropies
|
||||
batch_columns["text"] = leftover_texts
|
||||
self.batch_to_consume = batch_columns
|
||||
break
|
||||
elif len(batch) == curr_remaining:
|
||||
|
@ -250,30 +288,3 @@ class ArrowFileIterator(StatefulIterator):
|
|||
logger.info(
|
||||
f"Finished setting arrow position to {target_row_num} for {self.dataset_files}"
|
||||
)
|
||||
|
||||
|
||||
TRAIN_DATA_FILE_PATTERN = "*.chunk.*.jsonl"
|
||||
|
||||
|
||||
def find_and_sanitize_chunks(
|
||||
dataset_path: str,
|
||||
world_size: int,
|
||||
file_pattern: str = TRAIN_DATA_FILE_PATTERN,
|
||||
s3_profile: str | None = None,
|
||||
):
|
||||
fs = get_fs(dataset_path, s3_profile=s3_profile)
|
||||
path_with_glob = os.path.join(dataset_path, file_pattern)
|
||||
dataset_chunks = fs.glob(path_with_glob)
|
||||
n_chunks = len(dataset_chunks)
|
||||
|
||||
if n_chunks > world_size:
|
||||
n_discard = n_chunks - world_size
|
||||
dataset_chunks = dataset_chunks[:world_size]
|
||||
else:
|
||||
assert (
|
||||
world_size % n_chunks == 0
|
||||
), "World size should be a multiple of number of chunks"
|
||||
|
||||
assert n_chunks > 0, f"No valid chunks in {dataset_path}"
|
||||
|
||||
return dataset_chunks
|
||||
|
|
|
@ -15,29 +15,37 @@ from bytelatent.entropy_model import load_entropy_model
|
|||
from bytelatent.tokenizers.build_tokenizer import TokenizerArgs
|
||||
|
||||
|
||||
def get_id_from_doc(doc: dict) -> int:
|
||||
def get_id_key(doc: dict) -> int:
|
||||
"""
|
||||
We need a reliable way to ensure that samples from jsonl
|
||||
and arrow are the same, but there is no unique id field,
|
||||
so derive the best possible
|
||||
"""
|
||||
if "sample_id" in doc:
|
||||
sample_id = doc["sample_id"]
|
||||
return "sample_id"
|
||||
elif "title" in doc:
|
||||
sample_id = doc["title"]
|
||||
return "title"
|
||||
elif "qid" in doc:
|
||||
sample_id = doc["qid"]
|
||||
return "qid"
|
||||
elif "paper_id" in doc:
|
||||
sample_id = doc["paper_id"]
|
||||
return "paper_id"
|
||||
elif "path" in doc:
|
||||
sample_id = doc["path"]
|
||||
return "path"
|
||||
elif "url" in doc:
|
||||
sample_id = doc["url"]
|
||||
return "url"
|
||||
elif "id" in doc:
|
||||
sample_id = doc["id"]
|
||||
return "id"
|
||||
else:
|
||||
raise ValueError(f"Could not find a id key from: {doc.keys()}")
|
||||
return str(sample_id)
|
||||
|
||||
|
||||
def get_id_from_doc(doc: dict) -> int:
|
||||
"""
|
||||
We need a reliable way to ensure that samples from jsonl
|
||||
and arrow are the same, but there is no unique id field,
|
||||
so derive the best possible
|
||||
"""
|
||||
return str(doc[get_id_key(doc)])
|
||||
|
||||
|
||||
def get_text(doc: dict):
|
||||
|
|
|
@ -4,10 +4,10 @@ import json
|
|||
import os
|
||||
import shutil
|
||||
import subprocess
|
||||
from pydantic import BaseModel
|
||||
from typing import Any, Dict
|
||||
|
||||
from omegaconf import OmegaConf
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class StoolArgs(BaseModel):
|
||||
|
|
Loading…
Reference in a new issue