mirror of
https://github.com/facebookresearch/blt.git
synced 2025-02-23 13:32:14 +00:00
Allow ArrowIterator to read from json
Summary: Test Plan:
This commit is contained in:
parent
afedb16598
commit
8d26140970
|
@ -1,8 +1,10 @@
|
||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
import json
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
|
import fsspec
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import yaml
|
import yaml
|
||||||
from omegaconf import OmegaConf
|
from omegaconf import OmegaConf
|
||||||
|
@ -10,11 +12,9 @@ from pydantic import BaseModel, ConfigDict
|
||||||
|
|
||||||
from bytelatent.checkpoint import CheckpointArgs
|
from bytelatent.checkpoint import CheckpointArgs
|
||||||
from bytelatent.data.data_types import Batch
|
from bytelatent.data.data_types import Batch
|
||||||
|
from bytelatent.data.file_util import get_fs
|
||||||
from bytelatent.data.iterators.abstract_iterator import StatefulIterator
|
from bytelatent.data.iterators.abstract_iterator import StatefulIterator
|
||||||
from bytelatent.data.iterators.arrow_iterator import (
|
from bytelatent.data.iterators.arrow_iterator import ArrowFileIterator
|
||||||
ArrowFileIterator,
|
|
||||||
find_and_sanitize_chunks,
|
|
||||||
)
|
|
||||||
from bytelatent.data.iterators.looping_iterator import LoopingIterator
|
from bytelatent.data.iterators.looping_iterator import LoopingIterator
|
||||||
from bytelatent.data.iterators.multiprocess_iterator import MultiprocessIterator
|
from bytelatent.data.iterators.multiprocess_iterator import MultiprocessIterator
|
||||||
from bytelatent.data.iterators.packing_iterator import PackingArgs, PackingIterator
|
from bytelatent.data.iterators.packing_iterator import PackingArgs, PackingIterator
|
||||||
|
@ -53,6 +53,33 @@ def parse_args(args_cls):
|
||||||
return pydantic_args
|
return pydantic_args
|
||||||
|
|
||||||
|
|
||||||
|
TRAIN_DATA_FILE_PATTERN = "*.chunk.*.jsonl"
|
||||||
|
|
||||||
|
|
||||||
|
def find_and_sanitize_chunks(
|
||||||
|
dataset_path: str,
|
||||||
|
world_size: int,
|
||||||
|
file_pattern: str,
|
||||||
|
s3_profile: str | None = None,
|
||||||
|
):
|
||||||
|
fs = get_fs(dataset_path, s3_profile=s3_profile)
|
||||||
|
path_with_glob = os.path.join(dataset_path, file_pattern)
|
||||||
|
dataset_chunks = fs.glob(path_with_glob)
|
||||||
|
n_chunks = len(dataset_chunks)
|
||||||
|
|
||||||
|
if n_chunks > world_size:
|
||||||
|
n_discard = n_chunks - world_size
|
||||||
|
dataset_chunks = dataset_chunks[:world_size]
|
||||||
|
else:
|
||||||
|
assert (
|
||||||
|
world_size % n_chunks == 0
|
||||||
|
), "World size should be a multiple of number of chunks"
|
||||||
|
|
||||||
|
assert n_chunks > 0, f"No valid chunks in {dataset_path}"
|
||||||
|
|
||||||
|
return dataset_chunks
|
||||||
|
|
||||||
|
|
||||||
def distribute_data_to_rank(
|
def distribute_data_to_rank(
|
||||||
*,
|
*,
|
||||||
dataset_path: str,
|
dataset_path: str,
|
||||||
|
@ -62,9 +89,10 @@ def distribute_data_to_rank(
|
||||||
rank: int,
|
rank: int,
|
||||||
world_size: int,
|
world_size: int,
|
||||||
s3_profile: str | None = None,
|
s3_profile: str | None = None,
|
||||||
|
file_pattern: str = TRAIN_DATA_FILE_PATTERN,
|
||||||
) -> ArrowFileIterator:
|
) -> ArrowFileIterator:
|
||||||
dataset_chunks = find_and_sanitize_chunks(
|
dataset_chunks = find_and_sanitize_chunks(
|
||||||
dataset_path, world_size, s3_profile=s3_profile
|
dataset_path, world_size, file_pattern, s3_profile=s3_profile
|
||||||
)
|
)
|
||||||
n_workers_per_chunk = world_size // len(dataset_chunks)
|
n_workers_per_chunk = world_size // len(dataset_chunks)
|
||||||
rank_to_arrow_iterator_params = []
|
rank_to_arrow_iterator_params = []
|
||||||
|
|
|
@ -16,6 +16,7 @@ from bytelatent import ByteLatentError
|
||||||
from bytelatent.data.data_types import BltExample
|
from bytelatent.data.data_types import BltExample
|
||||||
from bytelatent.data.file_util import get_fs
|
from bytelatent.data.file_util import get_fs
|
||||||
from bytelatent.data.iterators.abstract_iterator import IteratorState, StatefulIterator
|
from bytelatent.data.iterators.abstract_iterator import IteratorState, StatefulIterator
|
||||||
|
from bytelatent.preprocess.preprocess_entropies import get_id_key, get_text
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
@ -32,6 +33,7 @@ class ArrowFileIteratorState(BaseModel, IteratorState):
|
||||||
arrow_batch_size: int = 100
|
arrow_batch_size: int = 100
|
||||||
s3_profile: str | None
|
s3_profile: str | None
|
||||||
filesystem_type: str | None = None
|
filesystem_type: str | None = None
|
||||||
|
file_format: str
|
||||||
|
|
||||||
def build(self) -> "ArrowFileIterator":
|
def build(self) -> "ArrowFileIterator":
|
||||||
arrow_file = ArrowFileIterator(
|
arrow_file = ArrowFileIterator(
|
||||||
|
@ -44,6 +46,7 @@ class ArrowFileIteratorState(BaseModel, IteratorState):
|
||||||
dataset_files=self.dataset_files,
|
dataset_files=self.dataset_files,
|
||||||
s3_profile=self.s3_profile,
|
s3_profile=self.s3_profile,
|
||||||
filesystem_type=self.filesystem_type,
|
filesystem_type=self.filesystem_type,
|
||||||
|
file_format=self.file_format,
|
||||||
)
|
)
|
||||||
if self.row_num != 0:
|
if self.row_num != 0:
|
||||||
arrow_file._set_row_num(self.row_num)
|
arrow_file._set_row_num(self.row_num)
|
||||||
|
@ -70,6 +73,7 @@ class ArrowFileIterator(StatefulIterator):
|
||||||
dataset_files: list[str] | None = None,
|
dataset_files: list[str] | None = None,
|
||||||
s3_profile: str | None = None,
|
s3_profile: str | None = None,
|
||||||
filesystem_type: str | None = None,
|
filesystem_type: str | None = None,
|
||||||
|
file_format: str = "arrow",
|
||||||
):
|
):
|
||||||
assert 0 <= worker_id < num_workers, (worker_id, num_workers)
|
assert 0 <= worker_id < num_workers, (worker_id, num_workers)
|
||||||
if file_path is None and dataset_files is None:
|
if file_path is None and dataset_files is None:
|
||||||
|
@ -87,12 +91,16 @@ class ArrowFileIterator(StatefulIterator):
|
||||||
self.arrow_batch_size = arrow_batch_size
|
self.arrow_batch_size = arrow_batch_size
|
||||||
self.s3_profile = s3_profile
|
self.s3_profile = s3_profile
|
||||||
self.filesystem_type = filesystem_type
|
self.filesystem_type = filesystem_type
|
||||||
|
self.file_format = file_format
|
||||||
self.fs = None
|
self.fs = None
|
||||||
if self.filesystem_type is not None:
|
if self.filesystem_type is not None:
|
||||||
if self.filesystem_type == "file":
|
if self.filesystem_type == "file":
|
||||||
self.fs = fsspec.filesystem("file")
|
self.fs = fsspec.filesystem("file")
|
||||||
elif self.filesystem_type == "s3":
|
elif self.filesystem_type == "s3":
|
||||||
self.fs = fsspec.filesystem("s3", profile=s3_profile)
|
self.fs = fsspec.filesystem("s3", profile=s3_profile)
|
||||||
|
else:
|
||||||
|
raise ValueError("Unknown filesystem")
|
||||||
|
logger.info("Arrow iterator using fs=%s", self.fs)
|
||||||
|
|
||||||
if dataset_files is None:
|
if dataset_files is None:
|
||||||
# Prepare arrow shards
|
# Prepare arrow shards
|
||||||
|
@ -153,6 +161,7 @@ class ArrowFileIterator(StatefulIterator):
|
||||||
dataset_files=self.dataset_files,
|
dataset_files=self.dataset_files,
|
||||||
s3_profile=self.s3_profile,
|
s3_profile=self.s3_profile,
|
||||||
filesystem_type=self.filesystem_type,
|
filesystem_type=self.filesystem_type,
|
||||||
|
file_format=self.file_format,
|
||||||
)
|
)
|
||||||
|
|
||||||
def create_iter(
|
def create_iter(
|
||||||
|
@ -164,7 +173,7 @@ class ArrowFileIterator(StatefulIterator):
|
||||||
else:
|
else:
|
||||||
filesystem = None
|
filesystem = None
|
||||||
self.dataset = pa.dataset.dataset(
|
self.dataset = pa.dataset.dataset(
|
||||||
self.dataset_files, format="arrow", filesystem=filesystem
|
self.dataset_files, format=self.file_format, filesystem=filesystem
|
||||||
)
|
)
|
||||||
self.batch_iterator = self.dataset.to_batches(
|
self.batch_iterator = self.dataset.to_batches(
|
||||||
batch_size=self.arrow_batch_size
|
batch_size=self.arrow_batch_size
|
||||||
|
@ -173,13 +182,22 @@ class ArrowFileIterator(StatefulIterator):
|
||||||
if self.batch_to_consume is not None:
|
if self.batch_to_consume is not None:
|
||||||
batch_columns: dict[str, list] = self.batch_to_consume
|
batch_columns: dict[str, list] = self.batch_to_consume
|
||||||
self.batch_to_consume = None
|
self.batch_to_consume = None
|
||||||
|
if self.file_format == "arrow":
|
||||||
sample_ids = batch_columns["sample_id"]
|
sample_ids = batch_columns["sample_id"]
|
||||||
texts = batch_columns["text"]
|
texts = batch_columns["text"]
|
||||||
entropies = batch_columns["entropies"]
|
entropies = batch_columns["entropies"]
|
||||||
|
elif self.file_format == "json":
|
||||||
|
# This data hasn't been preprocessed to a uniform format,
|
||||||
|
# so we have to do it now and omit entropies
|
||||||
|
sample_ids = batch_columns[get_id_key(batch_columns)]
|
||||||
|
texts = get_text(batch_columns)
|
||||||
|
entropies = None
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown file format: {self.file_format}")
|
||||||
for i in range(len(sample_ids)):
|
for i in range(len(sample_ids)):
|
||||||
out = BltExample(
|
out = BltExample(
|
||||||
sample_id=sample_ids[i],
|
sample_id=sample_ids[i],
|
||||||
entropies=entropies[i],
|
entropies=entropies[i] if entropies is not None else None,
|
||||||
text=texts[i],
|
text=texts[i],
|
||||||
tokens=None,
|
tokens=None,
|
||||||
mask=None,
|
mask=None,
|
||||||
|
@ -191,13 +209,22 @@ class ArrowFileIterator(StatefulIterator):
|
||||||
|
|
||||||
for batch in self.batch_iterator:
|
for batch in self.batch_iterator:
|
||||||
batch_columns = batch.to_pydict()
|
batch_columns = batch.to_pydict()
|
||||||
|
if self.file_format == "arrow":
|
||||||
sample_ids = batch_columns["sample_id"]
|
sample_ids = batch_columns["sample_id"]
|
||||||
texts = batch_columns["text"]
|
texts = batch_columns["text"]
|
||||||
entropies = batch_columns["entropies"]
|
entropies = batch_columns["entropies"]
|
||||||
|
elif self.file_format == "json":
|
||||||
|
# This data hasn't been preprocessed to a uniform format,
|
||||||
|
# so we have to do it now and omit entropies
|
||||||
|
sample_ids = batch_columns[get_id_key(batch_columns)]
|
||||||
|
texts = get_text(batch_columns)
|
||||||
|
entropies = None
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown file format: {self.file_format}")
|
||||||
for i in range(len(sample_ids)):
|
for i in range(len(sample_ids)):
|
||||||
out = BltExample(
|
out = BltExample(
|
||||||
sample_id=sample_ids[i],
|
sample_id=sample_ids[i],
|
||||||
entropies=entropies[i],
|
entropies=entropies[i] if entropies is not None else None,
|
||||||
text=texts[i],
|
text=texts[i],
|
||||||
tokens=None,
|
tokens=None,
|
||||||
mask=None,
|
mask=None,
|
||||||
|
@ -231,13 +258,24 @@ class ArrowFileIterator(StatefulIterator):
|
||||||
for batch in self.batch_iterator:
|
for batch in self.batch_iterator:
|
||||||
if len(batch) > curr_remaining:
|
if len(batch) > curr_remaining:
|
||||||
batch_columns: dict[str, list] = batch.to_pydict()
|
batch_columns: dict[str, list] = batch.to_pydict()
|
||||||
batch_columns["sample_id"] = batch_columns["sample_id"][
|
if self.file_format == "arrow":
|
||||||
|
leftover_sample_ids = batch_columns["sample_id"][
|
||||||
curr_remaining:
|
curr_remaining:
|
||||||
]
|
]
|
||||||
batch_columns["entropies"] = batch_columns["entropies"][
|
leftover_entropies = batch_columns["entropies"][curr_remaining:]
|
||||||
|
leftover_texts = batch_columns["text"][curr_remaining:]
|
||||||
|
elif self.file_format == "json":
|
||||||
|
leftover_sample_ids = batch_columns[get_id_key(batch_columns)][
|
||||||
curr_remaining:
|
curr_remaining:
|
||||||
]
|
]
|
||||||
batch_columns["text"] = batch_columns["text"][curr_remaining:]
|
leftover_entropies = None
|
||||||
|
leftover_texts = get_text(batch_columns)[curr_remaining:]
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown file format: {self.file_format}")
|
||||||
|
|
||||||
|
batch_columns["sample_id"] = leftover_sample_ids
|
||||||
|
batch_columns["entropies"] = leftover_entropies
|
||||||
|
batch_columns["text"] = leftover_texts
|
||||||
self.batch_to_consume = batch_columns
|
self.batch_to_consume = batch_columns
|
||||||
break
|
break
|
||||||
elif len(batch) == curr_remaining:
|
elif len(batch) == curr_remaining:
|
||||||
|
@ -250,30 +288,3 @@ class ArrowFileIterator(StatefulIterator):
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Finished setting arrow position to {target_row_num} for {self.dataset_files}"
|
f"Finished setting arrow position to {target_row_num} for {self.dataset_files}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
TRAIN_DATA_FILE_PATTERN = "*.chunk.*.jsonl"
|
|
||||||
|
|
||||||
|
|
||||||
def find_and_sanitize_chunks(
|
|
||||||
dataset_path: str,
|
|
||||||
world_size: int,
|
|
||||||
file_pattern: str = TRAIN_DATA_FILE_PATTERN,
|
|
||||||
s3_profile: str | None = None,
|
|
||||||
):
|
|
||||||
fs = get_fs(dataset_path, s3_profile=s3_profile)
|
|
||||||
path_with_glob = os.path.join(dataset_path, file_pattern)
|
|
||||||
dataset_chunks = fs.glob(path_with_glob)
|
|
||||||
n_chunks = len(dataset_chunks)
|
|
||||||
|
|
||||||
if n_chunks > world_size:
|
|
||||||
n_discard = n_chunks - world_size
|
|
||||||
dataset_chunks = dataset_chunks[:world_size]
|
|
||||||
else:
|
|
||||||
assert (
|
|
||||||
world_size % n_chunks == 0
|
|
||||||
), "World size should be a multiple of number of chunks"
|
|
||||||
|
|
||||||
assert n_chunks > 0, f"No valid chunks in {dataset_path}"
|
|
||||||
|
|
||||||
return dataset_chunks
|
|
||||||
|
|
|
@ -15,29 +15,37 @@ from bytelatent.entropy_model import load_entropy_model
|
||||||
from bytelatent.tokenizers.build_tokenizer import TokenizerArgs
|
from bytelatent.tokenizers.build_tokenizer import TokenizerArgs
|
||||||
|
|
||||||
|
|
||||||
def get_id_from_doc(doc: dict) -> int:
|
def get_id_key(doc: dict) -> int:
|
||||||
"""
|
"""
|
||||||
We need a reliable way to ensure that samples from jsonl
|
We need a reliable way to ensure that samples from jsonl
|
||||||
and arrow are the same, but there is no unique id field,
|
and arrow are the same, but there is no unique id field,
|
||||||
so derive the best possible
|
so derive the best possible
|
||||||
"""
|
"""
|
||||||
if "sample_id" in doc:
|
if "sample_id" in doc:
|
||||||
sample_id = doc["sample_id"]
|
return "sample_id"
|
||||||
elif "title" in doc:
|
elif "title" in doc:
|
||||||
sample_id = doc["title"]
|
return "title"
|
||||||
elif "qid" in doc:
|
elif "qid" in doc:
|
||||||
sample_id = doc["qid"]
|
return "qid"
|
||||||
elif "paper_id" in doc:
|
elif "paper_id" in doc:
|
||||||
sample_id = doc["paper_id"]
|
return "paper_id"
|
||||||
elif "path" in doc:
|
elif "path" in doc:
|
||||||
sample_id = doc["path"]
|
return "path"
|
||||||
elif "url" in doc:
|
elif "url" in doc:
|
||||||
sample_id = doc["url"]
|
return "url"
|
||||||
elif "id" in doc:
|
elif "id" in doc:
|
||||||
sample_id = doc["id"]
|
return "id"
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Could not find a id key from: {doc.keys()}")
|
raise ValueError(f"Could not find a id key from: {doc.keys()}")
|
||||||
return str(sample_id)
|
|
||||||
|
|
||||||
|
def get_id_from_doc(doc: dict) -> int:
|
||||||
|
"""
|
||||||
|
We need a reliable way to ensure that samples from jsonl
|
||||||
|
and arrow are the same, but there is no unique id field,
|
||||||
|
so derive the best possible
|
||||||
|
"""
|
||||||
|
return str(doc[get_id_key(doc)])
|
||||||
|
|
||||||
|
|
||||||
def get_text(doc: dict):
|
def get_text(doc: dict):
|
||||||
|
|
Loading…
Reference in a new issue