From 28016f144d5fbcc2f279955f0038463186f0de30 Mon Sep 17 00:00:00 2001 From: Pedro Rodriguez Date: Thu, 9 Jan 2025 20:06:18 +0000 Subject: [PATCH 1/9] Add plotting code from paper Summary: Test Plan: --- bytelatent/plotting/config_entropy_figure.yaml | 3 ++- bytelatent/plotting/entropy_figure.py | 17 ++++++++++++++++- plot_data/scores.json | 1 + 3 files changed, 19 insertions(+), 2 deletions(-) create mode 100644 plot_data/scores.json diff --git a/bytelatent/plotting/config_entropy_figure.yaml b/bytelatent/plotting/config_entropy_figure.yaml index 4d7bfd7..296ea07 100644 --- a/bytelatent/plotting/config_entropy_figure.yaml +++ b/bytelatent/plotting/config_entropy_figure.yaml @@ -1,3 +1,4 @@ data_path: plot_data/entropy_figure.json chart_path: figures/entropy_figure.pdf -# chart_path: figures/entropy_figure.pdf +threshold_override: 1.7171002626419067 +score_override_path: plot_data/scores.json diff --git a/bytelatent/plotting/entropy_figure.py b/bytelatent/plotting/entropy_figure.py index c1966a1..f401d7c 100644 --- a/bytelatent/plotting/entropy_figure.py +++ b/bytelatent/plotting/entropy_figure.py @@ -1,4 +1,5 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. +import json import os import sys from pathlib import Path @@ -12,6 +13,8 @@ from pydantic import BaseModel class PlotEntropiesConfig(BaseModel): data_path: str | None chart_path: str + score_override_path: str | None = None + threshold_override: float | None = None class Config: extra = "forbid" @@ -37,8 +40,20 @@ def main(): plot_config = PlotEntropiesConfig(**conf_dict) with open(plot_config.data_path) as f: json_data = f.read() + plot_data = PlotEntropiesData.model_validate_json(json_data) df = pd.read_json(plot_data.dataframe_json) + print("LEN", len(df)) + if plot_config.threshold_override is None: + threshold = plot_data.threshold + else: + threshold = plot_config.threshold_override + if plot_config.score_override_path is not None: + with open(plot_config.score_override_path) as f: + scores = json.load(f)["score"] + assert len(scores) == len(df) + df["entropies"] = scores + df["start"] = [1] + (df["entropies"] > threshold).values.tolist()[:-1] x_ticks = [] for row in df.itertuples(): @@ -65,7 +80,7 @@ def main(): ), ) rule = base.mark_rule(color="red", strokeDash=[4, 4]).encode( - y=alt.datum(plot_data.threshold), + y=alt.datum(threshold), ) patch_rules = ( alt.Chart(df[df["start"] > 0]) diff --git a/plot_data/scores.json b/plot_data/scores.json new file mode 100644 index 0000000..202cafc --- /dev/null +++ b/plot_data/scores.json @@ -0,0 +1 @@ +{"score": [3.3949153423309326, 2.1647746562957764, 2.3216569423675537, 2.8114914894104004, 1.505232334136963, 0.04055612534284592, 0.09150367230176926, 0.06008715182542801, 0.3453567624092102, 1.0483067035675049, 0.1967127025127411, 0.12737397849559784, 0.05923430994153023, 0.001597292022779584, 0.004362526815384626, 0.005547997076064348, 0.0011689786333590746, 0.0010273229563608766, 1.0228447914123535, 3.6863417625427246, 0.46605175733566284, 0.048645928502082825, 2.2544963359832764, 0.37329360842704773, 1.001160979270935, 2.9116122722625732, 1.8948925733566284, 1.4017235040664673, 0.3879640996456146, 0.2652309536933899, 1.780383825302124, 0.013964788988232613, 0.005456871818751097, 0.5426468253135681, 0.20666983723640442, 0.0051853349432349205, 0.0005802579107694328, 0.0007443525246344507, 0.0004390323010738939, 0.005452247802168131, 1.1932975053787231, 0.023798620328307152, 3.1230878829956055, 1.3915895223617554, 3.0489213466644287, 1.7018193006515503, 1.873910903930664, 1.4662408828735352, 0.004920408595353365, 0.02599342167377472, 0.6620859503746033, 0.31743818521499634, 2.8409600257873535, 1.1354060173034668, 0.0520976223051548, 0.3519965708255768, 0.40707266330718994, 2.5438783168792725, 1.3343133926391602, 0.023993035778403282, 3.445943832397461, 1.8542104959487915, 0.7849258780479431, 0.6848396062850952, 0.06938046962022781, 0.20923230051994324, 0.10084306448698044, 0.18334199488162994, 0.4126923978328705, 0.5505472421646118, 0.1042013093829155, 0.019447727128863335, 0.0014866517158225179, 0.0009848219342529774, 0.00021391961490735412, 0.007746236398816109, 0.00038792978739365935, 0.0007933690212666988, 1.2369810342788696, 0.4436197578907013, 4.6366687456611544e-05]} \ No newline at end of file From 84854423c49fa59187855785c935d6cdbc9e0c45 Mon Sep 17 00:00:00 2001 From: Pedro Rodriguez Date: Fri, 10 Jan 2025 01:00:54 +0000 Subject: [PATCH 2/9] Replace regular filesystem calls with fsspec + add s3 support Summary: For compatibility with either local/nfs or S3 datasets, swap to fsspec. Add a tool to compare local and remote filesystems Test Plan: - Ran regular train script - Ran with config with data in S3 --- .gitignore | 2 +- bytelatent/args.py | 8 +- bytelatent/data/file_util.py | 116 ++++++++++++++++++++ bytelatent/data/iterators/arrow_iterator.py | 95 +++++++++++++--- bytelatent/logger.py | 2 +- requirements.txt | 1 + setup/download_prepare_hf_data.py | 17 ++- 7 files changed, 222 insertions(+), 19 deletions(-) create mode 100644 bytelatent/data/file_util.py diff --git a/.gitignore b/.gitignore index 56891a9..6c664b8 100644 --- a/.gitignore +++ b/.gitignore @@ -165,4 +165,4 @@ cython_debug/ figures/ .vscode/ .DS_Store - +internal/ diff --git a/bytelatent/args.py b/bytelatent/args.py index 4fba100..cfba3bf 100644 --- a/bytelatent/args.py +++ b/bytelatent/args.py @@ -46,8 +46,11 @@ def distribute_data_to_rank( arrow_batch_size: int, rank: int, world_size: int, + s3_profile: str | None = None, ) -> ArrowFileIterator: - dataset_chunks = find_and_sanitize_chunks(dataset_path, world_size) + dataset_chunks = find_and_sanitize_chunks( + dataset_path, world_size, s3_profile=s3_profile + ) n_workers_per_chunk = world_size // len(dataset_chunks) rank_to_arrow_iterator_params = [] for chunk_path in dataset_chunks: @@ -61,6 +64,7 @@ def distribute_data_to_rank( dataset_files=None, entropy_model_name=entropy_model_name, arrow_batch_size=arrow_batch_size, + s3_profile=s3_profile, ) ) return rank_to_arrow_iterator_params[rank] @@ -68,6 +72,7 @@ def distribute_data_to_rank( class DataloaderArgs(BaseModel): model_config = ConfigDict(extra="forbid") + s3_profile: str | None = None root_dir: str | None = None sources: dict[str, float] = {} batch_size: int = 2 @@ -107,6 +112,7 @@ class DataloaderArgs(BaseModel): arrow_batch_size=self.arrow_batch_size, rank=rank, world_size=world_size, + s3_profile=self.s3_profile, ) looping_iterator = LoopingIterator(arrow_iterator) preprocess_iterator = PreprocessIterator( diff --git a/bytelatent/data/file_util.py b/bytelatent/data/file_util.py new file mode 100644 index 0000000..312e1b4 --- /dev/null +++ b/bytelatent/data/file_util.py @@ -0,0 +1,116 @@ +import os +from rich.progress import track +import typer +import fsspec +import pyarrow as pa +from pyarrow.lib import ArrowInvalid + +# pyarrow needs the initialization from this import +import pyarrow.dataset # pyright: ignore + + +def is_valid_arrow_file(path: str): + try: + dataset = pa.dataset.dataset(path, format="arrow") + return True + except ArrowInvalid: + return False + + +app = typer.Typer() + +S3_PREFIX = "s3://" + + +def get_fs(path: str, s3_profile: str | None = None) -> fsspec.AbstractFileSystem: + if path.startswith("s3://"): + if s3_profile is None: + return fsspec.filesystem("s3") + else: + return fsspec.filesystem("s3", profile=s3_profile) + else: + return fsspec.filesystem("file") + + +@app.command() +def print_local_to_delete( + blob_dir: str, local_dirs: list[str], s3_profile: str = "blt" +): + for s in local_dirs: + assert s.endswith("/"), "Dirs must end with /" + assert blob_dir.endswith("/"), "Dirs must end with /" + blob_fs = fsspec.filesystem("s3", profile=s3_profile) + blob_files = blob_fs.find(blob_dir) + for f in track(blob_files): + size = blob_fs.info(f)["Size"] + if not f.lower().endswith(".complete"): + assert size != 0, f"Size was invalidly zero for {f}" + + blob_relative_paths = {f[len(blob_dir) - len(S3_PREFIX) :] for f in blob_files} + local_fs = fsspec.filesystem("file") + + files_to_delete = [] + for local_dir in local_dirs: + local_files = local_fs.find(local_dir) + for f in local_files: + relative_path = f[len(local_dir) :] + if relative_path in blob_relative_paths and not os.path.islink(f): + files_to_delete.append(f) + print(len(files_to_delete)) + with open("/tmp/files_to_delete.txt", "w") as f: + for file in files_to_delete: + f.write(f"{file}\n") + + +@app.command() +def compare_local_to_blob( + source_dirs: list[str], dst_dir: str, s3_profile: str = "blt" +): + for s in source_dirs: + assert s.endswith("/"), "Dirs must end with /" + assert dst_dir.endswith("/"), "Dirs must end with /" + assert len(source_dirs) != 0 + assert dst_dir.startswith("s3://") + local_fs = fsspec.filesystem("file") + dst_fs = fsspec.filesystem("s3", profile=s3_profile) + source_to_files = {} + all_local_files = set() + for s in source_dirs: + skipped = [] + if s not in source_to_files: + source_to_files[s] = [] + for f in local_fs.find(s): + if os.path.islink(f): + continue + if f.endswith(".COMPLETE") or f.endswith(".complete"): + is_complete_file = True + assert os.path.getsize(f) == 0, ".COMPLETE files should be empty" + else: + is_complete_file = False + + if not is_complete_file and os.path.getsize(f) == 0: + skipped.append(f) + continue + if f.endswith(".arrow"): + if not is_valid_arrow_file(f): + skipped.append(f) + continue + + source_to_files[s].append(f) + all_local_files.add(f[len(s) :]) + print(s, len(source_to_files[s]), "skipped", len(skipped), skipped[:10]) + + dst_files = dst_fs.find(dst_dir) + print(dst_dir, len(dst_files)) + + dst_file_set = {f[len(dst_dir) - len(S3_PREFIX) :] for f in dst_files} + diff = all_local_files.symmetric_difference(dst_file_set) + print("Local files", len(all_local_files)) + print("DST Files", len(dst_file_set)) + print("Symmetric difference", len(diff)) + dst_only_files = dst_file_set - all_local_files + print("DST only", len(dst_only_files), list(dst_only_files)[:10]) + + +if __name__ == "__main__": + app() diff --git a/bytelatent/data/iterators/arrow_iterator.py b/bytelatent/data/iterators/arrow_iterator.py index df5f023..17569ee 100644 --- a/bytelatent/data/iterators/arrow_iterator.py +++ b/bytelatent/data/iterators/arrow_iterator.py @@ -1,7 +1,7 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. +import os import re from logging import getLogger -from pathlib import Path from typing import Any, Generator import pyarrow as pa @@ -9,9 +9,12 @@ import pyarrow as pa # pyarrow needs the initialization from this import import pyarrow.dataset # pyright: ignore from pydantic import BaseModel, ConfigDict +import fsspec +import s3fs from bytelatent import ByteLatentError from bytelatent.data.data_types import BltExample +from bytelatent.data.file_util import get_fs from bytelatent.data.iterators.abstract_iterator import IteratorState, StatefulIterator logger = getLogger(__name__) @@ -27,6 +30,8 @@ class ArrowFileIteratorState(BaseModel, IteratorState): dataset_files: list[str] | None entropy_model_name: str | None arrow_batch_size: int = 100 + s3_profile: str | None + filesystem_type: str | None = None def build(self) -> "ArrowFileIterator": arrow_file = ArrowFileIterator( @@ -37,14 +42,17 @@ class ArrowFileIteratorState(BaseModel, IteratorState): entropy_model_name=self.entropy_model_name, arrow_batch_size=self.arrow_batch_size, dataset_files=self.dataset_files, + s3_profile=self.s3_profile, + filesystem_type=self.filesystem_type, ) if self.row_num != 0: arrow_file._set_row_num(self.row_num) return arrow_file -def shard_sort_key(file: str | Path): - match = re.search(r".+\.shard_([0-9]+)\.arrow", str(file)) +def shard_sort_key(file: str): + assert isinstance(file, str) + match = re.search(r".+\.shard_([0-9]+)\.arrow", file) shard_number = int(match.group(1)) return shard_number @@ -60,6 +68,8 @@ class ArrowFileIterator(StatefulIterator): entropy_model_name: str | None, arrow_batch_size: int, dataset_files: list[str] | None = None, + s3_profile: str | None = None, + filesystem_type: str | None = None, ): assert 0 <= worker_id < num_workers, (worker_id, num_workers) if file_path is None and dataset_files is None: @@ -75,16 +85,44 @@ class ArrowFileIterator(StatefulIterator): self.preprocess_dir = preprocess_dir self.entropy_model_name = entropy_model_name self.arrow_batch_size = arrow_batch_size + self.s3_profile = s3_profile + self.filesystem_type = filesystem_type + self.fs = None + if self.filesystem_type is not None: + if self.filesystem_type == "file": + self.fs = fsspec.filesystem("file") + elif self.filesystem_type == "s3": + self.fs = fsspec.filesystem("s3", profile=s3_profile) + if dataset_files is None: # Prepare arrow shards - jsonl_file = Path(file_path) - parts = re.match(r"(.+)\.chunk\.[0-9]+\.jsonl", jsonl_file.name) + jsonl_file = file_path + parts = re.match( + r"(.+)\.chunk\.[0-9]+\.jsonl", os.path.basename(jsonl_file) + ) assert parts is not None dataset = parts.group(1) - data_dir = Path(preprocess_dir) / dataset / entropy_model_name - shard_files = list(data_dir.glob(f"{jsonl_file.name}.shard_*.arrow")) + data_dir = os.path.join(preprocess_dir, dataset, entropy_model_name) + data_dir_with_glob = os.path.join( + data_dir, f"{os.path.basename(jsonl_file)}.shard_*.arrow" + ) + print("preprocess_dir", preprocess_dir) + print("data_dir_with_glob", data_dir_with_glob) + if self.fs is None: + self.fs = get_fs(data_dir_with_glob, s3_profile=s3_profile) + if isinstance(self.fs, s3fs.S3FileSystem): + self.filesystem_type = "s3" + else: + self.filesystem_type = "file" + print("fs", self.fs) + shard_files = self.fs.glob(data_dir_with_glob) + for s in shard_files: - if not (data_dir / f"{s.name}.complete").exists(): + complete_file = os.path.join( + data_dir, f"{os.path.basename(s)}.complete" + ) + + if not self.fs.exists(complete_file): raise ValueError(f"Missing .complete for input file: {s}") shard_files = sorted(shard_files, key=shard_sort_key) @@ -92,10 +130,19 @@ class ArrowFileIterator(StatefulIterator): raise ByteLatentError( f"Zero shard_files found corresponding to: {file_path} using preprocess_dir={preprocess_dir} and entropy_model_name={entropy_model_name}, so the search path is data_dir={data_dir} for matches to {jsonl_file.name}.shard_*.arrow" ) - self.dataset_files = [str(f) for f in shard_files] + self.dataset_files = [f for f in shard_files] else: self.preprocess_dir = None self.dataset_files = dataset_files + if dataset_files[0].startswith("s3://"): + for f in dataset_files: + assert f.startswith("s3://") + if self.fs is None: + self.fs = get_fs(dataset_files[0], s3_profile=s3_profile) + if isinstance(self.fs, s3fs.S3FileSystem): + self.filesystem_type = "s3" + else: + self.filesystem_type = "file" def get_state(self) -> ArrowFileIteratorState: return ArrowFileIteratorState( @@ -107,13 +154,24 @@ class ArrowFileIterator(StatefulIterator): entropy_model_name=self.entropy_model_name, arrow_batch_size=self.arrow_batch_size, dataset_files=self.dataset_files, + s3_profile=self.s3_profile, + filesystem_type=self.filesystem_type, ) def create_iter( self, ) -> Generator[BltExample, Any, None]: if self.dataset is None: - self.dataset = pa.dataset.dataset(self.dataset_files, format="arrow") + if isinstance(self.fs, s3fs.core.S3FileSystem): + print("Using s3fs") + filesystem = self.fs + else: + print("Using local fs") + filesystem = None + print("Iterating over", self.dataset_files) + self.dataset = pa.dataset.dataset( + self.dataset_files, format="arrow", filesystem=filesystem + ) self.batch_iterator = self.dataset.to_batches( batch_size=self.arrow_batch_size ) @@ -165,7 +223,13 @@ class ArrowFileIterator(StatefulIterator): self.batch_iterator = None self.batch_to_consume = None else: - self.dataset = pa.dataset.dataset(self.dataset_files, format="arrow") + if isinstance(self.fs, s3fs.core.S3FileSystem): + filesystem = self.fs + else: + filesystem = None + self.dataset = pa.dataset.dataset( + self.dataset_files, format="arrow", filesystem=filesystem + ) self.batch_iterator = self.dataset.to_batches( batch_size=self.arrow_batch_size ) @@ -198,9 +262,14 @@ TRAIN_DATA_FILE_PATTERN = "*.chunk.*.jsonl" def find_and_sanitize_chunks( - dataset_path: str, world_size: int, file_pattern: str = TRAIN_DATA_FILE_PATTERN + dataset_path: str, + world_size: int, + file_pattern: str = TRAIN_DATA_FILE_PATTERN, + s3_profile: str | None = None, ): - dataset_chunks = [str(p) for p in Path(dataset_path).glob(file_pattern)] + fs = get_fs(dataset_path, s3_profile=s3_profile) + path_with_glob = os.path.join(dataset_path, file_pattern) + dataset_chunks = fs.glob(path_with_glob) n_chunks = len(dataset_chunks) if n_chunks > world_size: diff --git a/bytelatent/logger.py b/bytelatent/logger.py index 6723a84..87f04cc 100644 --- a/bytelatent/logger.py +++ b/bytelatent/logger.py @@ -91,7 +91,7 @@ def init_logger( log_file: str | None = None, *, name: str | None = None, - level: str = "NOTSET", + level: str = "INFO", ): """ Setup logging. diff --git a/requirements.txt b/requirements.txt index c6d87f1..59192cd 100644 --- a/requirements.txt +++ b/requirements.txt @@ -20,3 +20,4 @@ altair submitit typer rich +fsspec[full] diff --git a/setup/download_prepare_hf_data.py b/setup/download_prepare_hf_data.py index 1aacf8f..e194857 100644 --- a/setup/download_prepare_hf_data.py +++ b/setup/download_prepare_hf_data.py @@ -5,6 +5,7 @@ import os import subprocess import time +import fsspec import requests from huggingface_hub import snapshot_download @@ -38,11 +39,21 @@ def download_dataset(repo_id, local_dir, allow_patterns): print(f"Dataset downloaded to {local_dir}") -def parquet_to_jsonl(dataset, work_dir, src_dir, tgt_dir, ntasks=64): +def parquet_to_jsonl( + dataset, work_dir, src_dir, tgt_dir, ntasks=64, s3_profile: str | None = None +): from datatrove.executor import LocalPipelineExecutor from datatrove.pipeline.readers import ParquetReader from datatrove.pipeline.writers import JsonlWriter + if tgt_dir.startswith("s3//"): + if s3_profile is None: + out_spec = tgt_dir + else: + out_spec = (tgt_dir, fsspec.filesystem("s3", profile=s3_profile)) + else: + out_spec = tgt_dir + pipeline_exec = LocalPipelineExecutor( pipeline=[ ParquetReader( @@ -52,7 +63,7 @@ def parquet_to_jsonl(dataset, work_dir, src_dir, tgt_dir, ntasks=64): glob_pattern="**/*.parquet", ), JsonlWriter( - tgt_dir, + out_spec, output_filename=dataset + ".chunk.${rank}.jsonl", compression=None, ), @@ -77,7 +88,7 @@ def setup_terashuf(work_dir): return terashuf_dir -def main(dataset, memory, data_dir, seed=42, nchunks=32): +def main(dataset, memory, data_dir, seed=42, nchunks=32, s3_profile: str | None = None): # Configuration repo_id = { "fineweb_edu": "HuggingFaceFW/fineweb-edu", From c137f4e636ee85ef4110b416b404c455b86efba9 Mon Sep 17 00:00:00 2001 From: Pedro Rodriguez Date: Fri, 10 Jan 2025 01:02:25 +0000 Subject: [PATCH 3/9] Replace regular filesystem calls with fsspec + add s3 support Summary: For compatibility with either local/nfs or S3 datasets, swap to fsspec. Add a tool to compare local and remote filesystems Test Plan: - Ran regular train script - Ran with config with data in S3 --- .gitignore | 2 +- bytelatent/args.py | 8 +- bytelatent/data/file_util.py | 116 ++++++++++++++++++++ bytelatent/data/iterators/arrow_iterator.py | 95 +++++++++++++--- bytelatent/logger.py | 2 +- requirements.txt | 1 + setup/download_prepare_hf_data.py | 17 ++- 7 files changed, 222 insertions(+), 19 deletions(-) create mode 100644 bytelatent/data/file_util.py diff --git a/.gitignore b/.gitignore index 56891a9..6c664b8 100644 --- a/.gitignore +++ b/.gitignore @@ -165,4 +165,4 @@ cython_debug/ figures/ .vscode/ .DS_Store - +internal/ diff --git a/bytelatent/args.py b/bytelatent/args.py index 4fba100..cfba3bf 100644 --- a/bytelatent/args.py +++ b/bytelatent/args.py @@ -46,8 +46,11 @@ def distribute_data_to_rank( arrow_batch_size: int, rank: int, world_size: int, + s3_profile: str | None = None, ) -> ArrowFileIterator: - dataset_chunks = find_and_sanitize_chunks(dataset_path, world_size) + dataset_chunks = find_and_sanitize_chunks( + dataset_path, world_size, s3_profile=s3_profile + ) n_workers_per_chunk = world_size // len(dataset_chunks) rank_to_arrow_iterator_params = [] for chunk_path in dataset_chunks: @@ -61,6 +64,7 @@ def distribute_data_to_rank( dataset_files=None, entropy_model_name=entropy_model_name, arrow_batch_size=arrow_batch_size, + s3_profile=s3_profile, ) ) return rank_to_arrow_iterator_params[rank] @@ -68,6 +72,7 @@ def distribute_data_to_rank( class DataloaderArgs(BaseModel): model_config = ConfigDict(extra="forbid") + s3_profile: str | None = None root_dir: str | None = None sources: dict[str, float] = {} batch_size: int = 2 @@ -107,6 +112,7 @@ class DataloaderArgs(BaseModel): arrow_batch_size=self.arrow_batch_size, rank=rank, world_size=world_size, + s3_profile=self.s3_profile, ) looping_iterator = LoopingIterator(arrow_iterator) preprocess_iterator = PreprocessIterator( diff --git a/bytelatent/data/file_util.py b/bytelatent/data/file_util.py new file mode 100644 index 0000000..312e1b4 --- /dev/null +++ b/bytelatent/data/file_util.py @@ -0,0 +1,116 @@ +import os +from rich.progress import track +import typer +import fsspec +import pyarrow as pa +from pyarrow.lib import ArrowInvalid + +# pyarrow needs the initialization from this import +import pyarrow.dataset # pyright: ignore + + +def is_valid_arrow_file(path: str): + try: + dataset = pa.dataset.dataset(path, format="arrow") + return True + except ArrowInvalid: + return False + + +app = typer.Typer() + +S3_PREFIX = "s3://" + + +def get_fs(path: str, s3_profile: str | None = None) -> fsspec.AbstractFileSystem: + if path.startswith("s3://"): + if s3_profile is None: + return fsspec.filesystem("s3") + else: + return fsspec.filesystem("s3", profile=s3_profile) + else: + return fsspec.filesystem("file") + + +@app.command() +def print_local_to_delete( + blob_dir: str, local_dirs: list[str], s3_profile: str = "blt" +): + for s in local_dirs: + assert s.endswith("/"), "Dirs must end with /" + assert blob_dir.endswith("/"), "Dirs must end with /" + blob_fs = fsspec.filesystem("s3", profile=s3_profile) + blob_files = blob_fs.find(blob_dir) + for f in track(blob_files): + size = blob_fs.info(f)["Size"] + if not f.lower().endswith(".complete"): + assert size != 0, f"Size was invalidly zero for {f}" + + blob_relative_paths = {f[len(blob_dir) - len(S3_PREFIX) :] for f in blob_files} + local_fs = fsspec.filesystem("file") + + files_to_delete = [] + for local_dir in local_dirs: + local_files = local_fs.find(local_dir) + for f in local_files: + relative_path = f[len(local_dir) :] + if relative_path in blob_relative_paths and not os.path.islink(f): + files_to_delete.append(f) + print(len(files_to_delete)) + with open("/tmp/files_to_delete.txt", "w") as f: + for file in files_to_delete: + f.write(f"{file}\n") + + +@app.command() +def compare_local_to_blob( + source_dirs: list[str], dst_dir: str, s3_profile: str = "blt" +): + for s in source_dirs: + assert s.endswith("/"), "Dirs must end with /" + assert dst_dir.endswith("/"), "Dirs must end with /" + assert len(source_dirs) != 0 + assert dst_dir.startswith("s3://") + local_fs = fsspec.filesystem("file") + dst_fs = fsspec.filesystem("s3", profile=s3_profile) + source_to_files = {} + all_local_files = set() + for s in source_dirs: + skipped = [] + if s not in source_to_files: + source_to_files[s] = [] + for f in local_fs.find(s): + if os.path.islink(f): + continue + if f.endswith(".COMPLETE") or f.endswith(".complete"): + is_complete_file = True + assert os.path.getsize(f) == 0, ".COMPLETE files should be empty" + else: + is_complete_file = False + + if not is_complete_file and os.path.getsize(f) == 0: + skipped.append(f) + continue + if f.endswith(".arrow"): + if not is_valid_arrow_file(f): + skipped.append(f) + continue + + source_to_files[s].append(f) + all_local_files.add(f[len(s) :]) + print(s, len(source_to_files[s]), "skipped", len(skipped), skipped[:10]) + + dst_files = dst_fs.find(dst_dir) + print(dst_dir, len(dst_files)) + + dst_file_set = {f[len(dst_dir) - len(S3_PREFIX) :] for f in dst_files} + diff = all_local_files.symmetric_difference(dst_file_set) + print("Local files", len(all_local_files)) + print("DST Files", len(dst_file_set)) + print("Symmetric difference", len(diff)) + dst_only_files = dst_file_set - all_local_files + print("DST only", len(dst_only_files), list(dst_only_files)[:10]) + + +if __name__ == "__main__": + app() diff --git a/bytelatent/data/iterators/arrow_iterator.py b/bytelatent/data/iterators/arrow_iterator.py index df5f023..17569ee 100644 --- a/bytelatent/data/iterators/arrow_iterator.py +++ b/bytelatent/data/iterators/arrow_iterator.py @@ -1,7 +1,7 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. +import os import re from logging import getLogger -from pathlib import Path from typing import Any, Generator import pyarrow as pa @@ -9,9 +9,12 @@ import pyarrow as pa # pyarrow needs the initialization from this import import pyarrow.dataset # pyright: ignore from pydantic import BaseModel, ConfigDict +import fsspec +import s3fs from bytelatent import ByteLatentError from bytelatent.data.data_types import BltExample +from bytelatent.data.file_util import get_fs from bytelatent.data.iterators.abstract_iterator import IteratorState, StatefulIterator logger = getLogger(__name__) @@ -27,6 +30,8 @@ class ArrowFileIteratorState(BaseModel, IteratorState): dataset_files: list[str] | None entropy_model_name: str | None arrow_batch_size: int = 100 + s3_profile: str | None + filesystem_type: str | None = None def build(self) -> "ArrowFileIterator": arrow_file = ArrowFileIterator( @@ -37,14 +42,17 @@ class ArrowFileIteratorState(BaseModel, IteratorState): entropy_model_name=self.entropy_model_name, arrow_batch_size=self.arrow_batch_size, dataset_files=self.dataset_files, + s3_profile=self.s3_profile, + filesystem_type=self.filesystem_type, ) if self.row_num != 0: arrow_file._set_row_num(self.row_num) return arrow_file -def shard_sort_key(file: str | Path): - match = re.search(r".+\.shard_([0-9]+)\.arrow", str(file)) +def shard_sort_key(file: str): + assert isinstance(file, str) + match = re.search(r".+\.shard_([0-9]+)\.arrow", file) shard_number = int(match.group(1)) return shard_number @@ -60,6 +68,8 @@ class ArrowFileIterator(StatefulIterator): entropy_model_name: str | None, arrow_batch_size: int, dataset_files: list[str] | None = None, + s3_profile: str | None = None, + filesystem_type: str | None = None, ): assert 0 <= worker_id < num_workers, (worker_id, num_workers) if file_path is None and dataset_files is None: @@ -75,16 +85,44 @@ class ArrowFileIterator(StatefulIterator): self.preprocess_dir = preprocess_dir self.entropy_model_name = entropy_model_name self.arrow_batch_size = arrow_batch_size + self.s3_profile = s3_profile + self.filesystem_type = filesystem_type + self.fs = None + if self.filesystem_type is not None: + if self.filesystem_type == "file": + self.fs = fsspec.filesystem("file") + elif self.filesystem_type == "s3": + self.fs = fsspec.filesystem("s3", profile=s3_profile) + if dataset_files is None: # Prepare arrow shards - jsonl_file = Path(file_path) - parts = re.match(r"(.+)\.chunk\.[0-9]+\.jsonl", jsonl_file.name) + jsonl_file = file_path + parts = re.match( + r"(.+)\.chunk\.[0-9]+\.jsonl", os.path.basename(jsonl_file) + ) assert parts is not None dataset = parts.group(1) - data_dir = Path(preprocess_dir) / dataset / entropy_model_name - shard_files = list(data_dir.glob(f"{jsonl_file.name}.shard_*.arrow")) + data_dir = os.path.join(preprocess_dir, dataset, entropy_model_name) + data_dir_with_glob = os.path.join( + data_dir, f"{os.path.basename(jsonl_file)}.shard_*.arrow" + ) + print("preprocess_dir", preprocess_dir) + print("data_dir_with_glob", data_dir_with_glob) + if self.fs is None: + self.fs = get_fs(data_dir_with_glob, s3_profile=s3_profile) + if isinstance(self.fs, s3fs.S3FileSystem): + self.filesystem_type = "s3" + else: + self.filesystem_type = "file" + print("fs", self.fs) + shard_files = self.fs.glob(data_dir_with_glob) + for s in shard_files: - if not (data_dir / f"{s.name}.complete").exists(): + complete_file = os.path.join( + data_dir, f"{os.path.basename(s)}.complete" + ) + + if not self.fs.exists(complete_file): raise ValueError(f"Missing .complete for input file: {s}") shard_files = sorted(shard_files, key=shard_sort_key) @@ -92,10 +130,19 @@ class ArrowFileIterator(StatefulIterator): raise ByteLatentError( f"Zero shard_files found corresponding to: {file_path} using preprocess_dir={preprocess_dir} and entropy_model_name={entropy_model_name}, so the search path is data_dir={data_dir} for matches to {jsonl_file.name}.shard_*.arrow" ) - self.dataset_files = [str(f) for f in shard_files] + self.dataset_files = [f for f in shard_files] else: self.preprocess_dir = None self.dataset_files = dataset_files + if dataset_files[0].startswith("s3://"): + for f in dataset_files: + assert f.startswith("s3://") + if self.fs is None: + self.fs = get_fs(dataset_files[0], s3_profile=s3_profile) + if isinstance(self.fs, s3fs.S3FileSystem): + self.filesystem_type = "s3" + else: + self.filesystem_type = "file" def get_state(self) -> ArrowFileIteratorState: return ArrowFileIteratorState( @@ -107,13 +154,24 @@ class ArrowFileIterator(StatefulIterator): entropy_model_name=self.entropy_model_name, arrow_batch_size=self.arrow_batch_size, dataset_files=self.dataset_files, + s3_profile=self.s3_profile, + filesystem_type=self.filesystem_type, ) def create_iter( self, ) -> Generator[BltExample, Any, None]: if self.dataset is None: - self.dataset = pa.dataset.dataset(self.dataset_files, format="arrow") + if isinstance(self.fs, s3fs.core.S3FileSystem): + print("Using s3fs") + filesystem = self.fs + else: + print("Using local fs") + filesystem = None + print("Iterating over", self.dataset_files) + self.dataset = pa.dataset.dataset( + self.dataset_files, format="arrow", filesystem=filesystem + ) self.batch_iterator = self.dataset.to_batches( batch_size=self.arrow_batch_size ) @@ -165,7 +223,13 @@ class ArrowFileIterator(StatefulIterator): self.batch_iterator = None self.batch_to_consume = None else: - self.dataset = pa.dataset.dataset(self.dataset_files, format="arrow") + if isinstance(self.fs, s3fs.core.S3FileSystem): + filesystem = self.fs + else: + filesystem = None + self.dataset = pa.dataset.dataset( + self.dataset_files, format="arrow", filesystem=filesystem + ) self.batch_iterator = self.dataset.to_batches( batch_size=self.arrow_batch_size ) @@ -198,9 +262,14 @@ TRAIN_DATA_FILE_PATTERN = "*.chunk.*.jsonl" def find_and_sanitize_chunks( - dataset_path: str, world_size: int, file_pattern: str = TRAIN_DATA_FILE_PATTERN + dataset_path: str, + world_size: int, + file_pattern: str = TRAIN_DATA_FILE_PATTERN, + s3_profile: str | None = None, ): - dataset_chunks = [str(p) for p in Path(dataset_path).glob(file_pattern)] + fs = get_fs(dataset_path, s3_profile=s3_profile) + path_with_glob = os.path.join(dataset_path, file_pattern) + dataset_chunks = fs.glob(path_with_glob) n_chunks = len(dataset_chunks) if n_chunks > world_size: diff --git a/bytelatent/logger.py b/bytelatent/logger.py index 6723a84..87f04cc 100644 --- a/bytelatent/logger.py +++ b/bytelatent/logger.py @@ -91,7 +91,7 @@ def init_logger( log_file: str | None = None, *, name: str | None = None, - level: str = "NOTSET", + level: str = "INFO", ): """ Setup logging. diff --git a/requirements.txt b/requirements.txt index c6d87f1..59192cd 100644 --- a/requirements.txt +++ b/requirements.txt @@ -20,3 +20,4 @@ altair submitit typer rich +fsspec[full] diff --git a/setup/download_prepare_hf_data.py b/setup/download_prepare_hf_data.py index 1aacf8f..e194857 100644 --- a/setup/download_prepare_hf_data.py +++ b/setup/download_prepare_hf_data.py @@ -5,6 +5,7 @@ import os import subprocess import time +import fsspec import requests from huggingface_hub import snapshot_download @@ -38,11 +39,21 @@ def download_dataset(repo_id, local_dir, allow_patterns): print(f"Dataset downloaded to {local_dir}") -def parquet_to_jsonl(dataset, work_dir, src_dir, tgt_dir, ntasks=64): +def parquet_to_jsonl( + dataset, work_dir, src_dir, tgt_dir, ntasks=64, s3_profile: str | None = None +): from datatrove.executor import LocalPipelineExecutor from datatrove.pipeline.readers import ParquetReader from datatrove.pipeline.writers import JsonlWriter + if tgt_dir.startswith("s3//"): + if s3_profile is None: + out_spec = tgt_dir + else: + out_spec = (tgt_dir, fsspec.filesystem("s3", profile=s3_profile)) + else: + out_spec = tgt_dir + pipeline_exec = LocalPipelineExecutor( pipeline=[ ParquetReader( @@ -52,7 +63,7 @@ def parquet_to_jsonl(dataset, work_dir, src_dir, tgt_dir, ntasks=64): glob_pattern="**/*.parquet", ), JsonlWriter( - tgt_dir, + out_spec, output_filename=dataset + ".chunk.${rank}.jsonl", compression=None, ), @@ -77,7 +88,7 @@ def setup_terashuf(work_dir): return terashuf_dir -def main(dataset, memory, data_dir, seed=42, nchunks=32): +def main(dataset, memory, data_dir, seed=42, nchunks=32, s3_profile: str | None = None): # Configuration repo_id = { "fineweb_edu": "HuggingFaceFW/fineweb-edu", From 1bf6d15e5a3ff76f128288fc547e2604f1981e2f Mon Sep 17 00:00:00 2001 From: Pedro Rodriguez Date: Fri, 10 Jan 2025 01:02:25 +0000 Subject: [PATCH 4/9] Replace regular filesystem calls with fsspec + add s3 support Summary: For compatibility with either local/nfs or S3 datasets, swap to fsspec. Add a tool to compare local and remote filesystems Test Plan: - Ran regular train script - Ran with config with data in S3 --- .gitignore | 2 +- bytelatent/args.py | 8 +- bytelatent/data/file_util.py | 116 ++++++++++++++++++++ bytelatent/data/iterators/arrow_iterator.py | 89 ++++++++++++--- bytelatent/logger.py | 2 +- requirements.txt | 1 + setup/download_prepare_hf_data.py | 17 ++- 7 files changed, 216 insertions(+), 19 deletions(-) create mode 100644 bytelatent/data/file_util.py diff --git a/.gitignore b/.gitignore index 56891a9..6c664b8 100644 --- a/.gitignore +++ b/.gitignore @@ -165,4 +165,4 @@ cython_debug/ figures/ .vscode/ .DS_Store - +internal/ diff --git a/bytelatent/args.py b/bytelatent/args.py index 4fba100..cfba3bf 100644 --- a/bytelatent/args.py +++ b/bytelatent/args.py @@ -46,8 +46,11 @@ def distribute_data_to_rank( arrow_batch_size: int, rank: int, world_size: int, + s3_profile: str | None = None, ) -> ArrowFileIterator: - dataset_chunks = find_and_sanitize_chunks(dataset_path, world_size) + dataset_chunks = find_and_sanitize_chunks( + dataset_path, world_size, s3_profile=s3_profile + ) n_workers_per_chunk = world_size // len(dataset_chunks) rank_to_arrow_iterator_params = [] for chunk_path in dataset_chunks: @@ -61,6 +64,7 @@ def distribute_data_to_rank( dataset_files=None, entropy_model_name=entropy_model_name, arrow_batch_size=arrow_batch_size, + s3_profile=s3_profile, ) ) return rank_to_arrow_iterator_params[rank] @@ -68,6 +72,7 @@ def distribute_data_to_rank( class DataloaderArgs(BaseModel): model_config = ConfigDict(extra="forbid") + s3_profile: str | None = None root_dir: str | None = None sources: dict[str, float] = {} batch_size: int = 2 @@ -107,6 +112,7 @@ class DataloaderArgs(BaseModel): arrow_batch_size=self.arrow_batch_size, rank=rank, world_size=world_size, + s3_profile=self.s3_profile, ) looping_iterator = LoopingIterator(arrow_iterator) preprocess_iterator = PreprocessIterator( diff --git a/bytelatent/data/file_util.py b/bytelatent/data/file_util.py new file mode 100644 index 0000000..312e1b4 --- /dev/null +++ b/bytelatent/data/file_util.py @@ -0,0 +1,116 @@ +import os +from rich.progress import track +import typer +import fsspec +import pyarrow as pa +from pyarrow.lib import ArrowInvalid + +# pyarrow needs the initialization from this import +import pyarrow.dataset # pyright: ignore + + +def is_valid_arrow_file(path: str): + try: + dataset = pa.dataset.dataset(path, format="arrow") + return True + except ArrowInvalid: + return False + + +app = typer.Typer() + +S3_PREFIX = "s3://" + + +def get_fs(path: str, s3_profile: str | None = None) -> fsspec.AbstractFileSystem: + if path.startswith("s3://"): + if s3_profile is None: + return fsspec.filesystem("s3") + else: + return fsspec.filesystem("s3", profile=s3_profile) + else: + return fsspec.filesystem("file") + + +@app.command() +def print_local_to_delete( + blob_dir: str, local_dirs: list[str], s3_profile: str = "blt" +): + for s in local_dirs: + assert s.endswith("/"), "Dirs must end with /" + assert blob_dir.endswith("/"), "Dirs must end with /" + blob_fs = fsspec.filesystem("s3", profile=s3_profile) + blob_files = blob_fs.find(blob_dir) + for f in track(blob_files): + size = blob_fs.info(f)["Size"] + if not f.lower().endswith(".complete"): + assert size != 0, f"Size was invalidly zero for {f}" + + blob_relative_paths = {f[len(blob_dir) - len(S3_PREFIX) :] for f in blob_files} + local_fs = fsspec.filesystem("file") + + files_to_delete = [] + for local_dir in local_dirs: + local_files = local_fs.find(local_dir) + for f in local_files: + relative_path = f[len(local_dir) :] + if relative_path in blob_relative_paths and not os.path.islink(f): + files_to_delete.append(f) + print(len(files_to_delete)) + with open("/tmp/files_to_delete.txt", "w") as f: + for file in files_to_delete: + f.write(f"{file}\n") + + +@app.command() +def compare_local_to_blob( + source_dirs: list[str], dst_dir: str, s3_profile: str = "blt" +): + for s in source_dirs: + assert s.endswith("/"), "Dirs must end with /" + assert dst_dir.endswith("/"), "Dirs must end with /" + assert len(source_dirs) != 0 + assert dst_dir.startswith("s3://") + local_fs = fsspec.filesystem("file") + dst_fs = fsspec.filesystem("s3", profile=s3_profile) + source_to_files = {} + all_local_files = set() + for s in source_dirs: + skipped = [] + if s not in source_to_files: + source_to_files[s] = [] + for f in local_fs.find(s): + if os.path.islink(f): + continue + if f.endswith(".COMPLETE") or f.endswith(".complete"): + is_complete_file = True + assert os.path.getsize(f) == 0, ".COMPLETE files should be empty" + else: + is_complete_file = False + + if not is_complete_file and os.path.getsize(f) == 0: + skipped.append(f) + continue + if f.endswith(".arrow"): + if not is_valid_arrow_file(f): + skipped.append(f) + continue + + source_to_files[s].append(f) + all_local_files.add(f[len(s) :]) + print(s, len(source_to_files[s]), "skipped", len(skipped), skipped[:10]) + + dst_files = dst_fs.find(dst_dir) + print(dst_dir, len(dst_files)) + + dst_file_set = {f[len(dst_dir) - len(S3_PREFIX) :] for f in dst_files} + diff = all_local_files.symmetric_difference(dst_file_set) + print("Local files", len(all_local_files)) + print("DST Files", len(dst_file_set)) + print("Symmetric difference", len(diff)) + dst_only_files = dst_file_set - all_local_files + print("DST only", len(dst_only_files), list(dst_only_files)[:10]) + + +if __name__ == "__main__": + app() diff --git a/bytelatent/data/iterators/arrow_iterator.py b/bytelatent/data/iterators/arrow_iterator.py index df5f023..ad72091 100644 --- a/bytelatent/data/iterators/arrow_iterator.py +++ b/bytelatent/data/iterators/arrow_iterator.py @@ -1,7 +1,7 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. +import os import re from logging import getLogger -from pathlib import Path from typing import Any, Generator import pyarrow as pa @@ -9,9 +9,12 @@ import pyarrow as pa # pyarrow needs the initialization from this import import pyarrow.dataset # pyright: ignore from pydantic import BaseModel, ConfigDict +import fsspec +import s3fs from bytelatent import ByteLatentError from bytelatent.data.data_types import BltExample +from bytelatent.data.file_util import get_fs from bytelatent.data.iterators.abstract_iterator import IteratorState, StatefulIterator logger = getLogger(__name__) @@ -27,6 +30,8 @@ class ArrowFileIteratorState(BaseModel, IteratorState): dataset_files: list[str] | None entropy_model_name: str | None arrow_batch_size: int = 100 + s3_profile: str | None + filesystem_type: str | None = None def build(self) -> "ArrowFileIterator": arrow_file = ArrowFileIterator( @@ -37,14 +42,17 @@ class ArrowFileIteratorState(BaseModel, IteratorState): entropy_model_name=self.entropy_model_name, arrow_batch_size=self.arrow_batch_size, dataset_files=self.dataset_files, + s3_profile=self.s3_profile, + filesystem_type=self.filesystem_type, ) if self.row_num != 0: arrow_file._set_row_num(self.row_num) return arrow_file -def shard_sort_key(file: str | Path): - match = re.search(r".+\.shard_([0-9]+)\.arrow", str(file)) +def shard_sort_key(file: str): + assert isinstance(file, str) + match = re.search(r".+\.shard_([0-9]+)\.arrow", file) shard_number = int(match.group(1)) return shard_number @@ -60,6 +68,8 @@ class ArrowFileIterator(StatefulIterator): entropy_model_name: str | None, arrow_batch_size: int, dataset_files: list[str] | None = None, + s3_profile: str | None = None, + filesystem_type: str | None = None, ): assert 0 <= worker_id < num_workers, (worker_id, num_workers) if file_path is None and dataset_files is None: @@ -75,16 +85,41 @@ class ArrowFileIterator(StatefulIterator): self.preprocess_dir = preprocess_dir self.entropy_model_name = entropy_model_name self.arrow_batch_size = arrow_batch_size + self.s3_profile = s3_profile + self.filesystem_type = filesystem_type + self.fs = None + if self.filesystem_type is not None: + if self.filesystem_type == "file": + self.fs = fsspec.filesystem("file") + elif self.filesystem_type == "s3": + self.fs = fsspec.filesystem("s3", profile=s3_profile) + if dataset_files is None: # Prepare arrow shards - jsonl_file = Path(file_path) - parts = re.match(r"(.+)\.chunk\.[0-9]+\.jsonl", jsonl_file.name) + jsonl_file = file_path + parts = re.match( + r"(.+)\.chunk\.[0-9]+\.jsonl", os.path.basename(jsonl_file) + ) assert parts is not None dataset = parts.group(1) - data_dir = Path(preprocess_dir) / dataset / entropy_model_name - shard_files = list(data_dir.glob(f"{jsonl_file.name}.shard_*.arrow")) + data_dir = os.path.join(preprocess_dir, dataset, entropy_model_name) + data_dir_with_glob = os.path.join( + data_dir, f"{os.path.basename(jsonl_file)}.shard_*.arrow" + ) + if self.fs is None: + self.fs = get_fs(data_dir_with_glob, s3_profile=s3_profile) + if isinstance(self.fs, s3fs.S3FileSystem): + self.filesystem_type = "s3" + else: + self.filesystem_type = "file" + shard_files = self.fs.glob(data_dir_with_glob) + for s in shard_files: - if not (data_dir / f"{s.name}.complete").exists(): + complete_file = os.path.join( + data_dir, f"{os.path.basename(s)}.complete" + ) + + if not self.fs.exists(complete_file): raise ValueError(f"Missing .complete for input file: {s}") shard_files = sorted(shard_files, key=shard_sort_key) @@ -92,10 +127,19 @@ class ArrowFileIterator(StatefulIterator): raise ByteLatentError( f"Zero shard_files found corresponding to: {file_path} using preprocess_dir={preprocess_dir} and entropy_model_name={entropy_model_name}, so the search path is data_dir={data_dir} for matches to {jsonl_file.name}.shard_*.arrow" ) - self.dataset_files = [str(f) for f in shard_files] + self.dataset_files = [f for f in shard_files] else: self.preprocess_dir = None self.dataset_files = dataset_files + if dataset_files[0].startswith("s3://"): + for f in dataset_files: + assert f.startswith("s3://") + if self.fs is None: + self.fs = get_fs(dataset_files[0], s3_profile=s3_profile) + if isinstance(self.fs, s3fs.S3FileSystem): + self.filesystem_type = "s3" + else: + self.filesystem_type = "file" def get_state(self) -> ArrowFileIteratorState: return ArrowFileIteratorState( @@ -107,13 +151,21 @@ class ArrowFileIterator(StatefulIterator): entropy_model_name=self.entropy_model_name, arrow_batch_size=self.arrow_batch_size, dataset_files=self.dataset_files, + s3_profile=self.s3_profile, + filesystem_type=self.filesystem_type, ) def create_iter( self, ) -> Generator[BltExample, Any, None]: if self.dataset is None: - self.dataset = pa.dataset.dataset(self.dataset_files, format="arrow") + if isinstance(self.fs, s3fs.core.S3FileSystem): + filesystem = self.fs + else: + filesystem = None + self.dataset = pa.dataset.dataset( + self.dataset_files, format="arrow", filesystem=filesystem + ) self.batch_iterator = self.dataset.to_batches( batch_size=self.arrow_batch_size ) @@ -165,7 +217,13 @@ class ArrowFileIterator(StatefulIterator): self.batch_iterator = None self.batch_to_consume = None else: - self.dataset = pa.dataset.dataset(self.dataset_files, format="arrow") + if isinstance(self.fs, s3fs.core.S3FileSystem): + filesystem = self.fs + else: + filesystem = None + self.dataset = pa.dataset.dataset( + self.dataset_files, format="arrow", filesystem=filesystem + ) self.batch_iterator = self.dataset.to_batches( batch_size=self.arrow_batch_size ) @@ -198,9 +256,14 @@ TRAIN_DATA_FILE_PATTERN = "*.chunk.*.jsonl" def find_and_sanitize_chunks( - dataset_path: str, world_size: int, file_pattern: str = TRAIN_DATA_FILE_PATTERN + dataset_path: str, + world_size: int, + file_pattern: str = TRAIN_DATA_FILE_PATTERN, + s3_profile: str | None = None, ): - dataset_chunks = [str(p) for p in Path(dataset_path).glob(file_pattern)] + fs = get_fs(dataset_path, s3_profile=s3_profile) + path_with_glob = os.path.join(dataset_path, file_pattern) + dataset_chunks = fs.glob(path_with_glob) n_chunks = len(dataset_chunks) if n_chunks > world_size: diff --git a/bytelatent/logger.py b/bytelatent/logger.py index 6723a84..87f04cc 100644 --- a/bytelatent/logger.py +++ b/bytelatent/logger.py @@ -91,7 +91,7 @@ def init_logger( log_file: str | None = None, *, name: str | None = None, - level: str = "NOTSET", + level: str = "INFO", ): """ Setup logging. diff --git a/requirements.txt b/requirements.txt index c6d87f1..59192cd 100644 --- a/requirements.txt +++ b/requirements.txt @@ -20,3 +20,4 @@ altair submitit typer rich +fsspec[full] diff --git a/setup/download_prepare_hf_data.py b/setup/download_prepare_hf_data.py index 1aacf8f..e194857 100644 --- a/setup/download_prepare_hf_data.py +++ b/setup/download_prepare_hf_data.py @@ -5,6 +5,7 @@ import os import subprocess import time +import fsspec import requests from huggingface_hub import snapshot_download @@ -38,11 +39,21 @@ def download_dataset(repo_id, local_dir, allow_patterns): print(f"Dataset downloaded to {local_dir}") -def parquet_to_jsonl(dataset, work_dir, src_dir, tgt_dir, ntasks=64): +def parquet_to_jsonl( + dataset, work_dir, src_dir, tgt_dir, ntasks=64, s3_profile: str | None = None +): from datatrove.executor import LocalPipelineExecutor from datatrove.pipeline.readers import ParquetReader from datatrove.pipeline.writers import JsonlWriter + if tgt_dir.startswith("s3//"): + if s3_profile is None: + out_spec = tgt_dir + else: + out_spec = (tgt_dir, fsspec.filesystem("s3", profile=s3_profile)) + else: + out_spec = tgt_dir + pipeline_exec = LocalPipelineExecutor( pipeline=[ ParquetReader( @@ -52,7 +63,7 @@ def parquet_to_jsonl(dataset, work_dir, src_dir, tgt_dir, ntasks=64): glob_pattern="**/*.parquet", ), JsonlWriter( - tgt_dir, + out_spec, output_filename=dataset + ".chunk.${rank}.jsonl", compression=None, ), @@ -77,7 +88,7 @@ def setup_terashuf(work_dir): return terashuf_dir -def main(dataset, memory, data_dir, seed=42, nchunks=32): +def main(dataset, memory, data_dir, seed=42, nchunks=32, s3_profile: str | None = None): # Configuration repo_id = { "fineweb_edu": "HuggingFaceFW/fineweb-edu", From a1d05403b42cdb111be2cf3034a044e63f45b91b Mon Sep 17 00:00:00 2001 From: Pedro Rodriguez Date: Fri, 10 Jan 2025 01:02:25 +0000 Subject: [PATCH 5/9] Replace regular filesystem calls with fsspec + add s3 support Summary: For compatibility with either local/nfs or S3 datasets, swap to fsspec. Add a tool to compare local and remote filesystems Test Plan: - Ran regular train script - Ran with config with data in S3 --- .gitignore | 2 +- bytelatent/args.py | 8 +- bytelatent/data/file_util.py | 117 ++++++++++++++++++++ bytelatent/data/iterators/arrow_iterator.py | 89 ++++++++++++--- bytelatent/logger.py | 2 +- requirements.txt | 1 + setup/download_prepare_hf_data.py | 17 ++- 7 files changed, 217 insertions(+), 19 deletions(-) create mode 100644 bytelatent/data/file_util.py diff --git a/.gitignore b/.gitignore index 56891a9..6c664b8 100644 --- a/.gitignore +++ b/.gitignore @@ -165,4 +165,4 @@ cython_debug/ figures/ .vscode/ .DS_Store - +internal/ diff --git a/bytelatent/args.py b/bytelatent/args.py index 4fba100..cfba3bf 100644 --- a/bytelatent/args.py +++ b/bytelatent/args.py @@ -46,8 +46,11 @@ def distribute_data_to_rank( arrow_batch_size: int, rank: int, world_size: int, + s3_profile: str | None = None, ) -> ArrowFileIterator: - dataset_chunks = find_and_sanitize_chunks(dataset_path, world_size) + dataset_chunks = find_and_sanitize_chunks( + dataset_path, world_size, s3_profile=s3_profile + ) n_workers_per_chunk = world_size // len(dataset_chunks) rank_to_arrow_iterator_params = [] for chunk_path in dataset_chunks: @@ -61,6 +64,7 @@ def distribute_data_to_rank( dataset_files=None, entropy_model_name=entropy_model_name, arrow_batch_size=arrow_batch_size, + s3_profile=s3_profile, ) ) return rank_to_arrow_iterator_params[rank] @@ -68,6 +72,7 @@ def distribute_data_to_rank( class DataloaderArgs(BaseModel): model_config = ConfigDict(extra="forbid") + s3_profile: str | None = None root_dir: str | None = None sources: dict[str, float] = {} batch_size: int = 2 @@ -107,6 +112,7 @@ class DataloaderArgs(BaseModel): arrow_batch_size=self.arrow_batch_size, rank=rank, world_size=world_size, + s3_profile=self.s3_profile, ) looping_iterator = LoopingIterator(arrow_iterator) preprocess_iterator = PreprocessIterator( diff --git a/bytelatent/data/file_util.py b/bytelatent/data/file_util.py new file mode 100644 index 0000000..d67b6db --- /dev/null +++ b/bytelatent/data/file_util.py @@ -0,0 +1,117 @@ +import os + +import fsspec +import pyarrow as pa + +# pyarrow needs the initialization from this import +import pyarrow.dataset # pyright: ignore +import typer +from pyarrow.lib import ArrowInvalid +from rich.progress import track + + +def is_valid_arrow_file(path: str): + try: + dataset = pa.dataset.dataset(path, format="arrow") + return True + except ArrowInvalid: + return False + + +app = typer.Typer() + +S3_PREFIX = "s3://" + + +def get_fs(path: str, s3_profile: str | None = None) -> fsspec.AbstractFileSystem: + if path.startswith("s3://"): + if s3_profile is None: + return fsspec.filesystem("s3") + else: + return fsspec.filesystem("s3", profile=s3_profile) + else: + return fsspec.filesystem("file") + + +@app.command() +def print_local_to_delete( + blob_dir: str, local_dirs: list[str], s3_profile: str = "blt" +): + for s in local_dirs: + assert s.endswith("/"), "Dirs must end with /" + assert blob_dir.endswith("/"), "Dirs must end with /" + blob_fs = fsspec.filesystem("s3", profile=s3_profile) + blob_files = blob_fs.find(blob_dir) + for f in track(blob_files): + size = blob_fs.info(f)["Size"] + if not f.lower().endswith(".complete"): + assert size != 0, f"Size was invalidly zero for {f}" + + blob_relative_paths = {f[len(blob_dir) - len(S3_PREFIX) :] for f in blob_files} + local_fs = fsspec.filesystem("file") + + files_to_delete = [] + for local_dir in local_dirs: + local_files = local_fs.find(local_dir) + for f in local_files: + relative_path = f[len(local_dir) :] + if relative_path in blob_relative_paths and not os.path.islink(f): + files_to_delete.append(f) + print(len(files_to_delete)) + with open("/tmp/files_to_delete.txt", "w") as f: + for file in files_to_delete: + f.write(f"{file}\n") + + +@app.command() +def compare_local_to_blob( + source_dirs: list[str], dst_dir: str, s3_profile: str = "blt" +): + for s in source_dirs: + assert s.endswith("/"), "Dirs must end with /" + assert dst_dir.endswith("/"), "Dirs must end with /" + assert len(source_dirs) != 0 + assert dst_dir.startswith("s3://") + local_fs = fsspec.filesystem("file") + dst_fs = fsspec.filesystem("s3", profile=s3_profile) + source_to_files = {} + all_local_files = set() + for s in source_dirs: + skipped = [] + if s not in source_to_files: + source_to_files[s] = [] + for f in local_fs.find(s): + if os.path.islink(f): + continue + if f.endswith(".COMPLETE") or f.endswith(".complete"): + is_complete_file = True + assert os.path.getsize(f) == 0, ".COMPLETE files should be empty" + else: + is_complete_file = False + + if not is_complete_file and os.path.getsize(f) == 0: + skipped.append(f) + continue + if f.endswith(".arrow"): + if not is_valid_arrow_file(f): + skipped.append(f) + continue + + source_to_files[s].append(f) + all_local_files.add(f[len(s) :]) + print(s, len(source_to_files[s]), "skipped", len(skipped), skipped[:10]) + + dst_files = dst_fs.find(dst_dir) + print(dst_dir, len(dst_files)) + + dst_file_set = {f[len(dst_dir) - len(S3_PREFIX) :] for f in dst_files} + diff = all_local_files.symmetric_difference(dst_file_set) + print("Local files", len(all_local_files)) + print("DST Files", len(dst_file_set)) + print("Symmetric difference", len(diff)) + dst_only_files = dst_file_set - all_local_files + print("DST only", len(dst_only_files), list(dst_only_files)[:10]) + + +if __name__ == "__main__": + app() diff --git a/bytelatent/data/iterators/arrow_iterator.py b/bytelatent/data/iterators/arrow_iterator.py index df5f023..4e7b99e 100644 --- a/bytelatent/data/iterators/arrow_iterator.py +++ b/bytelatent/data/iterators/arrow_iterator.py @@ -1,17 +1,20 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. +import os import re from logging import getLogger -from pathlib import Path from typing import Any, Generator +import fsspec import pyarrow as pa # pyarrow needs the initialization from this import import pyarrow.dataset # pyright: ignore +import s3fs from pydantic import BaseModel, ConfigDict from bytelatent import ByteLatentError from bytelatent.data.data_types import BltExample +from bytelatent.data.file_util import get_fs from bytelatent.data.iterators.abstract_iterator import IteratorState, StatefulIterator logger = getLogger(__name__) @@ -27,6 +30,8 @@ class ArrowFileIteratorState(BaseModel, IteratorState): dataset_files: list[str] | None entropy_model_name: str | None arrow_batch_size: int = 100 + s3_profile: str | None + filesystem_type: str | None = None def build(self) -> "ArrowFileIterator": arrow_file = ArrowFileIterator( @@ -37,14 +42,17 @@ class ArrowFileIteratorState(BaseModel, IteratorState): entropy_model_name=self.entropy_model_name, arrow_batch_size=self.arrow_batch_size, dataset_files=self.dataset_files, + s3_profile=self.s3_profile, + filesystem_type=self.filesystem_type, ) if self.row_num != 0: arrow_file._set_row_num(self.row_num) return arrow_file -def shard_sort_key(file: str | Path): - match = re.search(r".+\.shard_([0-9]+)\.arrow", str(file)) +def shard_sort_key(file: str): + assert isinstance(file, str) + match = re.search(r".+\.shard_([0-9]+)\.arrow", file) shard_number = int(match.group(1)) return shard_number @@ -60,6 +68,8 @@ class ArrowFileIterator(StatefulIterator): entropy_model_name: str | None, arrow_batch_size: int, dataset_files: list[str] | None = None, + s3_profile: str | None = None, + filesystem_type: str | None = None, ): assert 0 <= worker_id < num_workers, (worker_id, num_workers) if file_path is None and dataset_files is None: @@ -75,16 +85,41 @@ class ArrowFileIterator(StatefulIterator): self.preprocess_dir = preprocess_dir self.entropy_model_name = entropy_model_name self.arrow_batch_size = arrow_batch_size + self.s3_profile = s3_profile + self.filesystem_type = filesystem_type + self.fs = None + if self.filesystem_type is not None: + if self.filesystem_type == "file": + self.fs = fsspec.filesystem("file") + elif self.filesystem_type == "s3": + self.fs = fsspec.filesystem("s3", profile=s3_profile) + if dataset_files is None: # Prepare arrow shards - jsonl_file = Path(file_path) - parts = re.match(r"(.+)\.chunk\.[0-9]+\.jsonl", jsonl_file.name) + jsonl_file = file_path + parts = re.match( + r"(.+)\.chunk\.[0-9]+\.jsonl", os.path.basename(jsonl_file) + ) assert parts is not None dataset = parts.group(1) - data_dir = Path(preprocess_dir) / dataset / entropy_model_name - shard_files = list(data_dir.glob(f"{jsonl_file.name}.shard_*.arrow")) + data_dir = os.path.join(preprocess_dir, dataset, entropy_model_name) + data_dir_with_glob = os.path.join( + data_dir, f"{os.path.basename(jsonl_file)}.shard_*.arrow" + ) + if self.fs is None: + self.fs = get_fs(data_dir_with_glob, s3_profile=s3_profile) + if isinstance(self.fs, s3fs.S3FileSystem): + self.filesystem_type = "s3" + else: + self.filesystem_type = "file" + shard_files = self.fs.glob(data_dir_with_glob) + for s in shard_files: - if not (data_dir / f"{s.name}.complete").exists(): + complete_file = os.path.join( + data_dir, f"{os.path.basename(s)}.complete" + ) + + if not self.fs.exists(complete_file): raise ValueError(f"Missing .complete for input file: {s}") shard_files = sorted(shard_files, key=shard_sort_key) @@ -92,10 +127,19 @@ class ArrowFileIterator(StatefulIterator): raise ByteLatentError( f"Zero shard_files found corresponding to: {file_path} using preprocess_dir={preprocess_dir} and entropy_model_name={entropy_model_name}, so the search path is data_dir={data_dir} for matches to {jsonl_file.name}.shard_*.arrow" ) - self.dataset_files = [str(f) for f in shard_files] + self.dataset_files = [f for f in shard_files] else: self.preprocess_dir = None self.dataset_files = dataset_files + if dataset_files[0].startswith("s3://"): + for f in dataset_files: + assert f.startswith("s3://") + if self.fs is None: + self.fs = get_fs(dataset_files[0], s3_profile=s3_profile) + if isinstance(self.fs, s3fs.S3FileSystem): + self.filesystem_type = "s3" + else: + self.filesystem_type = "file" def get_state(self) -> ArrowFileIteratorState: return ArrowFileIteratorState( @@ -107,13 +151,21 @@ class ArrowFileIterator(StatefulIterator): entropy_model_name=self.entropy_model_name, arrow_batch_size=self.arrow_batch_size, dataset_files=self.dataset_files, + s3_profile=self.s3_profile, + filesystem_type=self.filesystem_type, ) def create_iter( self, ) -> Generator[BltExample, Any, None]: if self.dataset is None: - self.dataset = pa.dataset.dataset(self.dataset_files, format="arrow") + if isinstance(self.fs, s3fs.core.S3FileSystem): + filesystem = self.fs + else: + filesystem = None + self.dataset = pa.dataset.dataset( + self.dataset_files, format="arrow", filesystem=filesystem + ) self.batch_iterator = self.dataset.to_batches( batch_size=self.arrow_batch_size ) @@ -165,7 +217,13 @@ class ArrowFileIterator(StatefulIterator): self.batch_iterator = None self.batch_to_consume = None else: - self.dataset = pa.dataset.dataset(self.dataset_files, format="arrow") + if isinstance(self.fs, s3fs.core.S3FileSystem): + filesystem = self.fs + else: + filesystem = None + self.dataset = pa.dataset.dataset( + self.dataset_files, format="arrow", filesystem=filesystem + ) self.batch_iterator = self.dataset.to_batches( batch_size=self.arrow_batch_size ) @@ -198,9 +256,14 @@ TRAIN_DATA_FILE_PATTERN = "*.chunk.*.jsonl" def find_and_sanitize_chunks( - dataset_path: str, world_size: int, file_pattern: str = TRAIN_DATA_FILE_PATTERN + dataset_path: str, + world_size: int, + file_pattern: str = TRAIN_DATA_FILE_PATTERN, + s3_profile: str | None = None, ): - dataset_chunks = [str(p) for p in Path(dataset_path).glob(file_pattern)] + fs = get_fs(dataset_path, s3_profile=s3_profile) + path_with_glob = os.path.join(dataset_path, file_pattern) + dataset_chunks = fs.glob(path_with_glob) n_chunks = len(dataset_chunks) if n_chunks > world_size: diff --git a/bytelatent/logger.py b/bytelatent/logger.py index 6723a84..87f04cc 100644 --- a/bytelatent/logger.py +++ b/bytelatent/logger.py @@ -91,7 +91,7 @@ def init_logger( log_file: str | None = None, *, name: str | None = None, - level: str = "NOTSET", + level: str = "INFO", ): """ Setup logging. diff --git a/requirements.txt b/requirements.txt index c6d87f1..59192cd 100644 --- a/requirements.txt +++ b/requirements.txt @@ -20,3 +20,4 @@ altair submitit typer rich +fsspec[full] diff --git a/setup/download_prepare_hf_data.py b/setup/download_prepare_hf_data.py index 1aacf8f..e194857 100644 --- a/setup/download_prepare_hf_data.py +++ b/setup/download_prepare_hf_data.py @@ -5,6 +5,7 @@ import os import subprocess import time +import fsspec import requests from huggingface_hub import snapshot_download @@ -38,11 +39,21 @@ def download_dataset(repo_id, local_dir, allow_patterns): print(f"Dataset downloaded to {local_dir}") -def parquet_to_jsonl(dataset, work_dir, src_dir, tgt_dir, ntasks=64): +def parquet_to_jsonl( + dataset, work_dir, src_dir, tgt_dir, ntasks=64, s3_profile: str | None = None +): from datatrove.executor import LocalPipelineExecutor from datatrove.pipeline.readers import ParquetReader from datatrove.pipeline.writers import JsonlWriter + if tgt_dir.startswith("s3//"): + if s3_profile is None: + out_spec = tgt_dir + else: + out_spec = (tgt_dir, fsspec.filesystem("s3", profile=s3_profile)) + else: + out_spec = tgt_dir + pipeline_exec = LocalPipelineExecutor( pipeline=[ ParquetReader( @@ -52,7 +63,7 @@ def parquet_to_jsonl(dataset, work_dir, src_dir, tgt_dir, ntasks=64): glob_pattern="**/*.parquet", ), JsonlWriter( - tgt_dir, + out_spec, output_filename=dataset + ".chunk.${rank}.jsonl", compression=None, ), @@ -77,7 +88,7 @@ def setup_terashuf(work_dir): return terashuf_dir -def main(dataset, memory, data_dir, seed=42, nchunks=32): +def main(dataset, memory, data_dir, seed=42, nchunks=32, s3_profile: str | None = None): # Configuration repo_id = { "fineweb_edu": "HuggingFaceFW/fineweb-edu", From 3f045f11238af657269e2051d870515e6ac7913b Mon Sep 17 00:00:00 2001 From: Pedro Rodriguez Date: Mon, 13 Jan 2025 23:13:49 +0000 Subject: [PATCH 6/9] Update preprocess_entropies script to blt inference + add fsspec support Summary: Test Plan: --- bytelatent/data/patcher.py | 8 +- bytelatent/preprocess/preprocess_entropies.py | 105 ++++++++++++------ requirements.txt | 1 + 3 files changed, 74 insertions(+), 40 deletions(-) diff --git a/bytelatent/data/patcher.py b/bytelatent/data/patcher.py index ede8b06..f8477a3 100644 --- a/bytelatent/data/patcher.py +++ b/bytelatent/data/patcher.py @@ -82,16 +82,16 @@ def calculate_entropies( if device is not None: split = split.to(device) assert torch.all(split >= 0) and torch.all(split < 260) - pred, _ = entropy_model(split) + pred = entropy_model(split) pred = pred.reshape(-1, pred.shape[-1])[ : split.numel() - pad_size, : ] # [batch_size * seq_len, vocab] pred_entropies = entropy(pred) entropies.append(pred_entropies) - entropies = torch.cat(entropies, dim=0) - entropies = entropies.reshape(tokens.shape) - return entropies + concat_entropies = torch.cat(entropies, dim=0) + concat_entropies = concat_entropies.reshape(tokens.shape) + return concat_entropies def patch_start_mask_from_entropy_with_monotonicity(entropies, t): diff --git a/bytelatent/preprocess/preprocess_entropies.py b/bytelatent/preprocess/preprocess_entropies.py index 20d1e0c..45c081c 100644 --- a/bytelatent/preprocess/preprocess_entropies.py +++ b/bytelatent/preprocess/preprocess_entropies.py @@ -1,14 +1,59 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. +import jsonlines import time -from pathlib import Path +import fsspec import numpy as np import pyarrow as pa import torch import typer from rich.progress import Progress, TextColumn -from bytelatent.data.iterators.arrow_iterator import ArrowFileIterator +from bytelatent.data.file_util import get_fs +from bytelatent.data.patcher import calculate_entropies +from bytelatent.entropy_model import load_entropy_model +from bytelatent.tokenizers.build_tokenizer import TokenizerArgs + + +def get_id_from_doc(doc: dict) -> int: + """ + We need a reliable way to ensure that samples from jsonl + and arrow are the same, but there is no unique id field, + so derive the best possible + """ + if "sample_id" in doc: + sample_id = doc["sample_id"] + elif "title" in doc: + sample_id = doc["title"] + elif "qid" in doc: + sample_id = doc["qid"] + elif "paper_id" in doc: + sample_id = doc["paper_id"] + elif "path" in doc: + sample_id = doc["path"] + elif "url" in doc: + sample_id = doc["url"] + elif "id" in doc: + sample_id = doc["id"] + else: + raise ValueError(f"Could not find a id key from: {doc.keys()}") + return str(sample_id) + + +def get_text(doc: dict): + if "text" in doc: + text = doc["text"] + elif "content" in doc: + text = doc["content"] + else: + raise ValueError(f"Could not find a text key from: {doc.keys()}") + return text + + +def jsonl_file_iterator(fs: fsspec.AbstractFileSystem, path: str): + with fs.open(path) as f: + reader = jsonlines.Reader(f) + yield from reader def main( @@ -16,39 +61,32 @@ def main( output_file: str, patching_device: str = "cuda", log_step: int = 10_000, - entropy_model_checkpoint_dir: str = "entropy_checkpoint_dir", + entropy_model_checkpoint_dir: str = "public_data/entropy_checkpoint", + entropy_model_state_dict_path: str = "public_data/entropy_model.pth", + bpe_tokenizer_path: str = "public_data/tokenizer.model", dry_run: bool = False, + s3_profile: str | None = None, ): - # TODO: Modify this to work with the new code - raise NotImplementedError() - iterator = ArrowFileIterator( - file_path=input_file, - worker_id=0, - num_workers=1, - ) - tokenization_mode = "bytes" print(f"Preprocessing entropies, input: {input_file}, output: {output_file}") print("Loading entropy model", entropy_model_checkpoint_dir) + input_fs = get_fs(input_file, s3_profile=s3_profile) + input_doc_iterator = jsonl_file_iterator(input_fs, input_file) + if dry_run: return entropy_model = load_entropy_model( - entropy_model_checkpoint_dir, device=patching_device + entropy_model_checkpoint_dir, + entropy_model_state_dict_path, + device=patching_device, ) - entropy_model, _ = to_device(entropy_model, patching_device) + print("Creating patcher") patching_batch_size = 32 print("Creating tokenizer") - tokenizer = Tokenizer( - model_path="/home/artidoro/tokenizers/llama_v2.tokenizer.model", - tokenization_mode=tokenization_mode, - # BYTE_UNITS - vocab_size_unit_1=256, - bos=True, - eos=True, - bpe_delim=False, - # This isn't used, just stores a reference for other calls we don't use - patcher=None, + tokenizer_args = TokenizerArgs( + name="blt", init_kwargs={"bpe_tokenizer_path": bpe_tokenizer_path} ) + tokenizer = tokenizer_args.build() step = 0 print("starting") start_time = time.time() @@ -59,8 +97,10 @@ def main( schema = pa.schema([sample_id_field, text_field, entropy_field]) arrow_batch_size = 1_000 + output_fs = get_fs(output_file, s3_profile=s3_profile) + try: - with pa.OSFile(output_file, "wb") as sink: + with output_fs.open(output_file, "wb") as sink: with pa.ipc.new_file(sink, schema) as writer: id_buffer = [] entropies_buffer = [] @@ -72,17 +112,9 @@ def main( task = progress.add_task( "[green]Calculating entropies...", total=None ) - for doc in iterator: + for doc in input_doc_iterator: sample_id = get_id_from_doc(doc) - - if "text" in doc: - text = doc["text"] - elif "content" in doc: - text = doc["content"] - else: - raise ValueError( - f"Could not find a text key from: {doc.keys()}" - ) + text = get_text(doc) tokens = torch.tensor(tokenizer.encode(text)) patch_start = time.time() scores = calculate_entropies( @@ -128,9 +160,10 @@ def main( entropies_buffer = [] id_buffer = [] text_buffer = [] - Path(f"{output_file}.complete").touch() + output_fs.touch(f"{output_file}.complete") except: - Path(output_file).unlink(missing_ok=True) + if output_fs.exists(output_file): + output_fs.rm(output_file) raise elapsed = time.time() - start_time print("steps", step) diff --git a/requirements.txt b/requirements.txt index 59192cd..8490556 100644 --- a/requirements.txt +++ b/requirements.txt @@ -21,3 +21,4 @@ submitit typer rich fsspec[full] +orjson From d718cfa9a105b37cd8c20ed47081f5d7bc4f0f9d Mon Sep 17 00:00:00 2001 From: Pedro Rodriguez Date: Mon, 13 Jan 2025 23:26:26 +0000 Subject: [PATCH 7/9] Update preprocess_entropies script to blt inference + add fsspec support Summary: Test Plan: --- bytelatent/data/patcher.py | 8 +- bytelatent/preprocess/preprocess_entropies.py | 105 ++++++++++++------ requirements.txt | 1 + 3 files changed, 74 insertions(+), 40 deletions(-) diff --git a/bytelatent/data/patcher.py b/bytelatent/data/patcher.py index ede8b06..f8477a3 100644 --- a/bytelatent/data/patcher.py +++ b/bytelatent/data/patcher.py @@ -82,16 +82,16 @@ def calculate_entropies( if device is not None: split = split.to(device) assert torch.all(split >= 0) and torch.all(split < 260) - pred, _ = entropy_model(split) + pred = entropy_model(split) pred = pred.reshape(-1, pred.shape[-1])[ : split.numel() - pad_size, : ] # [batch_size * seq_len, vocab] pred_entropies = entropy(pred) entropies.append(pred_entropies) - entropies = torch.cat(entropies, dim=0) - entropies = entropies.reshape(tokens.shape) - return entropies + concat_entropies = torch.cat(entropies, dim=0) + concat_entropies = concat_entropies.reshape(tokens.shape) + return concat_entropies def patch_start_mask_from_entropy_with_monotonicity(entropies, t): diff --git a/bytelatent/preprocess/preprocess_entropies.py b/bytelatent/preprocess/preprocess_entropies.py index 20d1e0c..1c19a5a 100644 --- a/bytelatent/preprocess/preprocess_entropies.py +++ b/bytelatent/preprocess/preprocess_entropies.py @@ -1,14 +1,59 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. import time -from pathlib import Path +import fsspec +import jsonlines import numpy as np import pyarrow as pa import torch import typer from rich.progress import Progress, TextColumn -from bytelatent.data.iterators.arrow_iterator import ArrowFileIterator +from bytelatent.data.file_util import get_fs +from bytelatent.data.patcher import calculate_entropies +from bytelatent.entropy_model import load_entropy_model +from bytelatent.tokenizers.build_tokenizer import TokenizerArgs + + +def get_id_from_doc(doc: dict) -> int: + """ + We need a reliable way to ensure that samples from jsonl + and arrow are the same, but there is no unique id field, + so derive the best possible + """ + if "sample_id" in doc: + sample_id = doc["sample_id"] + elif "title" in doc: + sample_id = doc["title"] + elif "qid" in doc: + sample_id = doc["qid"] + elif "paper_id" in doc: + sample_id = doc["paper_id"] + elif "path" in doc: + sample_id = doc["path"] + elif "url" in doc: + sample_id = doc["url"] + elif "id" in doc: + sample_id = doc["id"] + else: + raise ValueError(f"Could not find a id key from: {doc.keys()}") + return str(sample_id) + + +def get_text(doc: dict): + if "text" in doc: + text = doc["text"] + elif "content" in doc: + text = doc["content"] + else: + raise ValueError(f"Could not find a text key from: {doc.keys()}") + return text + + +def jsonl_file_iterator(fs: fsspec.AbstractFileSystem, path: str): + with fs.open(path) as f: + reader = jsonlines.Reader(f) + yield from reader def main( @@ -16,39 +61,32 @@ def main( output_file: str, patching_device: str = "cuda", log_step: int = 10_000, - entropy_model_checkpoint_dir: str = "entropy_checkpoint_dir", + entropy_model_checkpoint_dir: str = "public_data/entropy_checkpoint", + entropy_model_state_dict_path: str = "public_data/entropy_model.pth", + bpe_tokenizer_path: str = "public_data/tokenizer.model", dry_run: bool = False, + s3_profile: str | None = None, ): - # TODO: Modify this to work with the new code - raise NotImplementedError() - iterator = ArrowFileIterator( - file_path=input_file, - worker_id=0, - num_workers=1, - ) - tokenization_mode = "bytes" print(f"Preprocessing entropies, input: {input_file}, output: {output_file}") print("Loading entropy model", entropy_model_checkpoint_dir) + input_fs = get_fs(input_file, s3_profile=s3_profile) + input_doc_iterator = jsonl_file_iterator(input_fs, input_file) + if dry_run: return entropy_model = load_entropy_model( - entropy_model_checkpoint_dir, device=patching_device + entropy_model_checkpoint_dir, + entropy_model_state_dict_path, + device=patching_device, ) - entropy_model, _ = to_device(entropy_model, patching_device) + print("Creating patcher") patching_batch_size = 32 print("Creating tokenizer") - tokenizer = Tokenizer( - model_path="/home/artidoro/tokenizers/llama_v2.tokenizer.model", - tokenization_mode=tokenization_mode, - # BYTE_UNITS - vocab_size_unit_1=256, - bos=True, - eos=True, - bpe_delim=False, - # This isn't used, just stores a reference for other calls we don't use - patcher=None, + tokenizer_args = TokenizerArgs( + name="blt", init_kwargs={"bpe_tokenizer_path": bpe_tokenizer_path} ) + tokenizer = tokenizer_args.build() step = 0 print("starting") start_time = time.time() @@ -59,8 +97,10 @@ def main( schema = pa.schema([sample_id_field, text_field, entropy_field]) arrow_batch_size = 1_000 + output_fs = get_fs(output_file, s3_profile=s3_profile) + try: - with pa.OSFile(output_file, "wb") as sink: + with output_fs.open(output_file, "wb") as sink: with pa.ipc.new_file(sink, schema) as writer: id_buffer = [] entropies_buffer = [] @@ -72,17 +112,9 @@ def main( task = progress.add_task( "[green]Calculating entropies...", total=None ) - for doc in iterator: + for doc in input_doc_iterator: sample_id = get_id_from_doc(doc) - - if "text" in doc: - text = doc["text"] - elif "content" in doc: - text = doc["content"] - else: - raise ValueError( - f"Could not find a text key from: {doc.keys()}" - ) + text = get_text(doc) tokens = torch.tensor(tokenizer.encode(text)) patch_start = time.time() scores = calculate_entropies( @@ -128,9 +160,10 @@ def main( entropies_buffer = [] id_buffer = [] text_buffer = [] - Path(f"{output_file}.complete").touch() + output_fs.touch(f"{output_file}.complete") except: - Path(output_file).unlink(missing_ok=True) + if output_fs.exists(output_file): + output_fs.rm(output_file) raise elapsed = time.time() - start_time print("steps", step) diff --git a/requirements.txt b/requirements.txt index 59192cd..8490556 100644 --- a/requirements.txt +++ b/requirements.txt @@ -21,3 +21,4 @@ submitit typer rich fsspec[full] +orjson From 38022ac06e3f5b32daf12dfe3f02c6bbc6d1e840 Mon Sep 17 00:00:00 2001 From: Pedro Rodriguez Date: Thu, 16 Jan 2025 21:51:04 +0000 Subject: [PATCH 8/9] [WIP] Changes for training entropy model and correcting attention in local models Summary: - Refactor local model configs to be separate and clearer - Add attention arguments and correct which attention is used in local models - Preparation for being able to have an entropy train script - Fix failing unit tests Test Plan: --- bytelatent/args.py | 5 + bytelatent/base_transformer.py | 44 +++++- bytelatent/configs/debug.yaml | 2 +- .../data/iterators/test_arrow_iterator.py | 3 + bytelatent/model/blt.py | 128 ++++++++++-------- .../{transformer.py => global_transformer.py} | 17 ++- bytelatent/model/local_models.py | 94 +++++++++---- bytelatent/model/utils.py | 75 +++++++++- bytelatent/preprocess/fsspec_target.py | 38 ++++++ bytelatent/test_blt.py | 22 +-- bytelatent/test_entropy_model.py | 1 + bytelatent/transformer.py | 31 ++--- 12 files changed, 331 insertions(+), 129 deletions(-) rename bytelatent/model/{transformer.py => global_transformer.py} (93%) create mode 100644 bytelatent/preprocess/fsspec_target.py diff --git a/bytelatent/args.py b/bytelatent/args.py index cfba3bf..c6f714a 100644 --- a/bytelatent/args.py +++ b/bytelatent/args.py @@ -30,6 +30,7 @@ from bytelatent.model.blt import ByteLatentTransformerArgs from bytelatent.optim import OptimArgs from bytelatent.profiling import ProfilerArgs from bytelatent.tokenizers.build_tokenizer import TokenizerArgs +from bytelatent.transformer import LMTransformerArgs logger = logging.getLogger() @@ -176,6 +177,10 @@ class TrainArgs(BaseModel): data: DataloaderArgs = DataloaderArgs() optim: OptimArgs = OptimArgs() model: ByteLatentTransformerArgs = ByteLatentTransformerArgs() + # This is only needed for training the entropy model + entropy_model: LMTransformerArgs | None = None + # Instead of training main model, train entropy model + train_entropy_model: bool = False distributed: DistributedArgs = DistributedArgs() env: EnvironmentArgs = EnvironmentArgs() diff --git a/bytelatent/base_transformer.py b/bytelatent/base_transformer.py index 45cb7c5..969fc4b 100644 --- a/bytelatent/base_transformer.py +++ b/bytelatent/base_transformer.py @@ -4,7 +4,7 @@ from enum import Enum from typing import Optional, Tuple, Union import torch -from pydantic import BaseModel +from pydantic import BaseModel, ConfigDict from torch import nn from torch.nn import functional as F from torch.nn.attention.flex_attention import ( @@ -15,6 +15,7 @@ from torch.nn.attention.flex_attention import ( from xformers.ops import AttentionBias, fmha from bytelatent import probe +from bytelatent.tokenizers.constants import EOS_ID if int(os.environ.get("BLT_ALLOW_MISSING_FLEX_ATTENTION", False)) == 0: flex_attention_comp = torch.compile(flex_attention) @@ -30,13 +31,14 @@ class InitStdFactor(Enum): class BaseTransformerArgs(BaseModel): + model_config = ConfigDict(extra="forbid") dim: int = 512 n_layers: int = 8 - head_dim: Optional[int] = None - n_heads: Optional[int] = None - n_kv_heads: Optional[int] = None + head_dim: int | None = None + n_heads: int | None = None + n_kv_heads: int | None = None - ffn_dim_multiplier: Optional[float] = None + ffn_dim_multiplier: float | None = None multiple_of: int = 256 @@ -44,11 +46,16 @@ class BaseTransformerArgs(BaseModel): rope_theta: float = 10000.0 - init_base_std: Optional[float] = None + init_base_std: float | None = None init_std_factor: InitStdFactor = InitStdFactor.DISABLED max_seqlen: int = 1024 + attn_impl: str | None = "sdpa" + attn_bias_type: str | None = None + # Special token config + eos_id: int | None = EOS_ID + def cross_entropy(pred, target, **kwargs): return F.nll_loss( @@ -294,6 +301,18 @@ class RMSNorm(nn.Module): torch.nn.init.ones_(self.weight) # type: ignore +def _reshape_for_attn_bias( + attn_bias: AttentionBias | None, + *tensors: torch.Tensor, +) -> list[torch.Tensor]: + to_transform = list(tensors) + if isinstance(attn_bias, fmha.attn_bias.BlockDiagonalCausalMask): + # could be `view` instead of reshape during training, but for inference + # have to reshape due to strides mismatch + to_transform = [t.reshape(1, -1, *t.shape[2:]) for t in to_transform] + return to_transform + + class Attention(nn.Module): def __init__( self, @@ -371,9 +390,17 @@ class Attention(nn.Module): output = flex_attention_comp(xq, xk, xv, block_mask=mask) output = output.transpose(1, 2).contiguous() # B H S D -> B S H D - elif attn_impl == "fmha": + elif attn_impl == "xformers": assert mask is None or isinstance(mask, AttentionBias) + query_shape = xq.shape + print("Before reshape", "xq", xq.shape, "xk", xk.shape, "xv", xv.shape) + xq, xk, xv = _reshape_for_attn_bias(mask, xq, xk, xv) + print("Before reshape", "xq", xq.shape, "xk", xk.shape, "xv", xv.shape) output = fmha.memory_efficient_attention(xq, xk, xv, attn_bias=mask) + print("attn out", output.shape, "query_reshape", query_shape) + output_original_shape = output.view(query_shape) + print("Reshape success") + return output_original_shape # This uses B S H D instead of B H S D of pytorch elif attn_impl == "sdpa": @@ -545,6 +572,8 @@ class BaseTransformer(nn.Module): super().__init__() self.dim = args.dim self.init_base_std = args.init_base_std + self.attn_impl = args.attn_impl + self.attn_bias_type = args.attn_bias_type self.init_std_factor = InitStdFactor(args.init_std_factor) self.max_seqlen = args.max_seqlen self.rope_embeddings = RotaryEmbedding( @@ -552,6 +581,7 @@ class BaseTransformer(nn.Module): head_dim=args.head_dim or args.dim // args.n_heads, max_seqlen=args.max_seqlen, ) + self.eos_id = args.eos_id self.layers = nn.ModuleList() for _ in range(args.n_layers): diff --git a/bytelatent/configs/debug.yaml b/bytelatent/configs/debug.yaml index 5f6debb..fc8b943 100644 --- a/bytelatent/configs/debug.yaml +++ b/bytelatent/configs/debug.yaml @@ -58,13 +58,13 @@ model: recompute_attn: false custom_bwd: false layer_ckpt: "none" - efficient_attn: "sdpa" patch_only_encoder: false patch_only_decoder: false use_local_encoder_transformer: true init_use_gaussian: true init_use_depth: "current" attn_bias_type: "block_causal" + attn_impl: "xformers" alpha_depth: "disabled" max_length: 256 local_attention_window_len: 512 diff --git a/bytelatent/data/iterators/test_arrow_iterator.py b/bytelatent/data/iterators/test_arrow_iterator.py index 4266427..fd448eb 100644 --- a/bytelatent/data/iterators/test_arrow_iterator.py +++ b/bytelatent/data/iterators/test_arrow_iterator.py @@ -27,6 +27,7 @@ def test_basic_arrow_file(): dataset_files=[ARROW_TEST_DATA_1], row_num=0, arrow_batch_size=100, + s3_profile=None, ) arrow_file = initial_state.build() start_state = arrow_file.get_state() @@ -55,6 +56,7 @@ def test_basic_arrow_file(): dataset_files=[ARROW_TEST_DATA_1], row_num=251, arrow_batch_size=100, + s3_profile=None, ) arrow_file = resumed_state.build() for example in arrow_file.create_iter(): @@ -74,6 +76,7 @@ def test_basic_arrow_file(): dataset_files=[ARROW_TEST_DATA_1], row_num=0, arrow_batch_size=100, + s3_profile=None, ) arrow_file = rank_state.build() expected_ids = [] diff --git a/bytelatent/model/blt.py b/bytelatent/model/blt.py index 9332d19..3e405da 100644 --- a/bytelatent/model/blt.py +++ b/bytelatent/model/blt.py @@ -15,8 +15,8 @@ from bytelatent.base_transformer import ( TransformerBlock, ) from bytelatent.data.patcher import Patcher, PatcherArgs -from bytelatent.model.local_models import LocalDecoder, LocalEncoder -from bytelatent.model.transformer import GlobalTransformer +from bytelatent.model.local_models import LocalDecoder, LocalEncoder, LocalModelArgs +from bytelatent.model.global_transformer import GlobalTransformer from bytelatent.model.utils import downsample from bytelatent.tokenizers.constants import BOE_ID, BOS_ID, EOS_ID, OFFSET, PAD_ID @@ -403,7 +403,6 @@ def patch_ids_from_lengths(patch_lengths, seq_len): class ByteLatentTransformerArgs(BaseTransformerArgs): - model_config = ConfigDict(extra="forbid") # Basic model configuration seed: int = 42 vocab_size: int = -1 @@ -412,7 +411,6 @@ class ByteLatentTransformerArgs(BaseTransformerArgs): n_heads: int = 8 # TODO: What is the purpose of this parameter? weight_tying: bool = False - sliding_window: Optional[int] = None # Architecture and dimensions dim_token: int = 256 @@ -471,11 +469,6 @@ class ByteLatentTransformerArgs(BaseTransformerArgs): recompute_attn: bool = True custom_bwd: bool = False layer_ckpt: str = "all" - efficient_attn: str | None = None - - # Architecture options - patch_only_encoder: bool = False - patch_only_decoder: bool = False # Initialization and attention init_use_gaussian: bool = True @@ -541,9 +534,6 @@ class ByteLatentTransformerArgs(BaseTransformerArgs): # Logging full_logging_n_layers: int = 4 - # Special token config - eos_id: int | None = None - @model_validator(mode="after") def check_hash_byte_sizes(self) -> Self: if ( @@ -558,22 +548,6 @@ class ByteLatentTransformerArgs(BaseTransformerArgs): return self -class LocalEncoderArgs(ByteLatentTransformerArgs): - # Local encoder specific dimensions - n_heads_local_encoder: int = 8 - dim_token_emb: int | None = None - dim_patch_emb: int | None = None - - def __post_init__(self): - # Override base args with local encoder specific values - self.dim = self.dim_local_encoder - self.n_layers = self.n_layers_local_encoder - self.n_heads = self.n_heads_local_encoder - self.cross_attn_decoder = False - self.cross_attn_k = self.cross_attn_k if self.cross_attn_encoder else None - self.attn_bias_type = "local_block_causal" - - class GlobalTransformerArgs(ByteLatentTransformerArgs): # Global encoder specific dimensions dim_token_emb: int | None = None @@ -625,20 +599,42 @@ def create_global_transformer(args: ByteLatentTransformerArgs) -> GlobalTransfor def create_local_encoder(args: ByteLatentTransformerArgs) -> LocalEncoder: - # First deep copy the original args - # Replace with local encoder specific values - local_encoder_args = args.model_copy( - deep=True, - update=dict( - dim=args.dim_local_encoder, - n_layers=args.n_layers_local_encoder, - n_heads=args.n_heads_local_encoder, - dim_token_emb=get_encoder_dim_token_emb(args), - dim_patch_emb=get_encoder_dim_patch_emb(args), - cross_attn_decoder=False, - cross_attn_k=args.cross_attn_k if args.cross_attn_encoder else None, - attn_bias_type="local_block_causal", - ), + local_encoder_args = LocalModelArgs( + # Updated args + dim=args.dim_local_encoder, + n_layers=args.n_layers_local_encoder, + n_heads=args.n_heads_local_encoder, + dim_token_emb=get_encoder_dim_token_emb(args), + dim_patch_emb=get_encoder_dim_patch_emb(args), + cross_attn_encoder=args.cross_attn_encoder, + cross_attn_decoder=False, + cross_attn_k=args.cross_attn_k if args.cross_attn_encoder else None, + cross_attn_init_by_pooling=args.cross_attn_init_by_pooling, + # Defaults + head_dim=args.head_dim, + max_seqlen=args.max_encoder_seq_length, + dropout=args.dropout, + vocab_size=args.vocab_size + args.pm_size, + norm_eps=args.norm_eps, + patch_size=args.patch_size, + sliding_window=args.local_attention_window_len, + use_rope=args.use_rope, + rope_theta=args.rope_theta, + init_base_std=args.init_base_std, + init_std_factor=args.init_std_factor, + n_kv_heads=args.n_kv_heads, + attn_impl=args.attn_impl, + attn_bias_type="local_block_causal", + multiple_of=args.multiple_of, + ffn_dim_multiplier=args.ffn_dim_multiplier, + patching_mode=args.patching_mode, + use_local_encoder_transformer=args.use_local_encoder_transformer, + downsampling_by_pooling=args.downsampling_by_pooling, + encoder_hash_byte_group_size=args.encoder_hash_byte_group_size, + cross_attn_all_layers_encoder=args.cross_attn_all_layers_encoder, + cross_attn_all_layers_decoder=args.cross_attn_all_layers_decoder, + cross_attn_nheads=args.cross_attn_nheads, + eos_id=args.eos_id, ) return LocalEncoder(local_encoder_args) @@ -646,18 +642,41 @@ def create_local_encoder(args: ByteLatentTransformerArgs) -> LocalEncoder: def create_local_decoder(args: ByteLatentTransformerArgs) -> LocalDecoder: # First deep copy the original args - local_decoder_args = args.model_copy( - deep=True, - update=dict( - dim=args.dim_local_decoder, - n_layers=args.n_layers_local_decoder, - n_heads=args.n_heads_local_decoder, - cross_attn_encoder=False, - cross_attn_init_by_pooling=False, # states are already defined - dim_token_emb=get_decoder_dim_token_emb(args), - dim_patch_emb=args.dim_global, - cross_attn_k=args.cross_attn_k if args.cross_attn_decoder else None, - ), + local_decoder_args = LocalModelArgs( + dim=args.dim_local_decoder, + n_layers=args.n_layers_local_decoder, + n_heads=args.n_heads_local_decoder, + dim_token_emb=get_decoder_dim_token_emb(args), + dim_patch_emb=args.dim_global, + cross_attn_encoder=False, + cross_attn_decoder=args.cross_attn_decoder, + cross_attn_init_by_pooling=False, # states are already defined + cross_attn_k=args.cross_attn_k if args.cross_attn_decoder else None, + # Defaults + head_dim=args.head_dim, + max_seqlen=args.max_encoder_seq_length, + dropout=args.dropout, + vocab_size=args.vocab_size + args.pm_size, + norm_eps=args.norm_eps, + patch_size=args.patch_size, + sliding_window=args.local_attention_window_len, + use_rope=args.use_rope, + rope_theta=args.rope_theta, + init_base_std=args.init_base_std, + init_std_factor=args.init_std_factor, + n_kv_heads=args.n_kv_heads, + attn_impl=args.attn_impl, + attn_bias_type="local_block_causal", + multiple_of=args.multiple_of, + ffn_dim_multiplier=args.ffn_dim_multiplier, + patching_mode=args.patching_mode, + use_local_encoder_transformer=args.use_local_encoder_transformer, + downsampling_by_pooling=args.downsampling_by_pooling, + encoder_hash_byte_group_size=args.encoder_hash_byte_group_size, + cross_attn_all_layers_encoder=args.cross_attn_all_layers_encoder, + cross_attn_all_layers_decoder=args.cross_attn_all_layers_decoder, + cross_attn_nheads=args.cross_attn_nheads, + eos_id=args.eos_id, ) return LocalDecoder(local_decoder_args) @@ -763,7 +782,6 @@ class ByteLatentTransformer(nn.Module): # General configuration self.weight_tying = args.weight_tying - self.sliding_window = args.sliding_window self.patch_size = args.patch_size self.patching_mode = args.patching_mode self.boe_id, self.bos_id, self.pad_id, self.eos_id = ( diff --git a/bytelatent/model/transformer.py b/bytelatent/model/global_transformer.py similarity index 93% rename from bytelatent/model/transformer.py rename to bytelatent/model/global_transformer.py index 24dc057..21c3f0c 100644 --- a/bytelatent/model/transformer.py +++ b/bytelatent/model/global_transformer.py @@ -11,6 +11,7 @@ from xformers.ops import AttentionBias from bytelatent.base_transformer import ( BaseTransformer, + BaseTransformerArgs, RMSNorm, flex_attention_comp, repeat_kv, @@ -142,11 +143,10 @@ class CrossAttention(nn.Module): class GlobalTransformer(BaseTransformer): - def __init__(self, args): + def __init__(self, args: BaseTransformerArgs): super().__init__(args) self.dropout = args.dropout - self.sliding_window = args.sliding_window - self.efficient_attn = args.efficient_attn + self.eos_id = args.eos_id self.token_embedding_projection = None if args.dim_token_emb is not None and args.dim_token_emb != self.dim: @@ -169,14 +169,19 @@ class GlobalTransformer(BaseTransformer): and projection to the token space. """ bs, seqlen = tokens.shape - attn_impl = self.efficient_attn h = embeds mask = ( mask if mask is not None - else create_causal_mask(seqlen, attn_impl, self.sliding_window) + else create_causal_mask( + seqlen, + self.attn_impl, + self.attn_bias_type, + tokens=tokens, + eos_id=self.eos_id, + ) ) if self.token_embedding_projection is not None and h.shape[-1] != self.dim: @@ -184,7 +189,7 @@ class GlobalTransformer(BaseTransformer): h = F.dropout(h, p=self.dropout, training=self.training) - h = super().forward(h, tok_idx=tok_idx, mask=mask, attn_impl=attn_impl) + h = super().forward(h, tok_idx=tok_idx, mask=mask, attn_impl=self.attn_impl) return h, cache def init_weights(self, init_base_std: float): diff --git a/bytelatent/model/local_models.py b/bytelatent/model/local_models.py index 8255504..4ce8fb5 100644 --- a/bytelatent/model/local_models.py +++ b/bytelatent/model/local_models.py @@ -1,8 +1,9 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. import logging -from typing import List, Optional, Tuple, Union +from typing import Any, List, Optional, Tuple, Union +from pydantic import BaseModel, ConfigDict import torch import torch.nn import torch.nn as nn @@ -16,29 +17,69 @@ from bytelatent.base_transformer import ( RotaryEmbedding, TransformerBlock, ) -from bytelatent.model.transformer import CrossAttention +from bytelatent.model.global_transformer import CrossAttention from bytelatent.model.utils import create_causal_mask, downsample from bytelatent.tokenizers.blt_tokenizer import BOE_ID logger = logging.getLogger() +class LocalModelArgs(BaseModel): + model_config = ConfigDict(extra="forbid") + # Local encoder specific dimensions + head_dim: int | None + dim: int + dropout: float + vocab_size: int + patch_size: int + sliding_window: int | None + use_rope: bool + init_base_std: float | None = None + init_std_factor: InitStdFactor + cross_attn_encoder: bool | None + cross_attn_decoder: bool | None + cross_attn_k: int | None + cross_attn_init_by_pooling: bool + norm_eps: float + rope_theta: float + max_seqlen: int + ffn_dim_multiplier: float | None = None + patching_mode: str + use_local_encoder_transformer: bool + downsampling_by_pooling: str | None + encoder_hash_byte_group_size: Any | None = None + cross_attn_all_layers_encoder: bool = False + cross_attn_all_layers_decoder: bool = False + cross_attn_nheads: int | None + + n_layers: int + n_heads: int + n_kv_heads: int | None = None + dim_token_emb: int + dim_patch_emb: int | None + attn_impl: str | None = "xformers" + attn_bias_type: str | None = "local_block_causal" + multiple_of: int = 256 + eos_id: int | None = None + + class LocalModelBase(nn.Module): - def __init__(self, args): + def __init__(self, args: LocalModelArgs): super().__init__() self.dim = args.dim self.dropout = args.dropout - self.vocab_size = args.vocab_size + args.pm_size + self.vocab_size = args.vocab_size self.patch_size = args.patch_size - self.efficient_attn = args.efficient_attn + self.attn_impl = args.attn_impl self.sliding_window = args.sliding_window self.use_rope = args.use_rope self.init_std_factor = args.init_std_factor self.cross_attn_encoder = getattr(args, "cross_attn_encoder", None) self.cross_attn_decoder = getattr(args, "cross_attn_decoder", None) self.cross_attn_k = getattr(args, "cross_attn_k", None) + self.eos_id = args.eos_id self.boe_id = BOE_ID @@ -54,7 +95,7 @@ class LocalModelBase(nn.Module): self.rope = RotaryEmbedding( theta=args.rope_theta, head_dim=args.head_dim or args.dim // args.n_heads, - max_seqlen=getattr(args, "max_encoder_seq_length", args.max_length), + max_seqlen=args.max_seqlen, ) self.pos_embeddings = None @@ -66,21 +107,15 @@ class LocalModelBase(nn.Module): self.patch_embedding_projection = self._create_patch_projection(args) - def _should_create_patch_projection(self, args): + def _should_create_patch_projection(self, args: LocalModelArgs): dimension_mismatch = ( getattr(args, "dim_patch_emb") and args.dim_patch_emb != self.dim ) # Check cross attention conditions cross_attn_conditions = ( - hasattr(args, "cross_attn_encoder") - and args.cross_attn_encoder - and getattr(args, "cross_attn_init_by_pooling") - ) or ( - hasattr(args, "cross_attn_decoder") - and args.cross_attn_decoder - and getattr(args, "cross_attn_init_by_pooling") - ) + args.cross_attn_encoder and args.cross_attn_init_by_pooling + ) or (args.cross_attn_decoder and args.cross_attn_init_by_pooling) return dimension_mismatch or cross_attn_conditions @@ -172,7 +207,7 @@ class LocalModelBase(nn.Module): class LocalEncoder(LocalModelBase): - def __init__(self, args): + def __init__(self, args: LocalModelArgs): super().__init__(args) self.output_proj = ( args.patching_mode in ["entropy", "probmax"] @@ -180,7 +215,6 @@ class LocalEncoder(LocalModelBase): self.apply_transformer = args.use_local_encoder_transformer self.downsampling_by_pooling = args.downsampling_by_pooling - self.patch_only = args.patch_only_encoder self.expects_hash_embeddings = args.encoder_hash_byte_group_size is not None self.cross_attn_encoder = args.cross_attn_encoder self.cross_attn_all_layers_encoder = args.cross_attn_all_layers_encoder @@ -224,7 +258,14 @@ class LocalEncoder(LocalModelBase): """ """ bs, seqlen = tokens.shape if mask is None: - mask = create_causal_mask(seqlen, self.efficient_attn, self.sliding_window) + mask = create_causal_mask( + seqlen, + self.attn_impl, + "local_block_causal", + sliding_window=self.sliding_window, + tokens=tokens, + eos_id=self.eos_id, + ) h = self.apply_embedding(tokens, embeds) freqs_cis = self.rope(seqlen=seqlen) if self.use_rope else None @@ -232,7 +273,7 @@ class LocalEncoder(LocalModelBase): h = F.dropout(h, p=self.dropout, training=self.training) for i, layer in enumerate(self.layers): - h = layer(h, mask=mask, freq_cis=freqs_cis, attn_impl=self.efficient_attn) + h = layer(h, mask=mask, freq_cis=freqs_cis, attn_impl=self.attn_impl) # check if cross attention should be applied to either all layer or only the last layer if self.cross_attn_encoder and ( i == len(self.layers) - 1 or self.cross_attn_all_layers_encoder @@ -273,12 +314,10 @@ class LocalEncoder(LocalModelBase): class LocalDecoder(LocalModelBase): - def __init__(self, args): + def __init__(self, args: LocalModelArgs): super().__init__(args) # Model configuration flags - self.patch_only = args.patch_only_decoder - self.expects_embeddings = args.share_encoder_decoder_emb self.cross_attn_decoder = args.cross_attn_decoder self.cross_attn_all_layers_decoder = args.cross_attn_all_layers_decoder self.cross_attn_init_by_pooling = args.cross_attn_init_by_pooling @@ -317,7 +356,14 @@ class LocalDecoder(LocalModelBase): assert embeds is not None, "Embeddings must be provided" if mask is None: - mask = create_causal_mask(seqlen, self.efficient_attn, self.sliding_window) + mask = create_causal_mask( + seqlen, + self.attn_impl, + "local_block_causal", + sliding_window=self.sliding_window, + tokens=tokens, + eos_id=self.eos_id, + ) h = embeds @@ -347,7 +393,7 @@ class LocalDecoder(LocalModelBase): ) h = h + h_cross - h = layer(h, mask=mask, freq_cis=freqs_cis, attn_impl=self.efficient_attn) + h = layer(h, mask=mask, freq_cis=freqs_cis, attn_impl=self.attn_impl) h_preds = self.norm(h) h_preds = F.dropout(h_preds, p=self.dropout, training=self.training) diff --git a/bytelatent/model/utils.py b/bytelatent/model/utils.py index ce52a30..eac2c68 100644 --- a/bytelatent/model/utils.py +++ b/bytelatent/model/utils.py @@ -1,8 +1,11 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. +import logging import torch from torch.nn.attention.flex_attention import create_block_mask from xformers.ops import fmha +logger = logging.getLogger() + def patch_reduce(h, max_num_patches, reduction, patch_ids): """ @@ -97,14 +100,72 @@ def causal_mask(b, h, q_idx, kv_idx): return q_idx >= kv_idx -def create_causal_mask(seqlen, attn_impl, sliding_window): - if sliding_window is not None and attn_impl == "xformers": - return fmha.attn_bias.LocalAttentionFromBottomRightMask( - window_left=sliding_window - 1, window_right=0 - ) - elif attn_impl == "xformers": - return fmha.attn_bias.LowerTriangularMask() +def tokens_to_seqlen(batch: torch.Tensor, eos_id: int): + """ + 0 0 0 1 0 0 0 1 0 0 0 + 0 1 0 0 0 1 0 0 0 0 0 + -> 4 4 3 2 4 5 + """ + mask = batch == eos_id + mask[:, -1] = True # virtual eos at the end of each row + + # 0 0 0 1 0 0 0 1 0 0 X + # 0 1 0 0 0 1 0 0 0 0 X + row, col = torch.where(mask) + + # row = 0, 0, 0, 1, 1, 1 + # col = 3, 7, 10, 1, 5, 10 + seqlens = (col[1:] - col[:-1]) + (row[1:] - row[:-1]) * mask.shape[1] + # seqlens = (4, 3, -9, 4, 5) + (0, 0, 11, 0, 0) = (4, 3, 2, 4, 5) + return [int(col[0].item() + 1)] + seqlens.tolist() + + +WARNED_SDPA = False + + +def create_causal_mask( + seqlen, + attn_impl: str, + attn_bias_type: str | None, + *, + eos_id: int | None = None, + tokens: torch.Tensor | None = None, + sliding_window: int | None = None, +): + if attn_impl == "xformers": + if attn_bias_type is None: + return fmha.attn_bias.LowerTriangularMask() + elif attn_bias_type == "causal": + assert sliding_window is None + print("attn: causal") + return fmha.attn_bias.LowerTriangularMask() + elif attn_bias_type == "block_causal": + assert sliding_window is None + assert eos_id is not None + assert tokens is not None + print("attn: block_causal") + return fmha.attn_bias.BlockDiagonalCausalMask.from_seqlens( + q_seqlen=tokens_to_seqlen(tokens, eos_id) + ) + elif attn_bias_type == "local_block_causal": + assert sliding_window is not None + assert eos_id is not None + assert tokens is not None + print("attn: local_block_causal") + return fmha.attn_bias.BlockDiagonalCausalMask.from_seqlens( + q_seqlen=tokens_to_seqlen(tokens, eos_id) + ).make_local_attention(sliding_window) + else: + return fmha.attn_bias.LocalAttentionFromBottomRightMask( + window_left=sliding_window - 1, window_right=0 + ) elif attn_impl == "sdpa": + global WARNED_SDPA + if not WARNED_SDPA: + logging.warning( + "SDPA attention being used, which doesn't have specialized attention implementations for block_causal and local_block_causal attention." + ) + WARNED_SDPA = True return "causal" elif attn_impl == "flex_attention": return create_block_mask(causal_mask, None, None, seqlen, seqlen) diff --git a/bytelatent/preprocess/fsspec_target.py b/bytelatent/preprocess/fsspec_target.py new file mode 100644 index 0000000..eacb101 --- /dev/null +++ b/bytelatent/preprocess/fsspec_target.py @@ -0,0 +1,38 @@ +import fsspec +from luigi.target import FileSystem, FileSystemTarget + + +class FSSpecFileSystem(FileSystem): + def __init__(self, fs: fsspec.AbstractFileSystem): + self.fs = fs + + def exists(self, path): + return self.fs.exists() + + def remove(self, path, recursive=True, skip_trash=True): + raise NotImplementedError() + + def isdir(self, path): + return self.fs.isdir(path) + + def listdir(self, path): + return self.fs.ls(path) + + +class FSSpecTarget(FileSystemTarget): + def __init__(self, path, fs: fsspec.AbstractFileSystem | None = None): + self.path = path + if fs is None: + self.fsspec_fs = fsspec.filesystem("file") + else: + self.fsspec_fs = fs + self._fs = None + + @property + def fs(self): + if self._fs is None: + self._fs = FSSpecFileSystem(self.fsspec_fs) + return self._fs + + def open(self, mode): + return self.fs.open(self.path, mode=mode) diff --git a/bytelatent/test_blt.py b/bytelatent/test_blt.py index 73ad9f7..d7cda05 100644 --- a/bytelatent/test_blt.py +++ b/bytelatent/test_blt.py @@ -23,9 +23,10 @@ from bytelatent.model.blt import ( init_embeddings, patch_ids_from_lengths, ) -from bytelatent.model.transformer import CrossAttention +from bytelatent.model.global_transformer import CrossAttention from bytelatent.model.utils import create_causal_mask from bytelatent.optim import OptimArgs, build_optimizer +from bytelatent.tokenizers.constants import EOS_ID from bytelatent.train import compute_loss @@ -51,7 +52,7 @@ def batch_to_tensors_and_gpu(batch): def fake_batch(): - batch_dict = torch.load(os.path.join(BLT_DATA, "test_batch.pt")) + batch_dict = torch.load(os.path.join(BLT_DATA, "test_batch.pt"), weights_only=False) del batch_dict["x2"] del batch_dict["y2"] del batch_dict["src_names"] @@ -98,18 +99,17 @@ def create_args(cross_attention=False): recompute_attn=False, custom_bwd=False, layer_ckpt="none", - efficient_attn="sdpa", - patch_only_encoder=False, - patch_only_decoder=False, use_local_encoder_transformer=True, init_use_gaussian=True, init_use_depth="current", attn_bias_type="block_causal", + attn_impl="xformers", alpha_depth="disabled", max_length=256, local_attention_window_len=512, max_seqlen=12288, downsampling_by_pooling="max", + eos_id=EOS_ID, ) return transformer_args @@ -341,10 +341,10 @@ class TestByteLatentTransformer: model = ByteLatentTransformer(args) assert model is not None - @pytest.mark.parametrize("attn_type", ["fmha", "sdpa"]) - def test_blt_transformer_forward(self, attn_type): + @pytest.mark.parametrize("attn_impl", ["fmha", "sdpa", "xformers"]) + def test_blt_transformer_forward(self, attn_impl): args = create_args() - args = args.model_copy(update=dict(efficient_attn=attn_type)) + args = args.model_copy(update=dict(attn_impl=attn_impl)) model = ByteLatentTransformer(args) model = model.cuda() batch = fake_batch() @@ -393,7 +393,9 @@ class TestByteLatentTransformer: n_kv_heads=4, norm_eps=1e-6, ).to("cuda") - mask = create_causal_mask(x.shape[1], "flex_attention", sliding_window=None) + mask = create_causal_mask( + x.shape[1], "flex_attention", None, sliding_window=None + ) output = cross_attention(x, kv, mask) assert output is not None assert output.shape == (2, 256, 512) @@ -440,7 +442,7 @@ class TestByteLatentTransformer: def test_loss_backward(self): args = create_args() - args = args.model_copy(update=dict(efficient_attn="sdpa")) + args = args.model_copy(update=dict(attn_impl="sdpa")) batch = fake_batch() model = ByteLatentTransformer(args) steps = 10 diff --git a/bytelatent/test_entropy_model.py b/bytelatent/test_entropy_model.py index 3acc42d..af81638 100644 --- a/bytelatent/test_entropy_model.py +++ b/bytelatent/test_entropy_model.py @@ -24,6 +24,7 @@ def test_entropy_model(): dataset_files=[ARROW_TEST_DATA], row_num=0, arrow_batch_size=100, + s3_profile=None, ) arrow_file = initial_state.build() tokenizer_args = TokenizerArgs( diff --git a/bytelatent/transformer.py b/bytelatent/transformer.py index 432f7df..92c5ff5 100644 --- a/bytelatent/transformer.py +++ b/bytelatent/transformer.py @@ -22,23 +22,7 @@ from bytelatent.base_transformer import ( RMSNorm, cross_entropy, ) - - -def create_causal_mask(seqlen, attn_impl, sliding_window): - if sliding_window is not None and attn_impl == "xformers": - return fmha.attn_bias.LocalAttentionFromBottomRightMask( - window_left=sliding_window - 1, window_right=0 - ) - elif attn_impl == "xformers": - return fmha.attn_bias.LowerTriangularMask() - elif attn_impl == "sdpa": - return "causal" - elif attn_impl == "flex_attention": - return create_block_mask(causal_mask, None, None, seqlen, seqlen) - else: - raise NotImplementedError( - f"Attention {attn_impl} with {sliding_window} sliding window not implemented" - ) +from bytelatent.model.utils import create_causal_mask def attention_flops_per_token(n_layers, seq_len, dim, causal): @@ -94,8 +78,10 @@ class LMTransformer(BaseTransformer): target: Optional[torch.Tensor] = None, tok_idx: Optional[torch.Tensor] = None, mask: Optional[Union[BlockMask, AttentionBias, torch.Tensor, str]] = None, - attn_impl: str = "sdpa", + attn_impl: str | None = None, ): + if attn_impl is None: + attn_impl = self.attn_impl bsz, seqlen = token_values.shape h = self.tok_embeddings(token_values) @@ -103,7 +89,14 @@ class LMTransformer(BaseTransformer): mask = ( mask if mask is not None - else create_causal_mask(seqlen, attn_impl, self.sliding_window) + else create_causal_mask( + seqlen, + attn_impl, + self.attn_bias_type, + sliding_window=self.sliding_window, + tokens=token_values, + eos_id=self.eos_id, + ) ) h = super().forward(h, tok_idx=tok_idx, mask=mask, attn_impl=attn_impl) From 374409fa3b90e856f3b4c306c16fff21ac210ca7 Mon Sep 17 00:00:00 2001 From: Pedro Rodriguez Date: Fri, 17 Jan 2025 01:01:29 +0000 Subject: [PATCH 9/9] [WIP] Changes for training entropy model and correcting attention in local models Summary: - Refactor local model configs to be separate and clearer - Add attention arguments and correct which attention is used in local models - Preparation for being able to have an entropy train script - Fix failing unit tests Test Plan: --- bytelatent/args.py | 7 + bytelatent/base_transformer.py | 45 ++++-- bytelatent/configs/debug.yaml | 3 +- .../data/iterators/test_arrow_iterator.py | 3 + bytelatent/distributed.py | 1 - bytelatent/model/blt.py | 128 ++++++++++-------- .../{transformer.py => global_transformer.py} | 17 ++- bytelatent/model/local_models.py | 94 +++++++++---- bytelatent/model/utils.py | 73 +++++++++- bytelatent/preprocess/fsspec_target.py | 38 ++++++ bytelatent/test_blt.py | 22 +-- bytelatent/test_entropy_model.py | 1 + bytelatent/train.py | 4 + bytelatent/transformer.py | 31 ++--- 14 files changed, 334 insertions(+), 133 deletions(-) rename bytelatent/model/{transformer.py => global_transformer.py} (93%) create mode 100644 bytelatent/preprocess/fsspec_target.py diff --git a/bytelatent/args.py b/bytelatent/args.py index cfba3bf..b9144c6 100644 --- a/bytelatent/args.py +++ b/bytelatent/args.py @@ -30,6 +30,7 @@ from bytelatent.model.blt import ByteLatentTransformerArgs from bytelatent.optim import OptimArgs from bytelatent.profiling import ProfilerArgs from bytelatent.tokenizers.build_tokenizer import TokenizerArgs +from bytelatent.transformer import LMTransformerArgs logger = logging.getLogger() @@ -163,6 +164,8 @@ class TrainArgs(BaseModel): seed: int = 42 + debug_dynamo: bool = False + # Number of gradient accumulation steps # Total batch size is batch_size*grad_acc_steps grad_acc_steps: int = 1 @@ -176,6 +179,10 @@ class TrainArgs(BaseModel): data: DataloaderArgs = DataloaderArgs() optim: OptimArgs = OptimArgs() model: ByteLatentTransformerArgs = ByteLatentTransformerArgs() + # This is only needed for training the entropy model + entropy_model: LMTransformerArgs | None = None + # Instead of training main model, train entropy model + train_entropy_model: bool = False distributed: DistributedArgs = DistributedArgs() env: EnvironmentArgs = EnvironmentArgs() diff --git a/bytelatent/base_transformer.py b/bytelatent/base_transformer.py index 45cb7c5..dd0cce6 100644 --- a/bytelatent/base_transformer.py +++ b/bytelatent/base_transformer.py @@ -4,7 +4,7 @@ from enum import Enum from typing import Optional, Tuple, Union import torch -from pydantic import BaseModel +from pydantic import BaseModel, ConfigDict from torch import nn from torch.nn import functional as F from torch.nn.attention.flex_attention import ( @@ -15,6 +15,7 @@ from torch.nn.attention.flex_attention import ( from xformers.ops import AttentionBias, fmha from bytelatent import probe +from bytelatent.tokenizers.constants import EOS_ID if int(os.environ.get("BLT_ALLOW_MISSING_FLEX_ATTENTION", False)) == 0: flex_attention_comp = torch.compile(flex_attention) @@ -30,13 +31,14 @@ class InitStdFactor(Enum): class BaseTransformerArgs(BaseModel): + model_config = ConfigDict(extra="forbid") dim: int = 512 n_layers: int = 8 - head_dim: Optional[int] = None - n_heads: Optional[int] = None - n_kv_heads: Optional[int] = None + head_dim: int | None = None + n_heads: int | None = None + n_kv_heads: int | None = None - ffn_dim_multiplier: Optional[float] = None + ffn_dim_multiplier: float | None = None multiple_of: int = 256 @@ -44,11 +46,16 @@ class BaseTransformerArgs(BaseModel): rope_theta: float = 10000.0 - init_base_std: Optional[float] = None + init_base_std: float | None = None init_std_factor: InitStdFactor = InitStdFactor.DISABLED max_seqlen: int = 1024 + attn_impl: str | None = "sdpa" + attn_bias_type: str | None = None + # Special token config + eos_id: int | None = EOS_ID + def cross_entropy(pred, target, **kwargs): return F.nll_loss( @@ -294,6 +301,18 @@ class RMSNorm(nn.Module): torch.nn.init.ones_(self.weight) # type: ignore +def _reshape_for_attn_bias( + attn_bias: AttentionBias | None, + *tensors: torch.Tensor, +) -> list[torch.Tensor]: + to_transform = list(tensors) + if isinstance(attn_bias, fmha.attn_bias.BlockDiagonalCausalMask): + # could be `view` instead of reshape during training, but for inference + # have to reshape due to strides mismatch + to_transform = [t.reshape(1, -1, *t.shape[2:]) for t in to_transform] + return to_transform + + class Attention(nn.Module): def __init__( self, @@ -371,9 +390,12 @@ class Attention(nn.Module): output = flex_attention_comp(xq, xk, xv, block_mask=mask) output = output.transpose(1, 2).contiguous() # B H S D -> B S H D - elif attn_impl == "fmha": + elif attn_impl == "xformers": assert mask is None or isinstance(mask, AttentionBias) + query_shape = xq.shape + xq, xk, xv = _reshape_for_attn_bias(mask, xq, xk, xv) output = fmha.memory_efficient_attention(xq, xk, xv, attn_bias=mask) + output = output.view(query_shape) # This uses B S H D instead of B H S D of pytorch elif attn_impl == "sdpa": @@ -522,14 +544,16 @@ class TransformerBlock(nn.Module): mask: Optional[Union[BlockMask, AttentionBias, str]] = None, attn_impl: str = "sdpa", ) -> torch.Tensor: - h = x + self.attention( + attn_out = self.attention( self.attention_norm(x), freq_cis, tok_idx=tok_idx, mask=mask, attn_impl=attn_impl, ) - out = h + self.feed_forward(self.ffn_norm(h)) + h = x + attn_out + h_norm = self.ffn_norm(h) + out = h + self.feed_forward(h_norm) return out def init_weights(self, init_std=None, factor=1.0): @@ -545,6 +569,8 @@ class BaseTransformer(nn.Module): super().__init__() self.dim = args.dim self.init_base_std = args.init_base_std + self.attn_impl = args.attn_impl + self.attn_bias_type = args.attn_bias_type self.init_std_factor = InitStdFactor(args.init_std_factor) self.max_seqlen = args.max_seqlen self.rope_embeddings = RotaryEmbedding( @@ -552,6 +578,7 @@ class BaseTransformer(nn.Module): head_dim=args.head_dim or args.dim // args.n_heads, max_seqlen=args.max_seqlen, ) + self.eos_id = args.eos_id self.layers = nn.ModuleList() for _ in range(args.n_layers): diff --git a/bytelatent/configs/debug.yaml b/bytelatent/configs/debug.yaml index 5f6debb..4ae4459 100644 --- a/bytelatent/configs/debug.yaml +++ b/bytelatent/configs/debug.yaml @@ -15,7 +15,6 @@ optim: distributed: fsdp_type: full_shard - compile: true model_dtype: bf16 matmul_allow_tf32: false selective_activation_checkpointing: false @@ -58,13 +57,13 @@ model: recompute_attn: false custom_bwd: false layer_ckpt: "none" - efficient_attn: "sdpa" patch_only_encoder: false patch_only_decoder: false use_local_encoder_transformer: true init_use_gaussian: true init_use_depth: "current" attn_bias_type: "block_causal" + attn_impl: "xformers" alpha_depth: "disabled" max_length: 256 local_attention_window_len: 512 diff --git a/bytelatent/data/iterators/test_arrow_iterator.py b/bytelatent/data/iterators/test_arrow_iterator.py index 4266427..fd448eb 100644 --- a/bytelatent/data/iterators/test_arrow_iterator.py +++ b/bytelatent/data/iterators/test_arrow_iterator.py @@ -27,6 +27,7 @@ def test_basic_arrow_file(): dataset_files=[ARROW_TEST_DATA_1], row_num=0, arrow_batch_size=100, + s3_profile=None, ) arrow_file = initial_state.build() start_state = arrow_file.get_state() @@ -55,6 +56,7 @@ def test_basic_arrow_file(): dataset_files=[ARROW_TEST_DATA_1], row_num=251, arrow_batch_size=100, + s3_profile=None, ) arrow_file = resumed_state.build() for example in arrow_file.create_iter(): @@ -74,6 +76,7 @@ def test_basic_arrow_file(): dataset_files=[ARROW_TEST_DATA_1], row_num=0, arrow_batch_size=100, + s3_profile=None, ) arrow_file = rank_state.build() expected_ids = [] diff --git a/bytelatent/distributed.py b/bytelatent/distributed.py index b211858..168cb7c 100644 --- a/bytelatent/distributed.py +++ b/bytelatent/distributed.py @@ -11,7 +11,6 @@ import socket import subprocess import sys import tempfile -from dataclasses import asdict, dataclass from functools import lru_cache, partial, reduce from itertools import chain from typing import List, Optional, Tuple, Union diff --git a/bytelatent/model/blt.py b/bytelatent/model/blt.py index 9332d19..1d20cfa 100644 --- a/bytelatent/model/blt.py +++ b/bytelatent/model/blt.py @@ -15,8 +15,8 @@ from bytelatent.base_transformer import ( TransformerBlock, ) from bytelatent.data.patcher import Patcher, PatcherArgs -from bytelatent.model.local_models import LocalDecoder, LocalEncoder -from bytelatent.model.transformer import GlobalTransformer +from bytelatent.model.global_transformer import GlobalTransformer +from bytelatent.model.local_models import LocalDecoder, LocalEncoder, LocalModelArgs from bytelatent.model.utils import downsample from bytelatent.tokenizers.constants import BOE_ID, BOS_ID, EOS_ID, OFFSET, PAD_ID @@ -403,7 +403,6 @@ def patch_ids_from_lengths(patch_lengths, seq_len): class ByteLatentTransformerArgs(BaseTransformerArgs): - model_config = ConfigDict(extra="forbid") # Basic model configuration seed: int = 42 vocab_size: int = -1 @@ -412,7 +411,6 @@ class ByteLatentTransformerArgs(BaseTransformerArgs): n_heads: int = 8 # TODO: What is the purpose of this parameter? weight_tying: bool = False - sliding_window: Optional[int] = None # Architecture and dimensions dim_token: int = 256 @@ -471,11 +469,6 @@ class ByteLatentTransformerArgs(BaseTransformerArgs): recompute_attn: bool = True custom_bwd: bool = False layer_ckpt: str = "all" - efficient_attn: str | None = None - - # Architecture options - patch_only_encoder: bool = False - patch_only_decoder: bool = False # Initialization and attention init_use_gaussian: bool = True @@ -541,9 +534,6 @@ class ByteLatentTransformerArgs(BaseTransformerArgs): # Logging full_logging_n_layers: int = 4 - # Special token config - eos_id: int | None = None - @model_validator(mode="after") def check_hash_byte_sizes(self) -> Self: if ( @@ -558,22 +548,6 @@ class ByteLatentTransformerArgs(BaseTransformerArgs): return self -class LocalEncoderArgs(ByteLatentTransformerArgs): - # Local encoder specific dimensions - n_heads_local_encoder: int = 8 - dim_token_emb: int | None = None - dim_patch_emb: int | None = None - - def __post_init__(self): - # Override base args with local encoder specific values - self.dim = self.dim_local_encoder - self.n_layers = self.n_layers_local_encoder - self.n_heads = self.n_heads_local_encoder - self.cross_attn_decoder = False - self.cross_attn_k = self.cross_attn_k if self.cross_attn_encoder else None - self.attn_bias_type = "local_block_causal" - - class GlobalTransformerArgs(ByteLatentTransformerArgs): # Global encoder specific dimensions dim_token_emb: int | None = None @@ -625,20 +599,42 @@ def create_global_transformer(args: ByteLatentTransformerArgs) -> GlobalTransfor def create_local_encoder(args: ByteLatentTransformerArgs) -> LocalEncoder: - # First deep copy the original args - # Replace with local encoder specific values - local_encoder_args = args.model_copy( - deep=True, - update=dict( - dim=args.dim_local_encoder, - n_layers=args.n_layers_local_encoder, - n_heads=args.n_heads_local_encoder, - dim_token_emb=get_encoder_dim_token_emb(args), - dim_patch_emb=get_encoder_dim_patch_emb(args), - cross_attn_decoder=False, - cross_attn_k=args.cross_attn_k if args.cross_attn_encoder else None, - attn_bias_type="local_block_causal", - ), + local_encoder_args = LocalModelArgs( + # Updated args + dim=args.dim_local_encoder, + n_layers=args.n_layers_local_encoder, + n_heads=args.n_heads_local_encoder, + dim_token_emb=get_encoder_dim_token_emb(args), + dim_patch_emb=get_encoder_dim_patch_emb(args), + cross_attn_encoder=args.cross_attn_encoder, + cross_attn_decoder=False, + cross_attn_k=args.cross_attn_k if args.cross_attn_encoder else None, + cross_attn_init_by_pooling=args.cross_attn_init_by_pooling, + # Defaults + head_dim=args.head_dim, + max_seqlen=args.max_encoder_seq_length, + dropout=args.dropout, + vocab_size=args.vocab_size + args.pm_size, + norm_eps=args.norm_eps, + patch_size=args.patch_size, + sliding_window=args.local_attention_window_len, + use_rope=args.use_rope, + rope_theta=args.rope_theta, + init_base_std=args.init_base_std, + init_std_factor=args.init_std_factor, + n_kv_heads=args.n_kv_heads, + attn_impl=args.attn_impl, + attn_bias_type="local_block_causal", + multiple_of=args.multiple_of, + ffn_dim_multiplier=args.ffn_dim_multiplier, + patching_mode=args.patching_mode, + use_local_encoder_transformer=args.use_local_encoder_transformer, + downsampling_by_pooling=args.downsampling_by_pooling, + encoder_hash_byte_group_size=args.encoder_hash_byte_group_size, + cross_attn_all_layers_encoder=args.cross_attn_all_layers_encoder, + cross_attn_all_layers_decoder=args.cross_attn_all_layers_decoder, + cross_attn_nheads=args.cross_attn_nheads, + eos_id=args.eos_id, ) return LocalEncoder(local_encoder_args) @@ -646,18 +642,41 @@ def create_local_encoder(args: ByteLatentTransformerArgs) -> LocalEncoder: def create_local_decoder(args: ByteLatentTransformerArgs) -> LocalDecoder: # First deep copy the original args - local_decoder_args = args.model_copy( - deep=True, - update=dict( - dim=args.dim_local_decoder, - n_layers=args.n_layers_local_decoder, - n_heads=args.n_heads_local_decoder, - cross_attn_encoder=False, - cross_attn_init_by_pooling=False, # states are already defined - dim_token_emb=get_decoder_dim_token_emb(args), - dim_patch_emb=args.dim_global, - cross_attn_k=args.cross_attn_k if args.cross_attn_decoder else None, - ), + local_decoder_args = LocalModelArgs( + dim=args.dim_local_decoder, + n_layers=args.n_layers_local_decoder, + n_heads=args.n_heads_local_decoder, + dim_token_emb=get_decoder_dim_token_emb(args), + dim_patch_emb=args.dim_global, + cross_attn_encoder=False, + cross_attn_decoder=args.cross_attn_decoder, + cross_attn_init_by_pooling=False, # states are already defined + cross_attn_k=args.cross_attn_k if args.cross_attn_decoder else None, + # Defaults + head_dim=args.head_dim, + max_seqlen=args.max_encoder_seq_length, + dropout=args.dropout, + vocab_size=args.vocab_size + args.pm_size, + norm_eps=args.norm_eps, + patch_size=args.patch_size, + sliding_window=args.local_attention_window_len, + use_rope=args.use_rope, + rope_theta=args.rope_theta, + init_base_std=args.init_base_std, + init_std_factor=args.init_std_factor, + n_kv_heads=args.n_kv_heads, + attn_impl=args.attn_impl, + attn_bias_type="local_block_causal", + multiple_of=args.multiple_of, + ffn_dim_multiplier=args.ffn_dim_multiplier, + patching_mode=args.patching_mode, + use_local_encoder_transformer=args.use_local_encoder_transformer, + downsampling_by_pooling=args.downsampling_by_pooling, + encoder_hash_byte_group_size=args.encoder_hash_byte_group_size, + cross_attn_all_layers_encoder=args.cross_attn_all_layers_encoder, + cross_attn_all_layers_decoder=args.cross_attn_all_layers_decoder, + cross_attn_nheads=args.cross_attn_nheads, + eos_id=args.eos_id, ) return LocalDecoder(local_decoder_args) @@ -763,7 +782,6 @@ class ByteLatentTransformer(nn.Module): # General configuration self.weight_tying = args.weight_tying - self.sliding_window = args.sliding_window self.patch_size = args.patch_size self.patching_mode = args.patching_mode self.boe_id, self.bos_id, self.pad_id, self.eos_id = ( diff --git a/bytelatent/model/transformer.py b/bytelatent/model/global_transformer.py similarity index 93% rename from bytelatent/model/transformer.py rename to bytelatent/model/global_transformer.py index 24dc057..21c3f0c 100644 --- a/bytelatent/model/transformer.py +++ b/bytelatent/model/global_transformer.py @@ -11,6 +11,7 @@ from xformers.ops import AttentionBias from bytelatent.base_transformer import ( BaseTransformer, + BaseTransformerArgs, RMSNorm, flex_attention_comp, repeat_kv, @@ -142,11 +143,10 @@ class CrossAttention(nn.Module): class GlobalTransformer(BaseTransformer): - def __init__(self, args): + def __init__(self, args: BaseTransformerArgs): super().__init__(args) self.dropout = args.dropout - self.sliding_window = args.sliding_window - self.efficient_attn = args.efficient_attn + self.eos_id = args.eos_id self.token_embedding_projection = None if args.dim_token_emb is not None and args.dim_token_emb != self.dim: @@ -169,14 +169,19 @@ class GlobalTransformer(BaseTransformer): and projection to the token space. """ bs, seqlen = tokens.shape - attn_impl = self.efficient_attn h = embeds mask = ( mask if mask is not None - else create_causal_mask(seqlen, attn_impl, self.sliding_window) + else create_causal_mask( + seqlen, + self.attn_impl, + self.attn_bias_type, + tokens=tokens, + eos_id=self.eos_id, + ) ) if self.token_embedding_projection is not None and h.shape[-1] != self.dim: @@ -184,7 +189,7 @@ class GlobalTransformer(BaseTransformer): h = F.dropout(h, p=self.dropout, training=self.training) - h = super().forward(h, tok_idx=tok_idx, mask=mask, attn_impl=attn_impl) + h = super().forward(h, tok_idx=tok_idx, mask=mask, attn_impl=self.attn_impl) return h, cache def init_weights(self, init_base_std: float): diff --git a/bytelatent/model/local_models.py b/bytelatent/model/local_models.py index 8255504..f182780 100644 --- a/bytelatent/model/local_models.py +++ b/bytelatent/model/local_models.py @@ -1,11 +1,12 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. import logging -from typing import List, Optional, Tuple, Union +from typing import Any, List, Optional, Tuple, Union import torch import torch.nn import torch.nn as nn +from pydantic import BaseModel, ConfigDict from torch.nn import functional as F from torch.nn.attention.flex_attention import BlockMask from xformers.ops import AttentionBias @@ -16,29 +17,69 @@ from bytelatent.base_transformer import ( RotaryEmbedding, TransformerBlock, ) -from bytelatent.model.transformer import CrossAttention +from bytelatent.model.global_transformer import CrossAttention from bytelatent.model.utils import create_causal_mask, downsample from bytelatent.tokenizers.blt_tokenizer import BOE_ID logger = logging.getLogger() +class LocalModelArgs(BaseModel): + model_config = ConfigDict(extra="forbid") + # Local encoder specific dimensions + head_dim: int | None + dim: int + dropout: float + vocab_size: int + patch_size: int + sliding_window: int | None + use_rope: bool + init_base_std: float | None = None + init_std_factor: InitStdFactor + cross_attn_encoder: bool | None + cross_attn_decoder: bool | None + cross_attn_k: int | None + cross_attn_init_by_pooling: bool + norm_eps: float + rope_theta: float + max_seqlen: int + ffn_dim_multiplier: float | None = None + patching_mode: str + use_local_encoder_transformer: bool + downsampling_by_pooling: str | None + encoder_hash_byte_group_size: Any | None = None + cross_attn_all_layers_encoder: bool = False + cross_attn_all_layers_decoder: bool = False + cross_attn_nheads: int | None + + n_layers: int + n_heads: int + n_kv_heads: int | None = None + dim_token_emb: int + dim_patch_emb: int | None + attn_impl: str | None = "xformers" + attn_bias_type: str | None = "local_block_causal" + multiple_of: int = 256 + eos_id: int | None = None + + class LocalModelBase(nn.Module): - def __init__(self, args): + def __init__(self, args: LocalModelArgs): super().__init__() self.dim = args.dim self.dropout = args.dropout - self.vocab_size = args.vocab_size + args.pm_size + self.vocab_size = args.vocab_size self.patch_size = args.patch_size - self.efficient_attn = args.efficient_attn + self.attn_impl = args.attn_impl self.sliding_window = args.sliding_window self.use_rope = args.use_rope self.init_std_factor = args.init_std_factor self.cross_attn_encoder = getattr(args, "cross_attn_encoder", None) self.cross_attn_decoder = getattr(args, "cross_attn_decoder", None) self.cross_attn_k = getattr(args, "cross_attn_k", None) + self.eos_id = args.eos_id self.boe_id = BOE_ID @@ -54,7 +95,7 @@ class LocalModelBase(nn.Module): self.rope = RotaryEmbedding( theta=args.rope_theta, head_dim=args.head_dim or args.dim // args.n_heads, - max_seqlen=getattr(args, "max_encoder_seq_length", args.max_length), + max_seqlen=args.max_seqlen, ) self.pos_embeddings = None @@ -66,21 +107,15 @@ class LocalModelBase(nn.Module): self.patch_embedding_projection = self._create_patch_projection(args) - def _should_create_patch_projection(self, args): + def _should_create_patch_projection(self, args: LocalModelArgs): dimension_mismatch = ( getattr(args, "dim_patch_emb") and args.dim_patch_emb != self.dim ) # Check cross attention conditions cross_attn_conditions = ( - hasattr(args, "cross_attn_encoder") - and args.cross_attn_encoder - and getattr(args, "cross_attn_init_by_pooling") - ) or ( - hasattr(args, "cross_attn_decoder") - and args.cross_attn_decoder - and getattr(args, "cross_attn_init_by_pooling") - ) + args.cross_attn_encoder and args.cross_attn_init_by_pooling + ) or (args.cross_attn_decoder and args.cross_attn_init_by_pooling) return dimension_mismatch or cross_attn_conditions @@ -172,7 +207,7 @@ class LocalModelBase(nn.Module): class LocalEncoder(LocalModelBase): - def __init__(self, args): + def __init__(self, args: LocalModelArgs): super().__init__(args) self.output_proj = ( args.patching_mode in ["entropy", "probmax"] @@ -180,7 +215,6 @@ class LocalEncoder(LocalModelBase): self.apply_transformer = args.use_local_encoder_transformer self.downsampling_by_pooling = args.downsampling_by_pooling - self.patch_only = args.patch_only_encoder self.expects_hash_embeddings = args.encoder_hash_byte_group_size is not None self.cross_attn_encoder = args.cross_attn_encoder self.cross_attn_all_layers_encoder = args.cross_attn_all_layers_encoder @@ -224,7 +258,14 @@ class LocalEncoder(LocalModelBase): """ """ bs, seqlen = tokens.shape if mask is None: - mask = create_causal_mask(seqlen, self.efficient_attn, self.sliding_window) + mask = create_causal_mask( + seqlen, + self.attn_impl, + "local_block_causal", + sliding_window=self.sliding_window, + tokens=tokens, + eos_id=self.eos_id, + ) h = self.apply_embedding(tokens, embeds) freqs_cis = self.rope(seqlen=seqlen) if self.use_rope else None @@ -232,7 +273,7 @@ class LocalEncoder(LocalModelBase): h = F.dropout(h, p=self.dropout, training=self.training) for i, layer in enumerate(self.layers): - h = layer(h, mask=mask, freq_cis=freqs_cis, attn_impl=self.efficient_attn) + h = layer(h, mask=mask, freq_cis=freqs_cis, attn_impl=self.attn_impl) # check if cross attention should be applied to either all layer or only the last layer if self.cross_attn_encoder and ( i == len(self.layers) - 1 or self.cross_attn_all_layers_encoder @@ -273,12 +314,10 @@ class LocalEncoder(LocalModelBase): class LocalDecoder(LocalModelBase): - def __init__(self, args): + def __init__(self, args: LocalModelArgs): super().__init__(args) # Model configuration flags - self.patch_only = args.patch_only_decoder - self.expects_embeddings = args.share_encoder_decoder_emb self.cross_attn_decoder = args.cross_attn_decoder self.cross_attn_all_layers_decoder = args.cross_attn_all_layers_decoder self.cross_attn_init_by_pooling = args.cross_attn_init_by_pooling @@ -317,7 +356,14 @@ class LocalDecoder(LocalModelBase): assert embeds is not None, "Embeddings must be provided" if mask is None: - mask = create_causal_mask(seqlen, self.efficient_attn, self.sliding_window) + mask = create_causal_mask( + seqlen, + self.attn_impl, + "local_block_causal", + sliding_window=self.sliding_window, + tokens=tokens, + eos_id=self.eos_id, + ) h = embeds @@ -347,7 +393,7 @@ class LocalDecoder(LocalModelBase): ) h = h + h_cross - h = layer(h, mask=mask, freq_cis=freqs_cis, attn_impl=self.efficient_attn) + h = layer(h, mask=mask, freq_cis=freqs_cis, attn_impl=self.attn_impl) h_preds = self.norm(h) h_preds = F.dropout(h_preds, p=self.dropout, training=self.training) diff --git a/bytelatent/model/utils.py b/bytelatent/model/utils.py index ce52a30..42eb185 100644 --- a/bytelatent/model/utils.py +++ b/bytelatent/model/utils.py @@ -1,8 +1,12 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. +import logging + import torch from torch.nn.attention.flex_attention import create_block_mask from xformers.ops import fmha +logger = logging.getLogger() + def patch_reduce(h, max_num_patches, reduction, patch_ids): """ @@ -97,14 +101,69 @@ def causal_mask(b, h, q_idx, kv_idx): return q_idx >= kv_idx -def create_causal_mask(seqlen, attn_impl, sliding_window): - if sliding_window is not None and attn_impl == "xformers": - return fmha.attn_bias.LocalAttentionFromBottomRightMask( - window_left=sliding_window - 1, window_right=0 - ) - elif attn_impl == "xformers": - return fmha.attn_bias.LowerTriangularMask() +def tokens_to_seqlen(batch: torch.Tensor, eos_id: int): + """ + 0 0 0 1 0 0 0 1 0 0 0 + 0 1 0 0 0 1 0 0 0 0 0 + -> 4 4 3 2 4 5 + """ + mask = batch == eos_id + mask[:, -1] = True # virtual eos at the end of each row + + # 0 0 0 1 0 0 0 1 0 0 X + # 0 1 0 0 0 1 0 0 0 0 X + row, col = torch.where(mask) + + # row = 0, 0, 0, 1, 1, 1 + # col = 3, 7, 10, 1, 5, 10 + seqlens = (col[1:] - col[:-1]) + (row[1:] - row[:-1]) * mask.shape[1] + # seqlens = (4, 3, -9, 4, 5) + (0, 0, 11, 0, 0) = (4, 3, 2, 4, 5) + return [int(col[0].item() + 1)] + seqlens.tolist() + + +WARNED_SDPA = False + + +def create_causal_mask( + seqlen, + attn_impl: str, + attn_bias_type: str | None, + *, + eos_id: int | None = None, + tokens: torch.Tensor | None = None, + sliding_window: int | None = None, +): + if attn_impl == "xformers": + if attn_bias_type is None: + return fmha.attn_bias.LowerTriangularMask() + elif attn_bias_type == "causal": + assert sliding_window is None + return fmha.attn_bias.LowerTriangularMask() + elif attn_bias_type == "block_causal": + assert sliding_window is None + assert eos_id is not None + assert tokens is not None + return fmha.attn_bias.BlockDiagonalCausalMask.from_seqlens( + q_seqlen=tokens_to_seqlen(tokens, eos_id) + ) + elif attn_bias_type == "local_block_causal": + assert sliding_window is not None + assert eos_id is not None + assert tokens is not None + return fmha.attn_bias.BlockDiagonalCausalMask.from_seqlens( + q_seqlen=tokens_to_seqlen(tokens, eos_id) + ).make_local_attention(sliding_window) + else: + return fmha.attn_bias.LocalAttentionFromBottomRightMask( + window_left=sliding_window - 1, window_right=0 + ) elif attn_impl == "sdpa": + global WARNED_SDPA + if not WARNED_SDPA: + logging.warning( + "SDPA attention being used, which doesn't have specialized attention implementations for block_causal and local_block_causal attention." + ) + WARNED_SDPA = True return "causal" elif attn_impl == "flex_attention": return create_block_mask(causal_mask, None, None, seqlen, seqlen) diff --git a/bytelatent/preprocess/fsspec_target.py b/bytelatent/preprocess/fsspec_target.py new file mode 100644 index 0000000..eacb101 --- /dev/null +++ b/bytelatent/preprocess/fsspec_target.py @@ -0,0 +1,38 @@ +import fsspec +from luigi.target import FileSystem, FileSystemTarget + + +class FSSpecFileSystem(FileSystem): + def __init__(self, fs: fsspec.AbstractFileSystem): + self.fs = fs + + def exists(self, path): + return self.fs.exists() + + def remove(self, path, recursive=True, skip_trash=True): + raise NotImplementedError() + + def isdir(self, path): + return self.fs.isdir(path) + + def listdir(self, path): + return self.fs.ls(path) + + +class FSSpecTarget(FileSystemTarget): + def __init__(self, path, fs: fsspec.AbstractFileSystem | None = None): + self.path = path + if fs is None: + self.fsspec_fs = fsspec.filesystem("file") + else: + self.fsspec_fs = fs + self._fs = None + + @property + def fs(self): + if self._fs is None: + self._fs = FSSpecFileSystem(self.fsspec_fs) + return self._fs + + def open(self, mode): + return self.fs.open(self.path, mode=mode) diff --git a/bytelatent/test_blt.py b/bytelatent/test_blt.py index 73ad9f7..4d8e9c7 100644 --- a/bytelatent/test_blt.py +++ b/bytelatent/test_blt.py @@ -23,9 +23,10 @@ from bytelatent.model.blt import ( init_embeddings, patch_ids_from_lengths, ) -from bytelatent.model.transformer import CrossAttention +from bytelatent.model.global_transformer import CrossAttention from bytelatent.model.utils import create_causal_mask from bytelatent.optim import OptimArgs, build_optimizer +from bytelatent.tokenizers.constants import EOS_ID from bytelatent.train import compute_loss @@ -51,7 +52,7 @@ def batch_to_tensors_and_gpu(batch): def fake_batch(): - batch_dict = torch.load(os.path.join(BLT_DATA, "test_batch.pt")) + batch_dict = torch.load(os.path.join(BLT_DATA, "test_batch.pt"), weights_only=False) del batch_dict["x2"] del batch_dict["y2"] del batch_dict["src_names"] @@ -98,18 +99,17 @@ def create_args(cross_attention=False): recompute_attn=False, custom_bwd=False, layer_ckpt="none", - efficient_attn="sdpa", - patch_only_encoder=False, - patch_only_decoder=False, use_local_encoder_transformer=True, init_use_gaussian=True, init_use_depth="current", attn_bias_type="block_causal", + attn_impl="xformers", alpha_depth="disabled", max_length=256, local_attention_window_len=512, max_seqlen=12288, downsampling_by_pooling="max", + eos_id=EOS_ID, ) return transformer_args @@ -341,10 +341,10 @@ class TestByteLatentTransformer: model = ByteLatentTransformer(args) assert model is not None - @pytest.mark.parametrize("attn_type", ["fmha", "sdpa"]) - def test_blt_transformer_forward(self, attn_type): + @pytest.mark.parametrize("attn_impl", ["sdpa", "xformers"]) + def test_blt_transformer_forward(self, attn_impl): args = create_args() - args = args.model_copy(update=dict(efficient_attn=attn_type)) + args = args.model_copy(update=dict(attn_impl=attn_impl)) model = ByteLatentTransformer(args) model = model.cuda() batch = fake_batch() @@ -393,7 +393,9 @@ class TestByteLatentTransformer: n_kv_heads=4, norm_eps=1e-6, ).to("cuda") - mask = create_causal_mask(x.shape[1], "flex_attention", sliding_window=None) + mask = create_causal_mask( + x.shape[1], "flex_attention", None, sliding_window=None + ) output = cross_attention(x, kv, mask) assert output is not None assert output.shape == (2, 256, 512) @@ -440,7 +442,7 @@ class TestByteLatentTransformer: def test_loss_backward(self): args = create_args() - args = args.model_copy(update=dict(efficient_attn="sdpa")) + args = args.model_copy(update=dict(attn_impl="sdpa")) batch = fake_batch() model = ByteLatentTransformer(args) steps = 10 diff --git a/bytelatent/test_entropy_model.py b/bytelatent/test_entropy_model.py index 3acc42d..af81638 100644 --- a/bytelatent/test_entropy_model.py +++ b/bytelatent/test_entropy_model.py @@ -24,6 +24,7 @@ def test_entropy_model(): dataset_files=[ARROW_TEST_DATA], row_num=0, arrow_batch_size=100, + s3_profile=None, ) arrow_file = initial_state.build() tokenizer_args = TokenizerArgs( diff --git a/bytelatent/train.py b/bytelatent/train.py index 6cb13b9..80bd393 100644 --- a/bytelatent/train.py +++ b/bytelatent/train.py @@ -644,6 +644,10 @@ def main(): cfg = OmegaConf.merge(default_cfg, file_cfg, cli_args) cfg = OmegaConf.to_container(cfg, resolve=True, throw_on_missing=True) train_args = TrainArgs.model_validate(cfg) + if train_args.debug_dynamo: + import torch._dynamo + + torch._dynamo.config.suppress_errors = True train(train_args) diff --git a/bytelatent/transformer.py b/bytelatent/transformer.py index 432f7df..92c5ff5 100644 --- a/bytelatent/transformer.py +++ b/bytelatent/transformer.py @@ -22,23 +22,7 @@ from bytelatent.base_transformer import ( RMSNorm, cross_entropy, ) - - -def create_causal_mask(seqlen, attn_impl, sliding_window): - if sliding_window is not None and attn_impl == "xformers": - return fmha.attn_bias.LocalAttentionFromBottomRightMask( - window_left=sliding_window - 1, window_right=0 - ) - elif attn_impl == "xformers": - return fmha.attn_bias.LowerTriangularMask() - elif attn_impl == "sdpa": - return "causal" - elif attn_impl == "flex_attention": - return create_block_mask(causal_mask, None, None, seqlen, seqlen) - else: - raise NotImplementedError( - f"Attention {attn_impl} with {sliding_window} sliding window not implemented" - ) +from bytelatent.model.utils import create_causal_mask def attention_flops_per_token(n_layers, seq_len, dim, causal): @@ -94,8 +78,10 @@ class LMTransformer(BaseTransformer): target: Optional[torch.Tensor] = None, tok_idx: Optional[torch.Tensor] = None, mask: Optional[Union[BlockMask, AttentionBias, torch.Tensor, str]] = None, - attn_impl: str = "sdpa", + attn_impl: str | None = None, ): + if attn_impl is None: + attn_impl = self.attn_impl bsz, seqlen = token_values.shape h = self.tok_embeddings(token_values) @@ -103,7 +89,14 @@ class LMTransformer(BaseTransformer): mask = ( mask if mask is not None - else create_causal_mask(seqlen, attn_impl, self.sliding_window) + else create_causal_mask( + seqlen, + attn_impl, + self.attn_bias_type, + sliding_window=self.sliding_window, + tokens=token_values, + eos_id=self.eos_id, + ) ) h = super().forward(h, tok_idx=tok_idx, mask=mask, attn_impl=attn_impl)