mirror of
https://github.com/facebookresearch/blt.git
synced 2025-01-18 16:37:46 +00:00
Replace regular filesystem calls with fsspec + add s3 support
Summary: For compatibility with either local/nfs or S3 datasets, swap to fsspec. Add a tool to compare local and remote filesystems Test Plan: - Ran regular train script - Ran with config with data in S3
This commit is contained in:
parent
d4ddb95322
commit
a1d05403b4
2
.gitignore
vendored
2
.gitignore
vendored
|
@ -165,4 +165,4 @@ cython_debug/
|
||||||
figures/
|
figures/
|
||||||
.vscode/
|
.vscode/
|
||||||
.DS_Store
|
.DS_Store
|
||||||
|
internal/
|
||||||
|
|
|
@ -46,8 +46,11 @@ def distribute_data_to_rank(
|
||||||
arrow_batch_size: int,
|
arrow_batch_size: int,
|
||||||
rank: int,
|
rank: int,
|
||||||
world_size: int,
|
world_size: int,
|
||||||
|
s3_profile: str | None = None,
|
||||||
) -> ArrowFileIterator:
|
) -> ArrowFileIterator:
|
||||||
dataset_chunks = find_and_sanitize_chunks(dataset_path, world_size)
|
dataset_chunks = find_and_sanitize_chunks(
|
||||||
|
dataset_path, world_size, s3_profile=s3_profile
|
||||||
|
)
|
||||||
n_workers_per_chunk = world_size // len(dataset_chunks)
|
n_workers_per_chunk = world_size // len(dataset_chunks)
|
||||||
rank_to_arrow_iterator_params = []
|
rank_to_arrow_iterator_params = []
|
||||||
for chunk_path in dataset_chunks:
|
for chunk_path in dataset_chunks:
|
||||||
|
@ -61,6 +64,7 @@ def distribute_data_to_rank(
|
||||||
dataset_files=None,
|
dataset_files=None,
|
||||||
entropy_model_name=entropy_model_name,
|
entropy_model_name=entropy_model_name,
|
||||||
arrow_batch_size=arrow_batch_size,
|
arrow_batch_size=arrow_batch_size,
|
||||||
|
s3_profile=s3_profile,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
return rank_to_arrow_iterator_params[rank]
|
return rank_to_arrow_iterator_params[rank]
|
||||||
|
@ -68,6 +72,7 @@ def distribute_data_to_rank(
|
||||||
|
|
||||||
class DataloaderArgs(BaseModel):
|
class DataloaderArgs(BaseModel):
|
||||||
model_config = ConfigDict(extra="forbid")
|
model_config = ConfigDict(extra="forbid")
|
||||||
|
s3_profile: str | None = None
|
||||||
root_dir: str | None = None
|
root_dir: str | None = None
|
||||||
sources: dict[str, float] = {}
|
sources: dict[str, float] = {}
|
||||||
batch_size: int = 2
|
batch_size: int = 2
|
||||||
|
@ -107,6 +112,7 @@ class DataloaderArgs(BaseModel):
|
||||||
arrow_batch_size=self.arrow_batch_size,
|
arrow_batch_size=self.arrow_batch_size,
|
||||||
rank=rank,
|
rank=rank,
|
||||||
world_size=world_size,
|
world_size=world_size,
|
||||||
|
s3_profile=self.s3_profile,
|
||||||
)
|
)
|
||||||
looping_iterator = LoopingIterator(arrow_iterator)
|
looping_iterator = LoopingIterator(arrow_iterator)
|
||||||
preprocess_iterator = PreprocessIterator(
|
preprocess_iterator = PreprocessIterator(
|
||||||
|
|
117
bytelatent/data/file_util.py
Normal file
117
bytelatent/data/file_util.py
Normal file
|
@ -0,0 +1,117 @@
|
||||||
|
import os
|
||||||
|
|
||||||
|
import fsspec
|
||||||
|
import pyarrow as pa
|
||||||
|
|
||||||
|
# pyarrow needs the initialization from this import
|
||||||
|
import pyarrow.dataset # pyright: ignore
|
||||||
|
import typer
|
||||||
|
from pyarrow.lib import ArrowInvalid
|
||||||
|
from rich.progress import track
|
||||||
|
|
||||||
|
|
||||||
|
def is_valid_arrow_file(path: str):
|
||||||
|
try:
|
||||||
|
dataset = pa.dataset.dataset(path, format="arrow")
|
||||||
|
return True
|
||||||
|
except ArrowInvalid:
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
app = typer.Typer()
|
||||||
|
|
||||||
|
S3_PREFIX = "s3://"
|
||||||
|
|
||||||
|
|
||||||
|
def get_fs(path: str, s3_profile: str | None = None) -> fsspec.AbstractFileSystem:
|
||||||
|
if path.startswith("s3://"):
|
||||||
|
if s3_profile is None:
|
||||||
|
return fsspec.filesystem("s3")
|
||||||
|
else:
|
||||||
|
return fsspec.filesystem("s3", profile=s3_profile)
|
||||||
|
else:
|
||||||
|
return fsspec.filesystem("file")
|
||||||
|
|
||||||
|
|
||||||
|
@app.command()
|
||||||
|
def print_local_to_delete(
|
||||||
|
blob_dir: str, local_dirs: list[str], s3_profile: str = "blt"
|
||||||
|
):
|
||||||
|
for s in local_dirs:
|
||||||
|
assert s.endswith("/"), "Dirs must end with /"
|
||||||
|
assert blob_dir.endswith("/"), "Dirs must end with /"
|
||||||
|
blob_fs = fsspec.filesystem("s3", profile=s3_profile)
|
||||||
|
blob_files = blob_fs.find(blob_dir)
|
||||||
|
for f in track(blob_files):
|
||||||
|
size = blob_fs.info(f)["Size"]
|
||||||
|
if not f.lower().endswith(".complete"):
|
||||||
|
assert size != 0, f"Size was invalidly zero for {f}"
|
||||||
|
|
||||||
|
blob_relative_paths = {f[len(blob_dir) - len(S3_PREFIX) :] for f in blob_files}
|
||||||
|
local_fs = fsspec.filesystem("file")
|
||||||
|
|
||||||
|
files_to_delete = []
|
||||||
|
for local_dir in local_dirs:
|
||||||
|
local_files = local_fs.find(local_dir)
|
||||||
|
for f in local_files:
|
||||||
|
relative_path = f[len(local_dir) :]
|
||||||
|
if relative_path in blob_relative_paths and not os.path.islink(f):
|
||||||
|
files_to_delete.append(f)
|
||||||
|
print(len(files_to_delete))
|
||||||
|
with open("/tmp/files_to_delete.txt", "w") as f:
|
||||||
|
for file in files_to_delete:
|
||||||
|
f.write(f"{file}\n")
|
||||||
|
|
||||||
|
|
||||||
|
@app.command()
|
||||||
|
def compare_local_to_blob(
|
||||||
|
source_dirs: list[str], dst_dir: str, s3_profile: str = "blt"
|
||||||
|
):
|
||||||
|
for s in source_dirs:
|
||||||
|
assert s.endswith("/"), "Dirs must end with /"
|
||||||
|
assert dst_dir.endswith("/"), "Dirs must end with /"
|
||||||
|
assert len(source_dirs) != 0
|
||||||
|
assert dst_dir.startswith("s3://")
|
||||||
|
local_fs = fsspec.filesystem("file")
|
||||||
|
dst_fs = fsspec.filesystem("s3", profile=s3_profile)
|
||||||
|
source_to_files = {}
|
||||||
|
all_local_files = set()
|
||||||
|
for s in source_dirs:
|
||||||
|
skipped = []
|
||||||
|
if s not in source_to_files:
|
||||||
|
source_to_files[s] = []
|
||||||
|
for f in local_fs.find(s):
|
||||||
|
if os.path.islink(f):
|
||||||
|
continue
|
||||||
|
if f.endswith(".COMPLETE") or f.endswith(".complete"):
|
||||||
|
is_complete_file = True
|
||||||
|
assert os.path.getsize(f) == 0, ".COMPLETE files should be empty"
|
||||||
|
else:
|
||||||
|
is_complete_file = False
|
||||||
|
|
||||||
|
if not is_complete_file and os.path.getsize(f) == 0:
|
||||||
|
skipped.append(f)
|
||||||
|
continue
|
||||||
|
if f.endswith(".arrow"):
|
||||||
|
if not is_valid_arrow_file(f):
|
||||||
|
skipped.append(f)
|
||||||
|
continue
|
||||||
|
|
||||||
|
source_to_files[s].append(f)
|
||||||
|
all_local_files.add(f[len(s) :])
|
||||||
|
print(s, len(source_to_files[s]), "skipped", len(skipped), skipped[:10])
|
||||||
|
|
||||||
|
dst_files = dst_fs.find(dst_dir)
|
||||||
|
print(dst_dir, len(dst_files))
|
||||||
|
|
||||||
|
dst_file_set = {f[len(dst_dir) - len(S3_PREFIX) :] for f in dst_files}
|
||||||
|
diff = all_local_files.symmetric_difference(dst_file_set)
|
||||||
|
print("Local files", len(all_local_files))
|
||||||
|
print("DST Files", len(dst_file_set))
|
||||||
|
print("Symmetric difference", len(diff))
|
||||||
|
dst_only_files = dst_file_set - all_local_files
|
||||||
|
print("DST only", len(dst_only_files), list(dst_only_files)[:10])
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
app()
|
|
@ -1,17 +1,20 @@
|
||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
import os
|
||||||
import re
|
import re
|
||||||
from logging import getLogger
|
from logging import getLogger
|
||||||
from pathlib import Path
|
|
||||||
from typing import Any, Generator
|
from typing import Any, Generator
|
||||||
|
|
||||||
|
import fsspec
|
||||||
import pyarrow as pa
|
import pyarrow as pa
|
||||||
|
|
||||||
# pyarrow needs the initialization from this import
|
# pyarrow needs the initialization from this import
|
||||||
import pyarrow.dataset # pyright: ignore
|
import pyarrow.dataset # pyright: ignore
|
||||||
|
import s3fs
|
||||||
from pydantic import BaseModel, ConfigDict
|
from pydantic import BaseModel, ConfigDict
|
||||||
|
|
||||||
from bytelatent import ByteLatentError
|
from bytelatent import ByteLatentError
|
||||||
from bytelatent.data.data_types import BltExample
|
from bytelatent.data.data_types import BltExample
|
||||||
|
from bytelatent.data.file_util import get_fs
|
||||||
from bytelatent.data.iterators.abstract_iterator import IteratorState, StatefulIterator
|
from bytelatent.data.iterators.abstract_iterator import IteratorState, StatefulIterator
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
@ -27,6 +30,8 @@ class ArrowFileIteratorState(BaseModel, IteratorState):
|
||||||
dataset_files: list[str] | None
|
dataset_files: list[str] | None
|
||||||
entropy_model_name: str | None
|
entropy_model_name: str | None
|
||||||
arrow_batch_size: int = 100
|
arrow_batch_size: int = 100
|
||||||
|
s3_profile: str | None
|
||||||
|
filesystem_type: str | None = None
|
||||||
|
|
||||||
def build(self) -> "ArrowFileIterator":
|
def build(self) -> "ArrowFileIterator":
|
||||||
arrow_file = ArrowFileIterator(
|
arrow_file = ArrowFileIterator(
|
||||||
|
@ -37,14 +42,17 @@ class ArrowFileIteratorState(BaseModel, IteratorState):
|
||||||
entropy_model_name=self.entropy_model_name,
|
entropy_model_name=self.entropy_model_name,
|
||||||
arrow_batch_size=self.arrow_batch_size,
|
arrow_batch_size=self.arrow_batch_size,
|
||||||
dataset_files=self.dataset_files,
|
dataset_files=self.dataset_files,
|
||||||
|
s3_profile=self.s3_profile,
|
||||||
|
filesystem_type=self.filesystem_type,
|
||||||
)
|
)
|
||||||
if self.row_num != 0:
|
if self.row_num != 0:
|
||||||
arrow_file._set_row_num(self.row_num)
|
arrow_file._set_row_num(self.row_num)
|
||||||
return arrow_file
|
return arrow_file
|
||||||
|
|
||||||
|
|
||||||
def shard_sort_key(file: str | Path):
|
def shard_sort_key(file: str):
|
||||||
match = re.search(r".+\.shard_([0-9]+)\.arrow", str(file))
|
assert isinstance(file, str)
|
||||||
|
match = re.search(r".+\.shard_([0-9]+)\.arrow", file)
|
||||||
shard_number = int(match.group(1))
|
shard_number = int(match.group(1))
|
||||||
return shard_number
|
return shard_number
|
||||||
|
|
||||||
|
@ -60,6 +68,8 @@ class ArrowFileIterator(StatefulIterator):
|
||||||
entropy_model_name: str | None,
|
entropy_model_name: str | None,
|
||||||
arrow_batch_size: int,
|
arrow_batch_size: int,
|
||||||
dataset_files: list[str] | None = None,
|
dataset_files: list[str] | None = None,
|
||||||
|
s3_profile: str | None = None,
|
||||||
|
filesystem_type: str | None = None,
|
||||||
):
|
):
|
||||||
assert 0 <= worker_id < num_workers, (worker_id, num_workers)
|
assert 0 <= worker_id < num_workers, (worker_id, num_workers)
|
||||||
if file_path is None and dataset_files is None:
|
if file_path is None and dataset_files is None:
|
||||||
|
@ -75,16 +85,41 @@ class ArrowFileIterator(StatefulIterator):
|
||||||
self.preprocess_dir = preprocess_dir
|
self.preprocess_dir = preprocess_dir
|
||||||
self.entropy_model_name = entropy_model_name
|
self.entropy_model_name = entropy_model_name
|
||||||
self.arrow_batch_size = arrow_batch_size
|
self.arrow_batch_size = arrow_batch_size
|
||||||
|
self.s3_profile = s3_profile
|
||||||
|
self.filesystem_type = filesystem_type
|
||||||
|
self.fs = None
|
||||||
|
if self.filesystem_type is not None:
|
||||||
|
if self.filesystem_type == "file":
|
||||||
|
self.fs = fsspec.filesystem("file")
|
||||||
|
elif self.filesystem_type == "s3":
|
||||||
|
self.fs = fsspec.filesystem("s3", profile=s3_profile)
|
||||||
|
|
||||||
if dataset_files is None:
|
if dataset_files is None:
|
||||||
# Prepare arrow shards
|
# Prepare arrow shards
|
||||||
jsonl_file = Path(file_path)
|
jsonl_file = file_path
|
||||||
parts = re.match(r"(.+)\.chunk\.[0-9]+\.jsonl", jsonl_file.name)
|
parts = re.match(
|
||||||
|
r"(.+)\.chunk\.[0-9]+\.jsonl", os.path.basename(jsonl_file)
|
||||||
|
)
|
||||||
assert parts is not None
|
assert parts is not None
|
||||||
dataset = parts.group(1)
|
dataset = parts.group(1)
|
||||||
data_dir = Path(preprocess_dir) / dataset / entropy_model_name
|
data_dir = os.path.join(preprocess_dir, dataset, entropy_model_name)
|
||||||
shard_files = list(data_dir.glob(f"{jsonl_file.name}.shard_*.arrow"))
|
data_dir_with_glob = os.path.join(
|
||||||
|
data_dir, f"{os.path.basename(jsonl_file)}.shard_*.arrow"
|
||||||
|
)
|
||||||
|
if self.fs is None:
|
||||||
|
self.fs = get_fs(data_dir_with_glob, s3_profile=s3_profile)
|
||||||
|
if isinstance(self.fs, s3fs.S3FileSystem):
|
||||||
|
self.filesystem_type = "s3"
|
||||||
|
else:
|
||||||
|
self.filesystem_type = "file"
|
||||||
|
shard_files = self.fs.glob(data_dir_with_glob)
|
||||||
|
|
||||||
for s in shard_files:
|
for s in shard_files:
|
||||||
if not (data_dir / f"{s.name}.complete").exists():
|
complete_file = os.path.join(
|
||||||
|
data_dir, f"{os.path.basename(s)}.complete"
|
||||||
|
)
|
||||||
|
|
||||||
|
if not self.fs.exists(complete_file):
|
||||||
raise ValueError(f"Missing .complete for input file: {s}")
|
raise ValueError(f"Missing .complete for input file: {s}")
|
||||||
|
|
||||||
shard_files = sorted(shard_files, key=shard_sort_key)
|
shard_files = sorted(shard_files, key=shard_sort_key)
|
||||||
|
@ -92,10 +127,19 @@ class ArrowFileIterator(StatefulIterator):
|
||||||
raise ByteLatentError(
|
raise ByteLatentError(
|
||||||
f"Zero shard_files found corresponding to: {file_path} using preprocess_dir={preprocess_dir} and entropy_model_name={entropy_model_name}, so the search path is data_dir={data_dir} for matches to {jsonl_file.name}.shard_*.arrow"
|
f"Zero shard_files found corresponding to: {file_path} using preprocess_dir={preprocess_dir} and entropy_model_name={entropy_model_name}, so the search path is data_dir={data_dir} for matches to {jsonl_file.name}.shard_*.arrow"
|
||||||
)
|
)
|
||||||
self.dataset_files = [str(f) for f in shard_files]
|
self.dataset_files = [f for f in shard_files]
|
||||||
else:
|
else:
|
||||||
self.preprocess_dir = None
|
self.preprocess_dir = None
|
||||||
self.dataset_files = dataset_files
|
self.dataset_files = dataset_files
|
||||||
|
if dataset_files[0].startswith("s3://"):
|
||||||
|
for f in dataset_files:
|
||||||
|
assert f.startswith("s3://")
|
||||||
|
if self.fs is None:
|
||||||
|
self.fs = get_fs(dataset_files[0], s3_profile=s3_profile)
|
||||||
|
if isinstance(self.fs, s3fs.S3FileSystem):
|
||||||
|
self.filesystem_type = "s3"
|
||||||
|
else:
|
||||||
|
self.filesystem_type = "file"
|
||||||
|
|
||||||
def get_state(self) -> ArrowFileIteratorState:
|
def get_state(self) -> ArrowFileIteratorState:
|
||||||
return ArrowFileIteratorState(
|
return ArrowFileIteratorState(
|
||||||
|
@ -107,13 +151,21 @@ class ArrowFileIterator(StatefulIterator):
|
||||||
entropy_model_name=self.entropy_model_name,
|
entropy_model_name=self.entropy_model_name,
|
||||||
arrow_batch_size=self.arrow_batch_size,
|
arrow_batch_size=self.arrow_batch_size,
|
||||||
dataset_files=self.dataset_files,
|
dataset_files=self.dataset_files,
|
||||||
|
s3_profile=self.s3_profile,
|
||||||
|
filesystem_type=self.filesystem_type,
|
||||||
)
|
)
|
||||||
|
|
||||||
def create_iter(
|
def create_iter(
|
||||||
self,
|
self,
|
||||||
) -> Generator[BltExample, Any, None]:
|
) -> Generator[BltExample, Any, None]:
|
||||||
if self.dataset is None:
|
if self.dataset is None:
|
||||||
self.dataset = pa.dataset.dataset(self.dataset_files, format="arrow")
|
if isinstance(self.fs, s3fs.core.S3FileSystem):
|
||||||
|
filesystem = self.fs
|
||||||
|
else:
|
||||||
|
filesystem = None
|
||||||
|
self.dataset = pa.dataset.dataset(
|
||||||
|
self.dataset_files, format="arrow", filesystem=filesystem
|
||||||
|
)
|
||||||
self.batch_iterator = self.dataset.to_batches(
|
self.batch_iterator = self.dataset.to_batches(
|
||||||
batch_size=self.arrow_batch_size
|
batch_size=self.arrow_batch_size
|
||||||
)
|
)
|
||||||
|
@ -165,7 +217,13 @@ class ArrowFileIterator(StatefulIterator):
|
||||||
self.batch_iterator = None
|
self.batch_iterator = None
|
||||||
self.batch_to_consume = None
|
self.batch_to_consume = None
|
||||||
else:
|
else:
|
||||||
self.dataset = pa.dataset.dataset(self.dataset_files, format="arrow")
|
if isinstance(self.fs, s3fs.core.S3FileSystem):
|
||||||
|
filesystem = self.fs
|
||||||
|
else:
|
||||||
|
filesystem = None
|
||||||
|
self.dataset = pa.dataset.dataset(
|
||||||
|
self.dataset_files, format="arrow", filesystem=filesystem
|
||||||
|
)
|
||||||
self.batch_iterator = self.dataset.to_batches(
|
self.batch_iterator = self.dataset.to_batches(
|
||||||
batch_size=self.arrow_batch_size
|
batch_size=self.arrow_batch_size
|
||||||
)
|
)
|
||||||
|
@ -198,9 +256,14 @@ TRAIN_DATA_FILE_PATTERN = "*.chunk.*.jsonl"
|
||||||
|
|
||||||
|
|
||||||
def find_and_sanitize_chunks(
|
def find_and_sanitize_chunks(
|
||||||
dataset_path: str, world_size: int, file_pattern: str = TRAIN_DATA_FILE_PATTERN
|
dataset_path: str,
|
||||||
|
world_size: int,
|
||||||
|
file_pattern: str = TRAIN_DATA_FILE_PATTERN,
|
||||||
|
s3_profile: str | None = None,
|
||||||
):
|
):
|
||||||
dataset_chunks = [str(p) for p in Path(dataset_path).glob(file_pattern)]
|
fs = get_fs(dataset_path, s3_profile=s3_profile)
|
||||||
|
path_with_glob = os.path.join(dataset_path, file_pattern)
|
||||||
|
dataset_chunks = fs.glob(path_with_glob)
|
||||||
n_chunks = len(dataset_chunks)
|
n_chunks = len(dataset_chunks)
|
||||||
|
|
||||||
if n_chunks > world_size:
|
if n_chunks > world_size:
|
||||||
|
|
|
@ -91,7 +91,7 @@ def init_logger(
|
||||||
log_file: str | None = None,
|
log_file: str | None = None,
|
||||||
*,
|
*,
|
||||||
name: str | None = None,
|
name: str | None = None,
|
||||||
level: str = "NOTSET",
|
level: str = "INFO",
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Setup logging.
|
Setup logging.
|
||||||
|
|
|
@ -20,3 +20,4 @@ altair
|
||||||
submitit
|
submitit
|
||||||
typer
|
typer
|
||||||
rich
|
rich
|
||||||
|
fsspec[full]
|
||||||
|
|
|
@ -5,6 +5,7 @@ import os
|
||||||
import subprocess
|
import subprocess
|
||||||
import time
|
import time
|
||||||
|
|
||||||
|
import fsspec
|
||||||
import requests
|
import requests
|
||||||
from huggingface_hub import snapshot_download
|
from huggingface_hub import snapshot_download
|
||||||
|
|
||||||
|
@ -38,11 +39,21 @@ def download_dataset(repo_id, local_dir, allow_patterns):
|
||||||
print(f"Dataset downloaded to {local_dir}")
|
print(f"Dataset downloaded to {local_dir}")
|
||||||
|
|
||||||
|
|
||||||
def parquet_to_jsonl(dataset, work_dir, src_dir, tgt_dir, ntasks=64):
|
def parquet_to_jsonl(
|
||||||
|
dataset, work_dir, src_dir, tgt_dir, ntasks=64, s3_profile: str | None = None
|
||||||
|
):
|
||||||
from datatrove.executor import LocalPipelineExecutor
|
from datatrove.executor import LocalPipelineExecutor
|
||||||
from datatrove.pipeline.readers import ParquetReader
|
from datatrove.pipeline.readers import ParquetReader
|
||||||
from datatrove.pipeline.writers import JsonlWriter
|
from datatrove.pipeline.writers import JsonlWriter
|
||||||
|
|
||||||
|
if tgt_dir.startswith("s3//"):
|
||||||
|
if s3_profile is None:
|
||||||
|
out_spec = tgt_dir
|
||||||
|
else:
|
||||||
|
out_spec = (tgt_dir, fsspec.filesystem("s3", profile=s3_profile))
|
||||||
|
else:
|
||||||
|
out_spec = tgt_dir
|
||||||
|
|
||||||
pipeline_exec = LocalPipelineExecutor(
|
pipeline_exec = LocalPipelineExecutor(
|
||||||
pipeline=[
|
pipeline=[
|
||||||
ParquetReader(
|
ParquetReader(
|
||||||
|
@ -52,7 +63,7 @@ def parquet_to_jsonl(dataset, work_dir, src_dir, tgt_dir, ntasks=64):
|
||||||
glob_pattern="**/*.parquet",
|
glob_pattern="**/*.parquet",
|
||||||
),
|
),
|
||||||
JsonlWriter(
|
JsonlWriter(
|
||||||
tgt_dir,
|
out_spec,
|
||||||
output_filename=dataset + ".chunk.${rank}.jsonl",
|
output_filename=dataset + ".chunk.${rank}.jsonl",
|
||||||
compression=None,
|
compression=None,
|
||||||
),
|
),
|
||||||
|
@ -77,7 +88,7 @@ def setup_terashuf(work_dir):
|
||||||
return terashuf_dir
|
return terashuf_dir
|
||||||
|
|
||||||
|
|
||||||
def main(dataset, memory, data_dir, seed=42, nchunks=32):
|
def main(dataset, memory, data_dir, seed=42, nchunks=32, s3_profile: str | None = None):
|
||||||
# Configuration
|
# Configuration
|
||||||
repo_id = {
|
repo_id = {
|
||||||
"fineweb_edu": "HuggingFaceFW/fineweb-edu",
|
"fineweb_edu": "HuggingFaceFW/fineweb-edu",
|
||||||
|
|
Loading…
Reference in a new issue