From 28016f144d5fbcc2f279955f0038463186f0de30 Mon Sep 17 00:00:00 2001 From: Pedro Rodriguez Date: Thu, 9 Jan 2025 20:06:18 +0000 Subject: [PATCH 01/18] Add plotting code from paper Summary: Test Plan: --- bytelatent/plotting/config_entropy_figure.yaml | 3 ++- bytelatent/plotting/entropy_figure.py | 17 ++++++++++++++++- plot_data/scores.json | 1 + 3 files changed, 19 insertions(+), 2 deletions(-) create mode 100644 plot_data/scores.json diff --git a/bytelatent/plotting/config_entropy_figure.yaml b/bytelatent/plotting/config_entropy_figure.yaml index 4d7bfd7..296ea07 100644 --- a/bytelatent/plotting/config_entropy_figure.yaml +++ b/bytelatent/plotting/config_entropy_figure.yaml @@ -1,3 +1,4 @@ data_path: plot_data/entropy_figure.json chart_path: figures/entropy_figure.pdf -# chart_path: figures/entropy_figure.pdf +threshold_override: 1.7171002626419067 +score_override_path: plot_data/scores.json diff --git a/bytelatent/plotting/entropy_figure.py b/bytelatent/plotting/entropy_figure.py index c1966a1..f401d7c 100644 --- a/bytelatent/plotting/entropy_figure.py +++ b/bytelatent/plotting/entropy_figure.py @@ -1,4 +1,5 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. +import json import os import sys from pathlib import Path @@ -12,6 +13,8 @@ from pydantic import BaseModel class PlotEntropiesConfig(BaseModel): data_path: str | None chart_path: str + score_override_path: str | None = None + threshold_override: float | None = None class Config: extra = "forbid" @@ -37,8 +40,20 @@ def main(): plot_config = PlotEntropiesConfig(**conf_dict) with open(plot_config.data_path) as f: json_data = f.read() + plot_data = PlotEntropiesData.model_validate_json(json_data) df = pd.read_json(plot_data.dataframe_json) + print("LEN", len(df)) + if plot_config.threshold_override is None: + threshold = plot_data.threshold + else: + threshold = plot_config.threshold_override + if plot_config.score_override_path is not None: + with open(plot_config.score_override_path) as f: + scores = json.load(f)["score"] + assert len(scores) == len(df) + df["entropies"] = scores + df["start"] = [1] + (df["entropies"] > threshold).values.tolist()[:-1] x_ticks = [] for row in df.itertuples(): @@ -65,7 +80,7 @@ def main(): ), ) rule = base.mark_rule(color="red", strokeDash=[4, 4]).encode( - y=alt.datum(plot_data.threshold), + y=alt.datum(threshold), ) patch_rules = ( alt.Chart(df[df["start"] > 0]) diff --git a/plot_data/scores.json b/plot_data/scores.json new file mode 100644 index 0000000..202cafc --- /dev/null +++ b/plot_data/scores.json @@ -0,0 +1 @@ +{"score": [3.3949153423309326, 2.1647746562957764, 2.3216569423675537, 2.8114914894104004, 1.505232334136963, 0.04055612534284592, 0.09150367230176926, 0.06008715182542801, 0.3453567624092102, 1.0483067035675049, 0.1967127025127411, 0.12737397849559784, 0.05923430994153023, 0.001597292022779584, 0.004362526815384626, 0.005547997076064348, 0.0011689786333590746, 0.0010273229563608766, 1.0228447914123535, 3.6863417625427246, 0.46605175733566284, 0.048645928502082825, 2.2544963359832764, 0.37329360842704773, 1.001160979270935, 2.9116122722625732, 1.8948925733566284, 1.4017235040664673, 0.3879640996456146, 0.2652309536933899, 1.780383825302124, 0.013964788988232613, 0.005456871818751097, 0.5426468253135681, 0.20666983723640442, 0.0051853349432349205, 0.0005802579107694328, 0.0007443525246344507, 0.0004390323010738939, 0.005452247802168131, 1.1932975053787231, 0.023798620328307152, 3.1230878829956055, 1.3915895223617554, 3.0489213466644287, 1.7018193006515503, 1.873910903930664, 1.4662408828735352, 0.004920408595353365, 0.02599342167377472, 0.6620859503746033, 0.31743818521499634, 2.8409600257873535, 1.1354060173034668, 0.0520976223051548, 0.3519965708255768, 0.40707266330718994, 2.5438783168792725, 1.3343133926391602, 0.023993035778403282, 3.445943832397461, 1.8542104959487915, 0.7849258780479431, 0.6848396062850952, 0.06938046962022781, 0.20923230051994324, 0.10084306448698044, 0.18334199488162994, 0.4126923978328705, 0.5505472421646118, 0.1042013093829155, 0.019447727128863335, 0.0014866517158225179, 0.0009848219342529774, 0.00021391961490735412, 0.007746236398816109, 0.00038792978739365935, 0.0007933690212666988, 1.2369810342788696, 0.4436197578907013, 4.6366687456611544e-05]} \ No newline at end of file From 84854423c49fa59187855785c935d6cdbc9e0c45 Mon Sep 17 00:00:00 2001 From: Pedro Rodriguez Date: Fri, 10 Jan 2025 01:00:54 +0000 Subject: [PATCH 02/18] Replace regular filesystem calls with fsspec + add s3 support Summary: For compatibility with either local/nfs or S3 datasets, swap to fsspec. Add a tool to compare local and remote filesystems Test Plan: - Ran regular train script - Ran with config with data in S3 --- .gitignore | 2 +- bytelatent/args.py | 8 +- bytelatent/data/file_util.py | 116 ++++++++++++++++++++ bytelatent/data/iterators/arrow_iterator.py | 95 +++++++++++++--- bytelatent/logger.py | 2 +- requirements.txt | 1 + setup/download_prepare_hf_data.py | 17 ++- 7 files changed, 222 insertions(+), 19 deletions(-) create mode 100644 bytelatent/data/file_util.py diff --git a/.gitignore b/.gitignore index 56891a9..6c664b8 100644 --- a/.gitignore +++ b/.gitignore @@ -165,4 +165,4 @@ cython_debug/ figures/ .vscode/ .DS_Store - +internal/ diff --git a/bytelatent/args.py b/bytelatent/args.py index 4fba100..cfba3bf 100644 --- a/bytelatent/args.py +++ b/bytelatent/args.py @@ -46,8 +46,11 @@ def distribute_data_to_rank( arrow_batch_size: int, rank: int, world_size: int, + s3_profile: str | None = None, ) -> ArrowFileIterator: - dataset_chunks = find_and_sanitize_chunks(dataset_path, world_size) + dataset_chunks = find_and_sanitize_chunks( + dataset_path, world_size, s3_profile=s3_profile + ) n_workers_per_chunk = world_size // len(dataset_chunks) rank_to_arrow_iterator_params = [] for chunk_path in dataset_chunks: @@ -61,6 +64,7 @@ def distribute_data_to_rank( dataset_files=None, entropy_model_name=entropy_model_name, arrow_batch_size=arrow_batch_size, + s3_profile=s3_profile, ) ) return rank_to_arrow_iterator_params[rank] @@ -68,6 +72,7 @@ def distribute_data_to_rank( class DataloaderArgs(BaseModel): model_config = ConfigDict(extra="forbid") + s3_profile: str | None = None root_dir: str | None = None sources: dict[str, float] = {} batch_size: int = 2 @@ -107,6 +112,7 @@ class DataloaderArgs(BaseModel): arrow_batch_size=self.arrow_batch_size, rank=rank, world_size=world_size, + s3_profile=self.s3_profile, ) looping_iterator = LoopingIterator(arrow_iterator) preprocess_iterator = PreprocessIterator( diff --git a/bytelatent/data/file_util.py b/bytelatent/data/file_util.py new file mode 100644 index 0000000..312e1b4 --- /dev/null +++ b/bytelatent/data/file_util.py @@ -0,0 +1,116 @@ +import os +from rich.progress import track +import typer +import fsspec +import pyarrow as pa +from pyarrow.lib import ArrowInvalid + +# pyarrow needs the initialization from this import +import pyarrow.dataset # pyright: ignore + + +def is_valid_arrow_file(path: str): + try: + dataset = pa.dataset.dataset(path, format="arrow") + return True + except ArrowInvalid: + return False + + +app = typer.Typer() + +S3_PREFIX = "s3://" + + +def get_fs(path: str, s3_profile: str | None = None) -> fsspec.AbstractFileSystem: + if path.startswith("s3://"): + if s3_profile is None: + return fsspec.filesystem("s3") + else: + return fsspec.filesystem("s3", profile=s3_profile) + else: + return fsspec.filesystem("file") + + +@app.command() +def print_local_to_delete( + blob_dir: str, local_dirs: list[str], s3_profile: str = "blt" +): + for s in local_dirs: + assert s.endswith("/"), "Dirs must end with /" + assert blob_dir.endswith("/"), "Dirs must end with /" + blob_fs = fsspec.filesystem("s3", profile=s3_profile) + blob_files = blob_fs.find(blob_dir) + for f in track(blob_files): + size = blob_fs.info(f)["Size"] + if not f.lower().endswith(".complete"): + assert size != 0, f"Size was invalidly zero for {f}" + + blob_relative_paths = {f[len(blob_dir) - len(S3_PREFIX) :] for f in blob_files} + local_fs = fsspec.filesystem("file") + + files_to_delete = [] + for local_dir in local_dirs: + local_files = local_fs.find(local_dir) + for f in local_files: + relative_path = f[len(local_dir) :] + if relative_path in blob_relative_paths and not os.path.islink(f): + files_to_delete.append(f) + print(len(files_to_delete)) + with open("/tmp/files_to_delete.txt", "w") as f: + for file in files_to_delete: + f.write(f"{file}\n") + + +@app.command() +def compare_local_to_blob( + source_dirs: list[str], dst_dir: str, s3_profile: str = "blt" +): + for s in source_dirs: + assert s.endswith("/"), "Dirs must end with /" + assert dst_dir.endswith("/"), "Dirs must end with /" + assert len(source_dirs) != 0 + assert dst_dir.startswith("s3://") + local_fs = fsspec.filesystem("file") + dst_fs = fsspec.filesystem("s3", profile=s3_profile) + source_to_files = {} + all_local_files = set() + for s in source_dirs: + skipped = [] + if s not in source_to_files: + source_to_files[s] = [] + for f in local_fs.find(s): + if os.path.islink(f): + continue + if f.endswith(".COMPLETE") or f.endswith(".complete"): + is_complete_file = True + assert os.path.getsize(f) == 0, ".COMPLETE files should be empty" + else: + is_complete_file = False + + if not is_complete_file and os.path.getsize(f) == 0: + skipped.append(f) + continue + if f.endswith(".arrow"): + if not is_valid_arrow_file(f): + skipped.append(f) + continue + + source_to_files[s].append(f) + all_local_files.add(f[len(s) :]) + print(s, len(source_to_files[s]), "skipped", len(skipped), skipped[:10]) + + dst_files = dst_fs.find(dst_dir) + print(dst_dir, len(dst_files)) + + dst_file_set = {f[len(dst_dir) - len(S3_PREFIX) :] for f in dst_files} + diff = all_local_files.symmetric_difference(dst_file_set) + print("Local files", len(all_local_files)) + print("DST Files", len(dst_file_set)) + print("Symmetric difference", len(diff)) + dst_only_files = dst_file_set - all_local_files + print("DST only", len(dst_only_files), list(dst_only_files)[:10]) + + +if __name__ == "__main__": + app() diff --git a/bytelatent/data/iterators/arrow_iterator.py b/bytelatent/data/iterators/arrow_iterator.py index df5f023..17569ee 100644 --- a/bytelatent/data/iterators/arrow_iterator.py +++ b/bytelatent/data/iterators/arrow_iterator.py @@ -1,7 +1,7 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. +import os import re from logging import getLogger -from pathlib import Path from typing import Any, Generator import pyarrow as pa @@ -9,9 +9,12 @@ import pyarrow as pa # pyarrow needs the initialization from this import import pyarrow.dataset # pyright: ignore from pydantic import BaseModel, ConfigDict +import fsspec +import s3fs from bytelatent import ByteLatentError from bytelatent.data.data_types import BltExample +from bytelatent.data.file_util import get_fs from bytelatent.data.iterators.abstract_iterator import IteratorState, StatefulIterator logger = getLogger(__name__) @@ -27,6 +30,8 @@ class ArrowFileIteratorState(BaseModel, IteratorState): dataset_files: list[str] | None entropy_model_name: str | None arrow_batch_size: int = 100 + s3_profile: str | None + filesystem_type: str | None = None def build(self) -> "ArrowFileIterator": arrow_file = ArrowFileIterator( @@ -37,14 +42,17 @@ class ArrowFileIteratorState(BaseModel, IteratorState): entropy_model_name=self.entropy_model_name, arrow_batch_size=self.arrow_batch_size, dataset_files=self.dataset_files, + s3_profile=self.s3_profile, + filesystem_type=self.filesystem_type, ) if self.row_num != 0: arrow_file._set_row_num(self.row_num) return arrow_file -def shard_sort_key(file: str | Path): - match = re.search(r".+\.shard_([0-9]+)\.arrow", str(file)) +def shard_sort_key(file: str): + assert isinstance(file, str) + match = re.search(r".+\.shard_([0-9]+)\.arrow", file) shard_number = int(match.group(1)) return shard_number @@ -60,6 +68,8 @@ class ArrowFileIterator(StatefulIterator): entropy_model_name: str | None, arrow_batch_size: int, dataset_files: list[str] | None = None, + s3_profile: str | None = None, + filesystem_type: str | None = None, ): assert 0 <= worker_id < num_workers, (worker_id, num_workers) if file_path is None and dataset_files is None: @@ -75,16 +85,44 @@ class ArrowFileIterator(StatefulIterator): self.preprocess_dir = preprocess_dir self.entropy_model_name = entropy_model_name self.arrow_batch_size = arrow_batch_size + self.s3_profile = s3_profile + self.filesystem_type = filesystem_type + self.fs = None + if self.filesystem_type is not None: + if self.filesystem_type == "file": + self.fs = fsspec.filesystem("file") + elif self.filesystem_type == "s3": + self.fs = fsspec.filesystem("s3", profile=s3_profile) + if dataset_files is None: # Prepare arrow shards - jsonl_file = Path(file_path) - parts = re.match(r"(.+)\.chunk\.[0-9]+\.jsonl", jsonl_file.name) + jsonl_file = file_path + parts = re.match( + r"(.+)\.chunk\.[0-9]+\.jsonl", os.path.basename(jsonl_file) + ) assert parts is not None dataset = parts.group(1) - data_dir = Path(preprocess_dir) / dataset / entropy_model_name - shard_files = list(data_dir.glob(f"{jsonl_file.name}.shard_*.arrow")) + data_dir = os.path.join(preprocess_dir, dataset, entropy_model_name) + data_dir_with_glob = os.path.join( + data_dir, f"{os.path.basename(jsonl_file)}.shard_*.arrow" + ) + print("preprocess_dir", preprocess_dir) + print("data_dir_with_glob", data_dir_with_glob) + if self.fs is None: + self.fs = get_fs(data_dir_with_glob, s3_profile=s3_profile) + if isinstance(self.fs, s3fs.S3FileSystem): + self.filesystem_type = "s3" + else: + self.filesystem_type = "file" + print("fs", self.fs) + shard_files = self.fs.glob(data_dir_with_glob) + for s in shard_files: - if not (data_dir / f"{s.name}.complete").exists(): + complete_file = os.path.join( + data_dir, f"{os.path.basename(s)}.complete" + ) + + if not self.fs.exists(complete_file): raise ValueError(f"Missing .complete for input file: {s}") shard_files = sorted(shard_files, key=shard_sort_key) @@ -92,10 +130,19 @@ class ArrowFileIterator(StatefulIterator): raise ByteLatentError( f"Zero shard_files found corresponding to: {file_path} using preprocess_dir={preprocess_dir} and entropy_model_name={entropy_model_name}, so the search path is data_dir={data_dir} for matches to {jsonl_file.name}.shard_*.arrow" ) - self.dataset_files = [str(f) for f in shard_files] + self.dataset_files = [f for f in shard_files] else: self.preprocess_dir = None self.dataset_files = dataset_files + if dataset_files[0].startswith("s3://"): + for f in dataset_files: + assert f.startswith("s3://") + if self.fs is None: + self.fs = get_fs(dataset_files[0], s3_profile=s3_profile) + if isinstance(self.fs, s3fs.S3FileSystem): + self.filesystem_type = "s3" + else: + self.filesystem_type = "file" def get_state(self) -> ArrowFileIteratorState: return ArrowFileIteratorState( @@ -107,13 +154,24 @@ class ArrowFileIterator(StatefulIterator): entropy_model_name=self.entropy_model_name, arrow_batch_size=self.arrow_batch_size, dataset_files=self.dataset_files, + s3_profile=self.s3_profile, + filesystem_type=self.filesystem_type, ) def create_iter( self, ) -> Generator[BltExample, Any, None]: if self.dataset is None: - self.dataset = pa.dataset.dataset(self.dataset_files, format="arrow") + if isinstance(self.fs, s3fs.core.S3FileSystem): + print("Using s3fs") + filesystem = self.fs + else: + print("Using local fs") + filesystem = None + print("Iterating over", self.dataset_files) + self.dataset = pa.dataset.dataset( + self.dataset_files, format="arrow", filesystem=filesystem + ) self.batch_iterator = self.dataset.to_batches( batch_size=self.arrow_batch_size ) @@ -165,7 +223,13 @@ class ArrowFileIterator(StatefulIterator): self.batch_iterator = None self.batch_to_consume = None else: - self.dataset = pa.dataset.dataset(self.dataset_files, format="arrow") + if isinstance(self.fs, s3fs.core.S3FileSystem): + filesystem = self.fs + else: + filesystem = None + self.dataset = pa.dataset.dataset( + self.dataset_files, format="arrow", filesystem=filesystem + ) self.batch_iterator = self.dataset.to_batches( batch_size=self.arrow_batch_size ) @@ -198,9 +262,14 @@ TRAIN_DATA_FILE_PATTERN = "*.chunk.*.jsonl" def find_and_sanitize_chunks( - dataset_path: str, world_size: int, file_pattern: str = TRAIN_DATA_FILE_PATTERN + dataset_path: str, + world_size: int, + file_pattern: str = TRAIN_DATA_FILE_PATTERN, + s3_profile: str | None = None, ): - dataset_chunks = [str(p) for p in Path(dataset_path).glob(file_pattern)] + fs = get_fs(dataset_path, s3_profile=s3_profile) + path_with_glob = os.path.join(dataset_path, file_pattern) + dataset_chunks = fs.glob(path_with_glob) n_chunks = len(dataset_chunks) if n_chunks > world_size: diff --git a/bytelatent/logger.py b/bytelatent/logger.py index 6723a84..87f04cc 100644 --- a/bytelatent/logger.py +++ b/bytelatent/logger.py @@ -91,7 +91,7 @@ def init_logger( log_file: str | None = None, *, name: str | None = None, - level: str = "NOTSET", + level: str = "INFO", ): """ Setup logging. diff --git a/requirements.txt b/requirements.txt index c6d87f1..59192cd 100644 --- a/requirements.txt +++ b/requirements.txt @@ -20,3 +20,4 @@ altair submitit typer rich +fsspec[full] diff --git a/setup/download_prepare_hf_data.py b/setup/download_prepare_hf_data.py index 1aacf8f..e194857 100644 --- a/setup/download_prepare_hf_data.py +++ b/setup/download_prepare_hf_data.py @@ -5,6 +5,7 @@ import os import subprocess import time +import fsspec import requests from huggingface_hub import snapshot_download @@ -38,11 +39,21 @@ def download_dataset(repo_id, local_dir, allow_patterns): print(f"Dataset downloaded to {local_dir}") -def parquet_to_jsonl(dataset, work_dir, src_dir, tgt_dir, ntasks=64): +def parquet_to_jsonl( + dataset, work_dir, src_dir, tgt_dir, ntasks=64, s3_profile: str | None = None +): from datatrove.executor import LocalPipelineExecutor from datatrove.pipeline.readers import ParquetReader from datatrove.pipeline.writers import JsonlWriter + if tgt_dir.startswith("s3//"): + if s3_profile is None: + out_spec = tgt_dir + else: + out_spec = (tgt_dir, fsspec.filesystem("s3", profile=s3_profile)) + else: + out_spec = tgt_dir + pipeline_exec = LocalPipelineExecutor( pipeline=[ ParquetReader( @@ -52,7 +63,7 @@ def parquet_to_jsonl(dataset, work_dir, src_dir, tgt_dir, ntasks=64): glob_pattern="**/*.parquet", ), JsonlWriter( - tgt_dir, + out_spec, output_filename=dataset + ".chunk.${rank}.jsonl", compression=None, ), @@ -77,7 +88,7 @@ def setup_terashuf(work_dir): return terashuf_dir -def main(dataset, memory, data_dir, seed=42, nchunks=32): +def main(dataset, memory, data_dir, seed=42, nchunks=32, s3_profile: str | None = None): # Configuration repo_id = { "fineweb_edu": "HuggingFaceFW/fineweb-edu", From c137f4e636ee85ef4110b416b404c455b86efba9 Mon Sep 17 00:00:00 2001 From: Pedro Rodriguez Date: Fri, 10 Jan 2025 01:02:25 +0000 Subject: [PATCH 03/18] Replace regular filesystem calls with fsspec + add s3 support Summary: For compatibility with either local/nfs or S3 datasets, swap to fsspec. Add a tool to compare local and remote filesystems Test Plan: - Ran regular train script - Ran with config with data in S3 --- .gitignore | 2 +- bytelatent/args.py | 8 +- bytelatent/data/file_util.py | 116 ++++++++++++++++++++ bytelatent/data/iterators/arrow_iterator.py | 95 +++++++++++++--- bytelatent/logger.py | 2 +- requirements.txt | 1 + setup/download_prepare_hf_data.py | 17 ++- 7 files changed, 222 insertions(+), 19 deletions(-) create mode 100644 bytelatent/data/file_util.py diff --git a/.gitignore b/.gitignore index 56891a9..6c664b8 100644 --- a/.gitignore +++ b/.gitignore @@ -165,4 +165,4 @@ cython_debug/ figures/ .vscode/ .DS_Store - +internal/ diff --git a/bytelatent/args.py b/bytelatent/args.py index 4fba100..cfba3bf 100644 --- a/bytelatent/args.py +++ b/bytelatent/args.py @@ -46,8 +46,11 @@ def distribute_data_to_rank( arrow_batch_size: int, rank: int, world_size: int, + s3_profile: str | None = None, ) -> ArrowFileIterator: - dataset_chunks = find_and_sanitize_chunks(dataset_path, world_size) + dataset_chunks = find_and_sanitize_chunks( + dataset_path, world_size, s3_profile=s3_profile + ) n_workers_per_chunk = world_size // len(dataset_chunks) rank_to_arrow_iterator_params = [] for chunk_path in dataset_chunks: @@ -61,6 +64,7 @@ def distribute_data_to_rank( dataset_files=None, entropy_model_name=entropy_model_name, arrow_batch_size=arrow_batch_size, + s3_profile=s3_profile, ) ) return rank_to_arrow_iterator_params[rank] @@ -68,6 +72,7 @@ def distribute_data_to_rank( class DataloaderArgs(BaseModel): model_config = ConfigDict(extra="forbid") + s3_profile: str | None = None root_dir: str | None = None sources: dict[str, float] = {} batch_size: int = 2 @@ -107,6 +112,7 @@ class DataloaderArgs(BaseModel): arrow_batch_size=self.arrow_batch_size, rank=rank, world_size=world_size, + s3_profile=self.s3_profile, ) looping_iterator = LoopingIterator(arrow_iterator) preprocess_iterator = PreprocessIterator( diff --git a/bytelatent/data/file_util.py b/bytelatent/data/file_util.py new file mode 100644 index 0000000..312e1b4 --- /dev/null +++ b/bytelatent/data/file_util.py @@ -0,0 +1,116 @@ +import os +from rich.progress import track +import typer +import fsspec +import pyarrow as pa +from pyarrow.lib import ArrowInvalid + +# pyarrow needs the initialization from this import +import pyarrow.dataset # pyright: ignore + + +def is_valid_arrow_file(path: str): + try: + dataset = pa.dataset.dataset(path, format="arrow") + return True + except ArrowInvalid: + return False + + +app = typer.Typer() + +S3_PREFIX = "s3://" + + +def get_fs(path: str, s3_profile: str | None = None) -> fsspec.AbstractFileSystem: + if path.startswith("s3://"): + if s3_profile is None: + return fsspec.filesystem("s3") + else: + return fsspec.filesystem("s3", profile=s3_profile) + else: + return fsspec.filesystem("file") + + +@app.command() +def print_local_to_delete( + blob_dir: str, local_dirs: list[str], s3_profile: str = "blt" +): + for s in local_dirs: + assert s.endswith("/"), "Dirs must end with /" + assert blob_dir.endswith("/"), "Dirs must end with /" + blob_fs = fsspec.filesystem("s3", profile=s3_profile) + blob_files = blob_fs.find(blob_dir) + for f in track(blob_files): + size = blob_fs.info(f)["Size"] + if not f.lower().endswith(".complete"): + assert size != 0, f"Size was invalidly zero for {f}" + + blob_relative_paths = {f[len(blob_dir) - len(S3_PREFIX) :] for f in blob_files} + local_fs = fsspec.filesystem("file") + + files_to_delete = [] + for local_dir in local_dirs: + local_files = local_fs.find(local_dir) + for f in local_files: + relative_path = f[len(local_dir) :] + if relative_path in blob_relative_paths and not os.path.islink(f): + files_to_delete.append(f) + print(len(files_to_delete)) + with open("/tmp/files_to_delete.txt", "w") as f: + for file in files_to_delete: + f.write(f"{file}\n") + + +@app.command() +def compare_local_to_blob( + source_dirs: list[str], dst_dir: str, s3_profile: str = "blt" +): + for s in source_dirs: + assert s.endswith("/"), "Dirs must end with /" + assert dst_dir.endswith("/"), "Dirs must end with /" + assert len(source_dirs) != 0 + assert dst_dir.startswith("s3://") + local_fs = fsspec.filesystem("file") + dst_fs = fsspec.filesystem("s3", profile=s3_profile) + source_to_files = {} + all_local_files = set() + for s in source_dirs: + skipped = [] + if s not in source_to_files: + source_to_files[s] = [] + for f in local_fs.find(s): + if os.path.islink(f): + continue + if f.endswith(".COMPLETE") or f.endswith(".complete"): + is_complete_file = True + assert os.path.getsize(f) == 0, ".COMPLETE files should be empty" + else: + is_complete_file = False + + if not is_complete_file and os.path.getsize(f) == 0: + skipped.append(f) + continue + if f.endswith(".arrow"): + if not is_valid_arrow_file(f): + skipped.append(f) + continue + + source_to_files[s].append(f) + all_local_files.add(f[len(s) :]) + print(s, len(source_to_files[s]), "skipped", len(skipped), skipped[:10]) + + dst_files = dst_fs.find(dst_dir) + print(dst_dir, len(dst_files)) + + dst_file_set = {f[len(dst_dir) - len(S3_PREFIX) :] for f in dst_files} + diff = all_local_files.symmetric_difference(dst_file_set) + print("Local files", len(all_local_files)) + print("DST Files", len(dst_file_set)) + print("Symmetric difference", len(diff)) + dst_only_files = dst_file_set - all_local_files + print("DST only", len(dst_only_files), list(dst_only_files)[:10]) + + +if __name__ == "__main__": + app() diff --git a/bytelatent/data/iterators/arrow_iterator.py b/bytelatent/data/iterators/arrow_iterator.py index df5f023..17569ee 100644 --- a/bytelatent/data/iterators/arrow_iterator.py +++ b/bytelatent/data/iterators/arrow_iterator.py @@ -1,7 +1,7 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. +import os import re from logging import getLogger -from pathlib import Path from typing import Any, Generator import pyarrow as pa @@ -9,9 +9,12 @@ import pyarrow as pa # pyarrow needs the initialization from this import import pyarrow.dataset # pyright: ignore from pydantic import BaseModel, ConfigDict +import fsspec +import s3fs from bytelatent import ByteLatentError from bytelatent.data.data_types import BltExample +from bytelatent.data.file_util import get_fs from bytelatent.data.iterators.abstract_iterator import IteratorState, StatefulIterator logger = getLogger(__name__) @@ -27,6 +30,8 @@ class ArrowFileIteratorState(BaseModel, IteratorState): dataset_files: list[str] | None entropy_model_name: str | None arrow_batch_size: int = 100 + s3_profile: str | None + filesystem_type: str | None = None def build(self) -> "ArrowFileIterator": arrow_file = ArrowFileIterator( @@ -37,14 +42,17 @@ class ArrowFileIteratorState(BaseModel, IteratorState): entropy_model_name=self.entropy_model_name, arrow_batch_size=self.arrow_batch_size, dataset_files=self.dataset_files, + s3_profile=self.s3_profile, + filesystem_type=self.filesystem_type, ) if self.row_num != 0: arrow_file._set_row_num(self.row_num) return arrow_file -def shard_sort_key(file: str | Path): - match = re.search(r".+\.shard_([0-9]+)\.arrow", str(file)) +def shard_sort_key(file: str): + assert isinstance(file, str) + match = re.search(r".+\.shard_([0-9]+)\.arrow", file) shard_number = int(match.group(1)) return shard_number @@ -60,6 +68,8 @@ class ArrowFileIterator(StatefulIterator): entropy_model_name: str | None, arrow_batch_size: int, dataset_files: list[str] | None = None, + s3_profile: str | None = None, + filesystem_type: str | None = None, ): assert 0 <= worker_id < num_workers, (worker_id, num_workers) if file_path is None and dataset_files is None: @@ -75,16 +85,44 @@ class ArrowFileIterator(StatefulIterator): self.preprocess_dir = preprocess_dir self.entropy_model_name = entropy_model_name self.arrow_batch_size = arrow_batch_size + self.s3_profile = s3_profile + self.filesystem_type = filesystem_type + self.fs = None + if self.filesystem_type is not None: + if self.filesystem_type == "file": + self.fs = fsspec.filesystem("file") + elif self.filesystem_type == "s3": + self.fs = fsspec.filesystem("s3", profile=s3_profile) + if dataset_files is None: # Prepare arrow shards - jsonl_file = Path(file_path) - parts = re.match(r"(.+)\.chunk\.[0-9]+\.jsonl", jsonl_file.name) + jsonl_file = file_path + parts = re.match( + r"(.+)\.chunk\.[0-9]+\.jsonl", os.path.basename(jsonl_file) + ) assert parts is not None dataset = parts.group(1) - data_dir = Path(preprocess_dir) / dataset / entropy_model_name - shard_files = list(data_dir.glob(f"{jsonl_file.name}.shard_*.arrow")) + data_dir = os.path.join(preprocess_dir, dataset, entropy_model_name) + data_dir_with_glob = os.path.join( + data_dir, f"{os.path.basename(jsonl_file)}.shard_*.arrow" + ) + print("preprocess_dir", preprocess_dir) + print("data_dir_with_glob", data_dir_with_glob) + if self.fs is None: + self.fs = get_fs(data_dir_with_glob, s3_profile=s3_profile) + if isinstance(self.fs, s3fs.S3FileSystem): + self.filesystem_type = "s3" + else: + self.filesystem_type = "file" + print("fs", self.fs) + shard_files = self.fs.glob(data_dir_with_glob) + for s in shard_files: - if not (data_dir / f"{s.name}.complete").exists(): + complete_file = os.path.join( + data_dir, f"{os.path.basename(s)}.complete" + ) + + if not self.fs.exists(complete_file): raise ValueError(f"Missing .complete for input file: {s}") shard_files = sorted(shard_files, key=shard_sort_key) @@ -92,10 +130,19 @@ class ArrowFileIterator(StatefulIterator): raise ByteLatentError( f"Zero shard_files found corresponding to: {file_path} using preprocess_dir={preprocess_dir} and entropy_model_name={entropy_model_name}, so the search path is data_dir={data_dir} for matches to {jsonl_file.name}.shard_*.arrow" ) - self.dataset_files = [str(f) for f in shard_files] + self.dataset_files = [f for f in shard_files] else: self.preprocess_dir = None self.dataset_files = dataset_files + if dataset_files[0].startswith("s3://"): + for f in dataset_files: + assert f.startswith("s3://") + if self.fs is None: + self.fs = get_fs(dataset_files[0], s3_profile=s3_profile) + if isinstance(self.fs, s3fs.S3FileSystem): + self.filesystem_type = "s3" + else: + self.filesystem_type = "file" def get_state(self) -> ArrowFileIteratorState: return ArrowFileIteratorState( @@ -107,13 +154,24 @@ class ArrowFileIterator(StatefulIterator): entropy_model_name=self.entropy_model_name, arrow_batch_size=self.arrow_batch_size, dataset_files=self.dataset_files, + s3_profile=self.s3_profile, + filesystem_type=self.filesystem_type, ) def create_iter( self, ) -> Generator[BltExample, Any, None]: if self.dataset is None: - self.dataset = pa.dataset.dataset(self.dataset_files, format="arrow") + if isinstance(self.fs, s3fs.core.S3FileSystem): + print("Using s3fs") + filesystem = self.fs + else: + print("Using local fs") + filesystem = None + print("Iterating over", self.dataset_files) + self.dataset = pa.dataset.dataset( + self.dataset_files, format="arrow", filesystem=filesystem + ) self.batch_iterator = self.dataset.to_batches( batch_size=self.arrow_batch_size ) @@ -165,7 +223,13 @@ class ArrowFileIterator(StatefulIterator): self.batch_iterator = None self.batch_to_consume = None else: - self.dataset = pa.dataset.dataset(self.dataset_files, format="arrow") + if isinstance(self.fs, s3fs.core.S3FileSystem): + filesystem = self.fs + else: + filesystem = None + self.dataset = pa.dataset.dataset( + self.dataset_files, format="arrow", filesystem=filesystem + ) self.batch_iterator = self.dataset.to_batches( batch_size=self.arrow_batch_size ) @@ -198,9 +262,14 @@ TRAIN_DATA_FILE_PATTERN = "*.chunk.*.jsonl" def find_and_sanitize_chunks( - dataset_path: str, world_size: int, file_pattern: str = TRAIN_DATA_FILE_PATTERN + dataset_path: str, + world_size: int, + file_pattern: str = TRAIN_DATA_FILE_PATTERN, + s3_profile: str | None = None, ): - dataset_chunks = [str(p) for p in Path(dataset_path).glob(file_pattern)] + fs = get_fs(dataset_path, s3_profile=s3_profile) + path_with_glob = os.path.join(dataset_path, file_pattern) + dataset_chunks = fs.glob(path_with_glob) n_chunks = len(dataset_chunks) if n_chunks > world_size: diff --git a/bytelatent/logger.py b/bytelatent/logger.py index 6723a84..87f04cc 100644 --- a/bytelatent/logger.py +++ b/bytelatent/logger.py @@ -91,7 +91,7 @@ def init_logger( log_file: str | None = None, *, name: str | None = None, - level: str = "NOTSET", + level: str = "INFO", ): """ Setup logging. diff --git a/requirements.txt b/requirements.txt index c6d87f1..59192cd 100644 --- a/requirements.txt +++ b/requirements.txt @@ -20,3 +20,4 @@ altair submitit typer rich +fsspec[full] diff --git a/setup/download_prepare_hf_data.py b/setup/download_prepare_hf_data.py index 1aacf8f..e194857 100644 --- a/setup/download_prepare_hf_data.py +++ b/setup/download_prepare_hf_data.py @@ -5,6 +5,7 @@ import os import subprocess import time +import fsspec import requests from huggingface_hub import snapshot_download @@ -38,11 +39,21 @@ def download_dataset(repo_id, local_dir, allow_patterns): print(f"Dataset downloaded to {local_dir}") -def parquet_to_jsonl(dataset, work_dir, src_dir, tgt_dir, ntasks=64): +def parquet_to_jsonl( + dataset, work_dir, src_dir, tgt_dir, ntasks=64, s3_profile: str | None = None +): from datatrove.executor import LocalPipelineExecutor from datatrove.pipeline.readers import ParquetReader from datatrove.pipeline.writers import JsonlWriter + if tgt_dir.startswith("s3//"): + if s3_profile is None: + out_spec = tgt_dir + else: + out_spec = (tgt_dir, fsspec.filesystem("s3", profile=s3_profile)) + else: + out_spec = tgt_dir + pipeline_exec = LocalPipelineExecutor( pipeline=[ ParquetReader( @@ -52,7 +63,7 @@ def parquet_to_jsonl(dataset, work_dir, src_dir, tgt_dir, ntasks=64): glob_pattern="**/*.parquet", ), JsonlWriter( - tgt_dir, + out_spec, output_filename=dataset + ".chunk.${rank}.jsonl", compression=None, ), @@ -77,7 +88,7 @@ def setup_terashuf(work_dir): return terashuf_dir -def main(dataset, memory, data_dir, seed=42, nchunks=32): +def main(dataset, memory, data_dir, seed=42, nchunks=32, s3_profile: str | None = None): # Configuration repo_id = { "fineweb_edu": "HuggingFaceFW/fineweb-edu", From 1bf6d15e5a3ff76f128288fc547e2604f1981e2f Mon Sep 17 00:00:00 2001 From: Pedro Rodriguez Date: Fri, 10 Jan 2025 01:02:25 +0000 Subject: [PATCH 04/18] Replace regular filesystem calls with fsspec + add s3 support Summary: For compatibility with either local/nfs or S3 datasets, swap to fsspec. Add a tool to compare local and remote filesystems Test Plan: - Ran regular train script - Ran with config with data in S3 --- .gitignore | 2 +- bytelatent/args.py | 8 +- bytelatent/data/file_util.py | 116 ++++++++++++++++++++ bytelatent/data/iterators/arrow_iterator.py | 89 ++++++++++++--- bytelatent/logger.py | 2 +- requirements.txt | 1 + setup/download_prepare_hf_data.py | 17 ++- 7 files changed, 216 insertions(+), 19 deletions(-) create mode 100644 bytelatent/data/file_util.py diff --git a/.gitignore b/.gitignore index 56891a9..6c664b8 100644 --- a/.gitignore +++ b/.gitignore @@ -165,4 +165,4 @@ cython_debug/ figures/ .vscode/ .DS_Store - +internal/ diff --git a/bytelatent/args.py b/bytelatent/args.py index 4fba100..cfba3bf 100644 --- a/bytelatent/args.py +++ b/bytelatent/args.py @@ -46,8 +46,11 @@ def distribute_data_to_rank( arrow_batch_size: int, rank: int, world_size: int, + s3_profile: str | None = None, ) -> ArrowFileIterator: - dataset_chunks = find_and_sanitize_chunks(dataset_path, world_size) + dataset_chunks = find_and_sanitize_chunks( + dataset_path, world_size, s3_profile=s3_profile + ) n_workers_per_chunk = world_size // len(dataset_chunks) rank_to_arrow_iterator_params = [] for chunk_path in dataset_chunks: @@ -61,6 +64,7 @@ def distribute_data_to_rank( dataset_files=None, entropy_model_name=entropy_model_name, arrow_batch_size=arrow_batch_size, + s3_profile=s3_profile, ) ) return rank_to_arrow_iterator_params[rank] @@ -68,6 +72,7 @@ def distribute_data_to_rank( class DataloaderArgs(BaseModel): model_config = ConfigDict(extra="forbid") + s3_profile: str | None = None root_dir: str | None = None sources: dict[str, float] = {} batch_size: int = 2 @@ -107,6 +112,7 @@ class DataloaderArgs(BaseModel): arrow_batch_size=self.arrow_batch_size, rank=rank, world_size=world_size, + s3_profile=self.s3_profile, ) looping_iterator = LoopingIterator(arrow_iterator) preprocess_iterator = PreprocessIterator( diff --git a/bytelatent/data/file_util.py b/bytelatent/data/file_util.py new file mode 100644 index 0000000..312e1b4 --- /dev/null +++ b/bytelatent/data/file_util.py @@ -0,0 +1,116 @@ +import os +from rich.progress import track +import typer +import fsspec +import pyarrow as pa +from pyarrow.lib import ArrowInvalid + +# pyarrow needs the initialization from this import +import pyarrow.dataset # pyright: ignore + + +def is_valid_arrow_file(path: str): + try: + dataset = pa.dataset.dataset(path, format="arrow") + return True + except ArrowInvalid: + return False + + +app = typer.Typer() + +S3_PREFIX = "s3://" + + +def get_fs(path: str, s3_profile: str | None = None) -> fsspec.AbstractFileSystem: + if path.startswith("s3://"): + if s3_profile is None: + return fsspec.filesystem("s3") + else: + return fsspec.filesystem("s3", profile=s3_profile) + else: + return fsspec.filesystem("file") + + +@app.command() +def print_local_to_delete( + blob_dir: str, local_dirs: list[str], s3_profile: str = "blt" +): + for s in local_dirs: + assert s.endswith("/"), "Dirs must end with /" + assert blob_dir.endswith("/"), "Dirs must end with /" + blob_fs = fsspec.filesystem("s3", profile=s3_profile) + blob_files = blob_fs.find(blob_dir) + for f in track(blob_files): + size = blob_fs.info(f)["Size"] + if not f.lower().endswith(".complete"): + assert size != 0, f"Size was invalidly zero for {f}" + + blob_relative_paths = {f[len(blob_dir) - len(S3_PREFIX) :] for f in blob_files} + local_fs = fsspec.filesystem("file") + + files_to_delete = [] + for local_dir in local_dirs: + local_files = local_fs.find(local_dir) + for f in local_files: + relative_path = f[len(local_dir) :] + if relative_path in blob_relative_paths and not os.path.islink(f): + files_to_delete.append(f) + print(len(files_to_delete)) + with open("/tmp/files_to_delete.txt", "w") as f: + for file in files_to_delete: + f.write(f"{file}\n") + + +@app.command() +def compare_local_to_blob( + source_dirs: list[str], dst_dir: str, s3_profile: str = "blt" +): + for s in source_dirs: + assert s.endswith("/"), "Dirs must end with /" + assert dst_dir.endswith("/"), "Dirs must end with /" + assert len(source_dirs) != 0 + assert dst_dir.startswith("s3://") + local_fs = fsspec.filesystem("file") + dst_fs = fsspec.filesystem("s3", profile=s3_profile) + source_to_files = {} + all_local_files = set() + for s in source_dirs: + skipped = [] + if s not in source_to_files: + source_to_files[s] = [] + for f in local_fs.find(s): + if os.path.islink(f): + continue + if f.endswith(".COMPLETE") or f.endswith(".complete"): + is_complete_file = True + assert os.path.getsize(f) == 0, ".COMPLETE files should be empty" + else: + is_complete_file = False + + if not is_complete_file and os.path.getsize(f) == 0: + skipped.append(f) + continue + if f.endswith(".arrow"): + if not is_valid_arrow_file(f): + skipped.append(f) + continue + + source_to_files[s].append(f) + all_local_files.add(f[len(s) :]) + print(s, len(source_to_files[s]), "skipped", len(skipped), skipped[:10]) + + dst_files = dst_fs.find(dst_dir) + print(dst_dir, len(dst_files)) + + dst_file_set = {f[len(dst_dir) - len(S3_PREFIX) :] for f in dst_files} + diff = all_local_files.symmetric_difference(dst_file_set) + print("Local files", len(all_local_files)) + print("DST Files", len(dst_file_set)) + print("Symmetric difference", len(diff)) + dst_only_files = dst_file_set - all_local_files + print("DST only", len(dst_only_files), list(dst_only_files)[:10]) + + +if __name__ == "__main__": + app() diff --git a/bytelatent/data/iterators/arrow_iterator.py b/bytelatent/data/iterators/arrow_iterator.py index df5f023..ad72091 100644 --- a/bytelatent/data/iterators/arrow_iterator.py +++ b/bytelatent/data/iterators/arrow_iterator.py @@ -1,7 +1,7 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. +import os import re from logging import getLogger -from pathlib import Path from typing import Any, Generator import pyarrow as pa @@ -9,9 +9,12 @@ import pyarrow as pa # pyarrow needs the initialization from this import import pyarrow.dataset # pyright: ignore from pydantic import BaseModel, ConfigDict +import fsspec +import s3fs from bytelatent import ByteLatentError from bytelatent.data.data_types import BltExample +from bytelatent.data.file_util import get_fs from bytelatent.data.iterators.abstract_iterator import IteratorState, StatefulIterator logger = getLogger(__name__) @@ -27,6 +30,8 @@ class ArrowFileIteratorState(BaseModel, IteratorState): dataset_files: list[str] | None entropy_model_name: str | None arrow_batch_size: int = 100 + s3_profile: str | None + filesystem_type: str | None = None def build(self) -> "ArrowFileIterator": arrow_file = ArrowFileIterator( @@ -37,14 +42,17 @@ class ArrowFileIteratorState(BaseModel, IteratorState): entropy_model_name=self.entropy_model_name, arrow_batch_size=self.arrow_batch_size, dataset_files=self.dataset_files, + s3_profile=self.s3_profile, + filesystem_type=self.filesystem_type, ) if self.row_num != 0: arrow_file._set_row_num(self.row_num) return arrow_file -def shard_sort_key(file: str | Path): - match = re.search(r".+\.shard_([0-9]+)\.arrow", str(file)) +def shard_sort_key(file: str): + assert isinstance(file, str) + match = re.search(r".+\.shard_([0-9]+)\.arrow", file) shard_number = int(match.group(1)) return shard_number @@ -60,6 +68,8 @@ class ArrowFileIterator(StatefulIterator): entropy_model_name: str | None, arrow_batch_size: int, dataset_files: list[str] | None = None, + s3_profile: str | None = None, + filesystem_type: str | None = None, ): assert 0 <= worker_id < num_workers, (worker_id, num_workers) if file_path is None and dataset_files is None: @@ -75,16 +85,41 @@ class ArrowFileIterator(StatefulIterator): self.preprocess_dir = preprocess_dir self.entropy_model_name = entropy_model_name self.arrow_batch_size = arrow_batch_size + self.s3_profile = s3_profile + self.filesystem_type = filesystem_type + self.fs = None + if self.filesystem_type is not None: + if self.filesystem_type == "file": + self.fs = fsspec.filesystem("file") + elif self.filesystem_type == "s3": + self.fs = fsspec.filesystem("s3", profile=s3_profile) + if dataset_files is None: # Prepare arrow shards - jsonl_file = Path(file_path) - parts = re.match(r"(.+)\.chunk\.[0-9]+\.jsonl", jsonl_file.name) + jsonl_file = file_path + parts = re.match( + r"(.+)\.chunk\.[0-9]+\.jsonl", os.path.basename(jsonl_file) + ) assert parts is not None dataset = parts.group(1) - data_dir = Path(preprocess_dir) / dataset / entropy_model_name - shard_files = list(data_dir.glob(f"{jsonl_file.name}.shard_*.arrow")) + data_dir = os.path.join(preprocess_dir, dataset, entropy_model_name) + data_dir_with_glob = os.path.join( + data_dir, f"{os.path.basename(jsonl_file)}.shard_*.arrow" + ) + if self.fs is None: + self.fs = get_fs(data_dir_with_glob, s3_profile=s3_profile) + if isinstance(self.fs, s3fs.S3FileSystem): + self.filesystem_type = "s3" + else: + self.filesystem_type = "file" + shard_files = self.fs.glob(data_dir_with_glob) + for s in shard_files: - if not (data_dir / f"{s.name}.complete").exists(): + complete_file = os.path.join( + data_dir, f"{os.path.basename(s)}.complete" + ) + + if not self.fs.exists(complete_file): raise ValueError(f"Missing .complete for input file: {s}") shard_files = sorted(shard_files, key=shard_sort_key) @@ -92,10 +127,19 @@ class ArrowFileIterator(StatefulIterator): raise ByteLatentError( f"Zero shard_files found corresponding to: {file_path} using preprocess_dir={preprocess_dir} and entropy_model_name={entropy_model_name}, so the search path is data_dir={data_dir} for matches to {jsonl_file.name}.shard_*.arrow" ) - self.dataset_files = [str(f) for f in shard_files] + self.dataset_files = [f for f in shard_files] else: self.preprocess_dir = None self.dataset_files = dataset_files + if dataset_files[0].startswith("s3://"): + for f in dataset_files: + assert f.startswith("s3://") + if self.fs is None: + self.fs = get_fs(dataset_files[0], s3_profile=s3_profile) + if isinstance(self.fs, s3fs.S3FileSystem): + self.filesystem_type = "s3" + else: + self.filesystem_type = "file" def get_state(self) -> ArrowFileIteratorState: return ArrowFileIteratorState( @@ -107,13 +151,21 @@ class ArrowFileIterator(StatefulIterator): entropy_model_name=self.entropy_model_name, arrow_batch_size=self.arrow_batch_size, dataset_files=self.dataset_files, + s3_profile=self.s3_profile, + filesystem_type=self.filesystem_type, ) def create_iter( self, ) -> Generator[BltExample, Any, None]: if self.dataset is None: - self.dataset = pa.dataset.dataset(self.dataset_files, format="arrow") + if isinstance(self.fs, s3fs.core.S3FileSystem): + filesystem = self.fs + else: + filesystem = None + self.dataset = pa.dataset.dataset( + self.dataset_files, format="arrow", filesystem=filesystem + ) self.batch_iterator = self.dataset.to_batches( batch_size=self.arrow_batch_size ) @@ -165,7 +217,13 @@ class ArrowFileIterator(StatefulIterator): self.batch_iterator = None self.batch_to_consume = None else: - self.dataset = pa.dataset.dataset(self.dataset_files, format="arrow") + if isinstance(self.fs, s3fs.core.S3FileSystem): + filesystem = self.fs + else: + filesystem = None + self.dataset = pa.dataset.dataset( + self.dataset_files, format="arrow", filesystem=filesystem + ) self.batch_iterator = self.dataset.to_batches( batch_size=self.arrow_batch_size ) @@ -198,9 +256,14 @@ TRAIN_DATA_FILE_PATTERN = "*.chunk.*.jsonl" def find_and_sanitize_chunks( - dataset_path: str, world_size: int, file_pattern: str = TRAIN_DATA_FILE_PATTERN + dataset_path: str, + world_size: int, + file_pattern: str = TRAIN_DATA_FILE_PATTERN, + s3_profile: str | None = None, ): - dataset_chunks = [str(p) for p in Path(dataset_path).glob(file_pattern)] + fs = get_fs(dataset_path, s3_profile=s3_profile) + path_with_glob = os.path.join(dataset_path, file_pattern) + dataset_chunks = fs.glob(path_with_glob) n_chunks = len(dataset_chunks) if n_chunks > world_size: diff --git a/bytelatent/logger.py b/bytelatent/logger.py index 6723a84..87f04cc 100644 --- a/bytelatent/logger.py +++ b/bytelatent/logger.py @@ -91,7 +91,7 @@ def init_logger( log_file: str | None = None, *, name: str | None = None, - level: str = "NOTSET", + level: str = "INFO", ): """ Setup logging. diff --git a/requirements.txt b/requirements.txt index c6d87f1..59192cd 100644 --- a/requirements.txt +++ b/requirements.txt @@ -20,3 +20,4 @@ altair submitit typer rich +fsspec[full] diff --git a/setup/download_prepare_hf_data.py b/setup/download_prepare_hf_data.py index 1aacf8f..e194857 100644 --- a/setup/download_prepare_hf_data.py +++ b/setup/download_prepare_hf_data.py @@ -5,6 +5,7 @@ import os import subprocess import time +import fsspec import requests from huggingface_hub import snapshot_download @@ -38,11 +39,21 @@ def download_dataset(repo_id, local_dir, allow_patterns): print(f"Dataset downloaded to {local_dir}") -def parquet_to_jsonl(dataset, work_dir, src_dir, tgt_dir, ntasks=64): +def parquet_to_jsonl( + dataset, work_dir, src_dir, tgt_dir, ntasks=64, s3_profile: str | None = None +): from datatrove.executor import LocalPipelineExecutor from datatrove.pipeline.readers import ParquetReader from datatrove.pipeline.writers import JsonlWriter + if tgt_dir.startswith("s3//"): + if s3_profile is None: + out_spec = tgt_dir + else: + out_spec = (tgt_dir, fsspec.filesystem("s3", profile=s3_profile)) + else: + out_spec = tgt_dir + pipeline_exec = LocalPipelineExecutor( pipeline=[ ParquetReader( @@ -52,7 +63,7 @@ def parquet_to_jsonl(dataset, work_dir, src_dir, tgt_dir, ntasks=64): glob_pattern="**/*.parquet", ), JsonlWriter( - tgt_dir, + out_spec, output_filename=dataset + ".chunk.${rank}.jsonl", compression=None, ), @@ -77,7 +88,7 @@ def setup_terashuf(work_dir): return terashuf_dir -def main(dataset, memory, data_dir, seed=42, nchunks=32): +def main(dataset, memory, data_dir, seed=42, nchunks=32, s3_profile: str | None = None): # Configuration repo_id = { "fineweb_edu": "HuggingFaceFW/fineweb-edu", From a1d05403b42cdb111be2cf3034a044e63f45b91b Mon Sep 17 00:00:00 2001 From: Pedro Rodriguez Date: Fri, 10 Jan 2025 01:02:25 +0000 Subject: [PATCH 05/18] Replace regular filesystem calls with fsspec + add s3 support Summary: For compatibility with either local/nfs or S3 datasets, swap to fsspec. Add a tool to compare local and remote filesystems Test Plan: - Ran regular train script - Ran with config with data in S3 --- .gitignore | 2 +- bytelatent/args.py | 8 +- bytelatent/data/file_util.py | 117 ++++++++++++++++++++ bytelatent/data/iterators/arrow_iterator.py | 89 ++++++++++++--- bytelatent/logger.py | 2 +- requirements.txt | 1 + setup/download_prepare_hf_data.py | 17 ++- 7 files changed, 217 insertions(+), 19 deletions(-) create mode 100644 bytelatent/data/file_util.py diff --git a/.gitignore b/.gitignore index 56891a9..6c664b8 100644 --- a/.gitignore +++ b/.gitignore @@ -165,4 +165,4 @@ cython_debug/ figures/ .vscode/ .DS_Store - +internal/ diff --git a/bytelatent/args.py b/bytelatent/args.py index 4fba100..cfba3bf 100644 --- a/bytelatent/args.py +++ b/bytelatent/args.py @@ -46,8 +46,11 @@ def distribute_data_to_rank( arrow_batch_size: int, rank: int, world_size: int, + s3_profile: str | None = None, ) -> ArrowFileIterator: - dataset_chunks = find_and_sanitize_chunks(dataset_path, world_size) + dataset_chunks = find_and_sanitize_chunks( + dataset_path, world_size, s3_profile=s3_profile + ) n_workers_per_chunk = world_size // len(dataset_chunks) rank_to_arrow_iterator_params = [] for chunk_path in dataset_chunks: @@ -61,6 +64,7 @@ def distribute_data_to_rank( dataset_files=None, entropy_model_name=entropy_model_name, arrow_batch_size=arrow_batch_size, + s3_profile=s3_profile, ) ) return rank_to_arrow_iterator_params[rank] @@ -68,6 +72,7 @@ def distribute_data_to_rank( class DataloaderArgs(BaseModel): model_config = ConfigDict(extra="forbid") + s3_profile: str | None = None root_dir: str | None = None sources: dict[str, float] = {} batch_size: int = 2 @@ -107,6 +112,7 @@ class DataloaderArgs(BaseModel): arrow_batch_size=self.arrow_batch_size, rank=rank, world_size=world_size, + s3_profile=self.s3_profile, ) looping_iterator = LoopingIterator(arrow_iterator) preprocess_iterator = PreprocessIterator( diff --git a/bytelatent/data/file_util.py b/bytelatent/data/file_util.py new file mode 100644 index 0000000..d67b6db --- /dev/null +++ b/bytelatent/data/file_util.py @@ -0,0 +1,117 @@ +import os + +import fsspec +import pyarrow as pa + +# pyarrow needs the initialization from this import +import pyarrow.dataset # pyright: ignore +import typer +from pyarrow.lib import ArrowInvalid +from rich.progress import track + + +def is_valid_arrow_file(path: str): + try: + dataset = pa.dataset.dataset(path, format="arrow") + return True + except ArrowInvalid: + return False + + +app = typer.Typer() + +S3_PREFIX = "s3://" + + +def get_fs(path: str, s3_profile: str | None = None) -> fsspec.AbstractFileSystem: + if path.startswith("s3://"): + if s3_profile is None: + return fsspec.filesystem("s3") + else: + return fsspec.filesystem("s3", profile=s3_profile) + else: + return fsspec.filesystem("file") + + +@app.command() +def print_local_to_delete( + blob_dir: str, local_dirs: list[str], s3_profile: str = "blt" +): + for s in local_dirs: + assert s.endswith("/"), "Dirs must end with /" + assert blob_dir.endswith("/"), "Dirs must end with /" + blob_fs = fsspec.filesystem("s3", profile=s3_profile) + blob_files = blob_fs.find(blob_dir) + for f in track(blob_files): + size = blob_fs.info(f)["Size"] + if not f.lower().endswith(".complete"): + assert size != 0, f"Size was invalidly zero for {f}" + + blob_relative_paths = {f[len(blob_dir) - len(S3_PREFIX) :] for f in blob_files} + local_fs = fsspec.filesystem("file") + + files_to_delete = [] + for local_dir in local_dirs: + local_files = local_fs.find(local_dir) + for f in local_files: + relative_path = f[len(local_dir) :] + if relative_path in blob_relative_paths and not os.path.islink(f): + files_to_delete.append(f) + print(len(files_to_delete)) + with open("/tmp/files_to_delete.txt", "w") as f: + for file in files_to_delete: + f.write(f"{file}\n") + + +@app.command() +def compare_local_to_blob( + source_dirs: list[str], dst_dir: str, s3_profile: str = "blt" +): + for s in source_dirs: + assert s.endswith("/"), "Dirs must end with /" + assert dst_dir.endswith("/"), "Dirs must end with /" + assert len(source_dirs) != 0 + assert dst_dir.startswith("s3://") + local_fs = fsspec.filesystem("file") + dst_fs = fsspec.filesystem("s3", profile=s3_profile) + source_to_files = {} + all_local_files = set() + for s in source_dirs: + skipped = [] + if s not in source_to_files: + source_to_files[s] = [] + for f in local_fs.find(s): + if os.path.islink(f): + continue + if f.endswith(".COMPLETE") or f.endswith(".complete"): + is_complete_file = True + assert os.path.getsize(f) == 0, ".COMPLETE files should be empty" + else: + is_complete_file = False + + if not is_complete_file and os.path.getsize(f) == 0: + skipped.append(f) + continue + if f.endswith(".arrow"): + if not is_valid_arrow_file(f): + skipped.append(f) + continue + + source_to_files[s].append(f) + all_local_files.add(f[len(s) :]) + print(s, len(source_to_files[s]), "skipped", len(skipped), skipped[:10]) + + dst_files = dst_fs.find(dst_dir) + print(dst_dir, len(dst_files)) + + dst_file_set = {f[len(dst_dir) - len(S3_PREFIX) :] for f in dst_files} + diff = all_local_files.symmetric_difference(dst_file_set) + print("Local files", len(all_local_files)) + print("DST Files", len(dst_file_set)) + print("Symmetric difference", len(diff)) + dst_only_files = dst_file_set - all_local_files + print("DST only", len(dst_only_files), list(dst_only_files)[:10]) + + +if __name__ == "__main__": + app() diff --git a/bytelatent/data/iterators/arrow_iterator.py b/bytelatent/data/iterators/arrow_iterator.py index df5f023..4e7b99e 100644 --- a/bytelatent/data/iterators/arrow_iterator.py +++ b/bytelatent/data/iterators/arrow_iterator.py @@ -1,17 +1,20 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. +import os import re from logging import getLogger -from pathlib import Path from typing import Any, Generator +import fsspec import pyarrow as pa # pyarrow needs the initialization from this import import pyarrow.dataset # pyright: ignore +import s3fs from pydantic import BaseModel, ConfigDict from bytelatent import ByteLatentError from bytelatent.data.data_types import BltExample +from bytelatent.data.file_util import get_fs from bytelatent.data.iterators.abstract_iterator import IteratorState, StatefulIterator logger = getLogger(__name__) @@ -27,6 +30,8 @@ class ArrowFileIteratorState(BaseModel, IteratorState): dataset_files: list[str] | None entropy_model_name: str | None arrow_batch_size: int = 100 + s3_profile: str | None + filesystem_type: str | None = None def build(self) -> "ArrowFileIterator": arrow_file = ArrowFileIterator( @@ -37,14 +42,17 @@ class ArrowFileIteratorState(BaseModel, IteratorState): entropy_model_name=self.entropy_model_name, arrow_batch_size=self.arrow_batch_size, dataset_files=self.dataset_files, + s3_profile=self.s3_profile, + filesystem_type=self.filesystem_type, ) if self.row_num != 0: arrow_file._set_row_num(self.row_num) return arrow_file -def shard_sort_key(file: str | Path): - match = re.search(r".+\.shard_([0-9]+)\.arrow", str(file)) +def shard_sort_key(file: str): + assert isinstance(file, str) + match = re.search(r".+\.shard_([0-9]+)\.arrow", file) shard_number = int(match.group(1)) return shard_number @@ -60,6 +68,8 @@ class ArrowFileIterator(StatefulIterator): entropy_model_name: str | None, arrow_batch_size: int, dataset_files: list[str] | None = None, + s3_profile: str | None = None, + filesystem_type: str | None = None, ): assert 0 <= worker_id < num_workers, (worker_id, num_workers) if file_path is None and dataset_files is None: @@ -75,16 +85,41 @@ class ArrowFileIterator(StatefulIterator): self.preprocess_dir = preprocess_dir self.entropy_model_name = entropy_model_name self.arrow_batch_size = arrow_batch_size + self.s3_profile = s3_profile + self.filesystem_type = filesystem_type + self.fs = None + if self.filesystem_type is not None: + if self.filesystem_type == "file": + self.fs = fsspec.filesystem("file") + elif self.filesystem_type == "s3": + self.fs = fsspec.filesystem("s3", profile=s3_profile) + if dataset_files is None: # Prepare arrow shards - jsonl_file = Path(file_path) - parts = re.match(r"(.+)\.chunk\.[0-9]+\.jsonl", jsonl_file.name) + jsonl_file = file_path + parts = re.match( + r"(.+)\.chunk\.[0-9]+\.jsonl", os.path.basename(jsonl_file) + ) assert parts is not None dataset = parts.group(1) - data_dir = Path(preprocess_dir) / dataset / entropy_model_name - shard_files = list(data_dir.glob(f"{jsonl_file.name}.shard_*.arrow")) + data_dir = os.path.join(preprocess_dir, dataset, entropy_model_name) + data_dir_with_glob = os.path.join( + data_dir, f"{os.path.basename(jsonl_file)}.shard_*.arrow" + ) + if self.fs is None: + self.fs = get_fs(data_dir_with_glob, s3_profile=s3_profile) + if isinstance(self.fs, s3fs.S3FileSystem): + self.filesystem_type = "s3" + else: + self.filesystem_type = "file" + shard_files = self.fs.glob(data_dir_with_glob) + for s in shard_files: - if not (data_dir / f"{s.name}.complete").exists(): + complete_file = os.path.join( + data_dir, f"{os.path.basename(s)}.complete" + ) + + if not self.fs.exists(complete_file): raise ValueError(f"Missing .complete for input file: {s}") shard_files = sorted(shard_files, key=shard_sort_key) @@ -92,10 +127,19 @@ class ArrowFileIterator(StatefulIterator): raise ByteLatentError( f"Zero shard_files found corresponding to: {file_path} using preprocess_dir={preprocess_dir} and entropy_model_name={entropy_model_name}, so the search path is data_dir={data_dir} for matches to {jsonl_file.name}.shard_*.arrow" ) - self.dataset_files = [str(f) for f in shard_files] + self.dataset_files = [f for f in shard_files] else: self.preprocess_dir = None self.dataset_files = dataset_files + if dataset_files[0].startswith("s3://"): + for f in dataset_files: + assert f.startswith("s3://") + if self.fs is None: + self.fs = get_fs(dataset_files[0], s3_profile=s3_profile) + if isinstance(self.fs, s3fs.S3FileSystem): + self.filesystem_type = "s3" + else: + self.filesystem_type = "file" def get_state(self) -> ArrowFileIteratorState: return ArrowFileIteratorState( @@ -107,13 +151,21 @@ class ArrowFileIterator(StatefulIterator): entropy_model_name=self.entropy_model_name, arrow_batch_size=self.arrow_batch_size, dataset_files=self.dataset_files, + s3_profile=self.s3_profile, + filesystem_type=self.filesystem_type, ) def create_iter( self, ) -> Generator[BltExample, Any, None]: if self.dataset is None: - self.dataset = pa.dataset.dataset(self.dataset_files, format="arrow") + if isinstance(self.fs, s3fs.core.S3FileSystem): + filesystem = self.fs + else: + filesystem = None + self.dataset = pa.dataset.dataset( + self.dataset_files, format="arrow", filesystem=filesystem + ) self.batch_iterator = self.dataset.to_batches( batch_size=self.arrow_batch_size ) @@ -165,7 +217,13 @@ class ArrowFileIterator(StatefulIterator): self.batch_iterator = None self.batch_to_consume = None else: - self.dataset = pa.dataset.dataset(self.dataset_files, format="arrow") + if isinstance(self.fs, s3fs.core.S3FileSystem): + filesystem = self.fs + else: + filesystem = None + self.dataset = pa.dataset.dataset( + self.dataset_files, format="arrow", filesystem=filesystem + ) self.batch_iterator = self.dataset.to_batches( batch_size=self.arrow_batch_size ) @@ -198,9 +256,14 @@ TRAIN_DATA_FILE_PATTERN = "*.chunk.*.jsonl" def find_and_sanitize_chunks( - dataset_path: str, world_size: int, file_pattern: str = TRAIN_DATA_FILE_PATTERN + dataset_path: str, + world_size: int, + file_pattern: str = TRAIN_DATA_FILE_PATTERN, + s3_profile: str | None = None, ): - dataset_chunks = [str(p) for p in Path(dataset_path).glob(file_pattern)] + fs = get_fs(dataset_path, s3_profile=s3_profile) + path_with_glob = os.path.join(dataset_path, file_pattern) + dataset_chunks = fs.glob(path_with_glob) n_chunks = len(dataset_chunks) if n_chunks > world_size: diff --git a/bytelatent/logger.py b/bytelatent/logger.py index 6723a84..87f04cc 100644 --- a/bytelatent/logger.py +++ b/bytelatent/logger.py @@ -91,7 +91,7 @@ def init_logger( log_file: str | None = None, *, name: str | None = None, - level: str = "NOTSET", + level: str = "INFO", ): """ Setup logging. diff --git a/requirements.txt b/requirements.txt index c6d87f1..59192cd 100644 --- a/requirements.txt +++ b/requirements.txt @@ -20,3 +20,4 @@ altair submitit typer rich +fsspec[full] diff --git a/setup/download_prepare_hf_data.py b/setup/download_prepare_hf_data.py index 1aacf8f..e194857 100644 --- a/setup/download_prepare_hf_data.py +++ b/setup/download_prepare_hf_data.py @@ -5,6 +5,7 @@ import os import subprocess import time +import fsspec import requests from huggingface_hub import snapshot_download @@ -38,11 +39,21 @@ def download_dataset(repo_id, local_dir, allow_patterns): print(f"Dataset downloaded to {local_dir}") -def parquet_to_jsonl(dataset, work_dir, src_dir, tgt_dir, ntasks=64): +def parquet_to_jsonl( + dataset, work_dir, src_dir, tgt_dir, ntasks=64, s3_profile: str | None = None +): from datatrove.executor import LocalPipelineExecutor from datatrove.pipeline.readers import ParquetReader from datatrove.pipeline.writers import JsonlWriter + if tgt_dir.startswith("s3//"): + if s3_profile is None: + out_spec = tgt_dir + else: + out_spec = (tgt_dir, fsspec.filesystem("s3", profile=s3_profile)) + else: + out_spec = tgt_dir + pipeline_exec = LocalPipelineExecutor( pipeline=[ ParquetReader( @@ -52,7 +63,7 @@ def parquet_to_jsonl(dataset, work_dir, src_dir, tgt_dir, ntasks=64): glob_pattern="**/*.parquet", ), JsonlWriter( - tgt_dir, + out_spec, output_filename=dataset + ".chunk.${rank}.jsonl", compression=None, ), @@ -77,7 +88,7 @@ def setup_terashuf(work_dir): return terashuf_dir -def main(dataset, memory, data_dir, seed=42, nchunks=32): +def main(dataset, memory, data_dir, seed=42, nchunks=32, s3_profile: str | None = None): # Configuration repo_id = { "fineweb_edu": "HuggingFaceFW/fineweb-edu", From 3f045f11238af657269e2051d870515e6ac7913b Mon Sep 17 00:00:00 2001 From: Pedro Rodriguez Date: Mon, 13 Jan 2025 23:13:49 +0000 Subject: [PATCH 06/18] Update preprocess_entropies script to blt inference + add fsspec support Summary: Test Plan: --- bytelatent/data/patcher.py | 8 +- bytelatent/preprocess/preprocess_entropies.py | 105 ++++++++++++------ requirements.txt | 1 + 3 files changed, 74 insertions(+), 40 deletions(-) diff --git a/bytelatent/data/patcher.py b/bytelatent/data/patcher.py index ede8b06..f8477a3 100644 --- a/bytelatent/data/patcher.py +++ b/bytelatent/data/patcher.py @@ -82,16 +82,16 @@ def calculate_entropies( if device is not None: split = split.to(device) assert torch.all(split >= 0) and torch.all(split < 260) - pred, _ = entropy_model(split) + pred = entropy_model(split) pred = pred.reshape(-1, pred.shape[-1])[ : split.numel() - pad_size, : ] # [batch_size * seq_len, vocab] pred_entropies = entropy(pred) entropies.append(pred_entropies) - entropies = torch.cat(entropies, dim=0) - entropies = entropies.reshape(tokens.shape) - return entropies + concat_entropies = torch.cat(entropies, dim=0) + concat_entropies = concat_entropies.reshape(tokens.shape) + return concat_entropies def patch_start_mask_from_entropy_with_monotonicity(entropies, t): diff --git a/bytelatent/preprocess/preprocess_entropies.py b/bytelatent/preprocess/preprocess_entropies.py index 20d1e0c..45c081c 100644 --- a/bytelatent/preprocess/preprocess_entropies.py +++ b/bytelatent/preprocess/preprocess_entropies.py @@ -1,14 +1,59 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. +import jsonlines import time -from pathlib import Path +import fsspec import numpy as np import pyarrow as pa import torch import typer from rich.progress import Progress, TextColumn -from bytelatent.data.iterators.arrow_iterator import ArrowFileIterator +from bytelatent.data.file_util import get_fs +from bytelatent.data.patcher import calculate_entropies +from bytelatent.entropy_model import load_entropy_model +from bytelatent.tokenizers.build_tokenizer import TokenizerArgs + + +def get_id_from_doc(doc: dict) -> int: + """ + We need a reliable way to ensure that samples from jsonl + and arrow are the same, but there is no unique id field, + so derive the best possible + """ + if "sample_id" in doc: + sample_id = doc["sample_id"] + elif "title" in doc: + sample_id = doc["title"] + elif "qid" in doc: + sample_id = doc["qid"] + elif "paper_id" in doc: + sample_id = doc["paper_id"] + elif "path" in doc: + sample_id = doc["path"] + elif "url" in doc: + sample_id = doc["url"] + elif "id" in doc: + sample_id = doc["id"] + else: + raise ValueError(f"Could not find a id key from: {doc.keys()}") + return str(sample_id) + + +def get_text(doc: dict): + if "text" in doc: + text = doc["text"] + elif "content" in doc: + text = doc["content"] + else: + raise ValueError(f"Could not find a text key from: {doc.keys()}") + return text + + +def jsonl_file_iterator(fs: fsspec.AbstractFileSystem, path: str): + with fs.open(path) as f: + reader = jsonlines.Reader(f) + yield from reader def main( @@ -16,39 +61,32 @@ def main( output_file: str, patching_device: str = "cuda", log_step: int = 10_000, - entropy_model_checkpoint_dir: str = "entropy_checkpoint_dir", + entropy_model_checkpoint_dir: str = "public_data/entropy_checkpoint", + entropy_model_state_dict_path: str = "public_data/entropy_model.pth", + bpe_tokenizer_path: str = "public_data/tokenizer.model", dry_run: bool = False, + s3_profile: str | None = None, ): - # TODO: Modify this to work with the new code - raise NotImplementedError() - iterator = ArrowFileIterator( - file_path=input_file, - worker_id=0, - num_workers=1, - ) - tokenization_mode = "bytes" print(f"Preprocessing entropies, input: {input_file}, output: {output_file}") print("Loading entropy model", entropy_model_checkpoint_dir) + input_fs = get_fs(input_file, s3_profile=s3_profile) + input_doc_iterator = jsonl_file_iterator(input_fs, input_file) + if dry_run: return entropy_model = load_entropy_model( - entropy_model_checkpoint_dir, device=patching_device + entropy_model_checkpoint_dir, + entropy_model_state_dict_path, + device=patching_device, ) - entropy_model, _ = to_device(entropy_model, patching_device) + print("Creating patcher") patching_batch_size = 32 print("Creating tokenizer") - tokenizer = Tokenizer( - model_path="/home/artidoro/tokenizers/llama_v2.tokenizer.model", - tokenization_mode=tokenization_mode, - # BYTE_UNITS - vocab_size_unit_1=256, - bos=True, - eos=True, - bpe_delim=False, - # This isn't used, just stores a reference for other calls we don't use - patcher=None, + tokenizer_args = TokenizerArgs( + name="blt", init_kwargs={"bpe_tokenizer_path": bpe_tokenizer_path} ) + tokenizer = tokenizer_args.build() step = 0 print("starting") start_time = time.time() @@ -59,8 +97,10 @@ def main( schema = pa.schema([sample_id_field, text_field, entropy_field]) arrow_batch_size = 1_000 + output_fs = get_fs(output_file, s3_profile=s3_profile) + try: - with pa.OSFile(output_file, "wb") as sink: + with output_fs.open(output_file, "wb") as sink: with pa.ipc.new_file(sink, schema) as writer: id_buffer = [] entropies_buffer = [] @@ -72,17 +112,9 @@ def main( task = progress.add_task( "[green]Calculating entropies...", total=None ) - for doc in iterator: + for doc in input_doc_iterator: sample_id = get_id_from_doc(doc) - - if "text" in doc: - text = doc["text"] - elif "content" in doc: - text = doc["content"] - else: - raise ValueError( - f"Could not find a text key from: {doc.keys()}" - ) + text = get_text(doc) tokens = torch.tensor(tokenizer.encode(text)) patch_start = time.time() scores = calculate_entropies( @@ -128,9 +160,10 @@ def main( entropies_buffer = [] id_buffer = [] text_buffer = [] - Path(f"{output_file}.complete").touch() + output_fs.touch(f"{output_file}.complete") except: - Path(output_file).unlink(missing_ok=True) + if output_fs.exists(output_file): + output_fs.rm(output_file) raise elapsed = time.time() - start_time print("steps", step) diff --git a/requirements.txt b/requirements.txt index 59192cd..8490556 100644 --- a/requirements.txt +++ b/requirements.txt @@ -21,3 +21,4 @@ submitit typer rich fsspec[full] +orjson From d718cfa9a105b37cd8c20ed47081f5d7bc4f0f9d Mon Sep 17 00:00:00 2001 From: Pedro Rodriguez Date: Mon, 13 Jan 2025 23:26:26 +0000 Subject: [PATCH 07/18] Update preprocess_entropies script to blt inference + add fsspec support Summary: Test Plan: --- bytelatent/data/patcher.py | 8 +- bytelatent/preprocess/preprocess_entropies.py | 105 ++++++++++++------ requirements.txt | 1 + 3 files changed, 74 insertions(+), 40 deletions(-) diff --git a/bytelatent/data/patcher.py b/bytelatent/data/patcher.py index ede8b06..f8477a3 100644 --- a/bytelatent/data/patcher.py +++ b/bytelatent/data/patcher.py @@ -82,16 +82,16 @@ def calculate_entropies( if device is not None: split = split.to(device) assert torch.all(split >= 0) and torch.all(split < 260) - pred, _ = entropy_model(split) + pred = entropy_model(split) pred = pred.reshape(-1, pred.shape[-1])[ : split.numel() - pad_size, : ] # [batch_size * seq_len, vocab] pred_entropies = entropy(pred) entropies.append(pred_entropies) - entropies = torch.cat(entropies, dim=0) - entropies = entropies.reshape(tokens.shape) - return entropies + concat_entropies = torch.cat(entropies, dim=0) + concat_entropies = concat_entropies.reshape(tokens.shape) + return concat_entropies def patch_start_mask_from_entropy_with_monotonicity(entropies, t): diff --git a/bytelatent/preprocess/preprocess_entropies.py b/bytelatent/preprocess/preprocess_entropies.py index 20d1e0c..1c19a5a 100644 --- a/bytelatent/preprocess/preprocess_entropies.py +++ b/bytelatent/preprocess/preprocess_entropies.py @@ -1,14 +1,59 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. import time -from pathlib import Path +import fsspec +import jsonlines import numpy as np import pyarrow as pa import torch import typer from rich.progress import Progress, TextColumn -from bytelatent.data.iterators.arrow_iterator import ArrowFileIterator +from bytelatent.data.file_util import get_fs +from bytelatent.data.patcher import calculate_entropies +from bytelatent.entropy_model import load_entropy_model +from bytelatent.tokenizers.build_tokenizer import TokenizerArgs + + +def get_id_from_doc(doc: dict) -> int: + """ + We need a reliable way to ensure that samples from jsonl + and arrow are the same, but there is no unique id field, + so derive the best possible + """ + if "sample_id" in doc: + sample_id = doc["sample_id"] + elif "title" in doc: + sample_id = doc["title"] + elif "qid" in doc: + sample_id = doc["qid"] + elif "paper_id" in doc: + sample_id = doc["paper_id"] + elif "path" in doc: + sample_id = doc["path"] + elif "url" in doc: + sample_id = doc["url"] + elif "id" in doc: + sample_id = doc["id"] + else: + raise ValueError(f"Could not find a id key from: {doc.keys()}") + return str(sample_id) + + +def get_text(doc: dict): + if "text" in doc: + text = doc["text"] + elif "content" in doc: + text = doc["content"] + else: + raise ValueError(f"Could not find a text key from: {doc.keys()}") + return text + + +def jsonl_file_iterator(fs: fsspec.AbstractFileSystem, path: str): + with fs.open(path) as f: + reader = jsonlines.Reader(f) + yield from reader def main( @@ -16,39 +61,32 @@ def main( output_file: str, patching_device: str = "cuda", log_step: int = 10_000, - entropy_model_checkpoint_dir: str = "entropy_checkpoint_dir", + entropy_model_checkpoint_dir: str = "public_data/entropy_checkpoint", + entropy_model_state_dict_path: str = "public_data/entropy_model.pth", + bpe_tokenizer_path: str = "public_data/tokenizer.model", dry_run: bool = False, + s3_profile: str | None = None, ): - # TODO: Modify this to work with the new code - raise NotImplementedError() - iterator = ArrowFileIterator( - file_path=input_file, - worker_id=0, - num_workers=1, - ) - tokenization_mode = "bytes" print(f"Preprocessing entropies, input: {input_file}, output: {output_file}") print("Loading entropy model", entropy_model_checkpoint_dir) + input_fs = get_fs(input_file, s3_profile=s3_profile) + input_doc_iterator = jsonl_file_iterator(input_fs, input_file) + if dry_run: return entropy_model = load_entropy_model( - entropy_model_checkpoint_dir, device=patching_device + entropy_model_checkpoint_dir, + entropy_model_state_dict_path, + device=patching_device, ) - entropy_model, _ = to_device(entropy_model, patching_device) + print("Creating patcher") patching_batch_size = 32 print("Creating tokenizer") - tokenizer = Tokenizer( - model_path="/home/artidoro/tokenizers/llama_v2.tokenizer.model", - tokenization_mode=tokenization_mode, - # BYTE_UNITS - vocab_size_unit_1=256, - bos=True, - eos=True, - bpe_delim=False, - # This isn't used, just stores a reference for other calls we don't use - patcher=None, + tokenizer_args = TokenizerArgs( + name="blt", init_kwargs={"bpe_tokenizer_path": bpe_tokenizer_path} ) + tokenizer = tokenizer_args.build() step = 0 print("starting") start_time = time.time() @@ -59,8 +97,10 @@ def main( schema = pa.schema([sample_id_field, text_field, entropy_field]) arrow_batch_size = 1_000 + output_fs = get_fs(output_file, s3_profile=s3_profile) + try: - with pa.OSFile(output_file, "wb") as sink: + with output_fs.open(output_file, "wb") as sink: with pa.ipc.new_file(sink, schema) as writer: id_buffer = [] entropies_buffer = [] @@ -72,17 +112,9 @@ def main( task = progress.add_task( "[green]Calculating entropies...", total=None ) - for doc in iterator: + for doc in input_doc_iterator: sample_id = get_id_from_doc(doc) - - if "text" in doc: - text = doc["text"] - elif "content" in doc: - text = doc["content"] - else: - raise ValueError( - f"Could not find a text key from: {doc.keys()}" - ) + text = get_text(doc) tokens = torch.tensor(tokenizer.encode(text)) patch_start = time.time() scores = calculate_entropies( @@ -128,9 +160,10 @@ def main( entropies_buffer = [] id_buffer = [] text_buffer = [] - Path(f"{output_file}.complete").touch() + output_fs.touch(f"{output_file}.complete") except: - Path(output_file).unlink(missing_ok=True) + if output_fs.exists(output_file): + output_fs.rm(output_file) raise elapsed = time.time() - start_time print("steps", step) diff --git a/requirements.txt b/requirements.txt index 59192cd..8490556 100644 --- a/requirements.txt +++ b/requirements.txt @@ -21,3 +21,4 @@ submitit typer rich fsspec[full] +orjson From 38022ac06e3f5b32daf12dfe3f02c6bbc6d1e840 Mon Sep 17 00:00:00 2001 From: Pedro Rodriguez Date: Thu, 16 Jan 2025 21:51:04 +0000 Subject: [PATCH 08/18] [WIP] Changes for training entropy model and correcting attention in local models Summary: - Refactor local model configs to be separate and clearer - Add attention arguments and correct which attention is used in local models - Preparation for being able to have an entropy train script - Fix failing unit tests Test Plan: --- bytelatent/args.py | 5 + bytelatent/base_transformer.py | 44 +++++- bytelatent/configs/debug.yaml | 2 +- .../data/iterators/test_arrow_iterator.py | 3 + bytelatent/model/blt.py | 128 ++++++++++-------- .../{transformer.py => global_transformer.py} | 17 ++- bytelatent/model/local_models.py | 94 +++++++++---- bytelatent/model/utils.py | 75 +++++++++- bytelatent/preprocess/fsspec_target.py | 38 ++++++ bytelatent/test_blt.py | 22 +-- bytelatent/test_entropy_model.py | 1 + bytelatent/transformer.py | 31 ++--- 12 files changed, 331 insertions(+), 129 deletions(-) rename bytelatent/model/{transformer.py => global_transformer.py} (93%) create mode 100644 bytelatent/preprocess/fsspec_target.py diff --git a/bytelatent/args.py b/bytelatent/args.py index cfba3bf..c6f714a 100644 --- a/bytelatent/args.py +++ b/bytelatent/args.py @@ -30,6 +30,7 @@ from bytelatent.model.blt import ByteLatentTransformerArgs from bytelatent.optim import OptimArgs from bytelatent.profiling import ProfilerArgs from bytelatent.tokenizers.build_tokenizer import TokenizerArgs +from bytelatent.transformer import LMTransformerArgs logger = logging.getLogger() @@ -176,6 +177,10 @@ class TrainArgs(BaseModel): data: DataloaderArgs = DataloaderArgs() optim: OptimArgs = OptimArgs() model: ByteLatentTransformerArgs = ByteLatentTransformerArgs() + # This is only needed for training the entropy model + entropy_model: LMTransformerArgs | None = None + # Instead of training main model, train entropy model + train_entropy_model: bool = False distributed: DistributedArgs = DistributedArgs() env: EnvironmentArgs = EnvironmentArgs() diff --git a/bytelatent/base_transformer.py b/bytelatent/base_transformer.py index 45cb7c5..969fc4b 100644 --- a/bytelatent/base_transformer.py +++ b/bytelatent/base_transformer.py @@ -4,7 +4,7 @@ from enum import Enum from typing import Optional, Tuple, Union import torch -from pydantic import BaseModel +from pydantic import BaseModel, ConfigDict from torch import nn from torch.nn import functional as F from torch.nn.attention.flex_attention import ( @@ -15,6 +15,7 @@ from torch.nn.attention.flex_attention import ( from xformers.ops import AttentionBias, fmha from bytelatent import probe +from bytelatent.tokenizers.constants import EOS_ID if int(os.environ.get("BLT_ALLOW_MISSING_FLEX_ATTENTION", False)) == 0: flex_attention_comp = torch.compile(flex_attention) @@ -30,13 +31,14 @@ class InitStdFactor(Enum): class BaseTransformerArgs(BaseModel): + model_config = ConfigDict(extra="forbid") dim: int = 512 n_layers: int = 8 - head_dim: Optional[int] = None - n_heads: Optional[int] = None - n_kv_heads: Optional[int] = None + head_dim: int | None = None + n_heads: int | None = None + n_kv_heads: int | None = None - ffn_dim_multiplier: Optional[float] = None + ffn_dim_multiplier: float | None = None multiple_of: int = 256 @@ -44,11 +46,16 @@ class BaseTransformerArgs(BaseModel): rope_theta: float = 10000.0 - init_base_std: Optional[float] = None + init_base_std: float | None = None init_std_factor: InitStdFactor = InitStdFactor.DISABLED max_seqlen: int = 1024 + attn_impl: str | None = "sdpa" + attn_bias_type: str | None = None + # Special token config + eos_id: int | None = EOS_ID + def cross_entropy(pred, target, **kwargs): return F.nll_loss( @@ -294,6 +301,18 @@ class RMSNorm(nn.Module): torch.nn.init.ones_(self.weight) # type: ignore +def _reshape_for_attn_bias( + attn_bias: AttentionBias | None, + *tensors: torch.Tensor, +) -> list[torch.Tensor]: + to_transform = list(tensors) + if isinstance(attn_bias, fmha.attn_bias.BlockDiagonalCausalMask): + # could be `view` instead of reshape during training, but for inference + # have to reshape due to strides mismatch + to_transform = [t.reshape(1, -1, *t.shape[2:]) for t in to_transform] + return to_transform + + class Attention(nn.Module): def __init__( self, @@ -371,9 +390,17 @@ class Attention(nn.Module): output = flex_attention_comp(xq, xk, xv, block_mask=mask) output = output.transpose(1, 2).contiguous() # B H S D -> B S H D - elif attn_impl == "fmha": + elif attn_impl == "xformers": assert mask is None or isinstance(mask, AttentionBias) + query_shape = xq.shape + print("Before reshape", "xq", xq.shape, "xk", xk.shape, "xv", xv.shape) + xq, xk, xv = _reshape_for_attn_bias(mask, xq, xk, xv) + print("Before reshape", "xq", xq.shape, "xk", xk.shape, "xv", xv.shape) output = fmha.memory_efficient_attention(xq, xk, xv, attn_bias=mask) + print("attn out", output.shape, "query_reshape", query_shape) + output_original_shape = output.view(query_shape) + print("Reshape success") + return output_original_shape # This uses B S H D instead of B H S D of pytorch elif attn_impl == "sdpa": @@ -545,6 +572,8 @@ class BaseTransformer(nn.Module): super().__init__() self.dim = args.dim self.init_base_std = args.init_base_std + self.attn_impl = args.attn_impl + self.attn_bias_type = args.attn_bias_type self.init_std_factor = InitStdFactor(args.init_std_factor) self.max_seqlen = args.max_seqlen self.rope_embeddings = RotaryEmbedding( @@ -552,6 +581,7 @@ class BaseTransformer(nn.Module): head_dim=args.head_dim or args.dim // args.n_heads, max_seqlen=args.max_seqlen, ) + self.eos_id = args.eos_id self.layers = nn.ModuleList() for _ in range(args.n_layers): diff --git a/bytelatent/configs/debug.yaml b/bytelatent/configs/debug.yaml index 5f6debb..fc8b943 100644 --- a/bytelatent/configs/debug.yaml +++ b/bytelatent/configs/debug.yaml @@ -58,13 +58,13 @@ model: recompute_attn: false custom_bwd: false layer_ckpt: "none" - efficient_attn: "sdpa" patch_only_encoder: false patch_only_decoder: false use_local_encoder_transformer: true init_use_gaussian: true init_use_depth: "current" attn_bias_type: "block_causal" + attn_impl: "xformers" alpha_depth: "disabled" max_length: 256 local_attention_window_len: 512 diff --git a/bytelatent/data/iterators/test_arrow_iterator.py b/bytelatent/data/iterators/test_arrow_iterator.py index 4266427..fd448eb 100644 --- a/bytelatent/data/iterators/test_arrow_iterator.py +++ b/bytelatent/data/iterators/test_arrow_iterator.py @@ -27,6 +27,7 @@ def test_basic_arrow_file(): dataset_files=[ARROW_TEST_DATA_1], row_num=0, arrow_batch_size=100, + s3_profile=None, ) arrow_file = initial_state.build() start_state = arrow_file.get_state() @@ -55,6 +56,7 @@ def test_basic_arrow_file(): dataset_files=[ARROW_TEST_DATA_1], row_num=251, arrow_batch_size=100, + s3_profile=None, ) arrow_file = resumed_state.build() for example in arrow_file.create_iter(): @@ -74,6 +76,7 @@ def test_basic_arrow_file(): dataset_files=[ARROW_TEST_DATA_1], row_num=0, arrow_batch_size=100, + s3_profile=None, ) arrow_file = rank_state.build() expected_ids = [] diff --git a/bytelatent/model/blt.py b/bytelatent/model/blt.py index 9332d19..3e405da 100644 --- a/bytelatent/model/blt.py +++ b/bytelatent/model/blt.py @@ -15,8 +15,8 @@ from bytelatent.base_transformer import ( TransformerBlock, ) from bytelatent.data.patcher import Patcher, PatcherArgs -from bytelatent.model.local_models import LocalDecoder, LocalEncoder -from bytelatent.model.transformer import GlobalTransformer +from bytelatent.model.local_models import LocalDecoder, LocalEncoder, LocalModelArgs +from bytelatent.model.global_transformer import GlobalTransformer from bytelatent.model.utils import downsample from bytelatent.tokenizers.constants import BOE_ID, BOS_ID, EOS_ID, OFFSET, PAD_ID @@ -403,7 +403,6 @@ def patch_ids_from_lengths(patch_lengths, seq_len): class ByteLatentTransformerArgs(BaseTransformerArgs): - model_config = ConfigDict(extra="forbid") # Basic model configuration seed: int = 42 vocab_size: int = -1 @@ -412,7 +411,6 @@ class ByteLatentTransformerArgs(BaseTransformerArgs): n_heads: int = 8 # TODO: What is the purpose of this parameter? weight_tying: bool = False - sliding_window: Optional[int] = None # Architecture and dimensions dim_token: int = 256 @@ -471,11 +469,6 @@ class ByteLatentTransformerArgs(BaseTransformerArgs): recompute_attn: bool = True custom_bwd: bool = False layer_ckpt: str = "all" - efficient_attn: str | None = None - - # Architecture options - patch_only_encoder: bool = False - patch_only_decoder: bool = False # Initialization and attention init_use_gaussian: bool = True @@ -541,9 +534,6 @@ class ByteLatentTransformerArgs(BaseTransformerArgs): # Logging full_logging_n_layers: int = 4 - # Special token config - eos_id: int | None = None - @model_validator(mode="after") def check_hash_byte_sizes(self) -> Self: if ( @@ -558,22 +548,6 @@ class ByteLatentTransformerArgs(BaseTransformerArgs): return self -class LocalEncoderArgs(ByteLatentTransformerArgs): - # Local encoder specific dimensions - n_heads_local_encoder: int = 8 - dim_token_emb: int | None = None - dim_patch_emb: int | None = None - - def __post_init__(self): - # Override base args with local encoder specific values - self.dim = self.dim_local_encoder - self.n_layers = self.n_layers_local_encoder - self.n_heads = self.n_heads_local_encoder - self.cross_attn_decoder = False - self.cross_attn_k = self.cross_attn_k if self.cross_attn_encoder else None - self.attn_bias_type = "local_block_causal" - - class GlobalTransformerArgs(ByteLatentTransformerArgs): # Global encoder specific dimensions dim_token_emb: int | None = None @@ -625,20 +599,42 @@ def create_global_transformer(args: ByteLatentTransformerArgs) -> GlobalTransfor def create_local_encoder(args: ByteLatentTransformerArgs) -> LocalEncoder: - # First deep copy the original args - # Replace with local encoder specific values - local_encoder_args = args.model_copy( - deep=True, - update=dict( - dim=args.dim_local_encoder, - n_layers=args.n_layers_local_encoder, - n_heads=args.n_heads_local_encoder, - dim_token_emb=get_encoder_dim_token_emb(args), - dim_patch_emb=get_encoder_dim_patch_emb(args), - cross_attn_decoder=False, - cross_attn_k=args.cross_attn_k if args.cross_attn_encoder else None, - attn_bias_type="local_block_causal", - ), + local_encoder_args = LocalModelArgs( + # Updated args + dim=args.dim_local_encoder, + n_layers=args.n_layers_local_encoder, + n_heads=args.n_heads_local_encoder, + dim_token_emb=get_encoder_dim_token_emb(args), + dim_patch_emb=get_encoder_dim_patch_emb(args), + cross_attn_encoder=args.cross_attn_encoder, + cross_attn_decoder=False, + cross_attn_k=args.cross_attn_k if args.cross_attn_encoder else None, + cross_attn_init_by_pooling=args.cross_attn_init_by_pooling, + # Defaults + head_dim=args.head_dim, + max_seqlen=args.max_encoder_seq_length, + dropout=args.dropout, + vocab_size=args.vocab_size + args.pm_size, + norm_eps=args.norm_eps, + patch_size=args.patch_size, + sliding_window=args.local_attention_window_len, + use_rope=args.use_rope, + rope_theta=args.rope_theta, + init_base_std=args.init_base_std, + init_std_factor=args.init_std_factor, + n_kv_heads=args.n_kv_heads, + attn_impl=args.attn_impl, + attn_bias_type="local_block_causal", + multiple_of=args.multiple_of, + ffn_dim_multiplier=args.ffn_dim_multiplier, + patching_mode=args.patching_mode, + use_local_encoder_transformer=args.use_local_encoder_transformer, + downsampling_by_pooling=args.downsampling_by_pooling, + encoder_hash_byte_group_size=args.encoder_hash_byte_group_size, + cross_attn_all_layers_encoder=args.cross_attn_all_layers_encoder, + cross_attn_all_layers_decoder=args.cross_attn_all_layers_decoder, + cross_attn_nheads=args.cross_attn_nheads, + eos_id=args.eos_id, ) return LocalEncoder(local_encoder_args) @@ -646,18 +642,41 @@ def create_local_encoder(args: ByteLatentTransformerArgs) -> LocalEncoder: def create_local_decoder(args: ByteLatentTransformerArgs) -> LocalDecoder: # First deep copy the original args - local_decoder_args = args.model_copy( - deep=True, - update=dict( - dim=args.dim_local_decoder, - n_layers=args.n_layers_local_decoder, - n_heads=args.n_heads_local_decoder, - cross_attn_encoder=False, - cross_attn_init_by_pooling=False, # states are already defined - dim_token_emb=get_decoder_dim_token_emb(args), - dim_patch_emb=args.dim_global, - cross_attn_k=args.cross_attn_k if args.cross_attn_decoder else None, - ), + local_decoder_args = LocalModelArgs( + dim=args.dim_local_decoder, + n_layers=args.n_layers_local_decoder, + n_heads=args.n_heads_local_decoder, + dim_token_emb=get_decoder_dim_token_emb(args), + dim_patch_emb=args.dim_global, + cross_attn_encoder=False, + cross_attn_decoder=args.cross_attn_decoder, + cross_attn_init_by_pooling=False, # states are already defined + cross_attn_k=args.cross_attn_k if args.cross_attn_decoder else None, + # Defaults + head_dim=args.head_dim, + max_seqlen=args.max_encoder_seq_length, + dropout=args.dropout, + vocab_size=args.vocab_size + args.pm_size, + norm_eps=args.norm_eps, + patch_size=args.patch_size, + sliding_window=args.local_attention_window_len, + use_rope=args.use_rope, + rope_theta=args.rope_theta, + init_base_std=args.init_base_std, + init_std_factor=args.init_std_factor, + n_kv_heads=args.n_kv_heads, + attn_impl=args.attn_impl, + attn_bias_type="local_block_causal", + multiple_of=args.multiple_of, + ffn_dim_multiplier=args.ffn_dim_multiplier, + patching_mode=args.patching_mode, + use_local_encoder_transformer=args.use_local_encoder_transformer, + downsampling_by_pooling=args.downsampling_by_pooling, + encoder_hash_byte_group_size=args.encoder_hash_byte_group_size, + cross_attn_all_layers_encoder=args.cross_attn_all_layers_encoder, + cross_attn_all_layers_decoder=args.cross_attn_all_layers_decoder, + cross_attn_nheads=args.cross_attn_nheads, + eos_id=args.eos_id, ) return LocalDecoder(local_decoder_args) @@ -763,7 +782,6 @@ class ByteLatentTransformer(nn.Module): # General configuration self.weight_tying = args.weight_tying - self.sliding_window = args.sliding_window self.patch_size = args.patch_size self.patching_mode = args.patching_mode self.boe_id, self.bos_id, self.pad_id, self.eos_id = ( diff --git a/bytelatent/model/transformer.py b/bytelatent/model/global_transformer.py similarity index 93% rename from bytelatent/model/transformer.py rename to bytelatent/model/global_transformer.py index 24dc057..21c3f0c 100644 --- a/bytelatent/model/transformer.py +++ b/bytelatent/model/global_transformer.py @@ -11,6 +11,7 @@ from xformers.ops import AttentionBias from bytelatent.base_transformer import ( BaseTransformer, + BaseTransformerArgs, RMSNorm, flex_attention_comp, repeat_kv, @@ -142,11 +143,10 @@ class CrossAttention(nn.Module): class GlobalTransformer(BaseTransformer): - def __init__(self, args): + def __init__(self, args: BaseTransformerArgs): super().__init__(args) self.dropout = args.dropout - self.sliding_window = args.sliding_window - self.efficient_attn = args.efficient_attn + self.eos_id = args.eos_id self.token_embedding_projection = None if args.dim_token_emb is not None and args.dim_token_emb != self.dim: @@ -169,14 +169,19 @@ class GlobalTransformer(BaseTransformer): and projection to the token space. """ bs, seqlen = tokens.shape - attn_impl = self.efficient_attn h = embeds mask = ( mask if mask is not None - else create_causal_mask(seqlen, attn_impl, self.sliding_window) + else create_causal_mask( + seqlen, + self.attn_impl, + self.attn_bias_type, + tokens=tokens, + eos_id=self.eos_id, + ) ) if self.token_embedding_projection is not None and h.shape[-1] != self.dim: @@ -184,7 +189,7 @@ class GlobalTransformer(BaseTransformer): h = F.dropout(h, p=self.dropout, training=self.training) - h = super().forward(h, tok_idx=tok_idx, mask=mask, attn_impl=attn_impl) + h = super().forward(h, tok_idx=tok_idx, mask=mask, attn_impl=self.attn_impl) return h, cache def init_weights(self, init_base_std: float): diff --git a/bytelatent/model/local_models.py b/bytelatent/model/local_models.py index 8255504..4ce8fb5 100644 --- a/bytelatent/model/local_models.py +++ b/bytelatent/model/local_models.py @@ -1,8 +1,9 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. import logging -from typing import List, Optional, Tuple, Union +from typing import Any, List, Optional, Tuple, Union +from pydantic import BaseModel, ConfigDict import torch import torch.nn import torch.nn as nn @@ -16,29 +17,69 @@ from bytelatent.base_transformer import ( RotaryEmbedding, TransformerBlock, ) -from bytelatent.model.transformer import CrossAttention +from bytelatent.model.global_transformer import CrossAttention from bytelatent.model.utils import create_causal_mask, downsample from bytelatent.tokenizers.blt_tokenizer import BOE_ID logger = logging.getLogger() +class LocalModelArgs(BaseModel): + model_config = ConfigDict(extra="forbid") + # Local encoder specific dimensions + head_dim: int | None + dim: int + dropout: float + vocab_size: int + patch_size: int + sliding_window: int | None + use_rope: bool + init_base_std: float | None = None + init_std_factor: InitStdFactor + cross_attn_encoder: bool | None + cross_attn_decoder: bool | None + cross_attn_k: int | None + cross_attn_init_by_pooling: bool + norm_eps: float + rope_theta: float + max_seqlen: int + ffn_dim_multiplier: float | None = None + patching_mode: str + use_local_encoder_transformer: bool + downsampling_by_pooling: str | None + encoder_hash_byte_group_size: Any | None = None + cross_attn_all_layers_encoder: bool = False + cross_attn_all_layers_decoder: bool = False + cross_attn_nheads: int | None + + n_layers: int + n_heads: int + n_kv_heads: int | None = None + dim_token_emb: int + dim_patch_emb: int | None + attn_impl: str | None = "xformers" + attn_bias_type: str | None = "local_block_causal" + multiple_of: int = 256 + eos_id: int | None = None + + class LocalModelBase(nn.Module): - def __init__(self, args): + def __init__(self, args: LocalModelArgs): super().__init__() self.dim = args.dim self.dropout = args.dropout - self.vocab_size = args.vocab_size + args.pm_size + self.vocab_size = args.vocab_size self.patch_size = args.patch_size - self.efficient_attn = args.efficient_attn + self.attn_impl = args.attn_impl self.sliding_window = args.sliding_window self.use_rope = args.use_rope self.init_std_factor = args.init_std_factor self.cross_attn_encoder = getattr(args, "cross_attn_encoder", None) self.cross_attn_decoder = getattr(args, "cross_attn_decoder", None) self.cross_attn_k = getattr(args, "cross_attn_k", None) + self.eos_id = args.eos_id self.boe_id = BOE_ID @@ -54,7 +95,7 @@ class LocalModelBase(nn.Module): self.rope = RotaryEmbedding( theta=args.rope_theta, head_dim=args.head_dim or args.dim // args.n_heads, - max_seqlen=getattr(args, "max_encoder_seq_length", args.max_length), + max_seqlen=args.max_seqlen, ) self.pos_embeddings = None @@ -66,21 +107,15 @@ class LocalModelBase(nn.Module): self.patch_embedding_projection = self._create_patch_projection(args) - def _should_create_patch_projection(self, args): + def _should_create_patch_projection(self, args: LocalModelArgs): dimension_mismatch = ( getattr(args, "dim_patch_emb") and args.dim_patch_emb != self.dim ) # Check cross attention conditions cross_attn_conditions = ( - hasattr(args, "cross_attn_encoder") - and args.cross_attn_encoder - and getattr(args, "cross_attn_init_by_pooling") - ) or ( - hasattr(args, "cross_attn_decoder") - and args.cross_attn_decoder - and getattr(args, "cross_attn_init_by_pooling") - ) + args.cross_attn_encoder and args.cross_attn_init_by_pooling + ) or (args.cross_attn_decoder and args.cross_attn_init_by_pooling) return dimension_mismatch or cross_attn_conditions @@ -172,7 +207,7 @@ class LocalModelBase(nn.Module): class LocalEncoder(LocalModelBase): - def __init__(self, args): + def __init__(self, args: LocalModelArgs): super().__init__(args) self.output_proj = ( args.patching_mode in ["entropy", "probmax"] @@ -180,7 +215,6 @@ class LocalEncoder(LocalModelBase): self.apply_transformer = args.use_local_encoder_transformer self.downsampling_by_pooling = args.downsampling_by_pooling - self.patch_only = args.patch_only_encoder self.expects_hash_embeddings = args.encoder_hash_byte_group_size is not None self.cross_attn_encoder = args.cross_attn_encoder self.cross_attn_all_layers_encoder = args.cross_attn_all_layers_encoder @@ -224,7 +258,14 @@ class LocalEncoder(LocalModelBase): """ """ bs, seqlen = tokens.shape if mask is None: - mask = create_causal_mask(seqlen, self.efficient_attn, self.sliding_window) + mask = create_causal_mask( + seqlen, + self.attn_impl, + "local_block_causal", + sliding_window=self.sliding_window, + tokens=tokens, + eos_id=self.eos_id, + ) h = self.apply_embedding(tokens, embeds) freqs_cis = self.rope(seqlen=seqlen) if self.use_rope else None @@ -232,7 +273,7 @@ class LocalEncoder(LocalModelBase): h = F.dropout(h, p=self.dropout, training=self.training) for i, layer in enumerate(self.layers): - h = layer(h, mask=mask, freq_cis=freqs_cis, attn_impl=self.efficient_attn) + h = layer(h, mask=mask, freq_cis=freqs_cis, attn_impl=self.attn_impl) # check if cross attention should be applied to either all layer or only the last layer if self.cross_attn_encoder and ( i == len(self.layers) - 1 or self.cross_attn_all_layers_encoder @@ -273,12 +314,10 @@ class LocalEncoder(LocalModelBase): class LocalDecoder(LocalModelBase): - def __init__(self, args): + def __init__(self, args: LocalModelArgs): super().__init__(args) # Model configuration flags - self.patch_only = args.patch_only_decoder - self.expects_embeddings = args.share_encoder_decoder_emb self.cross_attn_decoder = args.cross_attn_decoder self.cross_attn_all_layers_decoder = args.cross_attn_all_layers_decoder self.cross_attn_init_by_pooling = args.cross_attn_init_by_pooling @@ -317,7 +356,14 @@ class LocalDecoder(LocalModelBase): assert embeds is not None, "Embeddings must be provided" if mask is None: - mask = create_causal_mask(seqlen, self.efficient_attn, self.sliding_window) + mask = create_causal_mask( + seqlen, + self.attn_impl, + "local_block_causal", + sliding_window=self.sliding_window, + tokens=tokens, + eos_id=self.eos_id, + ) h = embeds @@ -347,7 +393,7 @@ class LocalDecoder(LocalModelBase): ) h = h + h_cross - h = layer(h, mask=mask, freq_cis=freqs_cis, attn_impl=self.efficient_attn) + h = layer(h, mask=mask, freq_cis=freqs_cis, attn_impl=self.attn_impl) h_preds = self.norm(h) h_preds = F.dropout(h_preds, p=self.dropout, training=self.training) diff --git a/bytelatent/model/utils.py b/bytelatent/model/utils.py index ce52a30..eac2c68 100644 --- a/bytelatent/model/utils.py +++ b/bytelatent/model/utils.py @@ -1,8 +1,11 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. +import logging import torch from torch.nn.attention.flex_attention import create_block_mask from xformers.ops import fmha +logger = logging.getLogger() + def patch_reduce(h, max_num_patches, reduction, patch_ids): """ @@ -97,14 +100,72 @@ def causal_mask(b, h, q_idx, kv_idx): return q_idx >= kv_idx -def create_causal_mask(seqlen, attn_impl, sliding_window): - if sliding_window is not None and attn_impl == "xformers": - return fmha.attn_bias.LocalAttentionFromBottomRightMask( - window_left=sliding_window - 1, window_right=0 - ) - elif attn_impl == "xformers": - return fmha.attn_bias.LowerTriangularMask() +def tokens_to_seqlen(batch: torch.Tensor, eos_id: int): + """ + 0 0 0 1 0 0 0 1 0 0 0 + 0 1 0 0 0 1 0 0 0 0 0 + -> 4 4 3 2 4 5 + """ + mask = batch == eos_id + mask[:, -1] = True # virtual eos at the end of each row + + # 0 0 0 1 0 0 0 1 0 0 X + # 0 1 0 0 0 1 0 0 0 0 X + row, col = torch.where(mask) + + # row = 0, 0, 0, 1, 1, 1 + # col = 3, 7, 10, 1, 5, 10 + seqlens = (col[1:] - col[:-1]) + (row[1:] - row[:-1]) * mask.shape[1] + # seqlens = (4, 3, -9, 4, 5) + (0, 0, 11, 0, 0) = (4, 3, 2, 4, 5) + return [int(col[0].item() + 1)] + seqlens.tolist() + + +WARNED_SDPA = False + + +def create_causal_mask( + seqlen, + attn_impl: str, + attn_bias_type: str | None, + *, + eos_id: int | None = None, + tokens: torch.Tensor | None = None, + sliding_window: int | None = None, +): + if attn_impl == "xformers": + if attn_bias_type is None: + return fmha.attn_bias.LowerTriangularMask() + elif attn_bias_type == "causal": + assert sliding_window is None + print("attn: causal") + return fmha.attn_bias.LowerTriangularMask() + elif attn_bias_type == "block_causal": + assert sliding_window is None + assert eos_id is not None + assert tokens is not None + print("attn: block_causal") + return fmha.attn_bias.BlockDiagonalCausalMask.from_seqlens( + q_seqlen=tokens_to_seqlen(tokens, eos_id) + ) + elif attn_bias_type == "local_block_causal": + assert sliding_window is not None + assert eos_id is not None + assert tokens is not None + print("attn: local_block_causal") + return fmha.attn_bias.BlockDiagonalCausalMask.from_seqlens( + q_seqlen=tokens_to_seqlen(tokens, eos_id) + ).make_local_attention(sliding_window) + else: + return fmha.attn_bias.LocalAttentionFromBottomRightMask( + window_left=sliding_window - 1, window_right=0 + ) elif attn_impl == "sdpa": + global WARNED_SDPA + if not WARNED_SDPA: + logging.warning( + "SDPA attention being used, which doesn't have specialized attention implementations for block_causal and local_block_causal attention." + ) + WARNED_SDPA = True return "causal" elif attn_impl == "flex_attention": return create_block_mask(causal_mask, None, None, seqlen, seqlen) diff --git a/bytelatent/preprocess/fsspec_target.py b/bytelatent/preprocess/fsspec_target.py new file mode 100644 index 0000000..eacb101 --- /dev/null +++ b/bytelatent/preprocess/fsspec_target.py @@ -0,0 +1,38 @@ +import fsspec +from luigi.target import FileSystem, FileSystemTarget + + +class FSSpecFileSystem(FileSystem): + def __init__(self, fs: fsspec.AbstractFileSystem): + self.fs = fs + + def exists(self, path): + return self.fs.exists() + + def remove(self, path, recursive=True, skip_trash=True): + raise NotImplementedError() + + def isdir(self, path): + return self.fs.isdir(path) + + def listdir(self, path): + return self.fs.ls(path) + + +class FSSpecTarget(FileSystemTarget): + def __init__(self, path, fs: fsspec.AbstractFileSystem | None = None): + self.path = path + if fs is None: + self.fsspec_fs = fsspec.filesystem("file") + else: + self.fsspec_fs = fs + self._fs = None + + @property + def fs(self): + if self._fs is None: + self._fs = FSSpecFileSystem(self.fsspec_fs) + return self._fs + + def open(self, mode): + return self.fs.open(self.path, mode=mode) diff --git a/bytelatent/test_blt.py b/bytelatent/test_blt.py index 73ad9f7..d7cda05 100644 --- a/bytelatent/test_blt.py +++ b/bytelatent/test_blt.py @@ -23,9 +23,10 @@ from bytelatent.model.blt import ( init_embeddings, patch_ids_from_lengths, ) -from bytelatent.model.transformer import CrossAttention +from bytelatent.model.global_transformer import CrossAttention from bytelatent.model.utils import create_causal_mask from bytelatent.optim import OptimArgs, build_optimizer +from bytelatent.tokenizers.constants import EOS_ID from bytelatent.train import compute_loss @@ -51,7 +52,7 @@ def batch_to_tensors_and_gpu(batch): def fake_batch(): - batch_dict = torch.load(os.path.join(BLT_DATA, "test_batch.pt")) + batch_dict = torch.load(os.path.join(BLT_DATA, "test_batch.pt"), weights_only=False) del batch_dict["x2"] del batch_dict["y2"] del batch_dict["src_names"] @@ -98,18 +99,17 @@ def create_args(cross_attention=False): recompute_attn=False, custom_bwd=False, layer_ckpt="none", - efficient_attn="sdpa", - patch_only_encoder=False, - patch_only_decoder=False, use_local_encoder_transformer=True, init_use_gaussian=True, init_use_depth="current", attn_bias_type="block_causal", + attn_impl="xformers", alpha_depth="disabled", max_length=256, local_attention_window_len=512, max_seqlen=12288, downsampling_by_pooling="max", + eos_id=EOS_ID, ) return transformer_args @@ -341,10 +341,10 @@ class TestByteLatentTransformer: model = ByteLatentTransformer(args) assert model is not None - @pytest.mark.parametrize("attn_type", ["fmha", "sdpa"]) - def test_blt_transformer_forward(self, attn_type): + @pytest.mark.parametrize("attn_impl", ["fmha", "sdpa", "xformers"]) + def test_blt_transformer_forward(self, attn_impl): args = create_args() - args = args.model_copy(update=dict(efficient_attn=attn_type)) + args = args.model_copy(update=dict(attn_impl=attn_impl)) model = ByteLatentTransformer(args) model = model.cuda() batch = fake_batch() @@ -393,7 +393,9 @@ class TestByteLatentTransformer: n_kv_heads=4, norm_eps=1e-6, ).to("cuda") - mask = create_causal_mask(x.shape[1], "flex_attention", sliding_window=None) + mask = create_causal_mask( + x.shape[1], "flex_attention", None, sliding_window=None + ) output = cross_attention(x, kv, mask) assert output is not None assert output.shape == (2, 256, 512) @@ -440,7 +442,7 @@ class TestByteLatentTransformer: def test_loss_backward(self): args = create_args() - args = args.model_copy(update=dict(efficient_attn="sdpa")) + args = args.model_copy(update=dict(attn_impl="sdpa")) batch = fake_batch() model = ByteLatentTransformer(args) steps = 10 diff --git a/bytelatent/test_entropy_model.py b/bytelatent/test_entropy_model.py index 3acc42d..af81638 100644 --- a/bytelatent/test_entropy_model.py +++ b/bytelatent/test_entropy_model.py @@ -24,6 +24,7 @@ def test_entropy_model(): dataset_files=[ARROW_TEST_DATA], row_num=0, arrow_batch_size=100, + s3_profile=None, ) arrow_file = initial_state.build() tokenizer_args = TokenizerArgs( diff --git a/bytelatent/transformer.py b/bytelatent/transformer.py index 432f7df..92c5ff5 100644 --- a/bytelatent/transformer.py +++ b/bytelatent/transformer.py @@ -22,23 +22,7 @@ from bytelatent.base_transformer import ( RMSNorm, cross_entropy, ) - - -def create_causal_mask(seqlen, attn_impl, sliding_window): - if sliding_window is not None and attn_impl == "xformers": - return fmha.attn_bias.LocalAttentionFromBottomRightMask( - window_left=sliding_window - 1, window_right=0 - ) - elif attn_impl == "xformers": - return fmha.attn_bias.LowerTriangularMask() - elif attn_impl == "sdpa": - return "causal" - elif attn_impl == "flex_attention": - return create_block_mask(causal_mask, None, None, seqlen, seqlen) - else: - raise NotImplementedError( - f"Attention {attn_impl} with {sliding_window} sliding window not implemented" - ) +from bytelatent.model.utils import create_causal_mask def attention_flops_per_token(n_layers, seq_len, dim, causal): @@ -94,8 +78,10 @@ class LMTransformer(BaseTransformer): target: Optional[torch.Tensor] = None, tok_idx: Optional[torch.Tensor] = None, mask: Optional[Union[BlockMask, AttentionBias, torch.Tensor, str]] = None, - attn_impl: str = "sdpa", + attn_impl: str | None = None, ): + if attn_impl is None: + attn_impl = self.attn_impl bsz, seqlen = token_values.shape h = self.tok_embeddings(token_values) @@ -103,7 +89,14 @@ class LMTransformer(BaseTransformer): mask = ( mask if mask is not None - else create_causal_mask(seqlen, attn_impl, self.sliding_window) + else create_causal_mask( + seqlen, + attn_impl, + self.attn_bias_type, + sliding_window=self.sliding_window, + tokens=token_values, + eos_id=self.eos_id, + ) ) h = super().forward(h, tok_idx=tok_idx, mask=mask, attn_impl=attn_impl) From 374409fa3b90e856f3b4c306c16fff21ac210ca7 Mon Sep 17 00:00:00 2001 From: Pedro Rodriguez Date: Fri, 17 Jan 2025 01:01:29 +0000 Subject: [PATCH 09/18] [WIP] Changes for training entropy model and correcting attention in local models Summary: - Refactor local model configs to be separate and clearer - Add attention arguments and correct which attention is used in local models - Preparation for being able to have an entropy train script - Fix failing unit tests Test Plan: --- bytelatent/args.py | 7 + bytelatent/base_transformer.py | 45 ++++-- bytelatent/configs/debug.yaml | 3 +- .../data/iterators/test_arrow_iterator.py | 3 + bytelatent/distributed.py | 1 - bytelatent/model/blt.py | 128 ++++++++++-------- .../{transformer.py => global_transformer.py} | 17 ++- bytelatent/model/local_models.py | 94 +++++++++---- bytelatent/model/utils.py | 73 +++++++++- bytelatent/preprocess/fsspec_target.py | 38 ++++++ bytelatent/test_blt.py | 22 +-- bytelatent/test_entropy_model.py | 1 + bytelatent/train.py | 4 + bytelatent/transformer.py | 31 ++--- 14 files changed, 334 insertions(+), 133 deletions(-) rename bytelatent/model/{transformer.py => global_transformer.py} (93%) create mode 100644 bytelatent/preprocess/fsspec_target.py diff --git a/bytelatent/args.py b/bytelatent/args.py index cfba3bf..b9144c6 100644 --- a/bytelatent/args.py +++ b/bytelatent/args.py @@ -30,6 +30,7 @@ from bytelatent.model.blt import ByteLatentTransformerArgs from bytelatent.optim import OptimArgs from bytelatent.profiling import ProfilerArgs from bytelatent.tokenizers.build_tokenizer import TokenizerArgs +from bytelatent.transformer import LMTransformerArgs logger = logging.getLogger() @@ -163,6 +164,8 @@ class TrainArgs(BaseModel): seed: int = 42 + debug_dynamo: bool = False + # Number of gradient accumulation steps # Total batch size is batch_size*grad_acc_steps grad_acc_steps: int = 1 @@ -176,6 +179,10 @@ class TrainArgs(BaseModel): data: DataloaderArgs = DataloaderArgs() optim: OptimArgs = OptimArgs() model: ByteLatentTransformerArgs = ByteLatentTransformerArgs() + # This is only needed for training the entropy model + entropy_model: LMTransformerArgs | None = None + # Instead of training main model, train entropy model + train_entropy_model: bool = False distributed: DistributedArgs = DistributedArgs() env: EnvironmentArgs = EnvironmentArgs() diff --git a/bytelatent/base_transformer.py b/bytelatent/base_transformer.py index 45cb7c5..dd0cce6 100644 --- a/bytelatent/base_transformer.py +++ b/bytelatent/base_transformer.py @@ -4,7 +4,7 @@ from enum import Enum from typing import Optional, Tuple, Union import torch -from pydantic import BaseModel +from pydantic import BaseModel, ConfigDict from torch import nn from torch.nn import functional as F from torch.nn.attention.flex_attention import ( @@ -15,6 +15,7 @@ from torch.nn.attention.flex_attention import ( from xformers.ops import AttentionBias, fmha from bytelatent import probe +from bytelatent.tokenizers.constants import EOS_ID if int(os.environ.get("BLT_ALLOW_MISSING_FLEX_ATTENTION", False)) == 0: flex_attention_comp = torch.compile(flex_attention) @@ -30,13 +31,14 @@ class InitStdFactor(Enum): class BaseTransformerArgs(BaseModel): + model_config = ConfigDict(extra="forbid") dim: int = 512 n_layers: int = 8 - head_dim: Optional[int] = None - n_heads: Optional[int] = None - n_kv_heads: Optional[int] = None + head_dim: int | None = None + n_heads: int | None = None + n_kv_heads: int | None = None - ffn_dim_multiplier: Optional[float] = None + ffn_dim_multiplier: float | None = None multiple_of: int = 256 @@ -44,11 +46,16 @@ class BaseTransformerArgs(BaseModel): rope_theta: float = 10000.0 - init_base_std: Optional[float] = None + init_base_std: float | None = None init_std_factor: InitStdFactor = InitStdFactor.DISABLED max_seqlen: int = 1024 + attn_impl: str | None = "sdpa" + attn_bias_type: str | None = None + # Special token config + eos_id: int | None = EOS_ID + def cross_entropy(pred, target, **kwargs): return F.nll_loss( @@ -294,6 +301,18 @@ class RMSNorm(nn.Module): torch.nn.init.ones_(self.weight) # type: ignore +def _reshape_for_attn_bias( + attn_bias: AttentionBias | None, + *tensors: torch.Tensor, +) -> list[torch.Tensor]: + to_transform = list(tensors) + if isinstance(attn_bias, fmha.attn_bias.BlockDiagonalCausalMask): + # could be `view` instead of reshape during training, but for inference + # have to reshape due to strides mismatch + to_transform = [t.reshape(1, -1, *t.shape[2:]) for t in to_transform] + return to_transform + + class Attention(nn.Module): def __init__( self, @@ -371,9 +390,12 @@ class Attention(nn.Module): output = flex_attention_comp(xq, xk, xv, block_mask=mask) output = output.transpose(1, 2).contiguous() # B H S D -> B S H D - elif attn_impl == "fmha": + elif attn_impl == "xformers": assert mask is None or isinstance(mask, AttentionBias) + query_shape = xq.shape + xq, xk, xv = _reshape_for_attn_bias(mask, xq, xk, xv) output = fmha.memory_efficient_attention(xq, xk, xv, attn_bias=mask) + output = output.view(query_shape) # This uses B S H D instead of B H S D of pytorch elif attn_impl == "sdpa": @@ -522,14 +544,16 @@ class TransformerBlock(nn.Module): mask: Optional[Union[BlockMask, AttentionBias, str]] = None, attn_impl: str = "sdpa", ) -> torch.Tensor: - h = x + self.attention( + attn_out = self.attention( self.attention_norm(x), freq_cis, tok_idx=tok_idx, mask=mask, attn_impl=attn_impl, ) - out = h + self.feed_forward(self.ffn_norm(h)) + h = x + attn_out + h_norm = self.ffn_norm(h) + out = h + self.feed_forward(h_norm) return out def init_weights(self, init_std=None, factor=1.0): @@ -545,6 +569,8 @@ class BaseTransformer(nn.Module): super().__init__() self.dim = args.dim self.init_base_std = args.init_base_std + self.attn_impl = args.attn_impl + self.attn_bias_type = args.attn_bias_type self.init_std_factor = InitStdFactor(args.init_std_factor) self.max_seqlen = args.max_seqlen self.rope_embeddings = RotaryEmbedding( @@ -552,6 +578,7 @@ class BaseTransformer(nn.Module): head_dim=args.head_dim or args.dim // args.n_heads, max_seqlen=args.max_seqlen, ) + self.eos_id = args.eos_id self.layers = nn.ModuleList() for _ in range(args.n_layers): diff --git a/bytelatent/configs/debug.yaml b/bytelatent/configs/debug.yaml index 5f6debb..4ae4459 100644 --- a/bytelatent/configs/debug.yaml +++ b/bytelatent/configs/debug.yaml @@ -15,7 +15,6 @@ optim: distributed: fsdp_type: full_shard - compile: true model_dtype: bf16 matmul_allow_tf32: false selective_activation_checkpointing: false @@ -58,13 +57,13 @@ model: recompute_attn: false custom_bwd: false layer_ckpt: "none" - efficient_attn: "sdpa" patch_only_encoder: false patch_only_decoder: false use_local_encoder_transformer: true init_use_gaussian: true init_use_depth: "current" attn_bias_type: "block_causal" + attn_impl: "xformers" alpha_depth: "disabled" max_length: 256 local_attention_window_len: 512 diff --git a/bytelatent/data/iterators/test_arrow_iterator.py b/bytelatent/data/iterators/test_arrow_iterator.py index 4266427..fd448eb 100644 --- a/bytelatent/data/iterators/test_arrow_iterator.py +++ b/bytelatent/data/iterators/test_arrow_iterator.py @@ -27,6 +27,7 @@ def test_basic_arrow_file(): dataset_files=[ARROW_TEST_DATA_1], row_num=0, arrow_batch_size=100, + s3_profile=None, ) arrow_file = initial_state.build() start_state = arrow_file.get_state() @@ -55,6 +56,7 @@ def test_basic_arrow_file(): dataset_files=[ARROW_TEST_DATA_1], row_num=251, arrow_batch_size=100, + s3_profile=None, ) arrow_file = resumed_state.build() for example in arrow_file.create_iter(): @@ -74,6 +76,7 @@ def test_basic_arrow_file(): dataset_files=[ARROW_TEST_DATA_1], row_num=0, arrow_batch_size=100, + s3_profile=None, ) arrow_file = rank_state.build() expected_ids = [] diff --git a/bytelatent/distributed.py b/bytelatent/distributed.py index b211858..168cb7c 100644 --- a/bytelatent/distributed.py +++ b/bytelatent/distributed.py @@ -11,7 +11,6 @@ import socket import subprocess import sys import tempfile -from dataclasses import asdict, dataclass from functools import lru_cache, partial, reduce from itertools import chain from typing import List, Optional, Tuple, Union diff --git a/bytelatent/model/blt.py b/bytelatent/model/blt.py index 9332d19..1d20cfa 100644 --- a/bytelatent/model/blt.py +++ b/bytelatent/model/blt.py @@ -15,8 +15,8 @@ from bytelatent.base_transformer import ( TransformerBlock, ) from bytelatent.data.patcher import Patcher, PatcherArgs -from bytelatent.model.local_models import LocalDecoder, LocalEncoder -from bytelatent.model.transformer import GlobalTransformer +from bytelatent.model.global_transformer import GlobalTransformer +from bytelatent.model.local_models import LocalDecoder, LocalEncoder, LocalModelArgs from bytelatent.model.utils import downsample from bytelatent.tokenizers.constants import BOE_ID, BOS_ID, EOS_ID, OFFSET, PAD_ID @@ -403,7 +403,6 @@ def patch_ids_from_lengths(patch_lengths, seq_len): class ByteLatentTransformerArgs(BaseTransformerArgs): - model_config = ConfigDict(extra="forbid") # Basic model configuration seed: int = 42 vocab_size: int = -1 @@ -412,7 +411,6 @@ class ByteLatentTransformerArgs(BaseTransformerArgs): n_heads: int = 8 # TODO: What is the purpose of this parameter? weight_tying: bool = False - sliding_window: Optional[int] = None # Architecture and dimensions dim_token: int = 256 @@ -471,11 +469,6 @@ class ByteLatentTransformerArgs(BaseTransformerArgs): recompute_attn: bool = True custom_bwd: bool = False layer_ckpt: str = "all" - efficient_attn: str | None = None - - # Architecture options - patch_only_encoder: bool = False - patch_only_decoder: bool = False # Initialization and attention init_use_gaussian: bool = True @@ -541,9 +534,6 @@ class ByteLatentTransformerArgs(BaseTransformerArgs): # Logging full_logging_n_layers: int = 4 - # Special token config - eos_id: int | None = None - @model_validator(mode="after") def check_hash_byte_sizes(self) -> Self: if ( @@ -558,22 +548,6 @@ class ByteLatentTransformerArgs(BaseTransformerArgs): return self -class LocalEncoderArgs(ByteLatentTransformerArgs): - # Local encoder specific dimensions - n_heads_local_encoder: int = 8 - dim_token_emb: int | None = None - dim_patch_emb: int | None = None - - def __post_init__(self): - # Override base args with local encoder specific values - self.dim = self.dim_local_encoder - self.n_layers = self.n_layers_local_encoder - self.n_heads = self.n_heads_local_encoder - self.cross_attn_decoder = False - self.cross_attn_k = self.cross_attn_k if self.cross_attn_encoder else None - self.attn_bias_type = "local_block_causal" - - class GlobalTransformerArgs(ByteLatentTransformerArgs): # Global encoder specific dimensions dim_token_emb: int | None = None @@ -625,20 +599,42 @@ def create_global_transformer(args: ByteLatentTransformerArgs) -> GlobalTransfor def create_local_encoder(args: ByteLatentTransformerArgs) -> LocalEncoder: - # First deep copy the original args - # Replace with local encoder specific values - local_encoder_args = args.model_copy( - deep=True, - update=dict( - dim=args.dim_local_encoder, - n_layers=args.n_layers_local_encoder, - n_heads=args.n_heads_local_encoder, - dim_token_emb=get_encoder_dim_token_emb(args), - dim_patch_emb=get_encoder_dim_patch_emb(args), - cross_attn_decoder=False, - cross_attn_k=args.cross_attn_k if args.cross_attn_encoder else None, - attn_bias_type="local_block_causal", - ), + local_encoder_args = LocalModelArgs( + # Updated args + dim=args.dim_local_encoder, + n_layers=args.n_layers_local_encoder, + n_heads=args.n_heads_local_encoder, + dim_token_emb=get_encoder_dim_token_emb(args), + dim_patch_emb=get_encoder_dim_patch_emb(args), + cross_attn_encoder=args.cross_attn_encoder, + cross_attn_decoder=False, + cross_attn_k=args.cross_attn_k if args.cross_attn_encoder else None, + cross_attn_init_by_pooling=args.cross_attn_init_by_pooling, + # Defaults + head_dim=args.head_dim, + max_seqlen=args.max_encoder_seq_length, + dropout=args.dropout, + vocab_size=args.vocab_size + args.pm_size, + norm_eps=args.norm_eps, + patch_size=args.patch_size, + sliding_window=args.local_attention_window_len, + use_rope=args.use_rope, + rope_theta=args.rope_theta, + init_base_std=args.init_base_std, + init_std_factor=args.init_std_factor, + n_kv_heads=args.n_kv_heads, + attn_impl=args.attn_impl, + attn_bias_type="local_block_causal", + multiple_of=args.multiple_of, + ffn_dim_multiplier=args.ffn_dim_multiplier, + patching_mode=args.patching_mode, + use_local_encoder_transformer=args.use_local_encoder_transformer, + downsampling_by_pooling=args.downsampling_by_pooling, + encoder_hash_byte_group_size=args.encoder_hash_byte_group_size, + cross_attn_all_layers_encoder=args.cross_attn_all_layers_encoder, + cross_attn_all_layers_decoder=args.cross_attn_all_layers_decoder, + cross_attn_nheads=args.cross_attn_nheads, + eos_id=args.eos_id, ) return LocalEncoder(local_encoder_args) @@ -646,18 +642,41 @@ def create_local_encoder(args: ByteLatentTransformerArgs) -> LocalEncoder: def create_local_decoder(args: ByteLatentTransformerArgs) -> LocalDecoder: # First deep copy the original args - local_decoder_args = args.model_copy( - deep=True, - update=dict( - dim=args.dim_local_decoder, - n_layers=args.n_layers_local_decoder, - n_heads=args.n_heads_local_decoder, - cross_attn_encoder=False, - cross_attn_init_by_pooling=False, # states are already defined - dim_token_emb=get_decoder_dim_token_emb(args), - dim_patch_emb=args.dim_global, - cross_attn_k=args.cross_attn_k if args.cross_attn_decoder else None, - ), + local_decoder_args = LocalModelArgs( + dim=args.dim_local_decoder, + n_layers=args.n_layers_local_decoder, + n_heads=args.n_heads_local_decoder, + dim_token_emb=get_decoder_dim_token_emb(args), + dim_patch_emb=args.dim_global, + cross_attn_encoder=False, + cross_attn_decoder=args.cross_attn_decoder, + cross_attn_init_by_pooling=False, # states are already defined + cross_attn_k=args.cross_attn_k if args.cross_attn_decoder else None, + # Defaults + head_dim=args.head_dim, + max_seqlen=args.max_encoder_seq_length, + dropout=args.dropout, + vocab_size=args.vocab_size + args.pm_size, + norm_eps=args.norm_eps, + patch_size=args.patch_size, + sliding_window=args.local_attention_window_len, + use_rope=args.use_rope, + rope_theta=args.rope_theta, + init_base_std=args.init_base_std, + init_std_factor=args.init_std_factor, + n_kv_heads=args.n_kv_heads, + attn_impl=args.attn_impl, + attn_bias_type="local_block_causal", + multiple_of=args.multiple_of, + ffn_dim_multiplier=args.ffn_dim_multiplier, + patching_mode=args.patching_mode, + use_local_encoder_transformer=args.use_local_encoder_transformer, + downsampling_by_pooling=args.downsampling_by_pooling, + encoder_hash_byte_group_size=args.encoder_hash_byte_group_size, + cross_attn_all_layers_encoder=args.cross_attn_all_layers_encoder, + cross_attn_all_layers_decoder=args.cross_attn_all_layers_decoder, + cross_attn_nheads=args.cross_attn_nheads, + eos_id=args.eos_id, ) return LocalDecoder(local_decoder_args) @@ -763,7 +782,6 @@ class ByteLatentTransformer(nn.Module): # General configuration self.weight_tying = args.weight_tying - self.sliding_window = args.sliding_window self.patch_size = args.patch_size self.patching_mode = args.patching_mode self.boe_id, self.bos_id, self.pad_id, self.eos_id = ( diff --git a/bytelatent/model/transformer.py b/bytelatent/model/global_transformer.py similarity index 93% rename from bytelatent/model/transformer.py rename to bytelatent/model/global_transformer.py index 24dc057..21c3f0c 100644 --- a/bytelatent/model/transformer.py +++ b/bytelatent/model/global_transformer.py @@ -11,6 +11,7 @@ from xformers.ops import AttentionBias from bytelatent.base_transformer import ( BaseTransformer, + BaseTransformerArgs, RMSNorm, flex_attention_comp, repeat_kv, @@ -142,11 +143,10 @@ class CrossAttention(nn.Module): class GlobalTransformer(BaseTransformer): - def __init__(self, args): + def __init__(self, args: BaseTransformerArgs): super().__init__(args) self.dropout = args.dropout - self.sliding_window = args.sliding_window - self.efficient_attn = args.efficient_attn + self.eos_id = args.eos_id self.token_embedding_projection = None if args.dim_token_emb is not None and args.dim_token_emb != self.dim: @@ -169,14 +169,19 @@ class GlobalTransformer(BaseTransformer): and projection to the token space. """ bs, seqlen = tokens.shape - attn_impl = self.efficient_attn h = embeds mask = ( mask if mask is not None - else create_causal_mask(seqlen, attn_impl, self.sliding_window) + else create_causal_mask( + seqlen, + self.attn_impl, + self.attn_bias_type, + tokens=tokens, + eos_id=self.eos_id, + ) ) if self.token_embedding_projection is not None and h.shape[-1] != self.dim: @@ -184,7 +189,7 @@ class GlobalTransformer(BaseTransformer): h = F.dropout(h, p=self.dropout, training=self.training) - h = super().forward(h, tok_idx=tok_idx, mask=mask, attn_impl=attn_impl) + h = super().forward(h, tok_idx=tok_idx, mask=mask, attn_impl=self.attn_impl) return h, cache def init_weights(self, init_base_std: float): diff --git a/bytelatent/model/local_models.py b/bytelatent/model/local_models.py index 8255504..f182780 100644 --- a/bytelatent/model/local_models.py +++ b/bytelatent/model/local_models.py @@ -1,11 +1,12 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. import logging -from typing import List, Optional, Tuple, Union +from typing import Any, List, Optional, Tuple, Union import torch import torch.nn import torch.nn as nn +from pydantic import BaseModel, ConfigDict from torch.nn import functional as F from torch.nn.attention.flex_attention import BlockMask from xformers.ops import AttentionBias @@ -16,29 +17,69 @@ from bytelatent.base_transformer import ( RotaryEmbedding, TransformerBlock, ) -from bytelatent.model.transformer import CrossAttention +from bytelatent.model.global_transformer import CrossAttention from bytelatent.model.utils import create_causal_mask, downsample from bytelatent.tokenizers.blt_tokenizer import BOE_ID logger = logging.getLogger() +class LocalModelArgs(BaseModel): + model_config = ConfigDict(extra="forbid") + # Local encoder specific dimensions + head_dim: int | None + dim: int + dropout: float + vocab_size: int + patch_size: int + sliding_window: int | None + use_rope: bool + init_base_std: float | None = None + init_std_factor: InitStdFactor + cross_attn_encoder: bool | None + cross_attn_decoder: bool | None + cross_attn_k: int | None + cross_attn_init_by_pooling: bool + norm_eps: float + rope_theta: float + max_seqlen: int + ffn_dim_multiplier: float | None = None + patching_mode: str + use_local_encoder_transformer: bool + downsampling_by_pooling: str | None + encoder_hash_byte_group_size: Any | None = None + cross_attn_all_layers_encoder: bool = False + cross_attn_all_layers_decoder: bool = False + cross_attn_nheads: int | None + + n_layers: int + n_heads: int + n_kv_heads: int | None = None + dim_token_emb: int + dim_patch_emb: int | None + attn_impl: str | None = "xformers" + attn_bias_type: str | None = "local_block_causal" + multiple_of: int = 256 + eos_id: int | None = None + + class LocalModelBase(nn.Module): - def __init__(self, args): + def __init__(self, args: LocalModelArgs): super().__init__() self.dim = args.dim self.dropout = args.dropout - self.vocab_size = args.vocab_size + args.pm_size + self.vocab_size = args.vocab_size self.patch_size = args.patch_size - self.efficient_attn = args.efficient_attn + self.attn_impl = args.attn_impl self.sliding_window = args.sliding_window self.use_rope = args.use_rope self.init_std_factor = args.init_std_factor self.cross_attn_encoder = getattr(args, "cross_attn_encoder", None) self.cross_attn_decoder = getattr(args, "cross_attn_decoder", None) self.cross_attn_k = getattr(args, "cross_attn_k", None) + self.eos_id = args.eos_id self.boe_id = BOE_ID @@ -54,7 +95,7 @@ class LocalModelBase(nn.Module): self.rope = RotaryEmbedding( theta=args.rope_theta, head_dim=args.head_dim or args.dim // args.n_heads, - max_seqlen=getattr(args, "max_encoder_seq_length", args.max_length), + max_seqlen=args.max_seqlen, ) self.pos_embeddings = None @@ -66,21 +107,15 @@ class LocalModelBase(nn.Module): self.patch_embedding_projection = self._create_patch_projection(args) - def _should_create_patch_projection(self, args): + def _should_create_patch_projection(self, args: LocalModelArgs): dimension_mismatch = ( getattr(args, "dim_patch_emb") and args.dim_patch_emb != self.dim ) # Check cross attention conditions cross_attn_conditions = ( - hasattr(args, "cross_attn_encoder") - and args.cross_attn_encoder - and getattr(args, "cross_attn_init_by_pooling") - ) or ( - hasattr(args, "cross_attn_decoder") - and args.cross_attn_decoder - and getattr(args, "cross_attn_init_by_pooling") - ) + args.cross_attn_encoder and args.cross_attn_init_by_pooling + ) or (args.cross_attn_decoder and args.cross_attn_init_by_pooling) return dimension_mismatch or cross_attn_conditions @@ -172,7 +207,7 @@ class LocalModelBase(nn.Module): class LocalEncoder(LocalModelBase): - def __init__(self, args): + def __init__(self, args: LocalModelArgs): super().__init__(args) self.output_proj = ( args.patching_mode in ["entropy", "probmax"] @@ -180,7 +215,6 @@ class LocalEncoder(LocalModelBase): self.apply_transformer = args.use_local_encoder_transformer self.downsampling_by_pooling = args.downsampling_by_pooling - self.patch_only = args.patch_only_encoder self.expects_hash_embeddings = args.encoder_hash_byte_group_size is not None self.cross_attn_encoder = args.cross_attn_encoder self.cross_attn_all_layers_encoder = args.cross_attn_all_layers_encoder @@ -224,7 +258,14 @@ class LocalEncoder(LocalModelBase): """ """ bs, seqlen = tokens.shape if mask is None: - mask = create_causal_mask(seqlen, self.efficient_attn, self.sliding_window) + mask = create_causal_mask( + seqlen, + self.attn_impl, + "local_block_causal", + sliding_window=self.sliding_window, + tokens=tokens, + eos_id=self.eos_id, + ) h = self.apply_embedding(tokens, embeds) freqs_cis = self.rope(seqlen=seqlen) if self.use_rope else None @@ -232,7 +273,7 @@ class LocalEncoder(LocalModelBase): h = F.dropout(h, p=self.dropout, training=self.training) for i, layer in enumerate(self.layers): - h = layer(h, mask=mask, freq_cis=freqs_cis, attn_impl=self.efficient_attn) + h = layer(h, mask=mask, freq_cis=freqs_cis, attn_impl=self.attn_impl) # check if cross attention should be applied to either all layer or only the last layer if self.cross_attn_encoder and ( i == len(self.layers) - 1 or self.cross_attn_all_layers_encoder @@ -273,12 +314,10 @@ class LocalEncoder(LocalModelBase): class LocalDecoder(LocalModelBase): - def __init__(self, args): + def __init__(self, args: LocalModelArgs): super().__init__(args) # Model configuration flags - self.patch_only = args.patch_only_decoder - self.expects_embeddings = args.share_encoder_decoder_emb self.cross_attn_decoder = args.cross_attn_decoder self.cross_attn_all_layers_decoder = args.cross_attn_all_layers_decoder self.cross_attn_init_by_pooling = args.cross_attn_init_by_pooling @@ -317,7 +356,14 @@ class LocalDecoder(LocalModelBase): assert embeds is not None, "Embeddings must be provided" if mask is None: - mask = create_causal_mask(seqlen, self.efficient_attn, self.sliding_window) + mask = create_causal_mask( + seqlen, + self.attn_impl, + "local_block_causal", + sliding_window=self.sliding_window, + tokens=tokens, + eos_id=self.eos_id, + ) h = embeds @@ -347,7 +393,7 @@ class LocalDecoder(LocalModelBase): ) h = h + h_cross - h = layer(h, mask=mask, freq_cis=freqs_cis, attn_impl=self.efficient_attn) + h = layer(h, mask=mask, freq_cis=freqs_cis, attn_impl=self.attn_impl) h_preds = self.norm(h) h_preds = F.dropout(h_preds, p=self.dropout, training=self.training) diff --git a/bytelatent/model/utils.py b/bytelatent/model/utils.py index ce52a30..42eb185 100644 --- a/bytelatent/model/utils.py +++ b/bytelatent/model/utils.py @@ -1,8 +1,12 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. +import logging + import torch from torch.nn.attention.flex_attention import create_block_mask from xformers.ops import fmha +logger = logging.getLogger() + def patch_reduce(h, max_num_patches, reduction, patch_ids): """ @@ -97,14 +101,69 @@ def causal_mask(b, h, q_idx, kv_idx): return q_idx >= kv_idx -def create_causal_mask(seqlen, attn_impl, sliding_window): - if sliding_window is not None and attn_impl == "xformers": - return fmha.attn_bias.LocalAttentionFromBottomRightMask( - window_left=sliding_window - 1, window_right=0 - ) - elif attn_impl == "xformers": - return fmha.attn_bias.LowerTriangularMask() +def tokens_to_seqlen(batch: torch.Tensor, eos_id: int): + """ + 0 0 0 1 0 0 0 1 0 0 0 + 0 1 0 0 0 1 0 0 0 0 0 + -> 4 4 3 2 4 5 + """ + mask = batch == eos_id + mask[:, -1] = True # virtual eos at the end of each row + + # 0 0 0 1 0 0 0 1 0 0 X + # 0 1 0 0 0 1 0 0 0 0 X + row, col = torch.where(mask) + + # row = 0, 0, 0, 1, 1, 1 + # col = 3, 7, 10, 1, 5, 10 + seqlens = (col[1:] - col[:-1]) + (row[1:] - row[:-1]) * mask.shape[1] + # seqlens = (4, 3, -9, 4, 5) + (0, 0, 11, 0, 0) = (4, 3, 2, 4, 5) + return [int(col[0].item() + 1)] + seqlens.tolist() + + +WARNED_SDPA = False + + +def create_causal_mask( + seqlen, + attn_impl: str, + attn_bias_type: str | None, + *, + eos_id: int | None = None, + tokens: torch.Tensor | None = None, + sliding_window: int | None = None, +): + if attn_impl == "xformers": + if attn_bias_type is None: + return fmha.attn_bias.LowerTriangularMask() + elif attn_bias_type == "causal": + assert sliding_window is None + return fmha.attn_bias.LowerTriangularMask() + elif attn_bias_type == "block_causal": + assert sliding_window is None + assert eos_id is not None + assert tokens is not None + return fmha.attn_bias.BlockDiagonalCausalMask.from_seqlens( + q_seqlen=tokens_to_seqlen(tokens, eos_id) + ) + elif attn_bias_type == "local_block_causal": + assert sliding_window is not None + assert eos_id is not None + assert tokens is not None + return fmha.attn_bias.BlockDiagonalCausalMask.from_seqlens( + q_seqlen=tokens_to_seqlen(tokens, eos_id) + ).make_local_attention(sliding_window) + else: + return fmha.attn_bias.LocalAttentionFromBottomRightMask( + window_left=sliding_window - 1, window_right=0 + ) elif attn_impl == "sdpa": + global WARNED_SDPA + if not WARNED_SDPA: + logging.warning( + "SDPA attention being used, which doesn't have specialized attention implementations for block_causal and local_block_causal attention." + ) + WARNED_SDPA = True return "causal" elif attn_impl == "flex_attention": return create_block_mask(causal_mask, None, None, seqlen, seqlen) diff --git a/bytelatent/preprocess/fsspec_target.py b/bytelatent/preprocess/fsspec_target.py new file mode 100644 index 0000000..eacb101 --- /dev/null +++ b/bytelatent/preprocess/fsspec_target.py @@ -0,0 +1,38 @@ +import fsspec +from luigi.target import FileSystem, FileSystemTarget + + +class FSSpecFileSystem(FileSystem): + def __init__(self, fs: fsspec.AbstractFileSystem): + self.fs = fs + + def exists(self, path): + return self.fs.exists() + + def remove(self, path, recursive=True, skip_trash=True): + raise NotImplementedError() + + def isdir(self, path): + return self.fs.isdir(path) + + def listdir(self, path): + return self.fs.ls(path) + + +class FSSpecTarget(FileSystemTarget): + def __init__(self, path, fs: fsspec.AbstractFileSystem | None = None): + self.path = path + if fs is None: + self.fsspec_fs = fsspec.filesystem("file") + else: + self.fsspec_fs = fs + self._fs = None + + @property + def fs(self): + if self._fs is None: + self._fs = FSSpecFileSystem(self.fsspec_fs) + return self._fs + + def open(self, mode): + return self.fs.open(self.path, mode=mode) diff --git a/bytelatent/test_blt.py b/bytelatent/test_blt.py index 73ad9f7..4d8e9c7 100644 --- a/bytelatent/test_blt.py +++ b/bytelatent/test_blt.py @@ -23,9 +23,10 @@ from bytelatent.model.blt import ( init_embeddings, patch_ids_from_lengths, ) -from bytelatent.model.transformer import CrossAttention +from bytelatent.model.global_transformer import CrossAttention from bytelatent.model.utils import create_causal_mask from bytelatent.optim import OptimArgs, build_optimizer +from bytelatent.tokenizers.constants import EOS_ID from bytelatent.train import compute_loss @@ -51,7 +52,7 @@ def batch_to_tensors_and_gpu(batch): def fake_batch(): - batch_dict = torch.load(os.path.join(BLT_DATA, "test_batch.pt")) + batch_dict = torch.load(os.path.join(BLT_DATA, "test_batch.pt"), weights_only=False) del batch_dict["x2"] del batch_dict["y2"] del batch_dict["src_names"] @@ -98,18 +99,17 @@ def create_args(cross_attention=False): recompute_attn=False, custom_bwd=False, layer_ckpt="none", - efficient_attn="sdpa", - patch_only_encoder=False, - patch_only_decoder=False, use_local_encoder_transformer=True, init_use_gaussian=True, init_use_depth="current", attn_bias_type="block_causal", + attn_impl="xformers", alpha_depth="disabled", max_length=256, local_attention_window_len=512, max_seqlen=12288, downsampling_by_pooling="max", + eos_id=EOS_ID, ) return transformer_args @@ -341,10 +341,10 @@ class TestByteLatentTransformer: model = ByteLatentTransformer(args) assert model is not None - @pytest.mark.parametrize("attn_type", ["fmha", "sdpa"]) - def test_blt_transformer_forward(self, attn_type): + @pytest.mark.parametrize("attn_impl", ["sdpa", "xformers"]) + def test_blt_transformer_forward(self, attn_impl): args = create_args() - args = args.model_copy(update=dict(efficient_attn=attn_type)) + args = args.model_copy(update=dict(attn_impl=attn_impl)) model = ByteLatentTransformer(args) model = model.cuda() batch = fake_batch() @@ -393,7 +393,9 @@ class TestByteLatentTransformer: n_kv_heads=4, norm_eps=1e-6, ).to("cuda") - mask = create_causal_mask(x.shape[1], "flex_attention", sliding_window=None) + mask = create_causal_mask( + x.shape[1], "flex_attention", None, sliding_window=None + ) output = cross_attention(x, kv, mask) assert output is not None assert output.shape == (2, 256, 512) @@ -440,7 +442,7 @@ class TestByteLatentTransformer: def test_loss_backward(self): args = create_args() - args = args.model_copy(update=dict(efficient_attn="sdpa")) + args = args.model_copy(update=dict(attn_impl="sdpa")) batch = fake_batch() model = ByteLatentTransformer(args) steps = 10 diff --git a/bytelatent/test_entropy_model.py b/bytelatent/test_entropy_model.py index 3acc42d..af81638 100644 --- a/bytelatent/test_entropy_model.py +++ b/bytelatent/test_entropy_model.py @@ -24,6 +24,7 @@ def test_entropy_model(): dataset_files=[ARROW_TEST_DATA], row_num=0, arrow_batch_size=100, + s3_profile=None, ) arrow_file = initial_state.build() tokenizer_args = TokenizerArgs( diff --git a/bytelatent/train.py b/bytelatent/train.py index 6cb13b9..80bd393 100644 --- a/bytelatent/train.py +++ b/bytelatent/train.py @@ -644,6 +644,10 @@ def main(): cfg = OmegaConf.merge(default_cfg, file_cfg, cli_args) cfg = OmegaConf.to_container(cfg, resolve=True, throw_on_missing=True) train_args = TrainArgs.model_validate(cfg) + if train_args.debug_dynamo: + import torch._dynamo + + torch._dynamo.config.suppress_errors = True train(train_args) diff --git a/bytelatent/transformer.py b/bytelatent/transformer.py index 432f7df..92c5ff5 100644 --- a/bytelatent/transformer.py +++ b/bytelatent/transformer.py @@ -22,23 +22,7 @@ from bytelatent.base_transformer import ( RMSNorm, cross_entropy, ) - - -def create_causal_mask(seqlen, attn_impl, sliding_window): - if sliding_window is not None and attn_impl == "xformers": - return fmha.attn_bias.LocalAttentionFromBottomRightMask( - window_left=sliding_window - 1, window_right=0 - ) - elif attn_impl == "xformers": - return fmha.attn_bias.LowerTriangularMask() - elif attn_impl == "sdpa": - return "causal" - elif attn_impl == "flex_attention": - return create_block_mask(causal_mask, None, None, seqlen, seqlen) - else: - raise NotImplementedError( - f"Attention {attn_impl} with {sliding_window} sliding window not implemented" - ) +from bytelatent.model.utils import create_causal_mask def attention_flops_per_token(n_layers, seq_len, dim, causal): @@ -94,8 +78,10 @@ class LMTransformer(BaseTransformer): target: Optional[torch.Tensor] = None, tok_idx: Optional[torch.Tensor] = None, mask: Optional[Union[BlockMask, AttentionBias, torch.Tensor, str]] = None, - attn_impl: str = "sdpa", + attn_impl: str | None = None, ): + if attn_impl is None: + attn_impl = self.attn_impl bsz, seqlen = token_values.shape h = self.tok_embeddings(token_values) @@ -103,7 +89,14 @@ class LMTransformer(BaseTransformer): mask = ( mask if mask is not None - else create_causal_mask(seqlen, attn_impl, self.sliding_window) + else create_causal_mask( + seqlen, + attn_impl, + self.attn_bias_type, + sliding_window=self.sliding_window, + tokens=token_values, + eos_id=self.eos_id, + ) ) h = super().forward(h, tok_idx=tok_idx, mask=mask, attn_impl=attn_impl) From 7f305b38719c8ffb788753f9bc2f41df22834162 Mon Sep 17 00:00:00 2001 From: Pedro Rodriguez Date: Fri, 17 Jan 2025 22:21:50 +0000 Subject: [PATCH 10/18] [WIP] Changes for training entropy model and correcting attention in local models Summary: - Refactor local model configs to be separate and clearer - Add attention arguments and correct which attention is used in local models - Preparation for being able to have an entropy train script - Fix failing unit tests Test Plan: --- bytelatent/args.py | 7 + bytelatent/base_transformer.py | 45 ++++-- bytelatent/configs/debug.yaml | 3 +- .../data/iterators/test_arrow_iterator.py | 3 + bytelatent/distributed.py | 1 - bytelatent/entropy_model.py | 10 +- bytelatent/model/blt.py | 128 ++++++++++-------- .../{transformer.py => latent_transformer.py} | 17 ++- bytelatent/model/local_models.py | 84 ++++++++---- bytelatent/model/utils.py | 80 +++++++++-- bytelatent/preprocess/fsspec_target.py | 38 ++++++ bytelatent/test_blt.py | 27 ++-- bytelatent/test_entropy_model.py | 9 +- bytelatent/train.py | 4 + bytelatent/transformer.py | 31 ++--- 15 files changed, 349 insertions(+), 138 deletions(-) rename bytelatent/model/{transformer.py => latent_transformer.py} (93%) create mode 100644 bytelatent/preprocess/fsspec_target.py diff --git a/bytelatent/args.py b/bytelatent/args.py index cfba3bf..b9144c6 100644 --- a/bytelatent/args.py +++ b/bytelatent/args.py @@ -30,6 +30,7 @@ from bytelatent.model.blt import ByteLatentTransformerArgs from bytelatent.optim import OptimArgs from bytelatent.profiling import ProfilerArgs from bytelatent.tokenizers.build_tokenizer import TokenizerArgs +from bytelatent.transformer import LMTransformerArgs logger = logging.getLogger() @@ -163,6 +164,8 @@ class TrainArgs(BaseModel): seed: int = 42 + debug_dynamo: bool = False + # Number of gradient accumulation steps # Total batch size is batch_size*grad_acc_steps grad_acc_steps: int = 1 @@ -176,6 +179,10 @@ class TrainArgs(BaseModel): data: DataloaderArgs = DataloaderArgs() optim: OptimArgs = OptimArgs() model: ByteLatentTransformerArgs = ByteLatentTransformerArgs() + # This is only needed for training the entropy model + entropy_model: LMTransformerArgs | None = None + # Instead of training main model, train entropy model + train_entropy_model: bool = False distributed: DistributedArgs = DistributedArgs() env: EnvironmentArgs = EnvironmentArgs() diff --git a/bytelatent/base_transformer.py b/bytelatent/base_transformer.py index 45cb7c5..dd0cce6 100644 --- a/bytelatent/base_transformer.py +++ b/bytelatent/base_transformer.py @@ -4,7 +4,7 @@ from enum import Enum from typing import Optional, Tuple, Union import torch -from pydantic import BaseModel +from pydantic import BaseModel, ConfigDict from torch import nn from torch.nn import functional as F from torch.nn.attention.flex_attention import ( @@ -15,6 +15,7 @@ from torch.nn.attention.flex_attention import ( from xformers.ops import AttentionBias, fmha from bytelatent import probe +from bytelatent.tokenizers.constants import EOS_ID if int(os.environ.get("BLT_ALLOW_MISSING_FLEX_ATTENTION", False)) == 0: flex_attention_comp = torch.compile(flex_attention) @@ -30,13 +31,14 @@ class InitStdFactor(Enum): class BaseTransformerArgs(BaseModel): + model_config = ConfigDict(extra="forbid") dim: int = 512 n_layers: int = 8 - head_dim: Optional[int] = None - n_heads: Optional[int] = None - n_kv_heads: Optional[int] = None + head_dim: int | None = None + n_heads: int | None = None + n_kv_heads: int | None = None - ffn_dim_multiplier: Optional[float] = None + ffn_dim_multiplier: float | None = None multiple_of: int = 256 @@ -44,11 +46,16 @@ class BaseTransformerArgs(BaseModel): rope_theta: float = 10000.0 - init_base_std: Optional[float] = None + init_base_std: float | None = None init_std_factor: InitStdFactor = InitStdFactor.DISABLED max_seqlen: int = 1024 + attn_impl: str | None = "sdpa" + attn_bias_type: str | None = None + # Special token config + eos_id: int | None = EOS_ID + def cross_entropy(pred, target, **kwargs): return F.nll_loss( @@ -294,6 +301,18 @@ class RMSNorm(nn.Module): torch.nn.init.ones_(self.weight) # type: ignore +def _reshape_for_attn_bias( + attn_bias: AttentionBias | None, + *tensors: torch.Tensor, +) -> list[torch.Tensor]: + to_transform = list(tensors) + if isinstance(attn_bias, fmha.attn_bias.BlockDiagonalCausalMask): + # could be `view` instead of reshape during training, but for inference + # have to reshape due to strides mismatch + to_transform = [t.reshape(1, -1, *t.shape[2:]) for t in to_transform] + return to_transform + + class Attention(nn.Module): def __init__( self, @@ -371,9 +390,12 @@ class Attention(nn.Module): output = flex_attention_comp(xq, xk, xv, block_mask=mask) output = output.transpose(1, 2).contiguous() # B H S D -> B S H D - elif attn_impl == "fmha": + elif attn_impl == "xformers": assert mask is None or isinstance(mask, AttentionBias) + query_shape = xq.shape + xq, xk, xv = _reshape_for_attn_bias(mask, xq, xk, xv) output = fmha.memory_efficient_attention(xq, xk, xv, attn_bias=mask) + output = output.view(query_shape) # This uses B S H D instead of B H S D of pytorch elif attn_impl == "sdpa": @@ -522,14 +544,16 @@ class TransformerBlock(nn.Module): mask: Optional[Union[BlockMask, AttentionBias, str]] = None, attn_impl: str = "sdpa", ) -> torch.Tensor: - h = x + self.attention( + attn_out = self.attention( self.attention_norm(x), freq_cis, tok_idx=tok_idx, mask=mask, attn_impl=attn_impl, ) - out = h + self.feed_forward(self.ffn_norm(h)) + h = x + attn_out + h_norm = self.ffn_norm(h) + out = h + self.feed_forward(h_norm) return out def init_weights(self, init_std=None, factor=1.0): @@ -545,6 +569,8 @@ class BaseTransformer(nn.Module): super().__init__() self.dim = args.dim self.init_base_std = args.init_base_std + self.attn_impl = args.attn_impl + self.attn_bias_type = args.attn_bias_type self.init_std_factor = InitStdFactor(args.init_std_factor) self.max_seqlen = args.max_seqlen self.rope_embeddings = RotaryEmbedding( @@ -552,6 +578,7 @@ class BaseTransformer(nn.Module): head_dim=args.head_dim or args.dim // args.n_heads, max_seqlen=args.max_seqlen, ) + self.eos_id = args.eos_id self.layers = nn.ModuleList() for _ in range(args.n_layers): diff --git a/bytelatent/configs/debug.yaml b/bytelatent/configs/debug.yaml index 5f6debb..4ae4459 100644 --- a/bytelatent/configs/debug.yaml +++ b/bytelatent/configs/debug.yaml @@ -15,7 +15,6 @@ optim: distributed: fsdp_type: full_shard - compile: true model_dtype: bf16 matmul_allow_tf32: false selective_activation_checkpointing: false @@ -58,13 +57,13 @@ model: recompute_attn: false custom_bwd: false layer_ckpt: "none" - efficient_attn: "sdpa" patch_only_encoder: false patch_only_decoder: false use_local_encoder_transformer: true init_use_gaussian: true init_use_depth: "current" attn_bias_type: "block_causal" + attn_impl: "xformers" alpha_depth: "disabled" max_length: 256 local_attention_window_len: 512 diff --git a/bytelatent/data/iterators/test_arrow_iterator.py b/bytelatent/data/iterators/test_arrow_iterator.py index 4266427..fd448eb 100644 --- a/bytelatent/data/iterators/test_arrow_iterator.py +++ b/bytelatent/data/iterators/test_arrow_iterator.py @@ -27,6 +27,7 @@ def test_basic_arrow_file(): dataset_files=[ARROW_TEST_DATA_1], row_num=0, arrow_batch_size=100, + s3_profile=None, ) arrow_file = initial_state.build() start_state = arrow_file.get_state() @@ -55,6 +56,7 @@ def test_basic_arrow_file(): dataset_files=[ARROW_TEST_DATA_1], row_num=251, arrow_batch_size=100, + s3_profile=None, ) arrow_file = resumed_state.build() for example in arrow_file.create_iter(): @@ -74,6 +76,7 @@ def test_basic_arrow_file(): dataset_files=[ARROW_TEST_DATA_1], row_num=0, arrow_batch_size=100, + s3_profile=None, ) arrow_file = rank_state.build() expected_ids = [] diff --git a/bytelatent/distributed.py b/bytelatent/distributed.py index b211858..168cb7c 100644 --- a/bytelatent/distributed.py +++ b/bytelatent/distributed.py @@ -11,7 +11,6 @@ import socket import subprocess import sys import tempfile -from dataclasses import asdict, dataclass from functools import lru_cache, partial, reduce from itertools import chain from typing import List, Optional, Tuple, Union diff --git a/bytelatent/entropy_model.py b/bytelatent/entropy_model.py index 1bd1766..30754ee 100644 --- a/bytelatent/entropy_model.py +++ b/bytelatent/entropy_model.py @@ -1,12 +1,14 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. import json +import logging import os -import re import torch from bytelatent.transformer import LMTransformer, LMTransformerArgs +logger = logging.getLogger() + def load_entropy_model(entropy_model_checkpoint_dir, state_dict_path, device="cpu"): with open(os.path.join(entropy_model_checkpoint_dir, "params.json")) as fr: @@ -14,6 +16,9 @@ def load_entropy_model(entropy_model_checkpoint_dir, state_dict_path, device="cp torch.set_default_dtype(torch.bfloat16) model_params = reloaded["model"] + logger.warning( + "Update checkpoint to load attn and sliding window args from checkpoint" + ) entropy_model = LMTransformer( LMTransformerArgs( dim=model_params["dim"], @@ -22,6 +27,9 @@ def load_entropy_model(entropy_model_checkpoint_dir, state_dict_path, device="cp max_seqlen=model_params["max_length"], ffn_dim_multiplier=model_params["ffn_dim_multiplier"], vocab_size=model_params["vocab_size"], + attn_bias_type="local_block_causal", + attn_impl="xformers", + sliding_window=512, ) ) diff --git a/bytelatent/model/blt.py b/bytelatent/model/blt.py index 9332d19..843ad34 100644 --- a/bytelatent/model/blt.py +++ b/bytelatent/model/blt.py @@ -15,8 +15,8 @@ from bytelatent.base_transformer import ( TransformerBlock, ) from bytelatent.data.patcher import Patcher, PatcherArgs -from bytelatent.model.local_models import LocalDecoder, LocalEncoder -from bytelatent.model.transformer import GlobalTransformer +from bytelatent.model.latent_transformer import GlobalTransformer +from bytelatent.model.local_models import LocalDecoder, LocalEncoder, LocalModelArgs from bytelatent.model.utils import downsample from bytelatent.tokenizers.constants import BOE_ID, BOS_ID, EOS_ID, OFFSET, PAD_ID @@ -403,7 +403,6 @@ def patch_ids_from_lengths(patch_lengths, seq_len): class ByteLatentTransformerArgs(BaseTransformerArgs): - model_config = ConfigDict(extra="forbid") # Basic model configuration seed: int = 42 vocab_size: int = -1 @@ -412,7 +411,6 @@ class ByteLatentTransformerArgs(BaseTransformerArgs): n_heads: int = 8 # TODO: What is the purpose of this parameter? weight_tying: bool = False - sliding_window: Optional[int] = None # Architecture and dimensions dim_token: int = 256 @@ -471,11 +469,6 @@ class ByteLatentTransformerArgs(BaseTransformerArgs): recompute_attn: bool = True custom_bwd: bool = False layer_ckpt: str = "all" - efficient_attn: str | None = None - - # Architecture options - patch_only_encoder: bool = False - patch_only_decoder: bool = False # Initialization and attention init_use_gaussian: bool = True @@ -541,9 +534,6 @@ class ByteLatentTransformerArgs(BaseTransformerArgs): # Logging full_logging_n_layers: int = 4 - # Special token config - eos_id: int | None = None - @model_validator(mode="after") def check_hash_byte_sizes(self) -> Self: if ( @@ -558,22 +548,6 @@ class ByteLatentTransformerArgs(BaseTransformerArgs): return self -class LocalEncoderArgs(ByteLatentTransformerArgs): - # Local encoder specific dimensions - n_heads_local_encoder: int = 8 - dim_token_emb: int | None = None - dim_patch_emb: int | None = None - - def __post_init__(self): - # Override base args with local encoder specific values - self.dim = self.dim_local_encoder - self.n_layers = self.n_layers_local_encoder - self.n_heads = self.n_heads_local_encoder - self.cross_attn_decoder = False - self.cross_attn_k = self.cross_attn_k if self.cross_attn_encoder else None - self.attn_bias_type = "local_block_causal" - - class GlobalTransformerArgs(ByteLatentTransformerArgs): # Global encoder specific dimensions dim_token_emb: int | None = None @@ -625,20 +599,42 @@ def create_global_transformer(args: ByteLatentTransformerArgs) -> GlobalTransfor def create_local_encoder(args: ByteLatentTransformerArgs) -> LocalEncoder: - # First deep copy the original args - # Replace with local encoder specific values - local_encoder_args = args.model_copy( - deep=True, - update=dict( - dim=args.dim_local_encoder, - n_layers=args.n_layers_local_encoder, - n_heads=args.n_heads_local_encoder, - dim_token_emb=get_encoder_dim_token_emb(args), - dim_patch_emb=get_encoder_dim_patch_emb(args), - cross_attn_decoder=False, - cross_attn_k=args.cross_attn_k if args.cross_attn_encoder else None, - attn_bias_type="local_block_causal", - ), + local_encoder_args = LocalModelArgs( + # Updated args + dim=args.dim_local_encoder, + n_layers=args.n_layers_local_encoder, + n_heads=args.n_heads_local_encoder, + dim_token_emb=get_encoder_dim_token_emb(args), + dim_patch_emb=get_encoder_dim_patch_emb(args), + cross_attn_encoder=args.cross_attn_encoder, + cross_attn_decoder=False, + cross_attn_k=args.cross_attn_k if args.cross_attn_encoder else None, + cross_attn_init_by_pooling=args.cross_attn_init_by_pooling, + # Defaults + head_dim=args.head_dim, + max_seqlen=args.max_encoder_seq_length, + dropout=args.dropout, + vocab_size=args.vocab_size + args.pm_size, + norm_eps=args.norm_eps, + patch_size=args.patch_size, + sliding_window=args.local_attention_window_len, + use_rope=args.use_rope, + rope_theta=args.rope_theta, + init_base_std=args.init_base_std, + init_std_factor=args.init_std_factor, + n_kv_heads=args.n_kv_heads, + attn_impl=args.attn_impl, + attn_bias_type="local_block_causal", + multiple_of=args.multiple_of, + ffn_dim_multiplier=args.ffn_dim_multiplier, + patching_mode=args.patching_mode, + use_local_encoder_transformer=args.use_local_encoder_transformer, + downsampling_by_pooling=args.downsampling_by_pooling, + encoder_hash_byte_group_size=args.encoder_hash_byte_group_size, + cross_attn_all_layers_encoder=args.cross_attn_all_layers_encoder, + cross_attn_all_layers_decoder=args.cross_attn_all_layers_decoder, + cross_attn_nheads=args.cross_attn_nheads, + eos_id=args.eos_id, ) return LocalEncoder(local_encoder_args) @@ -646,18 +642,41 @@ def create_local_encoder(args: ByteLatentTransformerArgs) -> LocalEncoder: def create_local_decoder(args: ByteLatentTransformerArgs) -> LocalDecoder: # First deep copy the original args - local_decoder_args = args.model_copy( - deep=True, - update=dict( - dim=args.dim_local_decoder, - n_layers=args.n_layers_local_decoder, - n_heads=args.n_heads_local_decoder, - cross_attn_encoder=False, - cross_attn_init_by_pooling=False, # states are already defined - dim_token_emb=get_decoder_dim_token_emb(args), - dim_patch_emb=args.dim_global, - cross_attn_k=args.cross_attn_k if args.cross_attn_decoder else None, - ), + local_decoder_args = LocalModelArgs( + dim=args.dim_local_decoder, + n_layers=args.n_layers_local_decoder, + n_heads=args.n_heads_local_decoder, + dim_token_emb=get_decoder_dim_token_emb(args), + dim_patch_emb=args.dim_global, + cross_attn_encoder=False, + cross_attn_decoder=args.cross_attn_decoder, + cross_attn_init_by_pooling=False, # states are already defined + cross_attn_k=args.cross_attn_k if args.cross_attn_decoder else None, + # Defaults + head_dim=args.head_dim, + max_seqlen=args.max_encoder_seq_length, + dropout=args.dropout, + vocab_size=args.vocab_size + args.pm_size, + norm_eps=args.norm_eps, + patch_size=args.patch_size, + sliding_window=args.local_attention_window_len, + use_rope=args.use_rope, + rope_theta=args.rope_theta, + init_base_std=args.init_base_std, + init_std_factor=args.init_std_factor, + n_kv_heads=args.n_kv_heads, + attn_impl=args.attn_impl, + attn_bias_type="local_block_causal", + multiple_of=args.multiple_of, + ffn_dim_multiplier=args.ffn_dim_multiplier, + patching_mode=args.patching_mode, + use_local_encoder_transformer=args.use_local_encoder_transformer, + downsampling_by_pooling=args.downsampling_by_pooling, + encoder_hash_byte_group_size=args.encoder_hash_byte_group_size, + cross_attn_all_layers_encoder=args.cross_attn_all_layers_encoder, + cross_attn_all_layers_decoder=args.cross_attn_all_layers_decoder, + cross_attn_nheads=args.cross_attn_nheads, + eos_id=args.eos_id, ) return LocalDecoder(local_decoder_args) @@ -763,7 +782,6 @@ class ByteLatentTransformer(nn.Module): # General configuration self.weight_tying = args.weight_tying - self.sliding_window = args.sliding_window self.patch_size = args.patch_size self.patching_mode = args.patching_mode self.boe_id, self.bos_id, self.pad_id, self.eos_id = ( diff --git a/bytelatent/model/transformer.py b/bytelatent/model/latent_transformer.py similarity index 93% rename from bytelatent/model/transformer.py rename to bytelatent/model/latent_transformer.py index 24dc057..21c3f0c 100644 --- a/bytelatent/model/transformer.py +++ b/bytelatent/model/latent_transformer.py @@ -11,6 +11,7 @@ from xformers.ops import AttentionBias from bytelatent.base_transformer import ( BaseTransformer, + BaseTransformerArgs, RMSNorm, flex_attention_comp, repeat_kv, @@ -142,11 +143,10 @@ class CrossAttention(nn.Module): class GlobalTransformer(BaseTransformer): - def __init__(self, args): + def __init__(self, args: BaseTransformerArgs): super().__init__(args) self.dropout = args.dropout - self.sliding_window = args.sliding_window - self.efficient_attn = args.efficient_attn + self.eos_id = args.eos_id self.token_embedding_projection = None if args.dim_token_emb is not None and args.dim_token_emb != self.dim: @@ -169,14 +169,19 @@ class GlobalTransformer(BaseTransformer): and projection to the token space. """ bs, seqlen = tokens.shape - attn_impl = self.efficient_attn h = embeds mask = ( mask if mask is not None - else create_causal_mask(seqlen, attn_impl, self.sliding_window) + else create_causal_mask( + seqlen, + self.attn_impl, + self.attn_bias_type, + tokens=tokens, + eos_id=self.eos_id, + ) ) if self.token_embedding_projection is not None and h.shape[-1] != self.dim: @@ -184,7 +189,7 @@ class GlobalTransformer(BaseTransformer): h = F.dropout(h, p=self.dropout, training=self.training) - h = super().forward(h, tok_idx=tok_idx, mask=mask, attn_impl=attn_impl) + h = super().forward(h, tok_idx=tok_idx, mask=mask, attn_impl=self.attn_impl) return h, cache def init_weights(self, init_base_std: float): diff --git a/bytelatent/model/local_models.py b/bytelatent/model/local_models.py index 8255504..59fa76d 100644 --- a/bytelatent/model/local_models.py +++ b/bytelatent/model/local_models.py @@ -1,44 +1,75 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. import logging -from typing import List, Optional, Tuple, Union +from typing import Any, List, Optional, Tuple, Union import torch import torch.nn import torch.nn as nn +from pydantic import BaseModel, ConfigDict from torch.nn import functional as F from torch.nn.attention.flex_attention import BlockMask from xformers.ops import AttentionBias from bytelatent.base_transformer import ( + BaseTransformerArgs, InitStdFactor, RMSNorm, RotaryEmbedding, TransformerBlock, ) -from bytelatent.model.transformer import CrossAttention +from bytelatent.model.latent_transformer import CrossAttention from bytelatent.model.utils import create_causal_mask, downsample from bytelatent.tokenizers.blt_tokenizer import BOE_ID logger = logging.getLogger() +class LocalModelArgs(BaseTransformerArgs): + model_config = ConfigDict(extra="forbid") + # Override defaults + attn_impl: str | None = "xformers" + attn_bias_type: str | None = "local_block_causal" + + # Local encoder specific dimensions + dropout: float + vocab_size: int + patch_size: int + sliding_window: int | None + use_rope: bool + cross_attn_encoder: bool | None + cross_attn_decoder: bool | None + cross_attn_k: int | None + cross_attn_init_by_pooling: bool + patching_mode: str + use_local_encoder_transformer: bool + downsampling_by_pooling: str | None + encoder_hash_byte_group_size: Any | None = None + cross_attn_all_layers_encoder: bool = False + cross_attn_all_layers_decoder: bool = False + cross_attn_nheads: int | None + + dim_token_emb: int + dim_patch_emb: int | None + + class LocalModelBase(nn.Module): - def __init__(self, args): + def __init__(self, args: LocalModelArgs): super().__init__() self.dim = args.dim self.dropout = args.dropout - self.vocab_size = args.vocab_size + args.pm_size + self.vocab_size = args.vocab_size self.patch_size = args.patch_size - self.efficient_attn = args.efficient_attn + self.attn_impl = args.attn_impl self.sliding_window = args.sliding_window self.use_rope = args.use_rope self.init_std_factor = args.init_std_factor self.cross_attn_encoder = getattr(args, "cross_attn_encoder", None) self.cross_attn_decoder = getattr(args, "cross_attn_decoder", None) self.cross_attn_k = getattr(args, "cross_attn_k", None) + self.eos_id = args.eos_id self.boe_id = BOE_ID @@ -54,7 +85,7 @@ class LocalModelBase(nn.Module): self.rope = RotaryEmbedding( theta=args.rope_theta, head_dim=args.head_dim or args.dim // args.n_heads, - max_seqlen=getattr(args, "max_encoder_seq_length", args.max_length), + max_seqlen=args.max_seqlen, ) self.pos_embeddings = None @@ -66,21 +97,15 @@ class LocalModelBase(nn.Module): self.patch_embedding_projection = self._create_patch_projection(args) - def _should_create_patch_projection(self, args): + def _should_create_patch_projection(self, args: LocalModelArgs): dimension_mismatch = ( getattr(args, "dim_patch_emb") and args.dim_patch_emb != self.dim ) # Check cross attention conditions cross_attn_conditions = ( - hasattr(args, "cross_attn_encoder") - and args.cross_attn_encoder - and getattr(args, "cross_attn_init_by_pooling") - ) or ( - hasattr(args, "cross_attn_decoder") - and args.cross_attn_decoder - and getattr(args, "cross_attn_init_by_pooling") - ) + args.cross_attn_encoder and args.cross_attn_init_by_pooling + ) or (args.cross_attn_decoder and args.cross_attn_init_by_pooling) return dimension_mismatch or cross_attn_conditions @@ -172,7 +197,7 @@ class LocalModelBase(nn.Module): class LocalEncoder(LocalModelBase): - def __init__(self, args): + def __init__(self, args: LocalModelArgs): super().__init__(args) self.output_proj = ( args.patching_mode in ["entropy", "probmax"] @@ -180,7 +205,6 @@ class LocalEncoder(LocalModelBase): self.apply_transformer = args.use_local_encoder_transformer self.downsampling_by_pooling = args.downsampling_by_pooling - self.patch_only = args.patch_only_encoder self.expects_hash_embeddings = args.encoder_hash_byte_group_size is not None self.cross_attn_encoder = args.cross_attn_encoder self.cross_attn_all_layers_encoder = args.cross_attn_all_layers_encoder @@ -224,7 +248,14 @@ class LocalEncoder(LocalModelBase): """ """ bs, seqlen = tokens.shape if mask is None: - mask = create_causal_mask(seqlen, self.efficient_attn, self.sliding_window) + mask = create_causal_mask( + seqlen, + self.attn_impl, + "local_block_causal", + sliding_window=self.sliding_window, + tokens=tokens, + eos_id=self.eos_id, + ) h = self.apply_embedding(tokens, embeds) freqs_cis = self.rope(seqlen=seqlen) if self.use_rope else None @@ -232,7 +263,7 @@ class LocalEncoder(LocalModelBase): h = F.dropout(h, p=self.dropout, training=self.training) for i, layer in enumerate(self.layers): - h = layer(h, mask=mask, freq_cis=freqs_cis, attn_impl=self.efficient_attn) + h = layer(h, mask=mask, freq_cis=freqs_cis, attn_impl=self.attn_impl) # check if cross attention should be applied to either all layer or only the last layer if self.cross_attn_encoder and ( i == len(self.layers) - 1 or self.cross_attn_all_layers_encoder @@ -273,12 +304,10 @@ class LocalEncoder(LocalModelBase): class LocalDecoder(LocalModelBase): - def __init__(self, args): + def __init__(self, args: LocalModelArgs): super().__init__(args) # Model configuration flags - self.patch_only = args.patch_only_decoder - self.expects_embeddings = args.share_encoder_decoder_emb self.cross_attn_decoder = args.cross_attn_decoder self.cross_attn_all_layers_decoder = args.cross_attn_all_layers_decoder self.cross_attn_init_by_pooling = args.cross_attn_init_by_pooling @@ -317,7 +346,14 @@ class LocalDecoder(LocalModelBase): assert embeds is not None, "Embeddings must be provided" if mask is None: - mask = create_causal_mask(seqlen, self.efficient_attn, self.sliding_window) + mask = create_causal_mask( + seqlen, + self.attn_impl, + "local_block_causal", + sliding_window=self.sliding_window, + tokens=tokens, + eos_id=self.eos_id, + ) h = embeds @@ -347,7 +383,7 @@ class LocalDecoder(LocalModelBase): ) h = h + h_cross - h = layer(h, mask=mask, freq_cis=freqs_cis, attn_impl=self.efficient_attn) + h = layer(h, mask=mask, freq_cis=freqs_cis, attn_impl=self.attn_impl) h_preds = self.norm(h) h_preds = F.dropout(h_preds, p=self.dropout, training=self.training) diff --git a/bytelatent/model/utils.py b/bytelatent/model/utils.py index ce52a30..7ca979d 100644 --- a/bytelatent/model/utils.py +++ b/bytelatent/model/utils.py @@ -1,8 +1,13 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. +import logging +import os + import torch from torch.nn.attention.flex_attention import create_block_mask from xformers.ops import fmha +logger = logging.getLogger() + def patch_reduce(h, max_num_patches, reduction, patch_ids): """ @@ -97,15 +102,74 @@ def causal_mask(b, h, q_idx, kv_idx): return q_idx >= kv_idx -def create_causal_mask(seqlen, attn_impl, sliding_window): - if sliding_window is not None and attn_impl == "xformers": - return fmha.attn_bias.LocalAttentionFromBottomRightMask( - window_left=sliding_window - 1, window_right=0 - ) - elif attn_impl == "xformers": - return fmha.attn_bias.LowerTriangularMask() +def tokens_to_seqlen(batch: torch.Tensor, eos_id: int): + """ + 0 0 0 1 0 0 0 1 0 0 0 + 0 1 0 0 0 1 0 0 0 0 0 + -> 4 4 3 2 4 5 + """ + mask = batch == eos_id + mask[:, -1] = True # virtual eos at the end of each row + + # 0 0 0 1 0 0 0 1 0 0 X + # 0 1 0 0 0 1 0 0 0 0 X + row, col = torch.where(mask) + + # row = 0, 0, 0, 1, 1, 1 + # col = 3, 7, 10, 1, 5, 10 + seqlens = (col[1:] - col[:-1]) + (row[1:] - row[:-1]) * mask.shape[1] + # seqlens = (4, 3, -9, 4, 5) + (0, 0, 11, 0, 0) = (4, 3, 2, 4, 5) + return [int(col[0].item() + 1)] + seqlens.tolist() + + +def create_causal_mask( + seqlen, + attn_impl: str, + attn_bias_type: str | None, + *, + eos_id: int | None = None, + tokens: torch.Tensor | None = None, + sliding_window: int | None = None, +): + if attn_impl == "xformers": + if attn_bias_type is None: + return fmha.attn_bias.LowerTriangularMask() + elif attn_bias_type == "causal": + assert sliding_window is None + return fmha.attn_bias.LowerTriangularMask() + elif attn_bias_type == "block_causal": + assert sliding_window is None + assert eos_id is not None + assert tokens is not None + return fmha.attn_bias.BlockDiagonalCausalMask.from_seqlens( + q_seqlen=tokens_to_seqlen(tokens, eos_id) + ) + elif attn_bias_type == "local_block_causal": + assert sliding_window is not None + assert eos_id is not None + assert tokens is not None + return fmha.attn_bias.BlockDiagonalCausalMask.from_seqlens( + q_seqlen=tokens_to_seqlen(tokens, eos_id) + ).make_local_attention(sliding_window) + else: + return fmha.attn_bias.LocalAttentionFromBottomRightMask( + window_left=sliding_window - 1, window_right=0 + ) elif attn_impl == "sdpa": - return "causal" + BLT_SUPPRESS_ATTN_ERROR = int(os.environ.get("BLT_SUPPRESS_ATTN_ERROR", 0)) + + if attn_bias_type == "causal": + return "causal" + + if BLT_SUPPRESS_ATTN_ERROR == 1: + logging.warning( + "SDPA attention being used, which doesn't have specialized attention implementations for block_causal and local_block_causal attention. Allowing model to run since BLT_SUPPRESS_ATTN_ERROR=1" + ) + return "causal" + else: + raise ValueError( + "SDPA attention being used, which doesn't have specialized attention implementations for block_causal and local_block_causal attention. To suppress this error and run the model anyway, set the environment variable BLT_SUPPRESS_ATTN_ERROR=1" + ) elif attn_impl == "flex_attention": return create_block_mask(causal_mask, None, None, seqlen, seqlen) elif attn_impl == "fmha": diff --git a/bytelatent/preprocess/fsspec_target.py b/bytelatent/preprocess/fsspec_target.py new file mode 100644 index 0000000..eacb101 --- /dev/null +++ b/bytelatent/preprocess/fsspec_target.py @@ -0,0 +1,38 @@ +import fsspec +from luigi.target import FileSystem, FileSystemTarget + + +class FSSpecFileSystem(FileSystem): + def __init__(self, fs: fsspec.AbstractFileSystem): + self.fs = fs + + def exists(self, path): + return self.fs.exists() + + def remove(self, path, recursive=True, skip_trash=True): + raise NotImplementedError() + + def isdir(self, path): + return self.fs.isdir(path) + + def listdir(self, path): + return self.fs.ls(path) + + +class FSSpecTarget(FileSystemTarget): + def __init__(self, path, fs: fsspec.AbstractFileSystem | None = None): + self.path = path + if fs is None: + self.fsspec_fs = fsspec.filesystem("file") + else: + self.fsspec_fs = fs + self._fs = None + + @property + def fs(self): + if self._fs is None: + self._fs = FSSpecFileSystem(self.fsspec_fs) + return self._fs + + def open(self, mode): + return self.fs.open(self.path, mode=mode) diff --git a/bytelatent/test_blt.py b/bytelatent/test_blt.py index 73ad9f7..36a9882 100644 --- a/bytelatent/test_blt.py +++ b/bytelatent/test_blt.py @@ -23,9 +23,10 @@ from bytelatent.model.blt import ( init_embeddings, patch_ids_from_lengths, ) -from bytelatent.model.transformer import CrossAttention +from bytelatent.model.latent_transformer import CrossAttention from bytelatent.model.utils import create_causal_mask from bytelatent.optim import OptimArgs, build_optimizer +from bytelatent.tokenizers.constants import EOS_ID from bytelatent.train import compute_loss @@ -51,7 +52,7 @@ def batch_to_tensors_and_gpu(batch): def fake_batch(): - batch_dict = torch.load(os.path.join(BLT_DATA, "test_batch.pt")) + batch_dict = torch.load(os.path.join(BLT_DATA, "test_batch.pt"), weights_only=False) del batch_dict["x2"] del batch_dict["y2"] del batch_dict["src_names"] @@ -98,18 +99,17 @@ def create_args(cross_attention=False): recompute_attn=False, custom_bwd=False, layer_ckpt="none", - efficient_attn="sdpa", - patch_only_encoder=False, - patch_only_decoder=False, use_local_encoder_transformer=True, init_use_gaussian=True, init_use_depth="current", attn_bias_type="block_causal", + attn_impl="xformers", alpha_depth="disabled", max_length=256, local_attention_window_len=512, max_seqlen=12288, downsampling_by_pooling="max", + eos_id=EOS_ID, ) return transformer_args @@ -341,10 +341,15 @@ class TestByteLatentTransformer: model = ByteLatentTransformer(args) assert model is not None - @pytest.mark.parametrize("attn_type", ["fmha", "sdpa"]) - def test_blt_transformer_forward(self, attn_type): + @pytest.mark.parametrize("attn_impl", ["sdpa", "xformers"]) + def test_blt_transformer_forward(self, attn_impl): args = create_args() - args = args.model_copy(update=dict(efficient_attn=attn_type)) + if attn_impl == "sdpa": + os.environ["BLT_SUPPRESS_ATTN_ERROR"] = "1" + else: + os.environ["BLT_SUPPRESS_ATTN_ERROR"] = "0" + + args = args.model_copy(update=dict(attn_impl=attn_impl)) model = ByteLatentTransformer(args) model = model.cuda() batch = fake_batch() @@ -393,7 +398,9 @@ class TestByteLatentTransformer: n_kv_heads=4, norm_eps=1e-6, ).to("cuda") - mask = create_causal_mask(x.shape[1], "flex_attention", sliding_window=None) + mask = create_causal_mask( + x.shape[1], "flex_attention", None, sliding_window=None + ) output = cross_attention(x, kv, mask) assert output is not None assert output.shape == (2, 256, 512) @@ -440,7 +447,7 @@ class TestByteLatentTransformer: def test_loss_backward(self): args = create_args() - args = args.model_copy(update=dict(efficient_attn="sdpa")) + args = args.model_copy(update=dict(attn_impl="xformers")) batch = fake_batch() model = ByteLatentTransformer(args) steps = 10 diff --git a/bytelatent/test_entropy_model.py b/bytelatent/test_entropy_model.py index 3acc42d..9db7ff6 100644 --- a/bytelatent/test_entropy_model.py +++ b/bytelatent/test_entropy_model.py @@ -24,6 +24,7 @@ def test_entropy_model(): dataset_files=[ARROW_TEST_DATA], row_num=0, arrow_batch_size=100, + s3_profile=None, ) arrow_file = initial_state.build() tokenizer_args = TokenizerArgs( @@ -38,7 +39,7 @@ def test_entropy_model(): BLT_DATA, "entropy_model.pth", ), - ) + ).cuda() preprocess_iter = PreprocessIterator( arrow_file, tokenizer_args=tokenizer_args, @@ -48,8 +49,10 @@ def test_entropy_model(): for example in preprocess_iter.create_iter(): tokens = torch.tensor(example.tokens).unsqueeze(0) expected_entropies = torch.tensor(example.entropies).unsqueeze(0) - preds = entropy_model(tokens) + preds = entropy_model(tokens.cuda()) pred_entropies = entropy(preds) assert pred_entropies.shape == expected_entropies.shape - assert torch.allclose(pred_entropies, expected_entropies, rtol=1.0, atol=3.5) + assert torch.allclose( + pred_entropies.cpu(), expected_entropies, rtol=1.0, atol=3.5 + ) break diff --git a/bytelatent/train.py b/bytelatent/train.py index 6cb13b9..80bd393 100644 --- a/bytelatent/train.py +++ b/bytelatent/train.py @@ -644,6 +644,10 @@ def main(): cfg = OmegaConf.merge(default_cfg, file_cfg, cli_args) cfg = OmegaConf.to_container(cfg, resolve=True, throw_on_missing=True) train_args = TrainArgs.model_validate(cfg) + if train_args.debug_dynamo: + import torch._dynamo + + torch._dynamo.config.suppress_errors = True train(train_args) diff --git a/bytelatent/transformer.py b/bytelatent/transformer.py index 432f7df..92c5ff5 100644 --- a/bytelatent/transformer.py +++ b/bytelatent/transformer.py @@ -22,23 +22,7 @@ from bytelatent.base_transformer import ( RMSNorm, cross_entropy, ) - - -def create_causal_mask(seqlen, attn_impl, sliding_window): - if sliding_window is not None and attn_impl == "xformers": - return fmha.attn_bias.LocalAttentionFromBottomRightMask( - window_left=sliding_window - 1, window_right=0 - ) - elif attn_impl == "xformers": - return fmha.attn_bias.LowerTriangularMask() - elif attn_impl == "sdpa": - return "causal" - elif attn_impl == "flex_attention": - return create_block_mask(causal_mask, None, None, seqlen, seqlen) - else: - raise NotImplementedError( - f"Attention {attn_impl} with {sliding_window} sliding window not implemented" - ) +from bytelatent.model.utils import create_causal_mask def attention_flops_per_token(n_layers, seq_len, dim, causal): @@ -94,8 +78,10 @@ class LMTransformer(BaseTransformer): target: Optional[torch.Tensor] = None, tok_idx: Optional[torch.Tensor] = None, mask: Optional[Union[BlockMask, AttentionBias, torch.Tensor, str]] = None, - attn_impl: str = "sdpa", + attn_impl: str | None = None, ): + if attn_impl is None: + attn_impl = self.attn_impl bsz, seqlen = token_values.shape h = self.tok_embeddings(token_values) @@ -103,7 +89,14 @@ class LMTransformer(BaseTransformer): mask = ( mask if mask is not None - else create_causal_mask(seqlen, attn_impl, self.sliding_window) + else create_causal_mask( + seqlen, + attn_impl, + self.attn_bias_type, + sliding_window=self.sliding_window, + tokens=token_values, + eos_id=self.eos_id, + ) ) h = super().forward(h, tok_idx=tok_idx, mask=mask, attn_impl=attn_impl) From 8a3084c346863cba88836aba087027cd96fa1c3c Mon Sep 17 00:00:00 2001 From: Pedro Rodriguez Date: Wed, 22 Jan 2025 19:57:54 +0000 Subject: [PATCH 11/18] Update file check script to check sizes Summary: Test Plan: --- bytelatent/data/file_util.py | 51 +++++++++++++++++++++++++++++++++--- 1 file changed, 48 insertions(+), 3 deletions(-) diff --git a/bytelatent/data/file_util.py b/bytelatent/data/file_util.py index d67b6db..5165eab 100644 --- a/bytelatent/data/file_util.py +++ b/bytelatent/data/file_util.py @@ -65,7 +65,10 @@ def print_local_to_delete( @app.command() def compare_local_to_blob( - source_dirs: list[str], dst_dir: str, s3_profile: str = "blt" + source_dirs: list[str], + dst_dir: str, + s3_profile: str = "blt", + print_sizes: bool = False, ): for s in source_dirs: assert s.endswith("/"), "Dirs must end with /" @@ -75,6 +78,7 @@ def compare_local_to_blob( local_fs = fsspec.filesystem("file") dst_fs = fsspec.filesystem("s3", profile=s3_profile) source_to_files = {} + source_file_to_size = {} all_local_files = set() for s in source_dirs: skipped = [] @@ -97,14 +101,28 @@ def compare_local_to_blob( skipped.append(f) continue + file_without_prefix = f[len(s) :] + if file_without_prefix not in source_file_to_size: + source_file_to_size[file_without_prefix] = os.path.getsize(f) + else: + source_file_to_size[file_without_prefix] = max( + source_file_to_size[file_without_prefix], os.path.getsize(f) + ) + source_to_files[s].append(f) - all_local_files.add(f[len(s) :]) + all_local_files.add(file_without_prefix) print(s, len(source_to_files[s]), "skipped", len(skipped), skipped[:10]) dst_files = dst_fs.find(dst_dir) print(dst_dir, len(dst_files)) - dst_file_set = {f[len(dst_dir) - len(S3_PREFIX) :] for f in dst_files} + dst_file_to_size = {} + dst_file_set = set() + for f in dst_files: + dst_file_without_prefix = f[len(dst_dir) - len(S3_PREFIX) :] + dst_file_set.add(dst_file_without_prefix) + dst_file_to_size[dst_file_without_prefix] = dst_fs.size(f) + diff = all_local_files.symmetric_difference(dst_file_set) print("Local files", len(all_local_files)) print("DST Files", len(dst_file_set)) @@ -112,6 +130,33 @@ def compare_local_to_blob( dst_only_files = dst_file_set - all_local_files print("DST only", len(dst_only_files), list(dst_only_files)[:10]) + all_files = dst_file_set | all_local_files + print("Check that files match") + size_success = True + for f in sorted(all_files): + if f in source_file_to_size and f in dst_file_to_size: + if source_file_to_size[f] != dst_file_to_size[f]: + size_success = False + print( + f"Mismatch file size for {f}, Local: {source_file_to_size[f]} Blob: {dst_file_to_size[f]}" + ) + else: + if print_sizes: + print(f"Matching file size: {dst_file_to_size[f]} for {f}") + elif f not in source_file_to_size: + size_success = False + print(f"Missing file in source: {f}") + elif f not in dst_file_to_size: + size_success = False + print(f"missing file in dst: {f}") + else: + raise ValueError("Unexpected to be missing file in src and dst") + + if size_success: + print("All files pass size check") + else: + raise ValueError("At least one file failed size comparison check") + if __name__ == "__main__": app() From bd461af91a86b572442fb4170c288603dbd5e6e5 Mon Sep 17 00:00:00 2001 From: Pedro Rodriguez Date: Fri, 24 Jan 2025 18:56:18 +0000 Subject: [PATCH 12/18] Use load_async flag to not start MP iterator Summary: Test Plan: --- bytelatent/args.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/bytelatent/args.py b/bytelatent/args.py index b9144c6..a332c89 100644 --- a/bytelatent/args.py +++ b/bytelatent/args.py @@ -150,11 +150,13 @@ class DataloaderArgs(BaseModel): enable_byte_ngrams=self.enable_byte_ngrams, ) packing_iterator = PackingIterator(sampling_iterator, packing_args=packing_args) - mp_iterator = MultiprocessIterator( - packing_iterator, n_batches_to_prefetch=self.prefetch_size - ) - - return mp_iterator + if self.load_async: + mp_iterator = MultiprocessIterator( + packing_iterator, n_batches_to_prefetch=self.prefetch_size + ) + return mp_iterator + else: + return packing_iterator class TrainArgs(BaseModel): From fb09022e5e40706d901dc1f0eb19f0333a6127e3 Mon Sep 17 00:00:00 2001 From: Pedro Rodriguez Date: Fri, 24 Jan 2025 21:55:24 +0000 Subject: [PATCH 13/18] Initial codes and scripts for training entropy model Summary: Test Plan: --- .gitignore | 1 + bytelatent/args.py | 13 ++- bytelatent/configs/debug.yaml | 3 +- bytelatent/configs/entropy_model.yaml | 82 +++++++++++++++++++ bytelatent/data/data_types.py | 2 +- bytelatent/data/iterators/packing_iterator.py | 42 ++++++++++ .../data/iterators/sequence_iterator.py | 30 +++++-- bytelatent/data/patcher.py | 10 ++- bytelatent/model/blt.py | 5 +- bytelatent/test_blt.py | 3 +- bytelatent/train.py | 52 +++++++++--- 11 files changed, 209 insertions(+), 34 deletions(-) create mode 100644 bytelatent/configs/entropy_model.yaml diff --git a/.gitignore b/.gitignore index 6c664b8..d1d7c2a 100644 --- a/.gitignore +++ b/.gitignore @@ -166,3 +166,4 @@ figures/ .vscode/ .DS_Store internal/ +jobs_parallel-copy/ diff --git a/bytelatent/args.py b/bytelatent/args.py index a332c89..56de22d 100644 --- a/bytelatent/args.py +++ b/bytelatent/args.py @@ -93,6 +93,8 @@ class DataloaderArgs(BaseModel): max_encoder_seq_length: int = 12288 enable_byte_ngrams: bool = False + add_patches: bool = True + tokenizer_args: TokenizerArgs = TokenizerArgs() patcher_args: PatcherArgs = PatcherArgs() @@ -120,6 +122,7 @@ class DataloaderArgs(BaseModel): looping_iterator, patcher_args=self.patcher_args, tokenizer_args=self.tokenizer_args, + add_patches=self.add_patches, ) sequence_iterator = SequenceIterator( preprocess_iterator, @@ -141,13 +144,19 @@ class DataloaderArgs(BaseModel): source_to_iterator=source_to_sequence_iterators, ) tokenizer = self.tokenizer_args.build() + if self.tokenizer_args.name == "bytes": + # TODO: Check this with Artidoro + pad_id = 0 + else: + pad_id = tokenizer.boe_id packing_args = PackingArgs( batch_size=self.batch_size, seq_len=self.seq_len, - pad_id=tokenizer.boe_id, + pad_id=pad_id, max_length=self.max_encoder_seq_length, pad_to_max_length=self.pad_to_max_length, enable_byte_ngrams=self.enable_byte_ngrams, + tokenizer_name=self.tokenizer_args.name, ) packing_iterator = PackingIterator(sampling_iterator, packing_args=packing_args) if self.load_async: @@ -180,7 +189,7 @@ class TrainArgs(BaseModel): data: DataloaderArgs = DataloaderArgs() optim: OptimArgs = OptimArgs() - model: ByteLatentTransformerArgs = ByteLatentTransformerArgs() + model: ByteLatentTransformerArgs | None = ByteLatentTransformerArgs() # This is only needed for training the entropy model entropy_model: LMTransformerArgs | None = None # Instead of training main model, train entropy model diff --git a/bytelatent/configs/debug.yaml b/bytelatent/configs/debug.yaml index 4ae4459..1098ff5 100644 --- a/bytelatent/configs/debug.yaml +++ b/bytelatent/configs/debug.yaml @@ -26,10 +26,9 @@ model: vocab_size: 260 dim_token: 256 patch_size: 6 - tokenization_mode: "bytes" patching_mode: "space" tie_local_encoder_decoder_logits: false - data_loader_patching: true + patch_in_forward: false max_encoder_seq_length: 12288 pad_to_max_length: true patching_threshold: 3.1439168453216553 diff --git a/bytelatent/configs/entropy_model.yaml b/bytelatent/configs/entropy_model.yaml new file mode 100644 index 0000000..51b65d4 --- /dev/null +++ b/bytelatent/configs/entropy_model.yaml @@ -0,0 +1,82 @@ +# Template config, need to change dump_dir, data.root_dir and tokenizer.path +# Evals can be activated by uncommenting its config +# python -m launchers.stool config=apps/main/configs/debug.yaml nodes=8 account=fair_amaia_cw_codegen qos=lowest + +dump_dir: /tmp/ +name: "debug" +steps: 100_000 +probe_freq: null +seed: 777 +optim: + lr: 4e-04 + warmup: 500 + lr_min_ratio: 0.1 + clip: 10.0 + +distributed: + fsdp_type: full_shard + model_dtype: bf16 + matmul_allow_tf32: false + selective_activation_checkpointing: false + tp_size: 1 + +train_entropy_model: true +model: null +entropy_model: + dim: 768 + n_layers: 14 + n_heads: 12 + max_seqlen: 8192 + # vocab_size: -1 + vocab_size: 260 + ffn_dim_multiplier: 1.0 + sliding_window: 512 + attn_bias_type: "local_block_causal" + attn_impl: "xformers" + +data: + s3_profile: blt + root_dir: ??? + sources: + dclm_baseline_1.0: 1.0 + batch_size: 2 + prefetch_size: 64 + # seqlen is in terms of patches and + # max_encoder_seq_length is in terms of bytes. + # For entropy model, these are the same since 1 patch=1 byte + seq_len: 8192 + max_encoder_seq_length: 8192 + load_async: true + preprocess_dir: ??? + # We don't need patches for this model + add_patches: false + patcher_args: + # This doesn't matter since byte entropy model doesn't use patching, + # so pick the most efficient, so static + patching_mode: byte + tokenizer_args: + name: bytes + +profiling: + run: false + +checkpoint: + dump: + every: 500 + keep: 3 + eval: + every: 1000 + keep: -1 + +logging: + freq: 10 + +eval_on_gpus: 8 +eval: + dataset_dir: ??? + tasks: ??? + generator: + max_tokens: 65536 + dtype: bf16 + + mp_size: 1 diff --git a/bytelatent/data/data_types.py b/bytelatent/data/data_types.py index 7e142e4..aa2daa9 100644 --- a/bytelatent/data/data_types.py +++ b/bytelatent/data/data_types.py @@ -53,7 +53,7 @@ BltIterator = Iterator[tuple[BltExample, DataLoaderState]] class BltSequence(BaseModel): tokens: list[int] mask: list[bool] - patch_lengths: list[int] + patch_lengths: list[int] | None @dataclass diff --git a/bytelatent/data/iterators/packing_iterator.py b/bytelatent/data/iterators/packing_iterator.py index 361fc03..fa29149 100644 --- a/bytelatent/data/iterators/packing_iterator.py +++ b/bytelatent/data/iterators/packing_iterator.py @@ -17,6 +17,7 @@ class PackingArgs(BaseModel): max_length: int | None pad_to_max_length: bool enable_byte_ngrams: bool + tokenizer_name: str class PackingIteratorState(BaseModel, IteratorState): @@ -151,6 +152,43 @@ class PackingIterator(StatefulIterator[Batch, PackingIteratorState]): ) def create_iter(self): + if self.packing_args.tokenizer_name == "bytes": + return self._create_iter_from_bytes() + else: + return self._create_iter_from_patch_lengths() + + def _create_iter_from_bytes(self): + sequence_iter = self.sequence_iterator.create_iter() + batch_size = self.packing_args.batch_size + pad_id = self.packing_args.pad_id + seq_len = self.packing_args.seq_len + while True: + tokens: list[list[int]] = [] + masks: list[list[bool]] = [] + + for _ in range(self.packing_args.batch_size): + sequence = next(sequence_iter) + _tokens = sequence.tokens + _mask = sequence.mask + assert ( + sequence.patch_lengths is None + ), "patch_lengths should not be used in byte packing" + tokens.append(_tokens) + masks.append(_mask) + + x = np.full((batch_size, seq_len), fill_value=pad_id) + y = np.full((batch_size, seq_len), fill_value=pad_id) + + for i, tok_seq in enumerate(tokens): + x[i, : len(tok_seq)] = tok_seq + y[i, : len(tok_seq) - 1] = tok_seq[1:] + batch = Batch(x=x, y=y) + assert ( + batch.mask is None or np.sum(x != pad_id) == batch.mask.sum() + ), f"{np.sum(x != pad_id)} != {batch.mask.sum()}" + yield batch + + def _create_iter_from_patch_lengths(self): sequence_iter = self.sequence_iterator.create_iter() batch_size = self.packing_args.batch_size pad_id = self.packing_args.pad_id @@ -168,6 +206,10 @@ class PackingIterator(StatefulIterator[Batch, PackingIteratorState]): _tokens = sequence.tokens _mask = sequence.mask _patch_lengths = sequence.patch_lengths + assert ( + _patch_lengths is not None + ), "patch lengths are required for packing based on patches." + # Reminder: seq_len is in terms of patches assert len(sequence.patch_lengths) == self.packing_args.seq_len last_patch_length = 0 if _patch_lengths[0] > 1: diff --git a/bytelatent/data/iterators/sequence_iterator.py b/bytelatent/data/iterators/sequence_iterator.py index 14e3747..d90ea31 100644 --- a/bytelatent/data/iterators/sequence_iterator.py +++ b/bytelatent/data/iterators/sequence_iterator.py @@ -70,15 +70,22 @@ class SequenceIterator(StatefulIterator): for example in example_iter: assert example.tokens is not None assert example.mask is not None - assert example.patch_lengths is not None + if self.preprocess_iterator.add_patches: + assert example.patch_lengths is not None + assert len(example.tokens) == sum(example.patch_lengths) + else: + assert example.patch_lengths is None assert len(example.tokens) != 0 assert len(example.mask) != 0 assert len(example.tokens) == len(example.mask) - assert len(example.tokens) == sum(example.patch_lengths) tokens.extend(example.tokens) mask.extend(example.mask) - patch_lengths.extend(example.patch_lengths) + if self.preprocess_iterator.add_patches: + patch_lengths.extend(example.patch_lengths) + else: + # This lets the rest of the code work as expected and just yield byte seqs + patch_lengths.extend([1] * len(example.tokens)) while len(patch_lengths) >= n_buffer_patches: if first: @@ -115,8 +122,15 @@ class SequenceIterator(StatefulIterator): == len(seq_mask[idx]) ), f"{sum(seq_patch_lengths[idx])}, {len(seq_tokens[idx])} {len(seq_mask[idx])}, idx={idx}" assert seq_patch_lengths[idx][0] > 0, f"{seq_patch_lengths[idx]}" - yield BltSequence( - tokens=seq_tokens[idx], - mask=seq_mask[idx], - patch_lengths=seq_patch_lengths[idx], - ) + if self.preprocess_iterator.add_patches: + yield BltSequence( + tokens=seq_tokens[idx], + mask=seq_mask[idx], + patch_lengths=seq_patch_lengths[idx], + ) + else: + yield BltSequence( + tokens=seq_tokens[idx], + mask=seq_mask[idx], + patch_lengths=None, + ) diff --git a/bytelatent/data/patcher.py b/bytelatent/data/patcher.py index afcfa2e..44ff5e9 100644 --- a/bytelatent/data/patcher.py +++ b/bytelatent/data/patcher.py @@ -22,6 +22,8 @@ class PatchingModeEnum(str, Enum): bpe = "bpe" bpe_patcher = "bpe_patcher" space = "space" + static = "static" + byte = "byte" class PatcherArgs(BaseModel): @@ -34,7 +36,6 @@ class PatcherArgs(BaseModel): max_patch_length: int | None = None patch_size: float = 4.5 patching_batch_size: int = 1 - data_loader_patching: bool = False device: str = "cuda" monotonicity: bool = False log_time: bool = False @@ -486,7 +487,6 @@ class Patcher: self.max_patch_length = patcher_args.max_patch_length self.patch_size = patcher_args.patch_size self.patching_batch_size = patcher_args.patching_batch_size - self.data_loader_patching = patcher_args.data_loader_patching self.device = patcher_args.device self.monotonicity = patcher_args.monotonicity self.log_time = patcher_args.log_time @@ -528,7 +528,7 @@ class Patcher: seq_len_next_tok = seq_len + 1 if include_next_token else seq_len scores = None # STATIC - if self.patching_mode is None: + if self.patching_mode == PatchingModeEnum.static: patch_lengths = torch.zeros( (bs, math.ceil(seq_len_next_tok / self.patch_size)), dtype=tokens.dtype, @@ -536,6 +536,10 @@ class Patcher: ).fill_(self.patch_size) if seq_len_next_tok % self.patch_size != 0: patch_lengths[:, -1] = seq_len_next_tok % self.patch_size + elif self.patching_mode == PatchingModeEnum.byte: + patch_lengths = torch.ones( + (bs, seq_len_next_tok), dtype=tokens.dtype, device=tokens.device + ) # ENTROPY elif self.patching_mode == PatchingModeEnum.entropy: if self.log_time: diff --git a/bytelatent/model/blt.py b/bytelatent/model/blt.py index 843ad34..a62be23 100644 --- a/bytelatent/model/blt.py +++ b/bytelatent/model/blt.py @@ -411,6 +411,7 @@ class ByteLatentTransformerArgs(BaseTransformerArgs): n_heads: int = 8 # TODO: What is the purpose of this parameter? weight_tying: bool = False + patch_in_forward: bool = False # Architecture and dimensions dim_token: int = 256 @@ -422,7 +423,6 @@ class ByteLatentTransformerArgs(BaseTransformerArgs): n_layers_local_encoder: int = 8 # Tokenization and patching - tokenization_mode: str = "bpe" patch_size: float | None = None patching_mode: str | None = None patching_threshold: float | None = None @@ -430,7 +430,6 @@ class ByteLatentTransformerArgs(BaseTransformerArgs): monotonicity: bool = False patching_batch_size: int = 1 patching_device: str = "cuda" - data_loader_patching: bool = False max_patch_length: int | None = None # Encoder/Decoder configuration @@ -856,7 +855,7 @@ class ByteLatentTransformer(nn.Module): self.output.weight = self.tok_embeddings.weight # Patcher module - if not args.data_loader_patching: + if args.patch_in_forward: self.patcher = Patcher( PatcherArgs( patch_size=args.patch_size, diff --git a/bytelatent/test_blt.py b/bytelatent/test_blt.py index 36a9882..eb94df3 100644 --- a/bytelatent/test_blt.py +++ b/bytelatent/test_blt.py @@ -68,10 +68,9 @@ def create_args(cross_attention=False): # Additional args from command line dim_token=256, patch_size=6, - tokenization_mode="bytes", patching_mode="space", tie_local_encoder_decoder_logits=False, - data_loader_patching=True, + patch_in_forward=False, max_encoder_seq_length=12288, pad_to_max_length=True, encoder_lm_loss=False, diff --git a/bytelatent/train.py b/bytelatent/train.py index 80bd393..a7ca405 100644 --- a/bytelatent/train.py +++ b/bytelatent/train.py @@ -51,6 +51,7 @@ from bytelatent.transformer import ( get_no_recompute_ops, get_num_flop_per_token, tp_parallelize, + LMTransformer, ) logger = logging.getLogger() @@ -103,10 +104,15 @@ class TrainState(Stateful): def validate_train_args(args: TrainArgs, output_size: int): - if args.model.vocab_size < 0: + assert args.model is not None or args.entropy_model is not None + if args.model is not None: logger.info(f"Setting model output size to {args.model.vocab_size}") args.model.vocab_size = output_size + if args.entropy_model is not None: + logger.info(f"Setting model output size to {args.entropy_model.vocab_size}") + args.entropy_model.vocab_size = output_size + assert args.dump_dir, "Dump dir not set" if args.checkpoint.path is None: @@ -147,7 +153,10 @@ def validate_train_args(args: TrainArgs, output_size: int): and args.distributed.dp_replicate == get_world_size() ) - args.model.max_seqlen = args.data.seq_len + if args.model is not None: + args.model.max_seqlen = args.data.seq_len + if args.entropy_model is not None: + args.entropy_model.max_seqlen = args.data.seq_len if args.distributed.tp_size == 1: logger.warning( @@ -237,7 +246,14 @@ def train(args: TrainArgs): # Initializing Model in meta device allows us to initialize models much bigger than 1 gpu's memory with torch.device("meta"): - model = ByteLatentTransformer(args.model) + if args.train_entropy_model: + assert args.entropy_model is not None + model = LMTransformer(args.entropy_model) + model_args = args.entropy_model + else: + assert args.model is not None + model = ByteLatentTransformer(args.model) + model_args = args.model logger.info("Model is built !") model_param_count = get_num_params(model) @@ -247,7 +263,7 @@ def train(args: TrainArgs): world_mesh, args.model, args.distributed, - fsdp_grouping_plan=build_fsdp_grouping_plan(args.model), + fsdp_grouping_plan=build_fsdp_grouping_plan(model_args), tp_parallelize=tp_parallelize, no_recompute_ops=get_no_recompute_ops(), ) @@ -267,7 +283,7 @@ def train(args: TrainArgs): model.rope_embeddings.reset_parameters() # For RoPe initialization since it's a buffer it might not be loaded else: with torch.random.fork_rng(devices=[torch.cuda.current_device()]): - torch.manual_seed(args.model.seed) + torch.manual_seed(model_args.seed) model.init_weights() check_model_value_range(model, range=10.0, std=1.0) @@ -342,10 +358,17 @@ def train(args: TrainArgs): batch.x, ).cuda() batch_y = torch.from_numpy(batch.y).cuda() - batch_patch_lengths = torch.from_numpy(batch.patch_lengths).cuda() + if batch.patch_lengths is None: + batch_patch_lengths = None + else: + batch_patch_lengths = torch.from_numpy(batch.patch_lengths).cuda() mask = None if batch.mask is None else torch.from_numpy(batch.mask).cuda() - if args.model.encoder_enable_byte_ngrams and batch.ngram_ids is None: + if ( + not args.train_entropy_model + and args.model.encoder_enable_byte_ngrams + and batch.ngram_ids is None + ): raise ValueError( "Cannot enable byte ngrams and have batch.ngram_ids be None" ) @@ -408,9 +431,12 @@ def train(args: TrainArgs): next(probe_mod.parameters()).grad is None ), "Probe model shouldn't have grads at this point" - pred = model( - batch_x, patch_lengths=batch_patch_lengths, ngram_ids=ngram_ids - ) + if args.train_entropy_model: + pred = model(batch_x) + else: + pred = model( + batch_x, patch_lengths=batch_patch_lengths, ngram_ids=ngram_ids + ) loss, _ = compute_loss(pred, batch_y, mask, train_state.scale) @@ -474,9 +500,9 @@ def train(args: TrainArgs): # Use xformer's analyze profile trace to get actual measurement FLOPS = ( get_num_flop_per_token( - model_param_count - args.model.vocab_size * args.model.dim, - args.model.n_layers, - args.model.dim, + model_param_count - model_args.vocab_size * model_args.dim, + model_args.n_layers, + model_args.dim, args.data.seq_len, ) * wps From 34ca1f7d4b40030bbaba2cf51715434c902018bd Mon Sep 17 00:00:00 2001 From: Pedro Rodriguez Date: Fri, 24 Jan 2025 21:55:24 +0000 Subject: [PATCH 14/18] Initial codes and scripts for training entropy model Summary: Test Plan: --- .gitignore | 1 + bytelatent/args.py | 13 ++- bytelatent/configs/debug.yaml | 3 +- bytelatent/configs/entropy_model.yaml | 82 +++++++++++++++++++ bytelatent/data/data_types.py | 2 +- bytelatent/data/iterators/packing_iterator.py | 42 ++++++++++ .../data/iterators/sequence_iterator.py | 30 +++++-- bytelatent/data/patcher.py | 10 ++- bytelatent/model/blt.py | 5 +- bytelatent/test_blt.py | 3 +- bytelatent/train.py | 52 +++++++++--- 11 files changed, 209 insertions(+), 34 deletions(-) create mode 100644 bytelatent/configs/entropy_model.yaml diff --git a/.gitignore b/.gitignore index 6c664b8..d1d7c2a 100644 --- a/.gitignore +++ b/.gitignore @@ -166,3 +166,4 @@ figures/ .vscode/ .DS_Store internal/ +jobs_parallel-copy/ diff --git a/bytelatent/args.py b/bytelatent/args.py index a332c89..56de22d 100644 --- a/bytelatent/args.py +++ b/bytelatent/args.py @@ -93,6 +93,8 @@ class DataloaderArgs(BaseModel): max_encoder_seq_length: int = 12288 enable_byte_ngrams: bool = False + add_patches: bool = True + tokenizer_args: TokenizerArgs = TokenizerArgs() patcher_args: PatcherArgs = PatcherArgs() @@ -120,6 +122,7 @@ class DataloaderArgs(BaseModel): looping_iterator, patcher_args=self.patcher_args, tokenizer_args=self.tokenizer_args, + add_patches=self.add_patches, ) sequence_iterator = SequenceIterator( preprocess_iterator, @@ -141,13 +144,19 @@ class DataloaderArgs(BaseModel): source_to_iterator=source_to_sequence_iterators, ) tokenizer = self.tokenizer_args.build() + if self.tokenizer_args.name == "bytes": + # TODO: Check this with Artidoro + pad_id = 0 + else: + pad_id = tokenizer.boe_id packing_args = PackingArgs( batch_size=self.batch_size, seq_len=self.seq_len, - pad_id=tokenizer.boe_id, + pad_id=pad_id, max_length=self.max_encoder_seq_length, pad_to_max_length=self.pad_to_max_length, enable_byte_ngrams=self.enable_byte_ngrams, + tokenizer_name=self.tokenizer_args.name, ) packing_iterator = PackingIterator(sampling_iterator, packing_args=packing_args) if self.load_async: @@ -180,7 +189,7 @@ class TrainArgs(BaseModel): data: DataloaderArgs = DataloaderArgs() optim: OptimArgs = OptimArgs() - model: ByteLatentTransformerArgs = ByteLatentTransformerArgs() + model: ByteLatentTransformerArgs | None = ByteLatentTransformerArgs() # This is only needed for training the entropy model entropy_model: LMTransformerArgs | None = None # Instead of training main model, train entropy model diff --git a/bytelatent/configs/debug.yaml b/bytelatent/configs/debug.yaml index 4ae4459..1098ff5 100644 --- a/bytelatent/configs/debug.yaml +++ b/bytelatent/configs/debug.yaml @@ -26,10 +26,9 @@ model: vocab_size: 260 dim_token: 256 patch_size: 6 - tokenization_mode: "bytes" patching_mode: "space" tie_local_encoder_decoder_logits: false - data_loader_patching: true + patch_in_forward: false max_encoder_seq_length: 12288 pad_to_max_length: true patching_threshold: 3.1439168453216553 diff --git a/bytelatent/configs/entropy_model.yaml b/bytelatent/configs/entropy_model.yaml new file mode 100644 index 0000000..51b65d4 --- /dev/null +++ b/bytelatent/configs/entropy_model.yaml @@ -0,0 +1,82 @@ +# Template config, need to change dump_dir, data.root_dir and tokenizer.path +# Evals can be activated by uncommenting its config +# python -m launchers.stool config=apps/main/configs/debug.yaml nodes=8 account=fair_amaia_cw_codegen qos=lowest + +dump_dir: /tmp/ +name: "debug" +steps: 100_000 +probe_freq: null +seed: 777 +optim: + lr: 4e-04 + warmup: 500 + lr_min_ratio: 0.1 + clip: 10.0 + +distributed: + fsdp_type: full_shard + model_dtype: bf16 + matmul_allow_tf32: false + selective_activation_checkpointing: false + tp_size: 1 + +train_entropy_model: true +model: null +entropy_model: + dim: 768 + n_layers: 14 + n_heads: 12 + max_seqlen: 8192 + # vocab_size: -1 + vocab_size: 260 + ffn_dim_multiplier: 1.0 + sliding_window: 512 + attn_bias_type: "local_block_causal" + attn_impl: "xformers" + +data: + s3_profile: blt + root_dir: ??? + sources: + dclm_baseline_1.0: 1.0 + batch_size: 2 + prefetch_size: 64 + # seqlen is in terms of patches and + # max_encoder_seq_length is in terms of bytes. + # For entropy model, these are the same since 1 patch=1 byte + seq_len: 8192 + max_encoder_seq_length: 8192 + load_async: true + preprocess_dir: ??? + # We don't need patches for this model + add_patches: false + patcher_args: + # This doesn't matter since byte entropy model doesn't use patching, + # so pick the most efficient, so static + patching_mode: byte + tokenizer_args: + name: bytes + +profiling: + run: false + +checkpoint: + dump: + every: 500 + keep: 3 + eval: + every: 1000 + keep: -1 + +logging: + freq: 10 + +eval_on_gpus: 8 +eval: + dataset_dir: ??? + tasks: ??? + generator: + max_tokens: 65536 + dtype: bf16 + + mp_size: 1 diff --git a/bytelatent/data/data_types.py b/bytelatent/data/data_types.py index 7e142e4..aa2daa9 100644 --- a/bytelatent/data/data_types.py +++ b/bytelatent/data/data_types.py @@ -53,7 +53,7 @@ BltIterator = Iterator[tuple[BltExample, DataLoaderState]] class BltSequence(BaseModel): tokens: list[int] mask: list[bool] - patch_lengths: list[int] + patch_lengths: list[int] | None @dataclass diff --git a/bytelatent/data/iterators/packing_iterator.py b/bytelatent/data/iterators/packing_iterator.py index 361fc03..fa29149 100644 --- a/bytelatent/data/iterators/packing_iterator.py +++ b/bytelatent/data/iterators/packing_iterator.py @@ -17,6 +17,7 @@ class PackingArgs(BaseModel): max_length: int | None pad_to_max_length: bool enable_byte_ngrams: bool + tokenizer_name: str class PackingIteratorState(BaseModel, IteratorState): @@ -151,6 +152,43 @@ class PackingIterator(StatefulIterator[Batch, PackingIteratorState]): ) def create_iter(self): + if self.packing_args.tokenizer_name == "bytes": + return self._create_iter_from_bytes() + else: + return self._create_iter_from_patch_lengths() + + def _create_iter_from_bytes(self): + sequence_iter = self.sequence_iterator.create_iter() + batch_size = self.packing_args.batch_size + pad_id = self.packing_args.pad_id + seq_len = self.packing_args.seq_len + while True: + tokens: list[list[int]] = [] + masks: list[list[bool]] = [] + + for _ in range(self.packing_args.batch_size): + sequence = next(sequence_iter) + _tokens = sequence.tokens + _mask = sequence.mask + assert ( + sequence.patch_lengths is None + ), "patch_lengths should not be used in byte packing" + tokens.append(_tokens) + masks.append(_mask) + + x = np.full((batch_size, seq_len), fill_value=pad_id) + y = np.full((batch_size, seq_len), fill_value=pad_id) + + for i, tok_seq in enumerate(tokens): + x[i, : len(tok_seq)] = tok_seq + y[i, : len(tok_seq) - 1] = tok_seq[1:] + batch = Batch(x=x, y=y) + assert ( + batch.mask is None or np.sum(x != pad_id) == batch.mask.sum() + ), f"{np.sum(x != pad_id)} != {batch.mask.sum()}" + yield batch + + def _create_iter_from_patch_lengths(self): sequence_iter = self.sequence_iterator.create_iter() batch_size = self.packing_args.batch_size pad_id = self.packing_args.pad_id @@ -168,6 +206,10 @@ class PackingIterator(StatefulIterator[Batch, PackingIteratorState]): _tokens = sequence.tokens _mask = sequence.mask _patch_lengths = sequence.patch_lengths + assert ( + _patch_lengths is not None + ), "patch lengths are required for packing based on patches." + # Reminder: seq_len is in terms of patches assert len(sequence.patch_lengths) == self.packing_args.seq_len last_patch_length = 0 if _patch_lengths[0] > 1: diff --git a/bytelatent/data/iterators/sequence_iterator.py b/bytelatent/data/iterators/sequence_iterator.py index 14e3747..d90ea31 100644 --- a/bytelatent/data/iterators/sequence_iterator.py +++ b/bytelatent/data/iterators/sequence_iterator.py @@ -70,15 +70,22 @@ class SequenceIterator(StatefulIterator): for example in example_iter: assert example.tokens is not None assert example.mask is not None - assert example.patch_lengths is not None + if self.preprocess_iterator.add_patches: + assert example.patch_lengths is not None + assert len(example.tokens) == sum(example.patch_lengths) + else: + assert example.patch_lengths is None assert len(example.tokens) != 0 assert len(example.mask) != 0 assert len(example.tokens) == len(example.mask) - assert len(example.tokens) == sum(example.patch_lengths) tokens.extend(example.tokens) mask.extend(example.mask) - patch_lengths.extend(example.patch_lengths) + if self.preprocess_iterator.add_patches: + patch_lengths.extend(example.patch_lengths) + else: + # This lets the rest of the code work as expected and just yield byte seqs + patch_lengths.extend([1] * len(example.tokens)) while len(patch_lengths) >= n_buffer_patches: if first: @@ -115,8 +122,15 @@ class SequenceIterator(StatefulIterator): == len(seq_mask[idx]) ), f"{sum(seq_patch_lengths[idx])}, {len(seq_tokens[idx])} {len(seq_mask[idx])}, idx={idx}" assert seq_patch_lengths[idx][0] > 0, f"{seq_patch_lengths[idx]}" - yield BltSequence( - tokens=seq_tokens[idx], - mask=seq_mask[idx], - patch_lengths=seq_patch_lengths[idx], - ) + if self.preprocess_iterator.add_patches: + yield BltSequence( + tokens=seq_tokens[idx], + mask=seq_mask[idx], + patch_lengths=seq_patch_lengths[idx], + ) + else: + yield BltSequence( + tokens=seq_tokens[idx], + mask=seq_mask[idx], + patch_lengths=None, + ) diff --git a/bytelatent/data/patcher.py b/bytelatent/data/patcher.py index afcfa2e..44ff5e9 100644 --- a/bytelatent/data/patcher.py +++ b/bytelatent/data/patcher.py @@ -22,6 +22,8 @@ class PatchingModeEnum(str, Enum): bpe = "bpe" bpe_patcher = "bpe_patcher" space = "space" + static = "static" + byte = "byte" class PatcherArgs(BaseModel): @@ -34,7 +36,6 @@ class PatcherArgs(BaseModel): max_patch_length: int | None = None patch_size: float = 4.5 patching_batch_size: int = 1 - data_loader_patching: bool = False device: str = "cuda" monotonicity: bool = False log_time: bool = False @@ -486,7 +487,6 @@ class Patcher: self.max_patch_length = patcher_args.max_patch_length self.patch_size = patcher_args.patch_size self.patching_batch_size = patcher_args.patching_batch_size - self.data_loader_patching = patcher_args.data_loader_patching self.device = patcher_args.device self.monotonicity = patcher_args.monotonicity self.log_time = patcher_args.log_time @@ -528,7 +528,7 @@ class Patcher: seq_len_next_tok = seq_len + 1 if include_next_token else seq_len scores = None # STATIC - if self.patching_mode is None: + if self.patching_mode == PatchingModeEnum.static: patch_lengths = torch.zeros( (bs, math.ceil(seq_len_next_tok / self.patch_size)), dtype=tokens.dtype, @@ -536,6 +536,10 @@ class Patcher: ).fill_(self.patch_size) if seq_len_next_tok % self.patch_size != 0: patch_lengths[:, -1] = seq_len_next_tok % self.patch_size + elif self.patching_mode == PatchingModeEnum.byte: + patch_lengths = torch.ones( + (bs, seq_len_next_tok), dtype=tokens.dtype, device=tokens.device + ) # ENTROPY elif self.patching_mode == PatchingModeEnum.entropy: if self.log_time: diff --git a/bytelatent/model/blt.py b/bytelatent/model/blt.py index 843ad34..a62be23 100644 --- a/bytelatent/model/blt.py +++ b/bytelatent/model/blt.py @@ -411,6 +411,7 @@ class ByteLatentTransformerArgs(BaseTransformerArgs): n_heads: int = 8 # TODO: What is the purpose of this parameter? weight_tying: bool = False + patch_in_forward: bool = False # Architecture and dimensions dim_token: int = 256 @@ -422,7 +423,6 @@ class ByteLatentTransformerArgs(BaseTransformerArgs): n_layers_local_encoder: int = 8 # Tokenization and patching - tokenization_mode: str = "bpe" patch_size: float | None = None patching_mode: str | None = None patching_threshold: float | None = None @@ -430,7 +430,6 @@ class ByteLatentTransformerArgs(BaseTransformerArgs): monotonicity: bool = False patching_batch_size: int = 1 patching_device: str = "cuda" - data_loader_patching: bool = False max_patch_length: int | None = None # Encoder/Decoder configuration @@ -856,7 +855,7 @@ class ByteLatentTransformer(nn.Module): self.output.weight = self.tok_embeddings.weight # Patcher module - if not args.data_loader_patching: + if args.patch_in_forward: self.patcher = Patcher( PatcherArgs( patch_size=args.patch_size, diff --git a/bytelatent/test_blt.py b/bytelatent/test_blt.py index 36a9882..eb94df3 100644 --- a/bytelatent/test_blt.py +++ b/bytelatent/test_blt.py @@ -68,10 +68,9 @@ def create_args(cross_attention=False): # Additional args from command line dim_token=256, patch_size=6, - tokenization_mode="bytes", patching_mode="space", tie_local_encoder_decoder_logits=False, - data_loader_patching=True, + patch_in_forward=False, max_encoder_seq_length=12288, pad_to_max_length=True, encoder_lm_loss=False, diff --git a/bytelatent/train.py b/bytelatent/train.py index 80bd393..1d0fa40 100644 --- a/bytelatent/train.py +++ b/bytelatent/train.py @@ -47,6 +47,7 @@ from bytelatent.probe import AutoProbeD from bytelatent.profiling import maybe_run_profiler from bytelatent.stool import StoolArgs, launch_job from bytelatent.transformer import ( + LMTransformer, build_fsdp_grouping_plan, get_no_recompute_ops, get_num_flop_per_token, @@ -103,10 +104,15 @@ class TrainState(Stateful): def validate_train_args(args: TrainArgs, output_size: int): - if args.model.vocab_size < 0: + assert args.model is not None or args.entropy_model is not None + if args.model is not None: logger.info(f"Setting model output size to {args.model.vocab_size}") args.model.vocab_size = output_size + if args.entropy_model is not None: + logger.info(f"Setting model output size to {args.entropy_model.vocab_size}") + args.entropy_model.vocab_size = output_size + assert args.dump_dir, "Dump dir not set" if args.checkpoint.path is None: @@ -147,7 +153,10 @@ def validate_train_args(args: TrainArgs, output_size: int): and args.distributed.dp_replicate == get_world_size() ) - args.model.max_seqlen = args.data.seq_len + if args.model is not None: + args.model.max_seqlen = args.data.seq_len + if args.entropy_model is not None: + args.entropy_model.max_seqlen = args.data.seq_len if args.distributed.tp_size == 1: logger.warning( @@ -237,7 +246,14 @@ def train(args: TrainArgs): # Initializing Model in meta device allows us to initialize models much bigger than 1 gpu's memory with torch.device("meta"): - model = ByteLatentTransformer(args.model) + if args.train_entropy_model: + assert args.entropy_model is not None + model = LMTransformer(args.entropy_model) + model_args = args.entropy_model + else: + assert args.model is not None + model = ByteLatentTransformer(args.model) + model_args = args.model logger.info("Model is built !") model_param_count = get_num_params(model) @@ -247,7 +263,7 @@ def train(args: TrainArgs): world_mesh, args.model, args.distributed, - fsdp_grouping_plan=build_fsdp_grouping_plan(args.model), + fsdp_grouping_plan=build_fsdp_grouping_plan(model_args), tp_parallelize=tp_parallelize, no_recompute_ops=get_no_recompute_ops(), ) @@ -267,7 +283,7 @@ def train(args: TrainArgs): model.rope_embeddings.reset_parameters() # For RoPe initialization since it's a buffer it might not be loaded else: with torch.random.fork_rng(devices=[torch.cuda.current_device()]): - torch.manual_seed(args.model.seed) + torch.manual_seed(model_args.seed) model.init_weights() check_model_value_range(model, range=10.0, std=1.0) @@ -342,10 +358,17 @@ def train(args: TrainArgs): batch.x, ).cuda() batch_y = torch.from_numpy(batch.y).cuda() - batch_patch_lengths = torch.from_numpy(batch.patch_lengths).cuda() + if batch.patch_lengths is None: + batch_patch_lengths = None + else: + batch_patch_lengths = torch.from_numpy(batch.patch_lengths).cuda() mask = None if batch.mask is None else torch.from_numpy(batch.mask).cuda() - if args.model.encoder_enable_byte_ngrams and batch.ngram_ids is None: + if ( + not args.train_entropy_model + and args.model.encoder_enable_byte_ngrams + and batch.ngram_ids is None + ): raise ValueError( "Cannot enable byte ngrams and have batch.ngram_ids be None" ) @@ -408,9 +431,12 @@ def train(args: TrainArgs): next(probe_mod.parameters()).grad is None ), "Probe model shouldn't have grads at this point" - pred = model( - batch_x, patch_lengths=batch_patch_lengths, ngram_ids=ngram_ids - ) + if args.train_entropy_model: + pred = model(batch_x) + else: + pred = model( + batch_x, patch_lengths=batch_patch_lengths, ngram_ids=ngram_ids + ) loss, _ = compute_loss(pred, batch_y, mask, train_state.scale) @@ -474,9 +500,9 @@ def train(args: TrainArgs): # Use xformer's analyze profile trace to get actual measurement FLOPS = ( get_num_flop_per_token( - model_param_count - args.model.vocab_size * args.model.dim, - args.model.n_layers, - args.model.dim, + model_param_count - model_args.vocab_size * model_args.dim, + model_args.n_layers, + model_args.dim, args.data.seq_len, ) * wps From e02ba763b00aa7075752a3615f59243c210ed188 Mon Sep 17 00:00:00 2001 From: Pedro Rodriguez Date: Tue, 28 Jan 2025 00:38:16 +0000 Subject: [PATCH 15/18] This includes fixes that make checkpointing and reloading work correctly. It also batches in a first set of changes for fixing eval code Summary: Test Plan: --- apps/main/lingua_train.py | 2 +- bytelatent/args.py | 87 ++++++++- bytelatent/checkpoint.py | 9 +- bytelatent/data/data_types.py | 10 - .../data/iterators/multiprocess_iterator.py | 15 ++ {apps/main => bytelatent}/eval.py | 179 ++++-------------- {apps/main => bytelatent}/generate.py | 57 +++--- bytelatent/train.py | 81 ++++---- 8 files changed, 219 insertions(+), 221 deletions(-) rename {apps/main => bytelatent}/eval.py (56%) rename {apps/main => bytelatent}/generate.py (91%) diff --git a/apps/main/lingua_train.py b/apps/main/lingua_train.py index bdb47da..7925ec6 100644 --- a/apps/main/lingua_train.py +++ b/apps/main/lingua_train.py @@ -544,7 +544,7 @@ def train(args: TrainArgs): if args.eval is not None and every_n_steps( train_state, args.checkpoint.eval.every, acc_step=0 ): - from apps.main.eval import EVAL_FOLDER_NAME, EvalArgs, launch_eval + from bytelatent.eval import EVAL_FOLDER_NAME, EvalArgs, launch_eval eval_args = dataclass_from_dict(EvalArgs, args.eval) diff --git a/bytelatent/args.py b/bytelatent/args.py index 56de22d..d1bac46 100644 --- a/bytelatent/args.py +++ b/bytelatent/args.py @@ -5,6 +5,7 @@ from typing import Any import numpy as np import yaml +from omegaconf import OmegaConf from pydantic import BaseModel, ConfigDict from bytelatent.checkpoint import CheckpointArgs @@ -39,6 +40,19 @@ def get_rng_state(seed: int, rank: int, world_size: int) -> dict[str, Any]: return np.random.default_rng((seed, rank, world_size)).bit_generator.state +def parse_args(args_cls): + cli_args = OmegaConf.from_cli() + file_cfg = OmegaConf.load(cli_args.config) + # We remove 'config' attribute from config as the underlying DataClass does not have it + del cli_args.config + + default_cfg = OmegaConf.create(args_cls().model_dump()) + cfg = OmegaConf.merge(default_cfg, file_cfg, cli_args) + cfg = OmegaConf.to_container(cfg, resolve=True, throw_on_missing=True) + pydantic_args = args_cls.model_validate(cfg) + return pydantic_args + + def distribute_data_to_rank( *, dataset_path: str, @@ -71,6 +85,22 @@ def distribute_data_to_rank( return rank_to_arrow_iterator_params[rank] +class PackedCausalTransformerGeneratorArgs(BaseModel): + model_config = ConfigDict(extra="forbid") + temperature: float = 0.0 + top_p: float | None = None + top_k: float | None = None + max_gen_len: int = 512 # Maximum number of tokens to generate + max_tokens: int = 1024 # Maximum number of tokens that can go through the model + max_prompt_len: int | None = None + until: list[str] = [] + compile_prefilling: bool = False + reduce_generation_overhead: bool = False + show_progress: bool = False + dtype: str | None = "bf16" + device: str | None = "cuda" + + class DataloaderArgs(BaseModel): model_config = ConfigDict(extra="forbid") s3_profile: str | None = None @@ -168,6 +198,58 @@ class DataloaderArgs(BaseModel): return packing_iterator +class LMHarnessArgs(BaseModel): + model_config = ConfigDict(extra="forbid") + tasks: list[Any] | None = None + num_fewshot: int | None = None + device: str | None = None + use_cache: str | None = None + cache_requests: bool = False + rewrite_requests_cache: bool = False + delete_requests_cache: bool = False + limit: int | float | None = None + bootstrap_iters: int = 100000 + check_integrity: bool = False + write_out: bool = False + log_samples: bool = True + system_instruction: str | None = None + apply_chat_template: bool | str = False + fewshot_as_multiturn: bool = False + gen_kwargs: str | None = None + verbosity: str = "INFO" + predict_only: bool = False + random_seed: int = 0 + numpy_random_seed: int = 1234 + torch_random_seed: int = 1234 + fewshot_random_seed: int = 1234 + + +class ValidationArgs(BaseModel): + model_config = ConfigDict(extra="forbid") + max_steps: int | None = ( + None # If None the whole validation file is used -> /!\ This number of steps is gpu dependent (100 max steps on 8 gpus = 800 steps on 1 gpu) + ) + use_val_from_train_src: bool = True # Use the validation set from training sources + root_dir: str = "" + sources: list[str] = [] # Other sources to eval on + + +class EvalArgs(BaseModel): + model_config = ConfigDict(extra="forbid") + dump_dir: str + ckpt_dir: str + metric_log_dir: str | None = None + generator: PackedCausalTransformerGeneratorArgs = ( + PackedCausalTransformerGeneratorArgs() + ) + + harness: LMHarnessArgs | None = LMHarnessArgs() + validation: ValidationArgs | None = ValidationArgs() + + global_step: int | None = None # for in-training evaluation + s3_profile: str | None = None + + class TrainArgs(BaseModel): model_config = ConfigDict(extra="forbid") name: str = "lingua" @@ -186,6 +268,9 @@ class TrainArgs(BaseModel): # Nb optimizer steps to take steps: int = 1000 + # If not None, halt training after this many steps, + # useful for debugging + max_steps: int | None = None data: DataloaderArgs = DataloaderArgs() optim: OptimArgs = OptimArgs() @@ -203,7 +288,7 @@ class TrainArgs(BaseModel): # If set to None, eval is run locally otherwise it launches a new job with the given number of gpus async_eval_gpus: int | None = None - eval: Any | None = None + eval: EvalArgs | None = None eval_on_gpus: int | None = None def dump_to_yaml_file( diff --git a/bytelatent/checkpoint.py b/bytelatent/checkpoint.py index bcf591e..f213c84 100644 --- a/bytelatent/checkpoint.py +++ b/bytelatent/checkpoint.py @@ -7,6 +7,7 @@ import re from pathlib import Path from typing import List, Optional, Tuple +import fsspec import torch import torch.distributed as dist import torch.distributed.checkpoint as dcp @@ -21,6 +22,7 @@ from torch.distributed.checkpoint.state_dict import ( set_state_dict, ) +from bytelatent.data.file_util import get_fs from bytelatent.distributed import get_is_master logger = logging.getLogger("CHECKPOINT") @@ -51,13 +53,14 @@ class CheckpointArgs(BaseModel): path: str | None = None init_ckpt_path: str | None = None continue_training_from_init: bool = False + s3_profile: str | None = None def _get_key_step(name: str): return int(re.findall(RE_DIGITS, name)[-1]) -def consolidate_checkpoints(ckpt_dir: str): +def consolidate_checkpoints(fs: fsspec.AbstractFileSystem, ckpt_dir: str): """ Consolidates all FSDP checkpoints in a directory to a single file Consolidate checkpoint is saved in a subdirectory of ckpt_dir @@ -102,15 +105,17 @@ def load_from_checkpoint( dcp.load(state_dict, checkpoint_id=ckpt_dir) +# TODO: Rewrite the file operations here to use fsspec to enable s3 writing. class CheckpointManager: def __init__(self, args: CheckpointArgs): self.path = args.path + self.fs = get_fs(self.path, s3_profile=args.s3_profile) self.dump_every = args.dump self.eval_every = args.eval self.init_ckpt_path = args.init_ckpt_path self.continue_training_from_init = args.continue_training_from_init - assert os.path.exists( + assert self.fs.exists( self.path ), f"Path {self.path} does not exist and needs to be created before using CheckpointManager (use instantiate_and_make_dir)" diff --git a/bytelatent/data/data_types.py b/bytelatent/data/data_types.py index aa2daa9..f4bbc07 100644 --- a/bytelatent/data/data_types.py +++ b/bytelatent/data/data_types.py @@ -40,16 +40,6 @@ class BltPackTokensState(BaseModel): n_views: int = 2 -class DataLoaderState(BaseModel): - model_config = ConfigDict(extra="forbid") - multi_choice_state: MultiChoiceState - pack_tokens_state: BltPackTokensState - prefetch_state: PrefetchState - - -BltIterator = Iterator[tuple[BltExample, DataLoaderState]] - - class BltSequence(BaseModel): tokens: list[int] mask: list[bool] diff --git a/bytelatent/data/iterators/multiprocess_iterator.py b/bytelatent/data/iterators/multiprocess_iterator.py index f17ca6e..49d99ac 100644 --- a/bytelatent/data/iterators/multiprocess_iterator.py +++ b/bytelatent/data/iterators/multiprocess_iterator.py @@ -128,6 +128,13 @@ class MultiprocessIterator(StatefulIterator): self.producer = None self.stop_iterating_event = None self.state_dumped_event = None + self.force_shutdown = False + + def shutdown(self): + if self.producer is not None: + # This properly shuts things down + self.producer.kill() + self.force_shutdown = True def get_state(self) -> MultiprocessIteratorState: """ @@ -135,6 +142,10 @@ class MultiprocessIterator(StatefulIterator): to halt the background process and allow it to write the state to the main loop in order to not lose data """ + if self.force_shutdown: + raise ValueError( + "State will be invalid if shutdown was forced before state persisted." + ) if self.producer is None: serialized_prefetch_buffer = json.dumps( [b.to_python_dict() for b in self.prefetch_buffer] @@ -187,6 +198,10 @@ class MultiprocessIterator(StatefulIterator): ) def create_iter(self): + if self.force_shutdown: + raise ValueError( + "Iterator may be invalid if shutdown was forced before state persisted." + ) logging.info("Main thread: Creating MP iterator") # First yield from the stored prefetch buffer. if self.prefetch_buffer is not None: diff --git a/apps/main/eval.py b/bytelatent/eval.py similarity index 56% rename from apps/main/eval.py rename to bytelatent/eval.py index ed20f49..ae73066 100644 --- a/apps/main/eval.py +++ b/bytelatent/eval.py @@ -4,20 +4,20 @@ import json import logging import os from collections import defaultdict -from dataclasses import asdict, dataclass, field from datetime import datetime from pathlib import Path -from typing import Any, List, Optional, Tuple, Union +from typing import Any import torch -from lingua.args import dump_config -from lingua.data import init_choice_state, setup_sources from lm_eval import simple_evaluate from lm_eval.api.instance import Instance from lm_eval.api.model import LM from omegaconf import OmegaConf +from pydantic import BaseModel, ConfigDict +from bytelatent.args import EvalArgs, ValidationArgs, parse_args from bytelatent.checkpoint import CONSOLIDATE_FOLDER, consolidate_checkpoints +from bytelatent.data.file_util import get_fs from bytelatent.distributed import ( DistributedArgs, dist_mean_dict, @@ -25,72 +25,17 @@ from bytelatent.distributed import ( get_world_size, setup_torch_distributed, ) -from bytelatent.transformer import LMTransformer, LMTransformerArgs - -from apps.main.generate import ( +from bytelatent.generate import ( PackedCausalTransformerGenerator, - PackedCausalTransformerGeneratorArgs, load_consolidated_model_and_tokenizer, ) +from bytelatent.transformer import LMTransformer, LMTransformerArgs EVAL_FOLDER_NAME = "{:010d}" logger = logging.getLogger() -@dataclass -class LMHarnessArgs: - tasks: Optional[List[Any]] = None - num_fewshot: Optional[int] = None - device: Optional[str] = None - use_cache: Optional[str] = None - cache_requests: bool = False - rewrite_requests_cache: bool = False - delete_requests_cache: bool = False - limit: Optional[Union[int, float]] = None - bootstrap_iters: int = 100000 - check_integrity: bool = False - write_out: bool = False - log_samples: bool = True - system_instruction: Optional[str] = None - apply_chat_template: Union[bool, str] = False - fewshot_as_multiturn: bool = False - gen_kwargs: Optional[str] = None - verbosity: str = "INFO" - predict_only: bool = False - random_seed: int = 0 - numpy_random_seed: int = 1234 - torch_random_seed: int = 1234 - fewshot_random_seed: int = 1234 - - -@dataclass -class ValidationArgs: - max_steps: Optional[int] = ( - None # If None the whole validation file is used -> /!\ This number of steps is gpu dependent (100 max steps on 8 gpus = 800 steps on 1 gpu) - ) - use_val_from_train_src: bool = True # Use the validation set from training sources - root_dir: str = "" - sources: List[str] = field(default_factory=list) # Other sources to eval on - - -@dataclass -class EvalArgs: - name: str = "evals" - dump_dir: Optional[str] = None - metric_log_dir: Optional[str] = None - ckpt_dir: str = "" - generator: PackedCausalTransformerGeneratorArgs = field( - default_factory=PackedCausalTransformerGeneratorArgs - ) - harness: Optional[LMHarnessArgs] = field(default_factory=LMHarnessArgs) - validation: Optional[ValidationArgs] = field(default_factory=ValidationArgs) - - wandb: Optional[Any] = None - - global_step: Optional[int] = None # for in-training evaluation - - def all_dicts_same(dict_list): if not dict_list: # Check if the list is empty return True @@ -120,7 +65,7 @@ class EvalHarnessLM(LM): self._world_size = get_world_size() self.device = generator.device - def generate_until(self, requests: List[Instance]) -> List[str]: + def generate_until(self, requests: list[Instance]) -> list[str]: prompts, gen_args = zip(*[req.args for req in requests]) assert all_dicts_same(gen_args), "Doesn't support different gen args for now" gen_args = gen_args[0] @@ -141,7 +86,7 @@ class EvalHarnessLM(LM): filtered_gen.append(g) return filtered_gen - def loglikelihood(self, requests: List[Instance]) -> List[Tuple[float, bool]]: + def loglikelihood(self, requests: list[Instance]) -> list[tuple[float, bool]]: prompts, continuations = zip(*[req.args for req in requests]) inputs = [req.args[0] + req.args[1] for req in requests] max_gen_len = self.generator.max_gen_len @@ -158,7 +103,7 @@ class EvalHarnessLM(LM): self.generator.max_gen_len = max_gen_len return results - def loglikelihood_rolling(self, requests: List[Instance]) -> List[float]: + def loglikelihood_rolling(self, requests: list[Instance]) -> list[float]: prompts = [req.args[0] for req in requests] max_gen_len = self.generator.max_gen_len # We temporarily lower max gen len @@ -232,68 +177,73 @@ def eval_on_val(generator, val_args: ValidationArgs, train_cfg): return all_val_metrics -def launch_eval(cfg: EvalArgs): +def launch_eval(eval_args: EvalArgs): if not torch.distributed.is_initialized(): setup_torch_distributed(DistributedArgs()) + + fs = get_fs(eval_args.ckpt_dir, s3_profile=eval_args.s3_profile) if ( - Path(cfg.ckpt_dir).exists() - and (Path(cfg.ckpt_dir) / "params.json").exists() - and next(Path(cfg.ckpt_dir).glob("*.pth"), None) is not None + fs.exists(eval_args.ckpt_dir) + and fs.exists(os.path.join(eval_args.ckpt_dir, "params.json")) + and len(fs.glob(os.path.join(eval_args.ckpt_dir, "*.pth"))) != 0 ): - consolidate_path = Path(cfg.ckpt_dir) + consolidate_path = eval_args.ckpt_dir else: - consolidate_path = Path(cfg.ckpt_dir) / CONSOLIDATE_FOLDER - if not consolidate_path.exists() and get_global_rank() == 0: - consolidate_path = consolidate_checkpoints(cfg.ckpt_dir) + consolidate_path = os.path.join(eval_args.ckpt_dir, CONSOLIDATE_FOLDER) + if not fs.exists(consolidate_path) and get_global_rank() == 0: + consolidate_path = consolidate_checkpoints(eval_args.ckpt_dir) - Path(cfg.dump_dir).mkdir(parents=True, exist_ok=True) - dump_config(cfg, Path(cfg.dump_dir) / "config.yaml", log_config=False) + fs.mkdirs(eval_args.dump_dir, exist_ok=True) + with fs.open(os.path.join(eval_args.dump_dir, "config.yaml"), "w") as f: + f.write(eval_args.model_dump_json()) - consolidate_path = str(consolidate_path) torch.distributed.barrier() logger.info("Loading model") + # TODO: Make this general so that it works with either + # LMTransformer or Blt, similar with args model, tokenizer, train_cfg = load_consolidated_model_and_tokenizer( consolidate_path, - model_cls=LMTransformer, - model_args_cls=LMTransformerArgs, ) logger.info("Model loaded") model.eval() - generator = PackedCausalTransformerGenerator(cfg.generator, model, tokenizer) + generator = PackedCausalTransformerGenerator(eval_args.generator, model, tokenizer) wrap = EvalHarnessLM(generator) - results = simple_evaluate(wrap, **asdict(cfg.harness)) + # Redo + results = simple_evaluate(wrap, eval_args.harness.model_dump()) val_results = None - if cfg.validation: - val_results = eval_on_val(generator, cfg.validation, train_cfg) + if eval_args.validation: + val_results = eval_on_val(generator, eval_args.validation, train_cfg) if get_global_rank() == 0: - with open(Path(cfg.dump_dir) / "results.json", "w") as f: + with fs.open(os.path.join(eval_args.dump_dir, "results.json"), "w") as f: f.write(json.dumps(results)) logger.info(f"All evaluation results: {results['results']}") if val_results is not None: - with open(Path(cfg.dump_dir) / "validation.json", "w") as f: + with fs.open(os.path.join(eval_args.dump_dir, "validation.json"), "w") as f: f.write(json.dumps(val_results)) logger.info(f"All validation results: {val_results}") - if cfg.metric_log_dir and get_global_rank() == 0: - metric_log_path = Path(cfg.metric_log_dir) / "metrics.eval.jsonl" + if eval_args.metric_log_dir and get_global_rank() == 0: + metric_log_path = os.path.join(eval_args.metric_log_dir, "metrics.eval.jsonl") logger.info(f"Writing metric logs to {metric_log_path}") timestamp = { "created_at": datetime.utcnow().isoformat(), } - if cfg.global_step is not None: - timestamp["global_step"] = cfg.global_step + if eval_args.global_step is not None: + timestamp["global_step"] = eval_args.global_step print( json.dumps(timestamp | results["results"]), - file=open(metric_log_path, mode="a"), + file=fs.open(metric_log_path, mode="a"), flush=True, ) - val_log_path = Path(cfg.metric_log_dir) / "metrics.validation.jsonl" + val_log_path = os.path.join( + eval_args.metric_log_dir, "metrics.validation.jsonl" + ) if val_results is not None: print( json.dumps(timestamp | val_results), - file=open(val_log_path, mode="a"), + file=fs.open(val_log_path, mode="a"), flush=True, ) @@ -301,53 +251,8 @@ def launch_eval(cfg: EvalArgs): def main(): - """ - The command line interface here uses OmegaConf https://omegaconf.readthedocs.io/en/2.3_branch/usage.html#from-command-line-arguments - This accepts arguments as a dot list - So if the dataclass looks like - - @dataclass - class DummyArgs: - name: str - model: LMTransformerArgsgs - - @dataclass - class LMTransformerArgsgs: - dim: int - - Then you can pass model.dim=32 to change values in LMTransformerArgsgs - or just name=tictac for top level attributes. - - The behavior here is as follows: - 1. We instantiate EvalArgs with its default values - 2. We override those default values with the ones in the provided config file - 3. We override the result with the additional arguments provided through command line - - For example, if the config is the following - - model: - dim: 128 - n_layers: 4 - - and you call eval.py with eval.py model.dim=64 - - Then the final TrainArgs will have - - model: - dim: 64 - n_layers: 4 - - Plus all the default values in EvalArgs dataclass. - """ - cli_args = OmegaConf.from_cli() - file_cfg = OmegaConf.load(cli_args.config) - # We remove 'config' attribute from config as the underlying DataClass does not have it - del cli_args.config - - default_cfg = OmegaConf.structured(EvalArgs()) - cfg = OmegaConf.merge(default_cfg, file_cfg, cli_args) - cfg = OmegaConf.to_object(cfg) - launch_eval(cfg) + eval_args = parse_args(EvalArgs) + launch_eval(eval_args) if __name__ == "__main__": diff --git a/apps/main/generate.py b/bytelatent/generate.py similarity index 91% rename from apps/main/generate.py rename to bytelatent/generate.py index a3a8627..eb79d81 100644 --- a/apps/main/generate.py +++ b/bytelatent/generate.py @@ -1,20 +1,16 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. +import os import time -from dataclasses import dataclass, field -from pathlib import Path -from typing import List, Optional import torch -from lingua.args import dataclass_from_dict -from lingua.tokenizers.abstract_tokenizer import Tokenizer -from lingua.tokenizers.build_tokenizer import build_tokenizer from omegaconf import OmegaConf from torch import nn from torch.nn import functional as F from torch.nn.attention.flex_attention import create_block_mask from tqdm import tqdm +from bytelatent.args import PackedCausalTransformerGeneratorArgs, TrainArgs from bytelatent.base_transformer import ( Attention, causal_mask, @@ -23,7 +19,10 @@ from bytelatent.base_transformer import ( lengths_to_start_ids, ) from bytelatent.checkpoint import CONSOLIDATE_NAME -from bytelatent.transformer import LMTransformer, LMTransformerArgs +from bytelatent.data.file_util import get_fs +from bytelatent.model.blt import ByteLatentTransformer +from bytelatent.tokenizers.abstract_tokenizer import Tokenizer +from bytelatent.transformer import LMTransformer def sample_top_p(probs: torch.Tensor, p: float) -> torch.Tensor: @@ -62,7 +61,7 @@ def sample_tokens(logits, temperature=0.0, top_p=None, top_k=None): return next_token.view(shape[:-1]) -def pack_prompts(prompts: List[int]): +def pack_prompts(prompts: list[int]): res = [] lengths = [] for i, p in enumerate(prompts): @@ -120,22 +119,6 @@ class KVCache(nn.Module): return self.k_cache, self.v_cache -@dataclass -class PackedCausalTransformerGeneratorArgs: - temperature: float = 0.0 - top_p: Optional[float] = None - top_k: Optional[float] = None - max_gen_len: int = 512 # Maximum number of tokens to generate - max_tokens: int = 1024 # Maximum number of tokens that can go through the model - max_prompt_len: Optional[int] = None - until: List[str] = field(default_factory=list) - compile_prefilling: bool = False - reduce_generation_overhead: bool = False - show_progress: bool = False - dtype: Optional[str] = "bf16" - device: Optional[str] = "cuda" - - class PackedCausalTransformerGenerator: def __init__( self, @@ -401,25 +384,29 @@ class PackedCausalTransformerGenerator: def load_consolidated_model_and_tokenizer( consolidated_path, - model_cls=LMTransformer, - model_args_cls=LMTransformerArgs, ): - ckpt_path = Path(consolidated_path) - config = ckpt_path / "params.json" - config = OmegaConf.load(config) + train_args_path = os.path.join(consolidated_path, "params.json") + fs = get_fs(train_args_path) + with fs.open(train_args_path) as f: + train_args = TrainArgs.model_validate_json(f.read()) + + if train_args.train_entropy_model: + model_args = train_args.entropy_model + model = LMTransformer(model_args) + else: + model_args = train_args.model + model = ByteLatentTransformer(model_args) param_dtype = dict(fp32=torch.float32, fp16=torch.float16, bf16=torch.bfloat16)[ - config.distributed.model_dtype + train_args.distributed.model_dtype ] - model_args = dataclass_from_dict(model_args_cls, config.model, strict=False) - tokenizer = build_tokenizer(config.data.tokenizer.name, config.data.tokenizer.path) - model = model_cls(model_args) - st_dict = torch.load(ckpt_path / CONSOLIDATE_NAME, weights_only=True) + tokenizer = train_args.data.tokenizer_args.build() + st_dict = torch.load(consolidated_path / CONSOLIDATE_NAME, weights_only=True) model.load_state_dict(st_dict["model"]) model = model.cuda().eval() for param in model.parameters(): param.data = param.data.to(dtype=param_dtype) - return model, tokenizer, config + return model, tokenizer, train_args def main(): diff --git a/bytelatent/train.py b/bytelatent/train.py index 1d0fa40..6b20ecd 100644 --- a/bytelatent/train.py +++ b/bytelatent/train.py @@ -10,7 +10,7 @@ from copy import deepcopy from dataclasses import asdict, dataclass from pathlib import Path from timeit import default_timer as timer -from typing import Any, Dict, Type, TypeVar +from typing import Any, TypeVar import torch import torch.distributed @@ -23,9 +23,13 @@ from torch.distributed._tensor import DTensor from torch.distributed.checkpoint.stateful import Stateful from torch.optim import lr_scheduler -from bytelatent.args import TrainArgs +from bytelatent.args import TrainArgs, parse_args from bytelatent.checkpoint import CheckpointManager, load_from_checkpoint -from bytelatent.data.data_types import DataLoaderState +from bytelatent.data.iterators.multiprocess_iterator import ( + MultiprocessIterator, + MultiprocessIteratorState, +) +from bytelatent.data.iterators.packing_iterator import PackingIteratorState from bytelatent.distributed import ( check_model_value_range, clean_env, @@ -39,6 +43,7 @@ from bytelatent.distributed import ( setup_env, setup_torch_distributed, ) +from bytelatent.eval import EVAL_FOLDER_NAME, launch_eval from bytelatent.logger import init_logger from bytelatent.metrics import GPUMemoryMonitor, MetricLogger, get_num_params from bytelatent.model.blt import ByteLatentTransformer @@ -70,36 +75,49 @@ def flatten_dict(d, parent_key="", sep="_"): return dict(items) -def dataclass_from_dict(cls: Type[T], data: dict, strict: bool = True) -> T: - """ - Converts a dictionary to a dataclass instance, recursively for nested structures. - """ - base = OmegaConf.structured(cls()) - OmegaConf.set_struct(base, strict) - override = OmegaConf.create(data) - return OmegaConf.to_object(OmegaConf.merge(base, override)) +def get_iterator_state_name(iterator_state): + if isinstance(iterator_state, MultiprocessIteratorState): + return "multiprocess" + elif isinstance(iterator_state, PackingIteratorState): + return "packing" + else: + raise ValueError(f"Unsupported iterator to get name from: {iterator_state}") +# TODO: Make this pydantic based instead of data class based +# TODO: Generalize this to any iterator state @dataclass class TrainState(Stateful): step: int # Nb of steps taken by the optimizer acc_step: int # Nb of accumulation steps done since last optimizer step scheduler: lr_scheduler.LambdaLR - data_loader_state: DataLoaderState + data_loader_state: MultiprocessIteratorState | PackingIteratorState scale: float = 1.0 + data_loader_class: str | None = None - def state_dict(self) -> Dict[str, Any]: + def state_dict(self) -> dict[str, Any]: return { "step": self.step, "acc_step": self.acc_step, - "data_loader_state": self.data_loader_state.dict(), + "data_loader_state": self.data_loader_state.model_dump(), + "data_loader_class": get_iterator_state_name(self.data_loader_state), "scheduler": self.scheduler.state_dict(), } def load_state_dict(self, state_dict): self.step = state_dict["step"] self.acc_step = state_dict["acc_step"] - self.data_loader_state = DataLoaderState(**state_dict["data_loader_state"]) + self.data_loader_class = state_dict["data_loader_class"] + if self.data_loader_class == "multiprocess": + self.data_loader_state = MultiprocessIteratorState( + **state_dict["data_loader_state"] + ) + elif self.data_loader_class == "packing": + self.data_loader_state = PackingIteratorState( + **state_dict["data_loader_state"] + ) + else: + raise ValueError(f"invalid data loader class: {self.data_loader_class}") self.scheduler.load_state_dict(state_dict["scheduler"]) @@ -345,7 +363,10 @@ def train(args: TrainArgs): nwords_since_last_log = 0 time_last_log = timer() gc.collect() - while train_state.step < args.steps: + saved = False + while train_state.step < args.steps and ( + args.max_steps is None or train_state.step < args.max_steps + ): # We constrain train_state.acc_step to be in range 0 to args.grad_acc_steps - 1 train_state.acc_step += 1 train_state.acc_step = train_state.acc_step % args.grad_acc_steps @@ -552,7 +573,6 @@ def train(args: TrainArgs): f" pow: {gpu_mem_stats.power_draw/1000} W" ) - saved = False if every_n_steps( train_state, args.checkpoint.dump.every, acc_step=0 ) or every_n_steps(train_state, args.checkpoint.eval.every, acc_step=0): @@ -567,18 +587,14 @@ def train(args: TrainArgs): if args.eval is not None and every_n_steps( train_state, args.checkpoint.eval.every, acc_step=0 ): - from apps.main.eval import EVAL_FOLDER_NAME, EvalArgs, launch_eval - - eval_args = dataclass_from_dict(EvalArgs, args.eval) + eval_args = args.eval eval_args.global_step = train_state.step eval_args.ckpt_dir = str(checkpoint.existing_saves[-1]) - eval_args.dump_dir = str( - os.path.join( - args.dump_dir, - "evals", - EVAL_FOLDER_NAME.format(train_state.step), - ) + eval_args.dump_dir = os.path.join( + args.dump_dir, + "evals", + EVAL_FOLDER_NAME.format(train_state.step), ) eval_args.metric_log_dir = args.dump_dir if args.async_eval_gpus is None: @@ -619,6 +635,9 @@ def train(args: TrainArgs): args, device_mesh=world_mesh, ) + if isinstance(data_loader, MultiprocessIterator): + logger.info("Closing MP iterator before exiting") + data_loader.shutdown() gc.collect() @@ -661,15 +680,7 @@ def main(): Plus all the default values in TrainArgs dataclass. """ - cli_args = OmegaConf.from_cli() - file_cfg = OmegaConf.load(cli_args.config) - # We remove 'config' attribute from config as the underlying DataClass does not have it - del cli_args.config - - default_cfg = OmegaConf.create(TrainArgs().model_dump()) - cfg = OmegaConf.merge(default_cfg, file_cfg, cli_args) - cfg = OmegaConf.to_container(cfg, resolve=True, throw_on_missing=True) - train_args = TrainArgs.model_validate(cfg) + train_args = parse_args(TrainArgs) if train_args.debug_dynamo: import torch._dynamo From caf82b924e9f7fadc1b850cb30b59471c796fa8e Mon Sep 17 00:00:00 2001 From: Pedro Rodriguez Date: Tue, 28 Jan 2025 00:38:16 +0000 Subject: [PATCH 16/18] This includes fixes that make checkpointing and reloading work correctly. It also batches in a first set of changes for fixing eval code Summary: Test Plan: --- apps/main/lingua_train.py | 2 +- bytelatent/args.py | 87 ++++++++- bytelatent/checkpoint.py | 9 +- bytelatent/configs/debug.yaml | 9 +- bytelatent/configs/entropy_model.yaml | 9 +- bytelatent/data/data_types.py | 10 - .../data/iterators/multiprocess_iterator.py | 15 ++ {apps/main => bytelatent}/eval.py | 179 ++++-------------- {apps/main => bytelatent}/generate.py | 57 +++--- bytelatent/train.py | 81 ++++---- 10 files changed, 221 insertions(+), 237 deletions(-) rename {apps/main => bytelatent}/eval.py (56%) rename {apps/main => bytelatent}/generate.py (91%) diff --git a/apps/main/lingua_train.py b/apps/main/lingua_train.py index bdb47da..7925ec6 100644 --- a/apps/main/lingua_train.py +++ b/apps/main/lingua_train.py @@ -544,7 +544,7 @@ def train(args: TrainArgs): if args.eval is not None and every_n_steps( train_state, args.checkpoint.eval.every, acc_step=0 ): - from apps.main.eval import EVAL_FOLDER_NAME, EvalArgs, launch_eval + from bytelatent.eval import EVAL_FOLDER_NAME, EvalArgs, launch_eval eval_args = dataclass_from_dict(EvalArgs, args.eval) diff --git a/bytelatent/args.py b/bytelatent/args.py index 56de22d..d1bac46 100644 --- a/bytelatent/args.py +++ b/bytelatent/args.py @@ -5,6 +5,7 @@ from typing import Any import numpy as np import yaml +from omegaconf import OmegaConf from pydantic import BaseModel, ConfigDict from bytelatent.checkpoint import CheckpointArgs @@ -39,6 +40,19 @@ def get_rng_state(seed: int, rank: int, world_size: int) -> dict[str, Any]: return np.random.default_rng((seed, rank, world_size)).bit_generator.state +def parse_args(args_cls): + cli_args = OmegaConf.from_cli() + file_cfg = OmegaConf.load(cli_args.config) + # We remove 'config' attribute from config as the underlying DataClass does not have it + del cli_args.config + + default_cfg = OmegaConf.create(args_cls().model_dump()) + cfg = OmegaConf.merge(default_cfg, file_cfg, cli_args) + cfg = OmegaConf.to_container(cfg, resolve=True, throw_on_missing=True) + pydantic_args = args_cls.model_validate(cfg) + return pydantic_args + + def distribute_data_to_rank( *, dataset_path: str, @@ -71,6 +85,22 @@ def distribute_data_to_rank( return rank_to_arrow_iterator_params[rank] +class PackedCausalTransformerGeneratorArgs(BaseModel): + model_config = ConfigDict(extra="forbid") + temperature: float = 0.0 + top_p: float | None = None + top_k: float | None = None + max_gen_len: int = 512 # Maximum number of tokens to generate + max_tokens: int = 1024 # Maximum number of tokens that can go through the model + max_prompt_len: int | None = None + until: list[str] = [] + compile_prefilling: bool = False + reduce_generation_overhead: bool = False + show_progress: bool = False + dtype: str | None = "bf16" + device: str | None = "cuda" + + class DataloaderArgs(BaseModel): model_config = ConfigDict(extra="forbid") s3_profile: str | None = None @@ -168,6 +198,58 @@ class DataloaderArgs(BaseModel): return packing_iterator +class LMHarnessArgs(BaseModel): + model_config = ConfigDict(extra="forbid") + tasks: list[Any] | None = None + num_fewshot: int | None = None + device: str | None = None + use_cache: str | None = None + cache_requests: bool = False + rewrite_requests_cache: bool = False + delete_requests_cache: bool = False + limit: int | float | None = None + bootstrap_iters: int = 100000 + check_integrity: bool = False + write_out: bool = False + log_samples: bool = True + system_instruction: str | None = None + apply_chat_template: bool | str = False + fewshot_as_multiturn: bool = False + gen_kwargs: str | None = None + verbosity: str = "INFO" + predict_only: bool = False + random_seed: int = 0 + numpy_random_seed: int = 1234 + torch_random_seed: int = 1234 + fewshot_random_seed: int = 1234 + + +class ValidationArgs(BaseModel): + model_config = ConfigDict(extra="forbid") + max_steps: int | None = ( + None # If None the whole validation file is used -> /!\ This number of steps is gpu dependent (100 max steps on 8 gpus = 800 steps on 1 gpu) + ) + use_val_from_train_src: bool = True # Use the validation set from training sources + root_dir: str = "" + sources: list[str] = [] # Other sources to eval on + + +class EvalArgs(BaseModel): + model_config = ConfigDict(extra="forbid") + dump_dir: str + ckpt_dir: str + metric_log_dir: str | None = None + generator: PackedCausalTransformerGeneratorArgs = ( + PackedCausalTransformerGeneratorArgs() + ) + + harness: LMHarnessArgs | None = LMHarnessArgs() + validation: ValidationArgs | None = ValidationArgs() + + global_step: int | None = None # for in-training evaluation + s3_profile: str | None = None + + class TrainArgs(BaseModel): model_config = ConfigDict(extra="forbid") name: str = "lingua" @@ -186,6 +268,9 @@ class TrainArgs(BaseModel): # Nb optimizer steps to take steps: int = 1000 + # If not None, halt training after this many steps, + # useful for debugging + max_steps: int | None = None data: DataloaderArgs = DataloaderArgs() optim: OptimArgs = OptimArgs() @@ -203,7 +288,7 @@ class TrainArgs(BaseModel): # If set to None, eval is run locally otherwise it launches a new job with the given number of gpus async_eval_gpus: int | None = None - eval: Any | None = None + eval: EvalArgs | None = None eval_on_gpus: int | None = None def dump_to_yaml_file( diff --git a/bytelatent/checkpoint.py b/bytelatent/checkpoint.py index bcf591e..f213c84 100644 --- a/bytelatent/checkpoint.py +++ b/bytelatent/checkpoint.py @@ -7,6 +7,7 @@ import re from pathlib import Path from typing import List, Optional, Tuple +import fsspec import torch import torch.distributed as dist import torch.distributed.checkpoint as dcp @@ -21,6 +22,7 @@ from torch.distributed.checkpoint.state_dict import ( set_state_dict, ) +from bytelatent.data.file_util import get_fs from bytelatent.distributed import get_is_master logger = logging.getLogger("CHECKPOINT") @@ -51,13 +53,14 @@ class CheckpointArgs(BaseModel): path: str | None = None init_ckpt_path: str | None = None continue_training_from_init: bool = False + s3_profile: str | None = None def _get_key_step(name: str): return int(re.findall(RE_DIGITS, name)[-1]) -def consolidate_checkpoints(ckpt_dir: str): +def consolidate_checkpoints(fs: fsspec.AbstractFileSystem, ckpt_dir: str): """ Consolidates all FSDP checkpoints in a directory to a single file Consolidate checkpoint is saved in a subdirectory of ckpt_dir @@ -102,15 +105,17 @@ def load_from_checkpoint( dcp.load(state_dict, checkpoint_id=ckpt_dir) +# TODO: Rewrite the file operations here to use fsspec to enable s3 writing. class CheckpointManager: def __init__(self, args: CheckpointArgs): self.path = args.path + self.fs = get_fs(self.path, s3_profile=args.s3_profile) self.dump_every = args.dump self.eval_every = args.eval self.init_ckpt_path = args.init_ckpt_path self.continue_training_from_init = args.continue_training_from_init - assert os.path.exists( + assert self.fs.exists( self.path ), f"Path {self.path} does not exist and needs to be created before using CheckpointManager (use instantiate_and_make_dir)" diff --git a/bytelatent/configs/debug.yaml b/bytelatent/configs/debug.yaml index 1098ff5..07d489f 100644 --- a/bytelatent/configs/debug.yaml +++ b/bytelatent/configs/debug.yaml @@ -98,11 +98,4 @@ logging: freq: 10 eval_on_gpus: 8 -eval: - dataset_dir: /checkpoint/amaia/codegen/datasets/eval - tasks: boolq,hellaswag,nq,piqa,siqa,tqa,winogrande,obqa,arc_easy,arc_challenge,race.middle,race.high,gsm8k,math,bbh,copa,human_eval_plus,mbpp,mmlu - generator: - max_tokens: 65536 - dtype: bf16 - - mp_size: 1 +eval: null diff --git a/bytelatent/configs/entropy_model.yaml b/bytelatent/configs/entropy_model.yaml index 51b65d4..d7c27b7 100644 --- a/bytelatent/configs/entropy_model.yaml +++ b/bytelatent/configs/entropy_model.yaml @@ -72,11 +72,4 @@ logging: freq: 10 eval_on_gpus: 8 -eval: - dataset_dir: ??? - tasks: ??? - generator: - max_tokens: 65536 - dtype: bf16 - - mp_size: 1 +eval: null diff --git a/bytelatent/data/data_types.py b/bytelatent/data/data_types.py index aa2daa9..f4bbc07 100644 --- a/bytelatent/data/data_types.py +++ b/bytelatent/data/data_types.py @@ -40,16 +40,6 @@ class BltPackTokensState(BaseModel): n_views: int = 2 -class DataLoaderState(BaseModel): - model_config = ConfigDict(extra="forbid") - multi_choice_state: MultiChoiceState - pack_tokens_state: BltPackTokensState - prefetch_state: PrefetchState - - -BltIterator = Iterator[tuple[BltExample, DataLoaderState]] - - class BltSequence(BaseModel): tokens: list[int] mask: list[bool] diff --git a/bytelatent/data/iterators/multiprocess_iterator.py b/bytelatent/data/iterators/multiprocess_iterator.py index f17ca6e..49d99ac 100644 --- a/bytelatent/data/iterators/multiprocess_iterator.py +++ b/bytelatent/data/iterators/multiprocess_iterator.py @@ -128,6 +128,13 @@ class MultiprocessIterator(StatefulIterator): self.producer = None self.stop_iterating_event = None self.state_dumped_event = None + self.force_shutdown = False + + def shutdown(self): + if self.producer is not None: + # This properly shuts things down + self.producer.kill() + self.force_shutdown = True def get_state(self) -> MultiprocessIteratorState: """ @@ -135,6 +142,10 @@ class MultiprocessIterator(StatefulIterator): to halt the background process and allow it to write the state to the main loop in order to not lose data """ + if self.force_shutdown: + raise ValueError( + "State will be invalid if shutdown was forced before state persisted." + ) if self.producer is None: serialized_prefetch_buffer = json.dumps( [b.to_python_dict() for b in self.prefetch_buffer] @@ -187,6 +198,10 @@ class MultiprocessIterator(StatefulIterator): ) def create_iter(self): + if self.force_shutdown: + raise ValueError( + "Iterator may be invalid if shutdown was forced before state persisted." + ) logging.info("Main thread: Creating MP iterator") # First yield from the stored prefetch buffer. if self.prefetch_buffer is not None: diff --git a/apps/main/eval.py b/bytelatent/eval.py similarity index 56% rename from apps/main/eval.py rename to bytelatent/eval.py index ed20f49..ae73066 100644 --- a/apps/main/eval.py +++ b/bytelatent/eval.py @@ -4,20 +4,20 @@ import json import logging import os from collections import defaultdict -from dataclasses import asdict, dataclass, field from datetime import datetime from pathlib import Path -from typing import Any, List, Optional, Tuple, Union +from typing import Any import torch -from lingua.args import dump_config -from lingua.data import init_choice_state, setup_sources from lm_eval import simple_evaluate from lm_eval.api.instance import Instance from lm_eval.api.model import LM from omegaconf import OmegaConf +from pydantic import BaseModel, ConfigDict +from bytelatent.args import EvalArgs, ValidationArgs, parse_args from bytelatent.checkpoint import CONSOLIDATE_FOLDER, consolidate_checkpoints +from bytelatent.data.file_util import get_fs from bytelatent.distributed import ( DistributedArgs, dist_mean_dict, @@ -25,72 +25,17 @@ from bytelatent.distributed import ( get_world_size, setup_torch_distributed, ) -from bytelatent.transformer import LMTransformer, LMTransformerArgs - -from apps.main.generate import ( +from bytelatent.generate import ( PackedCausalTransformerGenerator, - PackedCausalTransformerGeneratorArgs, load_consolidated_model_and_tokenizer, ) +from bytelatent.transformer import LMTransformer, LMTransformerArgs EVAL_FOLDER_NAME = "{:010d}" logger = logging.getLogger() -@dataclass -class LMHarnessArgs: - tasks: Optional[List[Any]] = None - num_fewshot: Optional[int] = None - device: Optional[str] = None - use_cache: Optional[str] = None - cache_requests: bool = False - rewrite_requests_cache: bool = False - delete_requests_cache: bool = False - limit: Optional[Union[int, float]] = None - bootstrap_iters: int = 100000 - check_integrity: bool = False - write_out: bool = False - log_samples: bool = True - system_instruction: Optional[str] = None - apply_chat_template: Union[bool, str] = False - fewshot_as_multiturn: bool = False - gen_kwargs: Optional[str] = None - verbosity: str = "INFO" - predict_only: bool = False - random_seed: int = 0 - numpy_random_seed: int = 1234 - torch_random_seed: int = 1234 - fewshot_random_seed: int = 1234 - - -@dataclass -class ValidationArgs: - max_steps: Optional[int] = ( - None # If None the whole validation file is used -> /!\ This number of steps is gpu dependent (100 max steps on 8 gpus = 800 steps on 1 gpu) - ) - use_val_from_train_src: bool = True # Use the validation set from training sources - root_dir: str = "" - sources: List[str] = field(default_factory=list) # Other sources to eval on - - -@dataclass -class EvalArgs: - name: str = "evals" - dump_dir: Optional[str] = None - metric_log_dir: Optional[str] = None - ckpt_dir: str = "" - generator: PackedCausalTransformerGeneratorArgs = field( - default_factory=PackedCausalTransformerGeneratorArgs - ) - harness: Optional[LMHarnessArgs] = field(default_factory=LMHarnessArgs) - validation: Optional[ValidationArgs] = field(default_factory=ValidationArgs) - - wandb: Optional[Any] = None - - global_step: Optional[int] = None # for in-training evaluation - - def all_dicts_same(dict_list): if not dict_list: # Check if the list is empty return True @@ -120,7 +65,7 @@ class EvalHarnessLM(LM): self._world_size = get_world_size() self.device = generator.device - def generate_until(self, requests: List[Instance]) -> List[str]: + def generate_until(self, requests: list[Instance]) -> list[str]: prompts, gen_args = zip(*[req.args for req in requests]) assert all_dicts_same(gen_args), "Doesn't support different gen args for now" gen_args = gen_args[0] @@ -141,7 +86,7 @@ class EvalHarnessLM(LM): filtered_gen.append(g) return filtered_gen - def loglikelihood(self, requests: List[Instance]) -> List[Tuple[float, bool]]: + def loglikelihood(self, requests: list[Instance]) -> list[tuple[float, bool]]: prompts, continuations = zip(*[req.args for req in requests]) inputs = [req.args[0] + req.args[1] for req in requests] max_gen_len = self.generator.max_gen_len @@ -158,7 +103,7 @@ class EvalHarnessLM(LM): self.generator.max_gen_len = max_gen_len return results - def loglikelihood_rolling(self, requests: List[Instance]) -> List[float]: + def loglikelihood_rolling(self, requests: list[Instance]) -> list[float]: prompts = [req.args[0] for req in requests] max_gen_len = self.generator.max_gen_len # We temporarily lower max gen len @@ -232,68 +177,73 @@ def eval_on_val(generator, val_args: ValidationArgs, train_cfg): return all_val_metrics -def launch_eval(cfg: EvalArgs): +def launch_eval(eval_args: EvalArgs): if not torch.distributed.is_initialized(): setup_torch_distributed(DistributedArgs()) + + fs = get_fs(eval_args.ckpt_dir, s3_profile=eval_args.s3_profile) if ( - Path(cfg.ckpt_dir).exists() - and (Path(cfg.ckpt_dir) / "params.json").exists() - and next(Path(cfg.ckpt_dir).glob("*.pth"), None) is not None + fs.exists(eval_args.ckpt_dir) + and fs.exists(os.path.join(eval_args.ckpt_dir, "params.json")) + and len(fs.glob(os.path.join(eval_args.ckpt_dir, "*.pth"))) != 0 ): - consolidate_path = Path(cfg.ckpt_dir) + consolidate_path = eval_args.ckpt_dir else: - consolidate_path = Path(cfg.ckpt_dir) / CONSOLIDATE_FOLDER - if not consolidate_path.exists() and get_global_rank() == 0: - consolidate_path = consolidate_checkpoints(cfg.ckpt_dir) + consolidate_path = os.path.join(eval_args.ckpt_dir, CONSOLIDATE_FOLDER) + if not fs.exists(consolidate_path) and get_global_rank() == 0: + consolidate_path = consolidate_checkpoints(eval_args.ckpt_dir) - Path(cfg.dump_dir).mkdir(parents=True, exist_ok=True) - dump_config(cfg, Path(cfg.dump_dir) / "config.yaml", log_config=False) + fs.mkdirs(eval_args.dump_dir, exist_ok=True) + with fs.open(os.path.join(eval_args.dump_dir, "config.yaml"), "w") as f: + f.write(eval_args.model_dump_json()) - consolidate_path = str(consolidate_path) torch.distributed.barrier() logger.info("Loading model") + # TODO: Make this general so that it works with either + # LMTransformer or Blt, similar with args model, tokenizer, train_cfg = load_consolidated_model_and_tokenizer( consolidate_path, - model_cls=LMTransformer, - model_args_cls=LMTransformerArgs, ) logger.info("Model loaded") model.eval() - generator = PackedCausalTransformerGenerator(cfg.generator, model, tokenizer) + generator = PackedCausalTransformerGenerator(eval_args.generator, model, tokenizer) wrap = EvalHarnessLM(generator) - results = simple_evaluate(wrap, **asdict(cfg.harness)) + # Redo + results = simple_evaluate(wrap, eval_args.harness.model_dump()) val_results = None - if cfg.validation: - val_results = eval_on_val(generator, cfg.validation, train_cfg) + if eval_args.validation: + val_results = eval_on_val(generator, eval_args.validation, train_cfg) if get_global_rank() == 0: - with open(Path(cfg.dump_dir) / "results.json", "w") as f: + with fs.open(os.path.join(eval_args.dump_dir, "results.json"), "w") as f: f.write(json.dumps(results)) logger.info(f"All evaluation results: {results['results']}") if val_results is not None: - with open(Path(cfg.dump_dir) / "validation.json", "w") as f: + with fs.open(os.path.join(eval_args.dump_dir, "validation.json"), "w") as f: f.write(json.dumps(val_results)) logger.info(f"All validation results: {val_results}") - if cfg.metric_log_dir and get_global_rank() == 0: - metric_log_path = Path(cfg.metric_log_dir) / "metrics.eval.jsonl" + if eval_args.metric_log_dir and get_global_rank() == 0: + metric_log_path = os.path.join(eval_args.metric_log_dir, "metrics.eval.jsonl") logger.info(f"Writing metric logs to {metric_log_path}") timestamp = { "created_at": datetime.utcnow().isoformat(), } - if cfg.global_step is not None: - timestamp["global_step"] = cfg.global_step + if eval_args.global_step is not None: + timestamp["global_step"] = eval_args.global_step print( json.dumps(timestamp | results["results"]), - file=open(metric_log_path, mode="a"), + file=fs.open(metric_log_path, mode="a"), flush=True, ) - val_log_path = Path(cfg.metric_log_dir) / "metrics.validation.jsonl" + val_log_path = os.path.join( + eval_args.metric_log_dir, "metrics.validation.jsonl" + ) if val_results is not None: print( json.dumps(timestamp | val_results), - file=open(val_log_path, mode="a"), + file=fs.open(val_log_path, mode="a"), flush=True, ) @@ -301,53 +251,8 @@ def launch_eval(cfg: EvalArgs): def main(): - """ - The command line interface here uses OmegaConf https://omegaconf.readthedocs.io/en/2.3_branch/usage.html#from-command-line-arguments - This accepts arguments as a dot list - So if the dataclass looks like - - @dataclass - class DummyArgs: - name: str - model: LMTransformerArgsgs - - @dataclass - class LMTransformerArgsgs: - dim: int - - Then you can pass model.dim=32 to change values in LMTransformerArgsgs - or just name=tictac for top level attributes. - - The behavior here is as follows: - 1. We instantiate EvalArgs with its default values - 2. We override those default values with the ones in the provided config file - 3. We override the result with the additional arguments provided through command line - - For example, if the config is the following - - model: - dim: 128 - n_layers: 4 - - and you call eval.py with eval.py model.dim=64 - - Then the final TrainArgs will have - - model: - dim: 64 - n_layers: 4 - - Plus all the default values in EvalArgs dataclass. - """ - cli_args = OmegaConf.from_cli() - file_cfg = OmegaConf.load(cli_args.config) - # We remove 'config' attribute from config as the underlying DataClass does not have it - del cli_args.config - - default_cfg = OmegaConf.structured(EvalArgs()) - cfg = OmegaConf.merge(default_cfg, file_cfg, cli_args) - cfg = OmegaConf.to_object(cfg) - launch_eval(cfg) + eval_args = parse_args(EvalArgs) + launch_eval(eval_args) if __name__ == "__main__": diff --git a/apps/main/generate.py b/bytelatent/generate.py similarity index 91% rename from apps/main/generate.py rename to bytelatent/generate.py index a3a8627..eb79d81 100644 --- a/apps/main/generate.py +++ b/bytelatent/generate.py @@ -1,20 +1,16 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. +import os import time -from dataclasses import dataclass, field -from pathlib import Path -from typing import List, Optional import torch -from lingua.args import dataclass_from_dict -from lingua.tokenizers.abstract_tokenizer import Tokenizer -from lingua.tokenizers.build_tokenizer import build_tokenizer from omegaconf import OmegaConf from torch import nn from torch.nn import functional as F from torch.nn.attention.flex_attention import create_block_mask from tqdm import tqdm +from bytelatent.args import PackedCausalTransformerGeneratorArgs, TrainArgs from bytelatent.base_transformer import ( Attention, causal_mask, @@ -23,7 +19,10 @@ from bytelatent.base_transformer import ( lengths_to_start_ids, ) from bytelatent.checkpoint import CONSOLIDATE_NAME -from bytelatent.transformer import LMTransformer, LMTransformerArgs +from bytelatent.data.file_util import get_fs +from bytelatent.model.blt import ByteLatentTransformer +from bytelatent.tokenizers.abstract_tokenizer import Tokenizer +from bytelatent.transformer import LMTransformer def sample_top_p(probs: torch.Tensor, p: float) -> torch.Tensor: @@ -62,7 +61,7 @@ def sample_tokens(logits, temperature=0.0, top_p=None, top_k=None): return next_token.view(shape[:-1]) -def pack_prompts(prompts: List[int]): +def pack_prompts(prompts: list[int]): res = [] lengths = [] for i, p in enumerate(prompts): @@ -120,22 +119,6 @@ class KVCache(nn.Module): return self.k_cache, self.v_cache -@dataclass -class PackedCausalTransformerGeneratorArgs: - temperature: float = 0.0 - top_p: Optional[float] = None - top_k: Optional[float] = None - max_gen_len: int = 512 # Maximum number of tokens to generate - max_tokens: int = 1024 # Maximum number of tokens that can go through the model - max_prompt_len: Optional[int] = None - until: List[str] = field(default_factory=list) - compile_prefilling: bool = False - reduce_generation_overhead: bool = False - show_progress: bool = False - dtype: Optional[str] = "bf16" - device: Optional[str] = "cuda" - - class PackedCausalTransformerGenerator: def __init__( self, @@ -401,25 +384,29 @@ class PackedCausalTransformerGenerator: def load_consolidated_model_and_tokenizer( consolidated_path, - model_cls=LMTransformer, - model_args_cls=LMTransformerArgs, ): - ckpt_path = Path(consolidated_path) - config = ckpt_path / "params.json" - config = OmegaConf.load(config) + train_args_path = os.path.join(consolidated_path, "params.json") + fs = get_fs(train_args_path) + with fs.open(train_args_path) as f: + train_args = TrainArgs.model_validate_json(f.read()) + + if train_args.train_entropy_model: + model_args = train_args.entropy_model + model = LMTransformer(model_args) + else: + model_args = train_args.model + model = ByteLatentTransformer(model_args) param_dtype = dict(fp32=torch.float32, fp16=torch.float16, bf16=torch.bfloat16)[ - config.distributed.model_dtype + train_args.distributed.model_dtype ] - model_args = dataclass_from_dict(model_args_cls, config.model, strict=False) - tokenizer = build_tokenizer(config.data.tokenizer.name, config.data.tokenizer.path) - model = model_cls(model_args) - st_dict = torch.load(ckpt_path / CONSOLIDATE_NAME, weights_only=True) + tokenizer = train_args.data.tokenizer_args.build() + st_dict = torch.load(consolidated_path / CONSOLIDATE_NAME, weights_only=True) model.load_state_dict(st_dict["model"]) model = model.cuda().eval() for param in model.parameters(): param.data = param.data.to(dtype=param_dtype) - return model, tokenizer, config + return model, tokenizer, train_args def main(): diff --git a/bytelatent/train.py b/bytelatent/train.py index 1d0fa40..6b20ecd 100644 --- a/bytelatent/train.py +++ b/bytelatent/train.py @@ -10,7 +10,7 @@ from copy import deepcopy from dataclasses import asdict, dataclass from pathlib import Path from timeit import default_timer as timer -from typing import Any, Dict, Type, TypeVar +from typing import Any, TypeVar import torch import torch.distributed @@ -23,9 +23,13 @@ from torch.distributed._tensor import DTensor from torch.distributed.checkpoint.stateful import Stateful from torch.optim import lr_scheduler -from bytelatent.args import TrainArgs +from bytelatent.args import TrainArgs, parse_args from bytelatent.checkpoint import CheckpointManager, load_from_checkpoint -from bytelatent.data.data_types import DataLoaderState +from bytelatent.data.iterators.multiprocess_iterator import ( + MultiprocessIterator, + MultiprocessIteratorState, +) +from bytelatent.data.iterators.packing_iterator import PackingIteratorState from bytelatent.distributed import ( check_model_value_range, clean_env, @@ -39,6 +43,7 @@ from bytelatent.distributed import ( setup_env, setup_torch_distributed, ) +from bytelatent.eval import EVAL_FOLDER_NAME, launch_eval from bytelatent.logger import init_logger from bytelatent.metrics import GPUMemoryMonitor, MetricLogger, get_num_params from bytelatent.model.blt import ByteLatentTransformer @@ -70,36 +75,49 @@ def flatten_dict(d, parent_key="", sep="_"): return dict(items) -def dataclass_from_dict(cls: Type[T], data: dict, strict: bool = True) -> T: - """ - Converts a dictionary to a dataclass instance, recursively for nested structures. - """ - base = OmegaConf.structured(cls()) - OmegaConf.set_struct(base, strict) - override = OmegaConf.create(data) - return OmegaConf.to_object(OmegaConf.merge(base, override)) +def get_iterator_state_name(iterator_state): + if isinstance(iterator_state, MultiprocessIteratorState): + return "multiprocess" + elif isinstance(iterator_state, PackingIteratorState): + return "packing" + else: + raise ValueError(f"Unsupported iterator to get name from: {iterator_state}") +# TODO: Make this pydantic based instead of data class based +# TODO: Generalize this to any iterator state @dataclass class TrainState(Stateful): step: int # Nb of steps taken by the optimizer acc_step: int # Nb of accumulation steps done since last optimizer step scheduler: lr_scheduler.LambdaLR - data_loader_state: DataLoaderState + data_loader_state: MultiprocessIteratorState | PackingIteratorState scale: float = 1.0 + data_loader_class: str | None = None - def state_dict(self) -> Dict[str, Any]: + def state_dict(self) -> dict[str, Any]: return { "step": self.step, "acc_step": self.acc_step, - "data_loader_state": self.data_loader_state.dict(), + "data_loader_state": self.data_loader_state.model_dump(), + "data_loader_class": get_iterator_state_name(self.data_loader_state), "scheduler": self.scheduler.state_dict(), } def load_state_dict(self, state_dict): self.step = state_dict["step"] self.acc_step = state_dict["acc_step"] - self.data_loader_state = DataLoaderState(**state_dict["data_loader_state"]) + self.data_loader_class = state_dict["data_loader_class"] + if self.data_loader_class == "multiprocess": + self.data_loader_state = MultiprocessIteratorState( + **state_dict["data_loader_state"] + ) + elif self.data_loader_class == "packing": + self.data_loader_state = PackingIteratorState( + **state_dict["data_loader_state"] + ) + else: + raise ValueError(f"invalid data loader class: {self.data_loader_class}") self.scheduler.load_state_dict(state_dict["scheduler"]) @@ -345,7 +363,10 @@ def train(args: TrainArgs): nwords_since_last_log = 0 time_last_log = timer() gc.collect() - while train_state.step < args.steps: + saved = False + while train_state.step < args.steps and ( + args.max_steps is None or train_state.step < args.max_steps + ): # We constrain train_state.acc_step to be in range 0 to args.grad_acc_steps - 1 train_state.acc_step += 1 train_state.acc_step = train_state.acc_step % args.grad_acc_steps @@ -552,7 +573,6 @@ def train(args: TrainArgs): f" pow: {gpu_mem_stats.power_draw/1000} W" ) - saved = False if every_n_steps( train_state, args.checkpoint.dump.every, acc_step=0 ) or every_n_steps(train_state, args.checkpoint.eval.every, acc_step=0): @@ -567,18 +587,14 @@ def train(args: TrainArgs): if args.eval is not None and every_n_steps( train_state, args.checkpoint.eval.every, acc_step=0 ): - from apps.main.eval import EVAL_FOLDER_NAME, EvalArgs, launch_eval - - eval_args = dataclass_from_dict(EvalArgs, args.eval) + eval_args = args.eval eval_args.global_step = train_state.step eval_args.ckpt_dir = str(checkpoint.existing_saves[-1]) - eval_args.dump_dir = str( - os.path.join( - args.dump_dir, - "evals", - EVAL_FOLDER_NAME.format(train_state.step), - ) + eval_args.dump_dir = os.path.join( + args.dump_dir, + "evals", + EVAL_FOLDER_NAME.format(train_state.step), ) eval_args.metric_log_dir = args.dump_dir if args.async_eval_gpus is None: @@ -619,6 +635,9 @@ def train(args: TrainArgs): args, device_mesh=world_mesh, ) + if isinstance(data_loader, MultiprocessIterator): + logger.info("Closing MP iterator before exiting") + data_loader.shutdown() gc.collect() @@ -661,15 +680,7 @@ def main(): Plus all the default values in TrainArgs dataclass. """ - cli_args = OmegaConf.from_cli() - file_cfg = OmegaConf.load(cli_args.config) - # We remove 'config' attribute from config as the underlying DataClass does not have it - del cli_args.config - - default_cfg = OmegaConf.create(TrainArgs().model_dump()) - cfg = OmegaConf.merge(default_cfg, file_cfg, cli_args) - cfg = OmegaConf.to_container(cfg, resolve=True, throw_on_missing=True) - train_args = TrainArgs.model_validate(cfg) + train_args = parse_args(TrainArgs) if train_args.debug_dynamo: import torch._dynamo From 11cad6c84d083645fbde617c0b125a615ecc094d Mon Sep 17 00:00:00 2001 From: Pedro Rodriguez Date: Tue, 28 Jan 2025 00:57:06 +0000 Subject: [PATCH 17/18] WIP parallel copy script Summary: Test Plan: --- bytelatent/data/parallel_copy.py | 224 +++++++++++++++++++++++++++++++ 1 file changed, 224 insertions(+) create mode 100644 bytelatent/data/parallel_copy.py diff --git a/bytelatent/data/parallel_copy.py b/bytelatent/data/parallel_copy.py new file mode 100644 index 0000000..2820429 --- /dev/null +++ b/bytelatent/data/parallel_copy.py @@ -0,0 +1,224 @@ +import logging +import os +import shutil +import time +from enum import Enum + +import fsspec +import submitit +import typer +from rich.logging import RichHandler + +FORMAT = "%(message)s" +logging.basicConfig( + level="INFO", format=FORMAT, datefmt="[%X]", handlers=[RichHandler()] +) +logger = logging.getLogger("parallel_copy") + + +S3_PREFIX = "s3://" + + +def get_fs(path: str, s3_profile: str | None = None) -> fsspec.AbstractFileSystem: + if path.startswith("s3://"): + if s3_profile is None: + return fsspec.filesystem( + "s3", default_block_size=1000000 * 2**20, max_concurrency=1 + ) + else: + return fsspec.filesystem( + "s3", + profile=s3_profile, + default_block_size=1000000 * 2**20, + max_concurrency=1, + ) + else: + return fsspec.filesystem("file") + + +def strip_s3_prefix(path: str): + if path.startswith(S3_PREFIX): + return path[len(S3_PREFIX) :] + else: + return path + + +class OverwriteMode(str, Enum): + ALWAYS = "always" + SIZE_MISMATCH = "size_mismatch" + NEVER = "never" + + +class ParallelMode(str, Enum): + SLURM = "slurm" + MULTIPROCESS = "multiprocess" + + +def copy_src_to_dst( + src_fs: fsspec.AbstractFileSystem, + dst_fs: fsspec.AbstractFileSystem, + src_file: str, + dst_file: str, + dry_run: bool = False, +): + if dry_run: + logging.info("Dry run copy: %s -> %s", src_file, dst_file) + else: + dst_parent_directory = os.path.dirname(dst_file) + dst_fs.mkdirs(dst_parent_directory, exist_ok=True) + with src_fs.open(src_file, "rb") as src_pointer, dst_fs.open( + dst_file, "wb" + ) as dst_pointer: + shutil.copyfileobj(src_pointer, dst_pointer) + + +class CopyJob(submitit.helpers.Checkpointable): + def __call__( + self, + src_fs_dict: dict, + dst_fs_dict: dict, + src_file: str, + dst_file: str, + dry_run: bool = False, + validate_size: bool = True, + ): + src_fs = fsspec.AbstractFileSystem.from_dict(src_fs_dict) + dst_fs = fsspec.AbstractFileSystem.from_dict(dst_fs_dict) + copy_src_to_dst(src_fs, dst_fs, src_file, dst_file, dry_run=dry_run) + if validate_size and not dry_run: + src_size = src_fs.size(src_file) + dst_size = dst_fs.size(dst_file) + if src_size != dst_size: + raise ValueError( + f"Mismatched sizes for src={src_file} dst={dst_file} {src_size} != {dst_size}" + ) + return True + + +def main( + src_dir: str, + dst_dir: str, + src_s3_profile: str | None = None, + dst_s3_profile: str | None = None, + n_workers: int = 16, + cpus_per_task: int = 2, + overwrite_mode: OverwriteMode = OverwriteMode.SIZE_MISMATCH, + validate_size: bool = True, + parallel_mode: ParallelMode = ParallelMode.MULTIPROCESS, + dry_run: bool = False, + job_dir: str = "jobs_parallel-copy", + slurm_qos: str | None = None, + slurm_time_hours: int = 72, + slurm_memory: str = "0", + wait: bool = True, + wait_period: int = 5, +): + logging.info("Starting parallell copy: %s -> %s", src_dir, dst_dir) + logging.info("job_dir=%s", job_dir) + logging.info( + "Parallel=%s validate_size=%s overwrite_mode=%s qos=%s", + parallel_mode, + validate_size, + overwrite_mode, + slurm_qos, + ) + if parallel_mode == ParallelMode.MULTIPROCESS: + executor = submitit.LocalExecutor(folder=job_dir) + elif parallel_mode == ParallelMode.SLURM: + executor = submitit.SlurmExecutor(folder=job_dir) + executor.update_parameters( + time=slurm_time_hours * 60, + ntasks_per_node=1, + cpus_per_task=cpus_per_task, + array_parallelism=n_workers, + mem=slurm_memory, + gpus_per_node=0, + ) + if slurm_qos is not None: + executor.update_parameters(qos=slurm_qos) + else: + raise ValueError("Invalid parallel mode") + + assert src_dir.endswith("/"), "src_dir must end with a /" + assert dst_dir.endswith("/"), "dst_dir must end with a /" + src_fs = get_fs(src_dir, s3_profile=src_s3_profile) + dst_fs = get_fs(dst_dir, s3_profile=dst_s3_profile) + + src_dir = strip_s3_prefix(src_dir) + dst_dir = strip_s3_prefix(dst_dir) + logging.info("src: %s, dst: %s", src_dir, dst_dir) + + assert src_fs.isdir(src_dir), "src_dir must be a directory" + if dst_fs.exists(dst_dir): + assert dst_dir, "dst_dir must be a directory if it exists" + else: + dst_fs.mkdirs(dst_dir, exist_ok=True) + + files = src_fs.find(src_dir) + logging.info("Files found to check for transfer: %s", len(files)) + jobs = [] + with executor.batch(): + for src_file in files: + relative_src = src_file[len(src_dir) :] + dst_file_path = os.path.join(dst_dir, relative_src) + logging.debug("src: %s -> dst %s", src_file, dst_file_path) + if dst_fs.exists(dst_file_path): + if overwrite_mode == OverwriteMode.NEVER: + pass + elif overwrite_mode == OverwriteMode.ALWAYS: + logging.info("copy: %s -> %s", src_file, dst_file_path) + job = executor.submit( + CopyJob(), + src_fs.to_dict(), + dst_fs.to_dict(), + src_file, + dst_file_path, + dry_run=dry_run, + validate_size=validate_size, + ) + jobs.append(job) + elif overwrite_mode == OverwriteMode.SIZE_MISMATCH: + if src_fs.size(src_file) != dst_fs.size(dst_file_path): + logging.info("copy: %s -> %s", src_file, dst_file_path) + job = executor.submit( + CopyJob(), + src_fs.to_dict(), + dst_fs.to_dict(), + src_file, + dst_file_path, + dry_run=dry_run, + validate_size=validate_size, + ) + jobs.append(job) + else: + raise ValueError("Unknown overwrite_mode") + else: + logging.info("copy: %s -> %s", src_file, dst_file_path) + job = executor.submit( + CopyJob(), + src_fs.to_dict(), + dst_fs.to_dict(), + src_file, + dst_file_path, + dry_run=dry_run, + validate_size=validate_size, + ) + jobs.append(job) + if wait: + while True: + num_finished = sum(job.done() for job in jobs) + logging.info("Total Jobs: %s Completed Jobs: %s", len(jobs), num_finished) + if num_finished == len(jobs): + break + time.sleep(wait_period) + output = [job.result() for job in jobs] + if all(output): + logging.info("All copies succeeded") + else: + logging.info("Some copies failed") + else: + logging.info("Not waiting for job to complete before exiting submit program") + + +if __name__ == "__main__": + typer.run(main) From c6ef4285e2f289fe3e87216c9ca0f9eb58ea5c44 Mon Sep 17 00:00:00 2001 From: Pedro Rodriguez Date: Tue, 4 Feb 2025 18:03:18 +0000 Subject: [PATCH 18/18] Several changes to enable entropy model training/eval Summary: - Make arrow iterator able to read from jsonl files, the entropies are omitted in this case - Make the data/checkpoint code fsspec compatible - Fix issues with all reduce with non-bf16 in dist_sum and norm computation. - Minimal fixes to get eval to run, it is slow currently - Add bpb numbers during training Test Plan: --- bytelatent/args.py | 44 +++++++- bytelatent/checkpoint.py | 97 +++++++++-------- bytelatent/data/iterators/arrow_iterator.py | 100 ++++++++++-------- bytelatent/distributed.py | 18 +++- bytelatent/eval.py | 64 +++++++---- bytelatent/generate.py | 6 +- bytelatent/norms.py | 99 +++++++++++++++++ bytelatent/preprocess/preprocess_entropies.py | 26 +++-- bytelatent/train.py | 59 ++++++++++- 9 files changed, 382 insertions(+), 131 deletions(-) create mode 100644 bytelatent/norms.py diff --git a/bytelatent/args.py b/bytelatent/args.py index d1bac46..23c571a 100644 --- a/bytelatent/args.py +++ b/bytelatent/args.py @@ -1,8 +1,10 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. +import json import logging import os from typing import Any +import fsspec import numpy as np import yaml from omegaconf import OmegaConf @@ -10,10 +12,10 @@ from pydantic import BaseModel, ConfigDict from bytelatent.checkpoint import CheckpointArgs from bytelatent.data.data_types import Batch +from bytelatent.data.file_util import get_fs from bytelatent.data.iterators.abstract_iterator import StatefulIterator from bytelatent.data.iterators.arrow_iterator import ( ArrowFileIterator, - find_and_sanitize_chunks, ) from bytelatent.data.iterators.looping_iterator import LoopingIterator from bytelatent.data.iterators.multiprocess_iterator import MultiprocessIterator @@ -53,6 +55,43 @@ def parse_args(args_cls): return pydantic_args +def read_args_file(fs: fsspec.AbstractFileSystem, path: str) -> Any: + with fs.open(path, "rt") as f: + if path.endswith(".json"): + return json.load(f) + elif path.endswith(".yaml"): + return yaml.load(f) + else: + raise ValueError("Invalid args file format") + + +TRAIN_DATA_FILE_PATTERN = "*.chunk.*.jsonl" + + +def find_and_sanitize_chunks( + dataset_path: str, + world_size: int, + file_pattern: str, + s3_profile: str | None = None, +): + fs = get_fs(dataset_path, s3_profile=s3_profile) + path_with_glob = os.path.join(dataset_path, file_pattern) + dataset_chunks = fs.glob(path_with_glob) + n_chunks = len(dataset_chunks) + + if n_chunks > world_size: + n_discard = n_chunks - world_size + dataset_chunks = dataset_chunks[:world_size] + else: + assert ( + world_size % n_chunks == 0 + ), "World size should be a multiple of number of chunks" + + assert n_chunks > 0, f"No valid chunks in {dataset_path}" + + return dataset_chunks + + def distribute_data_to_rank( *, dataset_path: str, @@ -62,9 +101,10 @@ def distribute_data_to_rank( rank: int, world_size: int, s3_profile: str | None = None, + file_pattern: str = TRAIN_DATA_FILE_PATTERN, ) -> ArrowFileIterator: dataset_chunks = find_and_sanitize_chunks( - dataset_path, world_size, s3_profile=s3_profile + dataset_path, world_size, file_pattern, s3_profile=s3_profile ) n_workers_per_chunk = world_size // len(dataset_chunks) rank_to_arrow_iterator_params = [] diff --git a/bytelatent/checkpoint.py b/bytelatent/checkpoint.py index f213c84..6631673 100644 --- a/bytelatent/checkpoint.py +++ b/bytelatent/checkpoint.py @@ -4,8 +4,6 @@ import json import logging import os import re -from pathlib import Path -from typing import List, Optional, Tuple import fsspec import torch @@ -70,26 +68,29 @@ def consolidate_checkpoints(fs: fsspec.AbstractFileSystem, ckpt_dir: str): Returns the path to the consolidated checkpoint """ - consolidate_path = Path(ckpt_dir) / CONSOLIDATE_FOLDER - if not (consolidate_path / CONSOLIDATE_NAME).exists(): - consolidate_path.mkdir(exist_ok=True) - logger.info(f"Consolidating to: {str(consolidate_path)}") - dcp_to_torch_save(ckpt_dir, str(consolidate_path / CONSOLIDATE_NAME)) - (consolidate_path / CONFIG_NAME).write_text( - (Path(ckpt_dir) / CONFIG_NAME).read_text() + consolidate_path = os.path.join(ckpt_dir, CONSOLIDATE_FOLDER) + consolidate_name = os.path.join(consolidate_path, CONSOLIDATE_NAME) + if not fs.exists(consolidate_name): + fs.mkdirs(consolidate_path, exist_ok=True) + logger.info(f"Consolidating to: {consolidate_path}") + dcp_to_torch_save(ckpt_dir, consolidate_name) + fs.write_text( + os.path.join(consolidate_path, CONFIG_NAME), + fs.read_text(os.path.join(ckpt_dir, CONFIG_NAME)), ) logger.info("Consolidated !") return consolidate_path def load_from_checkpoint( + fs: fsspec.AbstractFileSystem, ckpt_dir: str, model: nn.Module, - optimizer: Optional[torch.optim.Optimizer] = None, + optimizer: torch.optim.Optimizer | None = None, model_key: str = "model", optim_key: str = "optim", ): - if not (Path(ckpt_dir) / ".metadata").exists(): + if not fs.exists(os.path.join(ckpt_dir, ".metadata")): raise ValueError( f"Please convert the checkpoint distcp format using `torch.distributed.checkpoint.format_utils.torch_save_to_dcp` before loading it" ) @@ -121,13 +122,13 @@ class CheckpointManager: self.existing_saves = self.get_existing_saves() - def get_existing_saves(self) -> List[Path]: + def get_existing_saves(self) -> list[str]: folders = [ p - for p in Path(self.path).iterdir() - if p.is_dir() and re.match(RE_FOLDER, p.name) + for p in self.fs.ls(self.path) + if self.fs.isdir(p) and re.match(RE_FOLDER, os.path.basename(p)) ] - folders.sort(key=lambda p: _get_key_step(p.name)) + folders.sort(key=lambda p: _get_key_step(os.path.basename(p))) return folders def clean_up(self): @@ -136,8 +137,9 @@ class CheckpointManager: eval_folders = [] other_folders = [] for p in self.existing_saves: - is_dump = _get_key_step(p.name) % self.dump_every.every == 0 - is_eval = _get_key_step(p.name) % self.eval_every.every == 0 + assert isinstance(p, str), f"Base path type: {p}" + is_dump = _get_key_step(os.path.basename(p)) % self.dump_every.every == 0 + is_eval = _get_key_step(os.path.basename(p)) % self.eval_every.every == 0 if is_dump: dump_folders.append(p) if is_eval: @@ -161,40 +163,39 @@ class CheckpointManager: if dist.get_rank() == 0: for folder in folder_to_remove: - for file in folder.iterdir(): - if file.is_file(): - file.unlink() - elif file.is_dir(): - assert file.name in [CONSOLIDATE_FOLDER] - for f in file.iterdir(): - f.unlink() - file.rmdir() - folder.rmdir() + for file in self.fs.ls(folder): + if self.fs.isfile(file): + self.fs.rm_file(file) + elif self.fs.isdir(file): + assert os.path.name(file) in [CONSOLIDATE_FOLDER] + for f in self.fs.ls(file): + self.fs.rm(f) + self.fs.rmdir(file) + self.fs.rmdir(folder) dist.barrier() self.existing_saves = list(folder_to_keep) - self.existing_saves.sort(key=lambda p: _get_key_step(p.name)) + self.existing_saves.sort(key=lambda p: _get_key_step(os.path.basename(p))) - def get_last_step_path(self, dp_rank: int = 0) -> Optional[Path]: + def get_last_step_path(self, dp_rank: int = 0) -> str | None: path = None for p in reversed(self.existing_saves): - if (p / TRAIN_STATE_NAME.format(dp_rank)).is_file(): + + if self.fs.isfile(os.path.join(p, TRAIN_STATE_NAME.format(dp_rank))): path = p break return path - def _create_folder(self, base_path: Path, folder_name: str) -> Path: - folder = base_path / folder_name + def _create_folder(self, base_path: str, folder_name: str) -> str: + folder = os.path.join(base_path, folder_name) if get_is_master(): - folder.mkdir(parents=False, exist_ok=True) + self.fs.mkdirs(folder, exist_ok=True) if dist.is_initialized(): dist.barrier() return folder - def _get_dp_tp_mesh( - self, device_mesh: Optional[DeviceMesh] = None - ) -> Tuple[int, int]: + def _get_dp_tp_mesh(self, device_mesh: DeviceMesh | None = None) -> tuple[int, int]: dp_rank = 0 tp_rank = 0 if device_mesh is not None: @@ -222,14 +223,14 @@ class CheckpointManager: model, optimizer, train_state, - config, - device_mesh: Optional[DeviceMesh] = None, + config: BaseModel, + device_mesh: DeviceMesh | None = None, ) -> bool: # When creating directory check if only rank0 or is there other solution - path = Path(self.path) + path = self.path curr_save_dir = self._create_folder(path, FOLDER_NAME.format(train_state.step)) - logger.info(f"Saving to: {str(curr_save_dir)}") + logger.info(f"Saving to: {curr_save_dir}") if dist.is_initialized(): dist.barrier() @@ -242,17 +243,19 @@ class CheckpointManager: if dist.is_initialized(): dist.barrier() + print("config type", type(config)) if get_is_master(): - config.dump_to_yaml_file(curr_save_dir / CONFIG_NAME) + self.fs.write_text( + os.path.join(curr_save_dir, CONFIG_NAME), config.model_dump_json() + ) # Add json dump here dp_rank, tp_rank = self._get_dp_tp_mesh(device_mesh) if tp_rank == 0: train_state_name = TRAIN_STATE_NAME.format(dp_rank) - logger.info( - f"Saving train state to: {str(curr_save_dir / train_state_name)}" - ) - with open(curr_save_dir / train_state_name, "w") as f: + train_state_full_path = os.path.join(curr_save_dir, train_state_name) + logger.info(f"Saving train state to: {train_state_full_path}") + with self.fs.open(train_state_full_path, "w") as f: json.dump(train_state.state_dict(), f) logger.info("Train state saved !") @@ -271,7 +274,7 @@ class CheckpointManager: optimizer, train_state, device_mesh: DeviceMesh, - path: Optional[Path] = None, + path: str | None = None, ): dp_rank, tp_rank = self._get_dp_tp_mesh(device_mesh) # Loading tries to load the provided path, if not available the last saved step and finally from the init path @@ -284,12 +287,12 @@ class CheckpointManager: # Only load train state if it's provided, the files exist and we're not loading from init path train_state_name = TRAIN_STATE_NAME.format(dp_rank) logger.info("Reloading train state") - with open(path / train_state_name, "r") as f: + with self.fs.open(os.path.join(path, train_state_name), "r") as f: train_state_dict = json.load(f) train_state.load_state_dict(train_state_dict) logger.info("Train state reloaded") - logger.info(f"Loading from: {str(path)}") + logger.info(f"Loading from: {path}") state_dict = self.get_state_dict( model=model, optimizer=optimizer, diff --git a/bytelatent/data/iterators/arrow_iterator.py b/bytelatent/data/iterators/arrow_iterator.py index 4e7b99e..fa18bf1 100644 --- a/bytelatent/data/iterators/arrow_iterator.py +++ b/bytelatent/data/iterators/arrow_iterator.py @@ -16,6 +16,10 @@ from bytelatent import ByteLatentError from bytelatent.data.data_types import BltExample from bytelatent.data.file_util import get_fs from bytelatent.data.iterators.abstract_iterator import IteratorState, StatefulIterator +from bytelatent.preprocess.preprocess_entropies import ( + get_id_key, + get_text, +) logger = getLogger(__name__) @@ -32,6 +36,7 @@ class ArrowFileIteratorState(BaseModel, IteratorState): arrow_batch_size: int = 100 s3_profile: str | None filesystem_type: str | None = None + file_format: str def build(self) -> "ArrowFileIterator": arrow_file = ArrowFileIterator( @@ -44,6 +49,7 @@ class ArrowFileIteratorState(BaseModel, IteratorState): dataset_files=self.dataset_files, s3_profile=self.s3_profile, filesystem_type=self.filesystem_type, + file_format=self.file_format, ) if self.row_num != 0: arrow_file._set_row_num(self.row_num) @@ -70,6 +76,7 @@ class ArrowFileIterator(StatefulIterator): dataset_files: list[str] | None = None, s3_profile: str | None = None, filesystem_type: str | None = None, + file_format: str = "arrow", ): assert 0 <= worker_id < num_workers, (worker_id, num_workers) if file_path is None and dataset_files is None: @@ -87,12 +94,16 @@ class ArrowFileIterator(StatefulIterator): self.arrow_batch_size = arrow_batch_size self.s3_profile = s3_profile self.filesystem_type = filesystem_type + self.file_format = file_format self.fs = None if self.filesystem_type is not None: if self.filesystem_type == "file": self.fs = fsspec.filesystem("file") elif self.filesystem_type == "s3": self.fs = fsspec.filesystem("s3", profile=s3_profile) + else: + raise ValueError("Unknown filesystem") + logger.info("Arrow iterator using fs=%s", self.fs) if dataset_files is None: # Prepare arrow shards @@ -153,6 +164,7 @@ class ArrowFileIterator(StatefulIterator): dataset_files=self.dataset_files, s3_profile=self.s3_profile, filesystem_type=self.filesystem_type, + file_format=self.file_format, ) def create_iter( @@ -164,7 +176,7 @@ class ArrowFileIterator(StatefulIterator): else: filesystem = None self.dataset = pa.dataset.dataset( - self.dataset_files, format="arrow", filesystem=filesystem + self.dataset_files, format=self.file_format, filesystem=filesystem ) self.batch_iterator = self.dataset.to_batches( batch_size=self.arrow_batch_size @@ -173,13 +185,22 @@ class ArrowFileIterator(StatefulIterator): if self.batch_to_consume is not None: batch_columns: dict[str, list] = self.batch_to_consume self.batch_to_consume = None - sample_ids = batch_columns["sample_id"] - texts = batch_columns["text"] - entropies = batch_columns["entropies"] + if self.file_format == "arrow": + sample_ids = batch_columns["sample_id"] + texts = batch_columns["text"] + entropies = batch_columns["entropies"] + elif self.file_format == "json": + # This data hasn't been preprocessed to a uniform format, + # so we have to do it now and omit entropies + sample_ids = batch_columns[get_id_key(batch_columns)] + texts = get_text(batch_columns) + entropies = None + else: + raise ValueError(f"Unknown file format: {self.file_format}") for i in range(len(sample_ids)): out = BltExample( sample_id=sample_ids[i], - entropies=entropies[i], + entropies=entropies[i] if entropies is not None else None, text=texts[i], tokens=None, mask=None, @@ -191,13 +212,22 @@ class ArrowFileIterator(StatefulIterator): for batch in self.batch_iterator: batch_columns = batch.to_pydict() - sample_ids = batch_columns["sample_id"] - texts = batch_columns["text"] - entropies = batch_columns["entropies"] + if self.file_format == "arrow": + sample_ids = batch_columns["sample_id"] + texts = batch_columns["text"] + entropies = batch_columns["entropies"] + elif self.file_format == "json": + # This data hasn't been preprocessed to a uniform format, + # so we have to do it now and omit entropies + sample_ids = batch_columns[get_id_key(batch_columns)] + texts = get_text(batch_columns) + entropies = None + else: + raise ValueError(f"Unknown file format: {self.file_format}") for i in range(len(sample_ids)): out = BltExample( sample_id=sample_ids[i], - entropies=entropies[i], + entropies=entropies[i] if entropies is not None else None, text=texts[i], tokens=None, mask=None, @@ -231,13 +261,24 @@ class ArrowFileIterator(StatefulIterator): for batch in self.batch_iterator: if len(batch) > curr_remaining: batch_columns: dict[str, list] = batch.to_pydict() - batch_columns["sample_id"] = batch_columns["sample_id"][ - curr_remaining: - ] - batch_columns["entropies"] = batch_columns["entropies"][ - curr_remaining: - ] - batch_columns["text"] = batch_columns["text"][curr_remaining:] + if self.file_format == "arrow": + leftover_sample_ids = batch_columns["sample_id"][ + curr_remaining: + ] + leftover_entropies = batch_columns["entropies"][curr_remaining:] + leftover_texts = batch_columns["text"][curr_remaining:] + elif self.file_format == "json": + leftover_sample_ids = batch_columns[get_id_key(batch_columns)][ + curr_remaining: + ] + leftover_entropies = None + leftover_texts = get_text(batch_columns)[curr_remaining:] + else: + raise ValueError(f"Unknown file format: {self.file_format}") + + batch_columns["sample_id"] = leftover_sample_ids + batch_columns["entropies"] = leftover_entropies + batch_columns["text"] = leftover_texts self.batch_to_consume = batch_columns break elif len(batch) == curr_remaining: @@ -250,30 +291,3 @@ class ArrowFileIterator(StatefulIterator): logger.info( f"Finished setting arrow position to {target_row_num} for {self.dataset_files}" ) - - -TRAIN_DATA_FILE_PATTERN = "*.chunk.*.jsonl" - - -def find_and_sanitize_chunks( - dataset_path: str, - world_size: int, - file_pattern: str = TRAIN_DATA_FILE_PATTERN, - s3_profile: str | None = None, -): - fs = get_fs(dataset_path, s3_profile=s3_profile) - path_with_glob = os.path.join(dataset_path, file_pattern) - dataset_chunks = fs.glob(path_with_glob) - n_chunks = len(dataset_chunks) - - if n_chunks > world_size: - n_discard = n_chunks - world_size - dataset_chunks = dataset_chunks[:world_size] - else: - assert ( - world_size % n_chunks == 0 - ), "World size should be a multiple of number of chunks" - - assert n_chunks > 0, f"No valid chunks in {dataset_path}" - - return dataset_chunks diff --git a/bytelatent/distributed.py b/bytelatent/distributed.py index 168cb7c..f4b57e2 100644 --- a/bytelatent/distributed.py +++ b/bytelatent/distributed.py @@ -127,6 +127,16 @@ def dist_max(x: Union[int, float], mesh: DeviceMesh = None): return tensor +def dist_sum( + x: Union[int, float], mesh: DeviceMesh = None, reduce_dtype: torch.dtype = None +): + tensor = torch.tensor(x).cuda() + if reduce_dtype is not None: + tensor = tensor.to(reduce_dtype) + dist.all_reduce(tensor, op=ReduceOp.SUM, group=mesh.get_group() if mesh else None) + return tensor + + def dist_mean(x: Union[int, float], mesh: DeviceMesh = None): tensor = torch.tensor(x).cuda() dist.all_reduce(tensor, op=ReduceOp.AVG, group=mesh.get_group() if mesh else None) @@ -236,7 +246,7 @@ def setup_env(env_args: EnvironmentArgs): logger.warning(f"WARNING: Setting {name} to {value}") -def setup_torch_distributed(dist_args): +def setup_torch_distributed(dist_args: DistributedArgs): """ Handle single and multi-GPU / multi-node / SLURM jobs. Initialize the following variables: @@ -388,14 +398,14 @@ def clean_env(): def parallelize_model( - model, + model: torch.nn.Module, device_mesh, model_args, distributed_args: DistributedArgs, fsdp_grouping_plan: Optional[List[Tuple[str, bool]]] = None, tp_parallelize=None, no_recompute_ops=None, -): +) -> torch.nn.Module: if distributed_args.tp_size > 1: assert ( distributed_args.fsdp_type == "full_shard" @@ -429,6 +439,8 @@ def parallelize_model( device_mesh["dp_shard"].size() == 1 ), "dp_shard must be 1 for no_shard fsdp_type" + # TODO: Remove with something better + # model = model.to(param_dtype) fsdp_config = dict( mp_policy=( MixedPrecisionPolicy( diff --git a/bytelatent/eval.py b/bytelatent/eval.py index ae73066..1943a33 100644 --- a/bytelatent/eval.py +++ b/bytelatent/eval.py @@ -15,9 +15,16 @@ from lm_eval.api.model import LM from omegaconf import OmegaConf from pydantic import BaseModel, ConfigDict -from bytelatent.args import EvalArgs, ValidationArgs, parse_args +from bytelatent.args import ( + EvalArgs, + TrainArgs, + ValidationArgs, + find_and_sanitize_chunks, + parse_args, +) from bytelatent.checkpoint import CONSOLIDATE_FOLDER, consolidate_checkpoints from bytelatent.data.file_util import get_fs +from bytelatent.data.iterators.arrow_iterator import ArrowFileIterator from bytelatent.distributed import ( DistributedArgs, dist_mean_dict, @@ -117,19 +124,40 @@ class EvalHarnessLM(LM): return results -def eval_on_val(generator, val_args: ValidationArgs, train_cfg): - srcs = {} +def eval_on_val(generator, val_args: ValidationArgs, train_cfg: TrainArgs): + srcs = [] for src in val_args.sources: path = os.path.join(val_args.root_dir, src) - srcs[path] = 1.0 + srcs.append(path) + for src in train_cfg.data.sources: path = os.path.join(train_cfg.data.root_dir, src) - srcs[path] = 1.0 + srcs.append(path) - multi_state = init_choice_state( - "", srcs, 0, get_global_rank(), get_world_size(), "*.val.jsonl" - ) - path_to_iter = setup_sources(multi_state) + path_to_iter = {} + for path in srcs: + chunks = find_and_sanitize_chunks( + path, + world_size=1, + file_pattern="*.val.jsonl", + s3_profile=train_cfg.data.s3_profile, + ) + assert ( + len(chunks) == 1 + ), f"There should be only 1 chunk per validation file, but found: {chunks}" + chunk = chunks[0] + iterator = ArrowFileIterator( + dataset_files=[chunk], + file_path=None, + preprocess_dir=None, + entropy_model_name=None, + worker_id=0, + num_workers=1, + arrow_batch_size=train_cfg.data.arrow_batch_size, + s3_profile=train_cfg.data.s3_profile, + file_format="json", + ) + path_to_iter[path] = iterator max_gen_len = generator.max_gen_len # We temporarily lower max gen len @@ -137,16 +165,11 @@ def eval_on_val(generator, val_args: ValidationArgs, train_cfg): all_val_metrics = {} for src in path_to_iter: - jsonl_iterator = path_to_iter[src] + example_iterator = path_to_iter[src].create_iter() texts = [] logger.info(f"Running validation on {src}...") - for step, (content, state) in enumerate(jsonl_iterator): - if state["current_iter"] > 0 or ( - val_args.max_steps is not None and step >= val_args.max_steps - ): - break - content_key = "text" if ("text" in content) else "content" - texts.append(content[content_key]) + for step, example in enumerate(example_iterator): + texts.append(example.text) _, loglikelihood, _ = generator.generate(texts) @@ -191,7 +214,7 @@ def launch_eval(eval_args: EvalArgs): else: consolidate_path = os.path.join(eval_args.ckpt_dir, CONSOLIDATE_FOLDER) if not fs.exists(consolidate_path) and get_global_rank() == 0: - consolidate_path = consolidate_checkpoints(eval_args.ckpt_dir) + consolidate_path = consolidate_checkpoints(fs, eval_args.ckpt_dir) fs.mkdirs(eval_args.dump_dir, exist_ok=True) with fs.open(os.path.join(eval_args.dump_dir, "config.yaml"), "w") as f: @@ -210,10 +233,12 @@ def launch_eval(eval_args: EvalArgs): wrap = EvalHarnessLM(generator) # Redo - results = simple_evaluate(wrap, eval_args.harness.model_dump()) + results = simple_evaluate(wrap, **eval_args.harness.model_dump()) + val_results = None if eval_args.validation: val_results = eval_on_val(generator, eval_args.validation, train_cfg) + if get_global_rank() == 0: with fs.open(os.path.join(eval_args.dump_dir, "results.json"), "w") as f: f.write(json.dumps(results)) @@ -222,6 +247,7 @@ def launch_eval(eval_args: EvalArgs): with fs.open(os.path.join(eval_args.dump_dir, "validation.json"), "w") as f: f.write(json.dumps(val_results)) logger.info(f"All validation results: {val_results}") + if eval_args.metric_log_dir and get_global_rank() == 0: metric_log_path = os.path.join(eval_args.metric_log_dir, "metrics.eval.jsonl") diff --git a/bytelatent/generate.py b/bytelatent/generate.py index eb79d81..9d44f30 100644 --- a/bytelatent/generate.py +++ b/bytelatent/generate.py @@ -387,8 +387,7 @@ def load_consolidated_model_and_tokenizer( ): train_args_path = os.path.join(consolidated_path, "params.json") fs = get_fs(train_args_path) - with fs.open(train_args_path) as f: - train_args = TrainArgs.model_validate_json(f.read()) + train_args = TrainArgs.model_validate_json(fs.read_text(train_args_path)) if train_args.train_entropy_model: model_args = train_args.entropy_model @@ -401,7 +400,8 @@ def load_consolidated_model_and_tokenizer( train_args.distributed.model_dtype ] tokenizer = train_args.data.tokenizer_args.build() - st_dict = torch.load(consolidated_path / CONSOLIDATE_NAME, weights_only=True) + with fs.open(os.path.join(consolidated_path, CONSOLIDATE_NAME)) as f: + st_dict = torch.load(f, weights_only=True) model.load_state_dict(st_dict["model"]) model = model.cuda().eval() for param in model.parameters(): diff --git a/bytelatent/norms.py b/bytelatent/norms.py new file mode 100644 index 0000000..6028b13 --- /dev/null +++ b/bytelatent/norms.py @@ -0,0 +1,99 @@ +from typing import Optional, Tuple, Dict, List +import torch +from torch import Tensor +from torch.utils._foreach_utils import ( + _device_has_foreach_support, + _group_tensors_by_device_and_dtype, + _has_foreach_support, +) + + +@torch.no_grad() +def fixed_clip_grad_norm_( + parameters: torch.Tensor | list[torch.Tensor], + max_norm: float, + norm_type: float = 2.0, + error_if_nonfinite: bool = False, + foreach: Optional[bool] = None, +) -> torch.Tensor: + r"""Clip the gradient norm of an iterable of parameters. + + The norm is computed over the norms of the individual gradients of all parameters, + as if the norms of the individual gradients were concatenated into a single vector. + Gradients are modified in-place. + + Args: + parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a + single Tensor that will have gradients normalized + max_norm (float): max norm of the gradients + norm_type (float): type of the used p-norm. Can be ``'inf'`` for + infinity norm. + error_if_nonfinite (bool): if True, an error is thrown if the total + norm of the gradients from :attr:`parameters` is ``nan``, + ``inf``, or ``-inf``. Default: False (will switch to True in the future) + foreach (bool): use the faster foreach-based implementation. + If ``None``, use the foreach implementation for CUDA and CPU native tensors and silently + fall back to the slow implementation for other device types. + Default: ``None`` + + Returns: + Total norm of the parameter gradients (viewed as a single vector). + """ + if isinstance(parameters, torch.Tensor): + parameters = [parameters] + grads = [p.grad.to(torch.bfloat16) for p in parameters if p.grad is not None] + max_norm = float(max_norm) + norm_type = float(norm_type) + if len(grads) == 0: + return torch.tensor(0.0) + first_device = grads[0].device + grouped_grads: Dict[ + Tuple[torch.device, torch.dtype], Tuple[List[List[Tensor]], List[int]] + ] = _group_tensors_by_device_and_dtype( + [grads] + ) # type: ignore[assignment] + + norms: List[Tensor] = [] + for (device, _), ([device_grads], _) in grouped_grads.items(): # type: ignore[assignment] + if (foreach is None and _has_foreach_support(device_grads, device)) or ( + foreach and _device_has_foreach_support(device) + ): + norms.extend(torch._foreach_norm(device_grads, norm_type)) + elif foreach: + raise RuntimeError( + f"foreach=True was passed, but can't use the foreach API on {device.type} tensors" + ) + else: + norms.extend([torch.linalg.vector_norm(g, norm_type) for g in device_grads]) + + total_norm = torch.linalg.vector_norm( + torch.stack([norm.to(first_device) for norm in norms]), norm_type + ) + + if error_if_nonfinite and torch.logical_or(total_norm.isnan(), total_norm.isinf()): + raise RuntimeError( + f"The total norm of order {norm_type} for gradients from " + "`parameters` is non-finite, so it cannot be clipped. To disable " + "this error and scale the gradients by the non-finite norm anyway, " + "set `error_if_nonfinite=False`" + ) + clip_coef = max_norm / (total_norm + 1e-6) + # Note: multiplying by the clamped coef is redundant when the coef is clamped to 1, but doing so + # avoids a `if clip_coef < 1:` conditional which can require a CPU <=> device synchronization + # when the gradients do not reside in CPU memory. + clip_coef_clamped = torch.clamp(clip_coef, max=1.0) + for (device, _), ([device_grads], _) in grouped_grads.items(): # type: ignore[assignment] + if (foreach is None and _has_foreach_support(device_grads, device)) or ( + foreach and _device_has_foreach_support(device) + ): + torch._foreach_mul_(device_grads, clip_coef_clamped.to(device)) + elif foreach: + raise RuntimeError( + f"foreach=True was passed, but can't use the foreach API on {device.type} tensors" + ) + else: + clip_coef_clamped_device = clip_coef_clamped.to(device) + for g in device_grads: + g.mul_(clip_coef_clamped_device) + + return total_norm diff --git a/bytelatent/preprocess/preprocess_entropies.py b/bytelatent/preprocess/preprocess_entropies.py index 519da94..31a4802 100644 --- a/bytelatent/preprocess/preprocess_entropies.py +++ b/bytelatent/preprocess/preprocess_entropies.py @@ -15,29 +15,37 @@ from bytelatent.entropy_model import load_entropy_model from bytelatent.tokenizers.build_tokenizer import TokenizerArgs -def get_id_from_doc(doc: dict) -> int: +def get_id_key(doc: dict) -> int: """ We need a reliable way to ensure that samples from jsonl and arrow are the same, but there is no unique id field, so derive the best possible """ if "sample_id" in doc: - sample_id = doc["sample_id"] + return "sample_id" elif "title" in doc: - sample_id = doc["title"] + return "title" elif "qid" in doc: - sample_id = doc["qid"] + return "qid" elif "paper_id" in doc: - sample_id = doc["paper_id"] + return "paper_id" elif "path" in doc: - sample_id = doc["path"] + return "path" elif "url" in doc: - sample_id = doc["url"] + return "url" elif "id" in doc: - sample_id = doc["id"] + return "id" else: raise ValueError(f"Could not find a id key from: {doc.keys()}") - return str(sample_id) + + +def get_id_from_doc(doc: dict) -> int: + """ + We need a reliable way to ensure that samples from jsonl + and arrow are the same, but there is no unique id field, + so derive the best possible + """ + return str(doc[get_id_key(doc)]) def get_text(doc: dict): diff --git a/bytelatent/train.py b/bytelatent/train.py index 6b20ecd..2d43270 100644 --- a/bytelatent/train.py +++ b/bytelatent/train.py @@ -2,6 +2,8 @@ # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. import gc +import math +import numpy as np import logging import os import sys @@ -25,6 +27,7 @@ from torch.optim import lr_scheduler from bytelatent.args import TrainArgs, parse_args from bytelatent.checkpoint import CheckpointManager, load_from_checkpoint +from bytelatent.data.file_util import get_fs from bytelatent.data.iterators.multiprocess_iterator import ( MultiprocessIterator, MultiprocessIteratorState, @@ -33,7 +36,9 @@ from bytelatent.data.iterators.packing_iterator import PackingIteratorState from bytelatent.distributed import ( check_model_value_range, clean_env, + dist_mean, dist_mean_dict, + dist_sum, get_device_mesh, get_is_master, get_world_size, @@ -47,6 +52,7 @@ from bytelatent.eval import EVAL_FOLDER_NAME, launch_eval from bytelatent.logger import init_logger from bytelatent.metrics import GPUMemoryMonitor, MetricLogger, get_num_params from bytelatent.model.blt import ByteLatentTransformer +from bytelatent.norms import fixed_clip_grad_norm_ from bytelatent.optim import build_optimizer from bytelatent.probe import AutoProbeD from bytelatent.profiling import maybe_run_profiler @@ -295,8 +301,11 @@ def train(args: TrainArgs): if args.checkpoint.init_ckpt_path: logger.info(f"Loading initial model from {args.checkpoint.init_ckpt_path}") + ckpt_fs = get_fs( + args.checkpoint.init_ckpt_path, s3_profile=args.checkpoint.s3_profile + ) load_from_checkpoint( - args.checkpoint.init_ckpt_path, model, model_key="model" + ckpt_fs, args.checkpoint.init_ckpt_path, model, model_key="model" ) # Put model_key="" if its directly the model checkpoint model.rope_embeddings.reset_parameters() # For RoPe initialization since it's a buffer it might not be loaded else: @@ -364,6 +373,9 @@ def train(args: TrainArgs): time_last_log = timer() gc.collect() saved = False + step_losses: list[float] = [] + step_tok_losses: list[float] = [] + n_bytes: int = 0 while train_state.step < args.steps and ( args.max_steps is None or train_state.step < args.max_steps ): @@ -385,6 +397,24 @@ def train(args: TrainArgs): batch_patch_lengths = torch.from_numpy(batch.patch_lengths).cuda() mask = None if batch.mask is None else torch.from_numpy(batch.mask).cuda() + if args.data.tokenizer_args.name in ["bytes", "blt"]: + if mask is None: + n_bytes += batch_y.numel() + else: + n_bytes += mask.sum() + elif args.data.tokenizer_args.name in ["sp", "tiktoken"]: + for example in batch.y: + target_tokens = tokenizer.decode(example.tolist(), cut_at_eos=False) + n_bytes += ( + len(bytes(target_tokens, encoding="utf-8", errors="ignore")) + + sum(example == tokenizer.eos_id) + + sum(example == tokenizer.bos_id) + ) + else: + raise ValueError( + f"Unexpected tokenizer to count n_bytes for: {args.data.tokenizer_args.name}" + ) + if ( not args.train_entropy_model and args.model.encoder_enable_byte_ngrams @@ -459,7 +489,7 @@ def train(args: TrainArgs): batch_x, patch_lengths=batch_patch_lengths, ngram_ids=ngram_ids ) - loss, _ = compute_loss(pred, batch_y, mask, train_state.scale) + loss, tok_loss = compute_loss(pred, batch_y, mask, train_state.scale) # We scale loss with grad_acc_steps so the gradient is the same # regardless of grad_acc_steps @@ -470,8 +500,14 @@ def train(args: TrainArgs): # For logging we undo that scaling loss = loss.detach() * args.grad_acc_steps - grad_norm = torch.nn.utils.clip_grad_norm_( - model.parameters(), max_norm=args.optim.clip, foreach=True + # Undo loss scaling so downstream down't need to worry about it + step_losses.append((loss / train_state.scale).item()) + step_tok_losses.append(tok_loss / train_state.scale) + + # grad_norm = torch.nn.utils.clip_grad_norm_( + grad_norm = fixed_clip_grad_norm_( + model.parameters(), + max_norm=args.optim.clip, # foreach=True ) grad_norm = ( @@ -559,20 +595,33 @@ def train(args: TrainArgs): gpu_memory_monitor.reset_peak_stats() nwords_since_last_log = 0 time_last_log = timer() + stacked_tok_loss = torch.cat(step_tok_losses, dim=0) + total_tok_loss = dist_sum( + stacked_tok_loss.sum().item(), reduce_dtype=torch.bfloat16 + ) + total_n_bytes = dist_sum(n_bytes, reduce_dtype=torch.bfloat16) + avg_bpb = total_tok_loss / math.log(2) / total_n_bytes + avg_loss = dist_mean(np.mean(step_losses).item()) logger.info( f"step: {train_state.step}" f" acc: {train_state.acc_step}" - f" loss: {round(loss.item(),4):>7}" + f" loss: step={round(loss.item(),4):>7} avg={avg_loss}" + f" bpb: {avg_bpb:3f}" f" grad: {grad_norm:.2e}" f" flops: {FLOPS:.2e}" f" wps: {wps:.2e}" f" iter: {curr_iter_time:>7}" f" data: {data_load_time:>5}" f" lr: {curr_lr:.2e}" + f" n_bytes={total_n_bytes}" f" mem: {gpu_mem_stats.max_active_pct:.0f}%" f" pow: {gpu_mem_stats.power_draw/1000} W" ) + n_bytes = 0 + step_losses = [] + step_tok_losses = [] + if every_n_steps( train_state, args.checkpoint.dump.every, acc_step=0 ) or every_n_steps(train_state, args.checkpoint.eval.every, acc_step=0):