From 28016f144d5fbcc2f279955f0038463186f0de30 Mon Sep 17 00:00:00 2001 From: Pedro Rodriguez Date: Thu, 9 Jan 2025 20:06:18 +0000 Subject: [PATCH 1/6] Add plotting code from paper Summary: Test Plan: --- bytelatent/plotting/config_entropy_figure.yaml | 3 ++- bytelatent/plotting/entropy_figure.py | 17 ++++++++++++++++- plot_data/scores.json | 1 + 3 files changed, 19 insertions(+), 2 deletions(-) create mode 100644 plot_data/scores.json diff --git a/bytelatent/plotting/config_entropy_figure.yaml b/bytelatent/plotting/config_entropy_figure.yaml index 4d7bfd7..296ea07 100644 --- a/bytelatent/plotting/config_entropy_figure.yaml +++ b/bytelatent/plotting/config_entropy_figure.yaml @@ -1,3 +1,4 @@ data_path: plot_data/entropy_figure.json chart_path: figures/entropy_figure.pdf -# chart_path: figures/entropy_figure.pdf +threshold_override: 1.7171002626419067 +score_override_path: plot_data/scores.json diff --git a/bytelatent/plotting/entropy_figure.py b/bytelatent/plotting/entropy_figure.py index c1966a1..f401d7c 100644 --- a/bytelatent/plotting/entropy_figure.py +++ b/bytelatent/plotting/entropy_figure.py @@ -1,4 +1,5 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. +import json import os import sys from pathlib import Path @@ -12,6 +13,8 @@ from pydantic import BaseModel class PlotEntropiesConfig(BaseModel): data_path: str | None chart_path: str + score_override_path: str | None = None + threshold_override: float | None = None class Config: extra = "forbid" @@ -37,8 +40,20 @@ def main(): plot_config = PlotEntropiesConfig(**conf_dict) with open(plot_config.data_path) as f: json_data = f.read() + plot_data = PlotEntropiesData.model_validate_json(json_data) df = pd.read_json(plot_data.dataframe_json) + print("LEN", len(df)) + if plot_config.threshold_override is None: + threshold = plot_data.threshold + else: + threshold = plot_config.threshold_override + if plot_config.score_override_path is not None: + with open(plot_config.score_override_path) as f: + scores = json.load(f)["score"] + assert len(scores) == len(df) + df["entropies"] = scores + df["start"] = [1] + (df["entropies"] > threshold).values.tolist()[:-1] x_ticks = [] for row in df.itertuples(): @@ -65,7 +80,7 @@ def main(): ), ) rule = base.mark_rule(color="red", strokeDash=[4, 4]).encode( - y=alt.datum(plot_data.threshold), + y=alt.datum(threshold), ) patch_rules = ( alt.Chart(df[df["start"] > 0]) diff --git a/plot_data/scores.json b/plot_data/scores.json new file mode 100644 index 0000000..202cafc --- /dev/null +++ b/plot_data/scores.json @@ -0,0 +1 @@ +{"score": [3.3949153423309326, 2.1647746562957764, 2.3216569423675537, 2.8114914894104004, 1.505232334136963, 0.04055612534284592, 0.09150367230176926, 0.06008715182542801, 0.3453567624092102, 1.0483067035675049, 0.1967127025127411, 0.12737397849559784, 0.05923430994153023, 0.001597292022779584, 0.004362526815384626, 0.005547997076064348, 0.0011689786333590746, 0.0010273229563608766, 1.0228447914123535, 3.6863417625427246, 0.46605175733566284, 0.048645928502082825, 2.2544963359832764, 0.37329360842704773, 1.001160979270935, 2.9116122722625732, 1.8948925733566284, 1.4017235040664673, 0.3879640996456146, 0.2652309536933899, 1.780383825302124, 0.013964788988232613, 0.005456871818751097, 0.5426468253135681, 0.20666983723640442, 0.0051853349432349205, 0.0005802579107694328, 0.0007443525246344507, 0.0004390323010738939, 0.005452247802168131, 1.1932975053787231, 0.023798620328307152, 3.1230878829956055, 1.3915895223617554, 3.0489213466644287, 1.7018193006515503, 1.873910903930664, 1.4662408828735352, 0.004920408595353365, 0.02599342167377472, 0.6620859503746033, 0.31743818521499634, 2.8409600257873535, 1.1354060173034668, 0.0520976223051548, 0.3519965708255768, 0.40707266330718994, 2.5438783168792725, 1.3343133926391602, 0.023993035778403282, 3.445943832397461, 1.8542104959487915, 0.7849258780479431, 0.6848396062850952, 0.06938046962022781, 0.20923230051994324, 0.10084306448698044, 0.18334199488162994, 0.4126923978328705, 0.5505472421646118, 0.1042013093829155, 0.019447727128863335, 0.0014866517158225179, 0.0009848219342529774, 0.00021391961490735412, 0.007746236398816109, 0.00038792978739365935, 0.0007933690212666988, 1.2369810342788696, 0.4436197578907013, 4.6366687456611544e-05]} \ No newline at end of file From 84854423c49fa59187855785c935d6cdbc9e0c45 Mon Sep 17 00:00:00 2001 From: Pedro Rodriguez Date: Fri, 10 Jan 2025 01:00:54 +0000 Subject: [PATCH 2/6] Replace regular filesystem calls with fsspec + add s3 support Summary: For compatibility with either local/nfs or S3 datasets, swap to fsspec. Add a tool to compare local and remote filesystems Test Plan: - Ran regular train script - Ran with config with data in S3 --- .gitignore | 2 +- bytelatent/args.py | 8 +- bytelatent/data/file_util.py | 116 ++++++++++++++++++++ bytelatent/data/iterators/arrow_iterator.py | 95 +++++++++++++--- bytelatent/logger.py | 2 +- requirements.txt | 1 + setup/download_prepare_hf_data.py | 17 ++- 7 files changed, 222 insertions(+), 19 deletions(-) create mode 100644 bytelatent/data/file_util.py diff --git a/.gitignore b/.gitignore index 56891a9..6c664b8 100644 --- a/.gitignore +++ b/.gitignore @@ -165,4 +165,4 @@ cython_debug/ figures/ .vscode/ .DS_Store - +internal/ diff --git a/bytelatent/args.py b/bytelatent/args.py index 4fba100..cfba3bf 100644 --- a/bytelatent/args.py +++ b/bytelatent/args.py @@ -46,8 +46,11 @@ def distribute_data_to_rank( arrow_batch_size: int, rank: int, world_size: int, + s3_profile: str | None = None, ) -> ArrowFileIterator: - dataset_chunks = find_and_sanitize_chunks(dataset_path, world_size) + dataset_chunks = find_and_sanitize_chunks( + dataset_path, world_size, s3_profile=s3_profile + ) n_workers_per_chunk = world_size // len(dataset_chunks) rank_to_arrow_iterator_params = [] for chunk_path in dataset_chunks: @@ -61,6 +64,7 @@ def distribute_data_to_rank( dataset_files=None, entropy_model_name=entropy_model_name, arrow_batch_size=arrow_batch_size, + s3_profile=s3_profile, ) ) return rank_to_arrow_iterator_params[rank] @@ -68,6 +72,7 @@ def distribute_data_to_rank( class DataloaderArgs(BaseModel): model_config = ConfigDict(extra="forbid") + s3_profile: str | None = None root_dir: str | None = None sources: dict[str, float] = {} batch_size: int = 2 @@ -107,6 +112,7 @@ class DataloaderArgs(BaseModel): arrow_batch_size=self.arrow_batch_size, rank=rank, world_size=world_size, + s3_profile=self.s3_profile, ) looping_iterator = LoopingIterator(arrow_iterator) preprocess_iterator = PreprocessIterator( diff --git a/bytelatent/data/file_util.py b/bytelatent/data/file_util.py new file mode 100644 index 0000000..312e1b4 --- /dev/null +++ b/bytelatent/data/file_util.py @@ -0,0 +1,116 @@ +import os +from rich.progress import track +import typer +import fsspec +import pyarrow as pa +from pyarrow.lib import ArrowInvalid + +# pyarrow needs the initialization from this import +import pyarrow.dataset # pyright: ignore + + +def is_valid_arrow_file(path: str): + try: + dataset = pa.dataset.dataset(path, format="arrow") + return True + except ArrowInvalid: + return False + + +app = typer.Typer() + +S3_PREFIX = "s3://" + + +def get_fs(path: str, s3_profile: str | None = None) -> fsspec.AbstractFileSystem: + if path.startswith("s3://"): + if s3_profile is None: + return fsspec.filesystem("s3") + else: + return fsspec.filesystem("s3", profile=s3_profile) + else: + return fsspec.filesystem("file") + + +@app.command() +def print_local_to_delete( + blob_dir: str, local_dirs: list[str], s3_profile: str = "blt" +): + for s in local_dirs: + assert s.endswith("/"), "Dirs must end with /" + assert blob_dir.endswith("/"), "Dirs must end with /" + blob_fs = fsspec.filesystem("s3", profile=s3_profile) + blob_files = blob_fs.find(blob_dir) + for f in track(blob_files): + size = blob_fs.info(f)["Size"] + if not f.lower().endswith(".complete"): + assert size != 0, f"Size was invalidly zero for {f}" + + blob_relative_paths = {f[len(blob_dir) - len(S3_PREFIX) :] for f in blob_files} + local_fs = fsspec.filesystem("file") + + files_to_delete = [] + for local_dir in local_dirs: + local_files = local_fs.find(local_dir) + for f in local_files: + relative_path = f[len(local_dir) :] + if relative_path in blob_relative_paths and not os.path.islink(f): + files_to_delete.append(f) + print(len(files_to_delete)) + with open("/tmp/files_to_delete.txt", "w") as f: + for file in files_to_delete: + f.write(f"{file}\n") + + +@app.command() +def compare_local_to_blob( + source_dirs: list[str], dst_dir: str, s3_profile: str = "blt" +): + for s in source_dirs: + assert s.endswith("/"), "Dirs must end with /" + assert dst_dir.endswith("/"), "Dirs must end with /" + assert len(source_dirs) != 0 + assert dst_dir.startswith("s3://") + local_fs = fsspec.filesystem("file") + dst_fs = fsspec.filesystem("s3", profile=s3_profile) + source_to_files = {} + all_local_files = set() + for s in source_dirs: + skipped = [] + if s not in source_to_files: + source_to_files[s] = [] + for f in local_fs.find(s): + if os.path.islink(f): + continue + if f.endswith(".COMPLETE") or f.endswith(".complete"): + is_complete_file = True + assert os.path.getsize(f) == 0, ".COMPLETE files should be empty" + else: + is_complete_file = False + + if not is_complete_file and os.path.getsize(f) == 0: + skipped.append(f) + continue + if f.endswith(".arrow"): + if not is_valid_arrow_file(f): + skipped.append(f) + continue + + source_to_files[s].append(f) + all_local_files.add(f[len(s) :]) + print(s, len(source_to_files[s]), "skipped", len(skipped), skipped[:10]) + + dst_files = dst_fs.find(dst_dir) + print(dst_dir, len(dst_files)) + + dst_file_set = {f[len(dst_dir) - len(S3_PREFIX) :] for f in dst_files} + diff = all_local_files.symmetric_difference(dst_file_set) + print("Local files", len(all_local_files)) + print("DST Files", len(dst_file_set)) + print("Symmetric difference", len(diff)) + dst_only_files = dst_file_set - all_local_files + print("DST only", len(dst_only_files), list(dst_only_files)[:10]) + + +if __name__ == "__main__": + app() diff --git a/bytelatent/data/iterators/arrow_iterator.py b/bytelatent/data/iterators/arrow_iterator.py index df5f023..17569ee 100644 --- a/bytelatent/data/iterators/arrow_iterator.py +++ b/bytelatent/data/iterators/arrow_iterator.py @@ -1,7 +1,7 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. +import os import re from logging import getLogger -from pathlib import Path from typing import Any, Generator import pyarrow as pa @@ -9,9 +9,12 @@ import pyarrow as pa # pyarrow needs the initialization from this import import pyarrow.dataset # pyright: ignore from pydantic import BaseModel, ConfigDict +import fsspec +import s3fs from bytelatent import ByteLatentError from bytelatent.data.data_types import BltExample +from bytelatent.data.file_util import get_fs from bytelatent.data.iterators.abstract_iterator import IteratorState, StatefulIterator logger = getLogger(__name__) @@ -27,6 +30,8 @@ class ArrowFileIteratorState(BaseModel, IteratorState): dataset_files: list[str] | None entropy_model_name: str | None arrow_batch_size: int = 100 + s3_profile: str | None + filesystem_type: str | None = None def build(self) -> "ArrowFileIterator": arrow_file = ArrowFileIterator( @@ -37,14 +42,17 @@ class ArrowFileIteratorState(BaseModel, IteratorState): entropy_model_name=self.entropy_model_name, arrow_batch_size=self.arrow_batch_size, dataset_files=self.dataset_files, + s3_profile=self.s3_profile, + filesystem_type=self.filesystem_type, ) if self.row_num != 0: arrow_file._set_row_num(self.row_num) return arrow_file -def shard_sort_key(file: str | Path): - match = re.search(r".+\.shard_([0-9]+)\.arrow", str(file)) +def shard_sort_key(file: str): + assert isinstance(file, str) + match = re.search(r".+\.shard_([0-9]+)\.arrow", file) shard_number = int(match.group(1)) return shard_number @@ -60,6 +68,8 @@ class ArrowFileIterator(StatefulIterator): entropy_model_name: str | None, arrow_batch_size: int, dataset_files: list[str] | None = None, + s3_profile: str | None = None, + filesystem_type: str | None = None, ): assert 0 <= worker_id < num_workers, (worker_id, num_workers) if file_path is None and dataset_files is None: @@ -75,16 +85,44 @@ class ArrowFileIterator(StatefulIterator): self.preprocess_dir = preprocess_dir self.entropy_model_name = entropy_model_name self.arrow_batch_size = arrow_batch_size + self.s3_profile = s3_profile + self.filesystem_type = filesystem_type + self.fs = None + if self.filesystem_type is not None: + if self.filesystem_type == "file": + self.fs = fsspec.filesystem("file") + elif self.filesystem_type == "s3": + self.fs = fsspec.filesystem("s3", profile=s3_profile) + if dataset_files is None: # Prepare arrow shards - jsonl_file = Path(file_path) - parts = re.match(r"(.+)\.chunk\.[0-9]+\.jsonl", jsonl_file.name) + jsonl_file = file_path + parts = re.match( + r"(.+)\.chunk\.[0-9]+\.jsonl", os.path.basename(jsonl_file) + ) assert parts is not None dataset = parts.group(1) - data_dir = Path(preprocess_dir) / dataset / entropy_model_name - shard_files = list(data_dir.glob(f"{jsonl_file.name}.shard_*.arrow")) + data_dir = os.path.join(preprocess_dir, dataset, entropy_model_name) + data_dir_with_glob = os.path.join( + data_dir, f"{os.path.basename(jsonl_file)}.shard_*.arrow" + ) + print("preprocess_dir", preprocess_dir) + print("data_dir_with_glob", data_dir_with_glob) + if self.fs is None: + self.fs = get_fs(data_dir_with_glob, s3_profile=s3_profile) + if isinstance(self.fs, s3fs.S3FileSystem): + self.filesystem_type = "s3" + else: + self.filesystem_type = "file" + print("fs", self.fs) + shard_files = self.fs.glob(data_dir_with_glob) + for s in shard_files: - if not (data_dir / f"{s.name}.complete").exists(): + complete_file = os.path.join( + data_dir, f"{os.path.basename(s)}.complete" + ) + + if not self.fs.exists(complete_file): raise ValueError(f"Missing .complete for input file: {s}") shard_files = sorted(shard_files, key=shard_sort_key) @@ -92,10 +130,19 @@ class ArrowFileIterator(StatefulIterator): raise ByteLatentError( f"Zero shard_files found corresponding to: {file_path} using preprocess_dir={preprocess_dir} and entropy_model_name={entropy_model_name}, so the search path is data_dir={data_dir} for matches to {jsonl_file.name}.shard_*.arrow" ) - self.dataset_files = [str(f) for f in shard_files] + self.dataset_files = [f for f in shard_files] else: self.preprocess_dir = None self.dataset_files = dataset_files + if dataset_files[0].startswith("s3://"): + for f in dataset_files: + assert f.startswith("s3://") + if self.fs is None: + self.fs = get_fs(dataset_files[0], s3_profile=s3_profile) + if isinstance(self.fs, s3fs.S3FileSystem): + self.filesystem_type = "s3" + else: + self.filesystem_type = "file" def get_state(self) -> ArrowFileIteratorState: return ArrowFileIteratorState( @@ -107,13 +154,24 @@ class ArrowFileIterator(StatefulIterator): entropy_model_name=self.entropy_model_name, arrow_batch_size=self.arrow_batch_size, dataset_files=self.dataset_files, + s3_profile=self.s3_profile, + filesystem_type=self.filesystem_type, ) def create_iter( self, ) -> Generator[BltExample, Any, None]: if self.dataset is None: - self.dataset = pa.dataset.dataset(self.dataset_files, format="arrow") + if isinstance(self.fs, s3fs.core.S3FileSystem): + print("Using s3fs") + filesystem = self.fs + else: + print("Using local fs") + filesystem = None + print("Iterating over", self.dataset_files) + self.dataset = pa.dataset.dataset( + self.dataset_files, format="arrow", filesystem=filesystem + ) self.batch_iterator = self.dataset.to_batches( batch_size=self.arrow_batch_size ) @@ -165,7 +223,13 @@ class ArrowFileIterator(StatefulIterator): self.batch_iterator = None self.batch_to_consume = None else: - self.dataset = pa.dataset.dataset(self.dataset_files, format="arrow") + if isinstance(self.fs, s3fs.core.S3FileSystem): + filesystem = self.fs + else: + filesystem = None + self.dataset = pa.dataset.dataset( + self.dataset_files, format="arrow", filesystem=filesystem + ) self.batch_iterator = self.dataset.to_batches( batch_size=self.arrow_batch_size ) @@ -198,9 +262,14 @@ TRAIN_DATA_FILE_PATTERN = "*.chunk.*.jsonl" def find_and_sanitize_chunks( - dataset_path: str, world_size: int, file_pattern: str = TRAIN_DATA_FILE_PATTERN + dataset_path: str, + world_size: int, + file_pattern: str = TRAIN_DATA_FILE_PATTERN, + s3_profile: str | None = None, ): - dataset_chunks = [str(p) for p in Path(dataset_path).glob(file_pattern)] + fs = get_fs(dataset_path, s3_profile=s3_profile) + path_with_glob = os.path.join(dataset_path, file_pattern) + dataset_chunks = fs.glob(path_with_glob) n_chunks = len(dataset_chunks) if n_chunks > world_size: diff --git a/bytelatent/logger.py b/bytelatent/logger.py index 6723a84..87f04cc 100644 --- a/bytelatent/logger.py +++ b/bytelatent/logger.py @@ -91,7 +91,7 @@ def init_logger( log_file: str | None = None, *, name: str | None = None, - level: str = "NOTSET", + level: str = "INFO", ): """ Setup logging. diff --git a/requirements.txt b/requirements.txt index c6d87f1..59192cd 100644 --- a/requirements.txt +++ b/requirements.txt @@ -20,3 +20,4 @@ altair submitit typer rich +fsspec[full] diff --git a/setup/download_prepare_hf_data.py b/setup/download_prepare_hf_data.py index 1aacf8f..e194857 100644 --- a/setup/download_prepare_hf_data.py +++ b/setup/download_prepare_hf_data.py @@ -5,6 +5,7 @@ import os import subprocess import time +import fsspec import requests from huggingface_hub import snapshot_download @@ -38,11 +39,21 @@ def download_dataset(repo_id, local_dir, allow_patterns): print(f"Dataset downloaded to {local_dir}") -def parquet_to_jsonl(dataset, work_dir, src_dir, tgt_dir, ntasks=64): +def parquet_to_jsonl( + dataset, work_dir, src_dir, tgt_dir, ntasks=64, s3_profile: str | None = None +): from datatrove.executor import LocalPipelineExecutor from datatrove.pipeline.readers import ParquetReader from datatrove.pipeline.writers import JsonlWriter + if tgt_dir.startswith("s3//"): + if s3_profile is None: + out_spec = tgt_dir + else: + out_spec = (tgt_dir, fsspec.filesystem("s3", profile=s3_profile)) + else: + out_spec = tgt_dir + pipeline_exec = LocalPipelineExecutor( pipeline=[ ParquetReader( @@ -52,7 +63,7 @@ def parquet_to_jsonl(dataset, work_dir, src_dir, tgt_dir, ntasks=64): glob_pattern="**/*.parquet", ), JsonlWriter( - tgt_dir, + out_spec, output_filename=dataset + ".chunk.${rank}.jsonl", compression=None, ), @@ -77,7 +88,7 @@ def setup_terashuf(work_dir): return terashuf_dir -def main(dataset, memory, data_dir, seed=42, nchunks=32): +def main(dataset, memory, data_dir, seed=42, nchunks=32, s3_profile: str | None = None): # Configuration repo_id = { "fineweb_edu": "HuggingFaceFW/fineweb-edu", From c137f4e636ee85ef4110b416b404c455b86efba9 Mon Sep 17 00:00:00 2001 From: Pedro Rodriguez Date: Fri, 10 Jan 2025 01:02:25 +0000 Subject: [PATCH 3/6] Replace regular filesystem calls with fsspec + add s3 support Summary: For compatibility with either local/nfs or S3 datasets, swap to fsspec. Add a tool to compare local and remote filesystems Test Plan: - Ran regular train script - Ran with config with data in S3 --- .gitignore | 2 +- bytelatent/args.py | 8 +- bytelatent/data/file_util.py | 116 ++++++++++++++++++++ bytelatent/data/iterators/arrow_iterator.py | 95 +++++++++++++--- bytelatent/logger.py | 2 +- requirements.txt | 1 + setup/download_prepare_hf_data.py | 17 ++- 7 files changed, 222 insertions(+), 19 deletions(-) create mode 100644 bytelatent/data/file_util.py diff --git a/.gitignore b/.gitignore index 56891a9..6c664b8 100644 --- a/.gitignore +++ b/.gitignore @@ -165,4 +165,4 @@ cython_debug/ figures/ .vscode/ .DS_Store - +internal/ diff --git a/bytelatent/args.py b/bytelatent/args.py index 4fba100..cfba3bf 100644 --- a/bytelatent/args.py +++ b/bytelatent/args.py @@ -46,8 +46,11 @@ def distribute_data_to_rank( arrow_batch_size: int, rank: int, world_size: int, + s3_profile: str | None = None, ) -> ArrowFileIterator: - dataset_chunks = find_and_sanitize_chunks(dataset_path, world_size) + dataset_chunks = find_and_sanitize_chunks( + dataset_path, world_size, s3_profile=s3_profile + ) n_workers_per_chunk = world_size // len(dataset_chunks) rank_to_arrow_iterator_params = [] for chunk_path in dataset_chunks: @@ -61,6 +64,7 @@ def distribute_data_to_rank( dataset_files=None, entropy_model_name=entropy_model_name, arrow_batch_size=arrow_batch_size, + s3_profile=s3_profile, ) ) return rank_to_arrow_iterator_params[rank] @@ -68,6 +72,7 @@ def distribute_data_to_rank( class DataloaderArgs(BaseModel): model_config = ConfigDict(extra="forbid") + s3_profile: str | None = None root_dir: str | None = None sources: dict[str, float] = {} batch_size: int = 2 @@ -107,6 +112,7 @@ class DataloaderArgs(BaseModel): arrow_batch_size=self.arrow_batch_size, rank=rank, world_size=world_size, + s3_profile=self.s3_profile, ) looping_iterator = LoopingIterator(arrow_iterator) preprocess_iterator = PreprocessIterator( diff --git a/bytelatent/data/file_util.py b/bytelatent/data/file_util.py new file mode 100644 index 0000000..312e1b4 --- /dev/null +++ b/bytelatent/data/file_util.py @@ -0,0 +1,116 @@ +import os +from rich.progress import track +import typer +import fsspec +import pyarrow as pa +from pyarrow.lib import ArrowInvalid + +# pyarrow needs the initialization from this import +import pyarrow.dataset # pyright: ignore + + +def is_valid_arrow_file(path: str): + try: + dataset = pa.dataset.dataset(path, format="arrow") + return True + except ArrowInvalid: + return False + + +app = typer.Typer() + +S3_PREFIX = "s3://" + + +def get_fs(path: str, s3_profile: str | None = None) -> fsspec.AbstractFileSystem: + if path.startswith("s3://"): + if s3_profile is None: + return fsspec.filesystem("s3") + else: + return fsspec.filesystem("s3", profile=s3_profile) + else: + return fsspec.filesystem("file") + + +@app.command() +def print_local_to_delete( + blob_dir: str, local_dirs: list[str], s3_profile: str = "blt" +): + for s in local_dirs: + assert s.endswith("/"), "Dirs must end with /" + assert blob_dir.endswith("/"), "Dirs must end with /" + blob_fs = fsspec.filesystem("s3", profile=s3_profile) + blob_files = blob_fs.find(blob_dir) + for f in track(blob_files): + size = blob_fs.info(f)["Size"] + if not f.lower().endswith(".complete"): + assert size != 0, f"Size was invalidly zero for {f}" + + blob_relative_paths = {f[len(blob_dir) - len(S3_PREFIX) :] for f in blob_files} + local_fs = fsspec.filesystem("file") + + files_to_delete = [] + for local_dir in local_dirs: + local_files = local_fs.find(local_dir) + for f in local_files: + relative_path = f[len(local_dir) :] + if relative_path in blob_relative_paths and not os.path.islink(f): + files_to_delete.append(f) + print(len(files_to_delete)) + with open("/tmp/files_to_delete.txt", "w") as f: + for file in files_to_delete: + f.write(f"{file}\n") + + +@app.command() +def compare_local_to_blob( + source_dirs: list[str], dst_dir: str, s3_profile: str = "blt" +): + for s in source_dirs: + assert s.endswith("/"), "Dirs must end with /" + assert dst_dir.endswith("/"), "Dirs must end with /" + assert len(source_dirs) != 0 + assert dst_dir.startswith("s3://") + local_fs = fsspec.filesystem("file") + dst_fs = fsspec.filesystem("s3", profile=s3_profile) + source_to_files = {} + all_local_files = set() + for s in source_dirs: + skipped = [] + if s not in source_to_files: + source_to_files[s] = [] + for f in local_fs.find(s): + if os.path.islink(f): + continue + if f.endswith(".COMPLETE") or f.endswith(".complete"): + is_complete_file = True + assert os.path.getsize(f) == 0, ".COMPLETE files should be empty" + else: + is_complete_file = False + + if not is_complete_file and os.path.getsize(f) == 0: + skipped.append(f) + continue + if f.endswith(".arrow"): + if not is_valid_arrow_file(f): + skipped.append(f) + continue + + source_to_files[s].append(f) + all_local_files.add(f[len(s) :]) + print(s, len(source_to_files[s]), "skipped", len(skipped), skipped[:10]) + + dst_files = dst_fs.find(dst_dir) + print(dst_dir, len(dst_files)) + + dst_file_set = {f[len(dst_dir) - len(S3_PREFIX) :] for f in dst_files} + diff = all_local_files.symmetric_difference(dst_file_set) + print("Local files", len(all_local_files)) + print("DST Files", len(dst_file_set)) + print("Symmetric difference", len(diff)) + dst_only_files = dst_file_set - all_local_files + print("DST only", len(dst_only_files), list(dst_only_files)[:10]) + + +if __name__ == "__main__": + app() diff --git a/bytelatent/data/iterators/arrow_iterator.py b/bytelatent/data/iterators/arrow_iterator.py index df5f023..17569ee 100644 --- a/bytelatent/data/iterators/arrow_iterator.py +++ b/bytelatent/data/iterators/arrow_iterator.py @@ -1,7 +1,7 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. +import os import re from logging import getLogger -from pathlib import Path from typing import Any, Generator import pyarrow as pa @@ -9,9 +9,12 @@ import pyarrow as pa # pyarrow needs the initialization from this import import pyarrow.dataset # pyright: ignore from pydantic import BaseModel, ConfigDict +import fsspec +import s3fs from bytelatent import ByteLatentError from bytelatent.data.data_types import BltExample +from bytelatent.data.file_util import get_fs from bytelatent.data.iterators.abstract_iterator import IteratorState, StatefulIterator logger = getLogger(__name__) @@ -27,6 +30,8 @@ class ArrowFileIteratorState(BaseModel, IteratorState): dataset_files: list[str] | None entropy_model_name: str | None arrow_batch_size: int = 100 + s3_profile: str | None + filesystem_type: str | None = None def build(self) -> "ArrowFileIterator": arrow_file = ArrowFileIterator( @@ -37,14 +42,17 @@ class ArrowFileIteratorState(BaseModel, IteratorState): entropy_model_name=self.entropy_model_name, arrow_batch_size=self.arrow_batch_size, dataset_files=self.dataset_files, + s3_profile=self.s3_profile, + filesystem_type=self.filesystem_type, ) if self.row_num != 0: arrow_file._set_row_num(self.row_num) return arrow_file -def shard_sort_key(file: str | Path): - match = re.search(r".+\.shard_([0-9]+)\.arrow", str(file)) +def shard_sort_key(file: str): + assert isinstance(file, str) + match = re.search(r".+\.shard_([0-9]+)\.arrow", file) shard_number = int(match.group(1)) return shard_number @@ -60,6 +68,8 @@ class ArrowFileIterator(StatefulIterator): entropy_model_name: str | None, arrow_batch_size: int, dataset_files: list[str] | None = None, + s3_profile: str | None = None, + filesystem_type: str | None = None, ): assert 0 <= worker_id < num_workers, (worker_id, num_workers) if file_path is None and dataset_files is None: @@ -75,16 +85,44 @@ class ArrowFileIterator(StatefulIterator): self.preprocess_dir = preprocess_dir self.entropy_model_name = entropy_model_name self.arrow_batch_size = arrow_batch_size + self.s3_profile = s3_profile + self.filesystem_type = filesystem_type + self.fs = None + if self.filesystem_type is not None: + if self.filesystem_type == "file": + self.fs = fsspec.filesystem("file") + elif self.filesystem_type == "s3": + self.fs = fsspec.filesystem("s3", profile=s3_profile) + if dataset_files is None: # Prepare arrow shards - jsonl_file = Path(file_path) - parts = re.match(r"(.+)\.chunk\.[0-9]+\.jsonl", jsonl_file.name) + jsonl_file = file_path + parts = re.match( + r"(.+)\.chunk\.[0-9]+\.jsonl", os.path.basename(jsonl_file) + ) assert parts is not None dataset = parts.group(1) - data_dir = Path(preprocess_dir) / dataset / entropy_model_name - shard_files = list(data_dir.glob(f"{jsonl_file.name}.shard_*.arrow")) + data_dir = os.path.join(preprocess_dir, dataset, entropy_model_name) + data_dir_with_glob = os.path.join( + data_dir, f"{os.path.basename(jsonl_file)}.shard_*.arrow" + ) + print("preprocess_dir", preprocess_dir) + print("data_dir_with_glob", data_dir_with_glob) + if self.fs is None: + self.fs = get_fs(data_dir_with_glob, s3_profile=s3_profile) + if isinstance(self.fs, s3fs.S3FileSystem): + self.filesystem_type = "s3" + else: + self.filesystem_type = "file" + print("fs", self.fs) + shard_files = self.fs.glob(data_dir_with_glob) + for s in shard_files: - if not (data_dir / f"{s.name}.complete").exists(): + complete_file = os.path.join( + data_dir, f"{os.path.basename(s)}.complete" + ) + + if not self.fs.exists(complete_file): raise ValueError(f"Missing .complete for input file: {s}") shard_files = sorted(shard_files, key=shard_sort_key) @@ -92,10 +130,19 @@ class ArrowFileIterator(StatefulIterator): raise ByteLatentError( f"Zero shard_files found corresponding to: {file_path} using preprocess_dir={preprocess_dir} and entropy_model_name={entropy_model_name}, so the search path is data_dir={data_dir} for matches to {jsonl_file.name}.shard_*.arrow" ) - self.dataset_files = [str(f) for f in shard_files] + self.dataset_files = [f for f in shard_files] else: self.preprocess_dir = None self.dataset_files = dataset_files + if dataset_files[0].startswith("s3://"): + for f in dataset_files: + assert f.startswith("s3://") + if self.fs is None: + self.fs = get_fs(dataset_files[0], s3_profile=s3_profile) + if isinstance(self.fs, s3fs.S3FileSystem): + self.filesystem_type = "s3" + else: + self.filesystem_type = "file" def get_state(self) -> ArrowFileIteratorState: return ArrowFileIteratorState( @@ -107,13 +154,24 @@ class ArrowFileIterator(StatefulIterator): entropy_model_name=self.entropy_model_name, arrow_batch_size=self.arrow_batch_size, dataset_files=self.dataset_files, + s3_profile=self.s3_profile, + filesystem_type=self.filesystem_type, ) def create_iter( self, ) -> Generator[BltExample, Any, None]: if self.dataset is None: - self.dataset = pa.dataset.dataset(self.dataset_files, format="arrow") + if isinstance(self.fs, s3fs.core.S3FileSystem): + print("Using s3fs") + filesystem = self.fs + else: + print("Using local fs") + filesystem = None + print("Iterating over", self.dataset_files) + self.dataset = pa.dataset.dataset( + self.dataset_files, format="arrow", filesystem=filesystem + ) self.batch_iterator = self.dataset.to_batches( batch_size=self.arrow_batch_size ) @@ -165,7 +223,13 @@ class ArrowFileIterator(StatefulIterator): self.batch_iterator = None self.batch_to_consume = None else: - self.dataset = pa.dataset.dataset(self.dataset_files, format="arrow") + if isinstance(self.fs, s3fs.core.S3FileSystem): + filesystem = self.fs + else: + filesystem = None + self.dataset = pa.dataset.dataset( + self.dataset_files, format="arrow", filesystem=filesystem + ) self.batch_iterator = self.dataset.to_batches( batch_size=self.arrow_batch_size ) @@ -198,9 +262,14 @@ TRAIN_DATA_FILE_PATTERN = "*.chunk.*.jsonl" def find_and_sanitize_chunks( - dataset_path: str, world_size: int, file_pattern: str = TRAIN_DATA_FILE_PATTERN + dataset_path: str, + world_size: int, + file_pattern: str = TRAIN_DATA_FILE_PATTERN, + s3_profile: str | None = None, ): - dataset_chunks = [str(p) for p in Path(dataset_path).glob(file_pattern)] + fs = get_fs(dataset_path, s3_profile=s3_profile) + path_with_glob = os.path.join(dataset_path, file_pattern) + dataset_chunks = fs.glob(path_with_glob) n_chunks = len(dataset_chunks) if n_chunks > world_size: diff --git a/bytelatent/logger.py b/bytelatent/logger.py index 6723a84..87f04cc 100644 --- a/bytelatent/logger.py +++ b/bytelatent/logger.py @@ -91,7 +91,7 @@ def init_logger( log_file: str | None = None, *, name: str | None = None, - level: str = "NOTSET", + level: str = "INFO", ): """ Setup logging. diff --git a/requirements.txt b/requirements.txt index c6d87f1..59192cd 100644 --- a/requirements.txt +++ b/requirements.txt @@ -20,3 +20,4 @@ altair submitit typer rich +fsspec[full] diff --git a/setup/download_prepare_hf_data.py b/setup/download_prepare_hf_data.py index 1aacf8f..e194857 100644 --- a/setup/download_prepare_hf_data.py +++ b/setup/download_prepare_hf_data.py @@ -5,6 +5,7 @@ import os import subprocess import time +import fsspec import requests from huggingface_hub import snapshot_download @@ -38,11 +39,21 @@ def download_dataset(repo_id, local_dir, allow_patterns): print(f"Dataset downloaded to {local_dir}") -def parquet_to_jsonl(dataset, work_dir, src_dir, tgt_dir, ntasks=64): +def parquet_to_jsonl( + dataset, work_dir, src_dir, tgt_dir, ntasks=64, s3_profile: str | None = None +): from datatrove.executor import LocalPipelineExecutor from datatrove.pipeline.readers import ParquetReader from datatrove.pipeline.writers import JsonlWriter + if tgt_dir.startswith("s3//"): + if s3_profile is None: + out_spec = tgt_dir + else: + out_spec = (tgt_dir, fsspec.filesystem("s3", profile=s3_profile)) + else: + out_spec = tgt_dir + pipeline_exec = LocalPipelineExecutor( pipeline=[ ParquetReader( @@ -52,7 +63,7 @@ def parquet_to_jsonl(dataset, work_dir, src_dir, tgt_dir, ntasks=64): glob_pattern="**/*.parquet", ), JsonlWriter( - tgt_dir, + out_spec, output_filename=dataset + ".chunk.${rank}.jsonl", compression=None, ), @@ -77,7 +88,7 @@ def setup_terashuf(work_dir): return terashuf_dir -def main(dataset, memory, data_dir, seed=42, nchunks=32): +def main(dataset, memory, data_dir, seed=42, nchunks=32, s3_profile: str | None = None): # Configuration repo_id = { "fineweb_edu": "HuggingFaceFW/fineweb-edu", From 1bf6d15e5a3ff76f128288fc547e2604f1981e2f Mon Sep 17 00:00:00 2001 From: Pedro Rodriguez Date: Fri, 10 Jan 2025 01:02:25 +0000 Subject: [PATCH 4/6] Replace regular filesystem calls with fsspec + add s3 support Summary: For compatibility with either local/nfs or S3 datasets, swap to fsspec. Add a tool to compare local and remote filesystems Test Plan: - Ran regular train script - Ran with config with data in S3 --- .gitignore | 2 +- bytelatent/args.py | 8 +- bytelatent/data/file_util.py | 116 ++++++++++++++++++++ bytelatent/data/iterators/arrow_iterator.py | 89 ++++++++++++--- bytelatent/logger.py | 2 +- requirements.txt | 1 + setup/download_prepare_hf_data.py | 17 ++- 7 files changed, 216 insertions(+), 19 deletions(-) create mode 100644 bytelatent/data/file_util.py diff --git a/.gitignore b/.gitignore index 56891a9..6c664b8 100644 --- a/.gitignore +++ b/.gitignore @@ -165,4 +165,4 @@ cython_debug/ figures/ .vscode/ .DS_Store - +internal/ diff --git a/bytelatent/args.py b/bytelatent/args.py index 4fba100..cfba3bf 100644 --- a/bytelatent/args.py +++ b/bytelatent/args.py @@ -46,8 +46,11 @@ def distribute_data_to_rank( arrow_batch_size: int, rank: int, world_size: int, + s3_profile: str | None = None, ) -> ArrowFileIterator: - dataset_chunks = find_and_sanitize_chunks(dataset_path, world_size) + dataset_chunks = find_and_sanitize_chunks( + dataset_path, world_size, s3_profile=s3_profile + ) n_workers_per_chunk = world_size // len(dataset_chunks) rank_to_arrow_iterator_params = [] for chunk_path in dataset_chunks: @@ -61,6 +64,7 @@ def distribute_data_to_rank( dataset_files=None, entropy_model_name=entropy_model_name, arrow_batch_size=arrow_batch_size, + s3_profile=s3_profile, ) ) return rank_to_arrow_iterator_params[rank] @@ -68,6 +72,7 @@ def distribute_data_to_rank( class DataloaderArgs(BaseModel): model_config = ConfigDict(extra="forbid") + s3_profile: str | None = None root_dir: str | None = None sources: dict[str, float] = {} batch_size: int = 2 @@ -107,6 +112,7 @@ class DataloaderArgs(BaseModel): arrow_batch_size=self.arrow_batch_size, rank=rank, world_size=world_size, + s3_profile=self.s3_profile, ) looping_iterator = LoopingIterator(arrow_iterator) preprocess_iterator = PreprocessIterator( diff --git a/bytelatent/data/file_util.py b/bytelatent/data/file_util.py new file mode 100644 index 0000000..312e1b4 --- /dev/null +++ b/bytelatent/data/file_util.py @@ -0,0 +1,116 @@ +import os +from rich.progress import track +import typer +import fsspec +import pyarrow as pa +from pyarrow.lib import ArrowInvalid + +# pyarrow needs the initialization from this import +import pyarrow.dataset # pyright: ignore + + +def is_valid_arrow_file(path: str): + try: + dataset = pa.dataset.dataset(path, format="arrow") + return True + except ArrowInvalid: + return False + + +app = typer.Typer() + +S3_PREFIX = "s3://" + + +def get_fs(path: str, s3_profile: str | None = None) -> fsspec.AbstractFileSystem: + if path.startswith("s3://"): + if s3_profile is None: + return fsspec.filesystem("s3") + else: + return fsspec.filesystem("s3", profile=s3_profile) + else: + return fsspec.filesystem("file") + + +@app.command() +def print_local_to_delete( + blob_dir: str, local_dirs: list[str], s3_profile: str = "blt" +): + for s in local_dirs: + assert s.endswith("/"), "Dirs must end with /" + assert blob_dir.endswith("/"), "Dirs must end with /" + blob_fs = fsspec.filesystem("s3", profile=s3_profile) + blob_files = blob_fs.find(blob_dir) + for f in track(blob_files): + size = blob_fs.info(f)["Size"] + if not f.lower().endswith(".complete"): + assert size != 0, f"Size was invalidly zero for {f}" + + blob_relative_paths = {f[len(blob_dir) - len(S3_PREFIX) :] for f in blob_files} + local_fs = fsspec.filesystem("file") + + files_to_delete = [] + for local_dir in local_dirs: + local_files = local_fs.find(local_dir) + for f in local_files: + relative_path = f[len(local_dir) :] + if relative_path in blob_relative_paths and not os.path.islink(f): + files_to_delete.append(f) + print(len(files_to_delete)) + with open("/tmp/files_to_delete.txt", "w") as f: + for file in files_to_delete: + f.write(f"{file}\n") + + +@app.command() +def compare_local_to_blob( + source_dirs: list[str], dst_dir: str, s3_profile: str = "blt" +): + for s in source_dirs: + assert s.endswith("/"), "Dirs must end with /" + assert dst_dir.endswith("/"), "Dirs must end with /" + assert len(source_dirs) != 0 + assert dst_dir.startswith("s3://") + local_fs = fsspec.filesystem("file") + dst_fs = fsspec.filesystem("s3", profile=s3_profile) + source_to_files = {} + all_local_files = set() + for s in source_dirs: + skipped = [] + if s not in source_to_files: + source_to_files[s] = [] + for f in local_fs.find(s): + if os.path.islink(f): + continue + if f.endswith(".COMPLETE") or f.endswith(".complete"): + is_complete_file = True + assert os.path.getsize(f) == 0, ".COMPLETE files should be empty" + else: + is_complete_file = False + + if not is_complete_file and os.path.getsize(f) == 0: + skipped.append(f) + continue + if f.endswith(".arrow"): + if not is_valid_arrow_file(f): + skipped.append(f) + continue + + source_to_files[s].append(f) + all_local_files.add(f[len(s) :]) + print(s, len(source_to_files[s]), "skipped", len(skipped), skipped[:10]) + + dst_files = dst_fs.find(dst_dir) + print(dst_dir, len(dst_files)) + + dst_file_set = {f[len(dst_dir) - len(S3_PREFIX) :] for f in dst_files} + diff = all_local_files.symmetric_difference(dst_file_set) + print("Local files", len(all_local_files)) + print("DST Files", len(dst_file_set)) + print("Symmetric difference", len(diff)) + dst_only_files = dst_file_set - all_local_files + print("DST only", len(dst_only_files), list(dst_only_files)[:10]) + + +if __name__ == "__main__": + app() diff --git a/bytelatent/data/iterators/arrow_iterator.py b/bytelatent/data/iterators/arrow_iterator.py index df5f023..ad72091 100644 --- a/bytelatent/data/iterators/arrow_iterator.py +++ b/bytelatent/data/iterators/arrow_iterator.py @@ -1,7 +1,7 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. +import os import re from logging import getLogger -from pathlib import Path from typing import Any, Generator import pyarrow as pa @@ -9,9 +9,12 @@ import pyarrow as pa # pyarrow needs the initialization from this import import pyarrow.dataset # pyright: ignore from pydantic import BaseModel, ConfigDict +import fsspec +import s3fs from bytelatent import ByteLatentError from bytelatent.data.data_types import BltExample +from bytelatent.data.file_util import get_fs from bytelatent.data.iterators.abstract_iterator import IteratorState, StatefulIterator logger = getLogger(__name__) @@ -27,6 +30,8 @@ class ArrowFileIteratorState(BaseModel, IteratorState): dataset_files: list[str] | None entropy_model_name: str | None arrow_batch_size: int = 100 + s3_profile: str | None + filesystem_type: str | None = None def build(self) -> "ArrowFileIterator": arrow_file = ArrowFileIterator( @@ -37,14 +42,17 @@ class ArrowFileIteratorState(BaseModel, IteratorState): entropy_model_name=self.entropy_model_name, arrow_batch_size=self.arrow_batch_size, dataset_files=self.dataset_files, + s3_profile=self.s3_profile, + filesystem_type=self.filesystem_type, ) if self.row_num != 0: arrow_file._set_row_num(self.row_num) return arrow_file -def shard_sort_key(file: str | Path): - match = re.search(r".+\.shard_([0-9]+)\.arrow", str(file)) +def shard_sort_key(file: str): + assert isinstance(file, str) + match = re.search(r".+\.shard_([0-9]+)\.arrow", file) shard_number = int(match.group(1)) return shard_number @@ -60,6 +68,8 @@ class ArrowFileIterator(StatefulIterator): entropy_model_name: str | None, arrow_batch_size: int, dataset_files: list[str] | None = None, + s3_profile: str | None = None, + filesystem_type: str | None = None, ): assert 0 <= worker_id < num_workers, (worker_id, num_workers) if file_path is None and dataset_files is None: @@ -75,16 +85,41 @@ class ArrowFileIterator(StatefulIterator): self.preprocess_dir = preprocess_dir self.entropy_model_name = entropy_model_name self.arrow_batch_size = arrow_batch_size + self.s3_profile = s3_profile + self.filesystem_type = filesystem_type + self.fs = None + if self.filesystem_type is not None: + if self.filesystem_type == "file": + self.fs = fsspec.filesystem("file") + elif self.filesystem_type == "s3": + self.fs = fsspec.filesystem("s3", profile=s3_profile) + if dataset_files is None: # Prepare arrow shards - jsonl_file = Path(file_path) - parts = re.match(r"(.+)\.chunk\.[0-9]+\.jsonl", jsonl_file.name) + jsonl_file = file_path + parts = re.match( + r"(.+)\.chunk\.[0-9]+\.jsonl", os.path.basename(jsonl_file) + ) assert parts is not None dataset = parts.group(1) - data_dir = Path(preprocess_dir) / dataset / entropy_model_name - shard_files = list(data_dir.glob(f"{jsonl_file.name}.shard_*.arrow")) + data_dir = os.path.join(preprocess_dir, dataset, entropy_model_name) + data_dir_with_glob = os.path.join( + data_dir, f"{os.path.basename(jsonl_file)}.shard_*.arrow" + ) + if self.fs is None: + self.fs = get_fs(data_dir_with_glob, s3_profile=s3_profile) + if isinstance(self.fs, s3fs.S3FileSystem): + self.filesystem_type = "s3" + else: + self.filesystem_type = "file" + shard_files = self.fs.glob(data_dir_with_glob) + for s in shard_files: - if not (data_dir / f"{s.name}.complete").exists(): + complete_file = os.path.join( + data_dir, f"{os.path.basename(s)}.complete" + ) + + if not self.fs.exists(complete_file): raise ValueError(f"Missing .complete for input file: {s}") shard_files = sorted(shard_files, key=shard_sort_key) @@ -92,10 +127,19 @@ class ArrowFileIterator(StatefulIterator): raise ByteLatentError( f"Zero shard_files found corresponding to: {file_path} using preprocess_dir={preprocess_dir} and entropy_model_name={entropy_model_name}, so the search path is data_dir={data_dir} for matches to {jsonl_file.name}.shard_*.arrow" ) - self.dataset_files = [str(f) for f in shard_files] + self.dataset_files = [f for f in shard_files] else: self.preprocess_dir = None self.dataset_files = dataset_files + if dataset_files[0].startswith("s3://"): + for f in dataset_files: + assert f.startswith("s3://") + if self.fs is None: + self.fs = get_fs(dataset_files[0], s3_profile=s3_profile) + if isinstance(self.fs, s3fs.S3FileSystem): + self.filesystem_type = "s3" + else: + self.filesystem_type = "file" def get_state(self) -> ArrowFileIteratorState: return ArrowFileIteratorState( @@ -107,13 +151,21 @@ class ArrowFileIterator(StatefulIterator): entropy_model_name=self.entropy_model_name, arrow_batch_size=self.arrow_batch_size, dataset_files=self.dataset_files, + s3_profile=self.s3_profile, + filesystem_type=self.filesystem_type, ) def create_iter( self, ) -> Generator[BltExample, Any, None]: if self.dataset is None: - self.dataset = pa.dataset.dataset(self.dataset_files, format="arrow") + if isinstance(self.fs, s3fs.core.S3FileSystem): + filesystem = self.fs + else: + filesystem = None + self.dataset = pa.dataset.dataset( + self.dataset_files, format="arrow", filesystem=filesystem + ) self.batch_iterator = self.dataset.to_batches( batch_size=self.arrow_batch_size ) @@ -165,7 +217,13 @@ class ArrowFileIterator(StatefulIterator): self.batch_iterator = None self.batch_to_consume = None else: - self.dataset = pa.dataset.dataset(self.dataset_files, format="arrow") + if isinstance(self.fs, s3fs.core.S3FileSystem): + filesystem = self.fs + else: + filesystem = None + self.dataset = pa.dataset.dataset( + self.dataset_files, format="arrow", filesystem=filesystem + ) self.batch_iterator = self.dataset.to_batches( batch_size=self.arrow_batch_size ) @@ -198,9 +256,14 @@ TRAIN_DATA_FILE_PATTERN = "*.chunk.*.jsonl" def find_and_sanitize_chunks( - dataset_path: str, world_size: int, file_pattern: str = TRAIN_DATA_FILE_PATTERN + dataset_path: str, + world_size: int, + file_pattern: str = TRAIN_DATA_FILE_PATTERN, + s3_profile: str | None = None, ): - dataset_chunks = [str(p) for p in Path(dataset_path).glob(file_pattern)] + fs = get_fs(dataset_path, s3_profile=s3_profile) + path_with_glob = os.path.join(dataset_path, file_pattern) + dataset_chunks = fs.glob(path_with_glob) n_chunks = len(dataset_chunks) if n_chunks > world_size: diff --git a/bytelatent/logger.py b/bytelatent/logger.py index 6723a84..87f04cc 100644 --- a/bytelatent/logger.py +++ b/bytelatent/logger.py @@ -91,7 +91,7 @@ def init_logger( log_file: str | None = None, *, name: str | None = None, - level: str = "NOTSET", + level: str = "INFO", ): """ Setup logging. diff --git a/requirements.txt b/requirements.txt index c6d87f1..59192cd 100644 --- a/requirements.txt +++ b/requirements.txt @@ -20,3 +20,4 @@ altair submitit typer rich +fsspec[full] diff --git a/setup/download_prepare_hf_data.py b/setup/download_prepare_hf_data.py index 1aacf8f..e194857 100644 --- a/setup/download_prepare_hf_data.py +++ b/setup/download_prepare_hf_data.py @@ -5,6 +5,7 @@ import os import subprocess import time +import fsspec import requests from huggingface_hub import snapshot_download @@ -38,11 +39,21 @@ def download_dataset(repo_id, local_dir, allow_patterns): print(f"Dataset downloaded to {local_dir}") -def parquet_to_jsonl(dataset, work_dir, src_dir, tgt_dir, ntasks=64): +def parquet_to_jsonl( + dataset, work_dir, src_dir, tgt_dir, ntasks=64, s3_profile: str | None = None +): from datatrove.executor import LocalPipelineExecutor from datatrove.pipeline.readers import ParquetReader from datatrove.pipeline.writers import JsonlWriter + if tgt_dir.startswith("s3//"): + if s3_profile is None: + out_spec = tgt_dir + else: + out_spec = (tgt_dir, fsspec.filesystem("s3", profile=s3_profile)) + else: + out_spec = tgt_dir + pipeline_exec = LocalPipelineExecutor( pipeline=[ ParquetReader( @@ -52,7 +63,7 @@ def parquet_to_jsonl(dataset, work_dir, src_dir, tgt_dir, ntasks=64): glob_pattern="**/*.parquet", ), JsonlWriter( - tgt_dir, + out_spec, output_filename=dataset + ".chunk.${rank}.jsonl", compression=None, ), @@ -77,7 +88,7 @@ def setup_terashuf(work_dir): return terashuf_dir -def main(dataset, memory, data_dir, seed=42, nchunks=32): +def main(dataset, memory, data_dir, seed=42, nchunks=32, s3_profile: str | None = None): # Configuration repo_id = { "fineweb_edu": "HuggingFaceFW/fineweb-edu", From a1d05403b42cdb111be2cf3034a044e63f45b91b Mon Sep 17 00:00:00 2001 From: Pedro Rodriguez Date: Fri, 10 Jan 2025 01:02:25 +0000 Subject: [PATCH 5/6] Replace regular filesystem calls with fsspec + add s3 support Summary: For compatibility with either local/nfs or S3 datasets, swap to fsspec. Add a tool to compare local and remote filesystems Test Plan: - Ran regular train script - Ran with config with data in S3 --- .gitignore | 2 +- bytelatent/args.py | 8 +- bytelatent/data/file_util.py | 117 ++++++++++++++++++++ bytelatent/data/iterators/arrow_iterator.py | 89 ++++++++++++--- bytelatent/logger.py | 2 +- requirements.txt | 1 + setup/download_prepare_hf_data.py | 17 ++- 7 files changed, 217 insertions(+), 19 deletions(-) create mode 100644 bytelatent/data/file_util.py diff --git a/.gitignore b/.gitignore index 56891a9..6c664b8 100644 --- a/.gitignore +++ b/.gitignore @@ -165,4 +165,4 @@ cython_debug/ figures/ .vscode/ .DS_Store - +internal/ diff --git a/bytelatent/args.py b/bytelatent/args.py index 4fba100..cfba3bf 100644 --- a/bytelatent/args.py +++ b/bytelatent/args.py @@ -46,8 +46,11 @@ def distribute_data_to_rank( arrow_batch_size: int, rank: int, world_size: int, + s3_profile: str | None = None, ) -> ArrowFileIterator: - dataset_chunks = find_and_sanitize_chunks(dataset_path, world_size) + dataset_chunks = find_and_sanitize_chunks( + dataset_path, world_size, s3_profile=s3_profile + ) n_workers_per_chunk = world_size // len(dataset_chunks) rank_to_arrow_iterator_params = [] for chunk_path in dataset_chunks: @@ -61,6 +64,7 @@ def distribute_data_to_rank( dataset_files=None, entropy_model_name=entropy_model_name, arrow_batch_size=arrow_batch_size, + s3_profile=s3_profile, ) ) return rank_to_arrow_iterator_params[rank] @@ -68,6 +72,7 @@ def distribute_data_to_rank( class DataloaderArgs(BaseModel): model_config = ConfigDict(extra="forbid") + s3_profile: str | None = None root_dir: str | None = None sources: dict[str, float] = {} batch_size: int = 2 @@ -107,6 +112,7 @@ class DataloaderArgs(BaseModel): arrow_batch_size=self.arrow_batch_size, rank=rank, world_size=world_size, + s3_profile=self.s3_profile, ) looping_iterator = LoopingIterator(arrow_iterator) preprocess_iterator = PreprocessIterator( diff --git a/bytelatent/data/file_util.py b/bytelatent/data/file_util.py new file mode 100644 index 0000000..d67b6db --- /dev/null +++ b/bytelatent/data/file_util.py @@ -0,0 +1,117 @@ +import os + +import fsspec +import pyarrow as pa + +# pyarrow needs the initialization from this import +import pyarrow.dataset # pyright: ignore +import typer +from pyarrow.lib import ArrowInvalid +from rich.progress import track + + +def is_valid_arrow_file(path: str): + try: + dataset = pa.dataset.dataset(path, format="arrow") + return True + except ArrowInvalid: + return False + + +app = typer.Typer() + +S3_PREFIX = "s3://" + + +def get_fs(path: str, s3_profile: str | None = None) -> fsspec.AbstractFileSystem: + if path.startswith("s3://"): + if s3_profile is None: + return fsspec.filesystem("s3") + else: + return fsspec.filesystem("s3", profile=s3_profile) + else: + return fsspec.filesystem("file") + + +@app.command() +def print_local_to_delete( + blob_dir: str, local_dirs: list[str], s3_profile: str = "blt" +): + for s in local_dirs: + assert s.endswith("/"), "Dirs must end with /" + assert blob_dir.endswith("/"), "Dirs must end with /" + blob_fs = fsspec.filesystem("s3", profile=s3_profile) + blob_files = blob_fs.find(blob_dir) + for f in track(blob_files): + size = blob_fs.info(f)["Size"] + if not f.lower().endswith(".complete"): + assert size != 0, f"Size was invalidly zero for {f}" + + blob_relative_paths = {f[len(blob_dir) - len(S3_PREFIX) :] for f in blob_files} + local_fs = fsspec.filesystem("file") + + files_to_delete = [] + for local_dir in local_dirs: + local_files = local_fs.find(local_dir) + for f in local_files: + relative_path = f[len(local_dir) :] + if relative_path in blob_relative_paths and not os.path.islink(f): + files_to_delete.append(f) + print(len(files_to_delete)) + with open("/tmp/files_to_delete.txt", "w") as f: + for file in files_to_delete: + f.write(f"{file}\n") + + +@app.command() +def compare_local_to_blob( + source_dirs: list[str], dst_dir: str, s3_profile: str = "blt" +): + for s in source_dirs: + assert s.endswith("/"), "Dirs must end with /" + assert dst_dir.endswith("/"), "Dirs must end with /" + assert len(source_dirs) != 0 + assert dst_dir.startswith("s3://") + local_fs = fsspec.filesystem("file") + dst_fs = fsspec.filesystem("s3", profile=s3_profile) + source_to_files = {} + all_local_files = set() + for s in source_dirs: + skipped = [] + if s not in source_to_files: + source_to_files[s] = [] + for f in local_fs.find(s): + if os.path.islink(f): + continue + if f.endswith(".COMPLETE") or f.endswith(".complete"): + is_complete_file = True + assert os.path.getsize(f) == 0, ".COMPLETE files should be empty" + else: + is_complete_file = False + + if not is_complete_file and os.path.getsize(f) == 0: + skipped.append(f) + continue + if f.endswith(".arrow"): + if not is_valid_arrow_file(f): + skipped.append(f) + continue + + source_to_files[s].append(f) + all_local_files.add(f[len(s) :]) + print(s, len(source_to_files[s]), "skipped", len(skipped), skipped[:10]) + + dst_files = dst_fs.find(dst_dir) + print(dst_dir, len(dst_files)) + + dst_file_set = {f[len(dst_dir) - len(S3_PREFIX) :] for f in dst_files} + diff = all_local_files.symmetric_difference(dst_file_set) + print("Local files", len(all_local_files)) + print("DST Files", len(dst_file_set)) + print("Symmetric difference", len(diff)) + dst_only_files = dst_file_set - all_local_files + print("DST only", len(dst_only_files), list(dst_only_files)[:10]) + + +if __name__ == "__main__": + app() diff --git a/bytelatent/data/iterators/arrow_iterator.py b/bytelatent/data/iterators/arrow_iterator.py index df5f023..4e7b99e 100644 --- a/bytelatent/data/iterators/arrow_iterator.py +++ b/bytelatent/data/iterators/arrow_iterator.py @@ -1,17 +1,20 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. +import os import re from logging import getLogger -from pathlib import Path from typing import Any, Generator +import fsspec import pyarrow as pa # pyarrow needs the initialization from this import import pyarrow.dataset # pyright: ignore +import s3fs from pydantic import BaseModel, ConfigDict from bytelatent import ByteLatentError from bytelatent.data.data_types import BltExample +from bytelatent.data.file_util import get_fs from bytelatent.data.iterators.abstract_iterator import IteratorState, StatefulIterator logger = getLogger(__name__) @@ -27,6 +30,8 @@ class ArrowFileIteratorState(BaseModel, IteratorState): dataset_files: list[str] | None entropy_model_name: str | None arrow_batch_size: int = 100 + s3_profile: str | None + filesystem_type: str | None = None def build(self) -> "ArrowFileIterator": arrow_file = ArrowFileIterator( @@ -37,14 +42,17 @@ class ArrowFileIteratorState(BaseModel, IteratorState): entropy_model_name=self.entropy_model_name, arrow_batch_size=self.arrow_batch_size, dataset_files=self.dataset_files, + s3_profile=self.s3_profile, + filesystem_type=self.filesystem_type, ) if self.row_num != 0: arrow_file._set_row_num(self.row_num) return arrow_file -def shard_sort_key(file: str | Path): - match = re.search(r".+\.shard_([0-9]+)\.arrow", str(file)) +def shard_sort_key(file: str): + assert isinstance(file, str) + match = re.search(r".+\.shard_([0-9]+)\.arrow", file) shard_number = int(match.group(1)) return shard_number @@ -60,6 +68,8 @@ class ArrowFileIterator(StatefulIterator): entropy_model_name: str | None, arrow_batch_size: int, dataset_files: list[str] | None = None, + s3_profile: str | None = None, + filesystem_type: str | None = None, ): assert 0 <= worker_id < num_workers, (worker_id, num_workers) if file_path is None and dataset_files is None: @@ -75,16 +85,41 @@ class ArrowFileIterator(StatefulIterator): self.preprocess_dir = preprocess_dir self.entropy_model_name = entropy_model_name self.arrow_batch_size = arrow_batch_size + self.s3_profile = s3_profile + self.filesystem_type = filesystem_type + self.fs = None + if self.filesystem_type is not None: + if self.filesystem_type == "file": + self.fs = fsspec.filesystem("file") + elif self.filesystem_type == "s3": + self.fs = fsspec.filesystem("s3", profile=s3_profile) + if dataset_files is None: # Prepare arrow shards - jsonl_file = Path(file_path) - parts = re.match(r"(.+)\.chunk\.[0-9]+\.jsonl", jsonl_file.name) + jsonl_file = file_path + parts = re.match( + r"(.+)\.chunk\.[0-9]+\.jsonl", os.path.basename(jsonl_file) + ) assert parts is not None dataset = parts.group(1) - data_dir = Path(preprocess_dir) / dataset / entropy_model_name - shard_files = list(data_dir.glob(f"{jsonl_file.name}.shard_*.arrow")) + data_dir = os.path.join(preprocess_dir, dataset, entropy_model_name) + data_dir_with_glob = os.path.join( + data_dir, f"{os.path.basename(jsonl_file)}.shard_*.arrow" + ) + if self.fs is None: + self.fs = get_fs(data_dir_with_glob, s3_profile=s3_profile) + if isinstance(self.fs, s3fs.S3FileSystem): + self.filesystem_type = "s3" + else: + self.filesystem_type = "file" + shard_files = self.fs.glob(data_dir_with_glob) + for s in shard_files: - if not (data_dir / f"{s.name}.complete").exists(): + complete_file = os.path.join( + data_dir, f"{os.path.basename(s)}.complete" + ) + + if not self.fs.exists(complete_file): raise ValueError(f"Missing .complete for input file: {s}") shard_files = sorted(shard_files, key=shard_sort_key) @@ -92,10 +127,19 @@ class ArrowFileIterator(StatefulIterator): raise ByteLatentError( f"Zero shard_files found corresponding to: {file_path} using preprocess_dir={preprocess_dir} and entropy_model_name={entropy_model_name}, so the search path is data_dir={data_dir} for matches to {jsonl_file.name}.shard_*.arrow" ) - self.dataset_files = [str(f) for f in shard_files] + self.dataset_files = [f for f in shard_files] else: self.preprocess_dir = None self.dataset_files = dataset_files + if dataset_files[0].startswith("s3://"): + for f in dataset_files: + assert f.startswith("s3://") + if self.fs is None: + self.fs = get_fs(dataset_files[0], s3_profile=s3_profile) + if isinstance(self.fs, s3fs.S3FileSystem): + self.filesystem_type = "s3" + else: + self.filesystem_type = "file" def get_state(self) -> ArrowFileIteratorState: return ArrowFileIteratorState( @@ -107,13 +151,21 @@ class ArrowFileIterator(StatefulIterator): entropy_model_name=self.entropy_model_name, arrow_batch_size=self.arrow_batch_size, dataset_files=self.dataset_files, + s3_profile=self.s3_profile, + filesystem_type=self.filesystem_type, ) def create_iter( self, ) -> Generator[BltExample, Any, None]: if self.dataset is None: - self.dataset = pa.dataset.dataset(self.dataset_files, format="arrow") + if isinstance(self.fs, s3fs.core.S3FileSystem): + filesystem = self.fs + else: + filesystem = None + self.dataset = pa.dataset.dataset( + self.dataset_files, format="arrow", filesystem=filesystem + ) self.batch_iterator = self.dataset.to_batches( batch_size=self.arrow_batch_size ) @@ -165,7 +217,13 @@ class ArrowFileIterator(StatefulIterator): self.batch_iterator = None self.batch_to_consume = None else: - self.dataset = pa.dataset.dataset(self.dataset_files, format="arrow") + if isinstance(self.fs, s3fs.core.S3FileSystem): + filesystem = self.fs + else: + filesystem = None + self.dataset = pa.dataset.dataset( + self.dataset_files, format="arrow", filesystem=filesystem + ) self.batch_iterator = self.dataset.to_batches( batch_size=self.arrow_batch_size ) @@ -198,9 +256,14 @@ TRAIN_DATA_FILE_PATTERN = "*.chunk.*.jsonl" def find_and_sanitize_chunks( - dataset_path: str, world_size: int, file_pattern: str = TRAIN_DATA_FILE_PATTERN + dataset_path: str, + world_size: int, + file_pattern: str = TRAIN_DATA_FILE_PATTERN, + s3_profile: str | None = None, ): - dataset_chunks = [str(p) for p in Path(dataset_path).glob(file_pattern)] + fs = get_fs(dataset_path, s3_profile=s3_profile) + path_with_glob = os.path.join(dataset_path, file_pattern) + dataset_chunks = fs.glob(path_with_glob) n_chunks = len(dataset_chunks) if n_chunks > world_size: diff --git a/bytelatent/logger.py b/bytelatent/logger.py index 6723a84..87f04cc 100644 --- a/bytelatent/logger.py +++ b/bytelatent/logger.py @@ -91,7 +91,7 @@ def init_logger( log_file: str | None = None, *, name: str | None = None, - level: str = "NOTSET", + level: str = "INFO", ): """ Setup logging. diff --git a/requirements.txt b/requirements.txt index c6d87f1..59192cd 100644 --- a/requirements.txt +++ b/requirements.txt @@ -20,3 +20,4 @@ altair submitit typer rich +fsspec[full] diff --git a/setup/download_prepare_hf_data.py b/setup/download_prepare_hf_data.py index 1aacf8f..e194857 100644 --- a/setup/download_prepare_hf_data.py +++ b/setup/download_prepare_hf_data.py @@ -5,6 +5,7 @@ import os import subprocess import time +import fsspec import requests from huggingface_hub import snapshot_download @@ -38,11 +39,21 @@ def download_dataset(repo_id, local_dir, allow_patterns): print(f"Dataset downloaded to {local_dir}") -def parquet_to_jsonl(dataset, work_dir, src_dir, tgt_dir, ntasks=64): +def parquet_to_jsonl( + dataset, work_dir, src_dir, tgt_dir, ntasks=64, s3_profile: str | None = None +): from datatrove.executor import LocalPipelineExecutor from datatrove.pipeline.readers import ParquetReader from datatrove.pipeline.writers import JsonlWriter + if tgt_dir.startswith("s3//"): + if s3_profile is None: + out_spec = tgt_dir + else: + out_spec = (tgt_dir, fsspec.filesystem("s3", profile=s3_profile)) + else: + out_spec = tgt_dir + pipeline_exec = LocalPipelineExecutor( pipeline=[ ParquetReader( @@ -52,7 +63,7 @@ def parquet_to_jsonl(dataset, work_dir, src_dir, tgt_dir, ntasks=64): glob_pattern="**/*.parquet", ), JsonlWriter( - tgt_dir, + out_spec, output_filename=dataset + ".chunk.${rank}.jsonl", compression=None, ), @@ -77,7 +88,7 @@ def setup_terashuf(work_dir): return terashuf_dir -def main(dataset, memory, data_dir, seed=42, nchunks=32): +def main(dataset, memory, data_dir, seed=42, nchunks=32, s3_profile: str | None = None): # Configuration repo_id = { "fineweb_edu": "HuggingFaceFW/fineweb-edu", From 3f045f11238af657269e2051d870515e6ac7913b Mon Sep 17 00:00:00 2001 From: Pedro Rodriguez Date: Mon, 13 Jan 2025 23:13:49 +0000 Subject: [PATCH 6/6] Update preprocess_entropies script to blt inference + add fsspec support Summary: Test Plan: --- bytelatent/data/patcher.py | 8 +- bytelatent/preprocess/preprocess_entropies.py | 105 ++++++++++++------ requirements.txt | 1 + 3 files changed, 74 insertions(+), 40 deletions(-) diff --git a/bytelatent/data/patcher.py b/bytelatent/data/patcher.py index ede8b06..f8477a3 100644 --- a/bytelatent/data/patcher.py +++ b/bytelatent/data/patcher.py @@ -82,16 +82,16 @@ def calculate_entropies( if device is not None: split = split.to(device) assert torch.all(split >= 0) and torch.all(split < 260) - pred, _ = entropy_model(split) + pred = entropy_model(split) pred = pred.reshape(-1, pred.shape[-1])[ : split.numel() - pad_size, : ] # [batch_size * seq_len, vocab] pred_entropies = entropy(pred) entropies.append(pred_entropies) - entropies = torch.cat(entropies, dim=0) - entropies = entropies.reshape(tokens.shape) - return entropies + concat_entropies = torch.cat(entropies, dim=0) + concat_entropies = concat_entropies.reshape(tokens.shape) + return concat_entropies def patch_start_mask_from_entropy_with_monotonicity(entropies, t): diff --git a/bytelatent/preprocess/preprocess_entropies.py b/bytelatent/preprocess/preprocess_entropies.py index 20d1e0c..45c081c 100644 --- a/bytelatent/preprocess/preprocess_entropies.py +++ b/bytelatent/preprocess/preprocess_entropies.py @@ -1,14 +1,59 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. +import jsonlines import time -from pathlib import Path +import fsspec import numpy as np import pyarrow as pa import torch import typer from rich.progress import Progress, TextColumn -from bytelatent.data.iterators.arrow_iterator import ArrowFileIterator +from bytelatent.data.file_util import get_fs +from bytelatent.data.patcher import calculate_entropies +from bytelatent.entropy_model import load_entropy_model +from bytelatent.tokenizers.build_tokenizer import TokenizerArgs + + +def get_id_from_doc(doc: dict) -> int: + """ + We need a reliable way to ensure that samples from jsonl + and arrow are the same, but there is no unique id field, + so derive the best possible + """ + if "sample_id" in doc: + sample_id = doc["sample_id"] + elif "title" in doc: + sample_id = doc["title"] + elif "qid" in doc: + sample_id = doc["qid"] + elif "paper_id" in doc: + sample_id = doc["paper_id"] + elif "path" in doc: + sample_id = doc["path"] + elif "url" in doc: + sample_id = doc["url"] + elif "id" in doc: + sample_id = doc["id"] + else: + raise ValueError(f"Could not find a id key from: {doc.keys()}") + return str(sample_id) + + +def get_text(doc: dict): + if "text" in doc: + text = doc["text"] + elif "content" in doc: + text = doc["content"] + else: + raise ValueError(f"Could not find a text key from: {doc.keys()}") + return text + + +def jsonl_file_iterator(fs: fsspec.AbstractFileSystem, path: str): + with fs.open(path) as f: + reader = jsonlines.Reader(f) + yield from reader def main( @@ -16,39 +61,32 @@ def main( output_file: str, patching_device: str = "cuda", log_step: int = 10_000, - entropy_model_checkpoint_dir: str = "entropy_checkpoint_dir", + entropy_model_checkpoint_dir: str = "public_data/entropy_checkpoint", + entropy_model_state_dict_path: str = "public_data/entropy_model.pth", + bpe_tokenizer_path: str = "public_data/tokenizer.model", dry_run: bool = False, + s3_profile: str | None = None, ): - # TODO: Modify this to work with the new code - raise NotImplementedError() - iterator = ArrowFileIterator( - file_path=input_file, - worker_id=0, - num_workers=1, - ) - tokenization_mode = "bytes" print(f"Preprocessing entropies, input: {input_file}, output: {output_file}") print("Loading entropy model", entropy_model_checkpoint_dir) + input_fs = get_fs(input_file, s3_profile=s3_profile) + input_doc_iterator = jsonl_file_iterator(input_fs, input_file) + if dry_run: return entropy_model = load_entropy_model( - entropy_model_checkpoint_dir, device=patching_device + entropy_model_checkpoint_dir, + entropy_model_state_dict_path, + device=patching_device, ) - entropy_model, _ = to_device(entropy_model, patching_device) + print("Creating patcher") patching_batch_size = 32 print("Creating tokenizer") - tokenizer = Tokenizer( - model_path="/home/artidoro/tokenizers/llama_v2.tokenizer.model", - tokenization_mode=tokenization_mode, - # BYTE_UNITS - vocab_size_unit_1=256, - bos=True, - eos=True, - bpe_delim=False, - # This isn't used, just stores a reference for other calls we don't use - patcher=None, + tokenizer_args = TokenizerArgs( + name="blt", init_kwargs={"bpe_tokenizer_path": bpe_tokenizer_path} ) + tokenizer = tokenizer_args.build() step = 0 print("starting") start_time = time.time() @@ -59,8 +97,10 @@ def main( schema = pa.schema([sample_id_field, text_field, entropy_field]) arrow_batch_size = 1_000 + output_fs = get_fs(output_file, s3_profile=s3_profile) + try: - with pa.OSFile(output_file, "wb") as sink: + with output_fs.open(output_file, "wb") as sink: with pa.ipc.new_file(sink, schema) as writer: id_buffer = [] entropies_buffer = [] @@ -72,17 +112,9 @@ def main( task = progress.add_task( "[green]Calculating entropies...", total=None ) - for doc in iterator: + for doc in input_doc_iterator: sample_id = get_id_from_doc(doc) - - if "text" in doc: - text = doc["text"] - elif "content" in doc: - text = doc["content"] - else: - raise ValueError( - f"Could not find a text key from: {doc.keys()}" - ) + text = get_text(doc) tokens = torch.tensor(tokenizer.encode(text)) patch_start = time.time() scores = calculate_entropies( @@ -128,9 +160,10 @@ def main( entropies_buffer = [] id_buffer = [] text_buffer = [] - Path(f"{output_file}.complete").touch() + output_fs.touch(f"{output_file}.complete") except: - Path(output_file).unlink(missing_ok=True) + if output_fs.exists(output_file): + output_fs.rm(output_file) raise elapsed = time.time() - start_time print("steps", step) diff --git a/requirements.txt b/requirements.txt index 59192cd..8490556 100644 --- a/requirements.txt +++ b/requirements.txt @@ -21,3 +21,4 @@ submitit typer rich fsspec[full] +orjson