From 28016f144d5fbcc2f279955f0038463186f0de30 Mon Sep 17 00:00:00 2001 From: Pedro Rodriguez Date: Thu, 9 Jan 2025 20:06:18 +0000 Subject: [PATCH 01/59] Add plotting code from paper Summary: Test Plan: --- bytelatent/plotting/config_entropy_figure.yaml | 3 ++- bytelatent/plotting/entropy_figure.py | 17 ++++++++++++++++- plot_data/scores.json | 1 + 3 files changed, 19 insertions(+), 2 deletions(-) create mode 100644 plot_data/scores.json diff --git a/bytelatent/plotting/config_entropy_figure.yaml b/bytelatent/plotting/config_entropy_figure.yaml index 4d7bfd7..296ea07 100644 --- a/bytelatent/plotting/config_entropy_figure.yaml +++ b/bytelatent/plotting/config_entropy_figure.yaml @@ -1,3 +1,4 @@ data_path: plot_data/entropy_figure.json chart_path: figures/entropy_figure.pdf -# chart_path: figures/entropy_figure.pdf +threshold_override: 1.7171002626419067 +score_override_path: plot_data/scores.json diff --git a/bytelatent/plotting/entropy_figure.py b/bytelatent/plotting/entropy_figure.py index c1966a1..f401d7c 100644 --- a/bytelatent/plotting/entropy_figure.py +++ b/bytelatent/plotting/entropy_figure.py @@ -1,4 +1,5 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. +import json import os import sys from pathlib import Path @@ -12,6 +13,8 @@ from pydantic import BaseModel class PlotEntropiesConfig(BaseModel): data_path: str | None chart_path: str + score_override_path: str | None = None + threshold_override: float | None = None class Config: extra = "forbid" @@ -37,8 +40,20 @@ def main(): plot_config = PlotEntropiesConfig(**conf_dict) with open(plot_config.data_path) as f: json_data = f.read() + plot_data = PlotEntropiesData.model_validate_json(json_data) df = pd.read_json(plot_data.dataframe_json) + print("LEN", len(df)) + if plot_config.threshold_override is None: + threshold = plot_data.threshold + else: + threshold = plot_config.threshold_override + if plot_config.score_override_path is not None: + with open(plot_config.score_override_path) as f: + scores = json.load(f)["score"] + assert len(scores) == len(df) + df["entropies"] = scores + df["start"] = [1] + (df["entropies"] > threshold).values.tolist()[:-1] x_ticks = [] for row in df.itertuples(): @@ -65,7 +80,7 @@ def main(): ), ) rule = base.mark_rule(color="red", strokeDash=[4, 4]).encode( - y=alt.datum(plot_data.threshold), + y=alt.datum(threshold), ) patch_rules = ( alt.Chart(df[df["start"] > 0]) diff --git a/plot_data/scores.json b/plot_data/scores.json new file mode 100644 index 0000000..202cafc --- /dev/null +++ b/plot_data/scores.json @@ -0,0 +1 @@ +{"score": [3.3949153423309326, 2.1647746562957764, 2.3216569423675537, 2.8114914894104004, 1.505232334136963, 0.04055612534284592, 0.09150367230176926, 0.06008715182542801, 0.3453567624092102, 1.0483067035675049, 0.1967127025127411, 0.12737397849559784, 0.05923430994153023, 0.001597292022779584, 0.004362526815384626, 0.005547997076064348, 0.0011689786333590746, 0.0010273229563608766, 1.0228447914123535, 3.6863417625427246, 0.46605175733566284, 0.048645928502082825, 2.2544963359832764, 0.37329360842704773, 1.001160979270935, 2.9116122722625732, 1.8948925733566284, 1.4017235040664673, 0.3879640996456146, 0.2652309536933899, 1.780383825302124, 0.013964788988232613, 0.005456871818751097, 0.5426468253135681, 0.20666983723640442, 0.0051853349432349205, 0.0005802579107694328, 0.0007443525246344507, 0.0004390323010738939, 0.005452247802168131, 1.1932975053787231, 0.023798620328307152, 3.1230878829956055, 1.3915895223617554, 3.0489213466644287, 1.7018193006515503, 1.873910903930664, 1.4662408828735352, 0.004920408595353365, 0.02599342167377472, 0.6620859503746033, 0.31743818521499634, 2.8409600257873535, 1.1354060173034668, 0.0520976223051548, 0.3519965708255768, 0.40707266330718994, 2.5438783168792725, 1.3343133926391602, 0.023993035778403282, 3.445943832397461, 1.8542104959487915, 0.7849258780479431, 0.6848396062850952, 0.06938046962022781, 0.20923230051994324, 0.10084306448698044, 0.18334199488162994, 0.4126923978328705, 0.5505472421646118, 0.1042013093829155, 0.019447727128863335, 0.0014866517158225179, 0.0009848219342529774, 0.00021391961490735412, 0.007746236398816109, 0.00038792978739365935, 0.0007933690212666988, 1.2369810342788696, 0.4436197578907013, 4.6366687456611544e-05]} \ No newline at end of file From 84854423c49fa59187855785c935d6cdbc9e0c45 Mon Sep 17 00:00:00 2001 From: Pedro Rodriguez Date: Fri, 10 Jan 2025 01:00:54 +0000 Subject: [PATCH 02/59] Replace regular filesystem calls with fsspec + add s3 support Summary: For compatibility with either local/nfs or S3 datasets, swap to fsspec. Add a tool to compare local and remote filesystems Test Plan: - Ran regular train script - Ran with config with data in S3 --- .gitignore | 2 +- bytelatent/args.py | 8 +- bytelatent/data/file_util.py | 116 ++++++++++++++++++++ bytelatent/data/iterators/arrow_iterator.py | 95 +++++++++++++--- bytelatent/logger.py | 2 +- requirements.txt | 1 + setup/download_prepare_hf_data.py | 17 ++- 7 files changed, 222 insertions(+), 19 deletions(-) create mode 100644 bytelatent/data/file_util.py diff --git a/.gitignore b/.gitignore index 56891a9..6c664b8 100644 --- a/.gitignore +++ b/.gitignore @@ -165,4 +165,4 @@ cython_debug/ figures/ .vscode/ .DS_Store - +internal/ diff --git a/bytelatent/args.py b/bytelatent/args.py index 4fba100..cfba3bf 100644 --- a/bytelatent/args.py +++ b/bytelatent/args.py @@ -46,8 +46,11 @@ def distribute_data_to_rank( arrow_batch_size: int, rank: int, world_size: int, + s3_profile: str | None = None, ) -> ArrowFileIterator: - dataset_chunks = find_and_sanitize_chunks(dataset_path, world_size) + dataset_chunks = find_and_sanitize_chunks( + dataset_path, world_size, s3_profile=s3_profile + ) n_workers_per_chunk = world_size // len(dataset_chunks) rank_to_arrow_iterator_params = [] for chunk_path in dataset_chunks: @@ -61,6 +64,7 @@ def distribute_data_to_rank( dataset_files=None, entropy_model_name=entropy_model_name, arrow_batch_size=arrow_batch_size, + s3_profile=s3_profile, ) ) return rank_to_arrow_iterator_params[rank] @@ -68,6 +72,7 @@ def distribute_data_to_rank( class DataloaderArgs(BaseModel): model_config = ConfigDict(extra="forbid") + s3_profile: str | None = None root_dir: str | None = None sources: dict[str, float] = {} batch_size: int = 2 @@ -107,6 +112,7 @@ class DataloaderArgs(BaseModel): arrow_batch_size=self.arrow_batch_size, rank=rank, world_size=world_size, + s3_profile=self.s3_profile, ) looping_iterator = LoopingIterator(arrow_iterator) preprocess_iterator = PreprocessIterator( diff --git a/bytelatent/data/file_util.py b/bytelatent/data/file_util.py new file mode 100644 index 0000000..312e1b4 --- /dev/null +++ b/bytelatent/data/file_util.py @@ -0,0 +1,116 @@ +import os +from rich.progress import track +import typer +import fsspec +import pyarrow as pa +from pyarrow.lib import ArrowInvalid + +# pyarrow needs the initialization from this import +import pyarrow.dataset # pyright: ignore + + +def is_valid_arrow_file(path: str): + try: + dataset = pa.dataset.dataset(path, format="arrow") + return True + except ArrowInvalid: + return False + + +app = typer.Typer() + +S3_PREFIX = "s3://" + + +def get_fs(path: str, s3_profile: str | None = None) -> fsspec.AbstractFileSystem: + if path.startswith("s3://"): + if s3_profile is None: + return fsspec.filesystem("s3") + else: + return fsspec.filesystem("s3", profile=s3_profile) + else: + return fsspec.filesystem("file") + + +@app.command() +def print_local_to_delete( + blob_dir: str, local_dirs: list[str], s3_profile: str = "blt" +): + for s in local_dirs: + assert s.endswith("/"), "Dirs must end with /" + assert blob_dir.endswith("/"), "Dirs must end with /" + blob_fs = fsspec.filesystem("s3", profile=s3_profile) + blob_files = blob_fs.find(blob_dir) + for f in track(blob_files): + size = blob_fs.info(f)["Size"] + if not f.lower().endswith(".complete"): + assert size != 0, f"Size was invalidly zero for {f}" + + blob_relative_paths = {f[len(blob_dir) - len(S3_PREFIX) :] for f in blob_files} + local_fs = fsspec.filesystem("file") + + files_to_delete = [] + for local_dir in local_dirs: + local_files = local_fs.find(local_dir) + for f in local_files: + relative_path = f[len(local_dir) :] + if relative_path in blob_relative_paths and not os.path.islink(f): + files_to_delete.append(f) + print(len(files_to_delete)) + with open("/tmp/files_to_delete.txt", "w") as f: + for file in files_to_delete: + f.write(f"{file}\n") + + +@app.command() +def compare_local_to_blob( + source_dirs: list[str], dst_dir: str, s3_profile: str = "blt" +): + for s in source_dirs: + assert s.endswith("/"), "Dirs must end with /" + assert dst_dir.endswith("/"), "Dirs must end with /" + assert len(source_dirs) != 0 + assert dst_dir.startswith("s3://") + local_fs = fsspec.filesystem("file") + dst_fs = fsspec.filesystem("s3", profile=s3_profile) + source_to_files = {} + all_local_files = set() + for s in source_dirs: + skipped = [] + if s not in source_to_files: + source_to_files[s] = [] + for f in local_fs.find(s): + if os.path.islink(f): + continue + if f.endswith(".COMPLETE") or f.endswith(".complete"): + is_complete_file = True + assert os.path.getsize(f) == 0, ".COMPLETE files should be empty" + else: + is_complete_file = False + + if not is_complete_file and os.path.getsize(f) == 0: + skipped.append(f) + continue + if f.endswith(".arrow"): + if not is_valid_arrow_file(f): + skipped.append(f) + continue + + source_to_files[s].append(f) + all_local_files.add(f[len(s) :]) + print(s, len(source_to_files[s]), "skipped", len(skipped), skipped[:10]) + + dst_files = dst_fs.find(dst_dir) + print(dst_dir, len(dst_files)) + + dst_file_set = {f[len(dst_dir) - len(S3_PREFIX) :] for f in dst_files} + diff = all_local_files.symmetric_difference(dst_file_set) + print("Local files", len(all_local_files)) + print("DST Files", len(dst_file_set)) + print("Symmetric difference", len(diff)) + dst_only_files = dst_file_set - all_local_files + print("DST only", len(dst_only_files), list(dst_only_files)[:10]) + + +if __name__ == "__main__": + app() diff --git a/bytelatent/data/iterators/arrow_iterator.py b/bytelatent/data/iterators/arrow_iterator.py index df5f023..17569ee 100644 --- a/bytelatent/data/iterators/arrow_iterator.py +++ b/bytelatent/data/iterators/arrow_iterator.py @@ -1,7 +1,7 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. +import os import re from logging import getLogger -from pathlib import Path from typing import Any, Generator import pyarrow as pa @@ -9,9 +9,12 @@ import pyarrow as pa # pyarrow needs the initialization from this import import pyarrow.dataset # pyright: ignore from pydantic import BaseModel, ConfigDict +import fsspec +import s3fs from bytelatent import ByteLatentError from bytelatent.data.data_types import BltExample +from bytelatent.data.file_util import get_fs from bytelatent.data.iterators.abstract_iterator import IteratorState, StatefulIterator logger = getLogger(__name__) @@ -27,6 +30,8 @@ class ArrowFileIteratorState(BaseModel, IteratorState): dataset_files: list[str] | None entropy_model_name: str | None arrow_batch_size: int = 100 + s3_profile: str | None + filesystem_type: str | None = None def build(self) -> "ArrowFileIterator": arrow_file = ArrowFileIterator( @@ -37,14 +42,17 @@ class ArrowFileIteratorState(BaseModel, IteratorState): entropy_model_name=self.entropy_model_name, arrow_batch_size=self.arrow_batch_size, dataset_files=self.dataset_files, + s3_profile=self.s3_profile, + filesystem_type=self.filesystem_type, ) if self.row_num != 0: arrow_file._set_row_num(self.row_num) return arrow_file -def shard_sort_key(file: str | Path): - match = re.search(r".+\.shard_([0-9]+)\.arrow", str(file)) +def shard_sort_key(file: str): + assert isinstance(file, str) + match = re.search(r".+\.shard_([0-9]+)\.arrow", file) shard_number = int(match.group(1)) return shard_number @@ -60,6 +68,8 @@ class ArrowFileIterator(StatefulIterator): entropy_model_name: str | None, arrow_batch_size: int, dataset_files: list[str] | None = None, + s3_profile: str | None = None, + filesystem_type: str | None = None, ): assert 0 <= worker_id < num_workers, (worker_id, num_workers) if file_path is None and dataset_files is None: @@ -75,16 +85,44 @@ class ArrowFileIterator(StatefulIterator): self.preprocess_dir = preprocess_dir self.entropy_model_name = entropy_model_name self.arrow_batch_size = arrow_batch_size + self.s3_profile = s3_profile + self.filesystem_type = filesystem_type + self.fs = None + if self.filesystem_type is not None: + if self.filesystem_type == "file": + self.fs = fsspec.filesystem("file") + elif self.filesystem_type == "s3": + self.fs = fsspec.filesystem("s3", profile=s3_profile) + if dataset_files is None: # Prepare arrow shards - jsonl_file = Path(file_path) - parts = re.match(r"(.+)\.chunk\.[0-9]+\.jsonl", jsonl_file.name) + jsonl_file = file_path + parts = re.match( + r"(.+)\.chunk\.[0-9]+\.jsonl", os.path.basename(jsonl_file) + ) assert parts is not None dataset = parts.group(1) - data_dir = Path(preprocess_dir) / dataset / entropy_model_name - shard_files = list(data_dir.glob(f"{jsonl_file.name}.shard_*.arrow")) + data_dir = os.path.join(preprocess_dir, dataset, entropy_model_name) + data_dir_with_glob = os.path.join( + data_dir, f"{os.path.basename(jsonl_file)}.shard_*.arrow" + ) + print("preprocess_dir", preprocess_dir) + print("data_dir_with_glob", data_dir_with_glob) + if self.fs is None: + self.fs = get_fs(data_dir_with_glob, s3_profile=s3_profile) + if isinstance(self.fs, s3fs.S3FileSystem): + self.filesystem_type = "s3" + else: + self.filesystem_type = "file" + print("fs", self.fs) + shard_files = self.fs.glob(data_dir_with_glob) + for s in shard_files: - if not (data_dir / f"{s.name}.complete").exists(): + complete_file = os.path.join( + data_dir, f"{os.path.basename(s)}.complete" + ) + + if not self.fs.exists(complete_file): raise ValueError(f"Missing .complete for input file: {s}") shard_files = sorted(shard_files, key=shard_sort_key) @@ -92,10 +130,19 @@ class ArrowFileIterator(StatefulIterator): raise ByteLatentError( f"Zero shard_files found corresponding to: {file_path} using preprocess_dir={preprocess_dir} and entropy_model_name={entropy_model_name}, so the search path is data_dir={data_dir} for matches to {jsonl_file.name}.shard_*.arrow" ) - self.dataset_files = [str(f) for f in shard_files] + self.dataset_files = [f for f in shard_files] else: self.preprocess_dir = None self.dataset_files = dataset_files + if dataset_files[0].startswith("s3://"): + for f in dataset_files: + assert f.startswith("s3://") + if self.fs is None: + self.fs = get_fs(dataset_files[0], s3_profile=s3_profile) + if isinstance(self.fs, s3fs.S3FileSystem): + self.filesystem_type = "s3" + else: + self.filesystem_type = "file" def get_state(self) -> ArrowFileIteratorState: return ArrowFileIteratorState( @@ -107,13 +154,24 @@ class ArrowFileIterator(StatefulIterator): entropy_model_name=self.entropy_model_name, arrow_batch_size=self.arrow_batch_size, dataset_files=self.dataset_files, + s3_profile=self.s3_profile, + filesystem_type=self.filesystem_type, ) def create_iter( self, ) -> Generator[BltExample, Any, None]: if self.dataset is None: - self.dataset = pa.dataset.dataset(self.dataset_files, format="arrow") + if isinstance(self.fs, s3fs.core.S3FileSystem): + print("Using s3fs") + filesystem = self.fs + else: + print("Using local fs") + filesystem = None + print("Iterating over", self.dataset_files) + self.dataset = pa.dataset.dataset( + self.dataset_files, format="arrow", filesystem=filesystem + ) self.batch_iterator = self.dataset.to_batches( batch_size=self.arrow_batch_size ) @@ -165,7 +223,13 @@ class ArrowFileIterator(StatefulIterator): self.batch_iterator = None self.batch_to_consume = None else: - self.dataset = pa.dataset.dataset(self.dataset_files, format="arrow") + if isinstance(self.fs, s3fs.core.S3FileSystem): + filesystem = self.fs + else: + filesystem = None + self.dataset = pa.dataset.dataset( + self.dataset_files, format="arrow", filesystem=filesystem + ) self.batch_iterator = self.dataset.to_batches( batch_size=self.arrow_batch_size ) @@ -198,9 +262,14 @@ TRAIN_DATA_FILE_PATTERN = "*.chunk.*.jsonl" def find_and_sanitize_chunks( - dataset_path: str, world_size: int, file_pattern: str = TRAIN_DATA_FILE_PATTERN + dataset_path: str, + world_size: int, + file_pattern: str = TRAIN_DATA_FILE_PATTERN, + s3_profile: str | None = None, ): - dataset_chunks = [str(p) for p in Path(dataset_path).glob(file_pattern)] + fs = get_fs(dataset_path, s3_profile=s3_profile) + path_with_glob = os.path.join(dataset_path, file_pattern) + dataset_chunks = fs.glob(path_with_glob) n_chunks = len(dataset_chunks) if n_chunks > world_size: diff --git a/bytelatent/logger.py b/bytelatent/logger.py index 6723a84..87f04cc 100644 --- a/bytelatent/logger.py +++ b/bytelatent/logger.py @@ -91,7 +91,7 @@ def init_logger( log_file: str | None = None, *, name: str | None = None, - level: str = "NOTSET", + level: str = "INFO", ): """ Setup logging. diff --git a/requirements.txt b/requirements.txt index c6d87f1..59192cd 100644 --- a/requirements.txt +++ b/requirements.txt @@ -20,3 +20,4 @@ altair submitit typer rich +fsspec[full] diff --git a/setup/download_prepare_hf_data.py b/setup/download_prepare_hf_data.py index 1aacf8f..e194857 100644 --- a/setup/download_prepare_hf_data.py +++ b/setup/download_prepare_hf_data.py @@ -5,6 +5,7 @@ import os import subprocess import time +import fsspec import requests from huggingface_hub import snapshot_download @@ -38,11 +39,21 @@ def download_dataset(repo_id, local_dir, allow_patterns): print(f"Dataset downloaded to {local_dir}") -def parquet_to_jsonl(dataset, work_dir, src_dir, tgt_dir, ntasks=64): +def parquet_to_jsonl( + dataset, work_dir, src_dir, tgt_dir, ntasks=64, s3_profile: str | None = None +): from datatrove.executor import LocalPipelineExecutor from datatrove.pipeline.readers import ParquetReader from datatrove.pipeline.writers import JsonlWriter + if tgt_dir.startswith("s3//"): + if s3_profile is None: + out_spec = tgt_dir + else: + out_spec = (tgt_dir, fsspec.filesystem("s3", profile=s3_profile)) + else: + out_spec = tgt_dir + pipeline_exec = LocalPipelineExecutor( pipeline=[ ParquetReader( @@ -52,7 +63,7 @@ def parquet_to_jsonl(dataset, work_dir, src_dir, tgt_dir, ntasks=64): glob_pattern="**/*.parquet", ), JsonlWriter( - tgt_dir, + out_spec, output_filename=dataset + ".chunk.${rank}.jsonl", compression=None, ), @@ -77,7 +88,7 @@ def setup_terashuf(work_dir): return terashuf_dir -def main(dataset, memory, data_dir, seed=42, nchunks=32): +def main(dataset, memory, data_dir, seed=42, nchunks=32, s3_profile: str | None = None): # Configuration repo_id = { "fineweb_edu": "HuggingFaceFW/fineweb-edu", From c137f4e636ee85ef4110b416b404c455b86efba9 Mon Sep 17 00:00:00 2001 From: Pedro Rodriguez Date: Fri, 10 Jan 2025 01:02:25 +0000 Subject: [PATCH 03/59] Replace regular filesystem calls with fsspec + add s3 support Summary: For compatibility with either local/nfs or S3 datasets, swap to fsspec. Add a tool to compare local and remote filesystems Test Plan: - Ran regular train script - Ran with config with data in S3 --- .gitignore | 2 +- bytelatent/args.py | 8 +- bytelatent/data/file_util.py | 116 ++++++++++++++++++++ bytelatent/data/iterators/arrow_iterator.py | 95 +++++++++++++--- bytelatent/logger.py | 2 +- requirements.txt | 1 + setup/download_prepare_hf_data.py | 17 ++- 7 files changed, 222 insertions(+), 19 deletions(-) create mode 100644 bytelatent/data/file_util.py diff --git a/.gitignore b/.gitignore index 56891a9..6c664b8 100644 --- a/.gitignore +++ b/.gitignore @@ -165,4 +165,4 @@ cython_debug/ figures/ .vscode/ .DS_Store - +internal/ diff --git a/bytelatent/args.py b/bytelatent/args.py index 4fba100..cfba3bf 100644 --- a/bytelatent/args.py +++ b/bytelatent/args.py @@ -46,8 +46,11 @@ def distribute_data_to_rank( arrow_batch_size: int, rank: int, world_size: int, + s3_profile: str | None = None, ) -> ArrowFileIterator: - dataset_chunks = find_and_sanitize_chunks(dataset_path, world_size) + dataset_chunks = find_and_sanitize_chunks( + dataset_path, world_size, s3_profile=s3_profile + ) n_workers_per_chunk = world_size // len(dataset_chunks) rank_to_arrow_iterator_params = [] for chunk_path in dataset_chunks: @@ -61,6 +64,7 @@ def distribute_data_to_rank( dataset_files=None, entropy_model_name=entropy_model_name, arrow_batch_size=arrow_batch_size, + s3_profile=s3_profile, ) ) return rank_to_arrow_iterator_params[rank] @@ -68,6 +72,7 @@ def distribute_data_to_rank( class DataloaderArgs(BaseModel): model_config = ConfigDict(extra="forbid") + s3_profile: str | None = None root_dir: str | None = None sources: dict[str, float] = {} batch_size: int = 2 @@ -107,6 +112,7 @@ class DataloaderArgs(BaseModel): arrow_batch_size=self.arrow_batch_size, rank=rank, world_size=world_size, + s3_profile=self.s3_profile, ) looping_iterator = LoopingIterator(arrow_iterator) preprocess_iterator = PreprocessIterator( diff --git a/bytelatent/data/file_util.py b/bytelatent/data/file_util.py new file mode 100644 index 0000000..312e1b4 --- /dev/null +++ b/bytelatent/data/file_util.py @@ -0,0 +1,116 @@ +import os +from rich.progress import track +import typer +import fsspec +import pyarrow as pa +from pyarrow.lib import ArrowInvalid + +# pyarrow needs the initialization from this import +import pyarrow.dataset # pyright: ignore + + +def is_valid_arrow_file(path: str): + try: + dataset = pa.dataset.dataset(path, format="arrow") + return True + except ArrowInvalid: + return False + + +app = typer.Typer() + +S3_PREFIX = "s3://" + + +def get_fs(path: str, s3_profile: str | None = None) -> fsspec.AbstractFileSystem: + if path.startswith("s3://"): + if s3_profile is None: + return fsspec.filesystem("s3") + else: + return fsspec.filesystem("s3", profile=s3_profile) + else: + return fsspec.filesystem("file") + + +@app.command() +def print_local_to_delete( + blob_dir: str, local_dirs: list[str], s3_profile: str = "blt" +): + for s in local_dirs: + assert s.endswith("/"), "Dirs must end with /" + assert blob_dir.endswith("/"), "Dirs must end with /" + blob_fs = fsspec.filesystem("s3", profile=s3_profile) + blob_files = blob_fs.find(blob_dir) + for f in track(blob_files): + size = blob_fs.info(f)["Size"] + if not f.lower().endswith(".complete"): + assert size != 0, f"Size was invalidly zero for {f}" + + blob_relative_paths = {f[len(blob_dir) - len(S3_PREFIX) :] for f in blob_files} + local_fs = fsspec.filesystem("file") + + files_to_delete = [] + for local_dir in local_dirs: + local_files = local_fs.find(local_dir) + for f in local_files: + relative_path = f[len(local_dir) :] + if relative_path in blob_relative_paths and not os.path.islink(f): + files_to_delete.append(f) + print(len(files_to_delete)) + with open("/tmp/files_to_delete.txt", "w") as f: + for file in files_to_delete: + f.write(f"{file}\n") + + +@app.command() +def compare_local_to_blob( + source_dirs: list[str], dst_dir: str, s3_profile: str = "blt" +): + for s in source_dirs: + assert s.endswith("/"), "Dirs must end with /" + assert dst_dir.endswith("/"), "Dirs must end with /" + assert len(source_dirs) != 0 + assert dst_dir.startswith("s3://") + local_fs = fsspec.filesystem("file") + dst_fs = fsspec.filesystem("s3", profile=s3_profile) + source_to_files = {} + all_local_files = set() + for s in source_dirs: + skipped = [] + if s not in source_to_files: + source_to_files[s] = [] + for f in local_fs.find(s): + if os.path.islink(f): + continue + if f.endswith(".COMPLETE") or f.endswith(".complete"): + is_complete_file = True + assert os.path.getsize(f) == 0, ".COMPLETE files should be empty" + else: + is_complete_file = False + + if not is_complete_file and os.path.getsize(f) == 0: + skipped.append(f) + continue + if f.endswith(".arrow"): + if not is_valid_arrow_file(f): + skipped.append(f) + continue + + source_to_files[s].append(f) + all_local_files.add(f[len(s) :]) + print(s, len(source_to_files[s]), "skipped", len(skipped), skipped[:10]) + + dst_files = dst_fs.find(dst_dir) + print(dst_dir, len(dst_files)) + + dst_file_set = {f[len(dst_dir) - len(S3_PREFIX) :] for f in dst_files} + diff = all_local_files.symmetric_difference(dst_file_set) + print("Local files", len(all_local_files)) + print("DST Files", len(dst_file_set)) + print("Symmetric difference", len(diff)) + dst_only_files = dst_file_set - all_local_files + print("DST only", len(dst_only_files), list(dst_only_files)[:10]) + + +if __name__ == "__main__": + app() diff --git a/bytelatent/data/iterators/arrow_iterator.py b/bytelatent/data/iterators/arrow_iterator.py index df5f023..17569ee 100644 --- a/bytelatent/data/iterators/arrow_iterator.py +++ b/bytelatent/data/iterators/arrow_iterator.py @@ -1,7 +1,7 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. +import os import re from logging import getLogger -from pathlib import Path from typing import Any, Generator import pyarrow as pa @@ -9,9 +9,12 @@ import pyarrow as pa # pyarrow needs the initialization from this import import pyarrow.dataset # pyright: ignore from pydantic import BaseModel, ConfigDict +import fsspec +import s3fs from bytelatent import ByteLatentError from bytelatent.data.data_types import BltExample +from bytelatent.data.file_util import get_fs from bytelatent.data.iterators.abstract_iterator import IteratorState, StatefulIterator logger = getLogger(__name__) @@ -27,6 +30,8 @@ class ArrowFileIteratorState(BaseModel, IteratorState): dataset_files: list[str] | None entropy_model_name: str | None arrow_batch_size: int = 100 + s3_profile: str | None + filesystem_type: str | None = None def build(self) -> "ArrowFileIterator": arrow_file = ArrowFileIterator( @@ -37,14 +42,17 @@ class ArrowFileIteratorState(BaseModel, IteratorState): entropy_model_name=self.entropy_model_name, arrow_batch_size=self.arrow_batch_size, dataset_files=self.dataset_files, + s3_profile=self.s3_profile, + filesystem_type=self.filesystem_type, ) if self.row_num != 0: arrow_file._set_row_num(self.row_num) return arrow_file -def shard_sort_key(file: str | Path): - match = re.search(r".+\.shard_([0-9]+)\.arrow", str(file)) +def shard_sort_key(file: str): + assert isinstance(file, str) + match = re.search(r".+\.shard_([0-9]+)\.arrow", file) shard_number = int(match.group(1)) return shard_number @@ -60,6 +68,8 @@ class ArrowFileIterator(StatefulIterator): entropy_model_name: str | None, arrow_batch_size: int, dataset_files: list[str] | None = None, + s3_profile: str | None = None, + filesystem_type: str | None = None, ): assert 0 <= worker_id < num_workers, (worker_id, num_workers) if file_path is None and dataset_files is None: @@ -75,16 +85,44 @@ class ArrowFileIterator(StatefulIterator): self.preprocess_dir = preprocess_dir self.entropy_model_name = entropy_model_name self.arrow_batch_size = arrow_batch_size + self.s3_profile = s3_profile + self.filesystem_type = filesystem_type + self.fs = None + if self.filesystem_type is not None: + if self.filesystem_type == "file": + self.fs = fsspec.filesystem("file") + elif self.filesystem_type == "s3": + self.fs = fsspec.filesystem("s3", profile=s3_profile) + if dataset_files is None: # Prepare arrow shards - jsonl_file = Path(file_path) - parts = re.match(r"(.+)\.chunk\.[0-9]+\.jsonl", jsonl_file.name) + jsonl_file = file_path + parts = re.match( + r"(.+)\.chunk\.[0-9]+\.jsonl", os.path.basename(jsonl_file) + ) assert parts is not None dataset = parts.group(1) - data_dir = Path(preprocess_dir) / dataset / entropy_model_name - shard_files = list(data_dir.glob(f"{jsonl_file.name}.shard_*.arrow")) + data_dir = os.path.join(preprocess_dir, dataset, entropy_model_name) + data_dir_with_glob = os.path.join( + data_dir, f"{os.path.basename(jsonl_file)}.shard_*.arrow" + ) + print("preprocess_dir", preprocess_dir) + print("data_dir_with_glob", data_dir_with_glob) + if self.fs is None: + self.fs = get_fs(data_dir_with_glob, s3_profile=s3_profile) + if isinstance(self.fs, s3fs.S3FileSystem): + self.filesystem_type = "s3" + else: + self.filesystem_type = "file" + print("fs", self.fs) + shard_files = self.fs.glob(data_dir_with_glob) + for s in shard_files: - if not (data_dir / f"{s.name}.complete").exists(): + complete_file = os.path.join( + data_dir, f"{os.path.basename(s)}.complete" + ) + + if not self.fs.exists(complete_file): raise ValueError(f"Missing .complete for input file: {s}") shard_files = sorted(shard_files, key=shard_sort_key) @@ -92,10 +130,19 @@ class ArrowFileIterator(StatefulIterator): raise ByteLatentError( f"Zero shard_files found corresponding to: {file_path} using preprocess_dir={preprocess_dir} and entropy_model_name={entropy_model_name}, so the search path is data_dir={data_dir} for matches to {jsonl_file.name}.shard_*.arrow" ) - self.dataset_files = [str(f) for f in shard_files] + self.dataset_files = [f for f in shard_files] else: self.preprocess_dir = None self.dataset_files = dataset_files + if dataset_files[0].startswith("s3://"): + for f in dataset_files: + assert f.startswith("s3://") + if self.fs is None: + self.fs = get_fs(dataset_files[0], s3_profile=s3_profile) + if isinstance(self.fs, s3fs.S3FileSystem): + self.filesystem_type = "s3" + else: + self.filesystem_type = "file" def get_state(self) -> ArrowFileIteratorState: return ArrowFileIteratorState( @@ -107,13 +154,24 @@ class ArrowFileIterator(StatefulIterator): entropy_model_name=self.entropy_model_name, arrow_batch_size=self.arrow_batch_size, dataset_files=self.dataset_files, + s3_profile=self.s3_profile, + filesystem_type=self.filesystem_type, ) def create_iter( self, ) -> Generator[BltExample, Any, None]: if self.dataset is None: - self.dataset = pa.dataset.dataset(self.dataset_files, format="arrow") + if isinstance(self.fs, s3fs.core.S3FileSystem): + print("Using s3fs") + filesystem = self.fs + else: + print("Using local fs") + filesystem = None + print("Iterating over", self.dataset_files) + self.dataset = pa.dataset.dataset( + self.dataset_files, format="arrow", filesystem=filesystem + ) self.batch_iterator = self.dataset.to_batches( batch_size=self.arrow_batch_size ) @@ -165,7 +223,13 @@ class ArrowFileIterator(StatefulIterator): self.batch_iterator = None self.batch_to_consume = None else: - self.dataset = pa.dataset.dataset(self.dataset_files, format="arrow") + if isinstance(self.fs, s3fs.core.S3FileSystem): + filesystem = self.fs + else: + filesystem = None + self.dataset = pa.dataset.dataset( + self.dataset_files, format="arrow", filesystem=filesystem + ) self.batch_iterator = self.dataset.to_batches( batch_size=self.arrow_batch_size ) @@ -198,9 +262,14 @@ TRAIN_DATA_FILE_PATTERN = "*.chunk.*.jsonl" def find_and_sanitize_chunks( - dataset_path: str, world_size: int, file_pattern: str = TRAIN_DATA_FILE_PATTERN + dataset_path: str, + world_size: int, + file_pattern: str = TRAIN_DATA_FILE_PATTERN, + s3_profile: str | None = None, ): - dataset_chunks = [str(p) for p in Path(dataset_path).glob(file_pattern)] + fs = get_fs(dataset_path, s3_profile=s3_profile) + path_with_glob = os.path.join(dataset_path, file_pattern) + dataset_chunks = fs.glob(path_with_glob) n_chunks = len(dataset_chunks) if n_chunks > world_size: diff --git a/bytelatent/logger.py b/bytelatent/logger.py index 6723a84..87f04cc 100644 --- a/bytelatent/logger.py +++ b/bytelatent/logger.py @@ -91,7 +91,7 @@ def init_logger( log_file: str | None = None, *, name: str | None = None, - level: str = "NOTSET", + level: str = "INFO", ): """ Setup logging. diff --git a/requirements.txt b/requirements.txt index c6d87f1..59192cd 100644 --- a/requirements.txt +++ b/requirements.txt @@ -20,3 +20,4 @@ altair submitit typer rich +fsspec[full] diff --git a/setup/download_prepare_hf_data.py b/setup/download_prepare_hf_data.py index 1aacf8f..e194857 100644 --- a/setup/download_prepare_hf_data.py +++ b/setup/download_prepare_hf_data.py @@ -5,6 +5,7 @@ import os import subprocess import time +import fsspec import requests from huggingface_hub import snapshot_download @@ -38,11 +39,21 @@ def download_dataset(repo_id, local_dir, allow_patterns): print(f"Dataset downloaded to {local_dir}") -def parquet_to_jsonl(dataset, work_dir, src_dir, tgt_dir, ntasks=64): +def parquet_to_jsonl( + dataset, work_dir, src_dir, tgt_dir, ntasks=64, s3_profile: str | None = None +): from datatrove.executor import LocalPipelineExecutor from datatrove.pipeline.readers import ParquetReader from datatrove.pipeline.writers import JsonlWriter + if tgt_dir.startswith("s3//"): + if s3_profile is None: + out_spec = tgt_dir + else: + out_spec = (tgt_dir, fsspec.filesystem("s3", profile=s3_profile)) + else: + out_spec = tgt_dir + pipeline_exec = LocalPipelineExecutor( pipeline=[ ParquetReader( @@ -52,7 +63,7 @@ def parquet_to_jsonl(dataset, work_dir, src_dir, tgt_dir, ntasks=64): glob_pattern="**/*.parquet", ), JsonlWriter( - tgt_dir, + out_spec, output_filename=dataset + ".chunk.${rank}.jsonl", compression=None, ), @@ -77,7 +88,7 @@ def setup_terashuf(work_dir): return terashuf_dir -def main(dataset, memory, data_dir, seed=42, nchunks=32): +def main(dataset, memory, data_dir, seed=42, nchunks=32, s3_profile: str | None = None): # Configuration repo_id = { "fineweb_edu": "HuggingFaceFW/fineweb-edu", From 1bf6d15e5a3ff76f128288fc547e2604f1981e2f Mon Sep 17 00:00:00 2001 From: Pedro Rodriguez Date: Fri, 10 Jan 2025 01:02:25 +0000 Subject: [PATCH 04/59] Replace regular filesystem calls with fsspec + add s3 support Summary: For compatibility with either local/nfs or S3 datasets, swap to fsspec. Add a tool to compare local and remote filesystems Test Plan: - Ran regular train script - Ran with config with data in S3 --- .gitignore | 2 +- bytelatent/args.py | 8 +- bytelatent/data/file_util.py | 116 ++++++++++++++++++++ bytelatent/data/iterators/arrow_iterator.py | 89 ++++++++++++--- bytelatent/logger.py | 2 +- requirements.txt | 1 + setup/download_prepare_hf_data.py | 17 ++- 7 files changed, 216 insertions(+), 19 deletions(-) create mode 100644 bytelatent/data/file_util.py diff --git a/.gitignore b/.gitignore index 56891a9..6c664b8 100644 --- a/.gitignore +++ b/.gitignore @@ -165,4 +165,4 @@ cython_debug/ figures/ .vscode/ .DS_Store - +internal/ diff --git a/bytelatent/args.py b/bytelatent/args.py index 4fba100..cfba3bf 100644 --- a/bytelatent/args.py +++ b/bytelatent/args.py @@ -46,8 +46,11 @@ def distribute_data_to_rank( arrow_batch_size: int, rank: int, world_size: int, + s3_profile: str | None = None, ) -> ArrowFileIterator: - dataset_chunks = find_and_sanitize_chunks(dataset_path, world_size) + dataset_chunks = find_and_sanitize_chunks( + dataset_path, world_size, s3_profile=s3_profile + ) n_workers_per_chunk = world_size // len(dataset_chunks) rank_to_arrow_iterator_params = [] for chunk_path in dataset_chunks: @@ -61,6 +64,7 @@ def distribute_data_to_rank( dataset_files=None, entropy_model_name=entropy_model_name, arrow_batch_size=arrow_batch_size, + s3_profile=s3_profile, ) ) return rank_to_arrow_iterator_params[rank] @@ -68,6 +72,7 @@ def distribute_data_to_rank( class DataloaderArgs(BaseModel): model_config = ConfigDict(extra="forbid") + s3_profile: str | None = None root_dir: str | None = None sources: dict[str, float] = {} batch_size: int = 2 @@ -107,6 +112,7 @@ class DataloaderArgs(BaseModel): arrow_batch_size=self.arrow_batch_size, rank=rank, world_size=world_size, + s3_profile=self.s3_profile, ) looping_iterator = LoopingIterator(arrow_iterator) preprocess_iterator = PreprocessIterator( diff --git a/bytelatent/data/file_util.py b/bytelatent/data/file_util.py new file mode 100644 index 0000000..312e1b4 --- /dev/null +++ b/bytelatent/data/file_util.py @@ -0,0 +1,116 @@ +import os +from rich.progress import track +import typer +import fsspec +import pyarrow as pa +from pyarrow.lib import ArrowInvalid + +# pyarrow needs the initialization from this import +import pyarrow.dataset # pyright: ignore + + +def is_valid_arrow_file(path: str): + try: + dataset = pa.dataset.dataset(path, format="arrow") + return True + except ArrowInvalid: + return False + + +app = typer.Typer() + +S3_PREFIX = "s3://" + + +def get_fs(path: str, s3_profile: str | None = None) -> fsspec.AbstractFileSystem: + if path.startswith("s3://"): + if s3_profile is None: + return fsspec.filesystem("s3") + else: + return fsspec.filesystem("s3", profile=s3_profile) + else: + return fsspec.filesystem("file") + + +@app.command() +def print_local_to_delete( + blob_dir: str, local_dirs: list[str], s3_profile: str = "blt" +): + for s in local_dirs: + assert s.endswith("/"), "Dirs must end with /" + assert blob_dir.endswith("/"), "Dirs must end with /" + blob_fs = fsspec.filesystem("s3", profile=s3_profile) + blob_files = blob_fs.find(blob_dir) + for f in track(blob_files): + size = blob_fs.info(f)["Size"] + if not f.lower().endswith(".complete"): + assert size != 0, f"Size was invalidly zero for {f}" + + blob_relative_paths = {f[len(blob_dir) - len(S3_PREFIX) :] for f in blob_files} + local_fs = fsspec.filesystem("file") + + files_to_delete = [] + for local_dir in local_dirs: + local_files = local_fs.find(local_dir) + for f in local_files: + relative_path = f[len(local_dir) :] + if relative_path in blob_relative_paths and not os.path.islink(f): + files_to_delete.append(f) + print(len(files_to_delete)) + with open("/tmp/files_to_delete.txt", "w") as f: + for file in files_to_delete: + f.write(f"{file}\n") + + +@app.command() +def compare_local_to_blob( + source_dirs: list[str], dst_dir: str, s3_profile: str = "blt" +): + for s in source_dirs: + assert s.endswith("/"), "Dirs must end with /" + assert dst_dir.endswith("/"), "Dirs must end with /" + assert len(source_dirs) != 0 + assert dst_dir.startswith("s3://") + local_fs = fsspec.filesystem("file") + dst_fs = fsspec.filesystem("s3", profile=s3_profile) + source_to_files = {} + all_local_files = set() + for s in source_dirs: + skipped = [] + if s not in source_to_files: + source_to_files[s] = [] + for f in local_fs.find(s): + if os.path.islink(f): + continue + if f.endswith(".COMPLETE") or f.endswith(".complete"): + is_complete_file = True + assert os.path.getsize(f) == 0, ".COMPLETE files should be empty" + else: + is_complete_file = False + + if not is_complete_file and os.path.getsize(f) == 0: + skipped.append(f) + continue + if f.endswith(".arrow"): + if not is_valid_arrow_file(f): + skipped.append(f) + continue + + source_to_files[s].append(f) + all_local_files.add(f[len(s) :]) + print(s, len(source_to_files[s]), "skipped", len(skipped), skipped[:10]) + + dst_files = dst_fs.find(dst_dir) + print(dst_dir, len(dst_files)) + + dst_file_set = {f[len(dst_dir) - len(S3_PREFIX) :] for f in dst_files} + diff = all_local_files.symmetric_difference(dst_file_set) + print("Local files", len(all_local_files)) + print("DST Files", len(dst_file_set)) + print("Symmetric difference", len(diff)) + dst_only_files = dst_file_set - all_local_files + print("DST only", len(dst_only_files), list(dst_only_files)[:10]) + + +if __name__ == "__main__": + app() diff --git a/bytelatent/data/iterators/arrow_iterator.py b/bytelatent/data/iterators/arrow_iterator.py index df5f023..ad72091 100644 --- a/bytelatent/data/iterators/arrow_iterator.py +++ b/bytelatent/data/iterators/arrow_iterator.py @@ -1,7 +1,7 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. +import os import re from logging import getLogger -from pathlib import Path from typing import Any, Generator import pyarrow as pa @@ -9,9 +9,12 @@ import pyarrow as pa # pyarrow needs the initialization from this import import pyarrow.dataset # pyright: ignore from pydantic import BaseModel, ConfigDict +import fsspec +import s3fs from bytelatent import ByteLatentError from bytelatent.data.data_types import BltExample +from bytelatent.data.file_util import get_fs from bytelatent.data.iterators.abstract_iterator import IteratorState, StatefulIterator logger = getLogger(__name__) @@ -27,6 +30,8 @@ class ArrowFileIteratorState(BaseModel, IteratorState): dataset_files: list[str] | None entropy_model_name: str | None arrow_batch_size: int = 100 + s3_profile: str | None + filesystem_type: str | None = None def build(self) -> "ArrowFileIterator": arrow_file = ArrowFileIterator( @@ -37,14 +42,17 @@ class ArrowFileIteratorState(BaseModel, IteratorState): entropy_model_name=self.entropy_model_name, arrow_batch_size=self.arrow_batch_size, dataset_files=self.dataset_files, + s3_profile=self.s3_profile, + filesystem_type=self.filesystem_type, ) if self.row_num != 0: arrow_file._set_row_num(self.row_num) return arrow_file -def shard_sort_key(file: str | Path): - match = re.search(r".+\.shard_([0-9]+)\.arrow", str(file)) +def shard_sort_key(file: str): + assert isinstance(file, str) + match = re.search(r".+\.shard_([0-9]+)\.arrow", file) shard_number = int(match.group(1)) return shard_number @@ -60,6 +68,8 @@ class ArrowFileIterator(StatefulIterator): entropy_model_name: str | None, arrow_batch_size: int, dataset_files: list[str] | None = None, + s3_profile: str | None = None, + filesystem_type: str | None = None, ): assert 0 <= worker_id < num_workers, (worker_id, num_workers) if file_path is None and dataset_files is None: @@ -75,16 +85,41 @@ class ArrowFileIterator(StatefulIterator): self.preprocess_dir = preprocess_dir self.entropy_model_name = entropy_model_name self.arrow_batch_size = arrow_batch_size + self.s3_profile = s3_profile + self.filesystem_type = filesystem_type + self.fs = None + if self.filesystem_type is not None: + if self.filesystem_type == "file": + self.fs = fsspec.filesystem("file") + elif self.filesystem_type == "s3": + self.fs = fsspec.filesystem("s3", profile=s3_profile) + if dataset_files is None: # Prepare arrow shards - jsonl_file = Path(file_path) - parts = re.match(r"(.+)\.chunk\.[0-9]+\.jsonl", jsonl_file.name) + jsonl_file = file_path + parts = re.match( + r"(.+)\.chunk\.[0-9]+\.jsonl", os.path.basename(jsonl_file) + ) assert parts is not None dataset = parts.group(1) - data_dir = Path(preprocess_dir) / dataset / entropy_model_name - shard_files = list(data_dir.glob(f"{jsonl_file.name}.shard_*.arrow")) + data_dir = os.path.join(preprocess_dir, dataset, entropy_model_name) + data_dir_with_glob = os.path.join( + data_dir, f"{os.path.basename(jsonl_file)}.shard_*.arrow" + ) + if self.fs is None: + self.fs = get_fs(data_dir_with_glob, s3_profile=s3_profile) + if isinstance(self.fs, s3fs.S3FileSystem): + self.filesystem_type = "s3" + else: + self.filesystem_type = "file" + shard_files = self.fs.glob(data_dir_with_glob) + for s in shard_files: - if not (data_dir / f"{s.name}.complete").exists(): + complete_file = os.path.join( + data_dir, f"{os.path.basename(s)}.complete" + ) + + if not self.fs.exists(complete_file): raise ValueError(f"Missing .complete for input file: {s}") shard_files = sorted(shard_files, key=shard_sort_key) @@ -92,10 +127,19 @@ class ArrowFileIterator(StatefulIterator): raise ByteLatentError( f"Zero shard_files found corresponding to: {file_path} using preprocess_dir={preprocess_dir} and entropy_model_name={entropy_model_name}, so the search path is data_dir={data_dir} for matches to {jsonl_file.name}.shard_*.arrow" ) - self.dataset_files = [str(f) for f in shard_files] + self.dataset_files = [f for f in shard_files] else: self.preprocess_dir = None self.dataset_files = dataset_files + if dataset_files[0].startswith("s3://"): + for f in dataset_files: + assert f.startswith("s3://") + if self.fs is None: + self.fs = get_fs(dataset_files[0], s3_profile=s3_profile) + if isinstance(self.fs, s3fs.S3FileSystem): + self.filesystem_type = "s3" + else: + self.filesystem_type = "file" def get_state(self) -> ArrowFileIteratorState: return ArrowFileIteratorState( @@ -107,13 +151,21 @@ class ArrowFileIterator(StatefulIterator): entropy_model_name=self.entropy_model_name, arrow_batch_size=self.arrow_batch_size, dataset_files=self.dataset_files, + s3_profile=self.s3_profile, + filesystem_type=self.filesystem_type, ) def create_iter( self, ) -> Generator[BltExample, Any, None]: if self.dataset is None: - self.dataset = pa.dataset.dataset(self.dataset_files, format="arrow") + if isinstance(self.fs, s3fs.core.S3FileSystem): + filesystem = self.fs + else: + filesystem = None + self.dataset = pa.dataset.dataset( + self.dataset_files, format="arrow", filesystem=filesystem + ) self.batch_iterator = self.dataset.to_batches( batch_size=self.arrow_batch_size ) @@ -165,7 +217,13 @@ class ArrowFileIterator(StatefulIterator): self.batch_iterator = None self.batch_to_consume = None else: - self.dataset = pa.dataset.dataset(self.dataset_files, format="arrow") + if isinstance(self.fs, s3fs.core.S3FileSystem): + filesystem = self.fs + else: + filesystem = None + self.dataset = pa.dataset.dataset( + self.dataset_files, format="arrow", filesystem=filesystem + ) self.batch_iterator = self.dataset.to_batches( batch_size=self.arrow_batch_size ) @@ -198,9 +256,14 @@ TRAIN_DATA_FILE_PATTERN = "*.chunk.*.jsonl" def find_and_sanitize_chunks( - dataset_path: str, world_size: int, file_pattern: str = TRAIN_DATA_FILE_PATTERN + dataset_path: str, + world_size: int, + file_pattern: str = TRAIN_DATA_FILE_PATTERN, + s3_profile: str | None = None, ): - dataset_chunks = [str(p) for p in Path(dataset_path).glob(file_pattern)] + fs = get_fs(dataset_path, s3_profile=s3_profile) + path_with_glob = os.path.join(dataset_path, file_pattern) + dataset_chunks = fs.glob(path_with_glob) n_chunks = len(dataset_chunks) if n_chunks > world_size: diff --git a/bytelatent/logger.py b/bytelatent/logger.py index 6723a84..87f04cc 100644 --- a/bytelatent/logger.py +++ b/bytelatent/logger.py @@ -91,7 +91,7 @@ def init_logger( log_file: str | None = None, *, name: str | None = None, - level: str = "NOTSET", + level: str = "INFO", ): """ Setup logging. diff --git a/requirements.txt b/requirements.txt index c6d87f1..59192cd 100644 --- a/requirements.txt +++ b/requirements.txt @@ -20,3 +20,4 @@ altair submitit typer rich +fsspec[full] diff --git a/setup/download_prepare_hf_data.py b/setup/download_prepare_hf_data.py index 1aacf8f..e194857 100644 --- a/setup/download_prepare_hf_data.py +++ b/setup/download_prepare_hf_data.py @@ -5,6 +5,7 @@ import os import subprocess import time +import fsspec import requests from huggingface_hub import snapshot_download @@ -38,11 +39,21 @@ def download_dataset(repo_id, local_dir, allow_patterns): print(f"Dataset downloaded to {local_dir}") -def parquet_to_jsonl(dataset, work_dir, src_dir, tgt_dir, ntasks=64): +def parquet_to_jsonl( + dataset, work_dir, src_dir, tgt_dir, ntasks=64, s3_profile: str | None = None +): from datatrove.executor import LocalPipelineExecutor from datatrove.pipeline.readers import ParquetReader from datatrove.pipeline.writers import JsonlWriter + if tgt_dir.startswith("s3//"): + if s3_profile is None: + out_spec = tgt_dir + else: + out_spec = (tgt_dir, fsspec.filesystem("s3", profile=s3_profile)) + else: + out_spec = tgt_dir + pipeline_exec = LocalPipelineExecutor( pipeline=[ ParquetReader( @@ -52,7 +63,7 @@ def parquet_to_jsonl(dataset, work_dir, src_dir, tgt_dir, ntasks=64): glob_pattern="**/*.parquet", ), JsonlWriter( - tgt_dir, + out_spec, output_filename=dataset + ".chunk.${rank}.jsonl", compression=None, ), @@ -77,7 +88,7 @@ def setup_terashuf(work_dir): return terashuf_dir -def main(dataset, memory, data_dir, seed=42, nchunks=32): +def main(dataset, memory, data_dir, seed=42, nchunks=32, s3_profile: str | None = None): # Configuration repo_id = { "fineweb_edu": "HuggingFaceFW/fineweb-edu", From a1d05403b42cdb111be2cf3034a044e63f45b91b Mon Sep 17 00:00:00 2001 From: Pedro Rodriguez Date: Fri, 10 Jan 2025 01:02:25 +0000 Subject: [PATCH 05/59] Replace regular filesystem calls with fsspec + add s3 support Summary: For compatibility with either local/nfs or S3 datasets, swap to fsspec. Add a tool to compare local and remote filesystems Test Plan: - Ran regular train script - Ran with config with data in S3 --- .gitignore | 2 +- bytelatent/args.py | 8 +- bytelatent/data/file_util.py | 117 ++++++++++++++++++++ bytelatent/data/iterators/arrow_iterator.py | 89 ++++++++++++--- bytelatent/logger.py | 2 +- requirements.txt | 1 + setup/download_prepare_hf_data.py | 17 ++- 7 files changed, 217 insertions(+), 19 deletions(-) create mode 100644 bytelatent/data/file_util.py diff --git a/.gitignore b/.gitignore index 56891a9..6c664b8 100644 --- a/.gitignore +++ b/.gitignore @@ -165,4 +165,4 @@ cython_debug/ figures/ .vscode/ .DS_Store - +internal/ diff --git a/bytelatent/args.py b/bytelatent/args.py index 4fba100..cfba3bf 100644 --- a/bytelatent/args.py +++ b/bytelatent/args.py @@ -46,8 +46,11 @@ def distribute_data_to_rank( arrow_batch_size: int, rank: int, world_size: int, + s3_profile: str | None = None, ) -> ArrowFileIterator: - dataset_chunks = find_and_sanitize_chunks(dataset_path, world_size) + dataset_chunks = find_and_sanitize_chunks( + dataset_path, world_size, s3_profile=s3_profile + ) n_workers_per_chunk = world_size // len(dataset_chunks) rank_to_arrow_iterator_params = [] for chunk_path in dataset_chunks: @@ -61,6 +64,7 @@ def distribute_data_to_rank( dataset_files=None, entropy_model_name=entropy_model_name, arrow_batch_size=arrow_batch_size, + s3_profile=s3_profile, ) ) return rank_to_arrow_iterator_params[rank] @@ -68,6 +72,7 @@ def distribute_data_to_rank( class DataloaderArgs(BaseModel): model_config = ConfigDict(extra="forbid") + s3_profile: str | None = None root_dir: str | None = None sources: dict[str, float] = {} batch_size: int = 2 @@ -107,6 +112,7 @@ class DataloaderArgs(BaseModel): arrow_batch_size=self.arrow_batch_size, rank=rank, world_size=world_size, + s3_profile=self.s3_profile, ) looping_iterator = LoopingIterator(arrow_iterator) preprocess_iterator = PreprocessIterator( diff --git a/bytelatent/data/file_util.py b/bytelatent/data/file_util.py new file mode 100644 index 0000000..d67b6db --- /dev/null +++ b/bytelatent/data/file_util.py @@ -0,0 +1,117 @@ +import os + +import fsspec +import pyarrow as pa + +# pyarrow needs the initialization from this import +import pyarrow.dataset # pyright: ignore +import typer +from pyarrow.lib import ArrowInvalid +from rich.progress import track + + +def is_valid_arrow_file(path: str): + try: + dataset = pa.dataset.dataset(path, format="arrow") + return True + except ArrowInvalid: + return False + + +app = typer.Typer() + +S3_PREFIX = "s3://" + + +def get_fs(path: str, s3_profile: str | None = None) -> fsspec.AbstractFileSystem: + if path.startswith("s3://"): + if s3_profile is None: + return fsspec.filesystem("s3") + else: + return fsspec.filesystem("s3", profile=s3_profile) + else: + return fsspec.filesystem("file") + + +@app.command() +def print_local_to_delete( + blob_dir: str, local_dirs: list[str], s3_profile: str = "blt" +): + for s in local_dirs: + assert s.endswith("/"), "Dirs must end with /" + assert blob_dir.endswith("/"), "Dirs must end with /" + blob_fs = fsspec.filesystem("s3", profile=s3_profile) + blob_files = blob_fs.find(blob_dir) + for f in track(blob_files): + size = blob_fs.info(f)["Size"] + if not f.lower().endswith(".complete"): + assert size != 0, f"Size was invalidly zero for {f}" + + blob_relative_paths = {f[len(blob_dir) - len(S3_PREFIX) :] for f in blob_files} + local_fs = fsspec.filesystem("file") + + files_to_delete = [] + for local_dir in local_dirs: + local_files = local_fs.find(local_dir) + for f in local_files: + relative_path = f[len(local_dir) :] + if relative_path in blob_relative_paths and not os.path.islink(f): + files_to_delete.append(f) + print(len(files_to_delete)) + with open("/tmp/files_to_delete.txt", "w") as f: + for file in files_to_delete: + f.write(f"{file}\n") + + +@app.command() +def compare_local_to_blob( + source_dirs: list[str], dst_dir: str, s3_profile: str = "blt" +): + for s in source_dirs: + assert s.endswith("/"), "Dirs must end with /" + assert dst_dir.endswith("/"), "Dirs must end with /" + assert len(source_dirs) != 0 + assert dst_dir.startswith("s3://") + local_fs = fsspec.filesystem("file") + dst_fs = fsspec.filesystem("s3", profile=s3_profile) + source_to_files = {} + all_local_files = set() + for s in source_dirs: + skipped = [] + if s not in source_to_files: + source_to_files[s] = [] + for f in local_fs.find(s): + if os.path.islink(f): + continue + if f.endswith(".COMPLETE") or f.endswith(".complete"): + is_complete_file = True + assert os.path.getsize(f) == 0, ".COMPLETE files should be empty" + else: + is_complete_file = False + + if not is_complete_file and os.path.getsize(f) == 0: + skipped.append(f) + continue + if f.endswith(".arrow"): + if not is_valid_arrow_file(f): + skipped.append(f) + continue + + source_to_files[s].append(f) + all_local_files.add(f[len(s) :]) + print(s, len(source_to_files[s]), "skipped", len(skipped), skipped[:10]) + + dst_files = dst_fs.find(dst_dir) + print(dst_dir, len(dst_files)) + + dst_file_set = {f[len(dst_dir) - len(S3_PREFIX) :] for f in dst_files} + diff = all_local_files.symmetric_difference(dst_file_set) + print("Local files", len(all_local_files)) + print("DST Files", len(dst_file_set)) + print("Symmetric difference", len(diff)) + dst_only_files = dst_file_set - all_local_files + print("DST only", len(dst_only_files), list(dst_only_files)[:10]) + + +if __name__ == "__main__": + app() diff --git a/bytelatent/data/iterators/arrow_iterator.py b/bytelatent/data/iterators/arrow_iterator.py index df5f023..4e7b99e 100644 --- a/bytelatent/data/iterators/arrow_iterator.py +++ b/bytelatent/data/iterators/arrow_iterator.py @@ -1,17 +1,20 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. +import os import re from logging import getLogger -from pathlib import Path from typing import Any, Generator +import fsspec import pyarrow as pa # pyarrow needs the initialization from this import import pyarrow.dataset # pyright: ignore +import s3fs from pydantic import BaseModel, ConfigDict from bytelatent import ByteLatentError from bytelatent.data.data_types import BltExample +from bytelatent.data.file_util import get_fs from bytelatent.data.iterators.abstract_iterator import IteratorState, StatefulIterator logger = getLogger(__name__) @@ -27,6 +30,8 @@ class ArrowFileIteratorState(BaseModel, IteratorState): dataset_files: list[str] | None entropy_model_name: str | None arrow_batch_size: int = 100 + s3_profile: str | None + filesystem_type: str | None = None def build(self) -> "ArrowFileIterator": arrow_file = ArrowFileIterator( @@ -37,14 +42,17 @@ class ArrowFileIteratorState(BaseModel, IteratorState): entropy_model_name=self.entropy_model_name, arrow_batch_size=self.arrow_batch_size, dataset_files=self.dataset_files, + s3_profile=self.s3_profile, + filesystem_type=self.filesystem_type, ) if self.row_num != 0: arrow_file._set_row_num(self.row_num) return arrow_file -def shard_sort_key(file: str | Path): - match = re.search(r".+\.shard_([0-9]+)\.arrow", str(file)) +def shard_sort_key(file: str): + assert isinstance(file, str) + match = re.search(r".+\.shard_([0-9]+)\.arrow", file) shard_number = int(match.group(1)) return shard_number @@ -60,6 +68,8 @@ class ArrowFileIterator(StatefulIterator): entropy_model_name: str | None, arrow_batch_size: int, dataset_files: list[str] | None = None, + s3_profile: str | None = None, + filesystem_type: str | None = None, ): assert 0 <= worker_id < num_workers, (worker_id, num_workers) if file_path is None and dataset_files is None: @@ -75,16 +85,41 @@ class ArrowFileIterator(StatefulIterator): self.preprocess_dir = preprocess_dir self.entropy_model_name = entropy_model_name self.arrow_batch_size = arrow_batch_size + self.s3_profile = s3_profile + self.filesystem_type = filesystem_type + self.fs = None + if self.filesystem_type is not None: + if self.filesystem_type == "file": + self.fs = fsspec.filesystem("file") + elif self.filesystem_type == "s3": + self.fs = fsspec.filesystem("s3", profile=s3_profile) + if dataset_files is None: # Prepare arrow shards - jsonl_file = Path(file_path) - parts = re.match(r"(.+)\.chunk\.[0-9]+\.jsonl", jsonl_file.name) + jsonl_file = file_path + parts = re.match( + r"(.+)\.chunk\.[0-9]+\.jsonl", os.path.basename(jsonl_file) + ) assert parts is not None dataset = parts.group(1) - data_dir = Path(preprocess_dir) / dataset / entropy_model_name - shard_files = list(data_dir.glob(f"{jsonl_file.name}.shard_*.arrow")) + data_dir = os.path.join(preprocess_dir, dataset, entropy_model_name) + data_dir_with_glob = os.path.join( + data_dir, f"{os.path.basename(jsonl_file)}.shard_*.arrow" + ) + if self.fs is None: + self.fs = get_fs(data_dir_with_glob, s3_profile=s3_profile) + if isinstance(self.fs, s3fs.S3FileSystem): + self.filesystem_type = "s3" + else: + self.filesystem_type = "file" + shard_files = self.fs.glob(data_dir_with_glob) + for s in shard_files: - if not (data_dir / f"{s.name}.complete").exists(): + complete_file = os.path.join( + data_dir, f"{os.path.basename(s)}.complete" + ) + + if not self.fs.exists(complete_file): raise ValueError(f"Missing .complete for input file: {s}") shard_files = sorted(shard_files, key=shard_sort_key) @@ -92,10 +127,19 @@ class ArrowFileIterator(StatefulIterator): raise ByteLatentError( f"Zero shard_files found corresponding to: {file_path} using preprocess_dir={preprocess_dir} and entropy_model_name={entropy_model_name}, so the search path is data_dir={data_dir} for matches to {jsonl_file.name}.shard_*.arrow" ) - self.dataset_files = [str(f) for f in shard_files] + self.dataset_files = [f for f in shard_files] else: self.preprocess_dir = None self.dataset_files = dataset_files + if dataset_files[0].startswith("s3://"): + for f in dataset_files: + assert f.startswith("s3://") + if self.fs is None: + self.fs = get_fs(dataset_files[0], s3_profile=s3_profile) + if isinstance(self.fs, s3fs.S3FileSystem): + self.filesystem_type = "s3" + else: + self.filesystem_type = "file" def get_state(self) -> ArrowFileIteratorState: return ArrowFileIteratorState( @@ -107,13 +151,21 @@ class ArrowFileIterator(StatefulIterator): entropy_model_name=self.entropy_model_name, arrow_batch_size=self.arrow_batch_size, dataset_files=self.dataset_files, + s3_profile=self.s3_profile, + filesystem_type=self.filesystem_type, ) def create_iter( self, ) -> Generator[BltExample, Any, None]: if self.dataset is None: - self.dataset = pa.dataset.dataset(self.dataset_files, format="arrow") + if isinstance(self.fs, s3fs.core.S3FileSystem): + filesystem = self.fs + else: + filesystem = None + self.dataset = pa.dataset.dataset( + self.dataset_files, format="arrow", filesystem=filesystem + ) self.batch_iterator = self.dataset.to_batches( batch_size=self.arrow_batch_size ) @@ -165,7 +217,13 @@ class ArrowFileIterator(StatefulIterator): self.batch_iterator = None self.batch_to_consume = None else: - self.dataset = pa.dataset.dataset(self.dataset_files, format="arrow") + if isinstance(self.fs, s3fs.core.S3FileSystem): + filesystem = self.fs + else: + filesystem = None + self.dataset = pa.dataset.dataset( + self.dataset_files, format="arrow", filesystem=filesystem + ) self.batch_iterator = self.dataset.to_batches( batch_size=self.arrow_batch_size ) @@ -198,9 +256,14 @@ TRAIN_DATA_FILE_PATTERN = "*.chunk.*.jsonl" def find_and_sanitize_chunks( - dataset_path: str, world_size: int, file_pattern: str = TRAIN_DATA_FILE_PATTERN + dataset_path: str, + world_size: int, + file_pattern: str = TRAIN_DATA_FILE_PATTERN, + s3_profile: str | None = None, ): - dataset_chunks = [str(p) for p in Path(dataset_path).glob(file_pattern)] + fs = get_fs(dataset_path, s3_profile=s3_profile) + path_with_glob = os.path.join(dataset_path, file_pattern) + dataset_chunks = fs.glob(path_with_glob) n_chunks = len(dataset_chunks) if n_chunks > world_size: diff --git a/bytelatent/logger.py b/bytelatent/logger.py index 6723a84..87f04cc 100644 --- a/bytelatent/logger.py +++ b/bytelatent/logger.py @@ -91,7 +91,7 @@ def init_logger( log_file: str | None = None, *, name: str | None = None, - level: str = "NOTSET", + level: str = "INFO", ): """ Setup logging. diff --git a/requirements.txt b/requirements.txt index c6d87f1..59192cd 100644 --- a/requirements.txt +++ b/requirements.txt @@ -20,3 +20,4 @@ altair submitit typer rich +fsspec[full] diff --git a/setup/download_prepare_hf_data.py b/setup/download_prepare_hf_data.py index 1aacf8f..e194857 100644 --- a/setup/download_prepare_hf_data.py +++ b/setup/download_prepare_hf_data.py @@ -5,6 +5,7 @@ import os import subprocess import time +import fsspec import requests from huggingface_hub import snapshot_download @@ -38,11 +39,21 @@ def download_dataset(repo_id, local_dir, allow_patterns): print(f"Dataset downloaded to {local_dir}") -def parquet_to_jsonl(dataset, work_dir, src_dir, tgt_dir, ntasks=64): +def parquet_to_jsonl( + dataset, work_dir, src_dir, tgt_dir, ntasks=64, s3_profile: str | None = None +): from datatrove.executor import LocalPipelineExecutor from datatrove.pipeline.readers import ParquetReader from datatrove.pipeline.writers import JsonlWriter + if tgt_dir.startswith("s3//"): + if s3_profile is None: + out_spec = tgt_dir + else: + out_spec = (tgt_dir, fsspec.filesystem("s3", profile=s3_profile)) + else: + out_spec = tgt_dir + pipeline_exec = LocalPipelineExecutor( pipeline=[ ParquetReader( @@ -52,7 +63,7 @@ def parquet_to_jsonl(dataset, work_dir, src_dir, tgt_dir, ntasks=64): glob_pattern="**/*.parquet", ), JsonlWriter( - tgt_dir, + out_spec, output_filename=dataset + ".chunk.${rank}.jsonl", compression=None, ), @@ -77,7 +88,7 @@ def setup_terashuf(work_dir): return terashuf_dir -def main(dataset, memory, data_dir, seed=42, nchunks=32): +def main(dataset, memory, data_dir, seed=42, nchunks=32, s3_profile: str | None = None): # Configuration repo_id = { "fineweb_edu": "HuggingFaceFW/fineweb-edu", From 3f045f11238af657269e2051d870515e6ac7913b Mon Sep 17 00:00:00 2001 From: Pedro Rodriguez Date: Mon, 13 Jan 2025 23:13:49 +0000 Subject: [PATCH 06/59] Update preprocess_entropies script to blt inference + add fsspec support Summary: Test Plan: --- bytelatent/data/patcher.py | 8 +- bytelatent/preprocess/preprocess_entropies.py | 105 ++++++++++++------ requirements.txt | 1 + 3 files changed, 74 insertions(+), 40 deletions(-) diff --git a/bytelatent/data/patcher.py b/bytelatent/data/patcher.py index ede8b06..f8477a3 100644 --- a/bytelatent/data/patcher.py +++ b/bytelatent/data/patcher.py @@ -82,16 +82,16 @@ def calculate_entropies( if device is not None: split = split.to(device) assert torch.all(split >= 0) and torch.all(split < 260) - pred, _ = entropy_model(split) + pred = entropy_model(split) pred = pred.reshape(-1, pred.shape[-1])[ : split.numel() - pad_size, : ] # [batch_size * seq_len, vocab] pred_entropies = entropy(pred) entropies.append(pred_entropies) - entropies = torch.cat(entropies, dim=0) - entropies = entropies.reshape(tokens.shape) - return entropies + concat_entropies = torch.cat(entropies, dim=0) + concat_entropies = concat_entropies.reshape(tokens.shape) + return concat_entropies def patch_start_mask_from_entropy_with_monotonicity(entropies, t): diff --git a/bytelatent/preprocess/preprocess_entropies.py b/bytelatent/preprocess/preprocess_entropies.py index 20d1e0c..45c081c 100644 --- a/bytelatent/preprocess/preprocess_entropies.py +++ b/bytelatent/preprocess/preprocess_entropies.py @@ -1,14 +1,59 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. +import jsonlines import time -from pathlib import Path +import fsspec import numpy as np import pyarrow as pa import torch import typer from rich.progress import Progress, TextColumn -from bytelatent.data.iterators.arrow_iterator import ArrowFileIterator +from bytelatent.data.file_util import get_fs +from bytelatent.data.patcher import calculate_entropies +from bytelatent.entropy_model import load_entropy_model +from bytelatent.tokenizers.build_tokenizer import TokenizerArgs + + +def get_id_from_doc(doc: dict) -> int: + """ + We need a reliable way to ensure that samples from jsonl + and arrow are the same, but there is no unique id field, + so derive the best possible + """ + if "sample_id" in doc: + sample_id = doc["sample_id"] + elif "title" in doc: + sample_id = doc["title"] + elif "qid" in doc: + sample_id = doc["qid"] + elif "paper_id" in doc: + sample_id = doc["paper_id"] + elif "path" in doc: + sample_id = doc["path"] + elif "url" in doc: + sample_id = doc["url"] + elif "id" in doc: + sample_id = doc["id"] + else: + raise ValueError(f"Could not find a id key from: {doc.keys()}") + return str(sample_id) + + +def get_text(doc: dict): + if "text" in doc: + text = doc["text"] + elif "content" in doc: + text = doc["content"] + else: + raise ValueError(f"Could not find a text key from: {doc.keys()}") + return text + + +def jsonl_file_iterator(fs: fsspec.AbstractFileSystem, path: str): + with fs.open(path) as f: + reader = jsonlines.Reader(f) + yield from reader def main( @@ -16,39 +61,32 @@ def main( output_file: str, patching_device: str = "cuda", log_step: int = 10_000, - entropy_model_checkpoint_dir: str = "entropy_checkpoint_dir", + entropy_model_checkpoint_dir: str = "public_data/entropy_checkpoint", + entropy_model_state_dict_path: str = "public_data/entropy_model.pth", + bpe_tokenizer_path: str = "public_data/tokenizer.model", dry_run: bool = False, + s3_profile: str | None = None, ): - # TODO: Modify this to work with the new code - raise NotImplementedError() - iterator = ArrowFileIterator( - file_path=input_file, - worker_id=0, - num_workers=1, - ) - tokenization_mode = "bytes" print(f"Preprocessing entropies, input: {input_file}, output: {output_file}") print("Loading entropy model", entropy_model_checkpoint_dir) + input_fs = get_fs(input_file, s3_profile=s3_profile) + input_doc_iterator = jsonl_file_iterator(input_fs, input_file) + if dry_run: return entropy_model = load_entropy_model( - entropy_model_checkpoint_dir, device=patching_device + entropy_model_checkpoint_dir, + entropy_model_state_dict_path, + device=patching_device, ) - entropy_model, _ = to_device(entropy_model, patching_device) + print("Creating patcher") patching_batch_size = 32 print("Creating tokenizer") - tokenizer = Tokenizer( - model_path="/home/artidoro/tokenizers/llama_v2.tokenizer.model", - tokenization_mode=tokenization_mode, - # BYTE_UNITS - vocab_size_unit_1=256, - bos=True, - eos=True, - bpe_delim=False, - # This isn't used, just stores a reference for other calls we don't use - patcher=None, + tokenizer_args = TokenizerArgs( + name="blt", init_kwargs={"bpe_tokenizer_path": bpe_tokenizer_path} ) + tokenizer = tokenizer_args.build() step = 0 print("starting") start_time = time.time() @@ -59,8 +97,10 @@ def main( schema = pa.schema([sample_id_field, text_field, entropy_field]) arrow_batch_size = 1_000 + output_fs = get_fs(output_file, s3_profile=s3_profile) + try: - with pa.OSFile(output_file, "wb") as sink: + with output_fs.open(output_file, "wb") as sink: with pa.ipc.new_file(sink, schema) as writer: id_buffer = [] entropies_buffer = [] @@ -72,17 +112,9 @@ def main( task = progress.add_task( "[green]Calculating entropies...", total=None ) - for doc in iterator: + for doc in input_doc_iterator: sample_id = get_id_from_doc(doc) - - if "text" in doc: - text = doc["text"] - elif "content" in doc: - text = doc["content"] - else: - raise ValueError( - f"Could not find a text key from: {doc.keys()}" - ) + text = get_text(doc) tokens = torch.tensor(tokenizer.encode(text)) patch_start = time.time() scores = calculate_entropies( @@ -128,9 +160,10 @@ def main( entropies_buffer = [] id_buffer = [] text_buffer = [] - Path(f"{output_file}.complete").touch() + output_fs.touch(f"{output_file}.complete") except: - Path(output_file).unlink(missing_ok=True) + if output_fs.exists(output_file): + output_fs.rm(output_file) raise elapsed = time.time() - start_time print("steps", step) diff --git a/requirements.txt b/requirements.txt index 59192cd..8490556 100644 --- a/requirements.txt +++ b/requirements.txt @@ -21,3 +21,4 @@ submitit typer rich fsspec[full] +orjson From d718cfa9a105b37cd8c20ed47081f5d7bc4f0f9d Mon Sep 17 00:00:00 2001 From: Pedro Rodriguez Date: Mon, 13 Jan 2025 23:26:26 +0000 Subject: [PATCH 07/59] Update preprocess_entropies script to blt inference + add fsspec support Summary: Test Plan: --- bytelatent/data/patcher.py | 8 +- bytelatent/preprocess/preprocess_entropies.py | 105 ++++++++++++------ requirements.txt | 1 + 3 files changed, 74 insertions(+), 40 deletions(-) diff --git a/bytelatent/data/patcher.py b/bytelatent/data/patcher.py index ede8b06..f8477a3 100644 --- a/bytelatent/data/patcher.py +++ b/bytelatent/data/patcher.py @@ -82,16 +82,16 @@ def calculate_entropies( if device is not None: split = split.to(device) assert torch.all(split >= 0) and torch.all(split < 260) - pred, _ = entropy_model(split) + pred = entropy_model(split) pred = pred.reshape(-1, pred.shape[-1])[ : split.numel() - pad_size, : ] # [batch_size * seq_len, vocab] pred_entropies = entropy(pred) entropies.append(pred_entropies) - entropies = torch.cat(entropies, dim=0) - entropies = entropies.reshape(tokens.shape) - return entropies + concat_entropies = torch.cat(entropies, dim=0) + concat_entropies = concat_entropies.reshape(tokens.shape) + return concat_entropies def patch_start_mask_from_entropy_with_monotonicity(entropies, t): diff --git a/bytelatent/preprocess/preprocess_entropies.py b/bytelatent/preprocess/preprocess_entropies.py index 20d1e0c..1c19a5a 100644 --- a/bytelatent/preprocess/preprocess_entropies.py +++ b/bytelatent/preprocess/preprocess_entropies.py @@ -1,14 +1,59 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. import time -from pathlib import Path +import fsspec +import jsonlines import numpy as np import pyarrow as pa import torch import typer from rich.progress import Progress, TextColumn -from bytelatent.data.iterators.arrow_iterator import ArrowFileIterator +from bytelatent.data.file_util import get_fs +from bytelatent.data.patcher import calculate_entropies +from bytelatent.entropy_model import load_entropy_model +from bytelatent.tokenizers.build_tokenizer import TokenizerArgs + + +def get_id_from_doc(doc: dict) -> int: + """ + We need a reliable way to ensure that samples from jsonl + and arrow are the same, but there is no unique id field, + so derive the best possible + """ + if "sample_id" in doc: + sample_id = doc["sample_id"] + elif "title" in doc: + sample_id = doc["title"] + elif "qid" in doc: + sample_id = doc["qid"] + elif "paper_id" in doc: + sample_id = doc["paper_id"] + elif "path" in doc: + sample_id = doc["path"] + elif "url" in doc: + sample_id = doc["url"] + elif "id" in doc: + sample_id = doc["id"] + else: + raise ValueError(f"Could not find a id key from: {doc.keys()}") + return str(sample_id) + + +def get_text(doc: dict): + if "text" in doc: + text = doc["text"] + elif "content" in doc: + text = doc["content"] + else: + raise ValueError(f"Could not find a text key from: {doc.keys()}") + return text + + +def jsonl_file_iterator(fs: fsspec.AbstractFileSystem, path: str): + with fs.open(path) as f: + reader = jsonlines.Reader(f) + yield from reader def main( @@ -16,39 +61,32 @@ def main( output_file: str, patching_device: str = "cuda", log_step: int = 10_000, - entropy_model_checkpoint_dir: str = "entropy_checkpoint_dir", + entropy_model_checkpoint_dir: str = "public_data/entropy_checkpoint", + entropy_model_state_dict_path: str = "public_data/entropy_model.pth", + bpe_tokenizer_path: str = "public_data/tokenizer.model", dry_run: bool = False, + s3_profile: str | None = None, ): - # TODO: Modify this to work with the new code - raise NotImplementedError() - iterator = ArrowFileIterator( - file_path=input_file, - worker_id=0, - num_workers=1, - ) - tokenization_mode = "bytes" print(f"Preprocessing entropies, input: {input_file}, output: {output_file}") print("Loading entropy model", entropy_model_checkpoint_dir) + input_fs = get_fs(input_file, s3_profile=s3_profile) + input_doc_iterator = jsonl_file_iterator(input_fs, input_file) + if dry_run: return entropy_model = load_entropy_model( - entropy_model_checkpoint_dir, device=patching_device + entropy_model_checkpoint_dir, + entropy_model_state_dict_path, + device=patching_device, ) - entropy_model, _ = to_device(entropy_model, patching_device) + print("Creating patcher") patching_batch_size = 32 print("Creating tokenizer") - tokenizer = Tokenizer( - model_path="/home/artidoro/tokenizers/llama_v2.tokenizer.model", - tokenization_mode=tokenization_mode, - # BYTE_UNITS - vocab_size_unit_1=256, - bos=True, - eos=True, - bpe_delim=False, - # This isn't used, just stores a reference for other calls we don't use - patcher=None, + tokenizer_args = TokenizerArgs( + name="blt", init_kwargs={"bpe_tokenizer_path": bpe_tokenizer_path} ) + tokenizer = tokenizer_args.build() step = 0 print("starting") start_time = time.time() @@ -59,8 +97,10 @@ def main( schema = pa.schema([sample_id_field, text_field, entropy_field]) arrow_batch_size = 1_000 + output_fs = get_fs(output_file, s3_profile=s3_profile) + try: - with pa.OSFile(output_file, "wb") as sink: + with output_fs.open(output_file, "wb") as sink: with pa.ipc.new_file(sink, schema) as writer: id_buffer = [] entropies_buffer = [] @@ -72,17 +112,9 @@ def main( task = progress.add_task( "[green]Calculating entropies...", total=None ) - for doc in iterator: + for doc in input_doc_iterator: sample_id = get_id_from_doc(doc) - - if "text" in doc: - text = doc["text"] - elif "content" in doc: - text = doc["content"] - else: - raise ValueError( - f"Could not find a text key from: {doc.keys()}" - ) + text = get_text(doc) tokens = torch.tensor(tokenizer.encode(text)) patch_start = time.time() scores = calculate_entropies( @@ -128,9 +160,10 @@ def main( entropies_buffer = [] id_buffer = [] text_buffer = [] - Path(f"{output_file}.complete").touch() + output_fs.touch(f"{output_file}.complete") except: - Path(output_file).unlink(missing_ok=True) + if output_fs.exists(output_file): + output_fs.rm(output_file) raise elapsed = time.time() - start_time print("steps", step) diff --git a/requirements.txt b/requirements.txt index 59192cd..8490556 100644 --- a/requirements.txt +++ b/requirements.txt @@ -21,3 +21,4 @@ submitit typer rich fsspec[full] +orjson From 38022ac06e3f5b32daf12dfe3f02c6bbc6d1e840 Mon Sep 17 00:00:00 2001 From: Pedro Rodriguez Date: Thu, 16 Jan 2025 21:51:04 +0000 Subject: [PATCH 08/59] [WIP] Changes for training entropy model and correcting attention in local models Summary: - Refactor local model configs to be separate and clearer - Add attention arguments and correct which attention is used in local models - Preparation for being able to have an entropy train script - Fix failing unit tests Test Plan: --- bytelatent/args.py | 5 + bytelatent/base_transformer.py | 44 +++++- bytelatent/configs/debug.yaml | 2 +- .../data/iterators/test_arrow_iterator.py | 3 + bytelatent/model/blt.py | 128 ++++++++++-------- .../{transformer.py => global_transformer.py} | 17 ++- bytelatent/model/local_models.py | 94 +++++++++---- bytelatent/model/utils.py | 75 +++++++++- bytelatent/preprocess/fsspec_target.py | 38 ++++++ bytelatent/test_blt.py | 22 +-- bytelatent/test_entropy_model.py | 1 + bytelatent/transformer.py | 31 ++--- 12 files changed, 331 insertions(+), 129 deletions(-) rename bytelatent/model/{transformer.py => global_transformer.py} (93%) create mode 100644 bytelatent/preprocess/fsspec_target.py diff --git a/bytelatent/args.py b/bytelatent/args.py index cfba3bf..c6f714a 100644 --- a/bytelatent/args.py +++ b/bytelatent/args.py @@ -30,6 +30,7 @@ from bytelatent.model.blt import ByteLatentTransformerArgs from bytelatent.optim import OptimArgs from bytelatent.profiling import ProfilerArgs from bytelatent.tokenizers.build_tokenizer import TokenizerArgs +from bytelatent.transformer import LMTransformerArgs logger = logging.getLogger() @@ -176,6 +177,10 @@ class TrainArgs(BaseModel): data: DataloaderArgs = DataloaderArgs() optim: OptimArgs = OptimArgs() model: ByteLatentTransformerArgs = ByteLatentTransformerArgs() + # This is only needed for training the entropy model + entropy_model: LMTransformerArgs | None = None + # Instead of training main model, train entropy model + train_entropy_model: bool = False distributed: DistributedArgs = DistributedArgs() env: EnvironmentArgs = EnvironmentArgs() diff --git a/bytelatent/base_transformer.py b/bytelatent/base_transformer.py index 45cb7c5..969fc4b 100644 --- a/bytelatent/base_transformer.py +++ b/bytelatent/base_transformer.py @@ -4,7 +4,7 @@ from enum import Enum from typing import Optional, Tuple, Union import torch -from pydantic import BaseModel +from pydantic import BaseModel, ConfigDict from torch import nn from torch.nn import functional as F from torch.nn.attention.flex_attention import ( @@ -15,6 +15,7 @@ from torch.nn.attention.flex_attention import ( from xformers.ops import AttentionBias, fmha from bytelatent import probe +from bytelatent.tokenizers.constants import EOS_ID if int(os.environ.get("BLT_ALLOW_MISSING_FLEX_ATTENTION", False)) == 0: flex_attention_comp = torch.compile(flex_attention) @@ -30,13 +31,14 @@ class InitStdFactor(Enum): class BaseTransformerArgs(BaseModel): + model_config = ConfigDict(extra="forbid") dim: int = 512 n_layers: int = 8 - head_dim: Optional[int] = None - n_heads: Optional[int] = None - n_kv_heads: Optional[int] = None + head_dim: int | None = None + n_heads: int | None = None + n_kv_heads: int | None = None - ffn_dim_multiplier: Optional[float] = None + ffn_dim_multiplier: float | None = None multiple_of: int = 256 @@ -44,11 +46,16 @@ class BaseTransformerArgs(BaseModel): rope_theta: float = 10000.0 - init_base_std: Optional[float] = None + init_base_std: float | None = None init_std_factor: InitStdFactor = InitStdFactor.DISABLED max_seqlen: int = 1024 + attn_impl: str | None = "sdpa" + attn_bias_type: str | None = None + # Special token config + eos_id: int | None = EOS_ID + def cross_entropy(pred, target, **kwargs): return F.nll_loss( @@ -294,6 +301,18 @@ class RMSNorm(nn.Module): torch.nn.init.ones_(self.weight) # type: ignore +def _reshape_for_attn_bias( + attn_bias: AttentionBias | None, + *tensors: torch.Tensor, +) -> list[torch.Tensor]: + to_transform = list(tensors) + if isinstance(attn_bias, fmha.attn_bias.BlockDiagonalCausalMask): + # could be `view` instead of reshape during training, but for inference + # have to reshape due to strides mismatch + to_transform = [t.reshape(1, -1, *t.shape[2:]) for t in to_transform] + return to_transform + + class Attention(nn.Module): def __init__( self, @@ -371,9 +390,17 @@ class Attention(nn.Module): output = flex_attention_comp(xq, xk, xv, block_mask=mask) output = output.transpose(1, 2).contiguous() # B H S D -> B S H D - elif attn_impl == "fmha": + elif attn_impl == "xformers": assert mask is None or isinstance(mask, AttentionBias) + query_shape = xq.shape + print("Before reshape", "xq", xq.shape, "xk", xk.shape, "xv", xv.shape) + xq, xk, xv = _reshape_for_attn_bias(mask, xq, xk, xv) + print("Before reshape", "xq", xq.shape, "xk", xk.shape, "xv", xv.shape) output = fmha.memory_efficient_attention(xq, xk, xv, attn_bias=mask) + print("attn out", output.shape, "query_reshape", query_shape) + output_original_shape = output.view(query_shape) + print("Reshape success") + return output_original_shape # This uses B S H D instead of B H S D of pytorch elif attn_impl == "sdpa": @@ -545,6 +572,8 @@ class BaseTransformer(nn.Module): super().__init__() self.dim = args.dim self.init_base_std = args.init_base_std + self.attn_impl = args.attn_impl + self.attn_bias_type = args.attn_bias_type self.init_std_factor = InitStdFactor(args.init_std_factor) self.max_seqlen = args.max_seqlen self.rope_embeddings = RotaryEmbedding( @@ -552,6 +581,7 @@ class BaseTransformer(nn.Module): head_dim=args.head_dim or args.dim // args.n_heads, max_seqlen=args.max_seqlen, ) + self.eos_id = args.eos_id self.layers = nn.ModuleList() for _ in range(args.n_layers): diff --git a/bytelatent/configs/debug.yaml b/bytelatent/configs/debug.yaml index 5f6debb..fc8b943 100644 --- a/bytelatent/configs/debug.yaml +++ b/bytelatent/configs/debug.yaml @@ -58,13 +58,13 @@ model: recompute_attn: false custom_bwd: false layer_ckpt: "none" - efficient_attn: "sdpa" patch_only_encoder: false patch_only_decoder: false use_local_encoder_transformer: true init_use_gaussian: true init_use_depth: "current" attn_bias_type: "block_causal" + attn_impl: "xformers" alpha_depth: "disabled" max_length: 256 local_attention_window_len: 512 diff --git a/bytelatent/data/iterators/test_arrow_iterator.py b/bytelatent/data/iterators/test_arrow_iterator.py index 4266427..fd448eb 100644 --- a/bytelatent/data/iterators/test_arrow_iterator.py +++ b/bytelatent/data/iterators/test_arrow_iterator.py @@ -27,6 +27,7 @@ def test_basic_arrow_file(): dataset_files=[ARROW_TEST_DATA_1], row_num=0, arrow_batch_size=100, + s3_profile=None, ) arrow_file = initial_state.build() start_state = arrow_file.get_state() @@ -55,6 +56,7 @@ def test_basic_arrow_file(): dataset_files=[ARROW_TEST_DATA_1], row_num=251, arrow_batch_size=100, + s3_profile=None, ) arrow_file = resumed_state.build() for example in arrow_file.create_iter(): @@ -74,6 +76,7 @@ def test_basic_arrow_file(): dataset_files=[ARROW_TEST_DATA_1], row_num=0, arrow_batch_size=100, + s3_profile=None, ) arrow_file = rank_state.build() expected_ids = [] diff --git a/bytelatent/model/blt.py b/bytelatent/model/blt.py index 9332d19..3e405da 100644 --- a/bytelatent/model/blt.py +++ b/bytelatent/model/blt.py @@ -15,8 +15,8 @@ from bytelatent.base_transformer import ( TransformerBlock, ) from bytelatent.data.patcher import Patcher, PatcherArgs -from bytelatent.model.local_models import LocalDecoder, LocalEncoder -from bytelatent.model.transformer import GlobalTransformer +from bytelatent.model.local_models import LocalDecoder, LocalEncoder, LocalModelArgs +from bytelatent.model.global_transformer import GlobalTransformer from bytelatent.model.utils import downsample from bytelatent.tokenizers.constants import BOE_ID, BOS_ID, EOS_ID, OFFSET, PAD_ID @@ -403,7 +403,6 @@ def patch_ids_from_lengths(patch_lengths, seq_len): class ByteLatentTransformerArgs(BaseTransformerArgs): - model_config = ConfigDict(extra="forbid") # Basic model configuration seed: int = 42 vocab_size: int = -1 @@ -412,7 +411,6 @@ class ByteLatentTransformerArgs(BaseTransformerArgs): n_heads: int = 8 # TODO: What is the purpose of this parameter? weight_tying: bool = False - sliding_window: Optional[int] = None # Architecture and dimensions dim_token: int = 256 @@ -471,11 +469,6 @@ class ByteLatentTransformerArgs(BaseTransformerArgs): recompute_attn: bool = True custom_bwd: bool = False layer_ckpt: str = "all" - efficient_attn: str | None = None - - # Architecture options - patch_only_encoder: bool = False - patch_only_decoder: bool = False # Initialization and attention init_use_gaussian: bool = True @@ -541,9 +534,6 @@ class ByteLatentTransformerArgs(BaseTransformerArgs): # Logging full_logging_n_layers: int = 4 - # Special token config - eos_id: int | None = None - @model_validator(mode="after") def check_hash_byte_sizes(self) -> Self: if ( @@ -558,22 +548,6 @@ class ByteLatentTransformerArgs(BaseTransformerArgs): return self -class LocalEncoderArgs(ByteLatentTransformerArgs): - # Local encoder specific dimensions - n_heads_local_encoder: int = 8 - dim_token_emb: int | None = None - dim_patch_emb: int | None = None - - def __post_init__(self): - # Override base args with local encoder specific values - self.dim = self.dim_local_encoder - self.n_layers = self.n_layers_local_encoder - self.n_heads = self.n_heads_local_encoder - self.cross_attn_decoder = False - self.cross_attn_k = self.cross_attn_k if self.cross_attn_encoder else None - self.attn_bias_type = "local_block_causal" - - class GlobalTransformerArgs(ByteLatentTransformerArgs): # Global encoder specific dimensions dim_token_emb: int | None = None @@ -625,20 +599,42 @@ def create_global_transformer(args: ByteLatentTransformerArgs) -> GlobalTransfor def create_local_encoder(args: ByteLatentTransformerArgs) -> LocalEncoder: - # First deep copy the original args - # Replace with local encoder specific values - local_encoder_args = args.model_copy( - deep=True, - update=dict( - dim=args.dim_local_encoder, - n_layers=args.n_layers_local_encoder, - n_heads=args.n_heads_local_encoder, - dim_token_emb=get_encoder_dim_token_emb(args), - dim_patch_emb=get_encoder_dim_patch_emb(args), - cross_attn_decoder=False, - cross_attn_k=args.cross_attn_k if args.cross_attn_encoder else None, - attn_bias_type="local_block_causal", - ), + local_encoder_args = LocalModelArgs( + # Updated args + dim=args.dim_local_encoder, + n_layers=args.n_layers_local_encoder, + n_heads=args.n_heads_local_encoder, + dim_token_emb=get_encoder_dim_token_emb(args), + dim_patch_emb=get_encoder_dim_patch_emb(args), + cross_attn_encoder=args.cross_attn_encoder, + cross_attn_decoder=False, + cross_attn_k=args.cross_attn_k if args.cross_attn_encoder else None, + cross_attn_init_by_pooling=args.cross_attn_init_by_pooling, + # Defaults + head_dim=args.head_dim, + max_seqlen=args.max_encoder_seq_length, + dropout=args.dropout, + vocab_size=args.vocab_size + args.pm_size, + norm_eps=args.norm_eps, + patch_size=args.patch_size, + sliding_window=args.local_attention_window_len, + use_rope=args.use_rope, + rope_theta=args.rope_theta, + init_base_std=args.init_base_std, + init_std_factor=args.init_std_factor, + n_kv_heads=args.n_kv_heads, + attn_impl=args.attn_impl, + attn_bias_type="local_block_causal", + multiple_of=args.multiple_of, + ffn_dim_multiplier=args.ffn_dim_multiplier, + patching_mode=args.patching_mode, + use_local_encoder_transformer=args.use_local_encoder_transformer, + downsampling_by_pooling=args.downsampling_by_pooling, + encoder_hash_byte_group_size=args.encoder_hash_byte_group_size, + cross_attn_all_layers_encoder=args.cross_attn_all_layers_encoder, + cross_attn_all_layers_decoder=args.cross_attn_all_layers_decoder, + cross_attn_nheads=args.cross_attn_nheads, + eos_id=args.eos_id, ) return LocalEncoder(local_encoder_args) @@ -646,18 +642,41 @@ def create_local_encoder(args: ByteLatentTransformerArgs) -> LocalEncoder: def create_local_decoder(args: ByteLatentTransformerArgs) -> LocalDecoder: # First deep copy the original args - local_decoder_args = args.model_copy( - deep=True, - update=dict( - dim=args.dim_local_decoder, - n_layers=args.n_layers_local_decoder, - n_heads=args.n_heads_local_decoder, - cross_attn_encoder=False, - cross_attn_init_by_pooling=False, # states are already defined - dim_token_emb=get_decoder_dim_token_emb(args), - dim_patch_emb=args.dim_global, - cross_attn_k=args.cross_attn_k if args.cross_attn_decoder else None, - ), + local_decoder_args = LocalModelArgs( + dim=args.dim_local_decoder, + n_layers=args.n_layers_local_decoder, + n_heads=args.n_heads_local_decoder, + dim_token_emb=get_decoder_dim_token_emb(args), + dim_patch_emb=args.dim_global, + cross_attn_encoder=False, + cross_attn_decoder=args.cross_attn_decoder, + cross_attn_init_by_pooling=False, # states are already defined + cross_attn_k=args.cross_attn_k if args.cross_attn_decoder else None, + # Defaults + head_dim=args.head_dim, + max_seqlen=args.max_encoder_seq_length, + dropout=args.dropout, + vocab_size=args.vocab_size + args.pm_size, + norm_eps=args.norm_eps, + patch_size=args.patch_size, + sliding_window=args.local_attention_window_len, + use_rope=args.use_rope, + rope_theta=args.rope_theta, + init_base_std=args.init_base_std, + init_std_factor=args.init_std_factor, + n_kv_heads=args.n_kv_heads, + attn_impl=args.attn_impl, + attn_bias_type="local_block_causal", + multiple_of=args.multiple_of, + ffn_dim_multiplier=args.ffn_dim_multiplier, + patching_mode=args.patching_mode, + use_local_encoder_transformer=args.use_local_encoder_transformer, + downsampling_by_pooling=args.downsampling_by_pooling, + encoder_hash_byte_group_size=args.encoder_hash_byte_group_size, + cross_attn_all_layers_encoder=args.cross_attn_all_layers_encoder, + cross_attn_all_layers_decoder=args.cross_attn_all_layers_decoder, + cross_attn_nheads=args.cross_attn_nheads, + eos_id=args.eos_id, ) return LocalDecoder(local_decoder_args) @@ -763,7 +782,6 @@ class ByteLatentTransformer(nn.Module): # General configuration self.weight_tying = args.weight_tying - self.sliding_window = args.sliding_window self.patch_size = args.patch_size self.patching_mode = args.patching_mode self.boe_id, self.bos_id, self.pad_id, self.eos_id = ( diff --git a/bytelatent/model/transformer.py b/bytelatent/model/global_transformer.py similarity index 93% rename from bytelatent/model/transformer.py rename to bytelatent/model/global_transformer.py index 24dc057..21c3f0c 100644 --- a/bytelatent/model/transformer.py +++ b/bytelatent/model/global_transformer.py @@ -11,6 +11,7 @@ from xformers.ops import AttentionBias from bytelatent.base_transformer import ( BaseTransformer, + BaseTransformerArgs, RMSNorm, flex_attention_comp, repeat_kv, @@ -142,11 +143,10 @@ class CrossAttention(nn.Module): class GlobalTransformer(BaseTransformer): - def __init__(self, args): + def __init__(self, args: BaseTransformerArgs): super().__init__(args) self.dropout = args.dropout - self.sliding_window = args.sliding_window - self.efficient_attn = args.efficient_attn + self.eos_id = args.eos_id self.token_embedding_projection = None if args.dim_token_emb is not None and args.dim_token_emb != self.dim: @@ -169,14 +169,19 @@ class GlobalTransformer(BaseTransformer): and projection to the token space. """ bs, seqlen = tokens.shape - attn_impl = self.efficient_attn h = embeds mask = ( mask if mask is not None - else create_causal_mask(seqlen, attn_impl, self.sliding_window) + else create_causal_mask( + seqlen, + self.attn_impl, + self.attn_bias_type, + tokens=tokens, + eos_id=self.eos_id, + ) ) if self.token_embedding_projection is not None and h.shape[-1] != self.dim: @@ -184,7 +189,7 @@ class GlobalTransformer(BaseTransformer): h = F.dropout(h, p=self.dropout, training=self.training) - h = super().forward(h, tok_idx=tok_idx, mask=mask, attn_impl=attn_impl) + h = super().forward(h, tok_idx=tok_idx, mask=mask, attn_impl=self.attn_impl) return h, cache def init_weights(self, init_base_std: float): diff --git a/bytelatent/model/local_models.py b/bytelatent/model/local_models.py index 8255504..4ce8fb5 100644 --- a/bytelatent/model/local_models.py +++ b/bytelatent/model/local_models.py @@ -1,8 +1,9 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. import logging -from typing import List, Optional, Tuple, Union +from typing import Any, List, Optional, Tuple, Union +from pydantic import BaseModel, ConfigDict import torch import torch.nn import torch.nn as nn @@ -16,29 +17,69 @@ from bytelatent.base_transformer import ( RotaryEmbedding, TransformerBlock, ) -from bytelatent.model.transformer import CrossAttention +from bytelatent.model.global_transformer import CrossAttention from bytelatent.model.utils import create_causal_mask, downsample from bytelatent.tokenizers.blt_tokenizer import BOE_ID logger = logging.getLogger() +class LocalModelArgs(BaseModel): + model_config = ConfigDict(extra="forbid") + # Local encoder specific dimensions + head_dim: int | None + dim: int + dropout: float + vocab_size: int + patch_size: int + sliding_window: int | None + use_rope: bool + init_base_std: float | None = None + init_std_factor: InitStdFactor + cross_attn_encoder: bool | None + cross_attn_decoder: bool | None + cross_attn_k: int | None + cross_attn_init_by_pooling: bool + norm_eps: float + rope_theta: float + max_seqlen: int + ffn_dim_multiplier: float | None = None + patching_mode: str + use_local_encoder_transformer: bool + downsampling_by_pooling: str | None + encoder_hash_byte_group_size: Any | None = None + cross_attn_all_layers_encoder: bool = False + cross_attn_all_layers_decoder: bool = False + cross_attn_nheads: int | None + + n_layers: int + n_heads: int + n_kv_heads: int | None = None + dim_token_emb: int + dim_patch_emb: int | None + attn_impl: str | None = "xformers" + attn_bias_type: str | None = "local_block_causal" + multiple_of: int = 256 + eos_id: int | None = None + + class LocalModelBase(nn.Module): - def __init__(self, args): + def __init__(self, args: LocalModelArgs): super().__init__() self.dim = args.dim self.dropout = args.dropout - self.vocab_size = args.vocab_size + args.pm_size + self.vocab_size = args.vocab_size self.patch_size = args.patch_size - self.efficient_attn = args.efficient_attn + self.attn_impl = args.attn_impl self.sliding_window = args.sliding_window self.use_rope = args.use_rope self.init_std_factor = args.init_std_factor self.cross_attn_encoder = getattr(args, "cross_attn_encoder", None) self.cross_attn_decoder = getattr(args, "cross_attn_decoder", None) self.cross_attn_k = getattr(args, "cross_attn_k", None) + self.eos_id = args.eos_id self.boe_id = BOE_ID @@ -54,7 +95,7 @@ class LocalModelBase(nn.Module): self.rope = RotaryEmbedding( theta=args.rope_theta, head_dim=args.head_dim or args.dim // args.n_heads, - max_seqlen=getattr(args, "max_encoder_seq_length", args.max_length), + max_seqlen=args.max_seqlen, ) self.pos_embeddings = None @@ -66,21 +107,15 @@ class LocalModelBase(nn.Module): self.patch_embedding_projection = self._create_patch_projection(args) - def _should_create_patch_projection(self, args): + def _should_create_patch_projection(self, args: LocalModelArgs): dimension_mismatch = ( getattr(args, "dim_patch_emb") and args.dim_patch_emb != self.dim ) # Check cross attention conditions cross_attn_conditions = ( - hasattr(args, "cross_attn_encoder") - and args.cross_attn_encoder - and getattr(args, "cross_attn_init_by_pooling") - ) or ( - hasattr(args, "cross_attn_decoder") - and args.cross_attn_decoder - and getattr(args, "cross_attn_init_by_pooling") - ) + args.cross_attn_encoder and args.cross_attn_init_by_pooling + ) or (args.cross_attn_decoder and args.cross_attn_init_by_pooling) return dimension_mismatch or cross_attn_conditions @@ -172,7 +207,7 @@ class LocalModelBase(nn.Module): class LocalEncoder(LocalModelBase): - def __init__(self, args): + def __init__(self, args: LocalModelArgs): super().__init__(args) self.output_proj = ( args.patching_mode in ["entropy", "probmax"] @@ -180,7 +215,6 @@ class LocalEncoder(LocalModelBase): self.apply_transformer = args.use_local_encoder_transformer self.downsampling_by_pooling = args.downsampling_by_pooling - self.patch_only = args.patch_only_encoder self.expects_hash_embeddings = args.encoder_hash_byte_group_size is not None self.cross_attn_encoder = args.cross_attn_encoder self.cross_attn_all_layers_encoder = args.cross_attn_all_layers_encoder @@ -224,7 +258,14 @@ class LocalEncoder(LocalModelBase): """ """ bs, seqlen = tokens.shape if mask is None: - mask = create_causal_mask(seqlen, self.efficient_attn, self.sliding_window) + mask = create_causal_mask( + seqlen, + self.attn_impl, + "local_block_causal", + sliding_window=self.sliding_window, + tokens=tokens, + eos_id=self.eos_id, + ) h = self.apply_embedding(tokens, embeds) freqs_cis = self.rope(seqlen=seqlen) if self.use_rope else None @@ -232,7 +273,7 @@ class LocalEncoder(LocalModelBase): h = F.dropout(h, p=self.dropout, training=self.training) for i, layer in enumerate(self.layers): - h = layer(h, mask=mask, freq_cis=freqs_cis, attn_impl=self.efficient_attn) + h = layer(h, mask=mask, freq_cis=freqs_cis, attn_impl=self.attn_impl) # check if cross attention should be applied to either all layer or only the last layer if self.cross_attn_encoder and ( i == len(self.layers) - 1 or self.cross_attn_all_layers_encoder @@ -273,12 +314,10 @@ class LocalEncoder(LocalModelBase): class LocalDecoder(LocalModelBase): - def __init__(self, args): + def __init__(self, args: LocalModelArgs): super().__init__(args) # Model configuration flags - self.patch_only = args.patch_only_decoder - self.expects_embeddings = args.share_encoder_decoder_emb self.cross_attn_decoder = args.cross_attn_decoder self.cross_attn_all_layers_decoder = args.cross_attn_all_layers_decoder self.cross_attn_init_by_pooling = args.cross_attn_init_by_pooling @@ -317,7 +356,14 @@ class LocalDecoder(LocalModelBase): assert embeds is not None, "Embeddings must be provided" if mask is None: - mask = create_causal_mask(seqlen, self.efficient_attn, self.sliding_window) + mask = create_causal_mask( + seqlen, + self.attn_impl, + "local_block_causal", + sliding_window=self.sliding_window, + tokens=tokens, + eos_id=self.eos_id, + ) h = embeds @@ -347,7 +393,7 @@ class LocalDecoder(LocalModelBase): ) h = h + h_cross - h = layer(h, mask=mask, freq_cis=freqs_cis, attn_impl=self.efficient_attn) + h = layer(h, mask=mask, freq_cis=freqs_cis, attn_impl=self.attn_impl) h_preds = self.norm(h) h_preds = F.dropout(h_preds, p=self.dropout, training=self.training) diff --git a/bytelatent/model/utils.py b/bytelatent/model/utils.py index ce52a30..eac2c68 100644 --- a/bytelatent/model/utils.py +++ b/bytelatent/model/utils.py @@ -1,8 +1,11 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. +import logging import torch from torch.nn.attention.flex_attention import create_block_mask from xformers.ops import fmha +logger = logging.getLogger() + def patch_reduce(h, max_num_patches, reduction, patch_ids): """ @@ -97,14 +100,72 @@ def causal_mask(b, h, q_idx, kv_idx): return q_idx >= kv_idx -def create_causal_mask(seqlen, attn_impl, sliding_window): - if sliding_window is not None and attn_impl == "xformers": - return fmha.attn_bias.LocalAttentionFromBottomRightMask( - window_left=sliding_window - 1, window_right=0 - ) - elif attn_impl == "xformers": - return fmha.attn_bias.LowerTriangularMask() +def tokens_to_seqlen(batch: torch.Tensor, eos_id: int): + """ + 0 0 0 1 0 0 0 1 0 0 0 + 0 1 0 0 0 1 0 0 0 0 0 + -> 4 4 3 2 4 5 + """ + mask = batch == eos_id + mask[:, -1] = True # virtual eos at the end of each row + + # 0 0 0 1 0 0 0 1 0 0 X + # 0 1 0 0 0 1 0 0 0 0 X + row, col = torch.where(mask) + + # row = 0, 0, 0, 1, 1, 1 + # col = 3, 7, 10, 1, 5, 10 + seqlens = (col[1:] - col[:-1]) + (row[1:] - row[:-1]) * mask.shape[1] + # seqlens = (4, 3, -9, 4, 5) + (0, 0, 11, 0, 0) = (4, 3, 2, 4, 5) + return [int(col[0].item() + 1)] + seqlens.tolist() + + +WARNED_SDPA = False + + +def create_causal_mask( + seqlen, + attn_impl: str, + attn_bias_type: str | None, + *, + eos_id: int | None = None, + tokens: torch.Tensor | None = None, + sliding_window: int | None = None, +): + if attn_impl == "xformers": + if attn_bias_type is None: + return fmha.attn_bias.LowerTriangularMask() + elif attn_bias_type == "causal": + assert sliding_window is None + print("attn: causal") + return fmha.attn_bias.LowerTriangularMask() + elif attn_bias_type == "block_causal": + assert sliding_window is None + assert eos_id is not None + assert tokens is not None + print("attn: block_causal") + return fmha.attn_bias.BlockDiagonalCausalMask.from_seqlens( + q_seqlen=tokens_to_seqlen(tokens, eos_id) + ) + elif attn_bias_type == "local_block_causal": + assert sliding_window is not None + assert eos_id is not None + assert tokens is not None + print("attn: local_block_causal") + return fmha.attn_bias.BlockDiagonalCausalMask.from_seqlens( + q_seqlen=tokens_to_seqlen(tokens, eos_id) + ).make_local_attention(sliding_window) + else: + return fmha.attn_bias.LocalAttentionFromBottomRightMask( + window_left=sliding_window - 1, window_right=0 + ) elif attn_impl == "sdpa": + global WARNED_SDPA + if not WARNED_SDPA: + logging.warning( + "SDPA attention being used, which doesn't have specialized attention implementations for block_causal and local_block_causal attention." + ) + WARNED_SDPA = True return "causal" elif attn_impl == "flex_attention": return create_block_mask(causal_mask, None, None, seqlen, seqlen) diff --git a/bytelatent/preprocess/fsspec_target.py b/bytelatent/preprocess/fsspec_target.py new file mode 100644 index 0000000..eacb101 --- /dev/null +++ b/bytelatent/preprocess/fsspec_target.py @@ -0,0 +1,38 @@ +import fsspec +from luigi.target import FileSystem, FileSystemTarget + + +class FSSpecFileSystem(FileSystem): + def __init__(self, fs: fsspec.AbstractFileSystem): + self.fs = fs + + def exists(self, path): + return self.fs.exists() + + def remove(self, path, recursive=True, skip_trash=True): + raise NotImplementedError() + + def isdir(self, path): + return self.fs.isdir(path) + + def listdir(self, path): + return self.fs.ls(path) + + +class FSSpecTarget(FileSystemTarget): + def __init__(self, path, fs: fsspec.AbstractFileSystem | None = None): + self.path = path + if fs is None: + self.fsspec_fs = fsspec.filesystem("file") + else: + self.fsspec_fs = fs + self._fs = None + + @property + def fs(self): + if self._fs is None: + self._fs = FSSpecFileSystem(self.fsspec_fs) + return self._fs + + def open(self, mode): + return self.fs.open(self.path, mode=mode) diff --git a/bytelatent/test_blt.py b/bytelatent/test_blt.py index 73ad9f7..d7cda05 100644 --- a/bytelatent/test_blt.py +++ b/bytelatent/test_blt.py @@ -23,9 +23,10 @@ from bytelatent.model.blt import ( init_embeddings, patch_ids_from_lengths, ) -from bytelatent.model.transformer import CrossAttention +from bytelatent.model.global_transformer import CrossAttention from bytelatent.model.utils import create_causal_mask from bytelatent.optim import OptimArgs, build_optimizer +from bytelatent.tokenizers.constants import EOS_ID from bytelatent.train import compute_loss @@ -51,7 +52,7 @@ def batch_to_tensors_and_gpu(batch): def fake_batch(): - batch_dict = torch.load(os.path.join(BLT_DATA, "test_batch.pt")) + batch_dict = torch.load(os.path.join(BLT_DATA, "test_batch.pt"), weights_only=False) del batch_dict["x2"] del batch_dict["y2"] del batch_dict["src_names"] @@ -98,18 +99,17 @@ def create_args(cross_attention=False): recompute_attn=False, custom_bwd=False, layer_ckpt="none", - efficient_attn="sdpa", - patch_only_encoder=False, - patch_only_decoder=False, use_local_encoder_transformer=True, init_use_gaussian=True, init_use_depth="current", attn_bias_type="block_causal", + attn_impl="xformers", alpha_depth="disabled", max_length=256, local_attention_window_len=512, max_seqlen=12288, downsampling_by_pooling="max", + eos_id=EOS_ID, ) return transformer_args @@ -341,10 +341,10 @@ class TestByteLatentTransformer: model = ByteLatentTransformer(args) assert model is not None - @pytest.mark.parametrize("attn_type", ["fmha", "sdpa"]) - def test_blt_transformer_forward(self, attn_type): + @pytest.mark.parametrize("attn_impl", ["fmha", "sdpa", "xformers"]) + def test_blt_transformer_forward(self, attn_impl): args = create_args() - args = args.model_copy(update=dict(efficient_attn=attn_type)) + args = args.model_copy(update=dict(attn_impl=attn_impl)) model = ByteLatentTransformer(args) model = model.cuda() batch = fake_batch() @@ -393,7 +393,9 @@ class TestByteLatentTransformer: n_kv_heads=4, norm_eps=1e-6, ).to("cuda") - mask = create_causal_mask(x.shape[1], "flex_attention", sliding_window=None) + mask = create_causal_mask( + x.shape[1], "flex_attention", None, sliding_window=None + ) output = cross_attention(x, kv, mask) assert output is not None assert output.shape == (2, 256, 512) @@ -440,7 +442,7 @@ class TestByteLatentTransformer: def test_loss_backward(self): args = create_args() - args = args.model_copy(update=dict(efficient_attn="sdpa")) + args = args.model_copy(update=dict(attn_impl="sdpa")) batch = fake_batch() model = ByteLatentTransformer(args) steps = 10 diff --git a/bytelatent/test_entropy_model.py b/bytelatent/test_entropy_model.py index 3acc42d..af81638 100644 --- a/bytelatent/test_entropy_model.py +++ b/bytelatent/test_entropy_model.py @@ -24,6 +24,7 @@ def test_entropy_model(): dataset_files=[ARROW_TEST_DATA], row_num=0, arrow_batch_size=100, + s3_profile=None, ) arrow_file = initial_state.build() tokenizer_args = TokenizerArgs( diff --git a/bytelatent/transformer.py b/bytelatent/transformer.py index 432f7df..92c5ff5 100644 --- a/bytelatent/transformer.py +++ b/bytelatent/transformer.py @@ -22,23 +22,7 @@ from bytelatent.base_transformer import ( RMSNorm, cross_entropy, ) - - -def create_causal_mask(seqlen, attn_impl, sliding_window): - if sliding_window is not None and attn_impl == "xformers": - return fmha.attn_bias.LocalAttentionFromBottomRightMask( - window_left=sliding_window - 1, window_right=0 - ) - elif attn_impl == "xformers": - return fmha.attn_bias.LowerTriangularMask() - elif attn_impl == "sdpa": - return "causal" - elif attn_impl == "flex_attention": - return create_block_mask(causal_mask, None, None, seqlen, seqlen) - else: - raise NotImplementedError( - f"Attention {attn_impl} with {sliding_window} sliding window not implemented" - ) +from bytelatent.model.utils import create_causal_mask def attention_flops_per_token(n_layers, seq_len, dim, causal): @@ -94,8 +78,10 @@ class LMTransformer(BaseTransformer): target: Optional[torch.Tensor] = None, tok_idx: Optional[torch.Tensor] = None, mask: Optional[Union[BlockMask, AttentionBias, torch.Tensor, str]] = None, - attn_impl: str = "sdpa", + attn_impl: str | None = None, ): + if attn_impl is None: + attn_impl = self.attn_impl bsz, seqlen = token_values.shape h = self.tok_embeddings(token_values) @@ -103,7 +89,14 @@ class LMTransformer(BaseTransformer): mask = ( mask if mask is not None - else create_causal_mask(seqlen, attn_impl, self.sliding_window) + else create_causal_mask( + seqlen, + attn_impl, + self.attn_bias_type, + sliding_window=self.sliding_window, + tokens=token_values, + eos_id=self.eos_id, + ) ) h = super().forward(h, tok_idx=tok_idx, mask=mask, attn_impl=attn_impl) From 374409fa3b90e856f3b4c306c16fff21ac210ca7 Mon Sep 17 00:00:00 2001 From: Pedro Rodriguez Date: Fri, 17 Jan 2025 01:01:29 +0000 Subject: [PATCH 09/59] [WIP] Changes for training entropy model and correcting attention in local models Summary: - Refactor local model configs to be separate and clearer - Add attention arguments and correct which attention is used in local models - Preparation for being able to have an entropy train script - Fix failing unit tests Test Plan: --- bytelatent/args.py | 7 + bytelatent/base_transformer.py | 45 ++++-- bytelatent/configs/debug.yaml | 3 +- .../data/iterators/test_arrow_iterator.py | 3 + bytelatent/distributed.py | 1 - bytelatent/model/blt.py | 128 ++++++++++-------- .../{transformer.py => global_transformer.py} | 17 ++- bytelatent/model/local_models.py | 94 +++++++++---- bytelatent/model/utils.py | 73 +++++++++- bytelatent/preprocess/fsspec_target.py | 38 ++++++ bytelatent/test_blt.py | 22 +-- bytelatent/test_entropy_model.py | 1 + bytelatent/train.py | 4 + bytelatent/transformer.py | 31 ++--- 14 files changed, 334 insertions(+), 133 deletions(-) rename bytelatent/model/{transformer.py => global_transformer.py} (93%) create mode 100644 bytelatent/preprocess/fsspec_target.py diff --git a/bytelatent/args.py b/bytelatent/args.py index cfba3bf..b9144c6 100644 --- a/bytelatent/args.py +++ b/bytelatent/args.py @@ -30,6 +30,7 @@ from bytelatent.model.blt import ByteLatentTransformerArgs from bytelatent.optim import OptimArgs from bytelatent.profiling import ProfilerArgs from bytelatent.tokenizers.build_tokenizer import TokenizerArgs +from bytelatent.transformer import LMTransformerArgs logger = logging.getLogger() @@ -163,6 +164,8 @@ class TrainArgs(BaseModel): seed: int = 42 + debug_dynamo: bool = False + # Number of gradient accumulation steps # Total batch size is batch_size*grad_acc_steps grad_acc_steps: int = 1 @@ -176,6 +179,10 @@ class TrainArgs(BaseModel): data: DataloaderArgs = DataloaderArgs() optim: OptimArgs = OptimArgs() model: ByteLatentTransformerArgs = ByteLatentTransformerArgs() + # This is only needed for training the entropy model + entropy_model: LMTransformerArgs | None = None + # Instead of training main model, train entropy model + train_entropy_model: bool = False distributed: DistributedArgs = DistributedArgs() env: EnvironmentArgs = EnvironmentArgs() diff --git a/bytelatent/base_transformer.py b/bytelatent/base_transformer.py index 45cb7c5..dd0cce6 100644 --- a/bytelatent/base_transformer.py +++ b/bytelatent/base_transformer.py @@ -4,7 +4,7 @@ from enum import Enum from typing import Optional, Tuple, Union import torch -from pydantic import BaseModel +from pydantic import BaseModel, ConfigDict from torch import nn from torch.nn import functional as F from torch.nn.attention.flex_attention import ( @@ -15,6 +15,7 @@ from torch.nn.attention.flex_attention import ( from xformers.ops import AttentionBias, fmha from bytelatent import probe +from bytelatent.tokenizers.constants import EOS_ID if int(os.environ.get("BLT_ALLOW_MISSING_FLEX_ATTENTION", False)) == 0: flex_attention_comp = torch.compile(flex_attention) @@ -30,13 +31,14 @@ class InitStdFactor(Enum): class BaseTransformerArgs(BaseModel): + model_config = ConfigDict(extra="forbid") dim: int = 512 n_layers: int = 8 - head_dim: Optional[int] = None - n_heads: Optional[int] = None - n_kv_heads: Optional[int] = None + head_dim: int | None = None + n_heads: int | None = None + n_kv_heads: int | None = None - ffn_dim_multiplier: Optional[float] = None + ffn_dim_multiplier: float | None = None multiple_of: int = 256 @@ -44,11 +46,16 @@ class BaseTransformerArgs(BaseModel): rope_theta: float = 10000.0 - init_base_std: Optional[float] = None + init_base_std: float | None = None init_std_factor: InitStdFactor = InitStdFactor.DISABLED max_seqlen: int = 1024 + attn_impl: str | None = "sdpa" + attn_bias_type: str | None = None + # Special token config + eos_id: int | None = EOS_ID + def cross_entropy(pred, target, **kwargs): return F.nll_loss( @@ -294,6 +301,18 @@ class RMSNorm(nn.Module): torch.nn.init.ones_(self.weight) # type: ignore +def _reshape_for_attn_bias( + attn_bias: AttentionBias | None, + *tensors: torch.Tensor, +) -> list[torch.Tensor]: + to_transform = list(tensors) + if isinstance(attn_bias, fmha.attn_bias.BlockDiagonalCausalMask): + # could be `view` instead of reshape during training, but for inference + # have to reshape due to strides mismatch + to_transform = [t.reshape(1, -1, *t.shape[2:]) for t in to_transform] + return to_transform + + class Attention(nn.Module): def __init__( self, @@ -371,9 +390,12 @@ class Attention(nn.Module): output = flex_attention_comp(xq, xk, xv, block_mask=mask) output = output.transpose(1, 2).contiguous() # B H S D -> B S H D - elif attn_impl == "fmha": + elif attn_impl == "xformers": assert mask is None or isinstance(mask, AttentionBias) + query_shape = xq.shape + xq, xk, xv = _reshape_for_attn_bias(mask, xq, xk, xv) output = fmha.memory_efficient_attention(xq, xk, xv, attn_bias=mask) + output = output.view(query_shape) # This uses B S H D instead of B H S D of pytorch elif attn_impl == "sdpa": @@ -522,14 +544,16 @@ class TransformerBlock(nn.Module): mask: Optional[Union[BlockMask, AttentionBias, str]] = None, attn_impl: str = "sdpa", ) -> torch.Tensor: - h = x + self.attention( + attn_out = self.attention( self.attention_norm(x), freq_cis, tok_idx=tok_idx, mask=mask, attn_impl=attn_impl, ) - out = h + self.feed_forward(self.ffn_norm(h)) + h = x + attn_out + h_norm = self.ffn_norm(h) + out = h + self.feed_forward(h_norm) return out def init_weights(self, init_std=None, factor=1.0): @@ -545,6 +569,8 @@ class BaseTransformer(nn.Module): super().__init__() self.dim = args.dim self.init_base_std = args.init_base_std + self.attn_impl = args.attn_impl + self.attn_bias_type = args.attn_bias_type self.init_std_factor = InitStdFactor(args.init_std_factor) self.max_seqlen = args.max_seqlen self.rope_embeddings = RotaryEmbedding( @@ -552,6 +578,7 @@ class BaseTransformer(nn.Module): head_dim=args.head_dim or args.dim // args.n_heads, max_seqlen=args.max_seqlen, ) + self.eos_id = args.eos_id self.layers = nn.ModuleList() for _ in range(args.n_layers): diff --git a/bytelatent/configs/debug.yaml b/bytelatent/configs/debug.yaml index 5f6debb..4ae4459 100644 --- a/bytelatent/configs/debug.yaml +++ b/bytelatent/configs/debug.yaml @@ -15,7 +15,6 @@ optim: distributed: fsdp_type: full_shard - compile: true model_dtype: bf16 matmul_allow_tf32: false selective_activation_checkpointing: false @@ -58,13 +57,13 @@ model: recompute_attn: false custom_bwd: false layer_ckpt: "none" - efficient_attn: "sdpa" patch_only_encoder: false patch_only_decoder: false use_local_encoder_transformer: true init_use_gaussian: true init_use_depth: "current" attn_bias_type: "block_causal" + attn_impl: "xformers" alpha_depth: "disabled" max_length: 256 local_attention_window_len: 512 diff --git a/bytelatent/data/iterators/test_arrow_iterator.py b/bytelatent/data/iterators/test_arrow_iterator.py index 4266427..fd448eb 100644 --- a/bytelatent/data/iterators/test_arrow_iterator.py +++ b/bytelatent/data/iterators/test_arrow_iterator.py @@ -27,6 +27,7 @@ def test_basic_arrow_file(): dataset_files=[ARROW_TEST_DATA_1], row_num=0, arrow_batch_size=100, + s3_profile=None, ) arrow_file = initial_state.build() start_state = arrow_file.get_state() @@ -55,6 +56,7 @@ def test_basic_arrow_file(): dataset_files=[ARROW_TEST_DATA_1], row_num=251, arrow_batch_size=100, + s3_profile=None, ) arrow_file = resumed_state.build() for example in arrow_file.create_iter(): @@ -74,6 +76,7 @@ def test_basic_arrow_file(): dataset_files=[ARROW_TEST_DATA_1], row_num=0, arrow_batch_size=100, + s3_profile=None, ) arrow_file = rank_state.build() expected_ids = [] diff --git a/bytelatent/distributed.py b/bytelatent/distributed.py index b211858..168cb7c 100644 --- a/bytelatent/distributed.py +++ b/bytelatent/distributed.py @@ -11,7 +11,6 @@ import socket import subprocess import sys import tempfile -from dataclasses import asdict, dataclass from functools import lru_cache, partial, reduce from itertools import chain from typing import List, Optional, Tuple, Union diff --git a/bytelatent/model/blt.py b/bytelatent/model/blt.py index 9332d19..1d20cfa 100644 --- a/bytelatent/model/blt.py +++ b/bytelatent/model/blt.py @@ -15,8 +15,8 @@ from bytelatent.base_transformer import ( TransformerBlock, ) from bytelatent.data.patcher import Patcher, PatcherArgs -from bytelatent.model.local_models import LocalDecoder, LocalEncoder -from bytelatent.model.transformer import GlobalTransformer +from bytelatent.model.global_transformer import GlobalTransformer +from bytelatent.model.local_models import LocalDecoder, LocalEncoder, LocalModelArgs from bytelatent.model.utils import downsample from bytelatent.tokenizers.constants import BOE_ID, BOS_ID, EOS_ID, OFFSET, PAD_ID @@ -403,7 +403,6 @@ def patch_ids_from_lengths(patch_lengths, seq_len): class ByteLatentTransformerArgs(BaseTransformerArgs): - model_config = ConfigDict(extra="forbid") # Basic model configuration seed: int = 42 vocab_size: int = -1 @@ -412,7 +411,6 @@ class ByteLatentTransformerArgs(BaseTransformerArgs): n_heads: int = 8 # TODO: What is the purpose of this parameter? weight_tying: bool = False - sliding_window: Optional[int] = None # Architecture and dimensions dim_token: int = 256 @@ -471,11 +469,6 @@ class ByteLatentTransformerArgs(BaseTransformerArgs): recompute_attn: bool = True custom_bwd: bool = False layer_ckpt: str = "all" - efficient_attn: str | None = None - - # Architecture options - patch_only_encoder: bool = False - patch_only_decoder: bool = False # Initialization and attention init_use_gaussian: bool = True @@ -541,9 +534,6 @@ class ByteLatentTransformerArgs(BaseTransformerArgs): # Logging full_logging_n_layers: int = 4 - # Special token config - eos_id: int | None = None - @model_validator(mode="after") def check_hash_byte_sizes(self) -> Self: if ( @@ -558,22 +548,6 @@ class ByteLatentTransformerArgs(BaseTransformerArgs): return self -class LocalEncoderArgs(ByteLatentTransformerArgs): - # Local encoder specific dimensions - n_heads_local_encoder: int = 8 - dim_token_emb: int | None = None - dim_patch_emb: int | None = None - - def __post_init__(self): - # Override base args with local encoder specific values - self.dim = self.dim_local_encoder - self.n_layers = self.n_layers_local_encoder - self.n_heads = self.n_heads_local_encoder - self.cross_attn_decoder = False - self.cross_attn_k = self.cross_attn_k if self.cross_attn_encoder else None - self.attn_bias_type = "local_block_causal" - - class GlobalTransformerArgs(ByteLatentTransformerArgs): # Global encoder specific dimensions dim_token_emb: int | None = None @@ -625,20 +599,42 @@ def create_global_transformer(args: ByteLatentTransformerArgs) -> GlobalTransfor def create_local_encoder(args: ByteLatentTransformerArgs) -> LocalEncoder: - # First deep copy the original args - # Replace with local encoder specific values - local_encoder_args = args.model_copy( - deep=True, - update=dict( - dim=args.dim_local_encoder, - n_layers=args.n_layers_local_encoder, - n_heads=args.n_heads_local_encoder, - dim_token_emb=get_encoder_dim_token_emb(args), - dim_patch_emb=get_encoder_dim_patch_emb(args), - cross_attn_decoder=False, - cross_attn_k=args.cross_attn_k if args.cross_attn_encoder else None, - attn_bias_type="local_block_causal", - ), + local_encoder_args = LocalModelArgs( + # Updated args + dim=args.dim_local_encoder, + n_layers=args.n_layers_local_encoder, + n_heads=args.n_heads_local_encoder, + dim_token_emb=get_encoder_dim_token_emb(args), + dim_patch_emb=get_encoder_dim_patch_emb(args), + cross_attn_encoder=args.cross_attn_encoder, + cross_attn_decoder=False, + cross_attn_k=args.cross_attn_k if args.cross_attn_encoder else None, + cross_attn_init_by_pooling=args.cross_attn_init_by_pooling, + # Defaults + head_dim=args.head_dim, + max_seqlen=args.max_encoder_seq_length, + dropout=args.dropout, + vocab_size=args.vocab_size + args.pm_size, + norm_eps=args.norm_eps, + patch_size=args.patch_size, + sliding_window=args.local_attention_window_len, + use_rope=args.use_rope, + rope_theta=args.rope_theta, + init_base_std=args.init_base_std, + init_std_factor=args.init_std_factor, + n_kv_heads=args.n_kv_heads, + attn_impl=args.attn_impl, + attn_bias_type="local_block_causal", + multiple_of=args.multiple_of, + ffn_dim_multiplier=args.ffn_dim_multiplier, + patching_mode=args.patching_mode, + use_local_encoder_transformer=args.use_local_encoder_transformer, + downsampling_by_pooling=args.downsampling_by_pooling, + encoder_hash_byte_group_size=args.encoder_hash_byte_group_size, + cross_attn_all_layers_encoder=args.cross_attn_all_layers_encoder, + cross_attn_all_layers_decoder=args.cross_attn_all_layers_decoder, + cross_attn_nheads=args.cross_attn_nheads, + eos_id=args.eos_id, ) return LocalEncoder(local_encoder_args) @@ -646,18 +642,41 @@ def create_local_encoder(args: ByteLatentTransformerArgs) -> LocalEncoder: def create_local_decoder(args: ByteLatentTransformerArgs) -> LocalDecoder: # First deep copy the original args - local_decoder_args = args.model_copy( - deep=True, - update=dict( - dim=args.dim_local_decoder, - n_layers=args.n_layers_local_decoder, - n_heads=args.n_heads_local_decoder, - cross_attn_encoder=False, - cross_attn_init_by_pooling=False, # states are already defined - dim_token_emb=get_decoder_dim_token_emb(args), - dim_patch_emb=args.dim_global, - cross_attn_k=args.cross_attn_k if args.cross_attn_decoder else None, - ), + local_decoder_args = LocalModelArgs( + dim=args.dim_local_decoder, + n_layers=args.n_layers_local_decoder, + n_heads=args.n_heads_local_decoder, + dim_token_emb=get_decoder_dim_token_emb(args), + dim_patch_emb=args.dim_global, + cross_attn_encoder=False, + cross_attn_decoder=args.cross_attn_decoder, + cross_attn_init_by_pooling=False, # states are already defined + cross_attn_k=args.cross_attn_k if args.cross_attn_decoder else None, + # Defaults + head_dim=args.head_dim, + max_seqlen=args.max_encoder_seq_length, + dropout=args.dropout, + vocab_size=args.vocab_size + args.pm_size, + norm_eps=args.norm_eps, + patch_size=args.patch_size, + sliding_window=args.local_attention_window_len, + use_rope=args.use_rope, + rope_theta=args.rope_theta, + init_base_std=args.init_base_std, + init_std_factor=args.init_std_factor, + n_kv_heads=args.n_kv_heads, + attn_impl=args.attn_impl, + attn_bias_type="local_block_causal", + multiple_of=args.multiple_of, + ffn_dim_multiplier=args.ffn_dim_multiplier, + patching_mode=args.patching_mode, + use_local_encoder_transformer=args.use_local_encoder_transformer, + downsampling_by_pooling=args.downsampling_by_pooling, + encoder_hash_byte_group_size=args.encoder_hash_byte_group_size, + cross_attn_all_layers_encoder=args.cross_attn_all_layers_encoder, + cross_attn_all_layers_decoder=args.cross_attn_all_layers_decoder, + cross_attn_nheads=args.cross_attn_nheads, + eos_id=args.eos_id, ) return LocalDecoder(local_decoder_args) @@ -763,7 +782,6 @@ class ByteLatentTransformer(nn.Module): # General configuration self.weight_tying = args.weight_tying - self.sliding_window = args.sliding_window self.patch_size = args.patch_size self.patching_mode = args.patching_mode self.boe_id, self.bos_id, self.pad_id, self.eos_id = ( diff --git a/bytelatent/model/transformer.py b/bytelatent/model/global_transformer.py similarity index 93% rename from bytelatent/model/transformer.py rename to bytelatent/model/global_transformer.py index 24dc057..21c3f0c 100644 --- a/bytelatent/model/transformer.py +++ b/bytelatent/model/global_transformer.py @@ -11,6 +11,7 @@ from xformers.ops import AttentionBias from bytelatent.base_transformer import ( BaseTransformer, + BaseTransformerArgs, RMSNorm, flex_attention_comp, repeat_kv, @@ -142,11 +143,10 @@ class CrossAttention(nn.Module): class GlobalTransformer(BaseTransformer): - def __init__(self, args): + def __init__(self, args: BaseTransformerArgs): super().__init__(args) self.dropout = args.dropout - self.sliding_window = args.sliding_window - self.efficient_attn = args.efficient_attn + self.eos_id = args.eos_id self.token_embedding_projection = None if args.dim_token_emb is not None and args.dim_token_emb != self.dim: @@ -169,14 +169,19 @@ class GlobalTransformer(BaseTransformer): and projection to the token space. """ bs, seqlen = tokens.shape - attn_impl = self.efficient_attn h = embeds mask = ( mask if mask is not None - else create_causal_mask(seqlen, attn_impl, self.sliding_window) + else create_causal_mask( + seqlen, + self.attn_impl, + self.attn_bias_type, + tokens=tokens, + eos_id=self.eos_id, + ) ) if self.token_embedding_projection is not None and h.shape[-1] != self.dim: @@ -184,7 +189,7 @@ class GlobalTransformer(BaseTransformer): h = F.dropout(h, p=self.dropout, training=self.training) - h = super().forward(h, tok_idx=tok_idx, mask=mask, attn_impl=attn_impl) + h = super().forward(h, tok_idx=tok_idx, mask=mask, attn_impl=self.attn_impl) return h, cache def init_weights(self, init_base_std: float): diff --git a/bytelatent/model/local_models.py b/bytelatent/model/local_models.py index 8255504..f182780 100644 --- a/bytelatent/model/local_models.py +++ b/bytelatent/model/local_models.py @@ -1,11 +1,12 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. import logging -from typing import List, Optional, Tuple, Union +from typing import Any, List, Optional, Tuple, Union import torch import torch.nn import torch.nn as nn +from pydantic import BaseModel, ConfigDict from torch.nn import functional as F from torch.nn.attention.flex_attention import BlockMask from xformers.ops import AttentionBias @@ -16,29 +17,69 @@ from bytelatent.base_transformer import ( RotaryEmbedding, TransformerBlock, ) -from bytelatent.model.transformer import CrossAttention +from bytelatent.model.global_transformer import CrossAttention from bytelatent.model.utils import create_causal_mask, downsample from bytelatent.tokenizers.blt_tokenizer import BOE_ID logger = logging.getLogger() +class LocalModelArgs(BaseModel): + model_config = ConfigDict(extra="forbid") + # Local encoder specific dimensions + head_dim: int | None + dim: int + dropout: float + vocab_size: int + patch_size: int + sliding_window: int | None + use_rope: bool + init_base_std: float | None = None + init_std_factor: InitStdFactor + cross_attn_encoder: bool | None + cross_attn_decoder: bool | None + cross_attn_k: int | None + cross_attn_init_by_pooling: bool + norm_eps: float + rope_theta: float + max_seqlen: int + ffn_dim_multiplier: float | None = None + patching_mode: str + use_local_encoder_transformer: bool + downsampling_by_pooling: str | None + encoder_hash_byte_group_size: Any | None = None + cross_attn_all_layers_encoder: bool = False + cross_attn_all_layers_decoder: bool = False + cross_attn_nheads: int | None + + n_layers: int + n_heads: int + n_kv_heads: int | None = None + dim_token_emb: int + dim_patch_emb: int | None + attn_impl: str | None = "xformers" + attn_bias_type: str | None = "local_block_causal" + multiple_of: int = 256 + eos_id: int | None = None + + class LocalModelBase(nn.Module): - def __init__(self, args): + def __init__(self, args: LocalModelArgs): super().__init__() self.dim = args.dim self.dropout = args.dropout - self.vocab_size = args.vocab_size + args.pm_size + self.vocab_size = args.vocab_size self.patch_size = args.patch_size - self.efficient_attn = args.efficient_attn + self.attn_impl = args.attn_impl self.sliding_window = args.sliding_window self.use_rope = args.use_rope self.init_std_factor = args.init_std_factor self.cross_attn_encoder = getattr(args, "cross_attn_encoder", None) self.cross_attn_decoder = getattr(args, "cross_attn_decoder", None) self.cross_attn_k = getattr(args, "cross_attn_k", None) + self.eos_id = args.eos_id self.boe_id = BOE_ID @@ -54,7 +95,7 @@ class LocalModelBase(nn.Module): self.rope = RotaryEmbedding( theta=args.rope_theta, head_dim=args.head_dim or args.dim // args.n_heads, - max_seqlen=getattr(args, "max_encoder_seq_length", args.max_length), + max_seqlen=args.max_seqlen, ) self.pos_embeddings = None @@ -66,21 +107,15 @@ class LocalModelBase(nn.Module): self.patch_embedding_projection = self._create_patch_projection(args) - def _should_create_patch_projection(self, args): + def _should_create_patch_projection(self, args: LocalModelArgs): dimension_mismatch = ( getattr(args, "dim_patch_emb") and args.dim_patch_emb != self.dim ) # Check cross attention conditions cross_attn_conditions = ( - hasattr(args, "cross_attn_encoder") - and args.cross_attn_encoder - and getattr(args, "cross_attn_init_by_pooling") - ) or ( - hasattr(args, "cross_attn_decoder") - and args.cross_attn_decoder - and getattr(args, "cross_attn_init_by_pooling") - ) + args.cross_attn_encoder and args.cross_attn_init_by_pooling + ) or (args.cross_attn_decoder and args.cross_attn_init_by_pooling) return dimension_mismatch or cross_attn_conditions @@ -172,7 +207,7 @@ class LocalModelBase(nn.Module): class LocalEncoder(LocalModelBase): - def __init__(self, args): + def __init__(self, args: LocalModelArgs): super().__init__(args) self.output_proj = ( args.patching_mode in ["entropy", "probmax"] @@ -180,7 +215,6 @@ class LocalEncoder(LocalModelBase): self.apply_transformer = args.use_local_encoder_transformer self.downsampling_by_pooling = args.downsampling_by_pooling - self.patch_only = args.patch_only_encoder self.expects_hash_embeddings = args.encoder_hash_byte_group_size is not None self.cross_attn_encoder = args.cross_attn_encoder self.cross_attn_all_layers_encoder = args.cross_attn_all_layers_encoder @@ -224,7 +258,14 @@ class LocalEncoder(LocalModelBase): """ """ bs, seqlen = tokens.shape if mask is None: - mask = create_causal_mask(seqlen, self.efficient_attn, self.sliding_window) + mask = create_causal_mask( + seqlen, + self.attn_impl, + "local_block_causal", + sliding_window=self.sliding_window, + tokens=tokens, + eos_id=self.eos_id, + ) h = self.apply_embedding(tokens, embeds) freqs_cis = self.rope(seqlen=seqlen) if self.use_rope else None @@ -232,7 +273,7 @@ class LocalEncoder(LocalModelBase): h = F.dropout(h, p=self.dropout, training=self.training) for i, layer in enumerate(self.layers): - h = layer(h, mask=mask, freq_cis=freqs_cis, attn_impl=self.efficient_attn) + h = layer(h, mask=mask, freq_cis=freqs_cis, attn_impl=self.attn_impl) # check if cross attention should be applied to either all layer or only the last layer if self.cross_attn_encoder and ( i == len(self.layers) - 1 or self.cross_attn_all_layers_encoder @@ -273,12 +314,10 @@ class LocalEncoder(LocalModelBase): class LocalDecoder(LocalModelBase): - def __init__(self, args): + def __init__(self, args: LocalModelArgs): super().__init__(args) # Model configuration flags - self.patch_only = args.patch_only_decoder - self.expects_embeddings = args.share_encoder_decoder_emb self.cross_attn_decoder = args.cross_attn_decoder self.cross_attn_all_layers_decoder = args.cross_attn_all_layers_decoder self.cross_attn_init_by_pooling = args.cross_attn_init_by_pooling @@ -317,7 +356,14 @@ class LocalDecoder(LocalModelBase): assert embeds is not None, "Embeddings must be provided" if mask is None: - mask = create_causal_mask(seqlen, self.efficient_attn, self.sliding_window) + mask = create_causal_mask( + seqlen, + self.attn_impl, + "local_block_causal", + sliding_window=self.sliding_window, + tokens=tokens, + eos_id=self.eos_id, + ) h = embeds @@ -347,7 +393,7 @@ class LocalDecoder(LocalModelBase): ) h = h + h_cross - h = layer(h, mask=mask, freq_cis=freqs_cis, attn_impl=self.efficient_attn) + h = layer(h, mask=mask, freq_cis=freqs_cis, attn_impl=self.attn_impl) h_preds = self.norm(h) h_preds = F.dropout(h_preds, p=self.dropout, training=self.training) diff --git a/bytelatent/model/utils.py b/bytelatent/model/utils.py index ce52a30..42eb185 100644 --- a/bytelatent/model/utils.py +++ b/bytelatent/model/utils.py @@ -1,8 +1,12 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. +import logging + import torch from torch.nn.attention.flex_attention import create_block_mask from xformers.ops import fmha +logger = logging.getLogger() + def patch_reduce(h, max_num_patches, reduction, patch_ids): """ @@ -97,14 +101,69 @@ def causal_mask(b, h, q_idx, kv_idx): return q_idx >= kv_idx -def create_causal_mask(seqlen, attn_impl, sliding_window): - if sliding_window is not None and attn_impl == "xformers": - return fmha.attn_bias.LocalAttentionFromBottomRightMask( - window_left=sliding_window - 1, window_right=0 - ) - elif attn_impl == "xformers": - return fmha.attn_bias.LowerTriangularMask() +def tokens_to_seqlen(batch: torch.Tensor, eos_id: int): + """ + 0 0 0 1 0 0 0 1 0 0 0 + 0 1 0 0 0 1 0 0 0 0 0 + -> 4 4 3 2 4 5 + """ + mask = batch == eos_id + mask[:, -1] = True # virtual eos at the end of each row + + # 0 0 0 1 0 0 0 1 0 0 X + # 0 1 0 0 0 1 0 0 0 0 X + row, col = torch.where(mask) + + # row = 0, 0, 0, 1, 1, 1 + # col = 3, 7, 10, 1, 5, 10 + seqlens = (col[1:] - col[:-1]) + (row[1:] - row[:-1]) * mask.shape[1] + # seqlens = (4, 3, -9, 4, 5) + (0, 0, 11, 0, 0) = (4, 3, 2, 4, 5) + return [int(col[0].item() + 1)] + seqlens.tolist() + + +WARNED_SDPA = False + + +def create_causal_mask( + seqlen, + attn_impl: str, + attn_bias_type: str | None, + *, + eos_id: int | None = None, + tokens: torch.Tensor | None = None, + sliding_window: int | None = None, +): + if attn_impl == "xformers": + if attn_bias_type is None: + return fmha.attn_bias.LowerTriangularMask() + elif attn_bias_type == "causal": + assert sliding_window is None + return fmha.attn_bias.LowerTriangularMask() + elif attn_bias_type == "block_causal": + assert sliding_window is None + assert eos_id is not None + assert tokens is not None + return fmha.attn_bias.BlockDiagonalCausalMask.from_seqlens( + q_seqlen=tokens_to_seqlen(tokens, eos_id) + ) + elif attn_bias_type == "local_block_causal": + assert sliding_window is not None + assert eos_id is not None + assert tokens is not None + return fmha.attn_bias.BlockDiagonalCausalMask.from_seqlens( + q_seqlen=tokens_to_seqlen(tokens, eos_id) + ).make_local_attention(sliding_window) + else: + return fmha.attn_bias.LocalAttentionFromBottomRightMask( + window_left=sliding_window - 1, window_right=0 + ) elif attn_impl == "sdpa": + global WARNED_SDPA + if not WARNED_SDPA: + logging.warning( + "SDPA attention being used, which doesn't have specialized attention implementations for block_causal and local_block_causal attention." + ) + WARNED_SDPA = True return "causal" elif attn_impl == "flex_attention": return create_block_mask(causal_mask, None, None, seqlen, seqlen) diff --git a/bytelatent/preprocess/fsspec_target.py b/bytelatent/preprocess/fsspec_target.py new file mode 100644 index 0000000..eacb101 --- /dev/null +++ b/bytelatent/preprocess/fsspec_target.py @@ -0,0 +1,38 @@ +import fsspec +from luigi.target import FileSystem, FileSystemTarget + + +class FSSpecFileSystem(FileSystem): + def __init__(self, fs: fsspec.AbstractFileSystem): + self.fs = fs + + def exists(self, path): + return self.fs.exists() + + def remove(self, path, recursive=True, skip_trash=True): + raise NotImplementedError() + + def isdir(self, path): + return self.fs.isdir(path) + + def listdir(self, path): + return self.fs.ls(path) + + +class FSSpecTarget(FileSystemTarget): + def __init__(self, path, fs: fsspec.AbstractFileSystem | None = None): + self.path = path + if fs is None: + self.fsspec_fs = fsspec.filesystem("file") + else: + self.fsspec_fs = fs + self._fs = None + + @property + def fs(self): + if self._fs is None: + self._fs = FSSpecFileSystem(self.fsspec_fs) + return self._fs + + def open(self, mode): + return self.fs.open(self.path, mode=mode) diff --git a/bytelatent/test_blt.py b/bytelatent/test_blt.py index 73ad9f7..4d8e9c7 100644 --- a/bytelatent/test_blt.py +++ b/bytelatent/test_blt.py @@ -23,9 +23,10 @@ from bytelatent.model.blt import ( init_embeddings, patch_ids_from_lengths, ) -from bytelatent.model.transformer import CrossAttention +from bytelatent.model.global_transformer import CrossAttention from bytelatent.model.utils import create_causal_mask from bytelatent.optim import OptimArgs, build_optimizer +from bytelatent.tokenizers.constants import EOS_ID from bytelatent.train import compute_loss @@ -51,7 +52,7 @@ def batch_to_tensors_and_gpu(batch): def fake_batch(): - batch_dict = torch.load(os.path.join(BLT_DATA, "test_batch.pt")) + batch_dict = torch.load(os.path.join(BLT_DATA, "test_batch.pt"), weights_only=False) del batch_dict["x2"] del batch_dict["y2"] del batch_dict["src_names"] @@ -98,18 +99,17 @@ def create_args(cross_attention=False): recompute_attn=False, custom_bwd=False, layer_ckpt="none", - efficient_attn="sdpa", - patch_only_encoder=False, - patch_only_decoder=False, use_local_encoder_transformer=True, init_use_gaussian=True, init_use_depth="current", attn_bias_type="block_causal", + attn_impl="xformers", alpha_depth="disabled", max_length=256, local_attention_window_len=512, max_seqlen=12288, downsampling_by_pooling="max", + eos_id=EOS_ID, ) return transformer_args @@ -341,10 +341,10 @@ class TestByteLatentTransformer: model = ByteLatentTransformer(args) assert model is not None - @pytest.mark.parametrize("attn_type", ["fmha", "sdpa"]) - def test_blt_transformer_forward(self, attn_type): + @pytest.mark.parametrize("attn_impl", ["sdpa", "xformers"]) + def test_blt_transformer_forward(self, attn_impl): args = create_args() - args = args.model_copy(update=dict(efficient_attn=attn_type)) + args = args.model_copy(update=dict(attn_impl=attn_impl)) model = ByteLatentTransformer(args) model = model.cuda() batch = fake_batch() @@ -393,7 +393,9 @@ class TestByteLatentTransformer: n_kv_heads=4, norm_eps=1e-6, ).to("cuda") - mask = create_causal_mask(x.shape[1], "flex_attention", sliding_window=None) + mask = create_causal_mask( + x.shape[1], "flex_attention", None, sliding_window=None + ) output = cross_attention(x, kv, mask) assert output is not None assert output.shape == (2, 256, 512) @@ -440,7 +442,7 @@ class TestByteLatentTransformer: def test_loss_backward(self): args = create_args() - args = args.model_copy(update=dict(efficient_attn="sdpa")) + args = args.model_copy(update=dict(attn_impl="sdpa")) batch = fake_batch() model = ByteLatentTransformer(args) steps = 10 diff --git a/bytelatent/test_entropy_model.py b/bytelatent/test_entropy_model.py index 3acc42d..af81638 100644 --- a/bytelatent/test_entropy_model.py +++ b/bytelatent/test_entropy_model.py @@ -24,6 +24,7 @@ def test_entropy_model(): dataset_files=[ARROW_TEST_DATA], row_num=0, arrow_batch_size=100, + s3_profile=None, ) arrow_file = initial_state.build() tokenizer_args = TokenizerArgs( diff --git a/bytelatent/train.py b/bytelatent/train.py index 6cb13b9..80bd393 100644 --- a/bytelatent/train.py +++ b/bytelatent/train.py @@ -644,6 +644,10 @@ def main(): cfg = OmegaConf.merge(default_cfg, file_cfg, cli_args) cfg = OmegaConf.to_container(cfg, resolve=True, throw_on_missing=True) train_args = TrainArgs.model_validate(cfg) + if train_args.debug_dynamo: + import torch._dynamo + + torch._dynamo.config.suppress_errors = True train(train_args) diff --git a/bytelatent/transformer.py b/bytelatent/transformer.py index 432f7df..92c5ff5 100644 --- a/bytelatent/transformer.py +++ b/bytelatent/transformer.py @@ -22,23 +22,7 @@ from bytelatent.base_transformer import ( RMSNorm, cross_entropy, ) - - -def create_causal_mask(seqlen, attn_impl, sliding_window): - if sliding_window is not None and attn_impl == "xformers": - return fmha.attn_bias.LocalAttentionFromBottomRightMask( - window_left=sliding_window - 1, window_right=0 - ) - elif attn_impl == "xformers": - return fmha.attn_bias.LowerTriangularMask() - elif attn_impl == "sdpa": - return "causal" - elif attn_impl == "flex_attention": - return create_block_mask(causal_mask, None, None, seqlen, seqlen) - else: - raise NotImplementedError( - f"Attention {attn_impl} with {sliding_window} sliding window not implemented" - ) +from bytelatent.model.utils import create_causal_mask def attention_flops_per_token(n_layers, seq_len, dim, causal): @@ -94,8 +78,10 @@ class LMTransformer(BaseTransformer): target: Optional[torch.Tensor] = None, tok_idx: Optional[torch.Tensor] = None, mask: Optional[Union[BlockMask, AttentionBias, torch.Tensor, str]] = None, - attn_impl: str = "sdpa", + attn_impl: str | None = None, ): + if attn_impl is None: + attn_impl = self.attn_impl bsz, seqlen = token_values.shape h = self.tok_embeddings(token_values) @@ -103,7 +89,14 @@ class LMTransformer(BaseTransformer): mask = ( mask if mask is not None - else create_causal_mask(seqlen, attn_impl, self.sliding_window) + else create_causal_mask( + seqlen, + attn_impl, + self.attn_bias_type, + sliding_window=self.sliding_window, + tokens=token_values, + eos_id=self.eos_id, + ) ) h = super().forward(h, tok_idx=tok_idx, mask=mask, attn_impl=attn_impl) From 7f305b38719c8ffb788753f9bc2f41df22834162 Mon Sep 17 00:00:00 2001 From: Pedro Rodriguez Date: Fri, 17 Jan 2025 22:21:50 +0000 Subject: [PATCH 10/59] [WIP] Changes for training entropy model and correcting attention in local models Summary: - Refactor local model configs to be separate and clearer - Add attention arguments and correct which attention is used in local models - Preparation for being able to have an entropy train script - Fix failing unit tests Test Plan: --- bytelatent/args.py | 7 + bytelatent/base_transformer.py | 45 ++++-- bytelatent/configs/debug.yaml | 3 +- .../data/iterators/test_arrow_iterator.py | 3 + bytelatent/distributed.py | 1 - bytelatent/entropy_model.py | 10 +- bytelatent/model/blt.py | 128 ++++++++++-------- .../{transformer.py => latent_transformer.py} | 17 ++- bytelatent/model/local_models.py | 84 ++++++++---- bytelatent/model/utils.py | 80 +++++++++-- bytelatent/preprocess/fsspec_target.py | 38 ++++++ bytelatent/test_blt.py | 27 ++-- bytelatent/test_entropy_model.py | 9 +- bytelatent/train.py | 4 + bytelatent/transformer.py | 31 ++--- 15 files changed, 349 insertions(+), 138 deletions(-) rename bytelatent/model/{transformer.py => latent_transformer.py} (93%) create mode 100644 bytelatent/preprocess/fsspec_target.py diff --git a/bytelatent/args.py b/bytelatent/args.py index cfba3bf..b9144c6 100644 --- a/bytelatent/args.py +++ b/bytelatent/args.py @@ -30,6 +30,7 @@ from bytelatent.model.blt import ByteLatentTransformerArgs from bytelatent.optim import OptimArgs from bytelatent.profiling import ProfilerArgs from bytelatent.tokenizers.build_tokenizer import TokenizerArgs +from bytelatent.transformer import LMTransformerArgs logger = logging.getLogger() @@ -163,6 +164,8 @@ class TrainArgs(BaseModel): seed: int = 42 + debug_dynamo: bool = False + # Number of gradient accumulation steps # Total batch size is batch_size*grad_acc_steps grad_acc_steps: int = 1 @@ -176,6 +179,10 @@ class TrainArgs(BaseModel): data: DataloaderArgs = DataloaderArgs() optim: OptimArgs = OptimArgs() model: ByteLatentTransformerArgs = ByteLatentTransformerArgs() + # This is only needed for training the entropy model + entropy_model: LMTransformerArgs | None = None + # Instead of training main model, train entropy model + train_entropy_model: bool = False distributed: DistributedArgs = DistributedArgs() env: EnvironmentArgs = EnvironmentArgs() diff --git a/bytelatent/base_transformer.py b/bytelatent/base_transformer.py index 45cb7c5..dd0cce6 100644 --- a/bytelatent/base_transformer.py +++ b/bytelatent/base_transformer.py @@ -4,7 +4,7 @@ from enum import Enum from typing import Optional, Tuple, Union import torch -from pydantic import BaseModel +from pydantic import BaseModel, ConfigDict from torch import nn from torch.nn import functional as F from torch.nn.attention.flex_attention import ( @@ -15,6 +15,7 @@ from torch.nn.attention.flex_attention import ( from xformers.ops import AttentionBias, fmha from bytelatent import probe +from bytelatent.tokenizers.constants import EOS_ID if int(os.environ.get("BLT_ALLOW_MISSING_FLEX_ATTENTION", False)) == 0: flex_attention_comp = torch.compile(flex_attention) @@ -30,13 +31,14 @@ class InitStdFactor(Enum): class BaseTransformerArgs(BaseModel): + model_config = ConfigDict(extra="forbid") dim: int = 512 n_layers: int = 8 - head_dim: Optional[int] = None - n_heads: Optional[int] = None - n_kv_heads: Optional[int] = None + head_dim: int | None = None + n_heads: int | None = None + n_kv_heads: int | None = None - ffn_dim_multiplier: Optional[float] = None + ffn_dim_multiplier: float | None = None multiple_of: int = 256 @@ -44,11 +46,16 @@ class BaseTransformerArgs(BaseModel): rope_theta: float = 10000.0 - init_base_std: Optional[float] = None + init_base_std: float | None = None init_std_factor: InitStdFactor = InitStdFactor.DISABLED max_seqlen: int = 1024 + attn_impl: str | None = "sdpa" + attn_bias_type: str | None = None + # Special token config + eos_id: int | None = EOS_ID + def cross_entropy(pred, target, **kwargs): return F.nll_loss( @@ -294,6 +301,18 @@ class RMSNorm(nn.Module): torch.nn.init.ones_(self.weight) # type: ignore +def _reshape_for_attn_bias( + attn_bias: AttentionBias | None, + *tensors: torch.Tensor, +) -> list[torch.Tensor]: + to_transform = list(tensors) + if isinstance(attn_bias, fmha.attn_bias.BlockDiagonalCausalMask): + # could be `view` instead of reshape during training, but for inference + # have to reshape due to strides mismatch + to_transform = [t.reshape(1, -1, *t.shape[2:]) for t in to_transform] + return to_transform + + class Attention(nn.Module): def __init__( self, @@ -371,9 +390,12 @@ class Attention(nn.Module): output = flex_attention_comp(xq, xk, xv, block_mask=mask) output = output.transpose(1, 2).contiguous() # B H S D -> B S H D - elif attn_impl == "fmha": + elif attn_impl == "xformers": assert mask is None or isinstance(mask, AttentionBias) + query_shape = xq.shape + xq, xk, xv = _reshape_for_attn_bias(mask, xq, xk, xv) output = fmha.memory_efficient_attention(xq, xk, xv, attn_bias=mask) + output = output.view(query_shape) # This uses B S H D instead of B H S D of pytorch elif attn_impl == "sdpa": @@ -522,14 +544,16 @@ class TransformerBlock(nn.Module): mask: Optional[Union[BlockMask, AttentionBias, str]] = None, attn_impl: str = "sdpa", ) -> torch.Tensor: - h = x + self.attention( + attn_out = self.attention( self.attention_norm(x), freq_cis, tok_idx=tok_idx, mask=mask, attn_impl=attn_impl, ) - out = h + self.feed_forward(self.ffn_norm(h)) + h = x + attn_out + h_norm = self.ffn_norm(h) + out = h + self.feed_forward(h_norm) return out def init_weights(self, init_std=None, factor=1.0): @@ -545,6 +569,8 @@ class BaseTransformer(nn.Module): super().__init__() self.dim = args.dim self.init_base_std = args.init_base_std + self.attn_impl = args.attn_impl + self.attn_bias_type = args.attn_bias_type self.init_std_factor = InitStdFactor(args.init_std_factor) self.max_seqlen = args.max_seqlen self.rope_embeddings = RotaryEmbedding( @@ -552,6 +578,7 @@ class BaseTransformer(nn.Module): head_dim=args.head_dim or args.dim // args.n_heads, max_seqlen=args.max_seqlen, ) + self.eos_id = args.eos_id self.layers = nn.ModuleList() for _ in range(args.n_layers): diff --git a/bytelatent/configs/debug.yaml b/bytelatent/configs/debug.yaml index 5f6debb..4ae4459 100644 --- a/bytelatent/configs/debug.yaml +++ b/bytelatent/configs/debug.yaml @@ -15,7 +15,6 @@ optim: distributed: fsdp_type: full_shard - compile: true model_dtype: bf16 matmul_allow_tf32: false selective_activation_checkpointing: false @@ -58,13 +57,13 @@ model: recompute_attn: false custom_bwd: false layer_ckpt: "none" - efficient_attn: "sdpa" patch_only_encoder: false patch_only_decoder: false use_local_encoder_transformer: true init_use_gaussian: true init_use_depth: "current" attn_bias_type: "block_causal" + attn_impl: "xformers" alpha_depth: "disabled" max_length: 256 local_attention_window_len: 512 diff --git a/bytelatent/data/iterators/test_arrow_iterator.py b/bytelatent/data/iterators/test_arrow_iterator.py index 4266427..fd448eb 100644 --- a/bytelatent/data/iterators/test_arrow_iterator.py +++ b/bytelatent/data/iterators/test_arrow_iterator.py @@ -27,6 +27,7 @@ def test_basic_arrow_file(): dataset_files=[ARROW_TEST_DATA_1], row_num=0, arrow_batch_size=100, + s3_profile=None, ) arrow_file = initial_state.build() start_state = arrow_file.get_state() @@ -55,6 +56,7 @@ def test_basic_arrow_file(): dataset_files=[ARROW_TEST_DATA_1], row_num=251, arrow_batch_size=100, + s3_profile=None, ) arrow_file = resumed_state.build() for example in arrow_file.create_iter(): @@ -74,6 +76,7 @@ def test_basic_arrow_file(): dataset_files=[ARROW_TEST_DATA_1], row_num=0, arrow_batch_size=100, + s3_profile=None, ) arrow_file = rank_state.build() expected_ids = [] diff --git a/bytelatent/distributed.py b/bytelatent/distributed.py index b211858..168cb7c 100644 --- a/bytelatent/distributed.py +++ b/bytelatent/distributed.py @@ -11,7 +11,6 @@ import socket import subprocess import sys import tempfile -from dataclasses import asdict, dataclass from functools import lru_cache, partial, reduce from itertools import chain from typing import List, Optional, Tuple, Union diff --git a/bytelatent/entropy_model.py b/bytelatent/entropy_model.py index 1bd1766..30754ee 100644 --- a/bytelatent/entropy_model.py +++ b/bytelatent/entropy_model.py @@ -1,12 +1,14 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. import json +import logging import os -import re import torch from bytelatent.transformer import LMTransformer, LMTransformerArgs +logger = logging.getLogger() + def load_entropy_model(entropy_model_checkpoint_dir, state_dict_path, device="cpu"): with open(os.path.join(entropy_model_checkpoint_dir, "params.json")) as fr: @@ -14,6 +16,9 @@ def load_entropy_model(entropy_model_checkpoint_dir, state_dict_path, device="cp torch.set_default_dtype(torch.bfloat16) model_params = reloaded["model"] + logger.warning( + "Update checkpoint to load attn and sliding window args from checkpoint" + ) entropy_model = LMTransformer( LMTransformerArgs( dim=model_params["dim"], @@ -22,6 +27,9 @@ def load_entropy_model(entropy_model_checkpoint_dir, state_dict_path, device="cp max_seqlen=model_params["max_length"], ffn_dim_multiplier=model_params["ffn_dim_multiplier"], vocab_size=model_params["vocab_size"], + attn_bias_type="local_block_causal", + attn_impl="xformers", + sliding_window=512, ) ) diff --git a/bytelatent/model/blt.py b/bytelatent/model/blt.py index 9332d19..843ad34 100644 --- a/bytelatent/model/blt.py +++ b/bytelatent/model/blt.py @@ -15,8 +15,8 @@ from bytelatent.base_transformer import ( TransformerBlock, ) from bytelatent.data.patcher import Patcher, PatcherArgs -from bytelatent.model.local_models import LocalDecoder, LocalEncoder -from bytelatent.model.transformer import GlobalTransformer +from bytelatent.model.latent_transformer import GlobalTransformer +from bytelatent.model.local_models import LocalDecoder, LocalEncoder, LocalModelArgs from bytelatent.model.utils import downsample from bytelatent.tokenizers.constants import BOE_ID, BOS_ID, EOS_ID, OFFSET, PAD_ID @@ -403,7 +403,6 @@ def patch_ids_from_lengths(patch_lengths, seq_len): class ByteLatentTransformerArgs(BaseTransformerArgs): - model_config = ConfigDict(extra="forbid") # Basic model configuration seed: int = 42 vocab_size: int = -1 @@ -412,7 +411,6 @@ class ByteLatentTransformerArgs(BaseTransformerArgs): n_heads: int = 8 # TODO: What is the purpose of this parameter? weight_tying: bool = False - sliding_window: Optional[int] = None # Architecture and dimensions dim_token: int = 256 @@ -471,11 +469,6 @@ class ByteLatentTransformerArgs(BaseTransformerArgs): recompute_attn: bool = True custom_bwd: bool = False layer_ckpt: str = "all" - efficient_attn: str | None = None - - # Architecture options - patch_only_encoder: bool = False - patch_only_decoder: bool = False # Initialization and attention init_use_gaussian: bool = True @@ -541,9 +534,6 @@ class ByteLatentTransformerArgs(BaseTransformerArgs): # Logging full_logging_n_layers: int = 4 - # Special token config - eos_id: int | None = None - @model_validator(mode="after") def check_hash_byte_sizes(self) -> Self: if ( @@ -558,22 +548,6 @@ class ByteLatentTransformerArgs(BaseTransformerArgs): return self -class LocalEncoderArgs(ByteLatentTransformerArgs): - # Local encoder specific dimensions - n_heads_local_encoder: int = 8 - dim_token_emb: int | None = None - dim_patch_emb: int | None = None - - def __post_init__(self): - # Override base args with local encoder specific values - self.dim = self.dim_local_encoder - self.n_layers = self.n_layers_local_encoder - self.n_heads = self.n_heads_local_encoder - self.cross_attn_decoder = False - self.cross_attn_k = self.cross_attn_k if self.cross_attn_encoder else None - self.attn_bias_type = "local_block_causal" - - class GlobalTransformerArgs(ByteLatentTransformerArgs): # Global encoder specific dimensions dim_token_emb: int | None = None @@ -625,20 +599,42 @@ def create_global_transformer(args: ByteLatentTransformerArgs) -> GlobalTransfor def create_local_encoder(args: ByteLatentTransformerArgs) -> LocalEncoder: - # First deep copy the original args - # Replace with local encoder specific values - local_encoder_args = args.model_copy( - deep=True, - update=dict( - dim=args.dim_local_encoder, - n_layers=args.n_layers_local_encoder, - n_heads=args.n_heads_local_encoder, - dim_token_emb=get_encoder_dim_token_emb(args), - dim_patch_emb=get_encoder_dim_patch_emb(args), - cross_attn_decoder=False, - cross_attn_k=args.cross_attn_k if args.cross_attn_encoder else None, - attn_bias_type="local_block_causal", - ), + local_encoder_args = LocalModelArgs( + # Updated args + dim=args.dim_local_encoder, + n_layers=args.n_layers_local_encoder, + n_heads=args.n_heads_local_encoder, + dim_token_emb=get_encoder_dim_token_emb(args), + dim_patch_emb=get_encoder_dim_patch_emb(args), + cross_attn_encoder=args.cross_attn_encoder, + cross_attn_decoder=False, + cross_attn_k=args.cross_attn_k if args.cross_attn_encoder else None, + cross_attn_init_by_pooling=args.cross_attn_init_by_pooling, + # Defaults + head_dim=args.head_dim, + max_seqlen=args.max_encoder_seq_length, + dropout=args.dropout, + vocab_size=args.vocab_size + args.pm_size, + norm_eps=args.norm_eps, + patch_size=args.patch_size, + sliding_window=args.local_attention_window_len, + use_rope=args.use_rope, + rope_theta=args.rope_theta, + init_base_std=args.init_base_std, + init_std_factor=args.init_std_factor, + n_kv_heads=args.n_kv_heads, + attn_impl=args.attn_impl, + attn_bias_type="local_block_causal", + multiple_of=args.multiple_of, + ffn_dim_multiplier=args.ffn_dim_multiplier, + patching_mode=args.patching_mode, + use_local_encoder_transformer=args.use_local_encoder_transformer, + downsampling_by_pooling=args.downsampling_by_pooling, + encoder_hash_byte_group_size=args.encoder_hash_byte_group_size, + cross_attn_all_layers_encoder=args.cross_attn_all_layers_encoder, + cross_attn_all_layers_decoder=args.cross_attn_all_layers_decoder, + cross_attn_nheads=args.cross_attn_nheads, + eos_id=args.eos_id, ) return LocalEncoder(local_encoder_args) @@ -646,18 +642,41 @@ def create_local_encoder(args: ByteLatentTransformerArgs) -> LocalEncoder: def create_local_decoder(args: ByteLatentTransformerArgs) -> LocalDecoder: # First deep copy the original args - local_decoder_args = args.model_copy( - deep=True, - update=dict( - dim=args.dim_local_decoder, - n_layers=args.n_layers_local_decoder, - n_heads=args.n_heads_local_decoder, - cross_attn_encoder=False, - cross_attn_init_by_pooling=False, # states are already defined - dim_token_emb=get_decoder_dim_token_emb(args), - dim_patch_emb=args.dim_global, - cross_attn_k=args.cross_attn_k if args.cross_attn_decoder else None, - ), + local_decoder_args = LocalModelArgs( + dim=args.dim_local_decoder, + n_layers=args.n_layers_local_decoder, + n_heads=args.n_heads_local_decoder, + dim_token_emb=get_decoder_dim_token_emb(args), + dim_patch_emb=args.dim_global, + cross_attn_encoder=False, + cross_attn_decoder=args.cross_attn_decoder, + cross_attn_init_by_pooling=False, # states are already defined + cross_attn_k=args.cross_attn_k if args.cross_attn_decoder else None, + # Defaults + head_dim=args.head_dim, + max_seqlen=args.max_encoder_seq_length, + dropout=args.dropout, + vocab_size=args.vocab_size + args.pm_size, + norm_eps=args.norm_eps, + patch_size=args.patch_size, + sliding_window=args.local_attention_window_len, + use_rope=args.use_rope, + rope_theta=args.rope_theta, + init_base_std=args.init_base_std, + init_std_factor=args.init_std_factor, + n_kv_heads=args.n_kv_heads, + attn_impl=args.attn_impl, + attn_bias_type="local_block_causal", + multiple_of=args.multiple_of, + ffn_dim_multiplier=args.ffn_dim_multiplier, + patching_mode=args.patching_mode, + use_local_encoder_transformer=args.use_local_encoder_transformer, + downsampling_by_pooling=args.downsampling_by_pooling, + encoder_hash_byte_group_size=args.encoder_hash_byte_group_size, + cross_attn_all_layers_encoder=args.cross_attn_all_layers_encoder, + cross_attn_all_layers_decoder=args.cross_attn_all_layers_decoder, + cross_attn_nheads=args.cross_attn_nheads, + eos_id=args.eos_id, ) return LocalDecoder(local_decoder_args) @@ -763,7 +782,6 @@ class ByteLatentTransformer(nn.Module): # General configuration self.weight_tying = args.weight_tying - self.sliding_window = args.sliding_window self.patch_size = args.patch_size self.patching_mode = args.patching_mode self.boe_id, self.bos_id, self.pad_id, self.eos_id = ( diff --git a/bytelatent/model/transformer.py b/bytelatent/model/latent_transformer.py similarity index 93% rename from bytelatent/model/transformer.py rename to bytelatent/model/latent_transformer.py index 24dc057..21c3f0c 100644 --- a/bytelatent/model/transformer.py +++ b/bytelatent/model/latent_transformer.py @@ -11,6 +11,7 @@ from xformers.ops import AttentionBias from bytelatent.base_transformer import ( BaseTransformer, + BaseTransformerArgs, RMSNorm, flex_attention_comp, repeat_kv, @@ -142,11 +143,10 @@ class CrossAttention(nn.Module): class GlobalTransformer(BaseTransformer): - def __init__(self, args): + def __init__(self, args: BaseTransformerArgs): super().__init__(args) self.dropout = args.dropout - self.sliding_window = args.sliding_window - self.efficient_attn = args.efficient_attn + self.eos_id = args.eos_id self.token_embedding_projection = None if args.dim_token_emb is not None and args.dim_token_emb != self.dim: @@ -169,14 +169,19 @@ class GlobalTransformer(BaseTransformer): and projection to the token space. """ bs, seqlen = tokens.shape - attn_impl = self.efficient_attn h = embeds mask = ( mask if mask is not None - else create_causal_mask(seqlen, attn_impl, self.sliding_window) + else create_causal_mask( + seqlen, + self.attn_impl, + self.attn_bias_type, + tokens=tokens, + eos_id=self.eos_id, + ) ) if self.token_embedding_projection is not None and h.shape[-1] != self.dim: @@ -184,7 +189,7 @@ class GlobalTransformer(BaseTransformer): h = F.dropout(h, p=self.dropout, training=self.training) - h = super().forward(h, tok_idx=tok_idx, mask=mask, attn_impl=attn_impl) + h = super().forward(h, tok_idx=tok_idx, mask=mask, attn_impl=self.attn_impl) return h, cache def init_weights(self, init_base_std: float): diff --git a/bytelatent/model/local_models.py b/bytelatent/model/local_models.py index 8255504..59fa76d 100644 --- a/bytelatent/model/local_models.py +++ b/bytelatent/model/local_models.py @@ -1,44 +1,75 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. import logging -from typing import List, Optional, Tuple, Union +from typing import Any, List, Optional, Tuple, Union import torch import torch.nn import torch.nn as nn +from pydantic import BaseModel, ConfigDict from torch.nn import functional as F from torch.nn.attention.flex_attention import BlockMask from xformers.ops import AttentionBias from bytelatent.base_transformer import ( + BaseTransformerArgs, InitStdFactor, RMSNorm, RotaryEmbedding, TransformerBlock, ) -from bytelatent.model.transformer import CrossAttention +from bytelatent.model.latent_transformer import CrossAttention from bytelatent.model.utils import create_causal_mask, downsample from bytelatent.tokenizers.blt_tokenizer import BOE_ID logger = logging.getLogger() +class LocalModelArgs(BaseTransformerArgs): + model_config = ConfigDict(extra="forbid") + # Override defaults + attn_impl: str | None = "xformers" + attn_bias_type: str | None = "local_block_causal" + + # Local encoder specific dimensions + dropout: float + vocab_size: int + patch_size: int + sliding_window: int | None + use_rope: bool + cross_attn_encoder: bool | None + cross_attn_decoder: bool | None + cross_attn_k: int | None + cross_attn_init_by_pooling: bool + patching_mode: str + use_local_encoder_transformer: bool + downsampling_by_pooling: str | None + encoder_hash_byte_group_size: Any | None = None + cross_attn_all_layers_encoder: bool = False + cross_attn_all_layers_decoder: bool = False + cross_attn_nheads: int | None + + dim_token_emb: int + dim_patch_emb: int | None + + class LocalModelBase(nn.Module): - def __init__(self, args): + def __init__(self, args: LocalModelArgs): super().__init__() self.dim = args.dim self.dropout = args.dropout - self.vocab_size = args.vocab_size + args.pm_size + self.vocab_size = args.vocab_size self.patch_size = args.patch_size - self.efficient_attn = args.efficient_attn + self.attn_impl = args.attn_impl self.sliding_window = args.sliding_window self.use_rope = args.use_rope self.init_std_factor = args.init_std_factor self.cross_attn_encoder = getattr(args, "cross_attn_encoder", None) self.cross_attn_decoder = getattr(args, "cross_attn_decoder", None) self.cross_attn_k = getattr(args, "cross_attn_k", None) + self.eos_id = args.eos_id self.boe_id = BOE_ID @@ -54,7 +85,7 @@ class LocalModelBase(nn.Module): self.rope = RotaryEmbedding( theta=args.rope_theta, head_dim=args.head_dim or args.dim // args.n_heads, - max_seqlen=getattr(args, "max_encoder_seq_length", args.max_length), + max_seqlen=args.max_seqlen, ) self.pos_embeddings = None @@ -66,21 +97,15 @@ class LocalModelBase(nn.Module): self.patch_embedding_projection = self._create_patch_projection(args) - def _should_create_patch_projection(self, args): + def _should_create_patch_projection(self, args: LocalModelArgs): dimension_mismatch = ( getattr(args, "dim_patch_emb") and args.dim_patch_emb != self.dim ) # Check cross attention conditions cross_attn_conditions = ( - hasattr(args, "cross_attn_encoder") - and args.cross_attn_encoder - and getattr(args, "cross_attn_init_by_pooling") - ) or ( - hasattr(args, "cross_attn_decoder") - and args.cross_attn_decoder - and getattr(args, "cross_attn_init_by_pooling") - ) + args.cross_attn_encoder and args.cross_attn_init_by_pooling + ) or (args.cross_attn_decoder and args.cross_attn_init_by_pooling) return dimension_mismatch or cross_attn_conditions @@ -172,7 +197,7 @@ class LocalModelBase(nn.Module): class LocalEncoder(LocalModelBase): - def __init__(self, args): + def __init__(self, args: LocalModelArgs): super().__init__(args) self.output_proj = ( args.patching_mode in ["entropy", "probmax"] @@ -180,7 +205,6 @@ class LocalEncoder(LocalModelBase): self.apply_transformer = args.use_local_encoder_transformer self.downsampling_by_pooling = args.downsampling_by_pooling - self.patch_only = args.patch_only_encoder self.expects_hash_embeddings = args.encoder_hash_byte_group_size is not None self.cross_attn_encoder = args.cross_attn_encoder self.cross_attn_all_layers_encoder = args.cross_attn_all_layers_encoder @@ -224,7 +248,14 @@ class LocalEncoder(LocalModelBase): """ """ bs, seqlen = tokens.shape if mask is None: - mask = create_causal_mask(seqlen, self.efficient_attn, self.sliding_window) + mask = create_causal_mask( + seqlen, + self.attn_impl, + "local_block_causal", + sliding_window=self.sliding_window, + tokens=tokens, + eos_id=self.eos_id, + ) h = self.apply_embedding(tokens, embeds) freqs_cis = self.rope(seqlen=seqlen) if self.use_rope else None @@ -232,7 +263,7 @@ class LocalEncoder(LocalModelBase): h = F.dropout(h, p=self.dropout, training=self.training) for i, layer in enumerate(self.layers): - h = layer(h, mask=mask, freq_cis=freqs_cis, attn_impl=self.efficient_attn) + h = layer(h, mask=mask, freq_cis=freqs_cis, attn_impl=self.attn_impl) # check if cross attention should be applied to either all layer or only the last layer if self.cross_attn_encoder and ( i == len(self.layers) - 1 or self.cross_attn_all_layers_encoder @@ -273,12 +304,10 @@ class LocalEncoder(LocalModelBase): class LocalDecoder(LocalModelBase): - def __init__(self, args): + def __init__(self, args: LocalModelArgs): super().__init__(args) # Model configuration flags - self.patch_only = args.patch_only_decoder - self.expects_embeddings = args.share_encoder_decoder_emb self.cross_attn_decoder = args.cross_attn_decoder self.cross_attn_all_layers_decoder = args.cross_attn_all_layers_decoder self.cross_attn_init_by_pooling = args.cross_attn_init_by_pooling @@ -317,7 +346,14 @@ class LocalDecoder(LocalModelBase): assert embeds is not None, "Embeddings must be provided" if mask is None: - mask = create_causal_mask(seqlen, self.efficient_attn, self.sliding_window) + mask = create_causal_mask( + seqlen, + self.attn_impl, + "local_block_causal", + sliding_window=self.sliding_window, + tokens=tokens, + eos_id=self.eos_id, + ) h = embeds @@ -347,7 +383,7 @@ class LocalDecoder(LocalModelBase): ) h = h + h_cross - h = layer(h, mask=mask, freq_cis=freqs_cis, attn_impl=self.efficient_attn) + h = layer(h, mask=mask, freq_cis=freqs_cis, attn_impl=self.attn_impl) h_preds = self.norm(h) h_preds = F.dropout(h_preds, p=self.dropout, training=self.training) diff --git a/bytelatent/model/utils.py b/bytelatent/model/utils.py index ce52a30..7ca979d 100644 --- a/bytelatent/model/utils.py +++ b/bytelatent/model/utils.py @@ -1,8 +1,13 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. +import logging +import os + import torch from torch.nn.attention.flex_attention import create_block_mask from xformers.ops import fmha +logger = logging.getLogger() + def patch_reduce(h, max_num_patches, reduction, patch_ids): """ @@ -97,15 +102,74 @@ def causal_mask(b, h, q_idx, kv_idx): return q_idx >= kv_idx -def create_causal_mask(seqlen, attn_impl, sliding_window): - if sliding_window is not None and attn_impl == "xformers": - return fmha.attn_bias.LocalAttentionFromBottomRightMask( - window_left=sliding_window - 1, window_right=0 - ) - elif attn_impl == "xformers": - return fmha.attn_bias.LowerTriangularMask() +def tokens_to_seqlen(batch: torch.Tensor, eos_id: int): + """ + 0 0 0 1 0 0 0 1 0 0 0 + 0 1 0 0 0 1 0 0 0 0 0 + -> 4 4 3 2 4 5 + """ + mask = batch == eos_id + mask[:, -1] = True # virtual eos at the end of each row + + # 0 0 0 1 0 0 0 1 0 0 X + # 0 1 0 0 0 1 0 0 0 0 X + row, col = torch.where(mask) + + # row = 0, 0, 0, 1, 1, 1 + # col = 3, 7, 10, 1, 5, 10 + seqlens = (col[1:] - col[:-1]) + (row[1:] - row[:-1]) * mask.shape[1] + # seqlens = (4, 3, -9, 4, 5) + (0, 0, 11, 0, 0) = (4, 3, 2, 4, 5) + return [int(col[0].item() + 1)] + seqlens.tolist() + + +def create_causal_mask( + seqlen, + attn_impl: str, + attn_bias_type: str | None, + *, + eos_id: int | None = None, + tokens: torch.Tensor | None = None, + sliding_window: int | None = None, +): + if attn_impl == "xformers": + if attn_bias_type is None: + return fmha.attn_bias.LowerTriangularMask() + elif attn_bias_type == "causal": + assert sliding_window is None + return fmha.attn_bias.LowerTriangularMask() + elif attn_bias_type == "block_causal": + assert sliding_window is None + assert eos_id is not None + assert tokens is not None + return fmha.attn_bias.BlockDiagonalCausalMask.from_seqlens( + q_seqlen=tokens_to_seqlen(tokens, eos_id) + ) + elif attn_bias_type == "local_block_causal": + assert sliding_window is not None + assert eos_id is not None + assert tokens is not None + return fmha.attn_bias.BlockDiagonalCausalMask.from_seqlens( + q_seqlen=tokens_to_seqlen(tokens, eos_id) + ).make_local_attention(sliding_window) + else: + return fmha.attn_bias.LocalAttentionFromBottomRightMask( + window_left=sliding_window - 1, window_right=0 + ) elif attn_impl == "sdpa": - return "causal" + BLT_SUPPRESS_ATTN_ERROR = int(os.environ.get("BLT_SUPPRESS_ATTN_ERROR", 0)) + + if attn_bias_type == "causal": + return "causal" + + if BLT_SUPPRESS_ATTN_ERROR == 1: + logging.warning( + "SDPA attention being used, which doesn't have specialized attention implementations for block_causal and local_block_causal attention. Allowing model to run since BLT_SUPPRESS_ATTN_ERROR=1" + ) + return "causal" + else: + raise ValueError( + "SDPA attention being used, which doesn't have specialized attention implementations for block_causal and local_block_causal attention. To suppress this error and run the model anyway, set the environment variable BLT_SUPPRESS_ATTN_ERROR=1" + ) elif attn_impl == "flex_attention": return create_block_mask(causal_mask, None, None, seqlen, seqlen) elif attn_impl == "fmha": diff --git a/bytelatent/preprocess/fsspec_target.py b/bytelatent/preprocess/fsspec_target.py new file mode 100644 index 0000000..eacb101 --- /dev/null +++ b/bytelatent/preprocess/fsspec_target.py @@ -0,0 +1,38 @@ +import fsspec +from luigi.target import FileSystem, FileSystemTarget + + +class FSSpecFileSystem(FileSystem): + def __init__(self, fs: fsspec.AbstractFileSystem): + self.fs = fs + + def exists(self, path): + return self.fs.exists() + + def remove(self, path, recursive=True, skip_trash=True): + raise NotImplementedError() + + def isdir(self, path): + return self.fs.isdir(path) + + def listdir(self, path): + return self.fs.ls(path) + + +class FSSpecTarget(FileSystemTarget): + def __init__(self, path, fs: fsspec.AbstractFileSystem | None = None): + self.path = path + if fs is None: + self.fsspec_fs = fsspec.filesystem("file") + else: + self.fsspec_fs = fs + self._fs = None + + @property + def fs(self): + if self._fs is None: + self._fs = FSSpecFileSystem(self.fsspec_fs) + return self._fs + + def open(self, mode): + return self.fs.open(self.path, mode=mode) diff --git a/bytelatent/test_blt.py b/bytelatent/test_blt.py index 73ad9f7..36a9882 100644 --- a/bytelatent/test_blt.py +++ b/bytelatent/test_blt.py @@ -23,9 +23,10 @@ from bytelatent.model.blt import ( init_embeddings, patch_ids_from_lengths, ) -from bytelatent.model.transformer import CrossAttention +from bytelatent.model.latent_transformer import CrossAttention from bytelatent.model.utils import create_causal_mask from bytelatent.optim import OptimArgs, build_optimizer +from bytelatent.tokenizers.constants import EOS_ID from bytelatent.train import compute_loss @@ -51,7 +52,7 @@ def batch_to_tensors_and_gpu(batch): def fake_batch(): - batch_dict = torch.load(os.path.join(BLT_DATA, "test_batch.pt")) + batch_dict = torch.load(os.path.join(BLT_DATA, "test_batch.pt"), weights_only=False) del batch_dict["x2"] del batch_dict["y2"] del batch_dict["src_names"] @@ -98,18 +99,17 @@ def create_args(cross_attention=False): recompute_attn=False, custom_bwd=False, layer_ckpt="none", - efficient_attn="sdpa", - patch_only_encoder=False, - patch_only_decoder=False, use_local_encoder_transformer=True, init_use_gaussian=True, init_use_depth="current", attn_bias_type="block_causal", + attn_impl="xformers", alpha_depth="disabled", max_length=256, local_attention_window_len=512, max_seqlen=12288, downsampling_by_pooling="max", + eos_id=EOS_ID, ) return transformer_args @@ -341,10 +341,15 @@ class TestByteLatentTransformer: model = ByteLatentTransformer(args) assert model is not None - @pytest.mark.parametrize("attn_type", ["fmha", "sdpa"]) - def test_blt_transformer_forward(self, attn_type): + @pytest.mark.parametrize("attn_impl", ["sdpa", "xformers"]) + def test_blt_transformer_forward(self, attn_impl): args = create_args() - args = args.model_copy(update=dict(efficient_attn=attn_type)) + if attn_impl == "sdpa": + os.environ["BLT_SUPPRESS_ATTN_ERROR"] = "1" + else: + os.environ["BLT_SUPPRESS_ATTN_ERROR"] = "0" + + args = args.model_copy(update=dict(attn_impl=attn_impl)) model = ByteLatentTransformer(args) model = model.cuda() batch = fake_batch() @@ -393,7 +398,9 @@ class TestByteLatentTransformer: n_kv_heads=4, norm_eps=1e-6, ).to("cuda") - mask = create_causal_mask(x.shape[1], "flex_attention", sliding_window=None) + mask = create_causal_mask( + x.shape[1], "flex_attention", None, sliding_window=None + ) output = cross_attention(x, kv, mask) assert output is not None assert output.shape == (2, 256, 512) @@ -440,7 +447,7 @@ class TestByteLatentTransformer: def test_loss_backward(self): args = create_args() - args = args.model_copy(update=dict(efficient_attn="sdpa")) + args = args.model_copy(update=dict(attn_impl="xformers")) batch = fake_batch() model = ByteLatentTransformer(args) steps = 10 diff --git a/bytelatent/test_entropy_model.py b/bytelatent/test_entropy_model.py index 3acc42d..9db7ff6 100644 --- a/bytelatent/test_entropy_model.py +++ b/bytelatent/test_entropy_model.py @@ -24,6 +24,7 @@ def test_entropy_model(): dataset_files=[ARROW_TEST_DATA], row_num=0, arrow_batch_size=100, + s3_profile=None, ) arrow_file = initial_state.build() tokenizer_args = TokenizerArgs( @@ -38,7 +39,7 @@ def test_entropy_model(): BLT_DATA, "entropy_model.pth", ), - ) + ).cuda() preprocess_iter = PreprocessIterator( arrow_file, tokenizer_args=tokenizer_args, @@ -48,8 +49,10 @@ def test_entropy_model(): for example in preprocess_iter.create_iter(): tokens = torch.tensor(example.tokens).unsqueeze(0) expected_entropies = torch.tensor(example.entropies).unsqueeze(0) - preds = entropy_model(tokens) + preds = entropy_model(tokens.cuda()) pred_entropies = entropy(preds) assert pred_entropies.shape == expected_entropies.shape - assert torch.allclose(pred_entropies, expected_entropies, rtol=1.0, atol=3.5) + assert torch.allclose( + pred_entropies.cpu(), expected_entropies, rtol=1.0, atol=3.5 + ) break diff --git a/bytelatent/train.py b/bytelatent/train.py index 6cb13b9..80bd393 100644 --- a/bytelatent/train.py +++ b/bytelatent/train.py @@ -644,6 +644,10 @@ def main(): cfg = OmegaConf.merge(default_cfg, file_cfg, cli_args) cfg = OmegaConf.to_container(cfg, resolve=True, throw_on_missing=True) train_args = TrainArgs.model_validate(cfg) + if train_args.debug_dynamo: + import torch._dynamo + + torch._dynamo.config.suppress_errors = True train(train_args) diff --git a/bytelatent/transformer.py b/bytelatent/transformer.py index 432f7df..92c5ff5 100644 --- a/bytelatent/transformer.py +++ b/bytelatent/transformer.py @@ -22,23 +22,7 @@ from bytelatent.base_transformer import ( RMSNorm, cross_entropy, ) - - -def create_causal_mask(seqlen, attn_impl, sliding_window): - if sliding_window is not None and attn_impl == "xformers": - return fmha.attn_bias.LocalAttentionFromBottomRightMask( - window_left=sliding_window - 1, window_right=0 - ) - elif attn_impl == "xformers": - return fmha.attn_bias.LowerTriangularMask() - elif attn_impl == "sdpa": - return "causal" - elif attn_impl == "flex_attention": - return create_block_mask(causal_mask, None, None, seqlen, seqlen) - else: - raise NotImplementedError( - f"Attention {attn_impl} with {sliding_window} sliding window not implemented" - ) +from bytelatent.model.utils import create_causal_mask def attention_flops_per_token(n_layers, seq_len, dim, causal): @@ -94,8 +78,10 @@ class LMTransformer(BaseTransformer): target: Optional[torch.Tensor] = None, tok_idx: Optional[torch.Tensor] = None, mask: Optional[Union[BlockMask, AttentionBias, torch.Tensor, str]] = None, - attn_impl: str = "sdpa", + attn_impl: str | None = None, ): + if attn_impl is None: + attn_impl = self.attn_impl bsz, seqlen = token_values.shape h = self.tok_embeddings(token_values) @@ -103,7 +89,14 @@ class LMTransformer(BaseTransformer): mask = ( mask if mask is not None - else create_causal_mask(seqlen, attn_impl, self.sliding_window) + else create_causal_mask( + seqlen, + attn_impl, + self.attn_bias_type, + sliding_window=self.sliding_window, + tokens=token_values, + eos_id=self.eos_id, + ) ) h = super().forward(h, tok_idx=tok_idx, mask=mask, attn_impl=attn_impl) From 8a3084c346863cba88836aba087027cd96fa1c3c Mon Sep 17 00:00:00 2001 From: Pedro Rodriguez Date: Wed, 22 Jan 2025 19:57:54 +0000 Subject: [PATCH 11/59] Update file check script to check sizes Summary: Test Plan: --- bytelatent/data/file_util.py | 51 +++++++++++++++++++++++++++++++++--- 1 file changed, 48 insertions(+), 3 deletions(-) diff --git a/bytelatent/data/file_util.py b/bytelatent/data/file_util.py index d67b6db..5165eab 100644 --- a/bytelatent/data/file_util.py +++ b/bytelatent/data/file_util.py @@ -65,7 +65,10 @@ def print_local_to_delete( @app.command() def compare_local_to_blob( - source_dirs: list[str], dst_dir: str, s3_profile: str = "blt" + source_dirs: list[str], + dst_dir: str, + s3_profile: str = "blt", + print_sizes: bool = False, ): for s in source_dirs: assert s.endswith("/"), "Dirs must end with /" @@ -75,6 +78,7 @@ def compare_local_to_blob( local_fs = fsspec.filesystem("file") dst_fs = fsspec.filesystem("s3", profile=s3_profile) source_to_files = {} + source_file_to_size = {} all_local_files = set() for s in source_dirs: skipped = [] @@ -97,14 +101,28 @@ def compare_local_to_blob( skipped.append(f) continue + file_without_prefix = f[len(s) :] + if file_without_prefix not in source_file_to_size: + source_file_to_size[file_without_prefix] = os.path.getsize(f) + else: + source_file_to_size[file_without_prefix] = max( + source_file_to_size[file_without_prefix], os.path.getsize(f) + ) + source_to_files[s].append(f) - all_local_files.add(f[len(s) :]) + all_local_files.add(file_without_prefix) print(s, len(source_to_files[s]), "skipped", len(skipped), skipped[:10]) dst_files = dst_fs.find(dst_dir) print(dst_dir, len(dst_files)) - dst_file_set = {f[len(dst_dir) - len(S3_PREFIX) :] for f in dst_files} + dst_file_to_size = {} + dst_file_set = set() + for f in dst_files: + dst_file_without_prefix = f[len(dst_dir) - len(S3_PREFIX) :] + dst_file_set.add(dst_file_without_prefix) + dst_file_to_size[dst_file_without_prefix] = dst_fs.size(f) + diff = all_local_files.symmetric_difference(dst_file_set) print("Local files", len(all_local_files)) print("DST Files", len(dst_file_set)) @@ -112,6 +130,33 @@ def compare_local_to_blob( dst_only_files = dst_file_set - all_local_files print("DST only", len(dst_only_files), list(dst_only_files)[:10]) + all_files = dst_file_set | all_local_files + print("Check that files match") + size_success = True + for f in sorted(all_files): + if f in source_file_to_size and f in dst_file_to_size: + if source_file_to_size[f] != dst_file_to_size[f]: + size_success = False + print( + f"Mismatch file size for {f}, Local: {source_file_to_size[f]} Blob: {dst_file_to_size[f]}" + ) + else: + if print_sizes: + print(f"Matching file size: {dst_file_to_size[f]} for {f}") + elif f not in source_file_to_size: + size_success = False + print(f"Missing file in source: {f}") + elif f not in dst_file_to_size: + size_success = False + print(f"missing file in dst: {f}") + else: + raise ValueError("Unexpected to be missing file in src and dst") + + if size_success: + print("All files pass size check") + else: + raise ValueError("At least one file failed size comparison check") + if __name__ == "__main__": app() From bd461af91a86b572442fb4170c288603dbd5e6e5 Mon Sep 17 00:00:00 2001 From: Pedro Rodriguez Date: Fri, 24 Jan 2025 18:56:18 +0000 Subject: [PATCH 12/59] Use load_async flag to not start MP iterator Summary: Test Plan: --- bytelatent/args.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/bytelatent/args.py b/bytelatent/args.py index b9144c6..a332c89 100644 --- a/bytelatent/args.py +++ b/bytelatent/args.py @@ -150,11 +150,13 @@ class DataloaderArgs(BaseModel): enable_byte_ngrams=self.enable_byte_ngrams, ) packing_iterator = PackingIterator(sampling_iterator, packing_args=packing_args) - mp_iterator = MultiprocessIterator( - packing_iterator, n_batches_to_prefetch=self.prefetch_size - ) - - return mp_iterator + if self.load_async: + mp_iterator = MultiprocessIterator( + packing_iterator, n_batches_to_prefetch=self.prefetch_size + ) + return mp_iterator + else: + return packing_iterator class TrainArgs(BaseModel): From fb09022e5e40706d901dc1f0eb19f0333a6127e3 Mon Sep 17 00:00:00 2001 From: Pedro Rodriguez Date: Fri, 24 Jan 2025 21:55:24 +0000 Subject: [PATCH 13/59] Initial codes and scripts for training entropy model Summary: Test Plan: --- .gitignore | 1 + bytelatent/args.py | 13 ++- bytelatent/configs/debug.yaml | 3 +- bytelatent/configs/entropy_model.yaml | 82 +++++++++++++++++++ bytelatent/data/data_types.py | 2 +- bytelatent/data/iterators/packing_iterator.py | 42 ++++++++++ .../data/iterators/sequence_iterator.py | 30 +++++-- bytelatent/data/patcher.py | 10 ++- bytelatent/model/blt.py | 5 +- bytelatent/test_blt.py | 3 +- bytelatent/train.py | 52 +++++++++--- 11 files changed, 209 insertions(+), 34 deletions(-) create mode 100644 bytelatent/configs/entropy_model.yaml diff --git a/.gitignore b/.gitignore index 6c664b8..d1d7c2a 100644 --- a/.gitignore +++ b/.gitignore @@ -166,3 +166,4 @@ figures/ .vscode/ .DS_Store internal/ +jobs_parallel-copy/ diff --git a/bytelatent/args.py b/bytelatent/args.py index a332c89..56de22d 100644 --- a/bytelatent/args.py +++ b/bytelatent/args.py @@ -93,6 +93,8 @@ class DataloaderArgs(BaseModel): max_encoder_seq_length: int = 12288 enable_byte_ngrams: bool = False + add_patches: bool = True + tokenizer_args: TokenizerArgs = TokenizerArgs() patcher_args: PatcherArgs = PatcherArgs() @@ -120,6 +122,7 @@ class DataloaderArgs(BaseModel): looping_iterator, patcher_args=self.patcher_args, tokenizer_args=self.tokenizer_args, + add_patches=self.add_patches, ) sequence_iterator = SequenceIterator( preprocess_iterator, @@ -141,13 +144,19 @@ class DataloaderArgs(BaseModel): source_to_iterator=source_to_sequence_iterators, ) tokenizer = self.tokenizer_args.build() + if self.tokenizer_args.name == "bytes": + # TODO: Check this with Artidoro + pad_id = 0 + else: + pad_id = tokenizer.boe_id packing_args = PackingArgs( batch_size=self.batch_size, seq_len=self.seq_len, - pad_id=tokenizer.boe_id, + pad_id=pad_id, max_length=self.max_encoder_seq_length, pad_to_max_length=self.pad_to_max_length, enable_byte_ngrams=self.enable_byte_ngrams, + tokenizer_name=self.tokenizer_args.name, ) packing_iterator = PackingIterator(sampling_iterator, packing_args=packing_args) if self.load_async: @@ -180,7 +189,7 @@ class TrainArgs(BaseModel): data: DataloaderArgs = DataloaderArgs() optim: OptimArgs = OptimArgs() - model: ByteLatentTransformerArgs = ByteLatentTransformerArgs() + model: ByteLatentTransformerArgs | None = ByteLatentTransformerArgs() # This is only needed for training the entropy model entropy_model: LMTransformerArgs | None = None # Instead of training main model, train entropy model diff --git a/bytelatent/configs/debug.yaml b/bytelatent/configs/debug.yaml index 4ae4459..1098ff5 100644 --- a/bytelatent/configs/debug.yaml +++ b/bytelatent/configs/debug.yaml @@ -26,10 +26,9 @@ model: vocab_size: 260 dim_token: 256 patch_size: 6 - tokenization_mode: "bytes" patching_mode: "space" tie_local_encoder_decoder_logits: false - data_loader_patching: true + patch_in_forward: false max_encoder_seq_length: 12288 pad_to_max_length: true patching_threshold: 3.1439168453216553 diff --git a/bytelatent/configs/entropy_model.yaml b/bytelatent/configs/entropy_model.yaml new file mode 100644 index 0000000..51b65d4 --- /dev/null +++ b/bytelatent/configs/entropy_model.yaml @@ -0,0 +1,82 @@ +# Template config, need to change dump_dir, data.root_dir and tokenizer.path +# Evals can be activated by uncommenting its config +# python -m launchers.stool config=apps/main/configs/debug.yaml nodes=8 account=fair_amaia_cw_codegen qos=lowest + +dump_dir: /tmp/ +name: "debug" +steps: 100_000 +probe_freq: null +seed: 777 +optim: + lr: 4e-04 + warmup: 500 + lr_min_ratio: 0.1 + clip: 10.0 + +distributed: + fsdp_type: full_shard + model_dtype: bf16 + matmul_allow_tf32: false + selective_activation_checkpointing: false + tp_size: 1 + +train_entropy_model: true +model: null +entropy_model: + dim: 768 + n_layers: 14 + n_heads: 12 + max_seqlen: 8192 + # vocab_size: -1 + vocab_size: 260 + ffn_dim_multiplier: 1.0 + sliding_window: 512 + attn_bias_type: "local_block_causal" + attn_impl: "xformers" + +data: + s3_profile: blt + root_dir: ??? + sources: + dclm_baseline_1.0: 1.0 + batch_size: 2 + prefetch_size: 64 + # seqlen is in terms of patches and + # max_encoder_seq_length is in terms of bytes. + # For entropy model, these are the same since 1 patch=1 byte + seq_len: 8192 + max_encoder_seq_length: 8192 + load_async: true + preprocess_dir: ??? + # We don't need patches for this model + add_patches: false + patcher_args: + # This doesn't matter since byte entropy model doesn't use patching, + # so pick the most efficient, so static + patching_mode: byte + tokenizer_args: + name: bytes + +profiling: + run: false + +checkpoint: + dump: + every: 500 + keep: 3 + eval: + every: 1000 + keep: -1 + +logging: + freq: 10 + +eval_on_gpus: 8 +eval: + dataset_dir: ??? + tasks: ??? + generator: + max_tokens: 65536 + dtype: bf16 + + mp_size: 1 diff --git a/bytelatent/data/data_types.py b/bytelatent/data/data_types.py index 7e142e4..aa2daa9 100644 --- a/bytelatent/data/data_types.py +++ b/bytelatent/data/data_types.py @@ -53,7 +53,7 @@ BltIterator = Iterator[tuple[BltExample, DataLoaderState]] class BltSequence(BaseModel): tokens: list[int] mask: list[bool] - patch_lengths: list[int] + patch_lengths: list[int] | None @dataclass diff --git a/bytelatent/data/iterators/packing_iterator.py b/bytelatent/data/iterators/packing_iterator.py index 361fc03..fa29149 100644 --- a/bytelatent/data/iterators/packing_iterator.py +++ b/bytelatent/data/iterators/packing_iterator.py @@ -17,6 +17,7 @@ class PackingArgs(BaseModel): max_length: int | None pad_to_max_length: bool enable_byte_ngrams: bool + tokenizer_name: str class PackingIteratorState(BaseModel, IteratorState): @@ -151,6 +152,43 @@ class PackingIterator(StatefulIterator[Batch, PackingIteratorState]): ) def create_iter(self): + if self.packing_args.tokenizer_name == "bytes": + return self._create_iter_from_bytes() + else: + return self._create_iter_from_patch_lengths() + + def _create_iter_from_bytes(self): + sequence_iter = self.sequence_iterator.create_iter() + batch_size = self.packing_args.batch_size + pad_id = self.packing_args.pad_id + seq_len = self.packing_args.seq_len + while True: + tokens: list[list[int]] = [] + masks: list[list[bool]] = [] + + for _ in range(self.packing_args.batch_size): + sequence = next(sequence_iter) + _tokens = sequence.tokens + _mask = sequence.mask + assert ( + sequence.patch_lengths is None + ), "patch_lengths should not be used in byte packing" + tokens.append(_tokens) + masks.append(_mask) + + x = np.full((batch_size, seq_len), fill_value=pad_id) + y = np.full((batch_size, seq_len), fill_value=pad_id) + + for i, tok_seq in enumerate(tokens): + x[i, : len(tok_seq)] = tok_seq + y[i, : len(tok_seq) - 1] = tok_seq[1:] + batch = Batch(x=x, y=y) + assert ( + batch.mask is None or np.sum(x != pad_id) == batch.mask.sum() + ), f"{np.sum(x != pad_id)} != {batch.mask.sum()}" + yield batch + + def _create_iter_from_patch_lengths(self): sequence_iter = self.sequence_iterator.create_iter() batch_size = self.packing_args.batch_size pad_id = self.packing_args.pad_id @@ -168,6 +206,10 @@ class PackingIterator(StatefulIterator[Batch, PackingIteratorState]): _tokens = sequence.tokens _mask = sequence.mask _patch_lengths = sequence.patch_lengths + assert ( + _patch_lengths is not None + ), "patch lengths are required for packing based on patches." + # Reminder: seq_len is in terms of patches assert len(sequence.patch_lengths) == self.packing_args.seq_len last_patch_length = 0 if _patch_lengths[0] > 1: diff --git a/bytelatent/data/iterators/sequence_iterator.py b/bytelatent/data/iterators/sequence_iterator.py index 14e3747..d90ea31 100644 --- a/bytelatent/data/iterators/sequence_iterator.py +++ b/bytelatent/data/iterators/sequence_iterator.py @@ -70,15 +70,22 @@ class SequenceIterator(StatefulIterator): for example in example_iter: assert example.tokens is not None assert example.mask is not None - assert example.patch_lengths is not None + if self.preprocess_iterator.add_patches: + assert example.patch_lengths is not None + assert len(example.tokens) == sum(example.patch_lengths) + else: + assert example.patch_lengths is None assert len(example.tokens) != 0 assert len(example.mask) != 0 assert len(example.tokens) == len(example.mask) - assert len(example.tokens) == sum(example.patch_lengths) tokens.extend(example.tokens) mask.extend(example.mask) - patch_lengths.extend(example.patch_lengths) + if self.preprocess_iterator.add_patches: + patch_lengths.extend(example.patch_lengths) + else: + # This lets the rest of the code work as expected and just yield byte seqs + patch_lengths.extend([1] * len(example.tokens)) while len(patch_lengths) >= n_buffer_patches: if first: @@ -115,8 +122,15 @@ class SequenceIterator(StatefulIterator): == len(seq_mask[idx]) ), f"{sum(seq_patch_lengths[idx])}, {len(seq_tokens[idx])} {len(seq_mask[idx])}, idx={idx}" assert seq_patch_lengths[idx][0] > 0, f"{seq_patch_lengths[idx]}" - yield BltSequence( - tokens=seq_tokens[idx], - mask=seq_mask[idx], - patch_lengths=seq_patch_lengths[idx], - ) + if self.preprocess_iterator.add_patches: + yield BltSequence( + tokens=seq_tokens[idx], + mask=seq_mask[idx], + patch_lengths=seq_patch_lengths[idx], + ) + else: + yield BltSequence( + tokens=seq_tokens[idx], + mask=seq_mask[idx], + patch_lengths=None, + ) diff --git a/bytelatent/data/patcher.py b/bytelatent/data/patcher.py index afcfa2e..44ff5e9 100644 --- a/bytelatent/data/patcher.py +++ b/bytelatent/data/patcher.py @@ -22,6 +22,8 @@ class PatchingModeEnum(str, Enum): bpe = "bpe" bpe_patcher = "bpe_patcher" space = "space" + static = "static" + byte = "byte" class PatcherArgs(BaseModel): @@ -34,7 +36,6 @@ class PatcherArgs(BaseModel): max_patch_length: int | None = None patch_size: float = 4.5 patching_batch_size: int = 1 - data_loader_patching: bool = False device: str = "cuda" monotonicity: bool = False log_time: bool = False @@ -486,7 +487,6 @@ class Patcher: self.max_patch_length = patcher_args.max_patch_length self.patch_size = patcher_args.patch_size self.patching_batch_size = patcher_args.patching_batch_size - self.data_loader_patching = patcher_args.data_loader_patching self.device = patcher_args.device self.monotonicity = patcher_args.monotonicity self.log_time = patcher_args.log_time @@ -528,7 +528,7 @@ class Patcher: seq_len_next_tok = seq_len + 1 if include_next_token else seq_len scores = None # STATIC - if self.patching_mode is None: + if self.patching_mode == PatchingModeEnum.static: patch_lengths = torch.zeros( (bs, math.ceil(seq_len_next_tok / self.patch_size)), dtype=tokens.dtype, @@ -536,6 +536,10 @@ class Patcher: ).fill_(self.patch_size) if seq_len_next_tok % self.patch_size != 0: patch_lengths[:, -1] = seq_len_next_tok % self.patch_size + elif self.patching_mode == PatchingModeEnum.byte: + patch_lengths = torch.ones( + (bs, seq_len_next_tok), dtype=tokens.dtype, device=tokens.device + ) # ENTROPY elif self.patching_mode == PatchingModeEnum.entropy: if self.log_time: diff --git a/bytelatent/model/blt.py b/bytelatent/model/blt.py index 843ad34..a62be23 100644 --- a/bytelatent/model/blt.py +++ b/bytelatent/model/blt.py @@ -411,6 +411,7 @@ class ByteLatentTransformerArgs(BaseTransformerArgs): n_heads: int = 8 # TODO: What is the purpose of this parameter? weight_tying: bool = False + patch_in_forward: bool = False # Architecture and dimensions dim_token: int = 256 @@ -422,7 +423,6 @@ class ByteLatentTransformerArgs(BaseTransformerArgs): n_layers_local_encoder: int = 8 # Tokenization and patching - tokenization_mode: str = "bpe" patch_size: float | None = None patching_mode: str | None = None patching_threshold: float | None = None @@ -430,7 +430,6 @@ class ByteLatentTransformerArgs(BaseTransformerArgs): monotonicity: bool = False patching_batch_size: int = 1 patching_device: str = "cuda" - data_loader_patching: bool = False max_patch_length: int | None = None # Encoder/Decoder configuration @@ -856,7 +855,7 @@ class ByteLatentTransformer(nn.Module): self.output.weight = self.tok_embeddings.weight # Patcher module - if not args.data_loader_patching: + if args.patch_in_forward: self.patcher = Patcher( PatcherArgs( patch_size=args.patch_size, diff --git a/bytelatent/test_blt.py b/bytelatent/test_blt.py index 36a9882..eb94df3 100644 --- a/bytelatent/test_blt.py +++ b/bytelatent/test_blt.py @@ -68,10 +68,9 @@ def create_args(cross_attention=False): # Additional args from command line dim_token=256, patch_size=6, - tokenization_mode="bytes", patching_mode="space", tie_local_encoder_decoder_logits=False, - data_loader_patching=True, + patch_in_forward=False, max_encoder_seq_length=12288, pad_to_max_length=True, encoder_lm_loss=False, diff --git a/bytelatent/train.py b/bytelatent/train.py index 80bd393..a7ca405 100644 --- a/bytelatent/train.py +++ b/bytelatent/train.py @@ -51,6 +51,7 @@ from bytelatent.transformer import ( get_no_recompute_ops, get_num_flop_per_token, tp_parallelize, + LMTransformer, ) logger = logging.getLogger() @@ -103,10 +104,15 @@ class TrainState(Stateful): def validate_train_args(args: TrainArgs, output_size: int): - if args.model.vocab_size < 0: + assert args.model is not None or args.entropy_model is not None + if args.model is not None: logger.info(f"Setting model output size to {args.model.vocab_size}") args.model.vocab_size = output_size + if args.entropy_model is not None: + logger.info(f"Setting model output size to {args.entropy_model.vocab_size}") + args.entropy_model.vocab_size = output_size + assert args.dump_dir, "Dump dir not set" if args.checkpoint.path is None: @@ -147,7 +153,10 @@ def validate_train_args(args: TrainArgs, output_size: int): and args.distributed.dp_replicate == get_world_size() ) - args.model.max_seqlen = args.data.seq_len + if args.model is not None: + args.model.max_seqlen = args.data.seq_len + if args.entropy_model is not None: + args.entropy_model.max_seqlen = args.data.seq_len if args.distributed.tp_size == 1: logger.warning( @@ -237,7 +246,14 @@ def train(args: TrainArgs): # Initializing Model in meta device allows us to initialize models much bigger than 1 gpu's memory with torch.device("meta"): - model = ByteLatentTransformer(args.model) + if args.train_entropy_model: + assert args.entropy_model is not None + model = LMTransformer(args.entropy_model) + model_args = args.entropy_model + else: + assert args.model is not None + model = ByteLatentTransformer(args.model) + model_args = args.model logger.info("Model is built !") model_param_count = get_num_params(model) @@ -247,7 +263,7 @@ def train(args: TrainArgs): world_mesh, args.model, args.distributed, - fsdp_grouping_plan=build_fsdp_grouping_plan(args.model), + fsdp_grouping_plan=build_fsdp_grouping_plan(model_args), tp_parallelize=tp_parallelize, no_recompute_ops=get_no_recompute_ops(), ) @@ -267,7 +283,7 @@ def train(args: TrainArgs): model.rope_embeddings.reset_parameters() # For RoPe initialization since it's a buffer it might not be loaded else: with torch.random.fork_rng(devices=[torch.cuda.current_device()]): - torch.manual_seed(args.model.seed) + torch.manual_seed(model_args.seed) model.init_weights() check_model_value_range(model, range=10.0, std=1.0) @@ -342,10 +358,17 @@ def train(args: TrainArgs): batch.x, ).cuda() batch_y = torch.from_numpy(batch.y).cuda() - batch_patch_lengths = torch.from_numpy(batch.patch_lengths).cuda() + if batch.patch_lengths is None: + batch_patch_lengths = None + else: + batch_patch_lengths = torch.from_numpy(batch.patch_lengths).cuda() mask = None if batch.mask is None else torch.from_numpy(batch.mask).cuda() - if args.model.encoder_enable_byte_ngrams and batch.ngram_ids is None: + if ( + not args.train_entropy_model + and args.model.encoder_enable_byte_ngrams + and batch.ngram_ids is None + ): raise ValueError( "Cannot enable byte ngrams and have batch.ngram_ids be None" ) @@ -408,9 +431,12 @@ def train(args: TrainArgs): next(probe_mod.parameters()).grad is None ), "Probe model shouldn't have grads at this point" - pred = model( - batch_x, patch_lengths=batch_patch_lengths, ngram_ids=ngram_ids - ) + if args.train_entropy_model: + pred = model(batch_x) + else: + pred = model( + batch_x, patch_lengths=batch_patch_lengths, ngram_ids=ngram_ids + ) loss, _ = compute_loss(pred, batch_y, mask, train_state.scale) @@ -474,9 +500,9 @@ def train(args: TrainArgs): # Use xformer's analyze profile trace to get actual measurement FLOPS = ( get_num_flop_per_token( - model_param_count - args.model.vocab_size * args.model.dim, - args.model.n_layers, - args.model.dim, + model_param_count - model_args.vocab_size * model_args.dim, + model_args.n_layers, + model_args.dim, args.data.seq_len, ) * wps From 34ca1f7d4b40030bbaba2cf51715434c902018bd Mon Sep 17 00:00:00 2001 From: Pedro Rodriguez Date: Fri, 24 Jan 2025 21:55:24 +0000 Subject: [PATCH 14/59] Initial codes and scripts for training entropy model Summary: Test Plan: --- .gitignore | 1 + bytelatent/args.py | 13 ++- bytelatent/configs/debug.yaml | 3 +- bytelatent/configs/entropy_model.yaml | 82 +++++++++++++++++++ bytelatent/data/data_types.py | 2 +- bytelatent/data/iterators/packing_iterator.py | 42 ++++++++++ .../data/iterators/sequence_iterator.py | 30 +++++-- bytelatent/data/patcher.py | 10 ++- bytelatent/model/blt.py | 5 +- bytelatent/test_blt.py | 3 +- bytelatent/train.py | 52 +++++++++--- 11 files changed, 209 insertions(+), 34 deletions(-) create mode 100644 bytelatent/configs/entropy_model.yaml diff --git a/.gitignore b/.gitignore index 6c664b8..d1d7c2a 100644 --- a/.gitignore +++ b/.gitignore @@ -166,3 +166,4 @@ figures/ .vscode/ .DS_Store internal/ +jobs_parallel-copy/ diff --git a/bytelatent/args.py b/bytelatent/args.py index a332c89..56de22d 100644 --- a/bytelatent/args.py +++ b/bytelatent/args.py @@ -93,6 +93,8 @@ class DataloaderArgs(BaseModel): max_encoder_seq_length: int = 12288 enable_byte_ngrams: bool = False + add_patches: bool = True + tokenizer_args: TokenizerArgs = TokenizerArgs() patcher_args: PatcherArgs = PatcherArgs() @@ -120,6 +122,7 @@ class DataloaderArgs(BaseModel): looping_iterator, patcher_args=self.patcher_args, tokenizer_args=self.tokenizer_args, + add_patches=self.add_patches, ) sequence_iterator = SequenceIterator( preprocess_iterator, @@ -141,13 +144,19 @@ class DataloaderArgs(BaseModel): source_to_iterator=source_to_sequence_iterators, ) tokenizer = self.tokenizer_args.build() + if self.tokenizer_args.name == "bytes": + # TODO: Check this with Artidoro + pad_id = 0 + else: + pad_id = tokenizer.boe_id packing_args = PackingArgs( batch_size=self.batch_size, seq_len=self.seq_len, - pad_id=tokenizer.boe_id, + pad_id=pad_id, max_length=self.max_encoder_seq_length, pad_to_max_length=self.pad_to_max_length, enable_byte_ngrams=self.enable_byte_ngrams, + tokenizer_name=self.tokenizer_args.name, ) packing_iterator = PackingIterator(sampling_iterator, packing_args=packing_args) if self.load_async: @@ -180,7 +189,7 @@ class TrainArgs(BaseModel): data: DataloaderArgs = DataloaderArgs() optim: OptimArgs = OptimArgs() - model: ByteLatentTransformerArgs = ByteLatentTransformerArgs() + model: ByteLatentTransformerArgs | None = ByteLatentTransformerArgs() # This is only needed for training the entropy model entropy_model: LMTransformerArgs | None = None # Instead of training main model, train entropy model diff --git a/bytelatent/configs/debug.yaml b/bytelatent/configs/debug.yaml index 4ae4459..1098ff5 100644 --- a/bytelatent/configs/debug.yaml +++ b/bytelatent/configs/debug.yaml @@ -26,10 +26,9 @@ model: vocab_size: 260 dim_token: 256 patch_size: 6 - tokenization_mode: "bytes" patching_mode: "space" tie_local_encoder_decoder_logits: false - data_loader_patching: true + patch_in_forward: false max_encoder_seq_length: 12288 pad_to_max_length: true patching_threshold: 3.1439168453216553 diff --git a/bytelatent/configs/entropy_model.yaml b/bytelatent/configs/entropy_model.yaml new file mode 100644 index 0000000..51b65d4 --- /dev/null +++ b/bytelatent/configs/entropy_model.yaml @@ -0,0 +1,82 @@ +# Template config, need to change dump_dir, data.root_dir and tokenizer.path +# Evals can be activated by uncommenting its config +# python -m launchers.stool config=apps/main/configs/debug.yaml nodes=8 account=fair_amaia_cw_codegen qos=lowest + +dump_dir: /tmp/ +name: "debug" +steps: 100_000 +probe_freq: null +seed: 777 +optim: + lr: 4e-04 + warmup: 500 + lr_min_ratio: 0.1 + clip: 10.0 + +distributed: + fsdp_type: full_shard + model_dtype: bf16 + matmul_allow_tf32: false + selective_activation_checkpointing: false + tp_size: 1 + +train_entropy_model: true +model: null +entropy_model: + dim: 768 + n_layers: 14 + n_heads: 12 + max_seqlen: 8192 + # vocab_size: -1 + vocab_size: 260 + ffn_dim_multiplier: 1.0 + sliding_window: 512 + attn_bias_type: "local_block_causal" + attn_impl: "xformers" + +data: + s3_profile: blt + root_dir: ??? + sources: + dclm_baseline_1.0: 1.0 + batch_size: 2 + prefetch_size: 64 + # seqlen is in terms of patches and + # max_encoder_seq_length is in terms of bytes. + # For entropy model, these are the same since 1 patch=1 byte + seq_len: 8192 + max_encoder_seq_length: 8192 + load_async: true + preprocess_dir: ??? + # We don't need patches for this model + add_patches: false + patcher_args: + # This doesn't matter since byte entropy model doesn't use patching, + # so pick the most efficient, so static + patching_mode: byte + tokenizer_args: + name: bytes + +profiling: + run: false + +checkpoint: + dump: + every: 500 + keep: 3 + eval: + every: 1000 + keep: -1 + +logging: + freq: 10 + +eval_on_gpus: 8 +eval: + dataset_dir: ??? + tasks: ??? + generator: + max_tokens: 65536 + dtype: bf16 + + mp_size: 1 diff --git a/bytelatent/data/data_types.py b/bytelatent/data/data_types.py index 7e142e4..aa2daa9 100644 --- a/bytelatent/data/data_types.py +++ b/bytelatent/data/data_types.py @@ -53,7 +53,7 @@ BltIterator = Iterator[tuple[BltExample, DataLoaderState]] class BltSequence(BaseModel): tokens: list[int] mask: list[bool] - patch_lengths: list[int] + patch_lengths: list[int] | None @dataclass diff --git a/bytelatent/data/iterators/packing_iterator.py b/bytelatent/data/iterators/packing_iterator.py index 361fc03..fa29149 100644 --- a/bytelatent/data/iterators/packing_iterator.py +++ b/bytelatent/data/iterators/packing_iterator.py @@ -17,6 +17,7 @@ class PackingArgs(BaseModel): max_length: int | None pad_to_max_length: bool enable_byte_ngrams: bool + tokenizer_name: str class PackingIteratorState(BaseModel, IteratorState): @@ -151,6 +152,43 @@ class PackingIterator(StatefulIterator[Batch, PackingIteratorState]): ) def create_iter(self): + if self.packing_args.tokenizer_name == "bytes": + return self._create_iter_from_bytes() + else: + return self._create_iter_from_patch_lengths() + + def _create_iter_from_bytes(self): + sequence_iter = self.sequence_iterator.create_iter() + batch_size = self.packing_args.batch_size + pad_id = self.packing_args.pad_id + seq_len = self.packing_args.seq_len + while True: + tokens: list[list[int]] = [] + masks: list[list[bool]] = [] + + for _ in range(self.packing_args.batch_size): + sequence = next(sequence_iter) + _tokens = sequence.tokens + _mask = sequence.mask + assert ( + sequence.patch_lengths is None + ), "patch_lengths should not be used in byte packing" + tokens.append(_tokens) + masks.append(_mask) + + x = np.full((batch_size, seq_len), fill_value=pad_id) + y = np.full((batch_size, seq_len), fill_value=pad_id) + + for i, tok_seq in enumerate(tokens): + x[i, : len(tok_seq)] = tok_seq + y[i, : len(tok_seq) - 1] = tok_seq[1:] + batch = Batch(x=x, y=y) + assert ( + batch.mask is None or np.sum(x != pad_id) == batch.mask.sum() + ), f"{np.sum(x != pad_id)} != {batch.mask.sum()}" + yield batch + + def _create_iter_from_patch_lengths(self): sequence_iter = self.sequence_iterator.create_iter() batch_size = self.packing_args.batch_size pad_id = self.packing_args.pad_id @@ -168,6 +206,10 @@ class PackingIterator(StatefulIterator[Batch, PackingIteratorState]): _tokens = sequence.tokens _mask = sequence.mask _patch_lengths = sequence.patch_lengths + assert ( + _patch_lengths is not None + ), "patch lengths are required for packing based on patches." + # Reminder: seq_len is in terms of patches assert len(sequence.patch_lengths) == self.packing_args.seq_len last_patch_length = 0 if _patch_lengths[0] > 1: diff --git a/bytelatent/data/iterators/sequence_iterator.py b/bytelatent/data/iterators/sequence_iterator.py index 14e3747..d90ea31 100644 --- a/bytelatent/data/iterators/sequence_iterator.py +++ b/bytelatent/data/iterators/sequence_iterator.py @@ -70,15 +70,22 @@ class SequenceIterator(StatefulIterator): for example in example_iter: assert example.tokens is not None assert example.mask is not None - assert example.patch_lengths is not None + if self.preprocess_iterator.add_patches: + assert example.patch_lengths is not None + assert len(example.tokens) == sum(example.patch_lengths) + else: + assert example.patch_lengths is None assert len(example.tokens) != 0 assert len(example.mask) != 0 assert len(example.tokens) == len(example.mask) - assert len(example.tokens) == sum(example.patch_lengths) tokens.extend(example.tokens) mask.extend(example.mask) - patch_lengths.extend(example.patch_lengths) + if self.preprocess_iterator.add_patches: + patch_lengths.extend(example.patch_lengths) + else: + # This lets the rest of the code work as expected and just yield byte seqs + patch_lengths.extend([1] * len(example.tokens)) while len(patch_lengths) >= n_buffer_patches: if first: @@ -115,8 +122,15 @@ class SequenceIterator(StatefulIterator): == len(seq_mask[idx]) ), f"{sum(seq_patch_lengths[idx])}, {len(seq_tokens[idx])} {len(seq_mask[idx])}, idx={idx}" assert seq_patch_lengths[idx][0] > 0, f"{seq_patch_lengths[idx]}" - yield BltSequence( - tokens=seq_tokens[idx], - mask=seq_mask[idx], - patch_lengths=seq_patch_lengths[idx], - ) + if self.preprocess_iterator.add_patches: + yield BltSequence( + tokens=seq_tokens[idx], + mask=seq_mask[idx], + patch_lengths=seq_patch_lengths[idx], + ) + else: + yield BltSequence( + tokens=seq_tokens[idx], + mask=seq_mask[idx], + patch_lengths=None, + ) diff --git a/bytelatent/data/patcher.py b/bytelatent/data/patcher.py index afcfa2e..44ff5e9 100644 --- a/bytelatent/data/patcher.py +++ b/bytelatent/data/patcher.py @@ -22,6 +22,8 @@ class PatchingModeEnum(str, Enum): bpe = "bpe" bpe_patcher = "bpe_patcher" space = "space" + static = "static" + byte = "byte" class PatcherArgs(BaseModel): @@ -34,7 +36,6 @@ class PatcherArgs(BaseModel): max_patch_length: int | None = None patch_size: float = 4.5 patching_batch_size: int = 1 - data_loader_patching: bool = False device: str = "cuda" monotonicity: bool = False log_time: bool = False @@ -486,7 +487,6 @@ class Patcher: self.max_patch_length = patcher_args.max_patch_length self.patch_size = patcher_args.patch_size self.patching_batch_size = patcher_args.patching_batch_size - self.data_loader_patching = patcher_args.data_loader_patching self.device = patcher_args.device self.monotonicity = patcher_args.monotonicity self.log_time = patcher_args.log_time @@ -528,7 +528,7 @@ class Patcher: seq_len_next_tok = seq_len + 1 if include_next_token else seq_len scores = None # STATIC - if self.patching_mode is None: + if self.patching_mode == PatchingModeEnum.static: patch_lengths = torch.zeros( (bs, math.ceil(seq_len_next_tok / self.patch_size)), dtype=tokens.dtype, @@ -536,6 +536,10 @@ class Patcher: ).fill_(self.patch_size) if seq_len_next_tok % self.patch_size != 0: patch_lengths[:, -1] = seq_len_next_tok % self.patch_size + elif self.patching_mode == PatchingModeEnum.byte: + patch_lengths = torch.ones( + (bs, seq_len_next_tok), dtype=tokens.dtype, device=tokens.device + ) # ENTROPY elif self.patching_mode == PatchingModeEnum.entropy: if self.log_time: diff --git a/bytelatent/model/blt.py b/bytelatent/model/blt.py index 843ad34..a62be23 100644 --- a/bytelatent/model/blt.py +++ b/bytelatent/model/blt.py @@ -411,6 +411,7 @@ class ByteLatentTransformerArgs(BaseTransformerArgs): n_heads: int = 8 # TODO: What is the purpose of this parameter? weight_tying: bool = False + patch_in_forward: bool = False # Architecture and dimensions dim_token: int = 256 @@ -422,7 +423,6 @@ class ByteLatentTransformerArgs(BaseTransformerArgs): n_layers_local_encoder: int = 8 # Tokenization and patching - tokenization_mode: str = "bpe" patch_size: float | None = None patching_mode: str | None = None patching_threshold: float | None = None @@ -430,7 +430,6 @@ class ByteLatentTransformerArgs(BaseTransformerArgs): monotonicity: bool = False patching_batch_size: int = 1 patching_device: str = "cuda" - data_loader_patching: bool = False max_patch_length: int | None = None # Encoder/Decoder configuration @@ -856,7 +855,7 @@ class ByteLatentTransformer(nn.Module): self.output.weight = self.tok_embeddings.weight # Patcher module - if not args.data_loader_patching: + if args.patch_in_forward: self.patcher = Patcher( PatcherArgs( patch_size=args.patch_size, diff --git a/bytelatent/test_blt.py b/bytelatent/test_blt.py index 36a9882..eb94df3 100644 --- a/bytelatent/test_blt.py +++ b/bytelatent/test_blt.py @@ -68,10 +68,9 @@ def create_args(cross_attention=False): # Additional args from command line dim_token=256, patch_size=6, - tokenization_mode="bytes", patching_mode="space", tie_local_encoder_decoder_logits=False, - data_loader_patching=True, + patch_in_forward=False, max_encoder_seq_length=12288, pad_to_max_length=True, encoder_lm_loss=False, diff --git a/bytelatent/train.py b/bytelatent/train.py index 80bd393..1d0fa40 100644 --- a/bytelatent/train.py +++ b/bytelatent/train.py @@ -47,6 +47,7 @@ from bytelatent.probe import AutoProbeD from bytelatent.profiling import maybe_run_profiler from bytelatent.stool import StoolArgs, launch_job from bytelatent.transformer import ( + LMTransformer, build_fsdp_grouping_plan, get_no_recompute_ops, get_num_flop_per_token, @@ -103,10 +104,15 @@ class TrainState(Stateful): def validate_train_args(args: TrainArgs, output_size: int): - if args.model.vocab_size < 0: + assert args.model is not None or args.entropy_model is not None + if args.model is not None: logger.info(f"Setting model output size to {args.model.vocab_size}") args.model.vocab_size = output_size + if args.entropy_model is not None: + logger.info(f"Setting model output size to {args.entropy_model.vocab_size}") + args.entropy_model.vocab_size = output_size + assert args.dump_dir, "Dump dir not set" if args.checkpoint.path is None: @@ -147,7 +153,10 @@ def validate_train_args(args: TrainArgs, output_size: int): and args.distributed.dp_replicate == get_world_size() ) - args.model.max_seqlen = args.data.seq_len + if args.model is not None: + args.model.max_seqlen = args.data.seq_len + if args.entropy_model is not None: + args.entropy_model.max_seqlen = args.data.seq_len if args.distributed.tp_size == 1: logger.warning( @@ -237,7 +246,14 @@ def train(args: TrainArgs): # Initializing Model in meta device allows us to initialize models much bigger than 1 gpu's memory with torch.device("meta"): - model = ByteLatentTransformer(args.model) + if args.train_entropy_model: + assert args.entropy_model is not None + model = LMTransformer(args.entropy_model) + model_args = args.entropy_model + else: + assert args.model is not None + model = ByteLatentTransformer(args.model) + model_args = args.model logger.info("Model is built !") model_param_count = get_num_params(model) @@ -247,7 +263,7 @@ def train(args: TrainArgs): world_mesh, args.model, args.distributed, - fsdp_grouping_plan=build_fsdp_grouping_plan(args.model), + fsdp_grouping_plan=build_fsdp_grouping_plan(model_args), tp_parallelize=tp_parallelize, no_recompute_ops=get_no_recompute_ops(), ) @@ -267,7 +283,7 @@ def train(args: TrainArgs): model.rope_embeddings.reset_parameters() # For RoPe initialization since it's a buffer it might not be loaded else: with torch.random.fork_rng(devices=[torch.cuda.current_device()]): - torch.manual_seed(args.model.seed) + torch.manual_seed(model_args.seed) model.init_weights() check_model_value_range(model, range=10.0, std=1.0) @@ -342,10 +358,17 @@ def train(args: TrainArgs): batch.x, ).cuda() batch_y = torch.from_numpy(batch.y).cuda() - batch_patch_lengths = torch.from_numpy(batch.patch_lengths).cuda() + if batch.patch_lengths is None: + batch_patch_lengths = None + else: + batch_patch_lengths = torch.from_numpy(batch.patch_lengths).cuda() mask = None if batch.mask is None else torch.from_numpy(batch.mask).cuda() - if args.model.encoder_enable_byte_ngrams and batch.ngram_ids is None: + if ( + not args.train_entropy_model + and args.model.encoder_enable_byte_ngrams + and batch.ngram_ids is None + ): raise ValueError( "Cannot enable byte ngrams and have batch.ngram_ids be None" ) @@ -408,9 +431,12 @@ def train(args: TrainArgs): next(probe_mod.parameters()).grad is None ), "Probe model shouldn't have grads at this point" - pred = model( - batch_x, patch_lengths=batch_patch_lengths, ngram_ids=ngram_ids - ) + if args.train_entropy_model: + pred = model(batch_x) + else: + pred = model( + batch_x, patch_lengths=batch_patch_lengths, ngram_ids=ngram_ids + ) loss, _ = compute_loss(pred, batch_y, mask, train_state.scale) @@ -474,9 +500,9 @@ def train(args: TrainArgs): # Use xformer's analyze profile trace to get actual measurement FLOPS = ( get_num_flop_per_token( - model_param_count - args.model.vocab_size * args.model.dim, - args.model.n_layers, - args.model.dim, + model_param_count - model_args.vocab_size * model_args.dim, + model_args.n_layers, + model_args.dim, args.data.seq_len, ) * wps From e02ba763b00aa7075752a3615f59243c210ed188 Mon Sep 17 00:00:00 2001 From: Pedro Rodriguez Date: Tue, 28 Jan 2025 00:38:16 +0000 Subject: [PATCH 15/59] This includes fixes that make checkpointing and reloading work correctly. It also batches in a first set of changes for fixing eval code Summary: Test Plan: --- apps/main/lingua_train.py | 2 +- bytelatent/args.py | 87 ++++++++- bytelatent/checkpoint.py | 9 +- bytelatent/data/data_types.py | 10 - .../data/iterators/multiprocess_iterator.py | 15 ++ {apps/main => bytelatent}/eval.py | 179 ++++-------------- {apps/main => bytelatent}/generate.py | 57 +++--- bytelatent/train.py | 81 ++++---- 8 files changed, 219 insertions(+), 221 deletions(-) rename {apps/main => bytelatent}/eval.py (56%) rename {apps/main => bytelatent}/generate.py (91%) diff --git a/apps/main/lingua_train.py b/apps/main/lingua_train.py index bdb47da..7925ec6 100644 --- a/apps/main/lingua_train.py +++ b/apps/main/lingua_train.py @@ -544,7 +544,7 @@ def train(args: TrainArgs): if args.eval is not None and every_n_steps( train_state, args.checkpoint.eval.every, acc_step=0 ): - from apps.main.eval import EVAL_FOLDER_NAME, EvalArgs, launch_eval + from bytelatent.eval import EVAL_FOLDER_NAME, EvalArgs, launch_eval eval_args = dataclass_from_dict(EvalArgs, args.eval) diff --git a/bytelatent/args.py b/bytelatent/args.py index 56de22d..d1bac46 100644 --- a/bytelatent/args.py +++ b/bytelatent/args.py @@ -5,6 +5,7 @@ from typing import Any import numpy as np import yaml +from omegaconf import OmegaConf from pydantic import BaseModel, ConfigDict from bytelatent.checkpoint import CheckpointArgs @@ -39,6 +40,19 @@ def get_rng_state(seed: int, rank: int, world_size: int) -> dict[str, Any]: return np.random.default_rng((seed, rank, world_size)).bit_generator.state +def parse_args(args_cls): + cli_args = OmegaConf.from_cli() + file_cfg = OmegaConf.load(cli_args.config) + # We remove 'config' attribute from config as the underlying DataClass does not have it + del cli_args.config + + default_cfg = OmegaConf.create(args_cls().model_dump()) + cfg = OmegaConf.merge(default_cfg, file_cfg, cli_args) + cfg = OmegaConf.to_container(cfg, resolve=True, throw_on_missing=True) + pydantic_args = args_cls.model_validate(cfg) + return pydantic_args + + def distribute_data_to_rank( *, dataset_path: str, @@ -71,6 +85,22 @@ def distribute_data_to_rank( return rank_to_arrow_iterator_params[rank] +class PackedCausalTransformerGeneratorArgs(BaseModel): + model_config = ConfigDict(extra="forbid") + temperature: float = 0.0 + top_p: float | None = None + top_k: float | None = None + max_gen_len: int = 512 # Maximum number of tokens to generate + max_tokens: int = 1024 # Maximum number of tokens that can go through the model + max_prompt_len: int | None = None + until: list[str] = [] + compile_prefilling: bool = False + reduce_generation_overhead: bool = False + show_progress: bool = False + dtype: str | None = "bf16" + device: str | None = "cuda" + + class DataloaderArgs(BaseModel): model_config = ConfigDict(extra="forbid") s3_profile: str | None = None @@ -168,6 +198,58 @@ class DataloaderArgs(BaseModel): return packing_iterator +class LMHarnessArgs(BaseModel): + model_config = ConfigDict(extra="forbid") + tasks: list[Any] | None = None + num_fewshot: int | None = None + device: str | None = None + use_cache: str | None = None + cache_requests: bool = False + rewrite_requests_cache: bool = False + delete_requests_cache: bool = False + limit: int | float | None = None + bootstrap_iters: int = 100000 + check_integrity: bool = False + write_out: bool = False + log_samples: bool = True + system_instruction: str | None = None + apply_chat_template: bool | str = False + fewshot_as_multiturn: bool = False + gen_kwargs: str | None = None + verbosity: str = "INFO" + predict_only: bool = False + random_seed: int = 0 + numpy_random_seed: int = 1234 + torch_random_seed: int = 1234 + fewshot_random_seed: int = 1234 + + +class ValidationArgs(BaseModel): + model_config = ConfigDict(extra="forbid") + max_steps: int | None = ( + None # If None the whole validation file is used -> /!\ This number of steps is gpu dependent (100 max steps on 8 gpus = 800 steps on 1 gpu) + ) + use_val_from_train_src: bool = True # Use the validation set from training sources + root_dir: str = "" + sources: list[str] = [] # Other sources to eval on + + +class EvalArgs(BaseModel): + model_config = ConfigDict(extra="forbid") + dump_dir: str + ckpt_dir: str + metric_log_dir: str | None = None + generator: PackedCausalTransformerGeneratorArgs = ( + PackedCausalTransformerGeneratorArgs() + ) + + harness: LMHarnessArgs | None = LMHarnessArgs() + validation: ValidationArgs | None = ValidationArgs() + + global_step: int | None = None # for in-training evaluation + s3_profile: str | None = None + + class TrainArgs(BaseModel): model_config = ConfigDict(extra="forbid") name: str = "lingua" @@ -186,6 +268,9 @@ class TrainArgs(BaseModel): # Nb optimizer steps to take steps: int = 1000 + # If not None, halt training after this many steps, + # useful for debugging + max_steps: int | None = None data: DataloaderArgs = DataloaderArgs() optim: OptimArgs = OptimArgs() @@ -203,7 +288,7 @@ class TrainArgs(BaseModel): # If set to None, eval is run locally otherwise it launches a new job with the given number of gpus async_eval_gpus: int | None = None - eval: Any | None = None + eval: EvalArgs | None = None eval_on_gpus: int | None = None def dump_to_yaml_file( diff --git a/bytelatent/checkpoint.py b/bytelatent/checkpoint.py index bcf591e..f213c84 100644 --- a/bytelatent/checkpoint.py +++ b/bytelatent/checkpoint.py @@ -7,6 +7,7 @@ import re from pathlib import Path from typing import List, Optional, Tuple +import fsspec import torch import torch.distributed as dist import torch.distributed.checkpoint as dcp @@ -21,6 +22,7 @@ from torch.distributed.checkpoint.state_dict import ( set_state_dict, ) +from bytelatent.data.file_util import get_fs from bytelatent.distributed import get_is_master logger = logging.getLogger("CHECKPOINT") @@ -51,13 +53,14 @@ class CheckpointArgs(BaseModel): path: str | None = None init_ckpt_path: str | None = None continue_training_from_init: bool = False + s3_profile: str | None = None def _get_key_step(name: str): return int(re.findall(RE_DIGITS, name)[-1]) -def consolidate_checkpoints(ckpt_dir: str): +def consolidate_checkpoints(fs: fsspec.AbstractFileSystem, ckpt_dir: str): """ Consolidates all FSDP checkpoints in a directory to a single file Consolidate checkpoint is saved in a subdirectory of ckpt_dir @@ -102,15 +105,17 @@ def load_from_checkpoint( dcp.load(state_dict, checkpoint_id=ckpt_dir) +# TODO: Rewrite the file operations here to use fsspec to enable s3 writing. class CheckpointManager: def __init__(self, args: CheckpointArgs): self.path = args.path + self.fs = get_fs(self.path, s3_profile=args.s3_profile) self.dump_every = args.dump self.eval_every = args.eval self.init_ckpt_path = args.init_ckpt_path self.continue_training_from_init = args.continue_training_from_init - assert os.path.exists( + assert self.fs.exists( self.path ), f"Path {self.path} does not exist and needs to be created before using CheckpointManager (use instantiate_and_make_dir)" diff --git a/bytelatent/data/data_types.py b/bytelatent/data/data_types.py index aa2daa9..f4bbc07 100644 --- a/bytelatent/data/data_types.py +++ b/bytelatent/data/data_types.py @@ -40,16 +40,6 @@ class BltPackTokensState(BaseModel): n_views: int = 2 -class DataLoaderState(BaseModel): - model_config = ConfigDict(extra="forbid") - multi_choice_state: MultiChoiceState - pack_tokens_state: BltPackTokensState - prefetch_state: PrefetchState - - -BltIterator = Iterator[tuple[BltExample, DataLoaderState]] - - class BltSequence(BaseModel): tokens: list[int] mask: list[bool] diff --git a/bytelatent/data/iterators/multiprocess_iterator.py b/bytelatent/data/iterators/multiprocess_iterator.py index f17ca6e..49d99ac 100644 --- a/bytelatent/data/iterators/multiprocess_iterator.py +++ b/bytelatent/data/iterators/multiprocess_iterator.py @@ -128,6 +128,13 @@ class MultiprocessIterator(StatefulIterator): self.producer = None self.stop_iterating_event = None self.state_dumped_event = None + self.force_shutdown = False + + def shutdown(self): + if self.producer is not None: + # This properly shuts things down + self.producer.kill() + self.force_shutdown = True def get_state(self) -> MultiprocessIteratorState: """ @@ -135,6 +142,10 @@ class MultiprocessIterator(StatefulIterator): to halt the background process and allow it to write the state to the main loop in order to not lose data """ + if self.force_shutdown: + raise ValueError( + "State will be invalid if shutdown was forced before state persisted." + ) if self.producer is None: serialized_prefetch_buffer = json.dumps( [b.to_python_dict() for b in self.prefetch_buffer] @@ -187,6 +198,10 @@ class MultiprocessIterator(StatefulIterator): ) def create_iter(self): + if self.force_shutdown: + raise ValueError( + "Iterator may be invalid if shutdown was forced before state persisted." + ) logging.info("Main thread: Creating MP iterator") # First yield from the stored prefetch buffer. if self.prefetch_buffer is not None: diff --git a/apps/main/eval.py b/bytelatent/eval.py similarity index 56% rename from apps/main/eval.py rename to bytelatent/eval.py index ed20f49..ae73066 100644 --- a/apps/main/eval.py +++ b/bytelatent/eval.py @@ -4,20 +4,20 @@ import json import logging import os from collections import defaultdict -from dataclasses import asdict, dataclass, field from datetime import datetime from pathlib import Path -from typing import Any, List, Optional, Tuple, Union +from typing import Any import torch -from lingua.args import dump_config -from lingua.data import init_choice_state, setup_sources from lm_eval import simple_evaluate from lm_eval.api.instance import Instance from lm_eval.api.model import LM from omegaconf import OmegaConf +from pydantic import BaseModel, ConfigDict +from bytelatent.args import EvalArgs, ValidationArgs, parse_args from bytelatent.checkpoint import CONSOLIDATE_FOLDER, consolidate_checkpoints +from bytelatent.data.file_util import get_fs from bytelatent.distributed import ( DistributedArgs, dist_mean_dict, @@ -25,72 +25,17 @@ from bytelatent.distributed import ( get_world_size, setup_torch_distributed, ) -from bytelatent.transformer import LMTransformer, LMTransformerArgs - -from apps.main.generate import ( +from bytelatent.generate import ( PackedCausalTransformerGenerator, - PackedCausalTransformerGeneratorArgs, load_consolidated_model_and_tokenizer, ) +from bytelatent.transformer import LMTransformer, LMTransformerArgs EVAL_FOLDER_NAME = "{:010d}" logger = logging.getLogger() -@dataclass -class LMHarnessArgs: - tasks: Optional[List[Any]] = None - num_fewshot: Optional[int] = None - device: Optional[str] = None - use_cache: Optional[str] = None - cache_requests: bool = False - rewrite_requests_cache: bool = False - delete_requests_cache: bool = False - limit: Optional[Union[int, float]] = None - bootstrap_iters: int = 100000 - check_integrity: bool = False - write_out: bool = False - log_samples: bool = True - system_instruction: Optional[str] = None - apply_chat_template: Union[bool, str] = False - fewshot_as_multiturn: bool = False - gen_kwargs: Optional[str] = None - verbosity: str = "INFO" - predict_only: bool = False - random_seed: int = 0 - numpy_random_seed: int = 1234 - torch_random_seed: int = 1234 - fewshot_random_seed: int = 1234 - - -@dataclass -class ValidationArgs: - max_steps: Optional[int] = ( - None # If None the whole validation file is used -> /!\ This number of steps is gpu dependent (100 max steps on 8 gpus = 800 steps on 1 gpu) - ) - use_val_from_train_src: bool = True # Use the validation set from training sources - root_dir: str = "" - sources: List[str] = field(default_factory=list) # Other sources to eval on - - -@dataclass -class EvalArgs: - name: str = "evals" - dump_dir: Optional[str] = None - metric_log_dir: Optional[str] = None - ckpt_dir: str = "" - generator: PackedCausalTransformerGeneratorArgs = field( - default_factory=PackedCausalTransformerGeneratorArgs - ) - harness: Optional[LMHarnessArgs] = field(default_factory=LMHarnessArgs) - validation: Optional[ValidationArgs] = field(default_factory=ValidationArgs) - - wandb: Optional[Any] = None - - global_step: Optional[int] = None # for in-training evaluation - - def all_dicts_same(dict_list): if not dict_list: # Check if the list is empty return True @@ -120,7 +65,7 @@ class EvalHarnessLM(LM): self._world_size = get_world_size() self.device = generator.device - def generate_until(self, requests: List[Instance]) -> List[str]: + def generate_until(self, requests: list[Instance]) -> list[str]: prompts, gen_args = zip(*[req.args for req in requests]) assert all_dicts_same(gen_args), "Doesn't support different gen args for now" gen_args = gen_args[0] @@ -141,7 +86,7 @@ class EvalHarnessLM(LM): filtered_gen.append(g) return filtered_gen - def loglikelihood(self, requests: List[Instance]) -> List[Tuple[float, bool]]: + def loglikelihood(self, requests: list[Instance]) -> list[tuple[float, bool]]: prompts, continuations = zip(*[req.args for req in requests]) inputs = [req.args[0] + req.args[1] for req in requests] max_gen_len = self.generator.max_gen_len @@ -158,7 +103,7 @@ class EvalHarnessLM(LM): self.generator.max_gen_len = max_gen_len return results - def loglikelihood_rolling(self, requests: List[Instance]) -> List[float]: + def loglikelihood_rolling(self, requests: list[Instance]) -> list[float]: prompts = [req.args[0] for req in requests] max_gen_len = self.generator.max_gen_len # We temporarily lower max gen len @@ -232,68 +177,73 @@ def eval_on_val(generator, val_args: ValidationArgs, train_cfg): return all_val_metrics -def launch_eval(cfg: EvalArgs): +def launch_eval(eval_args: EvalArgs): if not torch.distributed.is_initialized(): setup_torch_distributed(DistributedArgs()) + + fs = get_fs(eval_args.ckpt_dir, s3_profile=eval_args.s3_profile) if ( - Path(cfg.ckpt_dir).exists() - and (Path(cfg.ckpt_dir) / "params.json").exists() - and next(Path(cfg.ckpt_dir).glob("*.pth"), None) is not None + fs.exists(eval_args.ckpt_dir) + and fs.exists(os.path.join(eval_args.ckpt_dir, "params.json")) + and len(fs.glob(os.path.join(eval_args.ckpt_dir, "*.pth"))) != 0 ): - consolidate_path = Path(cfg.ckpt_dir) + consolidate_path = eval_args.ckpt_dir else: - consolidate_path = Path(cfg.ckpt_dir) / CONSOLIDATE_FOLDER - if not consolidate_path.exists() and get_global_rank() == 0: - consolidate_path = consolidate_checkpoints(cfg.ckpt_dir) + consolidate_path = os.path.join(eval_args.ckpt_dir, CONSOLIDATE_FOLDER) + if not fs.exists(consolidate_path) and get_global_rank() == 0: + consolidate_path = consolidate_checkpoints(eval_args.ckpt_dir) - Path(cfg.dump_dir).mkdir(parents=True, exist_ok=True) - dump_config(cfg, Path(cfg.dump_dir) / "config.yaml", log_config=False) + fs.mkdirs(eval_args.dump_dir, exist_ok=True) + with fs.open(os.path.join(eval_args.dump_dir, "config.yaml"), "w") as f: + f.write(eval_args.model_dump_json()) - consolidate_path = str(consolidate_path) torch.distributed.barrier() logger.info("Loading model") + # TODO: Make this general so that it works with either + # LMTransformer or Blt, similar with args model, tokenizer, train_cfg = load_consolidated_model_and_tokenizer( consolidate_path, - model_cls=LMTransformer, - model_args_cls=LMTransformerArgs, ) logger.info("Model loaded") model.eval() - generator = PackedCausalTransformerGenerator(cfg.generator, model, tokenizer) + generator = PackedCausalTransformerGenerator(eval_args.generator, model, tokenizer) wrap = EvalHarnessLM(generator) - results = simple_evaluate(wrap, **asdict(cfg.harness)) + # Redo + results = simple_evaluate(wrap, eval_args.harness.model_dump()) val_results = None - if cfg.validation: - val_results = eval_on_val(generator, cfg.validation, train_cfg) + if eval_args.validation: + val_results = eval_on_val(generator, eval_args.validation, train_cfg) if get_global_rank() == 0: - with open(Path(cfg.dump_dir) / "results.json", "w") as f: + with fs.open(os.path.join(eval_args.dump_dir, "results.json"), "w") as f: f.write(json.dumps(results)) logger.info(f"All evaluation results: {results['results']}") if val_results is not None: - with open(Path(cfg.dump_dir) / "validation.json", "w") as f: + with fs.open(os.path.join(eval_args.dump_dir, "validation.json"), "w") as f: f.write(json.dumps(val_results)) logger.info(f"All validation results: {val_results}") - if cfg.metric_log_dir and get_global_rank() == 0: - metric_log_path = Path(cfg.metric_log_dir) / "metrics.eval.jsonl" + if eval_args.metric_log_dir and get_global_rank() == 0: + metric_log_path = os.path.join(eval_args.metric_log_dir, "metrics.eval.jsonl") logger.info(f"Writing metric logs to {metric_log_path}") timestamp = { "created_at": datetime.utcnow().isoformat(), } - if cfg.global_step is not None: - timestamp["global_step"] = cfg.global_step + if eval_args.global_step is not None: + timestamp["global_step"] = eval_args.global_step print( json.dumps(timestamp | results["results"]), - file=open(metric_log_path, mode="a"), + file=fs.open(metric_log_path, mode="a"), flush=True, ) - val_log_path = Path(cfg.metric_log_dir) / "metrics.validation.jsonl" + val_log_path = os.path.join( + eval_args.metric_log_dir, "metrics.validation.jsonl" + ) if val_results is not None: print( json.dumps(timestamp | val_results), - file=open(val_log_path, mode="a"), + file=fs.open(val_log_path, mode="a"), flush=True, ) @@ -301,53 +251,8 @@ def launch_eval(cfg: EvalArgs): def main(): - """ - The command line interface here uses OmegaConf https://omegaconf.readthedocs.io/en/2.3_branch/usage.html#from-command-line-arguments - This accepts arguments as a dot list - So if the dataclass looks like - - @dataclass - class DummyArgs: - name: str - model: LMTransformerArgsgs - - @dataclass - class LMTransformerArgsgs: - dim: int - - Then you can pass model.dim=32 to change values in LMTransformerArgsgs - or just name=tictac for top level attributes. - - The behavior here is as follows: - 1. We instantiate EvalArgs with its default values - 2. We override those default values with the ones in the provided config file - 3. We override the result with the additional arguments provided through command line - - For example, if the config is the following - - model: - dim: 128 - n_layers: 4 - - and you call eval.py with eval.py model.dim=64 - - Then the final TrainArgs will have - - model: - dim: 64 - n_layers: 4 - - Plus all the default values in EvalArgs dataclass. - """ - cli_args = OmegaConf.from_cli() - file_cfg = OmegaConf.load(cli_args.config) - # We remove 'config' attribute from config as the underlying DataClass does not have it - del cli_args.config - - default_cfg = OmegaConf.structured(EvalArgs()) - cfg = OmegaConf.merge(default_cfg, file_cfg, cli_args) - cfg = OmegaConf.to_object(cfg) - launch_eval(cfg) + eval_args = parse_args(EvalArgs) + launch_eval(eval_args) if __name__ == "__main__": diff --git a/apps/main/generate.py b/bytelatent/generate.py similarity index 91% rename from apps/main/generate.py rename to bytelatent/generate.py index a3a8627..eb79d81 100644 --- a/apps/main/generate.py +++ b/bytelatent/generate.py @@ -1,20 +1,16 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. +import os import time -from dataclasses import dataclass, field -from pathlib import Path -from typing import List, Optional import torch -from lingua.args import dataclass_from_dict -from lingua.tokenizers.abstract_tokenizer import Tokenizer -from lingua.tokenizers.build_tokenizer import build_tokenizer from omegaconf import OmegaConf from torch import nn from torch.nn import functional as F from torch.nn.attention.flex_attention import create_block_mask from tqdm import tqdm +from bytelatent.args import PackedCausalTransformerGeneratorArgs, TrainArgs from bytelatent.base_transformer import ( Attention, causal_mask, @@ -23,7 +19,10 @@ from bytelatent.base_transformer import ( lengths_to_start_ids, ) from bytelatent.checkpoint import CONSOLIDATE_NAME -from bytelatent.transformer import LMTransformer, LMTransformerArgs +from bytelatent.data.file_util import get_fs +from bytelatent.model.blt import ByteLatentTransformer +from bytelatent.tokenizers.abstract_tokenizer import Tokenizer +from bytelatent.transformer import LMTransformer def sample_top_p(probs: torch.Tensor, p: float) -> torch.Tensor: @@ -62,7 +61,7 @@ def sample_tokens(logits, temperature=0.0, top_p=None, top_k=None): return next_token.view(shape[:-1]) -def pack_prompts(prompts: List[int]): +def pack_prompts(prompts: list[int]): res = [] lengths = [] for i, p in enumerate(prompts): @@ -120,22 +119,6 @@ class KVCache(nn.Module): return self.k_cache, self.v_cache -@dataclass -class PackedCausalTransformerGeneratorArgs: - temperature: float = 0.0 - top_p: Optional[float] = None - top_k: Optional[float] = None - max_gen_len: int = 512 # Maximum number of tokens to generate - max_tokens: int = 1024 # Maximum number of tokens that can go through the model - max_prompt_len: Optional[int] = None - until: List[str] = field(default_factory=list) - compile_prefilling: bool = False - reduce_generation_overhead: bool = False - show_progress: bool = False - dtype: Optional[str] = "bf16" - device: Optional[str] = "cuda" - - class PackedCausalTransformerGenerator: def __init__( self, @@ -401,25 +384,29 @@ class PackedCausalTransformerGenerator: def load_consolidated_model_and_tokenizer( consolidated_path, - model_cls=LMTransformer, - model_args_cls=LMTransformerArgs, ): - ckpt_path = Path(consolidated_path) - config = ckpt_path / "params.json" - config = OmegaConf.load(config) + train_args_path = os.path.join(consolidated_path, "params.json") + fs = get_fs(train_args_path) + with fs.open(train_args_path) as f: + train_args = TrainArgs.model_validate_json(f.read()) + + if train_args.train_entropy_model: + model_args = train_args.entropy_model + model = LMTransformer(model_args) + else: + model_args = train_args.model + model = ByteLatentTransformer(model_args) param_dtype = dict(fp32=torch.float32, fp16=torch.float16, bf16=torch.bfloat16)[ - config.distributed.model_dtype + train_args.distributed.model_dtype ] - model_args = dataclass_from_dict(model_args_cls, config.model, strict=False) - tokenizer = build_tokenizer(config.data.tokenizer.name, config.data.tokenizer.path) - model = model_cls(model_args) - st_dict = torch.load(ckpt_path / CONSOLIDATE_NAME, weights_only=True) + tokenizer = train_args.data.tokenizer_args.build() + st_dict = torch.load(consolidated_path / CONSOLIDATE_NAME, weights_only=True) model.load_state_dict(st_dict["model"]) model = model.cuda().eval() for param in model.parameters(): param.data = param.data.to(dtype=param_dtype) - return model, tokenizer, config + return model, tokenizer, train_args def main(): diff --git a/bytelatent/train.py b/bytelatent/train.py index 1d0fa40..6b20ecd 100644 --- a/bytelatent/train.py +++ b/bytelatent/train.py @@ -10,7 +10,7 @@ from copy import deepcopy from dataclasses import asdict, dataclass from pathlib import Path from timeit import default_timer as timer -from typing import Any, Dict, Type, TypeVar +from typing import Any, TypeVar import torch import torch.distributed @@ -23,9 +23,13 @@ from torch.distributed._tensor import DTensor from torch.distributed.checkpoint.stateful import Stateful from torch.optim import lr_scheduler -from bytelatent.args import TrainArgs +from bytelatent.args import TrainArgs, parse_args from bytelatent.checkpoint import CheckpointManager, load_from_checkpoint -from bytelatent.data.data_types import DataLoaderState +from bytelatent.data.iterators.multiprocess_iterator import ( + MultiprocessIterator, + MultiprocessIteratorState, +) +from bytelatent.data.iterators.packing_iterator import PackingIteratorState from bytelatent.distributed import ( check_model_value_range, clean_env, @@ -39,6 +43,7 @@ from bytelatent.distributed import ( setup_env, setup_torch_distributed, ) +from bytelatent.eval import EVAL_FOLDER_NAME, launch_eval from bytelatent.logger import init_logger from bytelatent.metrics import GPUMemoryMonitor, MetricLogger, get_num_params from bytelatent.model.blt import ByteLatentTransformer @@ -70,36 +75,49 @@ def flatten_dict(d, parent_key="", sep="_"): return dict(items) -def dataclass_from_dict(cls: Type[T], data: dict, strict: bool = True) -> T: - """ - Converts a dictionary to a dataclass instance, recursively for nested structures. - """ - base = OmegaConf.structured(cls()) - OmegaConf.set_struct(base, strict) - override = OmegaConf.create(data) - return OmegaConf.to_object(OmegaConf.merge(base, override)) +def get_iterator_state_name(iterator_state): + if isinstance(iterator_state, MultiprocessIteratorState): + return "multiprocess" + elif isinstance(iterator_state, PackingIteratorState): + return "packing" + else: + raise ValueError(f"Unsupported iterator to get name from: {iterator_state}") +# TODO: Make this pydantic based instead of data class based +# TODO: Generalize this to any iterator state @dataclass class TrainState(Stateful): step: int # Nb of steps taken by the optimizer acc_step: int # Nb of accumulation steps done since last optimizer step scheduler: lr_scheduler.LambdaLR - data_loader_state: DataLoaderState + data_loader_state: MultiprocessIteratorState | PackingIteratorState scale: float = 1.0 + data_loader_class: str | None = None - def state_dict(self) -> Dict[str, Any]: + def state_dict(self) -> dict[str, Any]: return { "step": self.step, "acc_step": self.acc_step, - "data_loader_state": self.data_loader_state.dict(), + "data_loader_state": self.data_loader_state.model_dump(), + "data_loader_class": get_iterator_state_name(self.data_loader_state), "scheduler": self.scheduler.state_dict(), } def load_state_dict(self, state_dict): self.step = state_dict["step"] self.acc_step = state_dict["acc_step"] - self.data_loader_state = DataLoaderState(**state_dict["data_loader_state"]) + self.data_loader_class = state_dict["data_loader_class"] + if self.data_loader_class == "multiprocess": + self.data_loader_state = MultiprocessIteratorState( + **state_dict["data_loader_state"] + ) + elif self.data_loader_class == "packing": + self.data_loader_state = PackingIteratorState( + **state_dict["data_loader_state"] + ) + else: + raise ValueError(f"invalid data loader class: {self.data_loader_class}") self.scheduler.load_state_dict(state_dict["scheduler"]) @@ -345,7 +363,10 @@ def train(args: TrainArgs): nwords_since_last_log = 0 time_last_log = timer() gc.collect() - while train_state.step < args.steps: + saved = False + while train_state.step < args.steps and ( + args.max_steps is None or train_state.step < args.max_steps + ): # We constrain train_state.acc_step to be in range 0 to args.grad_acc_steps - 1 train_state.acc_step += 1 train_state.acc_step = train_state.acc_step % args.grad_acc_steps @@ -552,7 +573,6 @@ def train(args: TrainArgs): f" pow: {gpu_mem_stats.power_draw/1000} W" ) - saved = False if every_n_steps( train_state, args.checkpoint.dump.every, acc_step=0 ) or every_n_steps(train_state, args.checkpoint.eval.every, acc_step=0): @@ -567,18 +587,14 @@ def train(args: TrainArgs): if args.eval is not None and every_n_steps( train_state, args.checkpoint.eval.every, acc_step=0 ): - from apps.main.eval import EVAL_FOLDER_NAME, EvalArgs, launch_eval - - eval_args = dataclass_from_dict(EvalArgs, args.eval) + eval_args = args.eval eval_args.global_step = train_state.step eval_args.ckpt_dir = str(checkpoint.existing_saves[-1]) - eval_args.dump_dir = str( - os.path.join( - args.dump_dir, - "evals", - EVAL_FOLDER_NAME.format(train_state.step), - ) + eval_args.dump_dir = os.path.join( + args.dump_dir, + "evals", + EVAL_FOLDER_NAME.format(train_state.step), ) eval_args.metric_log_dir = args.dump_dir if args.async_eval_gpus is None: @@ -619,6 +635,9 @@ def train(args: TrainArgs): args, device_mesh=world_mesh, ) + if isinstance(data_loader, MultiprocessIterator): + logger.info("Closing MP iterator before exiting") + data_loader.shutdown() gc.collect() @@ -661,15 +680,7 @@ def main(): Plus all the default values in TrainArgs dataclass. """ - cli_args = OmegaConf.from_cli() - file_cfg = OmegaConf.load(cli_args.config) - # We remove 'config' attribute from config as the underlying DataClass does not have it - del cli_args.config - - default_cfg = OmegaConf.create(TrainArgs().model_dump()) - cfg = OmegaConf.merge(default_cfg, file_cfg, cli_args) - cfg = OmegaConf.to_container(cfg, resolve=True, throw_on_missing=True) - train_args = TrainArgs.model_validate(cfg) + train_args = parse_args(TrainArgs) if train_args.debug_dynamo: import torch._dynamo From caf82b924e9f7fadc1b850cb30b59471c796fa8e Mon Sep 17 00:00:00 2001 From: Pedro Rodriguez Date: Tue, 28 Jan 2025 00:38:16 +0000 Subject: [PATCH 16/59] This includes fixes that make checkpointing and reloading work correctly. It also batches in a first set of changes for fixing eval code Summary: Test Plan: --- apps/main/lingua_train.py | 2 +- bytelatent/args.py | 87 ++++++++- bytelatent/checkpoint.py | 9 +- bytelatent/configs/debug.yaml | 9 +- bytelatent/configs/entropy_model.yaml | 9 +- bytelatent/data/data_types.py | 10 - .../data/iterators/multiprocess_iterator.py | 15 ++ {apps/main => bytelatent}/eval.py | 179 ++++-------------- {apps/main => bytelatent}/generate.py | 57 +++--- bytelatent/train.py | 81 ++++---- 10 files changed, 221 insertions(+), 237 deletions(-) rename {apps/main => bytelatent}/eval.py (56%) rename {apps/main => bytelatent}/generate.py (91%) diff --git a/apps/main/lingua_train.py b/apps/main/lingua_train.py index bdb47da..7925ec6 100644 --- a/apps/main/lingua_train.py +++ b/apps/main/lingua_train.py @@ -544,7 +544,7 @@ def train(args: TrainArgs): if args.eval is not None and every_n_steps( train_state, args.checkpoint.eval.every, acc_step=0 ): - from apps.main.eval import EVAL_FOLDER_NAME, EvalArgs, launch_eval + from bytelatent.eval import EVAL_FOLDER_NAME, EvalArgs, launch_eval eval_args = dataclass_from_dict(EvalArgs, args.eval) diff --git a/bytelatent/args.py b/bytelatent/args.py index 56de22d..d1bac46 100644 --- a/bytelatent/args.py +++ b/bytelatent/args.py @@ -5,6 +5,7 @@ from typing import Any import numpy as np import yaml +from omegaconf import OmegaConf from pydantic import BaseModel, ConfigDict from bytelatent.checkpoint import CheckpointArgs @@ -39,6 +40,19 @@ def get_rng_state(seed: int, rank: int, world_size: int) -> dict[str, Any]: return np.random.default_rng((seed, rank, world_size)).bit_generator.state +def parse_args(args_cls): + cli_args = OmegaConf.from_cli() + file_cfg = OmegaConf.load(cli_args.config) + # We remove 'config' attribute from config as the underlying DataClass does not have it + del cli_args.config + + default_cfg = OmegaConf.create(args_cls().model_dump()) + cfg = OmegaConf.merge(default_cfg, file_cfg, cli_args) + cfg = OmegaConf.to_container(cfg, resolve=True, throw_on_missing=True) + pydantic_args = args_cls.model_validate(cfg) + return pydantic_args + + def distribute_data_to_rank( *, dataset_path: str, @@ -71,6 +85,22 @@ def distribute_data_to_rank( return rank_to_arrow_iterator_params[rank] +class PackedCausalTransformerGeneratorArgs(BaseModel): + model_config = ConfigDict(extra="forbid") + temperature: float = 0.0 + top_p: float | None = None + top_k: float | None = None + max_gen_len: int = 512 # Maximum number of tokens to generate + max_tokens: int = 1024 # Maximum number of tokens that can go through the model + max_prompt_len: int | None = None + until: list[str] = [] + compile_prefilling: bool = False + reduce_generation_overhead: bool = False + show_progress: bool = False + dtype: str | None = "bf16" + device: str | None = "cuda" + + class DataloaderArgs(BaseModel): model_config = ConfigDict(extra="forbid") s3_profile: str | None = None @@ -168,6 +198,58 @@ class DataloaderArgs(BaseModel): return packing_iterator +class LMHarnessArgs(BaseModel): + model_config = ConfigDict(extra="forbid") + tasks: list[Any] | None = None + num_fewshot: int | None = None + device: str | None = None + use_cache: str | None = None + cache_requests: bool = False + rewrite_requests_cache: bool = False + delete_requests_cache: bool = False + limit: int | float | None = None + bootstrap_iters: int = 100000 + check_integrity: bool = False + write_out: bool = False + log_samples: bool = True + system_instruction: str | None = None + apply_chat_template: bool | str = False + fewshot_as_multiturn: bool = False + gen_kwargs: str | None = None + verbosity: str = "INFO" + predict_only: bool = False + random_seed: int = 0 + numpy_random_seed: int = 1234 + torch_random_seed: int = 1234 + fewshot_random_seed: int = 1234 + + +class ValidationArgs(BaseModel): + model_config = ConfigDict(extra="forbid") + max_steps: int | None = ( + None # If None the whole validation file is used -> /!\ This number of steps is gpu dependent (100 max steps on 8 gpus = 800 steps on 1 gpu) + ) + use_val_from_train_src: bool = True # Use the validation set from training sources + root_dir: str = "" + sources: list[str] = [] # Other sources to eval on + + +class EvalArgs(BaseModel): + model_config = ConfigDict(extra="forbid") + dump_dir: str + ckpt_dir: str + metric_log_dir: str | None = None + generator: PackedCausalTransformerGeneratorArgs = ( + PackedCausalTransformerGeneratorArgs() + ) + + harness: LMHarnessArgs | None = LMHarnessArgs() + validation: ValidationArgs | None = ValidationArgs() + + global_step: int | None = None # for in-training evaluation + s3_profile: str | None = None + + class TrainArgs(BaseModel): model_config = ConfigDict(extra="forbid") name: str = "lingua" @@ -186,6 +268,9 @@ class TrainArgs(BaseModel): # Nb optimizer steps to take steps: int = 1000 + # If not None, halt training after this many steps, + # useful for debugging + max_steps: int | None = None data: DataloaderArgs = DataloaderArgs() optim: OptimArgs = OptimArgs() @@ -203,7 +288,7 @@ class TrainArgs(BaseModel): # If set to None, eval is run locally otherwise it launches a new job with the given number of gpus async_eval_gpus: int | None = None - eval: Any | None = None + eval: EvalArgs | None = None eval_on_gpus: int | None = None def dump_to_yaml_file( diff --git a/bytelatent/checkpoint.py b/bytelatent/checkpoint.py index bcf591e..f213c84 100644 --- a/bytelatent/checkpoint.py +++ b/bytelatent/checkpoint.py @@ -7,6 +7,7 @@ import re from pathlib import Path from typing import List, Optional, Tuple +import fsspec import torch import torch.distributed as dist import torch.distributed.checkpoint as dcp @@ -21,6 +22,7 @@ from torch.distributed.checkpoint.state_dict import ( set_state_dict, ) +from bytelatent.data.file_util import get_fs from bytelatent.distributed import get_is_master logger = logging.getLogger("CHECKPOINT") @@ -51,13 +53,14 @@ class CheckpointArgs(BaseModel): path: str | None = None init_ckpt_path: str | None = None continue_training_from_init: bool = False + s3_profile: str | None = None def _get_key_step(name: str): return int(re.findall(RE_DIGITS, name)[-1]) -def consolidate_checkpoints(ckpt_dir: str): +def consolidate_checkpoints(fs: fsspec.AbstractFileSystem, ckpt_dir: str): """ Consolidates all FSDP checkpoints in a directory to a single file Consolidate checkpoint is saved in a subdirectory of ckpt_dir @@ -102,15 +105,17 @@ def load_from_checkpoint( dcp.load(state_dict, checkpoint_id=ckpt_dir) +# TODO: Rewrite the file operations here to use fsspec to enable s3 writing. class CheckpointManager: def __init__(self, args: CheckpointArgs): self.path = args.path + self.fs = get_fs(self.path, s3_profile=args.s3_profile) self.dump_every = args.dump self.eval_every = args.eval self.init_ckpt_path = args.init_ckpt_path self.continue_training_from_init = args.continue_training_from_init - assert os.path.exists( + assert self.fs.exists( self.path ), f"Path {self.path} does not exist and needs to be created before using CheckpointManager (use instantiate_and_make_dir)" diff --git a/bytelatent/configs/debug.yaml b/bytelatent/configs/debug.yaml index 1098ff5..07d489f 100644 --- a/bytelatent/configs/debug.yaml +++ b/bytelatent/configs/debug.yaml @@ -98,11 +98,4 @@ logging: freq: 10 eval_on_gpus: 8 -eval: - dataset_dir: /checkpoint/amaia/codegen/datasets/eval - tasks: boolq,hellaswag,nq,piqa,siqa,tqa,winogrande,obqa,arc_easy,arc_challenge,race.middle,race.high,gsm8k,math,bbh,copa,human_eval_plus,mbpp,mmlu - generator: - max_tokens: 65536 - dtype: bf16 - - mp_size: 1 +eval: null diff --git a/bytelatent/configs/entropy_model.yaml b/bytelatent/configs/entropy_model.yaml index 51b65d4..d7c27b7 100644 --- a/bytelatent/configs/entropy_model.yaml +++ b/bytelatent/configs/entropy_model.yaml @@ -72,11 +72,4 @@ logging: freq: 10 eval_on_gpus: 8 -eval: - dataset_dir: ??? - tasks: ??? - generator: - max_tokens: 65536 - dtype: bf16 - - mp_size: 1 +eval: null diff --git a/bytelatent/data/data_types.py b/bytelatent/data/data_types.py index aa2daa9..f4bbc07 100644 --- a/bytelatent/data/data_types.py +++ b/bytelatent/data/data_types.py @@ -40,16 +40,6 @@ class BltPackTokensState(BaseModel): n_views: int = 2 -class DataLoaderState(BaseModel): - model_config = ConfigDict(extra="forbid") - multi_choice_state: MultiChoiceState - pack_tokens_state: BltPackTokensState - prefetch_state: PrefetchState - - -BltIterator = Iterator[tuple[BltExample, DataLoaderState]] - - class BltSequence(BaseModel): tokens: list[int] mask: list[bool] diff --git a/bytelatent/data/iterators/multiprocess_iterator.py b/bytelatent/data/iterators/multiprocess_iterator.py index f17ca6e..49d99ac 100644 --- a/bytelatent/data/iterators/multiprocess_iterator.py +++ b/bytelatent/data/iterators/multiprocess_iterator.py @@ -128,6 +128,13 @@ class MultiprocessIterator(StatefulIterator): self.producer = None self.stop_iterating_event = None self.state_dumped_event = None + self.force_shutdown = False + + def shutdown(self): + if self.producer is not None: + # This properly shuts things down + self.producer.kill() + self.force_shutdown = True def get_state(self) -> MultiprocessIteratorState: """ @@ -135,6 +142,10 @@ class MultiprocessIterator(StatefulIterator): to halt the background process and allow it to write the state to the main loop in order to not lose data """ + if self.force_shutdown: + raise ValueError( + "State will be invalid if shutdown was forced before state persisted." + ) if self.producer is None: serialized_prefetch_buffer = json.dumps( [b.to_python_dict() for b in self.prefetch_buffer] @@ -187,6 +198,10 @@ class MultiprocessIterator(StatefulIterator): ) def create_iter(self): + if self.force_shutdown: + raise ValueError( + "Iterator may be invalid if shutdown was forced before state persisted." + ) logging.info("Main thread: Creating MP iterator") # First yield from the stored prefetch buffer. if self.prefetch_buffer is not None: diff --git a/apps/main/eval.py b/bytelatent/eval.py similarity index 56% rename from apps/main/eval.py rename to bytelatent/eval.py index ed20f49..ae73066 100644 --- a/apps/main/eval.py +++ b/bytelatent/eval.py @@ -4,20 +4,20 @@ import json import logging import os from collections import defaultdict -from dataclasses import asdict, dataclass, field from datetime import datetime from pathlib import Path -from typing import Any, List, Optional, Tuple, Union +from typing import Any import torch -from lingua.args import dump_config -from lingua.data import init_choice_state, setup_sources from lm_eval import simple_evaluate from lm_eval.api.instance import Instance from lm_eval.api.model import LM from omegaconf import OmegaConf +from pydantic import BaseModel, ConfigDict +from bytelatent.args import EvalArgs, ValidationArgs, parse_args from bytelatent.checkpoint import CONSOLIDATE_FOLDER, consolidate_checkpoints +from bytelatent.data.file_util import get_fs from bytelatent.distributed import ( DistributedArgs, dist_mean_dict, @@ -25,72 +25,17 @@ from bytelatent.distributed import ( get_world_size, setup_torch_distributed, ) -from bytelatent.transformer import LMTransformer, LMTransformerArgs - -from apps.main.generate import ( +from bytelatent.generate import ( PackedCausalTransformerGenerator, - PackedCausalTransformerGeneratorArgs, load_consolidated_model_and_tokenizer, ) +from bytelatent.transformer import LMTransformer, LMTransformerArgs EVAL_FOLDER_NAME = "{:010d}" logger = logging.getLogger() -@dataclass -class LMHarnessArgs: - tasks: Optional[List[Any]] = None - num_fewshot: Optional[int] = None - device: Optional[str] = None - use_cache: Optional[str] = None - cache_requests: bool = False - rewrite_requests_cache: bool = False - delete_requests_cache: bool = False - limit: Optional[Union[int, float]] = None - bootstrap_iters: int = 100000 - check_integrity: bool = False - write_out: bool = False - log_samples: bool = True - system_instruction: Optional[str] = None - apply_chat_template: Union[bool, str] = False - fewshot_as_multiturn: bool = False - gen_kwargs: Optional[str] = None - verbosity: str = "INFO" - predict_only: bool = False - random_seed: int = 0 - numpy_random_seed: int = 1234 - torch_random_seed: int = 1234 - fewshot_random_seed: int = 1234 - - -@dataclass -class ValidationArgs: - max_steps: Optional[int] = ( - None # If None the whole validation file is used -> /!\ This number of steps is gpu dependent (100 max steps on 8 gpus = 800 steps on 1 gpu) - ) - use_val_from_train_src: bool = True # Use the validation set from training sources - root_dir: str = "" - sources: List[str] = field(default_factory=list) # Other sources to eval on - - -@dataclass -class EvalArgs: - name: str = "evals" - dump_dir: Optional[str] = None - metric_log_dir: Optional[str] = None - ckpt_dir: str = "" - generator: PackedCausalTransformerGeneratorArgs = field( - default_factory=PackedCausalTransformerGeneratorArgs - ) - harness: Optional[LMHarnessArgs] = field(default_factory=LMHarnessArgs) - validation: Optional[ValidationArgs] = field(default_factory=ValidationArgs) - - wandb: Optional[Any] = None - - global_step: Optional[int] = None # for in-training evaluation - - def all_dicts_same(dict_list): if not dict_list: # Check if the list is empty return True @@ -120,7 +65,7 @@ class EvalHarnessLM(LM): self._world_size = get_world_size() self.device = generator.device - def generate_until(self, requests: List[Instance]) -> List[str]: + def generate_until(self, requests: list[Instance]) -> list[str]: prompts, gen_args = zip(*[req.args for req in requests]) assert all_dicts_same(gen_args), "Doesn't support different gen args for now" gen_args = gen_args[0] @@ -141,7 +86,7 @@ class EvalHarnessLM(LM): filtered_gen.append(g) return filtered_gen - def loglikelihood(self, requests: List[Instance]) -> List[Tuple[float, bool]]: + def loglikelihood(self, requests: list[Instance]) -> list[tuple[float, bool]]: prompts, continuations = zip(*[req.args for req in requests]) inputs = [req.args[0] + req.args[1] for req in requests] max_gen_len = self.generator.max_gen_len @@ -158,7 +103,7 @@ class EvalHarnessLM(LM): self.generator.max_gen_len = max_gen_len return results - def loglikelihood_rolling(self, requests: List[Instance]) -> List[float]: + def loglikelihood_rolling(self, requests: list[Instance]) -> list[float]: prompts = [req.args[0] for req in requests] max_gen_len = self.generator.max_gen_len # We temporarily lower max gen len @@ -232,68 +177,73 @@ def eval_on_val(generator, val_args: ValidationArgs, train_cfg): return all_val_metrics -def launch_eval(cfg: EvalArgs): +def launch_eval(eval_args: EvalArgs): if not torch.distributed.is_initialized(): setup_torch_distributed(DistributedArgs()) + + fs = get_fs(eval_args.ckpt_dir, s3_profile=eval_args.s3_profile) if ( - Path(cfg.ckpt_dir).exists() - and (Path(cfg.ckpt_dir) / "params.json").exists() - and next(Path(cfg.ckpt_dir).glob("*.pth"), None) is not None + fs.exists(eval_args.ckpt_dir) + and fs.exists(os.path.join(eval_args.ckpt_dir, "params.json")) + and len(fs.glob(os.path.join(eval_args.ckpt_dir, "*.pth"))) != 0 ): - consolidate_path = Path(cfg.ckpt_dir) + consolidate_path = eval_args.ckpt_dir else: - consolidate_path = Path(cfg.ckpt_dir) / CONSOLIDATE_FOLDER - if not consolidate_path.exists() and get_global_rank() == 0: - consolidate_path = consolidate_checkpoints(cfg.ckpt_dir) + consolidate_path = os.path.join(eval_args.ckpt_dir, CONSOLIDATE_FOLDER) + if not fs.exists(consolidate_path) and get_global_rank() == 0: + consolidate_path = consolidate_checkpoints(eval_args.ckpt_dir) - Path(cfg.dump_dir).mkdir(parents=True, exist_ok=True) - dump_config(cfg, Path(cfg.dump_dir) / "config.yaml", log_config=False) + fs.mkdirs(eval_args.dump_dir, exist_ok=True) + with fs.open(os.path.join(eval_args.dump_dir, "config.yaml"), "w") as f: + f.write(eval_args.model_dump_json()) - consolidate_path = str(consolidate_path) torch.distributed.barrier() logger.info("Loading model") + # TODO: Make this general so that it works with either + # LMTransformer or Blt, similar with args model, tokenizer, train_cfg = load_consolidated_model_and_tokenizer( consolidate_path, - model_cls=LMTransformer, - model_args_cls=LMTransformerArgs, ) logger.info("Model loaded") model.eval() - generator = PackedCausalTransformerGenerator(cfg.generator, model, tokenizer) + generator = PackedCausalTransformerGenerator(eval_args.generator, model, tokenizer) wrap = EvalHarnessLM(generator) - results = simple_evaluate(wrap, **asdict(cfg.harness)) + # Redo + results = simple_evaluate(wrap, eval_args.harness.model_dump()) val_results = None - if cfg.validation: - val_results = eval_on_val(generator, cfg.validation, train_cfg) + if eval_args.validation: + val_results = eval_on_val(generator, eval_args.validation, train_cfg) if get_global_rank() == 0: - with open(Path(cfg.dump_dir) / "results.json", "w") as f: + with fs.open(os.path.join(eval_args.dump_dir, "results.json"), "w") as f: f.write(json.dumps(results)) logger.info(f"All evaluation results: {results['results']}") if val_results is not None: - with open(Path(cfg.dump_dir) / "validation.json", "w") as f: + with fs.open(os.path.join(eval_args.dump_dir, "validation.json"), "w") as f: f.write(json.dumps(val_results)) logger.info(f"All validation results: {val_results}") - if cfg.metric_log_dir and get_global_rank() == 0: - metric_log_path = Path(cfg.metric_log_dir) / "metrics.eval.jsonl" + if eval_args.metric_log_dir and get_global_rank() == 0: + metric_log_path = os.path.join(eval_args.metric_log_dir, "metrics.eval.jsonl") logger.info(f"Writing metric logs to {metric_log_path}") timestamp = { "created_at": datetime.utcnow().isoformat(), } - if cfg.global_step is not None: - timestamp["global_step"] = cfg.global_step + if eval_args.global_step is not None: + timestamp["global_step"] = eval_args.global_step print( json.dumps(timestamp | results["results"]), - file=open(metric_log_path, mode="a"), + file=fs.open(metric_log_path, mode="a"), flush=True, ) - val_log_path = Path(cfg.metric_log_dir) / "metrics.validation.jsonl" + val_log_path = os.path.join( + eval_args.metric_log_dir, "metrics.validation.jsonl" + ) if val_results is not None: print( json.dumps(timestamp | val_results), - file=open(val_log_path, mode="a"), + file=fs.open(val_log_path, mode="a"), flush=True, ) @@ -301,53 +251,8 @@ def launch_eval(cfg: EvalArgs): def main(): - """ - The command line interface here uses OmegaConf https://omegaconf.readthedocs.io/en/2.3_branch/usage.html#from-command-line-arguments - This accepts arguments as a dot list - So if the dataclass looks like - - @dataclass - class DummyArgs: - name: str - model: LMTransformerArgsgs - - @dataclass - class LMTransformerArgsgs: - dim: int - - Then you can pass model.dim=32 to change values in LMTransformerArgsgs - or just name=tictac for top level attributes. - - The behavior here is as follows: - 1. We instantiate EvalArgs with its default values - 2. We override those default values with the ones in the provided config file - 3. We override the result with the additional arguments provided through command line - - For example, if the config is the following - - model: - dim: 128 - n_layers: 4 - - and you call eval.py with eval.py model.dim=64 - - Then the final TrainArgs will have - - model: - dim: 64 - n_layers: 4 - - Plus all the default values in EvalArgs dataclass. - """ - cli_args = OmegaConf.from_cli() - file_cfg = OmegaConf.load(cli_args.config) - # We remove 'config' attribute from config as the underlying DataClass does not have it - del cli_args.config - - default_cfg = OmegaConf.structured(EvalArgs()) - cfg = OmegaConf.merge(default_cfg, file_cfg, cli_args) - cfg = OmegaConf.to_object(cfg) - launch_eval(cfg) + eval_args = parse_args(EvalArgs) + launch_eval(eval_args) if __name__ == "__main__": diff --git a/apps/main/generate.py b/bytelatent/generate.py similarity index 91% rename from apps/main/generate.py rename to bytelatent/generate.py index a3a8627..eb79d81 100644 --- a/apps/main/generate.py +++ b/bytelatent/generate.py @@ -1,20 +1,16 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. +import os import time -from dataclasses import dataclass, field -from pathlib import Path -from typing import List, Optional import torch -from lingua.args import dataclass_from_dict -from lingua.tokenizers.abstract_tokenizer import Tokenizer -from lingua.tokenizers.build_tokenizer import build_tokenizer from omegaconf import OmegaConf from torch import nn from torch.nn import functional as F from torch.nn.attention.flex_attention import create_block_mask from tqdm import tqdm +from bytelatent.args import PackedCausalTransformerGeneratorArgs, TrainArgs from bytelatent.base_transformer import ( Attention, causal_mask, @@ -23,7 +19,10 @@ from bytelatent.base_transformer import ( lengths_to_start_ids, ) from bytelatent.checkpoint import CONSOLIDATE_NAME -from bytelatent.transformer import LMTransformer, LMTransformerArgs +from bytelatent.data.file_util import get_fs +from bytelatent.model.blt import ByteLatentTransformer +from bytelatent.tokenizers.abstract_tokenizer import Tokenizer +from bytelatent.transformer import LMTransformer def sample_top_p(probs: torch.Tensor, p: float) -> torch.Tensor: @@ -62,7 +61,7 @@ def sample_tokens(logits, temperature=0.0, top_p=None, top_k=None): return next_token.view(shape[:-1]) -def pack_prompts(prompts: List[int]): +def pack_prompts(prompts: list[int]): res = [] lengths = [] for i, p in enumerate(prompts): @@ -120,22 +119,6 @@ class KVCache(nn.Module): return self.k_cache, self.v_cache -@dataclass -class PackedCausalTransformerGeneratorArgs: - temperature: float = 0.0 - top_p: Optional[float] = None - top_k: Optional[float] = None - max_gen_len: int = 512 # Maximum number of tokens to generate - max_tokens: int = 1024 # Maximum number of tokens that can go through the model - max_prompt_len: Optional[int] = None - until: List[str] = field(default_factory=list) - compile_prefilling: bool = False - reduce_generation_overhead: bool = False - show_progress: bool = False - dtype: Optional[str] = "bf16" - device: Optional[str] = "cuda" - - class PackedCausalTransformerGenerator: def __init__( self, @@ -401,25 +384,29 @@ class PackedCausalTransformerGenerator: def load_consolidated_model_and_tokenizer( consolidated_path, - model_cls=LMTransformer, - model_args_cls=LMTransformerArgs, ): - ckpt_path = Path(consolidated_path) - config = ckpt_path / "params.json" - config = OmegaConf.load(config) + train_args_path = os.path.join(consolidated_path, "params.json") + fs = get_fs(train_args_path) + with fs.open(train_args_path) as f: + train_args = TrainArgs.model_validate_json(f.read()) + + if train_args.train_entropy_model: + model_args = train_args.entropy_model + model = LMTransformer(model_args) + else: + model_args = train_args.model + model = ByteLatentTransformer(model_args) param_dtype = dict(fp32=torch.float32, fp16=torch.float16, bf16=torch.bfloat16)[ - config.distributed.model_dtype + train_args.distributed.model_dtype ] - model_args = dataclass_from_dict(model_args_cls, config.model, strict=False) - tokenizer = build_tokenizer(config.data.tokenizer.name, config.data.tokenizer.path) - model = model_cls(model_args) - st_dict = torch.load(ckpt_path / CONSOLIDATE_NAME, weights_only=True) + tokenizer = train_args.data.tokenizer_args.build() + st_dict = torch.load(consolidated_path / CONSOLIDATE_NAME, weights_only=True) model.load_state_dict(st_dict["model"]) model = model.cuda().eval() for param in model.parameters(): param.data = param.data.to(dtype=param_dtype) - return model, tokenizer, config + return model, tokenizer, train_args def main(): diff --git a/bytelatent/train.py b/bytelatent/train.py index 1d0fa40..6b20ecd 100644 --- a/bytelatent/train.py +++ b/bytelatent/train.py @@ -10,7 +10,7 @@ from copy import deepcopy from dataclasses import asdict, dataclass from pathlib import Path from timeit import default_timer as timer -from typing import Any, Dict, Type, TypeVar +from typing import Any, TypeVar import torch import torch.distributed @@ -23,9 +23,13 @@ from torch.distributed._tensor import DTensor from torch.distributed.checkpoint.stateful import Stateful from torch.optim import lr_scheduler -from bytelatent.args import TrainArgs +from bytelatent.args import TrainArgs, parse_args from bytelatent.checkpoint import CheckpointManager, load_from_checkpoint -from bytelatent.data.data_types import DataLoaderState +from bytelatent.data.iterators.multiprocess_iterator import ( + MultiprocessIterator, + MultiprocessIteratorState, +) +from bytelatent.data.iterators.packing_iterator import PackingIteratorState from bytelatent.distributed import ( check_model_value_range, clean_env, @@ -39,6 +43,7 @@ from bytelatent.distributed import ( setup_env, setup_torch_distributed, ) +from bytelatent.eval import EVAL_FOLDER_NAME, launch_eval from bytelatent.logger import init_logger from bytelatent.metrics import GPUMemoryMonitor, MetricLogger, get_num_params from bytelatent.model.blt import ByteLatentTransformer @@ -70,36 +75,49 @@ def flatten_dict(d, parent_key="", sep="_"): return dict(items) -def dataclass_from_dict(cls: Type[T], data: dict, strict: bool = True) -> T: - """ - Converts a dictionary to a dataclass instance, recursively for nested structures. - """ - base = OmegaConf.structured(cls()) - OmegaConf.set_struct(base, strict) - override = OmegaConf.create(data) - return OmegaConf.to_object(OmegaConf.merge(base, override)) +def get_iterator_state_name(iterator_state): + if isinstance(iterator_state, MultiprocessIteratorState): + return "multiprocess" + elif isinstance(iterator_state, PackingIteratorState): + return "packing" + else: + raise ValueError(f"Unsupported iterator to get name from: {iterator_state}") +# TODO: Make this pydantic based instead of data class based +# TODO: Generalize this to any iterator state @dataclass class TrainState(Stateful): step: int # Nb of steps taken by the optimizer acc_step: int # Nb of accumulation steps done since last optimizer step scheduler: lr_scheduler.LambdaLR - data_loader_state: DataLoaderState + data_loader_state: MultiprocessIteratorState | PackingIteratorState scale: float = 1.0 + data_loader_class: str | None = None - def state_dict(self) -> Dict[str, Any]: + def state_dict(self) -> dict[str, Any]: return { "step": self.step, "acc_step": self.acc_step, - "data_loader_state": self.data_loader_state.dict(), + "data_loader_state": self.data_loader_state.model_dump(), + "data_loader_class": get_iterator_state_name(self.data_loader_state), "scheduler": self.scheduler.state_dict(), } def load_state_dict(self, state_dict): self.step = state_dict["step"] self.acc_step = state_dict["acc_step"] - self.data_loader_state = DataLoaderState(**state_dict["data_loader_state"]) + self.data_loader_class = state_dict["data_loader_class"] + if self.data_loader_class == "multiprocess": + self.data_loader_state = MultiprocessIteratorState( + **state_dict["data_loader_state"] + ) + elif self.data_loader_class == "packing": + self.data_loader_state = PackingIteratorState( + **state_dict["data_loader_state"] + ) + else: + raise ValueError(f"invalid data loader class: {self.data_loader_class}") self.scheduler.load_state_dict(state_dict["scheduler"]) @@ -345,7 +363,10 @@ def train(args: TrainArgs): nwords_since_last_log = 0 time_last_log = timer() gc.collect() - while train_state.step < args.steps: + saved = False + while train_state.step < args.steps and ( + args.max_steps is None or train_state.step < args.max_steps + ): # We constrain train_state.acc_step to be in range 0 to args.grad_acc_steps - 1 train_state.acc_step += 1 train_state.acc_step = train_state.acc_step % args.grad_acc_steps @@ -552,7 +573,6 @@ def train(args: TrainArgs): f" pow: {gpu_mem_stats.power_draw/1000} W" ) - saved = False if every_n_steps( train_state, args.checkpoint.dump.every, acc_step=0 ) or every_n_steps(train_state, args.checkpoint.eval.every, acc_step=0): @@ -567,18 +587,14 @@ def train(args: TrainArgs): if args.eval is not None and every_n_steps( train_state, args.checkpoint.eval.every, acc_step=0 ): - from apps.main.eval import EVAL_FOLDER_NAME, EvalArgs, launch_eval - - eval_args = dataclass_from_dict(EvalArgs, args.eval) + eval_args = args.eval eval_args.global_step = train_state.step eval_args.ckpt_dir = str(checkpoint.existing_saves[-1]) - eval_args.dump_dir = str( - os.path.join( - args.dump_dir, - "evals", - EVAL_FOLDER_NAME.format(train_state.step), - ) + eval_args.dump_dir = os.path.join( + args.dump_dir, + "evals", + EVAL_FOLDER_NAME.format(train_state.step), ) eval_args.metric_log_dir = args.dump_dir if args.async_eval_gpus is None: @@ -619,6 +635,9 @@ def train(args: TrainArgs): args, device_mesh=world_mesh, ) + if isinstance(data_loader, MultiprocessIterator): + logger.info("Closing MP iterator before exiting") + data_loader.shutdown() gc.collect() @@ -661,15 +680,7 @@ def main(): Plus all the default values in TrainArgs dataclass. """ - cli_args = OmegaConf.from_cli() - file_cfg = OmegaConf.load(cli_args.config) - # We remove 'config' attribute from config as the underlying DataClass does not have it - del cli_args.config - - default_cfg = OmegaConf.create(TrainArgs().model_dump()) - cfg = OmegaConf.merge(default_cfg, file_cfg, cli_args) - cfg = OmegaConf.to_container(cfg, resolve=True, throw_on_missing=True) - train_args = TrainArgs.model_validate(cfg) + train_args = parse_args(TrainArgs) if train_args.debug_dynamo: import torch._dynamo From 11cad6c84d083645fbde617c0b125a615ecc094d Mon Sep 17 00:00:00 2001 From: Pedro Rodriguez Date: Tue, 28 Jan 2025 00:57:06 +0000 Subject: [PATCH 17/59] WIP parallel copy script Summary: Test Plan: --- bytelatent/data/parallel_copy.py | 224 +++++++++++++++++++++++++++++++ 1 file changed, 224 insertions(+) create mode 100644 bytelatent/data/parallel_copy.py diff --git a/bytelatent/data/parallel_copy.py b/bytelatent/data/parallel_copy.py new file mode 100644 index 0000000..2820429 --- /dev/null +++ b/bytelatent/data/parallel_copy.py @@ -0,0 +1,224 @@ +import logging +import os +import shutil +import time +from enum import Enum + +import fsspec +import submitit +import typer +from rich.logging import RichHandler + +FORMAT = "%(message)s" +logging.basicConfig( + level="INFO", format=FORMAT, datefmt="[%X]", handlers=[RichHandler()] +) +logger = logging.getLogger("parallel_copy") + + +S3_PREFIX = "s3://" + + +def get_fs(path: str, s3_profile: str | None = None) -> fsspec.AbstractFileSystem: + if path.startswith("s3://"): + if s3_profile is None: + return fsspec.filesystem( + "s3", default_block_size=1000000 * 2**20, max_concurrency=1 + ) + else: + return fsspec.filesystem( + "s3", + profile=s3_profile, + default_block_size=1000000 * 2**20, + max_concurrency=1, + ) + else: + return fsspec.filesystem("file") + + +def strip_s3_prefix(path: str): + if path.startswith(S3_PREFIX): + return path[len(S3_PREFIX) :] + else: + return path + + +class OverwriteMode(str, Enum): + ALWAYS = "always" + SIZE_MISMATCH = "size_mismatch" + NEVER = "never" + + +class ParallelMode(str, Enum): + SLURM = "slurm" + MULTIPROCESS = "multiprocess" + + +def copy_src_to_dst( + src_fs: fsspec.AbstractFileSystem, + dst_fs: fsspec.AbstractFileSystem, + src_file: str, + dst_file: str, + dry_run: bool = False, +): + if dry_run: + logging.info("Dry run copy: %s -> %s", src_file, dst_file) + else: + dst_parent_directory = os.path.dirname(dst_file) + dst_fs.mkdirs(dst_parent_directory, exist_ok=True) + with src_fs.open(src_file, "rb") as src_pointer, dst_fs.open( + dst_file, "wb" + ) as dst_pointer: + shutil.copyfileobj(src_pointer, dst_pointer) + + +class CopyJob(submitit.helpers.Checkpointable): + def __call__( + self, + src_fs_dict: dict, + dst_fs_dict: dict, + src_file: str, + dst_file: str, + dry_run: bool = False, + validate_size: bool = True, + ): + src_fs = fsspec.AbstractFileSystem.from_dict(src_fs_dict) + dst_fs = fsspec.AbstractFileSystem.from_dict(dst_fs_dict) + copy_src_to_dst(src_fs, dst_fs, src_file, dst_file, dry_run=dry_run) + if validate_size and not dry_run: + src_size = src_fs.size(src_file) + dst_size = dst_fs.size(dst_file) + if src_size != dst_size: + raise ValueError( + f"Mismatched sizes for src={src_file} dst={dst_file} {src_size} != {dst_size}" + ) + return True + + +def main( + src_dir: str, + dst_dir: str, + src_s3_profile: str | None = None, + dst_s3_profile: str | None = None, + n_workers: int = 16, + cpus_per_task: int = 2, + overwrite_mode: OverwriteMode = OverwriteMode.SIZE_MISMATCH, + validate_size: bool = True, + parallel_mode: ParallelMode = ParallelMode.MULTIPROCESS, + dry_run: bool = False, + job_dir: str = "jobs_parallel-copy", + slurm_qos: str | None = None, + slurm_time_hours: int = 72, + slurm_memory: str = "0", + wait: bool = True, + wait_period: int = 5, +): + logging.info("Starting parallell copy: %s -> %s", src_dir, dst_dir) + logging.info("job_dir=%s", job_dir) + logging.info( + "Parallel=%s validate_size=%s overwrite_mode=%s qos=%s", + parallel_mode, + validate_size, + overwrite_mode, + slurm_qos, + ) + if parallel_mode == ParallelMode.MULTIPROCESS: + executor = submitit.LocalExecutor(folder=job_dir) + elif parallel_mode == ParallelMode.SLURM: + executor = submitit.SlurmExecutor(folder=job_dir) + executor.update_parameters( + time=slurm_time_hours * 60, + ntasks_per_node=1, + cpus_per_task=cpus_per_task, + array_parallelism=n_workers, + mem=slurm_memory, + gpus_per_node=0, + ) + if slurm_qos is not None: + executor.update_parameters(qos=slurm_qos) + else: + raise ValueError("Invalid parallel mode") + + assert src_dir.endswith("/"), "src_dir must end with a /" + assert dst_dir.endswith("/"), "dst_dir must end with a /" + src_fs = get_fs(src_dir, s3_profile=src_s3_profile) + dst_fs = get_fs(dst_dir, s3_profile=dst_s3_profile) + + src_dir = strip_s3_prefix(src_dir) + dst_dir = strip_s3_prefix(dst_dir) + logging.info("src: %s, dst: %s", src_dir, dst_dir) + + assert src_fs.isdir(src_dir), "src_dir must be a directory" + if dst_fs.exists(dst_dir): + assert dst_dir, "dst_dir must be a directory if it exists" + else: + dst_fs.mkdirs(dst_dir, exist_ok=True) + + files = src_fs.find(src_dir) + logging.info("Files found to check for transfer: %s", len(files)) + jobs = [] + with executor.batch(): + for src_file in files: + relative_src = src_file[len(src_dir) :] + dst_file_path = os.path.join(dst_dir, relative_src) + logging.debug("src: %s -> dst %s", src_file, dst_file_path) + if dst_fs.exists(dst_file_path): + if overwrite_mode == OverwriteMode.NEVER: + pass + elif overwrite_mode == OverwriteMode.ALWAYS: + logging.info("copy: %s -> %s", src_file, dst_file_path) + job = executor.submit( + CopyJob(), + src_fs.to_dict(), + dst_fs.to_dict(), + src_file, + dst_file_path, + dry_run=dry_run, + validate_size=validate_size, + ) + jobs.append(job) + elif overwrite_mode == OverwriteMode.SIZE_MISMATCH: + if src_fs.size(src_file) != dst_fs.size(dst_file_path): + logging.info("copy: %s -> %s", src_file, dst_file_path) + job = executor.submit( + CopyJob(), + src_fs.to_dict(), + dst_fs.to_dict(), + src_file, + dst_file_path, + dry_run=dry_run, + validate_size=validate_size, + ) + jobs.append(job) + else: + raise ValueError("Unknown overwrite_mode") + else: + logging.info("copy: %s -> %s", src_file, dst_file_path) + job = executor.submit( + CopyJob(), + src_fs.to_dict(), + dst_fs.to_dict(), + src_file, + dst_file_path, + dry_run=dry_run, + validate_size=validate_size, + ) + jobs.append(job) + if wait: + while True: + num_finished = sum(job.done() for job in jobs) + logging.info("Total Jobs: %s Completed Jobs: %s", len(jobs), num_finished) + if num_finished == len(jobs): + break + time.sleep(wait_period) + output = [job.result() for job in jobs] + if all(output): + logging.info("All copies succeeded") + else: + logging.info("Some copies failed") + else: + logging.info("Not waiting for job to complete before exiting submit program") + + +if __name__ == "__main__": + typer.run(main) From c6ef4285e2f289fe3e87216c9ca0f9eb58ea5c44 Mon Sep 17 00:00:00 2001 From: Pedro Rodriguez Date: Tue, 4 Feb 2025 18:03:18 +0000 Subject: [PATCH 18/59] Several changes to enable entropy model training/eval Summary: - Make arrow iterator able to read from jsonl files, the entropies are omitted in this case - Make the data/checkpoint code fsspec compatible - Fix issues with all reduce with non-bf16 in dist_sum and norm computation. - Minimal fixes to get eval to run, it is slow currently - Add bpb numbers during training Test Plan: --- bytelatent/args.py | 44 +++++++- bytelatent/checkpoint.py | 97 +++++++++-------- bytelatent/data/iterators/arrow_iterator.py | 100 ++++++++++-------- bytelatent/distributed.py | 18 +++- bytelatent/eval.py | 64 +++++++---- bytelatent/generate.py | 6 +- bytelatent/norms.py | 99 +++++++++++++++++ bytelatent/preprocess/preprocess_entropies.py | 26 +++-- bytelatent/train.py | 59 ++++++++++- 9 files changed, 382 insertions(+), 131 deletions(-) create mode 100644 bytelatent/norms.py diff --git a/bytelatent/args.py b/bytelatent/args.py index d1bac46..23c571a 100644 --- a/bytelatent/args.py +++ b/bytelatent/args.py @@ -1,8 +1,10 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. +import json import logging import os from typing import Any +import fsspec import numpy as np import yaml from omegaconf import OmegaConf @@ -10,10 +12,10 @@ from pydantic import BaseModel, ConfigDict from bytelatent.checkpoint import CheckpointArgs from bytelatent.data.data_types import Batch +from bytelatent.data.file_util import get_fs from bytelatent.data.iterators.abstract_iterator import StatefulIterator from bytelatent.data.iterators.arrow_iterator import ( ArrowFileIterator, - find_and_sanitize_chunks, ) from bytelatent.data.iterators.looping_iterator import LoopingIterator from bytelatent.data.iterators.multiprocess_iterator import MultiprocessIterator @@ -53,6 +55,43 @@ def parse_args(args_cls): return pydantic_args +def read_args_file(fs: fsspec.AbstractFileSystem, path: str) -> Any: + with fs.open(path, "rt") as f: + if path.endswith(".json"): + return json.load(f) + elif path.endswith(".yaml"): + return yaml.load(f) + else: + raise ValueError("Invalid args file format") + + +TRAIN_DATA_FILE_PATTERN = "*.chunk.*.jsonl" + + +def find_and_sanitize_chunks( + dataset_path: str, + world_size: int, + file_pattern: str, + s3_profile: str | None = None, +): + fs = get_fs(dataset_path, s3_profile=s3_profile) + path_with_glob = os.path.join(dataset_path, file_pattern) + dataset_chunks = fs.glob(path_with_glob) + n_chunks = len(dataset_chunks) + + if n_chunks > world_size: + n_discard = n_chunks - world_size + dataset_chunks = dataset_chunks[:world_size] + else: + assert ( + world_size % n_chunks == 0 + ), "World size should be a multiple of number of chunks" + + assert n_chunks > 0, f"No valid chunks in {dataset_path}" + + return dataset_chunks + + def distribute_data_to_rank( *, dataset_path: str, @@ -62,9 +101,10 @@ def distribute_data_to_rank( rank: int, world_size: int, s3_profile: str | None = None, + file_pattern: str = TRAIN_DATA_FILE_PATTERN, ) -> ArrowFileIterator: dataset_chunks = find_and_sanitize_chunks( - dataset_path, world_size, s3_profile=s3_profile + dataset_path, world_size, file_pattern, s3_profile=s3_profile ) n_workers_per_chunk = world_size // len(dataset_chunks) rank_to_arrow_iterator_params = [] diff --git a/bytelatent/checkpoint.py b/bytelatent/checkpoint.py index f213c84..6631673 100644 --- a/bytelatent/checkpoint.py +++ b/bytelatent/checkpoint.py @@ -4,8 +4,6 @@ import json import logging import os import re -from pathlib import Path -from typing import List, Optional, Tuple import fsspec import torch @@ -70,26 +68,29 @@ def consolidate_checkpoints(fs: fsspec.AbstractFileSystem, ckpt_dir: str): Returns the path to the consolidated checkpoint """ - consolidate_path = Path(ckpt_dir) / CONSOLIDATE_FOLDER - if not (consolidate_path / CONSOLIDATE_NAME).exists(): - consolidate_path.mkdir(exist_ok=True) - logger.info(f"Consolidating to: {str(consolidate_path)}") - dcp_to_torch_save(ckpt_dir, str(consolidate_path / CONSOLIDATE_NAME)) - (consolidate_path / CONFIG_NAME).write_text( - (Path(ckpt_dir) / CONFIG_NAME).read_text() + consolidate_path = os.path.join(ckpt_dir, CONSOLIDATE_FOLDER) + consolidate_name = os.path.join(consolidate_path, CONSOLIDATE_NAME) + if not fs.exists(consolidate_name): + fs.mkdirs(consolidate_path, exist_ok=True) + logger.info(f"Consolidating to: {consolidate_path}") + dcp_to_torch_save(ckpt_dir, consolidate_name) + fs.write_text( + os.path.join(consolidate_path, CONFIG_NAME), + fs.read_text(os.path.join(ckpt_dir, CONFIG_NAME)), ) logger.info("Consolidated !") return consolidate_path def load_from_checkpoint( + fs: fsspec.AbstractFileSystem, ckpt_dir: str, model: nn.Module, - optimizer: Optional[torch.optim.Optimizer] = None, + optimizer: torch.optim.Optimizer | None = None, model_key: str = "model", optim_key: str = "optim", ): - if not (Path(ckpt_dir) / ".metadata").exists(): + if not fs.exists(os.path.join(ckpt_dir, ".metadata")): raise ValueError( f"Please convert the checkpoint distcp format using `torch.distributed.checkpoint.format_utils.torch_save_to_dcp` before loading it" ) @@ -121,13 +122,13 @@ class CheckpointManager: self.existing_saves = self.get_existing_saves() - def get_existing_saves(self) -> List[Path]: + def get_existing_saves(self) -> list[str]: folders = [ p - for p in Path(self.path).iterdir() - if p.is_dir() and re.match(RE_FOLDER, p.name) + for p in self.fs.ls(self.path) + if self.fs.isdir(p) and re.match(RE_FOLDER, os.path.basename(p)) ] - folders.sort(key=lambda p: _get_key_step(p.name)) + folders.sort(key=lambda p: _get_key_step(os.path.basename(p))) return folders def clean_up(self): @@ -136,8 +137,9 @@ class CheckpointManager: eval_folders = [] other_folders = [] for p in self.existing_saves: - is_dump = _get_key_step(p.name) % self.dump_every.every == 0 - is_eval = _get_key_step(p.name) % self.eval_every.every == 0 + assert isinstance(p, str), f"Base path type: {p}" + is_dump = _get_key_step(os.path.basename(p)) % self.dump_every.every == 0 + is_eval = _get_key_step(os.path.basename(p)) % self.eval_every.every == 0 if is_dump: dump_folders.append(p) if is_eval: @@ -161,40 +163,39 @@ class CheckpointManager: if dist.get_rank() == 0: for folder in folder_to_remove: - for file in folder.iterdir(): - if file.is_file(): - file.unlink() - elif file.is_dir(): - assert file.name in [CONSOLIDATE_FOLDER] - for f in file.iterdir(): - f.unlink() - file.rmdir() - folder.rmdir() + for file in self.fs.ls(folder): + if self.fs.isfile(file): + self.fs.rm_file(file) + elif self.fs.isdir(file): + assert os.path.name(file) in [CONSOLIDATE_FOLDER] + for f in self.fs.ls(file): + self.fs.rm(f) + self.fs.rmdir(file) + self.fs.rmdir(folder) dist.barrier() self.existing_saves = list(folder_to_keep) - self.existing_saves.sort(key=lambda p: _get_key_step(p.name)) + self.existing_saves.sort(key=lambda p: _get_key_step(os.path.basename(p))) - def get_last_step_path(self, dp_rank: int = 0) -> Optional[Path]: + def get_last_step_path(self, dp_rank: int = 0) -> str | None: path = None for p in reversed(self.existing_saves): - if (p / TRAIN_STATE_NAME.format(dp_rank)).is_file(): + + if self.fs.isfile(os.path.join(p, TRAIN_STATE_NAME.format(dp_rank))): path = p break return path - def _create_folder(self, base_path: Path, folder_name: str) -> Path: - folder = base_path / folder_name + def _create_folder(self, base_path: str, folder_name: str) -> str: + folder = os.path.join(base_path, folder_name) if get_is_master(): - folder.mkdir(parents=False, exist_ok=True) + self.fs.mkdirs(folder, exist_ok=True) if dist.is_initialized(): dist.barrier() return folder - def _get_dp_tp_mesh( - self, device_mesh: Optional[DeviceMesh] = None - ) -> Tuple[int, int]: + def _get_dp_tp_mesh(self, device_mesh: DeviceMesh | None = None) -> tuple[int, int]: dp_rank = 0 tp_rank = 0 if device_mesh is not None: @@ -222,14 +223,14 @@ class CheckpointManager: model, optimizer, train_state, - config, - device_mesh: Optional[DeviceMesh] = None, + config: BaseModel, + device_mesh: DeviceMesh | None = None, ) -> bool: # When creating directory check if only rank0 or is there other solution - path = Path(self.path) + path = self.path curr_save_dir = self._create_folder(path, FOLDER_NAME.format(train_state.step)) - logger.info(f"Saving to: {str(curr_save_dir)}") + logger.info(f"Saving to: {curr_save_dir}") if dist.is_initialized(): dist.barrier() @@ -242,17 +243,19 @@ class CheckpointManager: if dist.is_initialized(): dist.barrier() + print("config type", type(config)) if get_is_master(): - config.dump_to_yaml_file(curr_save_dir / CONFIG_NAME) + self.fs.write_text( + os.path.join(curr_save_dir, CONFIG_NAME), config.model_dump_json() + ) # Add json dump here dp_rank, tp_rank = self._get_dp_tp_mesh(device_mesh) if tp_rank == 0: train_state_name = TRAIN_STATE_NAME.format(dp_rank) - logger.info( - f"Saving train state to: {str(curr_save_dir / train_state_name)}" - ) - with open(curr_save_dir / train_state_name, "w") as f: + train_state_full_path = os.path.join(curr_save_dir, train_state_name) + logger.info(f"Saving train state to: {train_state_full_path}") + with self.fs.open(train_state_full_path, "w") as f: json.dump(train_state.state_dict(), f) logger.info("Train state saved !") @@ -271,7 +274,7 @@ class CheckpointManager: optimizer, train_state, device_mesh: DeviceMesh, - path: Optional[Path] = None, + path: str | None = None, ): dp_rank, tp_rank = self._get_dp_tp_mesh(device_mesh) # Loading tries to load the provided path, if not available the last saved step and finally from the init path @@ -284,12 +287,12 @@ class CheckpointManager: # Only load train state if it's provided, the files exist and we're not loading from init path train_state_name = TRAIN_STATE_NAME.format(dp_rank) logger.info("Reloading train state") - with open(path / train_state_name, "r") as f: + with self.fs.open(os.path.join(path, train_state_name), "r") as f: train_state_dict = json.load(f) train_state.load_state_dict(train_state_dict) logger.info("Train state reloaded") - logger.info(f"Loading from: {str(path)}") + logger.info(f"Loading from: {path}") state_dict = self.get_state_dict( model=model, optimizer=optimizer, diff --git a/bytelatent/data/iterators/arrow_iterator.py b/bytelatent/data/iterators/arrow_iterator.py index 4e7b99e..fa18bf1 100644 --- a/bytelatent/data/iterators/arrow_iterator.py +++ b/bytelatent/data/iterators/arrow_iterator.py @@ -16,6 +16,10 @@ from bytelatent import ByteLatentError from bytelatent.data.data_types import BltExample from bytelatent.data.file_util import get_fs from bytelatent.data.iterators.abstract_iterator import IteratorState, StatefulIterator +from bytelatent.preprocess.preprocess_entropies import ( + get_id_key, + get_text, +) logger = getLogger(__name__) @@ -32,6 +36,7 @@ class ArrowFileIteratorState(BaseModel, IteratorState): arrow_batch_size: int = 100 s3_profile: str | None filesystem_type: str | None = None + file_format: str def build(self) -> "ArrowFileIterator": arrow_file = ArrowFileIterator( @@ -44,6 +49,7 @@ class ArrowFileIteratorState(BaseModel, IteratorState): dataset_files=self.dataset_files, s3_profile=self.s3_profile, filesystem_type=self.filesystem_type, + file_format=self.file_format, ) if self.row_num != 0: arrow_file._set_row_num(self.row_num) @@ -70,6 +76,7 @@ class ArrowFileIterator(StatefulIterator): dataset_files: list[str] | None = None, s3_profile: str | None = None, filesystem_type: str | None = None, + file_format: str = "arrow", ): assert 0 <= worker_id < num_workers, (worker_id, num_workers) if file_path is None and dataset_files is None: @@ -87,12 +94,16 @@ class ArrowFileIterator(StatefulIterator): self.arrow_batch_size = arrow_batch_size self.s3_profile = s3_profile self.filesystem_type = filesystem_type + self.file_format = file_format self.fs = None if self.filesystem_type is not None: if self.filesystem_type == "file": self.fs = fsspec.filesystem("file") elif self.filesystem_type == "s3": self.fs = fsspec.filesystem("s3", profile=s3_profile) + else: + raise ValueError("Unknown filesystem") + logger.info("Arrow iterator using fs=%s", self.fs) if dataset_files is None: # Prepare arrow shards @@ -153,6 +164,7 @@ class ArrowFileIterator(StatefulIterator): dataset_files=self.dataset_files, s3_profile=self.s3_profile, filesystem_type=self.filesystem_type, + file_format=self.file_format, ) def create_iter( @@ -164,7 +176,7 @@ class ArrowFileIterator(StatefulIterator): else: filesystem = None self.dataset = pa.dataset.dataset( - self.dataset_files, format="arrow", filesystem=filesystem + self.dataset_files, format=self.file_format, filesystem=filesystem ) self.batch_iterator = self.dataset.to_batches( batch_size=self.arrow_batch_size @@ -173,13 +185,22 @@ class ArrowFileIterator(StatefulIterator): if self.batch_to_consume is not None: batch_columns: dict[str, list] = self.batch_to_consume self.batch_to_consume = None - sample_ids = batch_columns["sample_id"] - texts = batch_columns["text"] - entropies = batch_columns["entropies"] + if self.file_format == "arrow": + sample_ids = batch_columns["sample_id"] + texts = batch_columns["text"] + entropies = batch_columns["entropies"] + elif self.file_format == "json": + # This data hasn't been preprocessed to a uniform format, + # so we have to do it now and omit entropies + sample_ids = batch_columns[get_id_key(batch_columns)] + texts = get_text(batch_columns) + entropies = None + else: + raise ValueError(f"Unknown file format: {self.file_format}") for i in range(len(sample_ids)): out = BltExample( sample_id=sample_ids[i], - entropies=entropies[i], + entropies=entropies[i] if entropies is not None else None, text=texts[i], tokens=None, mask=None, @@ -191,13 +212,22 @@ class ArrowFileIterator(StatefulIterator): for batch in self.batch_iterator: batch_columns = batch.to_pydict() - sample_ids = batch_columns["sample_id"] - texts = batch_columns["text"] - entropies = batch_columns["entropies"] + if self.file_format == "arrow": + sample_ids = batch_columns["sample_id"] + texts = batch_columns["text"] + entropies = batch_columns["entropies"] + elif self.file_format == "json": + # This data hasn't been preprocessed to a uniform format, + # so we have to do it now and omit entropies + sample_ids = batch_columns[get_id_key(batch_columns)] + texts = get_text(batch_columns) + entropies = None + else: + raise ValueError(f"Unknown file format: {self.file_format}") for i in range(len(sample_ids)): out = BltExample( sample_id=sample_ids[i], - entropies=entropies[i], + entropies=entropies[i] if entropies is not None else None, text=texts[i], tokens=None, mask=None, @@ -231,13 +261,24 @@ class ArrowFileIterator(StatefulIterator): for batch in self.batch_iterator: if len(batch) > curr_remaining: batch_columns: dict[str, list] = batch.to_pydict() - batch_columns["sample_id"] = batch_columns["sample_id"][ - curr_remaining: - ] - batch_columns["entropies"] = batch_columns["entropies"][ - curr_remaining: - ] - batch_columns["text"] = batch_columns["text"][curr_remaining:] + if self.file_format == "arrow": + leftover_sample_ids = batch_columns["sample_id"][ + curr_remaining: + ] + leftover_entropies = batch_columns["entropies"][curr_remaining:] + leftover_texts = batch_columns["text"][curr_remaining:] + elif self.file_format == "json": + leftover_sample_ids = batch_columns[get_id_key(batch_columns)][ + curr_remaining: + ] + leftover_entropies = None + leftover_texts = get_text(batch_columns)[curr_remaining:] + else: + raise ValueError(f"Unknown file format: {self.file_format}") + + batch_columns["sample_id"] = leftover_sample_ids + batch_columns["entropies"] = leftover_entropies + batch_columns["text"] = leftover_texts self.batch_to_consume = batch_columns break elif len(batch) == curr_remaining: @@ -250,30 +291,3 @@ class ArrowFileIterator(StatefulIterator): logger.info( f"Finished setting arrow position to {target_row_num} for {self.dataset_files}" ) - - -TRAIN_DATA_FILE_PATTERN = "*.chunk.*.jsonl" - - -def find_and_sanitize_chunks( - dataset_path: str, - world_size: int, - file_pattern: str = TRAIN_DATA_FILE_PATTERN, - s3_profile: str | None = None, -): - fs = get_fs(dataset_path, s3_profile=s3_profile) - path_with_glob = os.path.join(dataset_path, file_pattern) - dataset_chunks = fs.glob(path_with_glob) - n_chunks = len(dataset_chunks) - - if n_chunks > world_size: - n_discard = n_chunks - world_size - dataset_chunks = dataset_chunks[:world_size] - else: - assert ( - world_size % n_chunks == 0 - ), "World size should be a multiple of number of chunks" - - assert n_chunks > 0, f"No valid chunks in {dataset_path}" - - return dataset_chunks diff --git a/bytelatent/distributed.py b/bytelatent/distributed.py index 168cb7c..f4b57e2 100644 --- a/bytelatent/distributed.py +++ b/bytelatent/distributed.py @@ -127,6 +127,16 @@ def dist_max(x: Union[int, float], mesh: DeviceMesh = None): return tensor +def dist_sum( + x: Union[int, float], mesh: DeviceMesh = None, reduce_dtype: torch.dtype = None +): + tensor = torch.tensor(x).cuda() + if reduce_dtype is not None: + tensor = tensor.to(reduce_dtype) + dist.all_reduce(tensor, op=ReduceOp.SUM, group=mesh.get_group() if mesh else None) + return tensor + + def dist_mean(x: Union[int, float], mesh: DeviceMesh = None): tensor = torch.tensor(x).cuda() dist.all_reduce(tensor, op=ReduceOp.AVG, group=mesh.get_group() if mesh else None) @@ -236,7 +246,7 @@ def setup_env(env_args: EnvironmentArgs): logger.warning(f"WARNING: Setting {name} to {value}") -def setup_torch_distributed(dist_args): +def setup_torch_distributed(dist_args: DistributedArgs): """ Handle single and multi-GPU / multi-node / SLURM jobs. Initialize the following variables: @@ -388,14 +398,14 @@ def clean_env(): def parallelize_model( - model, + model: torch.nn.Module, device_mesh, model_args, distributed_args: DistributedArgs, fsdp_grouping_plan: Optional[List[Tuple[str, bool]]] = None, tp_parallelize=None, no_recompute_ops=None, -): +) -> torch.nn.Module: if distributed_args.tp_size > 1: assert ( distributed_args.fsdp_type == "full_shard" @@ -429,6 +439,8 @@ def parallelize_model( device_mesh["dp_shard"].size() == 1 ), "dp_shard must be 1 for no_shard fsdp_type" + # TODO: Remove with something better + # model = model.to(param_dtype) fsdp_config = dict( mp_policy=( MixedPrecisionPolicy( diff --git a/bytelatent/eval.py b/bytelatent/eval.py index ae73066..1943a33 100644 --- a/bytelatent/eval.py +++ b/bytelatent/eval.py @@ -15,9 +15,16 @@ from lm_eval.api.model import LM from omegaconf import OmegaConf from pydantic import BaseModel, ConfigDict -from bytelatent.args import EvalArgs, ValidationArgs, parse_args +from bytelatent.args import ( + EvalArgs, + TrainArgs, + ValidationArgs, + find_and_sanitize_chunks, + parse_args, +) from bytelatent.checkpoint import CONSOLIDATE_FOLDER, consolidate_checkpoints from bytelatent.data.file_util import get_fs +from bytelatent.data.iterators.arrow_iterator import ArrowFileIterator from bytelatent.distributed import ( DistributedArgs, dist_mean_dict, @@ -117,19 +124,40 @@ class EvalHarnessLM(LM): return results -def eval_on_val(generator, val_args: ValidationArgs, train_cfg): - srcs = {} +def eval_on_val(generator, val_args: ValidationArgs, train_cfg: TrainArgs): + srcs = [] for src in val_args.sources: path = os.path.join(val_args.root_dir, src) - srcs[path] = 1.0 + srcs.append(path) + for src in train_cfg.data.sources: path = os.path.join(train_cfg.data.root_dir, src) - srcs[path] = 1.0 + srcs.append(path) - multi_state = init_choice_state( - "", srcs, 0, get_global_rank(), get_world_size(), "*.val.jsonl" - ) - path_to_iter = setup_sources(multi_state) + path_to_iter = {} + for path in srcs: + chunks = find_and_sanitize_chunks( + path, + world_size=1, + file_pattern="*.val.jsonl", + s3_profile=train_cfg.data.s3_profile, + ) + assert ( + len(chunks) == 1 + ), f"There should be only 1 chunk per validation file, but found: {chunks}" + chunk = chunks[0] + iterator = ArrowFileIterator( + dataset_files=[chunk], + file_path=None, + preprocess_dir=None, + entropy_model_name=None, + worker_id=0, + num_workers=1, + arrow_batch_size=train_cfg.data.arrow_batch_size, + s3_profile=train_cfg.data.s3_profile, + file_format="json", + ) + path_to_iter[path] = iterator max_gen_len = generator.max_gen_len # We temporarily lower max gen len @@ -137,16 +165,11 @@ def eval_on_val(generator, val_args: ValidationArgs, train_cfg): all_val_metrics = {} for src in path_to_iter: - jsonl_iterator = path_to_iter[src] + example_iterator = path_to_iter[src].create_iter() texts = [] logger.info(f"Running validation on {src}...") - for step, (content, state) in enumerate(jsonl_iterator): - if state["current_iter"] > 0 or ( - val_args.max_steps is not None and step >= val_args.max_steps - ): - break - content_key = "text" if ("text" in content) else "content" - texts.append(content[content_key]) + for step, example in enumerate(example_iterator): + texts.append(example.text) _, loglikelihood, _ = generator.generate(texts) @@ -191,7 +214,7 @@ def launch_eval(eval_args: EvalArgs): else: consolidate_path = os.path.join(eval_args.ckpt_dir, CONSOLIDATE_FOLDER) if not fs.exists(consolidate_path) and get_global_rank() == 0: - consolidate_path = consolidate_checkpoints(eval_args.ckpt_dir) + consolidate_path = consolidate_checkpoints(fs, eval_args.ckpt_dir) fs.mkdirs(eval_args.dump_dir, exist_ok=True) with fs.open(os.path.join(eval_args.dump_dir, "config.yaml"), "w") as f: @@ -210,10 +233,12 @@ def launch_eval(eval_args: EvalArgs): wrap = EvalHarnessLM(generator) # Redo - results = simple_evaluate(wrap, eval_args.harness.model_dump()) + results = simple_evaluate(wrap, **eval_args.harness.model_dump()) + val_results = None if eval_args.validation: val_results = eval_on_val(generator, eval_args.validation, train_cfg) + if get_global_rank() == 0: with fs.open(os.path.join(eval_args.dump_dir, "results.json"), "w") as f: f.write(json.dumps(results)) @@ -222,6 +247,7 @@ def launch_eval(eval_args: EvalArgs): with fs.open(os.path.join(eval_args.dump_dir, "validation.json"), "w") as f: f.write(json.dumps(val_results)) logger.info(f"All validation results: {val_results}") + if eval_args.metric_log_dir and get_global_rank() == 0: metric_log_path = os.path.join(eval_args.metric_log_dir, "metrics.eval.jsonl") diff --git a/bytelatent/generate.py b/bytelatent/generate.py index eb79d81..9d44f30 100644 --- a/bytelatent/generate.py +++ b/bytelatent/generate.py @@ -387,8 +387,7 @@ def load_consolidated_model_and_tokenizer( ): train_args_path = os.path.join(consolidated_path, "params.json") fs = get_fs(train_args_path) - with fs.open(train_args_path) as f: - train_args = TrainArgs.model_validate_json(f.read()) + train_args = TrainArgs.model_validate_json(fs.read_text(train_args_path)) if train_args.train_entropy_model: model_args = train_args.entropy_model @@ -401,7 +400,8 @@ def load_consolidated_model_and_tokenizer( train_args.distributed.model_dtype ] tokenizer = train_args.data.tokenizer_args.build() - st_dict = torch.load(consolidated_path / CONSOLIDATE_NAME, weights_only=True) + with fs.open(os.path.join(consolidated_path, CONSOLIDATE_NAME)) as f: + st_dict = torch.load(f, weights_only=True) model.load_state_dict(st_dict["model"]) model = model.cuda().eval() for param in model.parameters(): diff --git a/bytelatent/norms.py b/bytelatent/norms.py new file mode 100644 index 0000000..6028b13 --- /dev/null +++ b/bytelatent/norms.py @@ -0,0 +1,99 @@ +from typing import Optional, Tuple, Dict, List +import torch +from torch import Tensor +from torch.utils._foreach_utils import ( + _device_has_foreach_support, + _group_tensors_by_device_and_dtype, + _has_foreach_support, +) + + +@torch.no_grad() +def fixed_clip_grad_norm_( + parameters: torch.Tensor | list[torch.Tensor], + max_norm: float, + norm_type: float = 2.0, + error_if_nonfinite: bool = False, + foreach: Optional[bool] = None, +) -> torch.Tensor: + r"""Clip the gradient norm of an iterable of parameters. + + The norm is computed over the norms of the individual gradients of all parameters, + as if the norms of the individual gradients were concatenated into a single vector. + Gradients are modified in-place. + + Args: + parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a + single Tensor that will have gradients normalized + max_norm (float): max norm of the gradients + norm_type (float): type of the used p-norm. Can be ``'inf'`` for + infinity norm. + error_if_nonfinite (bool): if True, an error is thrown if the total + norm of the gradients from :attr:`parameters` is ``nan``, + ``inf``, or ``-inf``. Default: False (will switch to True in the future) + foreach (bool): use the faster foreach-based implementation. + If ``None``, use the foreach implementation for CUDA and CPU native tensors and silently + fall back to the slow implementation for other device types. + Default: ``None`` + + Returns: + Total norm of the parameter gradients (viewed as a single vector). + """ + if isinstance(parameters, torch.Tensor): + parameters = [parameters] + grads = [p.grad.to(torch.bfloat16) for p in parameters if p.grad is not None] + max_norm = float(max_norm) + norm_type = float(norm_type) + if len(grads) == 0: + return torch.tensor(0.0) + first_device = grads[0].device + grouped_grads: Dict[ + Tuple[torch.device, torch.dtype], Tuple[List[List[Tensor]], List[int]] + ] = _group_tensors_by_device_and_dtype( + [grads] + ) # type: ignore[assignment] + + norms: List[Tensor] = [] + for (device, _), ([device_grads], _) in grouped_grads.items(): # type: ignore[assignment] + if (foreach is None and _has_foreach_support(device_grads, device)) or ( + foreach and _device_has_foreach_support(device) + ): + norms.extend(torch._foreach_norm(device_grads, norm_type)) + elif foreach: + raise RuntimeError( + f"foreach=True was passed, but can't use the foreach API on {device.type} tensors" + ) + else: + norms.extend([torch.linalg.vector_norm(g, norm_type) for g in device_grads]) + + total_norm = torch.linalg.vector_norm( + torch.stack([norm.to(first_device) for norm in norms]), norm_type + ) + + if error_if_nonfinite and torch.logical_or(total_norm.isnan(), total_norm.isinf()): + raise RuntimeError( + f"The total norm of order {norm_type} for gradients from " + "`parameters` is non-finite, so it cannot be clipped. To disable " + "this error and scale the gradients by the non-finite norm anyway, " + "set `error_if_nonfinite=False`" + ) + clip_coef = max_norm / (total_norm + 1e-6) + # Note: multiplying by the clamped coef is redundant when the coef is clamped to 1, but doing so + # avoids a `if clip_coef < 1:` conditional which can require a CPU <=> device synchronization + # when the gradients do not reside in CPU memory. + clip_coef_clamped = torch.clamp(clip_coef, max=1.0) + for (device, _), ([device_grads], _) in grouped_grads.items(): # type: ignore[assignment] + if (foreach is None and _has_foreach_support(device_grads, device)) or ( + foreach and _device_has_foreach_support(device) + ): + torch._foreach_mul_(device_grads, clip_coef_clamped.to(device)) + elif foreach: + raise RuntimeError( + f"foreach=True was passed, but can't use the foreach API on {device.type} tensors" + ) + else: + clip_coef_clamped_device = clip_coef_clamped.to(device) + for g in device_grads: + g.mul_(clip_coef_clamped_device) + + return total_norm diff --git a/bytelatent/preprocess/preprocess_entropies.py b/bytelatent/preprocess/preprocess_entropies.py index 519da94..31a4802 100644 --- a/bytelatent/preprocess/preprocess_entropies.py +++ b/bytelatent/preprocess/preprocess_entropies.py @@ -15,29 +15,37 @@ from bytelatent.entropy_model import load_entropy_model from bytelatent.tokenizers.build_tokenizer import TokenizerArgs -def get_id_from_doc(doc: dict) -> int: +def get_id_key(doc: dict) -> int: """ We need a reliable way to ensure that samples from jsonl and arrow are the same, but there is no unique id field, so derive the best possible """ if "sample_id" in doc: - sample_id = doc["sample_id"] + return "sample_id" elif "title" in doc: - sample_id = doc["title"] + return "title" elif "qid" in doc: - sample_id = doc["qid"] + return "qid" elif "paper_id" in doc: - sample_id = doc["paper_id"] + return "paper_id" elif "path" in doc: - sample_id = doc["path"] + return "path" elif "url" in doc: - sample_id = doc["url"] + return "url" elif "id" in doc: - sample_id = doc["id"] + return "id" else: raise ValueError(f"Could not find a id key from: {doc.keys()}") - return str(sample_id) + + +def get_id_from_doc(doc: dict) -> int: + """ + We need a reliable way to ensure that samples from jsonl + and arrow are the same, but there is no unique id field, + so derive the best possible + """ + return str(doc[get_id_key(doc)]) def get_text(doc: dict): diff --git a/bytelatent/train.py b/bytelatent/train.py index 6b20ecd..2d43270 100644 --- a/bytelatent/train.py +++ b/bytelatent/train.py @@ -2,6 +2,8 @@ # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. import gc +import math +import numpy as np import logging import os import sys @@ -25,6 +27,7 @@ from torch.optim import lr_scheduler from bytelatent.args import TrainArgs, parse_args from bytelatent.checkpoint import CheckpointManager, load_from_checkpoint +from bytelatent.data.file_util import get_fs from bytelatent.data.iterators.multiprocess_iterator import ( MultiprocessIterator, MultiprocessIteratorState, @@ -33,7 +36,9 @@ from bytelatent.data.iterators.packing_iterator import PackingIteratorState from bytelatent.distributed import ( check_model_value_range, clean_env, + dist_mean, dist_mean_dict, + dist_sum, get_device_mesh, get_is_master, get_world_size, @@ -47,6 +52,7 @@ from bytelatent.eval import EVAL_FOLDER_NAME, launch_eval from bytelatent.logger import init_logger from bytelatent.metrics import GPUMemoryMonitor, MetricLogger, get_num_params from bytelatent.model.blt import ByteLatentTransformer +from bytelatent.norms import fixed_clip_grad_norm_ from bytelatent.optim import build_optimizer from bytelatent.probe import AutoProbeD from bytelatent.profiling import maybe_run_profiler @@ -295,8 +301,11 @@ def train(args: TrainArgs): if args.checkpoint.init_ckpt_path: logger.info(f"Loading initial model from {args.checkpoint.init_ckpt_path}") + ckpt_fs = get_fs( + args.checkpoint.init_ckpt_path, s3_profile=args.checkpoint.s3_profile + ) load_from_checkpoint( - args.checkpoint.init_ckpt_path, model, model_key="model" + ckpt_fs, args.checkpoint.init_ckpt_path, model, model_key="model" ) # Put model_key="" if its directly the model checkpoint model.rope_embeddings.reset_parameters() # For RoPe initialization since it's a buffer it might not be loaded else: @@ -364,6 +373,9 @@ def train(args: TrainArgs): time_last_log = timer() gc.collect() saved = False + step_losses: list[float] = [] + step_tok_losses: list[float] = [] + n_bytes: int = 0 while train_state.step < args.steps and ( args.max_steps is None or train_state.step < args.max_steps ): @@ -385,6 +397,24 @@ def train(args: TrainArgs): batch_patch_lengths = torch.from_numpy(batch.patch_lengths).cuda() mask = None if batch.mask is None else torch.from_numpy(batch.mask).cuda() + if args.data.tokenizer_args.name in ["bytes", "blt"]: + if mask is None: + n_bytes += batch_y.numel() + else: + n_bytes += mask.sum() + elif args.data.tokenizer_args.name in ["sp", "tiktoken"]: + for example in batch.y: + target_tokens = tokenizer.decode(example.tolist(), cut_at_eos=False) + n_bytes += ( + len(bytes(target_tokens, encoding="utf-8", errors="ignore")) + + sum(example == tokenizer.eos_id) + + sum(example == tokenizer.bos_id) + ) + else: + raise ValueError( + f"Unexpected tokenizer to count n_bytes for: {args.data.tokenizer_args.name}" + ) + if ( not args.train_entropy_model and args.model.encoder_enable_byte_ngrams @@ -459,7 +489,7 @@ def train(args: TrainArgs): batch_x, patch_lengths=batch_patch_lengths, ngram_ids=ngram_ids ) - loss, _ = compute_loss(pred, batch_y, mask, train_state.scale) + loss, tok_loss = compute_loss(pred, batch_y, mask, train_state.scale) # We scale loss with grad_acc_steps so the gradient is the same # regardless of grad_acc_steps @@ -470,8 +500,14 @@ def train(args: TrainArgs): # For logging we undo that scaling loss = loss.detach() * args.grad_acc_steps - grad_norm = torch.nn.utils.clip_grad_norm_( - model.parameters(), max_norm=args.optim.clip, foreach=True + # Undo loss scaling so downstream down't need to worry about it + step_losses.append((loss / train_state.scale).item()) + step_tok_losses.append(tok_loss / train_state.scale) + + # grad_norm = torch.nn.utils.clip_grad_norm_( + grad_norm = fixed_clip_grad_norm_( + model.parameters(), + max_norm=args.optim.clip, # foreach=True ) grad_norm = ( @@ -559,20 +595,33 @@ def train(args: TrainArgs): gpu_memory_monitor.reset_peak_stats() nwords_since_last_log = 0 time_last_log = timer() + stacked_tok_loss = torch.cat(step_tok_losses, dim=0) + total_tok_loss = dist_sum( + stacked_tok_loss.sum().item(), reduce_dtype=torch.bfloat16 + ) + total_n_bytes = dist_sum(n_bytes, reduce_dtype=torch.bfloat16) + avg_bpb = total_tok_loss / math.log(2) / total_n_bytes + avg_loss = dist_mean(np.mean(step_losses).item()) logger.info( f"step: {train_state.step}" f" acc: {train_state.acc_step}" - f" loss: {round(loss.item(),4):>7}" + f" loss: step={round(loss.item(),4):>7} avg={avg_loss}" + f" bpb: {avg_bpb:3f}" f" grad: {grad_norm:.2e}" f" flops: {FLOPS:.2e}" f" wps: {wps:.2e}" f" iter: {curr_iter_time:>7}" f" data: {data_load_time:>5}" f" lr: {curr_lr:.2e}" + f" n_bytes={total_n_bytes}" f" mem: {gpu_mem_stats.max_active_pct:.0f}%" f" pow: {gpu_mem_stats.power_draw/1000} W" ) + n_bytes = 0 + step_losses = [] + step_tok_losses = [] + if every_n_steps( train_state, args.checkpoint.dump.every, acc_step=0 ) or every_n_steps(train_state, args.checkpoint.eval.every, acc_step=0): From ab399e981d2d18f67867a209c823e2e04f939e46 Mon Sep 17 00:00:00 2001 From: Pedro Rodriguez Date: Tue, 4 Feb 2025 18:04:54 +0000 Subject: [PATCH 19/59] Several changes to enable entropy model training/eval Summary: - Make arrow iterator able to read from jsonl files, the entropies are omitted in this case - Make the data/checkpoint code fsspec compatible - Fix issues with all reduce with non-bf16 in dist_sum and norm computation. - Minimal fixes to get eval to run, it is slow currently - Add bpb numbers during training Test Plan: Run ``` torchrun --nproc-per-node 8 -m bytelatent.train config=internal/configs/entropy_model.yaml eval=null max_steps=10100 ``` --- bytelatent/args.py | 48 ++++++++- bytelatent/checkpoint.py | 97 +++++++++-------- bytelatent/data/iterators/arrow_iterator.py | 97 +++++++++-------- bytelatent/distributed.py | 18 +++- bytelatent/eval.py | 64 +++++++---- bytelatent/generate.py | 6 +- bytelatent/norms.py | 100 ++++++++++++++++++ bytelatent/preprocess/preprocess_entropies.py | 26 +++-- bytelatent/train.py | 59 ++++++++++- 9 files changed, 381 insertions(+), 134 deletions(-) create mode 100644 bytelatent/norms.py diff --git a/bytelatent/args.py b/bytelatent/args.py index d1bac46..33d12d6 100644 --- a/bytelatent/args.py +++ b/bytelatent/args.py @@ -1,8 +1,10 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. +import json import logging import os from typing import Any +import fsspec import numpy as np import yaml from omegaconf import OmegaConf @@ -10,11 +12,9 @@ from pydantic import BaseModel, ConfigDict from bytelatent.checkpoint import CheckpointArgs from bytelatent.data.data_types import Batch +from bytelatent.data.file_util import get_fs from bytelatent.data.iterators.abstract_iterator import StatefulIterator -from bytelatent.data.iterators.arrow_iterator import ( - ArrowFileIterator, - find_and_sanitize_chunks, -) +from bytelatent.data.iterators.arrow_iterator import ArrowFileIterator from bytelatent.data.iterators.looping_iterator import LoopingIterator from bytelatent.data.iterators.multiprocess_iterator import MultiprocessIterator from bytelatent.data.iterators.packing_iterator import PackingArgs, PackingIterator @@ -53,6 +53,43 @@ def parse_args(args_cls): return pydantic_args +def read_args_file(fs: fsspec.AbstractFileSystem, path: str) -> Any: + with fs.open(path, "rt") as f: + if path.endswith(".json"): + return json.load(f) + elif path.endswith(".yaml"): + return yaml.load(f) + else: + raise ValueError("Invalid args file format") + + +TRAIN_DATA_FILE_PATTERN = "*.chunk.*.jsonl" + + +def find_and_sanitize_chunks( + dataset_path: str, + world_size: int, + file_pattern: str, + s3_profile: str | None = None, +): + fs = get_fs(dataset_path, s3_profile=s3_profile) + path_with_glob = os.path.join(dataset_path, file_pattern) + dataset_chunks = fs.glob(path_with_glob) + n_chunks = len(dataset_chunks) + + if n_chunks > world_size: + n_discard = n_chunks - world_size + dataset_chunks = dataset_chunks[:world_size] + else: + assert ( + world_size % n_chunks == 0 + ), "World size should be a multiple of number of chunks" + + assert n_chunks > 0, f"No valid chunks in {dataset_path}" + + return dataset_chunks + + def distribute_data_to_rank( *, dataset_path: str, @@ -62,9 +99,10 @@ def distribute_data_to_rank( rank: int, world_size: int, s3_profile: str | None = None, + file_pattern: str = TRAIN_DATA_FILE_PATTERN, ) -> ArrowFileIterator: dataset_chunks = find_and_sanitize_chunks( - dataset_path, world_size, s3_profile=s3_profile + dataset_path, world_size, file_pattern, s3_profile=s3_profile ) n_workers_per_chunk = world_size // len(dataset_chunks) rank_to_arrow_iterator_params = [] diff --git a/bytelatent/checkpoint.py b/bytelatent/checkpoint.py index f213c84..6631673 100644 --- a/bytelatent/checkpoint.py +++ b/bytelatent/checkpoint.py @@ -4,8 +4,6 @@ import json import logging import os import re -from pathlib import Path -from typing import List, Optional, Tuple import fsspec import torch @@ -70,26 +68,29 @@ def consolidate_checkpoints(fs: fsspec.AbstractFileSystem, ckpt_dir: str): Returns the path to the consolidated checkpoint """ - consolidate_path = Path(ckpt_dir) / CONSOLIDATE_FOLDER - if not (consolidate_path / CONSOLIDATE_NAME).exists(): - consolidate_path.mkdir(exist_ok=True) - logger.info(f"Consolidating to: {str(consolidate_path)}") - dcp_to_torch_save(ckpt_dir, str(consolidate_path / CONSOLIDATE_NAME)) - (consolidate_path / CONFIG_NAME).write_text( - (Path(ckpt_dir) / CONFIG_NAME).read_text() + consolidate_path = os.path.join(ckpt_dir, CONSOLIDATE_FOLDER) + consolidate_name = os.path.join(consolidate_path, CONSOLIDATE_NAME) + if not fs.exists(consolidate_name): + fs.mkdirs(consolidate_path, exist_ok=True) + logger.info(f"Consolidating to: {consolidate_path}") + dcp_to_torch_save(ckpt_dir, consolidate_name) + fs.write_text( + os.path.join(consolidate_path, CONFIG_NAME), + fs.read_text(os.path.join(ckpt_dir, CONFIG_NAME)), ) logger.info("Consolidated !") return consolidate_path def load_from_checkpoint( + fs: fsspec.AbstractFileSystem, ckpt_dir: str, model: nn.Module, - optimizer: Optional[torch.optim.Optimizer] = None, + optimizer: torch.optim.Optimizer | None = None, model_key: str = "model", optim_key: str = "optim", ): - if not (Path(ckpt_dir) / ".metadata").exists(): + if not fs.exists(os.path.join(ckpt_dir, ".metadata")): raise ValueError( f"Please convert the checkpoint distcp format using `torch.distributed.checkpoint.format_utils.torch_save_to_dcp` before loading it" ) @@ -121,13 +122,13 @@ class CheckpointManager: self.existing_saves = self.get_existing_saves() - def get_existing_saves(self) -> List[Path]: + def get_existing_saves(self) -> list[str]: folders = [ p - for p in Path(self.path).iterdir() - if p.is_dir() and re.match(RE_FOLDER, p.name) + for p in self.fs.ls(self.path) + if self.fs.isdir(p) and re.match(RE_FOLDER, os.path.basename(p)) ] - folders.sort(key=lambda p: _get_key_step(p.name)) + folders.sort(key=lambda p: _get_key_step(os.path.basename(p))) return folders def clean_up(self): @@ -136,8 +137,9 @@ class CheckpointManager: eval_folders = [] other_folders = [] for p in self.existing_saves: - is_dump = _get_key_step(p.name) % self.dump_every.every == 0 - is_eval = _get_key_step(p.name) % self.eval_every.every == 0 + assert isinstance(p, str), f"Base path type: {p}" + is_dump = _get_key_step(os.path.basename(p)) % self.dump_every.every == 0 + is_eval = _get_key_step(os.path.basename(p)) % self.eval_every.every == 0 if is_dump: dump_folders.append(p) if is_eval: @@ -161,40 +163,39 @@ class CheckpointManager: if dist.get_rank() == 0: for folder in folder_to_remove: - for file in folder.iterdir(): - if file.is_file(): - file.unlink() - elif file.is_dir(): - assert file.name in [CONSOLIDATE_FOLDER] - for f in file.iterdir(): - f.unlink() - file.rmdir() - folder.rmdir() + for file in self.fs.ls(folder): + if self.fs.isfile(file): + self.fs.rm_file(file) + elif self.fs.isdir(file): + assert os.path.name(file) in [CONSOLIDATE_FOLDER] + for f in self.fs.ls(file): + self.fs.rm(f) + self.fs.rmdir(file) + self.fs.rmdir(folder) dist.barrier() self.existing_saves = list(folder_to_keep) - self.existing_saves.sort(key=lambda p: _get_key_step(p.name)) + self.existing_saves.sort(key=lambda p: _get_key_step(os.path.basename(p))) - def get_last_step_path(self, dp_rank: int = 0) -> Optional[Path]: + def get_last_step_path(self, dp_rank: int = 0) -> str | None: path = None for p in reversed(self.existing_saves): - if (p / TRAIN_STATE_NAME.format(dp_rank)).is_file(): + + if self.fs.isfile(os.path.join(p, TRAIN_STATE_NAME.format(dp_rank))): path = p break return path - def _create_folder(self, base_path: Path, folder_name: str) -> Path: - folder = base_path / folder_name + def _create_folder(self, base_path: str, folder_name: str) -> str: + folder = os.path.join(base_path, folder_name) if get_is_master(): - folder.mkdir(parents=False, exist_ok=True) + self.fs.mkdirs(folder, exist_ok=True) if dist.is_initialized(): dist.barrier() return folder - def _get_dp_tp_mesh( - self, device_mesh: Optional[DeviceMesh] = None - ) -> Tuple[int, int]: + def _get_dp_tp_mesh(self, device_mesh: DeviceMesh | None = None) -> tuple[int, int]: dp_rank = 0 tp_rank = 0 if device_mesh is not None: @@ -222,14 +223,14 @@ class CheckpointManager: model, optimizer, train_state, - config, - device_mesh: Optional[DeviceMesh] = None, + config: BaseModel, + device_mesh: DeviceMesh | None = None, ) -> bool: # When creating directory check if only rank0 or is there other solution - path = Path(self.path) + path = self.path curr_save_dir = self._create_folder(path, FOLDER_NAME.format(train_state.step)) - logger.info(f"Saving to: {str(curr_save_dir)}") + logger.info(f"Saving to: {curr_save_dir}") if dist.is_initialized(): dist.barrier() @@ -242,17 +243,19 @@ class CheckpointManager: if dist.is_initialized(): dist.barrier() + print("config type", type(config)) if get_is_master(): - config.dump_to_yaml_file(curr_save_dir / CONFIG_NAME) + self.fs.write_text( + os.path.join(curr_save_dir, CONFIG_NAME), config.model_dump_json() + ) # Add json dump here dp_rank, tp_rank = self._get_dp_tp_mesh(device_mesh) if tp_rank == 0: train_state_name = TRAIN_STATE_NAME.format(dp_rank) - logger.info( - f"Saving train state to: {str(curr_save_dir / train_state_name)}" - ) - with open(curr_save_dir / train_state_name, "w") as f: + train_state_full_path = os.path.join(curr_save_dir, train_state_name) + logger.info(f"Saving train state to: {train_state_full_path}") + with self.fs.open(train_state_full_path, "w") as f: json.dump(train_state.state_dict(), f) logger.info("Train state saved !") @@ -271,7 +274,7 @@ class CheckpointManager: optimizer, train_state, device_mesh: DeviceMesh, - path: Optional[Path] = None, + path: str | None = None, ): dp_rank, tp_rank = self._get_dp_tp_mesh(device_mesh) # Loading tries to load the provided path, if not available the last saved step and finally from the init path @@ -284,12 +287,12 @@ class CheckpointManager: # Only load train state if it's provided, the files exist and we're not loading from init path train_state_name = TRAIN_STATE_NAME.format(dp_rank) logger.info("Reloading train state") - with open(path / train_state_name, "r") as f: + with self.fs.open(os.path.join(path, train_state_name), "r") as f: train_state_dict = json.load(f) train_state.load_state_dict(train_state_dict) logger.info("Train state reloaded") - logger.info(f"Loading from: {str(path)}") + logger.info(f"Loading from: {path}") state_dict = self.get_state_dict( model=model, optimizer=optimizer, diff --git a/bytelatent/data/iterators/arrow_iterator.py b/bytelatent/data/iterators/arrow_iterator.py index 4e7b99e..1c68d3a 100644 --- a/bytelatent/data/iterators/arrow_iterator.py +++ b/bytelatent/data/iterators/arrow_iterator.py @@ -16,6 +16,7 @@ from bytelatent import ByteLatentError from bytelatent.data.data_types import BltExample from bytelatent.data.file_util import get_fs from bytelatent.data.iterators.abstract_iterator import IteratorState, StatefulIterator +from bytelatent.preprocess.preprocess_entropies import get_id_key, get_text logger = getLogger(__name__) @@ -32,6 +33,7 @@ class ArrowFileIteratorState(BaseModel, IteratorState): arrow_batch_size: int = 100 s3_profile: str | None filesystem_type: str | None = None + file_format: str def build(self) -> "ArrowFileIterator": arrow_file = ArrowFileIterator( @@ -44,6 +46,7 @@ class ArrowFileIteratorState(BaseModel, IteratorState): dataset_files=self.dataset_files, s3_profile=self.s3_profile, filesystem_type=self.filesystem_type, + file_format=self.file_format, ) if self.row_num != 0: arrow_file._set_row_num(self.row_num) @@ -70,6 +73,7 @@ class ArrowFileIterator(StatefulIterator): dataset_files: list[str] | None = None, s3_profile: str | None = None, filesystem_type: str | None = None, + file_format: str = "arrow", ): assert 0 <= worker_id < num_workers, (worker_id, num_workers) if file_path is None and dataset_files is None: @@ -87,12 +91,16 @@ class ArrowFileIterator(StatefulIterator): self.arrow_batch_size = arrow_batch_size self.s3_profile = s3_profile self.filesystem_type = filesystem_type + self.file_format = file_format self.fs = None if self.filesystem_type is not None: if self.filesystem_type == "file": self.fs = fsspec.filesystem("file") elif self.filesystem_type == "s3": self.fs = fsspec.filesystem("s3", profile=s3_profile) + else: + raise ValueError("Unknown filesystem") + logger.info("Arrow iterator using fs=%s", self.fs) if dataset_files is None: # Prepare arrow shards @@ -153,6 +161,7 @@ class ArrowFileIterator(StatefulIterator): dataset_files=self.dataset_files, s3_profile=self.s3_profile, filesystem_type=self.filesystem_type, + file_format=self.file_format, ) def create_iter( @@ -164,7 +173,7 @@ class ArrowFileIterator(StatefulIterator): else: filesystem = None self.dataset = pa.dataset.dataset( - self.dataset_files, format="arrow", filesystem=filesystem + self.dataset_files, format=self.file_format, filesystem=filesystem ) self.batch_iterator = self.dataset.to_batches( batch_size=self.arrow_batch_size @@ -173,13 +182,22 @@ class ArrowFileIterator(StatefulIterator): if self.batch_to_consume is not None: batch_columns: dict[str, list] = self.batch_to_consume self.batch_to_consume = None - sample_ids = batch_columns["sample_id"] - texts = batch_columns["text"] - entropies = batch_columns["entropies"] + if self.file_format == "arrow": + sample_ids = batch_columns["sample_id"] + texts = batch_columns["text"] + entropies = batch_columns["entropies"] + elif self.file_format == "json": + # This data hasn't been preprocessed to a uniform format, + # so we have to do it now and omit entropies + sample_ids = batch_columns[get_id_key(batch_columns)] + texts = get_text(batch_columns) + entropies = None + else: + raise ValueError(f"Unknown file format: {self.file_format}") for i in range(len(sample_ids)): out = BltExample( sample_id=sample_ids[i], - entropies=entropies[i], + entropies=entropies[i] if entropies is not None else None, text=texts[i], tokens=None, mask=None, @@ -191,13 +209,22 @@ class ArrowFileIterator(StatefulIterator): for batch in self.batch_iterator: batch_columns = batch.to_pydict() - sample_ids = batch_columns["sample_id"] - texts = batch_columns["text"] - entropies = batch_columns["entropies"] + if self.file_format == "arrow": + sample_ids = batch_columns["sample_id"] + texts = batch_columns["text"] + entropies = batch_columns["entropies"] + elif self.file_format == "json": + # This data hasn't been preprocessed to a uniform format, + # so we have to do it now and omit entropies + sample_ids = batch_columns[get_id_key(batch_columns)] + texts = get_text(batch_columns) + entropies = None + else: + raise ValueError(f"Unknown file format: {self.file_format}") for i in range(len(sample_ids)): out = BltExample( sample_id=sample_ids[i], - entropies=entropies[i], + entropies=entropies[i] if entropies is not None else None, text=texts[i], tokens=None, mask=None, @@ -231,13 +258,24 @@ class ArrowFileIterator(StatefulIterator): for batch in self.batch_iterator: if len(batch) > curr_remaining: batch_columns: dict[str, list] = batch.to_pydict() - batch_columns["sample_id"] = batch_columns["sample_id"][ - curr_remaining: - ] - batch_columns["entropies"] = batch_columns["entropies"][ - curr_remaining: - ] - batch_columns["text"] = batch_columns["text"][curr_remaining:] + if self.file_format == "arrow": + leftover_sample_ids = batch_columns["sample_id"][ + curr_remaining: + ] + leftover_entropies = batch_columns["entropies"][curr_remaining:] + leftover_texts = batch_columns["text"][curr_remaining:] + elif self.file_format == "json": + leftover_sample_ids = batch_columns[get_id_key(batch_columns)][ + curr_remaining: + ] + leftover_entropies = None + leftover_texts = get_text(batch_columns)[curr_remaining:] + else: + raise ValueError(f"Unknown file format: {self.file_format}") + + batch_columns["sample_id"] = leftover_sample_ids + batch_columns["entropies"] = leftover_entropies + batch_columns["text"] = leftover_texts self.batch_to_consume = batch_columns break elif len(batch) == curr_remaining: @@ -250,30 +288,3 @@ class ArrowFileIterator(StatefulIterator): logger.info( f"Finished setting arrow position to {target_row_num} for {self.dataset_files}" ) - - -TRAIN_DATA_FILE_PATTERN = "*.chunk.*.jsonl" - - -def find_and_sanitize_chunks( - dataset_path: str, - world_size: int, - file_pattern: str = TRAIN_DATA_FILE_PATTERN, - s3_profile: str | None = None, -): - fs = get_fs(dataset_path, s3_profile=s3_profile) - path_with_glob = os.path.join(dataset_path, file_pattern) - dataset_chunks = fs.glob(path_with_glob) - n_chunks = len(dataset_chunks) - - if n_chunks > world_size: - n_discard = n_chunks - world_size - dataset_chunks = dataset_chunks[:world_size] - else: - assert ( - world_size % n_chunks == 0 - ), "World size should be a multiple of number of chunks" - - assert n_chunks > 0, f"No valid chunks in {dataset_path}" - - return dataset_chunks diff --git a/bytelatent/distributed.py b/bytelatent/distributed.py index 168cb7c..f4b57e2 100644 --- a/bytelatent/distributed.py +++ b/bytelatent/distributed.py @@ -127,6 +127,16 @@ def dist_max(x: Union[int, float], mesh: DeviceMesh = None): return tensor +def dist_sum( + x: Union[int, float], mesh: DeviceMesh = None, reduce_dtype: torch.dtype = None +): + tensor = torch.tensor(x).cuda() + if reduce_dtype is not None: + tensor = tensor.to(reduce_dtype) + dist.all_reduce(tensor, op=ReduceOp.SUM, group=mesh.get_group() if mesh else None) + return tensor + + def dist_mean(x: Union[int, float], mesh: DeviceMesh = None): tensor = torch.tensor(x).cuda() dist.all_reduce(tensor, op=ReduceOp.AVG, group=mesh.get_group() if mesh else None) @@ -236,7 +246,7 @@ def setup_env(env_args: EnvironmentArgs): logger.warning(f"WARNING: Setting {name} to {value}") -def setup_torch_distributed(dist_args): +def setup_torch_distributed(dist_args: DistributedArgs): """ Handle single and multi-GPU / multi-node / SLURM jobs. Initialize the following variables: @@ -388,14 +398,14 @@ def clean_env(): def parallelize_model( - model, + model: torch.nn.Module, device_mesh, model_args, distributed_args: DistributedArgs, fsdp_grouping_plan: Optional[List[Tuple[str, bool]]] = None, tp_parallelize=None, no_recompute_ops=None, -): +) -> torch.nn.Module: if distributed_args.tp_size > 1: assert ( distributed_args.fsdp_type == "full_shard" @@ -429,6 +439,8 @@ def parallelize_model( device_mesh["dp_shard"].size() == 1 ), "dp_shard must be 1 for no_shard fsdp_type" + # TODO: Remove with something better + # model = model.to(param_dtype) fsdp_config = dict( mp_policy=( MixedPrecisionPolicy( diff --git a/bytelatent/eval.py b/bytelatent/eval.py index ae73066..1943a33 100644 --- a/bytelatent/eval.py +++ b/bytelatent/eval.py @@ -15,9 +15,16 @@ from lm_eval.api.model import LM from omegaconf import OmegaConf from pydantic import BaseModel, ConfigDict -from bytelatent.args import EvalArgs, ValidationArgs, parse_args +from bytelatent.args import ( + EvalArgs, + TrainArgs, + ValidationArgs, + find_and_sanitize_chunks, + parse_args, +) from bytelatent.checkpoint import CONSOLIDATE_FOLDER, consolidate_checkpoints from bytelatent.data.file_util import get_fs +from bytelatent.data.iterators.arrow_iterator import ArrowFileIterator from bytelatent.distributed import ( DistributedArgs, dist_mean_dict, @@ -117,19 +124,40 @@ class EvalHarnessLM(LM): return results -def eval_on_val(generator, val_args: ValidationArgs, train_cfg): - srcs = {} +def eval_on_val(generator, val_args: ValidationArgs, train_cfg: TrainArgs): + srcs = [] for src in val_args.sources: path = os.path.join(val_args.root_dir, src) - srcs[path] = 1.0 + srcs.append(path) + for src in train_cfg.data.sources: path = os.path.join(train_cfg.data.root_dir, src) - srcs[path] = 1.0 + srcs.append(path) - multi_state = init_choice_state( - "", srcs, 0, get_global_rank(), get_world_size(), "*.val.jsonl" - ) - path_to_iter = setup_sources(multi_state) + path_to_iter = {} + for path in srcs: + chunks = find_and_sanitize_chunks( + path, + world_size=1, + file_pattern="*.val.jsonl", + s3_profile=train_cfg.data.s3_profile, + ) + assert ( + len(chunks) == 1 + ), f"There should be only 1 chunk per validation file, but found: {chunks}" + chunk = chunks[0] + iterator = ArrowFileIterator( + dataset_files=[chunk], + file_path=None, + preprocess_dir=None, + entropy_model_name=None, + worker_id=0, + num_workers=1, + arrow_batch_size=train_cfg.data.arrow_batch_size, + s3_profile=train_cfg.data.s3_profile, + file_format="json", + ) + path_to_iter[path] = iterator max_gen_len = generator.max_gen_len # We temporarily lower max gen len @@ -137,16 +165,11 @@ def eval_on_val(generator, val_args: ValidationArgs, train_cfg): all_val_metrics = {} for src in path_to_iter: - jsonl_iterator = path_to_iter[src] + example_iterator = path_to_iter[src].create_iter() texts = [] logger.info(f"Running validation on {src}...") - for step, (content, state) in enumerate(jsonl_iterator): - if state["current_iter"] > 0 or ( - val_args.max_steps is not None and step >= val_args.max_steps - ): - break - content_key = "text" if ("text" in content) else "content" - texts.append(content[content_key]) + for step, example in enumerate(example_iterator): + texts.append(example.text) _, loglikelihood, _ = generator.generate(texts) @@ -191,7 +214,7 @@ def launch_eval(eval_args: EvalArgs): else: consolidate_path = os.path.join(eval_args.ckpt_dir, CONSOLIDATE_FOLDER) if not fs.exists(consolidate_path) and get_global_rank() == 0: - consolidate_path = consolidate_checkpoints(eval_args.ckpt_dir) + consolidate_path = consolidate_checkpoints(fs, eval_args.ckpt_dir) fs.mkdirs(eval_args.dump_dir, exist_ok=True) with fs.open(os.path.join(eval_args.dump_dir, "config.yaml"), "w") as f: @@ -210,10 +233,12 @@ def launch_eval(eval_args: EvalArgs): wrap = EvalHarnessLM(generator) # Redo - results = simple_evaluate(wrap, eval_args.harness.model_dump()) + results = simple_evaluate(wrap, **eval_args.harness.model_dump()) + val_results = None if eval_args.validation: val_results = eval_on_val(generator, eval_args.validation, train_cfg) + if get_global_rank() == 0: with fs.open(os.path.join(eval_args.dump_dir, "results.json"), "w") as f: f.write(json.dumps(results)) @@ -222,6 +247,7 @@ def launch_eval(eval_args: EvalArgs): with fs.open(os.path.join(eval_args.dump_dir, "validation.json"), "w") as f: f.write(json.dumps(val_results)) logger.info(f"All validation results: {val_results}") + if eval_args.metric_log_dir and get_global_rank() == 0: metric_log_path = os.path.join(eval_args.metric_log_dir, "metrics.eval.jsonl") diff --git a/bytelatent/generate.py b/bytelatent/generate.py index eb79d81..9d44f30 100644 --- a/bytelatent/generate.py +++ b/bytelatent/generate.py @@ -387,8 +387,7 @@ def load_consolidated_model_and_tokenizer( ): train_args_path = os.path.join(consolidated_path, "params.json") fs = get_fs(train_args_path) - with fs.open(train_args_path) as f: - train_args = TrainArgs.model_validate_json(f.read()) + train_args = TrainArgs.model_validate_json(fs.read_text(train_args_path)) if train_args.train_entropy_model: model_args = train_args.entropy_model @@ -401,7 +400,8 @@ def load_consolidated_model_and_tokenizer( train_args.distributed.model_dtype ] tokenizer = train_args.data.tokenizer_args.build() - st_dict = torch.load(consolidated_path / CONSOLIDATE_NAME, weights_only=True) + with fs.open(os.path.join(consolidated_path, CONSOLIDATE_NAME)) as f: + st_dict = torch.load(f, weights_only=True) model.load_state_dict(st_dict["model"]) model = model.cuda().eval() for param in model.parameters(): diff --git a/bytelatent/norms.py b/bytelatent/norms.py new file mode 100644 index 0000000..81d1652 --- /dev/null +++ b/bytelatent/norms.py @@ -0,0 +1,100 @@ +from typing import Dict, List, Optional, Tuple + +import torch +from torch import Tensor +from torch.utils._foreach_utils import ( + _device_has_foreach_support, + _group_tensors_by_device_and_dtype, + _has_foreach_support, +) + + +@torch.no_grad() +def fixed_clip_grad_norm_( + parameters: torch.Tensor | list[torch.Tensor], + max_norm: float, + norm_type: float = 2.0, + error_if_nonfinite: bool = False, + foreach: Optional[bool] = None, +) -> torch.Tensor: + r"""Clip the gradient norm of an iterable of parameters. + + The norm is computed over the norms of the individual gradients of all parameters, + as if the norms of the individual gradients were concatenated into a single vector. + Gradients are modified in-place. + + Args: + parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a + single Tensor that will have gradients normalized + max_norm (float): max norm of the gradients + norm_type (float): type of the used p-norm. Can be ``'inf'`` for + infinity norm. + error_if_nonfinite (bool): if True, an error is thrown if the total + norm of the gradients from :attr:`parameters` is ``nan``, + ``inf``, or ``-inf``. Default: False (will switch to True in the future) + foreach (bool): use the faster foreach-based implementation. + If ``None``, use the foreach implementation for CUDA and CPU native tensors and silently + fall back to the slow implementation for other device types. + Default: ``None`` + + Returns: + Total norm of the parameter gradients (viewed as a single vector). + """ + if isinstance(parameters, torch.Tensor): + parameters = [parameters] + grads = [p.grad.to(torch.bfloat16) for p in parameters if p.grad is not None] + max_norm = float(max_norm) + norm_type = float(norm_type) + if len(grads) == 0: + return torch.tensor(0.0) + first_device = grads[0].device + grouped_grads: Dict[ + Tuple[torch.device, torch.dtype], Tuple[List[List[Tensor]], List[int]] + ] = _group_tensors_by_device_and_dtype( + [grads] + ) # type: ignore[assignment] + + norms: List[Tensor] = [] + for (device, _), ([device_grads], _) in grouped_grads.items(): # type: ignore[assignment] + if (foreach is None and _has_foreach_support(device_grads, device)) or ( + foreach and _device_has_foreach_support(device) + ): + norms.extend(torch._foreach_norm(device_grads, norm_type)) + elif foreach: + raise RuntimeError( + f"foreach=True was passed, but can't use the foreach API on {device.type} tensors" + ) + else: + norms.extend([torch.linalg.vector_norm(g, norm_type) for g in device_grads]) + + total_norm = torch.linalg.vector_norm( + torch.stack([norm.to(first_device) for norm in norms]), norm_type + ) + + if error_if_nonfinite and torch.logical_or(total_norm.isnan(), total_norm.isinf()): + raise RuntimeError( + f"The total norm of order {norm_type} for gradients from " + "`parameters` is non-finite, so it cannot be clipped. To disable " + "this error and scale the gradients by the non-finite norm anyway, " + "set `error_if_nonfinite=False`" + ) + clip_coef = max_norm / (total_norm + 1e-6) + # Note: multiplying by the clamped coef is redundant when the coef is clamped to 1, but doing so + # avoids a `if clip_coef < 1:` conditional which can require a CPU <=> device synchronization + # when the gradients do not reside in CPU memory. + clip_coef_clamped = torch.clamp(clip_coef, max=1.0) + for (device, _), ([device_grads], _) in grouped_grads.items(): # type: ignore[assignment] + if (foreach is None and _has_foreach_support(device_grads, device)) or ( + foreach and _device_has_foreach_support(device) + ): + torch._foreach_mul_(device_grads, clip_coef_clamped.to(device)) + elif foreach: + raise RuntimeError( + f"foreach=True was passed, but can't use the foreach API on {device.type} tensors" + ) + else: + clip_coef_clamped_device = clip_coef_clamped.to(device) + for g in device_grads: + g.mul_(clip_coef_clamped_device) + + return total_norm diff --git a/bytelatent/preprocess/preprocess_entropies.py b/bytelatent/preprocess/preprocess_entropies.py index 519da94..31a4802 100644 --- a/bytelatent/preprocess/preprocess_entropies.py +++ b/bytelatent/preprocess/preprocess_entropies.py @@ -15,29 +15,37 @@ from bytelatent.entropy_model import load_entropy_model from bytelatent.tokenizers.build_tokenizer import TokenizerArgs -def get_id_from_doc(doc: dict) -> int: +def get_id_key(doc: dict) -> int: """ We need a reliable way to ensure that samples from jsonl and arrow are the same, but there is no unique id field, so derive the best possible """ if "sample_id" in doc: - sample_id = doc["sample_id"] + return "sample_id" elif "title" in doc: - sample_id = doc["title"] + return "title" elif "qid" in doc: - sample_id = doc["qid"] + return "qid" elif "paper_id" in doc: - sample_id = doc["paper_id"] + return "paper_id" elif "path" in doc: - sample_id = doc["path"] + return "path" elif "url" in doc: - sample_id = doc["url"] + return "url" elif "id" in doc: - sample_id = doc["id"] + return "id" else: raise ValueError(f"Could not find a id key from: {doc.keys()}") - return str(sample_id) + + +def get_id_from_doc(doc: dict) -> int: + """ + We need a reliable way to ensure that samples from jsonl + and arrow are the same, but there is no unique id field, + so derive the best possible + """ + return str(doc[get_id_key(doc)]) def get_text(doc: dict): diff --git a/bytelatent/train.py b/bytelatent/train.py index 6b20ecd..16c6865 100644 --- a/bytelatent/train.py +++ b/bytelatent/train.py @@ -3,6 +3,7 @@ import gc import logging +import math import os import sys from contextlib import ExitStack @@ -12,6 +13,7 @@ from pathlib import Path from timeit import default_timer as timer from typing import Any, TypeVar +import numpy as np import torch import torch.distributed import torch.nn.functional @@ -25,6 +27,7 @@ from torch.optim import lr_scheduler from bytelatent.args import TrainArgs, parse_args from bytelatent.checkpoint import CheckpointManager, load_from_checkpoint +from bytelatent.data.file_util import get_fs from bytelatent.data.iterators.multiprocess_iterator import ( MultiprocessIterator, MultiprocessIteratorState, @@ -33,7 +36,9 @@ from bytelatent.data.iterators.packing_iterator import PackingIteratorState from bytelatent.distributed import ( check_model_value_range, clean_env, + dist_mean, dist_mean_dict, + dist_sum, get_device_mesh, get_is_master, get_world_size, @@ -47,6 +52,7 @@ from bytelatent.eval import EVAL_FOLDER_NAME, launch_eval from bytelatent.logger import init_logger from bytelatent.metrics import GPUMemoryMonitor, MetricLogger, get_num_params from bytelatent.model.blt import ByteLatentTransformer +from bytelatent.norms import fixed_clip_grad_norm_ from bytelatent.optim import build_optimizer from bytelatent.probe import AutoProbeD from bytelatent.profiling import maybe_run_profiler @@ -295,8 +301,11 @@ def train(args: TrainArgs): if args.checkpoint.init_ckpt_path: logger.info(f"Loading initial model from {args.checkpoint.init_ckpt_path}") + ckpt_fs = get_fs( + args.checkpoint.init_ckpt_path, s3_profile=args.checkpoint.s3_profile + ) load_from_checkpoint( - args.checkpoint.init_ckpt_path, model, model_key="model" + ckpt_fs, args.checkpoint.init_ckpt_path, model, model_key="model" ) # Put model_key="" if its directly the model checkpoint model.rope_embeddings.reset_parameters() # For RoPe initialization since it's a buffer it might not be loaded else: @@ -364,6 +373,9 @@ def train(args: TrainArgs): time_last_log = timer() gc.collect() saved = False + step_losses: list[float] = [] + step_tok_losses: list[float] = [] + n_bytes: int = 0 while train_state.step < args.steps and ( args.max_steps is None or train_state.step < args.max_steps ): @@ -385,6 +397,24 @@ def train(args: TrainArgs): batch_patch_lengths = torch.from_numpy(batch.patch_lengths).cuda() mask = None if batch.mask is None else torch.from_numpy(batch.mask).cuda() + if args.data.tokenizer_args.name in ["bytes", "blt"]: + if mask is None: + n_bytes += batch_y.numel() + else: + n_bytes += mask.sum() + elif args.data.tokenizer_args.name in ["sp", "tiktoken"]: + for example in batch.y: + target_tokens = tokenizer.decode(example.tolist(), cut_at_eos=False) + n_bytes += ( + len(bytes(target_tokens, encoding="utf-8", errors="ignore")) + + sum(example == tokenizer.eos_id) + + sum(example == tokenizer.bos_id) + ) + else: + raise ValueError( + f"Unexpected tokenizer to count n_bytes for: {args.data.tokenizer_args.name}" + ) + if ( not args.train_entropy_model and args.model.encoder_enable_byte_ngrams @@ -459,7 +489,7 @@ def train(args: TrainArgs): batch_x, patch_lengths=batch_patch_lengths, ngram_ids=ngram_ids ) - loss, _ = compute_loss(pred, batch_y, mask, train_state.scale) + loss, tok_loss = compute_loss(pred, batch_y, mask, train_state.scale) # We scale loss with grad_acc_steps so the gradient is the same # regardless of grad_acc_steps @@ -470,8 +500,14 @@ def train(args: TrainArgs): # For logging we undo that scaling loss = loss.detach() * args.grad_acc_steps - grad_norm = torch.nn.utils.clip_grad_norm_( - model.parameters(), max_norm=args.optim.clip, foreach=True + # Undo loss scaling so downstream down't need to worry about it + step_losses.append((loss / train_state.scale).item()) + step_tok_losses.append(tok_loss / train_state.scale) + + # grad_norm = torch.nn.utils.clip_grad_norm_( + grad_norm = fixed_clip_grad_norm_( + model.parameters(), + max_norm=args.optim.clip, # foreach=True ) grad_norm = ( @@ -559,20 +595,33 @@ def train(args: TrainArgs): gpu_memory_monitor.reset_peak_stats() nwords_since_last_log = 0 time_last_log = timer() + stacked_tok_loss = torch.cat(step_tok_losses, dim=0) + total_tok_loss = dist_sum( + stacked_tok_loss.sum().item(), reduce_dtype=torch.bfloat16 + ) + total_n_bytes = dist_sum(n_bytes, reduce_dtype=torch.bfloat16) + avg_bpb = total_tok_loss / math.log(2) / total_n_bytes + avg_loss = dist_mean(np.mean(step_losses).item()) logger.info( f"step: {train_state.step}" f" acc: {train_state.acc_step}" - f" loss: {round(loss.item(),4):>7}" + f" loss: step={round(loss.item(),4):>7} avg={avg_loss}" + f" bpb: {avg_bpb:3f}" f" grad: {grad_norm:.2e}" f" flops: {FLOPS:.2e}" f" wps: {wps:.2e}" f" iter: {curr_iter_time:>7}" f" data: {data_load_time:>5}" f" lr: {curr_lr:.2e}" + f" n_bytes={total_n_bytes}" f" mem: {gpu_mem_stats.max_active_pct:.0f}%" f" pow: {gpu_mem_stats.power_draw/1000} W" ) + n_bytes = 0 + step_losses = [] + step_tok_losses = [] + if every_n_steps( train_state, args.checkpoint.dump.every, acc_step=0 ) or every_n_steps(train_state, args.checkpoint.eval.every, acc_step=0): From bc39591032f26054dbeb8baa303f648fcabef828 Mon Sep 17 00:00:00 2001 From: Pedro Rodriguez Date: Tue, 4 Feb 2025 18:19:41 +0000 Subject: [PATCH 20/59] Several changes to enable entropy model training/eval Summary: - Make arrow iterator able to read from jsonl files, the entropies are omitted in this case - Make the data/checkpoint code fsspec compatible - Fix issues with all reduce with non-bf16 in dist_sum and norm computation. - Minimal fixes to get eval to run, it is slow currently - Add bpb numbers during training Test Plan: Run ``` torchrun --nproc-per-node 8 -m bytelatent.train config=internal/configs/entropy_model.yaml eval=null max_steps=10100 ``` ``` python -m bytelatent.train config=internal/configs/s3_debug.yaml eval=null ``` ``` torchrun --nproc-per-node 8 -m bytelatent.train config=internal/configs/s3_debug.yaml eval=null ``` --- bytelatent/args.py | 48 ++++++++- bytelatent/checkpoint.py | 97 +++++++++-------- bytelatent/data/iterators/arrow_iterator.py | 97 +++++++++-------- bytelatent/distributed.py | 18 +++- bytelatent/eval.py | 64 +++++++---- bytelatent/generate.py | 6 +- bytelatent/norms.py | 100 ++++++++++++++++++ bytelatent/preprocess/preprocess_entropies.py | 26 +++-- bytelatent/train.py | 59 ++++++++++- 9 files changed, 381 insertions(+), 134 deletions(-) create mode 100644 bytelatent/norms.py diff --git a/bytelatent/args.py b/bytelatent/args.py index d1bac46..33d12d6 100644 --- a/bytelatent/args.py +++ b/bytelatent/args.py @@ -1,8 +1,10 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. +import json import logging import os from typing import Any +import fsspec import numpy as np import yaml from omegaconf import OmegaConf @@ -10,11 +12,9 @@ from pydantic import BaseModel, ConfigDict from bytelatent.checkpoint import CheckpointArgs from bytelatent.data.data_types import Batch +from bytelatent.data.file_util import get_fs from bytelatent.data.iterators.abstract_iterator import StatefulIterator -from bytelatent.data.iterators.arrow_iterator import ( - ArrowFileIterator, - find_and_sanitize_chunks, -) +from bytelatent.data.iterators.arrow_iterator import ArrowFileIterator from bytelatent.data.iterators.looping_iterator import LoopingIterator from bytelatent.data.iterators.multiprocess_iterator import MultiprocessIterator from bytelatent.data.iterators.packing_iterator import PackingArgs, PackingIterator @@ -53,6 +53,43 @@ def parse_args(args_cls): return pydantic_args +def read_args_file(fs: fsspec.AbstractFileSystem, path: str) -> Any: + with fs.open(path, "rt") as f: + if path.endswith(".json"): + return json.load(f) + elif path.endswith(".yaml"): + return yaml.load(f) + else: + raise ValueError("Invalid args file format") + + +TRAIN_DATA_FILE_PATTERN = "*.chunk.*.jsonl" + + +def find_and_sanitize_chunks( + dataset_path: str, + world_size: int, + file_pattern: str, + s3_profile: str | None = None, +): + fs = get_fs(dataset_path, s3_profile=s3_profile) + path_with_glob = os.path.join(dataset_path, file_pattern) + dataset_chunks = fs.glob(path_with_glob) + n_chunks = len(dataset_chunks) + + if n_chunks > world_size: + n_discard = n_chunks - world_size + dataset_chunks = dataset_chunks[:world_size] + else: + assert ( + world_size % n_chunks == 0 + ), "World size should be a multiple of number of chunks" + + assert n_chunks > 0, f"No valid chunks in {dataset_path}" + + return dataset_chunks + + def distribute_data_to_rank( *, dataset_path: str, @@ -62,9 +99,10 @@ def distribute_data_to_rank( rank: int, world_size: int, s3_profile: str | None = None, + file_pattern: str = TRAIN_DATA_FILE_PATTERN, ) -> ArrowFileIterator: dataset_chunks = find_and_sanitize_chunks( - dataset_path, world_size, s3_profile=s3_profile + dataset_path, world_size, file_pattern, s3_profile=s3_profile ) n_workers_per_chunk = world_size // len(dataset_chunks) rank_to_arrow_iterator_params = [] diff --git a/bytelatent/checkpoint.py b/bytelatent/checkpoint.py index f213c84..6631673 100644 --- a/bytelatent/checkpoint.py +++ b/bytelatent/checkpoint.py @@ -4,8 +4,6 @@ import json import logging import os import re -from pathlib import Path -from typing import List, Optional, Tuple import fsspec import torch @@ -70,26 +68,29 @@ def consolidate_checkpoints(fs: fsspec.AbstractFileSystem, ckpt_dir: str): Returns the path to the consolidated checkpoint """ - consolidate_path = Path(ckpt_dir) / CONSOLIDATE_FOLDER - if not (consolidate_path / CONSOLIDATE_NAME).exists(): - consolidate_path.mkdir(exist_ok=True) - logger.info(f"Consolidating to: {str(consolidate_path)}") - dcp_to_torch_save(ckpt_dir, str(consolidate_path / CONSOLIDATE_NAME)) - (consolidate_path / CONFIG_NAME).write_text( - (Path(ckpt_dir) / CONFIG_NAME).read_text() + consolidate_path = os.path.join(ckpt_dir, CONSOLIDATE_FOLDER) + consolidate_name = os.path.join(consolidate_path, CONSOLIDATE_NAME) + if not fs.exists(consolidate_name): + fs.mkdirs(consolidate_path, exist_ok=True) + logger.info(f"Consolidating to: {consolidate_path}") + dcp_to_torch_save(ckpt_dir, consolidate_name) + fs.write_text( + os.path.join(consolidate_path, CONFIG_NAME), + fs.read_text(os.path.join(ckpt_dir, CONFIG_NAME)), ) logger.info("Consolidated !") return consolidate_path def load_from_checkpoint( + fs: fsspec.AbstractFileSystem, ckpt_dir: str, model: nn.Module, - optimizer: Optional[torch.optim.Optimizer] = None, + optimizer: torch.optim.Optimizer | None = None, model_key: str = "model", optim_key: str = "optim", ): - if not (Path(ckpt_dir) / ".metadata").exists(): + if not fs.exists(os.path.join(ckpt_dir, ".metadata")): raise ValueError( f"Please convert the checkpoint distcp format using `torch.distributed.checkpoint.format_utils.torch_save_to_dcp` before loading it" ) @@ -121,13 +122,13 @@ class CheckpointManager: self.existing_saves = self.get_existing_saves() - def get_existing_saves(self) -> List[Path]: + def get_existing_saves(self) -> list[str]: folders = [ p - for p in Path(self.path).iterdir() - if p.is_dir() and re.match(RE_FOLDER, p.name) + for p in self.fs.ls(self.path) + if self.fs.isdir(p) and re.match(RE_FOLDER, os.path.basename(p)) ] - folders.sort(key=lambda p: _get_key_step(p.name)) + folders.sort(key=lambda p: _get_key_step(os.path.basename(p))) return folders def clean_up(self): @@ -136,8 +137,9 @@ class CheckpointManager: eval_folders = [] other_folders = [] for p in self.existing_saves: - is_dump = _get_key_step(p.name) % self.dump_every.every == 0 - is_eval = _get_key_step(p.name) % self.eval_every.every == 0 + assert isinstance(p, str), f"Base path type: {p}" + is_dump = _get_key_step(os.path.basename(p)) % self.dump_every.every == 0 + is_eval = _get_key_step(os.path.basename(p)) % self.eval_every.every == 0 if is_dump: dump_folders.append(p) if is_eval: @@ -161,40 +163,39 @@ class CheckpointManager: if dist.get_rank() == 0: for folder in folder_to_remove: - for file in folder.iterdir(): - if file.is_file(): - file.unlink() - elif file.is_dir(): - assert file.name in [CONSOLIDATE_FOLDER] - for f in file.iterdir(): - f.unlink() - file.rmdir() - folder.rmdir() + for file in self.fs.ls(folder): + if self.fs.isfile(file): + self.fs.rm_file(file) + elif self.fs.isdir(file): + assert os.path.name(file) in [CONSOLIDATE_FOLDER] + for f in self.fs.ls(file): + self.fs.rm(f) + self.fs.rmdir(file) + self.fs.rmdir(folder) dist.barrier() self.existing_saves = list(folder_to_keep) - self.existing_saves.sort(key=lambda p: _get_key_step(p.name)) + self.existing_saves.sort(key=lambda p: _get_key_step(os.path.basename(p))) - def get_last_step_path(self, dp_rank: int = 0) -> Optional[Path]: + def get_last_step_path(self, dp_rank: int = 0) -> str | None: path = None for p in reversed(self.existing_saves): - if (p / TRAIN_STATE_NAME.format(dp_rank)).is_file(): + + if self.fs.isfile(os.path.join(p, TRAIN_STATE_NAME.format(dp_rank))): path = p break return path - def _create_folder(self, base_path: Path, folder_name: str) -> Path: - folder = base_path / folder_name + def _create_folder(self, base_path: str, folder_name: str) -> str: + folder = os.path.join(base_path, folder_name) if get_is_master(): - folder.mkdir(parents=False, exist_ok=True) + self.fs.mkdirs(folder, exist_ok=True) if dist.is_initialized(): dist.barrier() return folder - def _get_dp_tp_mesh( - self, device_mesh: Optional[DeviceMesh] = None - ) -> Tuple[int, int]: + def _get_dp_tp_mesh(self, device_mesh: DeviceMesh | None = None) -> tuple[int, int]: dp_rank = 0 tp_rank = 0 if device_mesh is not None: @@ -222,14 +223,14 @@ class CheckpointManager: model, optimizer, train_state, - config, - device_mesh: Optional[DeviceMesh] = None, + config: BaseModel, + device_mesh: DeviceMesh | None = None, ) -> bool: # When creating directory check if only rank0 or is there other solution - path = Path(self.path) + path = self.path curr_save_dir = self._create_folder(path, FOLDER_NAME.format(train_state.step)) - logger.info(f"Saving to: {str(curr_save_dir)}") + logger.info(f"Saving to: {curr_save_dir}") if dist.is_initialized(): dist.barrier() @@ -242,17 +243,19 @@ class CheckpointManager: if dist.is_initialized(): dist.barrier() + print("config type", type(config)) if get_is_master(): - config.dump_to_yaml_file(curr_save_dir / CONFIG_NAME) + self.fs.write_text( + os.path.join(curr_save_dir, CONFIG_NAME), config.model_dump_json() + ) # Add json dump here dp_rank, tp_rank = self._get_dp_tp_mesh(device_mesh) if tp_rank == 0: train_state_name = TRAIN_STATE_NAME.format(dp_rank) - logger.info( - f"Saving train state to: {str(curr_save_dir / train_state_name)}" - ) - with open(curr_save_dir / train_state_name, "w") as f: + train_state_full_path = os.path.join(curr_save_dir, train_state_name) + logger.info(f"Saving train state to: {train_state_full_path}") + with self.fs.open(train_state_full_path, "w") as f: json.dump(train_state.state_dict(), f) logger.info("Train state saved !") @@ -271,7 +274,7 @@ class CheckpointManager: optimizer, train_state, device_mesh: DeviceMesh, - path: Optional[Path] = None, + path: str | None = None, ): dp_rank, tp_rank = self._get_dp_tp_mesh(device_mesh) # Loading tries to load the provided path, if not available the last saved step and finally from the init path @@ -284,12 +287,12 @@ class CheckpointManager: # Only load train state if it's provided, the files exist and we're not loading from init path train_state_name = TRAIN_STATE_NAME.format(dp_rank) logger.info("Reloading train state") - with open(path / train_state_name, "r") as f: + with self.fs.open(os.path.join(path, train_state_name), "r") as f: train_state_dict = json.load(f) train_state.load_state_dict(train_state_dict) logger.info("Train state reloaded") - logger.info(f"Loading from: {str(path)}") + logger.info(f"Loading from: {path}") state_dict = self.get_state_dict( model=model, optimizer=optimizer, diff --git a/bytelatent/data/iterators/arrow_iterator.py b/bytelatent/data/iterators/arrow_iterator.py index 4e7b99e..1c68d3a 100644 --- a/bytelatent/data/iterators/arrow_iterator.py +++ b/bytelatent/data/iterators/arrow_iterator.py @@ -16,6 +16,7 @@ from bytelatent import ByteLatentError from bytelatent.data.data_types import BltExample from bytelatent.data.file_util import get_fs from bytelatent.data.iterators.abstract_iterator import IteratorState, StatefulIterator +from bytelatent.preprocess.preprocess_entropies import get_id_key, get_text logger = getLogger(__name__) @@ -32,6 +33,7 @@ class ArrowFileIteratorState(BaseModel, IteratorState): arrow_batch_size: int = 100 s3_profile: str | None filesystem_type: str | None = None + file_format: str def build(self) -> "ArrowFileIterator": arrow_file = ArrowFileIterator( @@ -44,6 +46,7 @@ class ArrowFileIteratorState(BaseModel, IteratorState): dataset_files=self.dataset_files, s3_profile=self.s3_profile, filesystem_type=self.filesystem_type, + file_format=self.file_format, ) if self.row_num != 0: arrow_file._set_row_num(self.row_num) @@ -70,6 +73,7 @@ class ArrowFileIterator(StatefulIterator): dataset_files: list[str] | None = None, s3_profile: str | None = None, filesystem_type: str | None = None, + file_format: str = "arrow", ): assert 0 <= worker_id < num_workers, (worker_id, num_workers) if file_path is None and dataset_files is None: @@ -87,12 +91,16 @@ class ArrowFileIterator(StatefulIterator): self.arrow_batch_size = arrow_batch_size self.s3_profile = s3_profile self.filesystem_type = filesystem_type + self.file_format = file_format self.fs = None if self.filesystem_type is not None: if self.filesystem_type == "file": self.fs = fsspec.filesystem("file") elif self.filesystem_type == "s3": self.fs = fsspec.filesystem("s3", profile=s3_profile) + else: + raise ValueError("Unknown filesystem") + logger.info("Arrow iterator using fs=%s", self.fs) if dataset_files is None: # Prepare arrow shards @@ -153,6 +161,7 @@ class ArrowFileIterator(StatefulIterator): dataset_files=self.dataset_files, s3_profile=self.s3_profile, filesystem_type=self.filesystem_type, + file_format=self.file_format, ) def create_iter( @@ -164,7 +173,7 @@ class ArrowFileIterator(StatefulIterator): else: filesystem = None self.dataset = pa.dataset.dataset( - self.dataset_files, format="arrow", filesystem=filesystem + self.dataset_files, format=self.file_format, filesystem=filesystem ) self.batch_iterator = self.dataset.to_batches( batch_size=self.arrow_batch_size @@ -173,13 +182,22 @@ class ArrowFileIterator(StatefulIterator): if self.batch_to_consume is not None: batch_columns: dict[str, list] = self.batch_to_consume self.batch_to_consume = None - sample_ids = batch_columns["sample_id"] - texts = batch_columns["text"] - entropies = batch_columns["entropies"] + if self.file_format == "arrow": + sample_ids = batch_columns["sample_id"] + texts = batch_columns["text"] + entropies = batch_columns["entropies"] + elif self.file_format == "json": + # This data hasn't been preprocessed to a uniform format, + # so we have to do it now and omit entropies + sample_ids = batch_columns[get_id_key(batch_columns)] + texts = get_text(batch_columns) + entropies = None + else: + raise ValueError(f"Unknown file format: {self.file_format}") for i in range(len(sample_ids)): out = BltExample( sample_id=sample_ids[i], - entropies=entropies[i], + entropies=entropies[i] if entropies is not None else None, text=texts[i], tokens=None, mask=None, @@ -191,13 +209,22 @@ class ArrowFileIterator(StatefulIterator): for batch in self.batch_iterator: batch_columns = batch.to_pydict() - sample_ids = batch_columns["sample_id"] - texts = batch_columns["text"] - entropies = batch_columns["entropies"] + if self.file_format == "arrow": + sample_ids = batch_columns["sample_id"] + texts = batch_columns["text"] + entropies = batch_columns["entropies"] + elif self.file_format == "json": + # This data hasn't been preprocessed to a uniform format, + # so we have to do it now and omit entropies + sample_ids = batch_columns[get_id_key(batch_columns)] + texts = get_text(batch_columns) + entropies = None + else: + raise ValueError(f"Unknown file format: {self.file_format}") for i in range(len(sample_ids)): out = BltExample( sample_id=sample_ids[i], - entropies=entropies[i], + entropies=entropies[i] if entropies is not None else None, text=texts[i], tokens=None, mask=None, @@ -231,13 +258,24 @@ class ArrowFileIterator(StatefulIterator): for batch in self.batch_iterator: if len(batch) > curr_remaining: batch_columns: dict[str, list] = batch.to_pydict() - batch_columns["sample_id"] = batch_columns["sample_id"][ - curr_remaining: - ] - batch_columns["entropies"] = batch_columns["entropies"][ - curr_remaining: - ] - batch_columns["text"] = batch_columns["text"][curr_remaining:] + if self.file_format == "arrow": + leftover_sample_ids = batch_columns["sample_id"][ + curr_remaining: + ] + leftover_entropies = batch_columns["entropies"][curr_remaining:] + leftover_texts = batch_columns["text"][curr_remaining:] + elif self.file_format == "json": + leftover_sample_ids = batch_columns[get_id_key(batch_columns)][ + curr_remaining: + ] + leftover_entropies = None + leftover_texts = get_text(batch_columns)[curr_remaining:] + else: + raise ValueError(f"Unknown file format: {self.file_format}") + + batch_columns["sample_id"] = leftover_sample_ids + batch_columns["entropies"] = leftover_entropies + batch_columns["text"] = leftover_texts self.batch_to_consume = batch_columns break elif len(batch) == curr_remaining: @@ -250,30 +288,3 @@ class ArrowFileIterator(StatefulIterator): logger.info( f"Finished setting arrow position to {target_row_num} for {self.dataset_files}" ) - - -TRAIN_DATA_FILE_PATTERN = "*.chunk.*.jsonl" - - -def find_and_sanitize_chunks( - dataset_path: str, - world_size: int, - file_pattern: str = TRAIN_DATA_FILE_PATTERN, - s3_profile: str | None = None, -): - fs = get_fs(dataset_path, s3_profile=s3_profile) - path_with_glob = os.path.join(dataset_path, file_pattern) - dataset_chunks = fs.glob(path_with_glob) - n_chunks = len(dataset_chunks) - - if n_chunks > world_size: - n_discard = n_chunks - world_size - dataset_chunks = dataset_chunks[:world_size] - else: - assert ( - world_size % n_chunks == 0 - ), "World size should be a multiple of number of chunks" - - assert n_chunks > 0, f"No valid chunks in {dataset_path}" - - return dataset_chunks diff --git a/bytelatent/distributed.py b/bytelatent/distributed.py index 168cb7c..f4b57e2 100644 --- a/bytelatent/distributed.py +++ b/bytelatent/distributed.py @@ -127,6 +127,16 @@ def dist_max(x: Union[int, float], mesh: DeviceMesh = None): return tensor +def dist_sum( + x: Union[int, float], mesh: DeviceMesh = None, reduce_dtype: torch.dtype = None +): + tensor = torch.tensor(x).cuda() + if reduce_dtype is not None: + tensor = tensor.to(reduce_dtype) + dist.all_reduce(tensor, op=ReduceOp.SUM, group=mesh.get_group() if mesh else None) + return tensor + + def dist_mean(x: Union[int, float], mesh: DeviceMesh = None): tensor = torch.tensor(x).cuda() dist.all_reduce(tensor, op=ReduceOp.AVG, group=mesh.get_group() if mesh else None) @@ -236,7 +246,7 @@ def setup_env(env_args: EnvironmentArgs): logger.warning(f"WARNING: Setting {name} to {value}") -def setup_torch_distributed(dist_args): +def setup_torch_distributed(dist_args: DistributedArgs): """ Handle single and multi-GPU / multi-node / SLURM jobs. Initialize the following variables: @@ -388,14 +398,14 @@ def clean_env(): def parallelize_model( - model, + model: torch.nn.Module, device_mesh, model_args, distributed_args: DistributedArgs, fsdp_grouping_plan: Optional[List[Tuple[str, bool]]] = None, tp_parallelize=None, no_recompute_ops=None, -): +) -> torch.nn.Module: if distributed_args.tp_size > 1: assert ( distributed_args.fsdp_type == "full_shard" @@ -429,6 +439,8 @@ def parallelize_model( device_mesh["dp_shard"].size() == 1 ), "dp_shard must be 1 for no_shard fsdp_type" + # TODO: Remove with something better + # model = model.to(param_dtype) fsdp_config = dict( mp_policy=( MixedPrecisionPolicy( diff --git a/bytelatent/eval.py b/bytelatent/eval.py index ae73066..1943a33 100644 --- a/bytelatent/eval.py +++ b/bytelatent/eval.py @@ -15,9 +15,16 @@ from lm_eval.api.model import LM from omegaconf import OmegaConf from pydantic import BaseModel, ConfigDict -from bytelatent.args import EvalArgs, ValidationArgs, parse_args +from bytelatent.args import ( + EvalArgs, + TrainArgs, + ValidationArgs, + find_and_sanitize_chunks, + parse_args, +) from bytelatent.checkpoint import CONSOLIDATE_FOLDER, consolidate_checkpoints from bytelatent.data.file_util import get_fs +from bytelatent.data.iterators.arrow_iterator import ArrowFileIterator from bytelatent.distributed import ( DistributedArgs, dist_mean_dict, @@ -117,19 +124,40 @@ class EvalHarnessLM(LM): return results -def eval_on_val(generator, val_args: ValidationArgs, train_cfg): - srcs = {} +def eval_on_val(generator, val_args: ValidationArgs, train_cfg: TrainArgs): + srcs = [] for src in val_args.sources: path = os.path.join(val_args.root_dir, src) - srcs[path] = 1.0 + srcs.append(path) + for src in train_cfg.data.sources: path = os.path.join(train_cfg.data.root_dir, src) - srcs[path] = 1.0 + srcs.append(path) - multi_state = init_choice_state( - "", srcs, 0, get_global_rank(), get_world_size(), "*.val.jsonl" - ) - path_to_iter = setup_sources(multi_state) + path_to_iter = {} + for path in srcs: + chunks = find_and_sanitize_chunks( + path, + world_size=1, + file_pattern="*.val.jsonl", + s3_profile=train_cfg.data.s3_profile, + ) + assert ( + len(chunks) == 1 + ), f"There should be only 1 chunk per validation file, but found: {chunks}" + chunk = chunks[0] + iterator = ArrowFileIterator( + dataset_files=[chunk], + file_path=None, + preprocess_dir=None, + entropy_model_name=None, + worker_id=0, + num_workers=1, + arrow_batch_size=train_cfg.data.arrow_batch_size, + s3_profile=train_cfg.data.s3_profile, + file_format="json", + ) + path_to_iter[path] = iterator max_gen_len = generator.max_gen_len # We temporarily lower max gen len @@ -137,16 +165,11 @@ def eval_on_val(generator, val_args: ValidationArgs, train_cfg): all_val_metrics = {} for src in path_to_iter: - jsonl_iterator = path_to_iter[src] + example_iterator = path_to_iter[src].create_iter() texts = [] logger.info(f"Running validation on {src}...") - for step, (content, state) in enumerate(jsonl_iterator): - if state["current_iter"] > 0 or ( - val_args.max_steps is not None and step >= val_args.max_steps - ): - break - content_key = "text" if ("text" in content) else "content" - texts.append(content[content_key]) + for step, example in enumerate(example_iterator): + texts.append(example.text) _, loglikelihood, _ = generator.generate(texts) @@ -191,7 +214,7 @@ def launch_eval(eval_args: EvalArgs): else: consolidate_path = os.path.join(eval_args.ckpt_dir, CONSOLIDATE_FOLDER) if not fs.exists(consolidate_path) and get_global_rank() == 0: - consolidate_path = consolidate_checkpoints(eval_args.ckpt_dir) + consolidate_path = consolidate_checkpoints(fs, eval_args.ckpt_dir) fs.mkdirs(eval_args.dump_dir, exist_ok=True) with fs.open(os.path.join(eval_args.dump_dir, "config.yaml"), "w") as f: @@ -210,10 +233,12 @@ def launch_eval(eval_args: EvalArgs): wrap = EvalHarnessLM(generator) # Redo - results = simple_evaluate(wrap, eval_args.harness.model_dump()) + results = simple_evaluate(wrap, **eval_args.harness.model_dump()) + val_results = None if eval_args.validation: val_results = eval_on_val(generator, eval_args.validation, train_cfg) + if get_global_rank() == 0: with fs.open(os.path.join(eval_args.dump_dir, "results.json"), "w") as f: f.write(json.dumps(results)) @@ -222,6 +247,7 @@ def launch_eval(eval_args: EvalArgs): with fs.open(os.path.join(eval_args.dump_dir, "validation.json"), "w") as f: f.write(json.dumps(val_results)) logger.info(f"All validation results: {val_results}") + if eval_args.metric_log_dir and get_global_rank() == 0: metric_log_path = os.path.join(eval_args.metric_log_dir, "metrics.eval.jsonl") diff --git a/bytelatent/generate.py b/bytelatent/generate.py index eb79d81..9d44f30 100644 --- a/bytelatent/generate.py +++ b/bytelatent/generate.py @@ -387,8 +387,7 @@ def load_consolidated_model_and_tokenizer( ): train_args_path = os.path.join(consolidated_path, "params.json") fs = get_fs(train_args_path) - with fs.open(train_args_path) as f: - train_args = TrainArgs.model_validate_json(f.read()) + train_args = TrainArgs.model_validate_json(fs.read_text(train_args_path)) if train_args.train_entropy_model: model_args = train_args.entropy_model @@ -401,7 +400,8 @@ def load_consolidated_model_and_tokenizer( train_args.distributed.model_dtype ] tokenizer = train_args.data.tokenizer_args.build() - st_dict = torch.load(consolidated_path / CONSOLIDATE_NAME, weights_only=True) + with fs.open(os.path.join(consolidated_path, CONSOLIDATE_NAME)) as f: + st_dict = torch.load(f, weights_only=True) model.load_state_dict(st_dict["model"]) model = model.cuda().eval() for param in model.parameters(): diff --git a/bytelatent/norms.py b/bytelatent/norms.py new file mode 100644 index 0000000..81d1652 --- /dev/null +++ b/bytelatent/norms.py @@ -0,0 +1,100 @@ +from typing import Dict, List, Optional, Tuple + +import torch +from torch import Tensor +from torch.utils._foreach_utils import ( + _device_has_foreach_support, + _group_tensors_by_device_and_dtype, + _has_foreach_support, +) + + +@torch.no_grad() +def fixed_clip_grad_norm_( + parameters: torch.Tensor | list[torch.Tensor], + max_norm: float, + norm_type: float = 2.0, + error_if_nonfinite: bool = False, + foreach: Optional[bool] = None, +) -> torch.Tensor: + r"""Clip the gradient norm of an iterable of parameters. + + The norm is computed over the norms of the individual gradients of all parameters, + as if the norms of the individual gradients were concatenated into a single vector. + Gradients are modified in-place. + + Args: + parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a + single Tensor that will have gradients normalized + max_norm (float): max norm of the gradients + norm_type (float): type of the used p-norm. Can be ``'inf'`` for + infinity norm. + error_if_nonfinite (bool): if True, an error is thrown if the total + norm of the gradients from :attr:`parameters` is ``nan``, + ``inf``, or ``-inf``. Default: False (will switch to True in the future) + foreach (bool): use the faster foreach-based implementation. + If ``None``, use the foreach implementation for CUDA and CPU native tensors and silently + fall back to the slow implementation for other device types. + Default: ``None`` + + Returns: + Total norm of the parameter gradients (viewed as a single vector). + """ + if isinstance(parameters, torch.Tensor): + parameters = [parameters] + grads = [p.grad.to(torch.bfloat16) for p in parameters if p.grad is not None] + max_norm = float(max_norm) + norm_type = float(norm_type) + if len(grads) == 0: + return torch.tensor(0.0) + first_device = grads[0].device + grouped_grads: Dict[ + Tuple[torch.device, torch.dtype], Tuple[List[List[Tensor]], List[int]] + ] = _group_tensors_by_device_and_dtype( + [grads] + ) # type: ignore[assignment] + + norms: List[Tensor] = [] + for (device, _), ([device_grads], _) in grouped_grads.items(): # type: ignore[assignment] + if (foreach is None and _has_foreach_support(device_grads, device)) or ( + foreach and _device_has_foreach_support(device) + ): + norms.extend(torch._foreach_norm(device_grads, norm_type)) + elif foreach: + raise RuntimeError( + f"foreach=True was passed, but can't use the foreach API on {device.type} tensors" + ) + else: + norms.extend([torch.linalg.vector_norm(g, norm_type) for g in device_grads]) + + total_norm = torch.linalg.vector_norm( + torch.stack([norm.to(first_device) for norm in norms]), norm_type + ) + + if error_if_nonfinite and torch.logical_or(total_norm.isnan(), total_norm.isinf()): + raise RuntimeError( + f"The total norm of order {norm_type} for gradients from " + "`parameters` is non-finite, so it cannot be clipped. To disable " + "this error and scale the gradients by the non-finite norm anyway, " + "set `error_if_nonfinite=False`" + ) + clip_coef = max_norm / (total_norm + 1e-6) + # Note: multiplying by the clamped coef is redundant when the coef is clamped to 1, but doing so + # avoids a `if clip_coef < 1:` conditional which can require a CPU <=> device synchronization + # when the gradients do not reside in CPU memory. + clip_coef_clamped = torch.clamp(clip_coef, max=1.0) + for (device, _), ([device_grads], _) in grouped_grads.items(): # type: ignore[assignment] + if (foreach is None and _has_foreach_support(device_grads, device)) or ( + foreach and _device_has_foreach_support(device) + ): + torch._foreach_mul_(device_grads, clip_coef_clamped.to(device)) + elif foreach: + raise RuntimeError( + f"foreach=True was passed, but can't use the foreach API on {device.type} tensors" + ) + else: + clip_coef_clamped_device = clip_coef_clamped.to(device) + for g in device_grads: + g.mul_(clip_coef_clamped_device) + + return total_norm diff --git a/bytelatent/preprocess/preprocess_entropies.py b/bytelatent/preprocess/preprocess_entropies.py index 519da94..31a4802 100644 --- a/bytelatent/preprocess/preprocess_entropies.py +++ b/bytelatent/preprocess/preprocess_entropies.py @@ -15,29 +15,37 @@ from bytelatent.entropy_model import load_entropy_model from bytelatent.tokenizers.build_tokenizer import TokenizerArgs -def get_id_from_doc(doc: dict) -> int: +def get_id_key(doc: dict) -> int: """ We need a reliable way to ensure that samples from jsonl and arrow are the same, but there is no unique id field, so derive the best possible """ if "sample_id" in doc: - sample_id = doc["sample_id"] + return "sample_id" elif "title" in doc: - sample_id = doc["title"] + return "title" elif "qid" in doc: - sample_id = doc["qid"] + return "qid" elif "paper_id" in doc: - sample_id = doc["paper_id"] + return "paper_id" elif "path" in doc: - sample_id = doc["path"] + return "path" elif "url" in doc: - sample_id = doc["url"] + return "url" elif "id" in doc: - sample_id = doc["id"] + return "id" else: raise ValueError(f"Could not find a id key from: {doc.keys()}") - return str(sample_id) + + +def get_id_from_doc(doc: dict) -> int: + """ + We need a reliable way to ensure that samples from jsonl + and arrow are the same, but there is no unique id field, + so derive the best possible + """ + return str(doc[get_id_key(doc)]) def get_text(doc: dict): diff --git a/bytelatent/train.py b/bytelatent/train.py index 6b20ecd..16c6865 100644 --- a/bytelatent/train.py +++ b/bytelatent/train.py @@ -3,6 +3,7 @@ import gc import logging +import math import os import sys from contextlib import ExitStack @@ -12,6 +13,7 @@ from pathlib import Path from timeit import default_timer as timer from typing import Any, TypeVar +import numpy as np import torch import torch.distributed import torch.nn.functional @@ -25,6 +27,7 @@ from torch.optim import lr_scheduler from bytelatent.args import TrainArgs, parse_args from bytelatent.checkpoint import CheckpointManager, load_from_checkpoint +from bytelatent.data.file_util import get_fs from bytelatent.data.iterators.multiprocess_iterator import ( MultiprocessIterator, MultiprocessIteratorState, @@ -33,7 +36,9 @@ from bytelatent.data.iterators.packing_iterator import PackingIteratorState from bytelatent.distributed import ( check_model_value_range, clean_env, + dist_mean, dist_mean_dict, + dist_sum, get_device_mesh, get_is_master, get_world_size, @@ -47,6 +52,7 @@ from bytelatent.eval import EVAL_FOLDER_NAME, launch_eval from bytelatent.logger import init_logger from bytelatent.metrics import GPUMemoryMonitor, MetricLogger, get_num_params from bytelatent.model.blt import ByteLatentTransformer +from bytelatent.norms import fixed_clip_grad_norm_ from bytelatent.optim import build_optimizer from bytelatent.probe import AutoProbeD from bytelatent.profiling import maybe_run_profiler @@ -295,8 +301,11 @@ def train(args: TrainArgs): if args.checkpoint.init_ckpt_path: logger.info(f"Loading initial model from {args.checkpoint.init_ckpt_path}") + ckpt_fs = get_fs( + args.checkpoint.init_ckpt_path, s3_profile=args.checkpoint.s3_profile + ) load_from_checkpoint( - args.checkpoint.init_ckpt_path, model, model_key="model" + ckpt_fs, args.checkpoint.init_ckpt_path, model, model_key="model" ) # Put model_key="" if its directly the model checkpoint model.rope_embeddings.reset_parameters() # For RoPe initialization since it's a buffer it might not be loaded else: @@ -364,6 +373,9 @@ def train(args: TrainArgs): time_last_log = timer() gc.collect() saved = False + step_losses: list[float] = [] + step_tok_losses: list[float] = [] + n_bytes: int = 0 while train_state.step < args.steps and ( args.max_steps is None or train_state.step < args.max_steps ): @@ -385,6 +397,24 @@ def train(args: TrainArgs): batch_patch_lengths = torch.from_numpy(batch.patch_lengths).cuda() mask = None if batch.mask is None else torch.from_numpy(batch.mask).cuda() + if args.data.tokenizer_args.name in ["bytes", "blt"]: + if mask is None: + n_bytes += batch_y.numel() + else: + n_bytes += mask.sum() + elif args.data.tokenizer_args.name in ["sp", "tiktoken"]: + for example in batch.y: + target_tokens = tokenizer.decode(example.tolist(), cut_at_eos=False) + n_bytes += ( + len(bytes(target_tokens, encoding="utf-8", errors="ignore")) + + sum(example == tokenizer.eos_id) + + sum(example == tokenizer.bos_id) + ) + else: + raise ValueError( + f"Unexpected tokenizer to count n_bytes for: {args.data.tokenizer_args.name}" + ) + if ( not args.train_entropy_model and args.model.encoder_enable_byte_ngrams @@ -459,7 +489,7 @@ def train(args: TrainArgs): batch_x, patch_lengths=batch_patch_lengths, ngram_ids=ngram_ids ) - loss, _ = compute_loss(pred, batch_y, mask, train_state.scale) + loss, tok_loss = compute_loss(pred, batch_y, mask, train_state.scale) # We scale loss with grad_acc_steps so the gradient is the same # regardless of grad_acc_steps @@ -470,8 +500,14 @@ def train(args: TrainArgs): # For logging we undo that scaling loss = loss.detach() * args.grad_acc_steps - grad_norm = torch.nn.utils.clip_grad_norm_( - model.parameters(), max_norm=args.optim.clip, foreach=True + # Undo loss scaling so downstream down't need to worry about it + step_losses.append((loss / train_state.scale).item()) + step_tok_losses.append(tok_loss / train_state.scale) + + # grad_norm = torch.nn.utils.clip_grad_norm_( + grad_norm = fixed_clip_grad_norm_( + model.parameters(), + max_norm=args.optim.clip, # foreach=True ) grad_norm = ( @@ -559,20 +595,33 @@ def train(args: TrainArgs): gpu_memory_monitor.reset_peak_stats() nwords_since_last_log = 0 time_last_log = timer() + stacked_tok_loss = torch.cat(step_tok_losses, dim=0) + total_tok_loss = dist_sum( + stacked_tok_loss.sum().item(), reduce_dtype=torch.bfloat16 + ) + total_n_bytes = dist_sum(n_bytes, reduce_dtype=torch.bfloat16) + avg_bpb = total_tok_loss / math.log(2) / total_n_bytes + avg_loss = dist_mean(np.mean(step_losses).item()) logger.info( f"step: {train_state.step}" f" acc: {train_state.acc_step}" - f" loss: {round(loss.item(),4):>7}" + f" loss: step={round(loss.item(),4):>7} avg={avg_loss}" + f" bpb: {avg_bpb:3f}" f" grad: {grad_norm:.2e}" f" flops: {FLOPS:.2e}" f" wps: {wps:.2e}" f" iter: {curr_iter_time:>7}" f" data: {data_load_time:>5}" f" lr: {curr_lr:.2e}" + f" n_bytes={total_n_bytes}" f" mem: {gpu_mem_stats.max_active_pct:.0f}%" f" pow: {gpu_mem_stats.power_draw/1000} W" ) + n_bytes = 0 + step_losses = [] + step_tok_losses = [] + if every_n_steps( train_state, args.checkpoint.dump.every, acc_step=0 ) or every_n_steps(train_state, args.checkpoint.eval.every, acc_step=0): From e742218d65f8988736b96e94f98e7ed773a6ddf6 Mon Sep 17 00:00:00 2001 From: Pedro Rodriguez Date: Wed, 5 Feb 2025 00:20:57 +0000 Subject: [PATCH 21/59] Update checkpointing to use fsspec Summary: - Make arrow iterator able to read from jsonl files, the entropies are omitted in this case - Make the data/checkpoint code fsspec compatible - Fix issues with all reduce with non-bf16 in dist_sum and norm computation. - Minimal fixes to get eval to run, it is slow currently - Add bpb numbers during training Test Plan: Run ``` torchrun --nproc-per-node 8 -m bytelatent.train config=internal/configs/entropy_model.yaml eval=null max_steps=10100 ``` ``` python -m bytelatent.train config=internal/configs/s3_debug.yaml eval=null ``` ``` torchrun --nproc-per-node 8 -m bytelatent.train config=internal/configs/s3_debug.yaml eval=null ``` --- bytelatent/checkpoint.py | 97 +++++++++++++++++++++------------------- 1 file changed, 50 insertions(+), 47 deletions(-) diff --git a/bytelatent/checkpoint.py b/bytelatent/checkpoint.py index f213c84..6631673 100644 --- a/bytelatent/checkpoint.py +++ b/bytelatent/checkpoint.py @@ -4,8 +4,6 @@ import json import logging import os import re -from pathlib import Path -from typing import List, Optional, Tuple import fsspec import torch @@ -70,26 +68,29 @@ def consolidate_checkpoints(fs: fsspec.AbstractFileSystem, ckpt_dir: str): Returns the path to the consolidated checkpoint """ - consolidate_path = Path(ckpt_dir) / CONSOLIDATE_FOLDER - if not (consolidate_path / CONSOLIDATE_NAME).exists(): - consolidate_path.mkdir(exist_ok=True) - logger.info(f"Consolidating to: {str(consolidate_path)}") - dcp_to_torch_save(ckpt_dir, str(consolidate_path / CONSOLIDATE_NAME)) - (consolidate_path / CONFIG_NAME).write_text( - (Path(ckpt_dir) / CONFIG_NAME).read_text() + consolidate_path = os.path.join(ckpt_dir, CONSOLIDATE_FOLDER) + consolidate_name = os.path.join(consolidate_path, CONSOLIDATE_NAME) + if not fs.exists(consolidate_name): + fs.mkdirs(consolidate_path, exist_ok=True) + logger.info(f"Consolidating to: {consolidate_path}") + dcp_to_torch_save(ckpt_dir, consolidate_name) + fs.write_text( + os.path.join(consolidate_path, CONFIG_NAME), + fs.read_text(os.path.join(ckpt_dir, CONFIG_NAME)), ) logger.info("Consolidated !") return consolidate_path def load_from_checkpoint( + fs: fsspec.AbstractFileSystem, ckpt_dir: str, model: nn.Module, - optimizer: Optional[torch.optim.Optimizer] = None, + optimizer: torch.optim.Optimizer | None = None, model_key: str = "model", optim_key: str = "optim", ): - if not (Path(ckpt_dir) / ".metadata").exists(): + if not fs.exists(os.path.join(ckpt_dir, ".metadata")): raise ValueError( f"Please convert the checkpoint distcp format using `torch.distributed.checkpoint.format_utils.torch_save_to_dcp` before loading it" ) @@ -121,13 +122,13 @@ class CheckpointManager: self.existing_saves = self.get_existing_saves() - def get_existing_saves(self) -> List[Path]: + def get_existing_saves(self) -> list[str]: folders = [ p - for p in Path(self.path).iterdir() - if p.is_dir() and re.match(RE_FOLDER, p.name) + for p in self.fs.ls(self.path) + if self.fs.isdir(p) and re.match(RE_FOLDER, os.path.basename(p)) ] - folders.sort(key=lambda p: _get_key_step(p.name)) + folders.sort(key=lambda p: _get_key_step(os.path.basename(p))) return folders def clean_up(self): @@ -136,8 +137,9 @@ class CheckpointManager: eval_folders = [] other_folders = [] for p in self.existing_saves: - is_dump = _get_key_step(p.name) % self.dump_every.every == 0 - is_eval = _get_key_step(p.name) % self.eval_every.every == 0 + assert isinstance(p, str), f"Base path type: {p}" + is_dump = _get_key_step(os.path.basename(p)) % self.dump_every.every == 0 + is_eval = _get_key_step(os.path.basename(p)) % self.eval_every.every == 0 if is_dump: dump_folders.append(p) if is_eval: @@ -161,40 +163,39 @@ class CheckpointManager: if dist.get_rank() == 0: for folder in folder_to_remove: - for file in folder.iterdir(): - if file.is_file(): - file.unlink() - elif file.is_dir(): - assert file.name in [CONSOLIDATE_FOLDER] - for f in file.iterdir(): - f.unlink() - file.rmdir() - folder.rmdir() + for file in self.fs.ls(folder): + if self.fs.isfile(file): + self.fs.rm_file(file) + elif self.fs.isdir(file): + assert os.path.name(file) in [CONSOLIDATE_FOLDER] + for f in self.fs.ls(file): + self.fs.rm(f) + self.fs.rmdir(file) + self.fs.rmdir(folder) dist.barrier() self.existing_saves = list(folder_to_keep) - self.existing_saves.sort(key=lambda p: _get_key_step(p.name)) + self.existing_saves.sort(key=lambda p: _get_key_step(os.path.basename(p))) - def get_last_step_path(self, dp_rank: int = 0) -> Optional[Path]: + def get_last_step_path(self, dp_rank: int = 0) -> str | None: path = None for p in reversed(self.existing_saves): - if (p / TRAIN_STATE_NAME.format(dp_rank)).is_file(): + + if self.fs.isfile(os.path.join(p, TRAIN_STATE_NAME.format(dp_rank))): path = p break return path - def _create_folder(self, base_path: Path, folder_name: str) -> Path: - folder = base_path / folder_name + def _create_folder(self, base_path: str, folder_name: str) -> str: + folder = os.path.join(base_path, folder_name) if get_is_master(): - folder.mkdir(parents=False, exist_ok=True) + self.fs.mkdirs(folder, exist_ok=True) if dist.is_initialized(): dist.barrier() return folder - def _get_dp_tp_mesh( - self, device_mesh: Optional[DeviceMesh] = None - ) -> Tuple[int, int]: + def _get_dp_tp_mesh(self, device_mesh: DeviceMesh | None = None) -> tuple[int, int]: dp_rank = 0 tp_rank = 0 if device_mesh is not None: @@ -222,14 +223,14 @@ class CheckpointManager: model, optimizer, train_state, - config, - device_mesh: Optional[DeviceMesh] = None, + config: BaseModel, + device_mesh: DeviceMesh | None = None, ) -> bool: # When creating directory check if only rank0 or is there other solution - path = Path(self.path) + path = self.path curr_save_dir = self._create_folder(path, FOLDER_NAME.format(train_state.step)) - logger.info(f"Saving to: {str(curr_save_dir)}") + logger.info(f"Saving to: {curr_save_dir}") if dist.is_initialized(): dist.barrier() @@ -242,17 +243,19 @@ class CheckpointManager: if dist.is_initialized(): dist.barrier() + print("config type", type(config)) if get_is_master(): - config.dump_to_yaml_file(curr_save_dir / CONFIG_NAME) + self.fs.write_text( + os.path.join(curr_save_dir, CONFIG_NAME), config.model_dump_json() + ) # Add json dump here dp_rank, tp_rank = self._get_dp_tp_mesh(device_mesh) if tp_rank == 0: train_state_name = TRAIN_STATE_NAME.format(dp_rank) - logger.info( - f"Saving train state to: {str(curr_save_dir / train_state_name)}" - ) - with open(curr_save_dir / train_state_name, "w") as f: + train_state_full_path = os.path.join(curr_save_dir, train_state_name) + logger.info(f"Saving train state to: {train_state_full_path}") + with self.fs.open(train_state_full_path, "w") as f: json.dump(train_state.state_dict(), f) logger.info("Train state saved !") @@ -271,7 +274,7 @@ class CheckpointManager: optimizer, train_state, device_mesh: DeviceMesh, - path: Optional[Path] = None, + path: str | None = None, ): dp_rank, tp_rank = self._get_dp_tp_mesh(device_mesh) # Loading tries to load the provided path, if not available the last saved step and finally from the init path @@ -284,12 +287,12 @@ class CheckpointManager: # Only load train state if it's provided, the files exist and we're not loading from init path train_state_name = TRAIN_STATE_NAME.format(dp_rank) logger.info("Reloading train state") - with open(path / train_state_name, "r") as f: + with self.fs.open(os.path.join(path, train_state_name), "r") as f: train_state_dict = json.load(f) train_state.load_state_dict(train_state_dict) logger.info("Train state reloaded") - logger.info(f"Loading from: {str(path)}") + logger.info(f"Loading from: {path}") state_dict = self.get_state_dict( model=model, optimizer=optimizer, From b2058fb0f617b95f0379ca700cfa95a69b4237c6 Mon Sep 17 00:00:00 2001 From: Pedro Rodriguez Date: Wed, 5 Feb 2025 00:27:00 +0000 Subject: [PATCH 22/59] Update checkpointing to use fsspec Summary: - Make arrow iterator able to read from jsonl files, the entropies are omitted in this case - Make the data/checkpoint code fsspec compatible - Fix issues with all reduce with non-bf16 in dist_sum and norm computation. - Minimal fixes to get eval to run, it is slow currently - Add bpb numbers during training Test Plan: Run unit tests and the commands below ``` python -m bytelatent.train config=internal/configs/s3_debug.yaml eval=null checkpoint.dump.every=100 ``` ``` torchrun --nproc-per-node 8 -m bytelatent.train config=internal/configs/s3_debug.yaml eval=null checkpoint.dump.every=100 ``` ``` python -m bytelatent.train config=internal/configs/s3_debug.yaml eval=null checkpoint.dump.every=100 dump_dir=s3://blt/scratch/checkpoint-test/ ``` ``` torchrun --nproc-per-node 8 -m bytelatent.train config=internal/configs/s3_debug.yaml eval=null checkpoint.dump.every=100 dump_dir=s3://blt/scratch/checkpoint-test/ ``` --- bytelatent/checkpoint.py | 97 +++++++++++++++++++++------------------- 1 file changed, 50 insertions(+), 47 deletions(-) diff --git a/bytelatent/checkpoint.py b/bytelatent/checkpoint.py index f213c84..6631673 100644 --- a/bytelatent/checkpoint.py +++ b/bytelatent/checkpoint.py @@ -4,8 +4,6 @@ import json import logging import os import re -from pathlib import Path -from typing import List, Optional, Tuple import fsspec import torch @@ -70,26 +68,29 @@ def consolidate_checkpoints(fs: fsspec.AbstractFileSystem, ckpt_dir: str): Returns the path to the consolidated checkpoint """ - consolidate_path = Path(ckpt_dir) / CONSOLIDATE_FOLDER - if not (consolidate_path / CONSOLIDATE_NAME).exists(): - consolidate_path.mkdir(exist_ok=True) - logger.info(f"Consolidating to: {str(consolidate_path)}") - dcp_to_torch_save(ckpt_dir, str(consolidate_path / CONSOLIDATE_NAME)) - (consolidate_path / CONFIG_NAME).write_text( - (Path(ckpt_dir) / CONFIG_NAME).read_text() + consolidate_path = os.path.join(ckpt_dir, CONSOLIDATE_FOLDER) + consolidate_name = os.path.join(consolidate_path, CONSOLIDATE_NAME) + if not fs.exists(consolidate_name): + fs.mkdirs(consolidate_path, exist_ok=True) + logger.info(f"Consolidating to: {consolidate_path}") + dcp_to_torch_save(ckpt_dir, consolidate_name) + fs.write_text( + os.path.join(consolidate_path, CONFIG_NAME), + fs.read_text(os.path.join(ckpt_dir, CONFIG_NAME)), ) logger.info("Consolidated !") return consolidate_path def load_from_checkpoint( + fs: fsspec.AbstractFileSystem, ckpt_dir: str, model: nn.Module, - optimizer: Optional[torch.optim.Optimizer] = None, + optimizer: torch.optim.Optimizer | None = None, model_key: str = "model", optim_key: str = "optim", ): - if not (Path(ckpt_dir) / ".metadata").exists(): + if not fs.exists(os.path.join(ckpt_dir, ".metadata")): raise ValueError( f"Please convert the checkpoint distcp format using `torch.distributed.checkpoint.format_utils.torch_save_to_dcp` before loading it" ) @@ -121,13 +122,13 @@ class CheckpointManager: self.existing_saves = self.get_existing_saves() - def get_existing_saves(self) -> List[Path]: + def get_existing_saves(self) -> list[str]: folders = [ p - for p in Path(self.path).iterdir() - if p.is_dir() and re.match(RE_FOLDER, p.name) + for p in self.fs.ls(self.path) + if self.fs.isdir(p) and re.match(RE_FOLDER, os.path.basename(p)) ] - folders.sort(key=lambda p: _get_key_step(p.name)) + folders.sort(key=lambda p: _get_key_step(os.path.basename(p))) return folders def clean_up(self): @@ -136,8 +137,9 @@ class CheckpointManager: eval_folders = [] other_folders = [] for p in self.existing_saves: - is_dump = _get_key_step(p.name) % self.dump_every.every == 0 - is_eval = _get_key_step(p.name) % self.eval_every.every == 0 + assert isinstance(p, str), f"Base path type: {p}" + is_dump = _get_key_step(os.path.basename(p)) % self.dump_every.every == 0 + is_eval = _get_key_step(os.path.basename(p)) % self.eval_every.every == 0 if is_dump: dump_folders.append(p) if is_eval: @@ -161,40 +163,39 @@ class CheckpointManager: if dist.get_rank() == 0: for folder in folder_to_remove: - for file in folder.iterdir(): - if file.is_file(): - file.unlink() - elif file.is_dir(): - assert file.name in [CONSOLIDATE_FOLDER] - for f in file.iterdir(): - f.unlink() - file.rmdir() - folder.rmdir() + for file in self.fs.ls(folder): + if self.fs.isfile(file): + self.fs.rm_file(file) + elif self.fs.isdir(file): + assert os.path.name(file) in [CONSOLIDATE_FOLDER] + for f in self.fs.ls(file): + self.fs.rm(f) + self.fs.rmdir(file) + self.fs.rmdir(folder) dist.barrier() self.existing_saves = list(folder_to_keep) - self.existing_saves.sort(key=lambda p: _get_key_step(p.name)) + self.existing_saves.sort(key=lambda p: _get_key_step(os.path.basename(p))) - def get_last_step_path(self, dp_rank: int = 0) -> Optional[Path]: + def get_last_step_path(self, dp_rank: int = 0) -> str | None: path = None for p in reversed(self.existing_saves): - if (p / TRAIN_STATE_NAME.format(dp_rank)).is_file(): + + if self.fs.isfile(os.path.join(p, TRAIN_STATE_NAME.format(dp_rank))): path = p break return path - def _create_folder(self, base_path: Path, folder_name: str) -> Path: - folder = base_path / folder_name + def _create_folder(self, base_path: str, folder_name: str) -> str: + folder = os.path.join(base_path, folder_name) if get_is_master(): - folder.mkdir(parents=False, exist_ok=True) + self.fs.mkdirs(folder, exist_ok=True) if dist.is_initialized(): dist.barrier() return folder - def _get_dp_tp_mesh( - self, device_mesh: Optional[DeviceMesh] = None - ) -> Tuple[int, int]: + def _get_dp_tp_mesh(self, device_mesh: DeviceMesh | None = None) -> tuple[int, int]: dp_rank = 0 tp_rank = 0 if device_mesh is not None: @@ -222,14 +223,14 @@ class CheckpointManager: model, optimizer, train_state, - config, - device_mesh: Optional[DeviceMesh] = None, + config: BaseModel, + device_mesh: DeviceMesh | None = None, ) -> bool: # When creating directory check if only rank0 or is there other solution - path = Path(self.path) + path = self.path curr_save_dir = self._create_folder(path, FOLDER_NAME.format(train_state.step)) - logger.info(f"Saving to: {str(curr_save_dir)}") + logger.info(f"Saving to: {curr_save_dir}") if dist.is_initialized(): dist.barrier() @@ -242,17 +243,19 @@ class CheckpointManager: if dist.is_initialized(): dist.barrier() + print("config type", type(config)) if get_is_master(): - config.dump_to_yaml_file(curr_save_dir / CONFIG_NAME) + self.fs.write_text( + os.path.join(curr_save_dir, CONFIG_NAME), config.model_dump_json() + ) # Add json dump here dp_rank, tp_rank = self._get_dp_tp_mesh(device_mesh) if tp_rank == 0: train_state_name = TRAIN_STATE_NAME.format(dp_rank) - logger.info( - f"Saving train state to: {str(curr_save_dir / train_state_name)}" - ) - with open(curr_save_dir / train_state_name, "w") as f: + train_state_full_path = os.path.join(curr_save_dir, train_state_name) + logger.info(f"Saving train state to: {train_state_full_path}") + with self.fs.open(train_state_full_path, "w") as f: json.dump(train_state.state_dict(), f) logger.info("Train state saved !") @@ -271,7 +274,7 @@ class CheckpointManager: optimizer, train_state, device_mesh: DeviceMesh, - path: Optional[Path] = None, + path: str | None = None, ): dp_rank, tp_rank = self._get_dp_tp_mesh(device_mesh) # Loading tries to load the provided path, if not available the last saved step and finally from the init path @@ -284,12 +287,12 @@ class CheckpointManager: # Only load train state if it's provided, the files exist and we're not loading from init path train_state_name = TRAIN_STATE_NAME.format(dp_rank) logger.info("Reloading train state") - with open(path / train_state_name, "r") as f: + with self.fs.open(os.path.join(path, train_state_name), "r") as f: train_state_dict = json.load(f) train_state.load_state_dict(train_state_dict) logger.info("Train state reloaded") - logger.info(f"Loading from: {str(path)}") + logger.info(f"Loading from: {path}") state_dict = self.get_state_dict( model=model, optimizer=optimizer, From 4ad488940533160df500dae7a80003209ed92136 Mon Sep 17 00:00:00 2001 From: Pedro Rodriguez Date: Wed, 5 Feb 2025 00:30:25 +0000 Subject: [PATCH 23/59] Update checkpointing to use fsspec Summary: - Make the data/checkpoint code fsspec compatible Test Plan: Run unit tests and the commands below ``` python -m bytelatent.train config=internal/configs/s3_debug.yaml eval=null checkpoint.dump.every=100 ``` ``` torchrun --nproc-per-node 8 -m bytelatent.train config=internal/configs/s3_debug.yaml eval=null checkpoint.dump.every=100 ``` ``` python -m bytelatent.train config=internal/configs/s3_debug.yaml eval=null checkpoint.dump.every=100 dump_dir=s3://blt/scratch/checkpoint-test/ ``` ``` torchrun --nproc-per-node 8 -m bytelatent.train config=internal/configs/s3_debug.yaml eval=null checkpoint.dump.every=100 dump_dir=s3://blt/scratch/checkpoint-test/ ``` --- bytelatent/checkpoint.py | 97 +++++++++++++++++++++------------------- 1 file changed, 50 insertions(+), 47 deletions(-) diff --git a/bytelatent/checkpoint.py b/bytelatent/checkpoint.py index f213c84..6631673 100644 --- a/bytelatent/checkpoint.py +++ b/bytelatent/checkpoint.py @@ -4,8 +4,6 @@ import json import logging import os import re -from pathlib import Path -from typing import List, Optional, Tuple import fsspec import torch @@ -70,26 +68,29 @@ def consolidate_checkpoints(fs: fsspec.AbstractFileSystem, ckpt_dir: str): Returns the path to the consolidated checkpoint """ - consolidate_path = Path(ckpt_dir) / CONSOLIDATE_FOLDER - if not (consolidate_path / CONSOLIDATE_NAME).exists(): - consolidate_path.mkdir(exist_ok=True) - logger.info(f"Consolidating to: {str(consolidate_path)}") - dcp_to_torch_save(ckpt_dir, str(consolidate_path / CONSOLIDATE_NAME)) - (consolidate_path / CONFIG_NAME).write_text( - (Path(ckpt_dir) / CONFIG_NAME).read_text() + consolidate_path = os.path.join(ckpt_dir, CONSOLIDATE_FOLDER) + consolidate_name = os.path.join(consolidate_path, CONSOLIDATE_NAME) + if not fs.exists(consolidate_name): + fs.mkdirs(consolidate_path, exist_ok=True) + logger.info(f"Consolidating to: {consolidate_path}") + dcp_to_torch_save(ckpt_dir, consolidate_name) + fs.write_text( + os.path.join(consolidate_path, CONFIG_NAME), + fs.read_text(os.path.join(ckpt_dir, CONFIG_NAME)), ) logger.info("Consolidated !") return consolidate_path def load_from_checkpoint( + fs: fsspec.AbstractFileSystem, ckpt_dir: str, model: nn.Module, - optimizer: Optional[torch.optim.Optimizer] = None, + optimizer: torch.optim.Optimizer | None = None, model_key: str = "model", optim_key: str = "optim", ): - if not (Path(ckpt_dir) / ".metadata").exists(): + if not fs.exists(os.path.join(ckpt_dir, ".metadata")): raise ValueError( f"Please convert the checkpoint distcp format using `torch.distributed.checkpoint.format_utils.torch_save_to_dcp` before loading it" ) @@ -121,13 +122,13 @@ class CheckpointManager: self.existing_saves = self.get_existing_saves() - def get_existing_saves(self) -> List[Path]: + def get_existing_saves(self) -> list[str]: folders = [ p - for p in Path(self.path).iterdir() - if p.is_dir() and re.match(RE_FOLDER, p.name) + for p in self.fs.ls(self.path) + if self.fs.isdir(p) and re.match(RE_FOLDER, os.path.basename(p)) ] - folders.sort(key=lambda p: _get_key_step(p.name)) + folders.sort(key=lambda p: _get_key_step(os.path.basename(p))) return folders def clean_up(self): @@ -136,8 +137,9 @@ class CheckpointManager: eval_folders = [] other_folders = [] for p in self.existing_saves: - is_dump = _get_key_step(p.name) % self.dump_every.every == 0 - is_eval = _get_key_step(p.name) % self.eval_every.every == 0 + assert isinstance(p, str), f"Base path type: {p}" + is_dump = _get_key_step(os.path.basename(p)) % self.dump_every.every == 0 + is_eval = _get_key_step(os.path.basename(p)) % self.eval_every.every == 0 if is_dump: dump_folders.append(p) if is_eval: @@ -161,40 +163,39 @@ class CheckpointManager: if dist.get_rank() == 0: for folder in folder_to_remove: - for file in folder.iterdir(): - if file.is_file(): - file.unlink() - elif file.is_dir(): - assert file.name in [CONSOLIDATE_FOLDER] - for f in file.iterdir(): - f.unlink() - file.rmdir() - folder.rmdir() + for file in self.fs.ls(folder): + if self.fs.isfile(file): + self.fs.rm_file(file) + elif self.fs.isdir(file): + assert os.path.name(file) in [CONSOLIDATE_FOLDER] + for f in self.fs.ls(file): + self.fs.rm(f) + self.fs.rmdir(file) + self.fs.rmdir(folder) dist.barrier() self.existing_saves = list(folder_to_keep) - self.existing_saves.sort(key=lambda p: _get_key_step(p.name)) + self.existing_saves.sort(key=lambda p: _get_key_step(os.path.basename(p))) - def get_last_step_path(self, dp_rank: int = 0) -> Optional[Path]: + def get_last_step_path(self, dp_rank: int = 0) -> str | None: path = None for p in reversed(self.existing_saves): - if (p / TRAIN_STATE_NAME.format(dp_rank)).is_file(): + + if self.fs.isfile(os.path.join(p, TRAIN_STATE_NAME.format(dp_rank))): path = p break return path - def _create_folder(self, base_path: Path, folder_name: str) -> Path: - folder = base_path / folder_name + def _create_folder(self, base_path: str, folder_name: str) -> str: + folder = os.path.join(base_path, folder_name) if get_is_master(): - folder.mkdir(parents=False, exist_ok=True) + self.fs.mkdirs(folder, exist_ok=True) if dist.is_initialized(): dist.barrier() return folder - def _get_dp_tp_mesh( - self, device_mesh: Optional[DeviceMesh] = None - ) -> Tuple[int, int]: + def _get_dp_tp_mesh(self, device_mesh: DeviceMesh | None = None) -> tuple[int, int]: dp_rank = 0 tp_rank = 0 if device_mesh is not None: @@ -222,14 +223,14 @@ class CheckpointManager: model, optimizer, train_state, - config, - device_mesh: Optional[DeviceMesh] = None, + config: BaseModel, + device_mesh: DeviceMesh | None = None, ) -> bool: # When creating directory check if only rank0 or is there other solution - path = Path(self.path) + path = self.path curr_save_dir = self._create_folder(path, FOLDER_NAME.format(train_state.step)) - logger.info(f"Saving to: {str(curr_save_dir)}") + logger.info(f"Saving to: {curr_save_dir}") if dist.is_initialized(): dist.barrier() @@ -242,17 +243,19 @@ class CheckpointManager: if dist.is_initialized(): dist.barrier() + print("config type", type(config)) if get_is_master(): - config.dump_to_yaml_file(curr_save_dir / CONFIG_NAME) + self.fs.write_text( + os.path.join(curr_save_dir, CONFIG_NAME), config.model_dump_json() + ) # Add json dump here dp_rank, tp_rank = self._get_dp_tp_mesh(device_mesh) if tp_rank == 0: train_state_name = TRAIN_STATE_NAME.format(dp_rank) - logger.info( - f"Saving train state to: {str(curr_save_dir / train_state_name)}" - ) - with open(curr_save_dir / train_state_name, "w") as f: + train_state_full_path = os.path.join(curr_save_dir, train_state_name) + logger.info(f"Saving train state to: {train_state_full_path}") + with self.fs.open(train_state_full_path, "w") as f: json.dump(train_state.state_dict(), f) logger.info("Train state saved !") @@ -271,7 +274,7 @@ class CheckpointManager: optimizer, train_state, device_mesh: DeviceMesh, - path: Optional[Path] = None, + path: str | None = None, ): dp_rank, tp_rank = self._get_dp_tp_mesh(device_mesh) # Loading tries to load the provided path, if not available the last saved step and finally from the init path @@ -284,12 +287,12 @@ class CheckpointManager: # Only load train state if it's provided, the files exist and we're not loading from init path train_state_name = TRAIN_STATE_NAME.format(dp_rank) logger.info("Reloading train state") - with open(path / train_state_name, "r") as f: + with self.fs.open(os.path.join(path, train_state_name), "r") as f: train_state_dict = json.load(f) train_state.load_state_dict(train_state_dict) logger.info("Train state reloaded") - logger.info(f"Loading from: {str(path)}") + logger.info(f"Loading from: {path}") state_dict = self.get_state_dict( model=model, optimizer=optimizer, From 9cf7847e2642c6971881daf13d24a9efa09f7e92 Mon Sep 17 00:00:00 2001 From: Pedro Rodriguez Date: Wed, 5 Feb 2025 00:50:20 +0000 Subject: [PATCH 24/59] Fix distributed all reduce grad norm Summary: With >1 GPU, but only 1 node, all reduces fail when inputs are not bf16. This uses a modified copy of torch's grad norm to avoid failures Test Plan: - Run unit tests: - Run single gpu training: `python -m bytelatent.train config=internal/configs/s3_debug.yaml eval=null checkpoint.dump.every=100` - Run 1 node, multi-gpu training `torchrun --nproc-per-node 8 -m bytelatent.train config=internal/configs/s3_debug.yaml eval=null checkpoint.dump.every=100` --- bytelatent/norms.py | 100 ++++++++++++++++++++++++++++++++++++++++++++ bytelatent/train.py | 36 ++++++++++++++-- 2 files changed, 133 insertions(+), 3 deletions(-) create mode 100644 bytelatent/norms.py diff --git a/bytelatent/norms.py b/bytelatent/norms.py new file mode 100644 index 0000000..81d1652 --- /dev/null +++ b/bytelatent/norms.py @@ -0,0 +1,100 @@ +from typing import Dict, List, Optional, Tuple + +import torch +from torch import Tensor +from torch.utils._foreach_utils import ( + _device_has_foreach_support, + _group_tensors_by_device_and_dtype, + _has_foreach_support, +) + + +@torch.no_grad() +def fixed_clip_grad_norm_( + parameters: torch.Tensor | list[torch.Tensor], + max_norm: float, + norm_type: float = 2.0, + error_if_nonfinite: bool = False, + foreach: Optional[bool] = None, +) -> torch.Tensor: + r"""Clip the gradient norm of an iterable of parameters. + + The norm is computed over the norms of the individual gradients of all parameters, + as if the norms of the individual gradients were concatenated into a single vector. + Gradients are modified in-place. + + Args: + parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a + single Tensor that will have gradients normalized + max_norm (float): max norm of the gradients + norm_type (float): type of the used p-norm. Can be ``'inf'`` for + infinity norm. + error_if_nonfinite (bool): if True, an error is thrown if the total + norm of the gradients from :attr:`parameters` is ``nan``, + ``inf``, or ``-inf``. Default: False (will switch to True in the future) + foreach (bool): use the faster foreach-based implementation. + If ``None``, use the foreach implementation for CUDA and CPU native tensors and silently + fall back to the slow implementation for other device types. + Default: ``None`` + + Returns: + Total norm of the parameter gradients (viewed as a single vector). + """ + if isinstance(parameters, torch.Tensor): + parameters = [parameters] + grads = [p.grad.to(torch.bfloat16) for p in parameters if p.grad is not None] + max_norm = float(max_norm) + norm_type = float(norm_type) + if len(grads) == 0: + return torch.tensor(0.0) + first_device = grads[0].device + grouped_grads: Dict[ + Tuple[torch.device, torch.dtype], Tuple[List[List[Tensor]], List[int]] + ] = _group_tensors_by_device_and_dtype( + [grads] + ) # type: ignore[assignment] + + norms: List[Tensor] = [] + for (device, _), ([device_grads], _) in grouped_grads.items(): # type: ignore[assignment] + if (foreach is None and _has_foreach_support(device_grads, device)) or ( + foreach and _device_has_foreach_support(device) + ): + norms.extend(torch._foreach_norm(device_grads, norm_type)) + elif foreach: + raise RuntimeError( + f"foreach=True was passed, but can't use the foreach API on {device.type} tensors" + ) + else: + norms.extend([torch.linalg.vector_norm(g, norm_type) for g in device_grads]) + + total_norm = torch.linalg.vector_norm( + torch.stack([norm.to(first_device) for norm in norms]), norm_type + ) + + if error_if_nonfinite and torch.logical_or(total_norm.isnan(), total_norm.isinf()): + raise RuntimeError( + f"The total norm of order {norm_type} for gradients from " + "`parameters` is non-finite, so it cannot be clipped. To disable " + "this error and scale the gradients by the non-finite norm anyway, " + "set `error_if_nonfinite=False`" + ) + clip_coef = max_norm / (total_norm + 1e-6) + # Note: multiplying by the clamped coef is redundant when the coef is clamped to 1, but doing so + # avoids a `if clip_coef < 1:` conditional which can require a CPU <=> device synchronization + # when the gradients do not reside in CPU memory. + clip_coef_clamped = torch.clamp(clip_coef, max=1.0) + for (device, _), ([device_grads], _) in grouped_grads.items(): # type: ignore[assignment] + if (foreach is None and _has_foreach_support(device_grads, device)) or ( + foreach and _device_has_foreach_support(device) + ): + torch._foreach_mul_(device_grads, clip_coef_clamped.to(device)) + elif foreach: + raise RuntimeError( + f"foreach=True was passed, but can't use the foreach API on {device.type} tensors" + ) + else: + clip_coef_clamped_device = clip_coef_clamped.to(device) + for g in device_grads: + g.mul_(clip_coef_clamped_device) + + return total_norm diff --git a/bytelatent/train.py b/bytelatent/train.py index 6b20ecd..8fd3808 100644 --- a/bytelatent/train.py +++ b/bytelatent/train.py @@ -47,6 +47,7 @@ from bytelatent.eval import EVAL_FOLDER_NAME, launch_eval from bytelatent.logger import init_logger from bytelatent.metrics import GPUMemoryMonitor, MetricLogger, get_num_params from bytelatent.model.blt import ByteLatentTransformer +from bytelatent.norms import fixed_clip_grad_norm_ from bytelatent.optim import build_optimizer from bytelatent.probe import AutoProbeD from bytelatent.profiling import maybe_run_profiler @@ -147,9 +148,26 @@ def validate_train_args(args: TrainArgs, output_size: int): * args.distributed.tp_size != get_world_size() ): + logging.info("Modifying TrainArgs distributed config") assert get_world_size() % args.distributed.dp_shard == 0 + logging.info("World size: %s", get_world_size()) + logging.info( + "Existing setting: train_args.distributed.dp_shard=%s", + args.distributed.dp_shard, + ) + logging.info( + "Setting train_args.distributed.dp_replicate=%s, was dp_replicate=%s", + get_world_size() // args.distributed.dp_shard, + args.distributed.dp_replicate, + ) args.distributed.dp_replicate = get_world_size() // args.distributed.dp_shard + logging.info( + "Changing dp_replicate from %s to %s, to account for tp_size=%s", + args.distributed.dp_replicate, + args.distributed.dp_replicate // args.distributed.tp_size, + args.distributed.tp_size, + ) assert args.distributed.dp_replicate % args.distributed.tp_size == 0 args.distributed.dp_replicate = ( args.distributed.dp_replicate // args.distributed.tp_size @@ -470,9 +488,21 @@ def train(args: TrainArgs): # For logging we undo that scaling loss = loss.detach() * args.grad_acc_steps - grad_norm = torch.nn.utils.clip_grad_norm_( - model.parameters(), max_norm=args.optim.clip, foreach=True - ) + + world_size = get_world_size() + if 1 < world_size <= 8: + # For some reason, there are errors in reduces due to + # not working for non-bf16 numbers. This function is a patched + # version that converts gradients to bf16 before computing norms. + # The error only happens in distributed training on one node, + # hence the guard + grad_norm = fixed_clip_grad_norm_( + model.parameters(), max_norm=args.optim.clip, foreach=True + ) + else: + grad_norm = torch.nn.utils.clip_grad_norm_( + model.parameters(), max_norm=args.optim.clip, foreach=True + ) grad_norm = ( grad_norm.full_tensor() if isinstance(grad_norm, DTensor) else grad_norm From ac257bac19006485919aaa222b284fe9d30ad2fa Mon Sep 17 00:00:00 2001 From: Pedro Rodriguez Date: Wed, 5 Feb 2025 00:50:20 +0000 Subject: [PATCH 25/59] Fix distributed all reduce grad norm Summary: With >1 GPU, but only 1 node, all reduces fail when inputs are not bf16. This uses a modified copy of torch's grad norm to avoid failures Test Plan: - Run unit tests: - Run single gpu training: `python -m bytelatent.train config=internal/configs/s3_debug.yaml eval=null checkpoint.dump.every=100` - Run 1 node, multi-gpu training `torchrun --nproc-per-node 8 -m bytelatent.train config=internal/configs/s3_debug.yaml eval=null checkpoint.dump.every=100` --- bytelatent/norms.py | 100 ++++++++++++++++++++++++++++++++++++++++++++ bytelatent/train.py | 35 ++++++++++++++-- 2 files changed, 132 insertions(+), 3 deletions(-) create mode 100644 bytelatent/norms.py diff --git a/bytelatent/norms.py b/bytelatent/norms.py new file mode 100644 index 0000000..81d1652 --- /dev/null +++ b/bytelatent/norms.py @@ -0,0 +1,100 @@ +from typing import Dict, List, Optional, Tuple + +import torch +from torch import Tensor +from torch.utils._foreach_utils import ( + _device_has_foreach_support, + _group_tensors_by_device_and_dtype, + _has_foreach_support, +) + + +@torch.no_grad() +def fixed_clip_grad_norm_( + parameters: torch.Tensor | list[torch.Tensor], + max_norm: float, + norm_type: float = 2.0, + error_if_nonfinite: bool = False, + foreach: Optional[bool] = None, +) -> torch.Tensor: + r"""Clip the gradient norm of an iterable of parameters. + + The norm is computed over the norms of the individual gradients of all parameters, + as if the norms of the individual gradients were concatenated into a single vector. + Gradients are modified in-place. + + Args: + parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a + single Tensor that will have gradients normalized + max_norm (float): max norm of the gradients + norm_type (float): type of the used p-norm. Can be ``'inf'`` for + infinity norm. + error_if_nonfinite (bool): if True, an error is thrown if the total + norm of the gradients from :attr:`parameters` is ``nan``, + ``inf``, or ``-inf``. Default: False (will switch to True in the future) + foreach (bool): use the faster foreach-based implementation. + If ``None``, use the foreach implementation for CUDA and CPU native tensors and silently + fall back to the slow implementation for other device types. + Default: ``None`` + + Returns: + Total norm of the parameter gradients (viewed as a single vector). + """ + if isinstance(parameters, torch.Tensor): + parameters = [parameters] + grads = [p.grad.to(torch.bfloat16) for p in parameters if p.grad is not None] + max_norm = float(max_norm) + norm_type = float(norm_type) + if len(grads) == 0: + return torch.tensor(0.0) + first_device = grads[0].device + grouped_grads: Dict[ + Tuple[torch.device, torch.dtype], Tuple[List[List[Tensor]], List[int]] + ] = _group_tensors_by_device_and_dtype( + [grads] + ) # type: ignore[assignment] + + norms: List[Tensor] = [] + for (device, _), ([device_grads], _) in grouped_grads.items(): # type: ignore[assignment] + if (foreach is None and _has_foreach_support(device_grads, device)) or ( + foreach and _device_has_foreach_support(device) + ): + norms.extend(torch._foreach_norm(device_grads, norm_type)) + elif foreach: + raise RuntimeError( + f"foreach=True was passed, but can't use the foreach API on {device.type} tensors" + ) + else: + norms.extend([torch.linalg.vector_norm(g, norm_type) for g in device_grads]) + + total_norm = torch.linalg.vector_norm( + torch.stack([norm.to(first_device) for norm in norms]), norm_type + ) + + if error_if_nonfinite and torch.logical_or(total_norm.isnan(), total_norm.isinf()): + raise RuntimeError( + f"The total norm of order {norm_type} for gradients from " + "`parameters` is non-finite, so it cannot be clipped. To disable " + "this error and scale the gradients by the non-finite norm anyway, " + "set `error_if_nonfinite=False`" + ) + clip_coef = max_norm / (total_norm + 1e-6) + # Note: multiplying by the clamped coef is redundant when the coef is clamped to 1, but doing so + # avoids a `if clip_coef < 1:` conditional which can require a CPU <=> device synchronization + # when the gradients do not reside in CPU memory. + clip_coef_clamped = torch.clamp(clip_coef, max=1.0) + for (device, _), ([device_grads], _) in grouped_grads.items(): # type: ignore[assignment] + if (foreach is None and _has_foreach_support(device_grads, device)) or ( + foreach and _device_has_foreach_support(device) + ): + torch._foreach_mul_(device_grads, clip_coef_clamped.to(device)) + elif foreach: + raise RuntimeError( + f"foreach=True was passed, but can't use the foreach API on {device.type} tensors" + ) + else: + clip_coef_clamped_device = clip_coef_clamped.to(device) + for g in device_grads: + g.mul_(clip_coef_clamped_device) + + return total_norm diff --git a/bytelatent/train.py b/bytelatent/train.py index 6b20ecd..86d1c7a 100644 --- a/bytelatent/train.py +++ b/bytelatent/train.py @@ -47,6 +47,7 @@ from bytelatent.eval import EVAL_FOLDER_NAME, launch_eval from bytelatent.logger import init_logger from bytelatent.metrics import GPUMemoryMonitor, MetricLogger, get_num_params from bytelatent.model.blt import ByteLatentTransformer +from bytelatent.norms import fixed_clip_grad_norm_ from bytelatent.optim import build_optimizer from bytelatent.probe import AutoProbeD from bytelatent.profiling import maybe_run_profiler @@ -147,9 +148,26 @@ def validate_train_args(args: TrainArgs, output_size: int): * args.distributed.tp_size != get_world_size() ): + logging.info("Modifying TrainArgs distributed config") assert get_world_size() % args.distributed.dp_shard == 0 + logging.info("World size: %s", get_world_size()) + logging.info( + "Existing setting: train_args.distributed.dp_shard=%s", + args.distributed.dp_shard, + ) + logging.info( + "Setting train_args.distributed.dp_replicate=%s, was dp_replicate=%s", + get_world_size() // args.distributed.dp_shard, + args.distributed.dp_replicate, + ) args.distributed.dp_replicate = get_world_size() // args.distributed.dp_shard + logging.info( + "Changing dp_replicate from %s to %s, to account for tp_size=%s", + args.distributed.dp_replicate, + args.distributed.dp_replicate // args.distributed.tp_size, + args.distributed.tp_size, + ) assert args.distributed.dp_replicate % args.distributed.tp_size == 0 args.distributed.dp_replicate = ( args.distributed.dp_replicate // args.distributed.tp_size @@ -470,9 +488,20 @@ def train(args: TrainArgs): # For logging we undo that scaling loss = loss.detach() * args.grad_acc_steps - grad_norm = torch.nn.utils.clip_grad_norm_( - model.parameters(), max_norm=args.optim.clip, foreach=True - ) + world_size = get_world_size() + if 1 < world_size <= 8: + # For some reason, there are errors in reduces due to + # not working for non-bf16 numbers. This function is a patched + # version that converts gradients to bf16 before computing norms. + # The error only happens in distributed training on one node, + # hence the guard + grad_norm = fixed_clip_grad_norm_( + model.parameters(), max_norm=args.optim.clip, foreach=True + ) + else: + grad_norm = torch.nn.utils.clip_grad_norm_( + model.parameters(), max_norm=args.optim.clip, foreach=True + ) grad_norm = ( grad_norm.full_tensor() if isinstance(grad_norm, DTensor) else grad_norm From b6e53f1d4c2418775cc2ad4050e06b6ac0d6c401 Mon Sep 17 00:00:00 2001 From: Pedro Rodriguez Date: Wed, 5 Feb 2025 00:55:18 +0000 Subject: [PATCH 26/59] Update checkpointing to use fsspec Summary: - Make the data/checkpoint code fsspec compatible Test Plan: Run unit tests and the commands below ``` python -m bytelatent.train config=internal/configs/s3_debug.yaml eval=null checkpoint.dump.every=100 ``` ``` torchrun --nproc-per-node 8 -m bytelatent.train config=internal/configs/s3_debug.yaml eval=null checkpoint.dump.every=100 ``` ``` python -m bytelatent.train config=internal/configs/s3_debug.yaml eval=null checkpoint.dump.every=100 dump_dir=s3://blt/scratch/checkpoint-test/ ``` ``` torchrun --nproc-per-node 8 -m bytelatent.train config=internal/configs/s3_debug.yaml eval=null checkpoint.dump.every=100 dump_dir=s3://blt/scratch/checkpoint-test/ ``` --- bytelatent/checkpoint.py | 97 +++++++++++++++++++++------------------- bytelatent/train.py | 6 ++- 2 files changed, 55 insertions(+), 48 deletions(-) diff --git a/bytelatent/checkpoint.py b/bytelatent/checkpoint.py index f213c84..6631673 100644 --- a/bytelatent/checkpoint.py +++ b/bytelatent/checkpoint.py @@ -4,8 +4,6 @@ import json import logging import os import re -from pathlib import Path -from typing import List, Optional, Tuple import fsspec import torch @@ -70,26 +68,29 @@ def consolidate_checkpoints(fs: fsspec.AbstractFileSystem, ckpt_dir: str): Returns the path to the consolidated checkpoint """ - consolidate_path = Path(ckpt_dir) / CONSOLIDATE_FOLDER - if not (consolidate_path / CONSOLIDATE_NAME).exists(): - consolidate_path.mkdir(exist_ok=True) - logger.info(f"Consolidating to: {str(consolidate_path)}") - dcp_to_torch_save(ckpt_dir, str(consolidate_path / CONSOLIDATE_NAME)) - (consolidate_path / CONFIG_NAME).write_text( - (Path(ckpt_dir) / CONFIG_NAME).read_text() + consolidate_path = os.path.join(ckpt_dir, CONSOLIDATE_FOLDER) + consolidate_name = os.path.join(consolidate_path, CONSOLIDATE_NAME) + if not fs.exists(consolidate_name): + fs.mkdirs(consolidate_path, exist_ok=True) + logger.info(f"Consolidating to: {consolidate_path}") + dcp_to_torch_save(ckpt_dir, consolidate_name) + fs.write_text( + os.path.join(consolidate_path, CONFIG_NAME), + fs.read_text(os.path.join(ckpt_dir, CONFIG_NAME)), ) logger.info("Consolidated !") return consolidate_path def load_from_checkpoint( + fs: fsspec.AbstractFileSystem, ckpt_dir: str, model: nn.Module, - optimizer: Optional[torch.optim.Optimizer] = None, + optimizer: torch.optim.Optimizer | None = None, model_key: str = "model", optim_key: str = "optim", ): - if not (Path(ckpt_dir) / ".metadata").exists(): + if not fs.exists(os.path.join(ckpt_dir, ".metadata")): raise ValueError( f"Please convert the checkpoint distcp format using `torch.distributed.checkpoint.format_utils.torch_save_to_dcp` before loading it" ) @@ -121,13 +122,13 @@ class CheckpointManager: self.existing_saves = self.get_existing_saves() - def get_existing_saves(self) -> List[Path]: + def get_existing_saves(self) -> list[str]: folders = [ p - for p in Path(self.path).iterdir() - if p.is_dir() and re.match(RE_FOLDER, p.name) + for p in self.fs.ls(self.path) + if self.fs.isdir(p) and re.match(RE_FOLDER, os.path.basename(p)) ] - folders.sort(key=lambda p: _get_key_step(p.name)) + folders.sort(key=lambda p: _get_key_step(os.path.basename(p))) return folders def clean_up(self): @@ -136,8 +137,9 @@ class CheckpointManager: eval_folders = [] other_folders = [] for p in self.existing_saves: - is_dump = _get_key_step(p.name) % self.dump_every.every == 0 - is_eval = _get_key_step(p.name) % self.eval_every.every == 0 + assert isinstance(p, str), f"Base path type: {p}" + is_dump = _get_key_step(os.path.basename(p)) % self.dump_every.every == 0 + is_eval = _get_key_step(os.path.basename(p)) % self.eval_every.every == 0 if is_dump: dump_folders.append(p) if is_eval: @@ -161,40 +163,39 @@ class CheckpointManager: if dist.get_rank() == 0: for folder in folder_to_remove: - for file in folder.iterdir(): - if file.is_file(): - file.unlink() - elif file.is_dir(): - assert file.name in [CONSOLIDATE_FOLDER] - for f in file.iterdir(): - f.unlink() - file.rmdir() - folder.rmdir() + for file in self.fs.ls(folder): + if self.fs.isfile(file): + self.fs.rm_file(file) + elif self.fs.isdir(file): + assert os.path.name(file) in [CONSOLIDATE_FOLDER] + for f in self.fs.ls(file): + self.fs.rm(f) + self.fs.rmdir(file) + self.fs.rmdir(folder) dist.barrier() self.existing_saves = list(folder_to_keep) - self.existing_saves.sort(key=lambda p: _get_key_step(p.name)) + self.existing_saves.sort(key=lambda p: _get_key_step(os.path.basename(p))) - def get_last_step_path(self, dp_rank: int = 0) -> Optional[Path]: + def get_last_step_path(self, dp_rank: int = 0) -> str | None: path = None for p in reversed(self.existing_saves): - if (p / TRAIN_STATE_NAME.format(dp_rank)).is_file(): + + if self.fs.isfile(os.path.join(p, TRAIN_STATE_NAME.format(dp_rank))): path = p break return path - def _create_folder(self, base_path: Path, folder_name: str) -> Path: - folder = base_path / folder_name + def _create_folder(self, base_path: str, folder_name: str) -> str: + folder = os.path.join(base_path, folder_name) if get_is_master(): - folder.mkdir(parents=False, exist_ok=True) + self.fs.mkdirs(folder, exist_ok=True) if dist.is_initialized(): dist.barrier() return folder - def _get_dp_tp_mesh( - self, device_mesh: Optional[DeviceMesh] = None - ) -> Tuple[int, int]: + def _get_dp_tp_mesh(self, device_mesh: DeviceMesh | None = None) -> tuple[int, int]: dp_rank = 0 tp_rank = 0 if device_mesh is not None: @@ -222,14 +223,14 @@ class CheckpointManager: model, optimizer, train_state, - config, - device_mesh: Optional[DeviceMesh] = None, + config: BaseModel, + device_mesh: DeviceMesh | None = None, ) -> bool: # When creating directory check if only rank0 or is there other solution - path = Path(self.path) + path = self.path curr_save_dir = self._create_folder(path, FOLDER_NAME.format(train_state.step)) - logger.info(f"Saving to: {str(curr_save_dir)}") + logger.info(f"Saving to: {curr_save_dir}") if dist.is_initialized(): dist.barrier() @@ -242,17 +243,19 @@ class CheckpointManager: if dist.is_initialized(): dist.barrier() + print("config type", type(config)) if get_is_master(): - config.dump_to_yaml_file(curr_save_dir / CONFIG_NAME) + self.fs.write_text( + os.path.join(curr_save_dir, CONFIG_NAME), config.model_dump_json() + ) # Add json dump here dp_rank, tp_rank = self._get_dp_tp_mesh(device_mesh) if tp_rank == 0: train_state_name = TRAIN_STATE_NAME.format(dp_rank) - logger.info( - f"Saving train state to: {str(curr_save_dir / train_state_name)}" - ) - with open(curr_save_dir / train_state_name, "w") as f: + train_state_full_path = os.path.join(curr_save_dir, train_state_name) + logger.info(f"Saving train state to: {train_state_full_path}") + with self.fs.open(train_state_full_path, "w") as f: json.dump(train_state.state_dict(), f) logger.info("Train state saved !") @@ -271,7 +274,7 @@ class CheckpointManager: optimizer, train_state, device_mesh: DeviceMesh, - path: Optional[Path] = None, + path: str | None = None, ): dp_rank, tp_rank = self._get_dp_tp_mesh(device_mesh) # Loading tries to load the provided path, if not available the last saved step and finally from the init path @@ -284,12 +287,12 @@ class CheckpointManager: # Only load train state if it's provided, the files exist and we're not loading from init path train_state_name = TRAIN_STATE_NAME.format(dp_rank) logger.info("Reloading train state") - with open(path / train_state_name, "r") as f: + with self.fs.open(os.path.join(path, train_state_name), "r") as f: train_state_dict = json.load(f) train_state.load_state_dict(train_state_dict) logger.info("Train state reloaded") - logger.info(f"Loading from: {str(path)}") + logger.info(f"Loading from: {path}") state_dict = self.get_state_dict( model=model, optimizer=optimizer, diff --git a/bytelatent/train.py b/bytelatent/train.py index 86d1c7a..c80a74c 100644 --- a/bytelatent/train.py +++ b/bytelatent/train.py @@ -25,6 +25,7 @@ from torch.optim import lr_scheduler from bytelatent.args import TrainArgs, parse_args from bytelatent.checkpoint import CheckpointManager, load_from_checkpoint +from bytelatent.data.file_util import get_fs from bytelatent.data.iterators.multiprocess_iterator import ( MultiprocessIterator, MultiprocessIteratorState, @@ -313,8 +314,11 @@ def train(args: TrainArgs): if args.checkpoint.init_ckpt_path: logger.info(f"Loading initial model from {args.checkpoint.init_ckpt_path}") + ckpt_fs = get_fs( + args.checkpoint.init_ckpt_path, s3_profile=args.checkpoint.s3_profile + ) load_from_checkpoint( - args.checkpoint.init_ckpt_path, model, model_key="model" + ckpt_fs, args.checkpoint.init_ckpt_path, model, model_key="model" ) # Put model_key="" if its directly the model checkpoint model.rope_embeddings.reset_parameters() # For RoPe initialization since it's a buffer it might not be loaded else: From 1450464031da3202bb35e280f1a2323cd3a43b49 Mon Sep 17 00:00:00 2001 From: Pedro Rodriguez Date: Wed, 5 Feb 2025 19:09:13 +0000 Subject: [PATCH 27/59] Update checkpointing to use fsspec Summary: - Make the data/checkpoint code fsspec compatible Test Plan: Run unit tests and the commands below ``` python -m bytelatent.train config=internal/configs/s3_debug.yaml eval=null checkpoint.dump.every=100 ``` ``` torchrun --nproc-per-node 8 -m bytelatent.train config=internal/configs/s3_debug.yaml eval=null checkpoint.dump.every=100 ``` ``` python -m bytelatent.train config=internal/configs/s3_debug.yaml eval=null checkpoint.dump.every=100 dump_dir=s3://blt/scratch/checkpoint-test/ ``` ``` torchrun --nproc-per-node 8 -m bytelatent.train config=internal/configs/s3_debug.yaml eval=null checkpoint.dump.every=100 dump_dir=s3://blt/scratch/checkpoint-test/ ``` --- bytelatent/args.py | 14 +++-- bytelatent/checkpoint.py | 115 +++++++++++++++++++++------------------ bytelatent/logger.py | 9 ++- bytelatent/metrics.py | 15 ++++- bytelatent/train.py | 31 +++++++---- 5 files changed, 112 insertions(+), 72 deletions(-) diff --git a/bytelatent/args.py b/bytelatent/args.py index d1bac46..fc72b32 100644 --- a/bytelatent/args.py +++ b/bytelatent/args.py @@ -294,6 +294,14 @@ class TrainArgs(BaseModel): def dump_to_yaml_file( self, path: str, log_config: bool = True, sort_keys: bool = True ): + yaml_str = self.dump_to_yaml_str(sort_keys=sort_keys) + with open(path, "w") as f: + if log_config: + logger.info("Using the following config for this run:") + logger.info(yaml_str) + f.write(yaml_str) + + def dump_to_yaml_str(self, sort_keys: bool = True): model_dict = self.model_dump(mode="json") yaml_str = yaml.dump( model_dict, @@ -301,8 +309,4 @@ class TrainArgs(BaseModel): sort_keys=sort_keys, default_flow_style=False, ) - with open(path, "w") as f: - if log_config: - logger.info("Using the following config for this run:") - logger.info(yaml_str) - f.write(yaml_str) + return yaml_str diff --git a/bytelatent/checkpoint.py b/bytelatent/checkpoint.py index f213c84..1668c88 100644 --- a/bytelatent/checkpoint.py +++ b/bytelatent/checkpoint.py @@ -4,10 +4,9 @@ import json import logging import os import re -from pathlib import Path -from typing import List, Optional, Tuple import fsspec +import s3fs import torch import torch.distributed as dist import torch.distributed.checkpoint as dcp @@ -70,26 +69,29 @@ def consolidate_checkpoints(fs: fsspec.AbstractFileSystem, ckpt_dir: str): Returns the path to the consolidated checkpoint """ - consolidate_path = Path(ckpt_dir) / CONSOLIDATE_FOLDER - if not (consolidate_path / CONSOLIDATE_NAME).exists(): - consolidate_path.mkdir(exist_ok=True) - logger.info(f"Consolidating to: {str(consolidate_path)}") - dcp_to_torch_save(ckpt_dir, str(consolidate_path / CONSOLIDATE_NAME)) - (consolidate_path / CONFIG_NAME).write_text( - (Path(ckpt_dir) / CONFIG_NAME).read_text() + consolidate_path = os.path.join(ckpt_dir, CONSOLIDATE_FOLDER) + consolidate_name = os.path.join(consolidate_path, CONSOLIDATE_NAME) + if not fs.exists(consolidate_name): + fs.mkdirs(consolidate_path, exist_ok=True) + logger.info(f"Consolidating to: {consolidate_path}") + dcp_to_torch_save(ckpt_dir, consolidate_name) + fs.write_text( + os.path.join(consolidate_path, CONFIG_NAME), + fs.read_text(os.path.join(ckpt_dir, CONFIG_NAME)), ) logger.info("Consolidated !") return consolidate_path def load_from_checkpoint( + fs: fsspec.AbstractFileSystem, ckpt_dir: str, model: nn.Module, - optimizer: Optional[torch.optim.Optimizer] = None, + optimizer: torch.optim.Optimizer | None = None, model_key: str = "model", optim_key: str = "optim", ): - if not (Path(ckpt_dir) / ".metadata").exists(): + if not fs.exists(os.path.join(ckpt_dir, ".metadata")): raise ValueError( f"Please convert the checkpoint distcp format using `torch.distributed.checkpoint.format_utils.torch_save_to_dcp` before loading it" ) @@ -115,19 +117,24 @@ class CheckpointManager: self.init_ckpt_path = args.init_ckpt_path self.continue_training_from_init = args.continue_training_from_init - assert self.fs.exists( - self.path - ), f"Path {self.path} does not exist and needs to be created before using CheckpointManager (use instantiate_and_make_dir)" + if not isinstance(self.fs, s3fs.S3FileSystem): + # S3 does not have a concept of directories + assert self.fs.exists( + self.path + ), f"Path {self.path} does not exist and needs to be created before using CheckpointManager (use instantiate_and_make_dir)" self.existing_saves = self.get_existing_saves() - def get_existing_saves(self) -> List[Path]: - folders = [ - p - for p in Path(self.path).iterdir() - if p.is_dir() and re.match(RE_FOLDER, p.name) - ] - folders.sort(key=lambda p: _get_key_step(p.name)) + def get_existing_saves(self) -> list[str]: + if self.fs.exists(self.path) and self.fs.isdir(self.path): + folders = [ + p + for p in self.fs.ls(self.path) + if self.fs.isdir(p) and re.match(RE_FOLDER, os.path.basename(p)) + ] + else: + folders = [] + folders.sort(key=lambda p: _get_key_step(os.path.basename(p))) return folders def clean_up(self): @@ -136,8 +143,9 @@ class CheckpointManager: eval_folders = [] other_folders = [] for p in self.existing_saves: - is_dump = _get_key_step(p.name) % self.dump_every.every == 0 - is_eval = _get_key_step(p.name) % self.eval_every.every == 0 + assert isinstance(p, str), f"Base path type: {p}" + is_dump = _get_key_step(os.path.basename(p)) % self.dump_every.every == 0 + is_eval = _get_key_step(os.path.basename(p)) % self.eval_every.every == 0 if is_dump: dump_folders.append(p) if is_eval: @@ -161,40 +169,39 @@ class CheckpointManager: if dist.get_rank() == 0: for folder in folder_to_remove: - for file in folder.iterdir(): - if file.is_file(): - file.unlink() - elif file.is_dir(): - assert file.name in [CONSOLIDATE_FOLDER] - for f in file.iterdir(): - f.unlink() - file.rmdir() - folder.rmdir() + for file in self.fs.ls(folder): + if self.fs.isfile(file): + self.fs.rm_file(file) + elif self.fs.isdir(file): + assert os.path.name(file) in [CONSOLIDATE_FOLDER] + for f in self.fs.ls(file): + self.fs.rm(f) + self.fs.rmdir(file) + self.fs.rmdir(folder) dist.barrier() self.existing_saves = list(folder_to_keep) - self.existing_saves.sort(key=lambda p: _get_key_step(p.name)) + self.existing_saves.sort(key=lambda p: _get_key_step(os.path.basename(p))) - def get_last_step_path(self, dp_rank: int = 0) -> Optional[Path]: + def get_last_step_path(self, dp_rank: int = 0) -> str | None: path = None for p in reversed(self.existing_saves): - if (p / TRAIN_STATE_NAME.format(dp_rank)).is_file(): + + if self.fs.isfile(os.path.join(p, TRAIN_STATE_NAME.format(dp_rank))): path = p break return path - def _create_folder(self, base_path: Path, folder_name: str) -> Path: - folder = base_path / folder_name + def _create_folder(self, base_path: str, folder_name: str) -> str: + folder = os.path.join(base_path, folder_name) if get_is_master(): - folder.mkdir(parents=False, exist_ok=True) + self.fs.mkdirs(folder, exist_ok=True) if dist.is_initialized(): dist.barrier() return folder - def _get_dp_tp_mesh( - self, device_mesh: Optional[DeviceMesh] = None - ) -> Tuple[int, int]: + def _get_dp_tp_mesh(self, device_mesh: DeviceMesh | None = None) -> tuple[int, int]: dp_rank = 0 tp_rank = 0 if device_mesh is not None: @@ -222,14 +229,14 @@ class CheckpointManager: model, optimizer, train_state, - config, - device_mesh: Optional[DeviceMesh] = None, + config: BaseModel, + device_mesh: DeviceMesh | None = None, ) -> bool: # When creating directory check if only rank0 or is there other solution - path = Path(self.path) + path = self.path curr_save_dir = self._create_folder(path, FOLDER_NAME.format(train_state.step)) - logger.info(f"Saving to: {str(curr_save_dir)}") + logger.info(f"Saving to: {curr_save_dir}") if dist.is_initialized(): dist.barrier() @@ -242,17 +249,19 @@ class CheckpointManager: if dist.is_initialized(): dist.barrier() + print("config type", type(config)) if get_is_master(): - config.dump_to_yaml_file(curr_save_dir / CONFIG_NAME) + self.fs.write_text( + os.path.join(curr_save_dir, CONFIG_NAME), config.model_dump_json() + ) # Add json dump here dp_rank, tp_rank = self._get_dp_tp_mesh(device_mesh) if tp_rank == 0: train_state_name = TRAIN_STATE_NAME.format(dp_rank) - logger.info( - f"Saving train state to: {str(curr_save_dir / train_state_name)}" - ) - with open(curr_save_dir / train_state_name, "w") as f: + train_state_full_path = os.path.join(curr_save_dir, train_state_name) + logger.info(f"Saving train state to: {train_state_full_path}") + with self.fs.open(train_state_full_path, "w") as f: json.dump(train_state.state_dict(), f) logger.info("Train state saved !") @@ -271,7 +280,7 @@ class CheckpointManager: optimizer, train_state, device_mesh: DeviceMesh, - path: Optional[Path] = None, + path: str | None = None, ): dp_rank, tp_rank = self._get_dp_tp_mesh(device_mesh) # Loading tries to load the provided path, if not available the last saved step and finally from the init path @@ -284,12 +293,12 @@ class CheckpointManager: # Only load train state if it's provided, the files exist and we're not loading from init path train_state_name = TRAIN_STATE_NAME.format(dp_rank) logger.info("Reloading train state") - with open(path / train_state_name, "r") as f: + with self.fs.open(os.path.join(path, train_state_name), "r") as f: train_state_dict = json.load(f) train_state.load_state_dict(train_state_dict) logger.info("Train state reloaded") - logger.info(f"Loading from: {str(path)}") + logger.info(f"Loading from: {path}") state_dict = self.get_state_dict( model=model, optimizer=optimizer, diff --git a/bytelatent/logger.py b/bytelatent/logger.py index 87f04cc..6f9a397 100644 --- a/bytelatent/logger.py +++ b/bytelatent/logger.py @@ -6,6 +6,8 @@ import sys import time from datetime import timedelta +import fsspec + from bytelatent.distributed import get_global_rank, get_is_slurm_job @@ -92,6 +94,7 @@ def init_logger( *, name: str | None = None, level: str = "INFO", + fs: fsspec.AbstractFileSystem | None = None, ): """ Setup logging. @@ -121,7 +124,11 @@ def init_logger( if log_file is not None and get_global_rank() == 0: # build file handler - file_handler = logging.FileHandler(log_file, "a") + if fs is None: + file_handler = logging.FileHandler(log_file, "a") + else: + file_stream = fs.open(log_file, mode="a") + file_handler = logging.StreamHandler(file_stream) file_handler.setLevel(logging.NOTSET) file_handler.setFormatter(LogFormatter()) # update logger diff --git a/bytelatent/metrics.py b/bytelatent/metrics.py index 77dc4d7..2df5be7 100644 --- a/bytelatent/metrics.py +++ b/bytelatent/metrics.py @@ -9,6 +9,7 @@ from datetime import datetime, timezone from pathlib import Path from typing import Any, Union +import fsspec import torch import torch.nn as nn import wandb @@ -54,14 +55,24 @@ class LoggingArgs(BaseModel): class MetricLogger: - def __init__(self, outdir: Path, args: Any | None = None): + def __init__( + self, + outdir: Path, + # args: TrainArgs + args: Any | None = None, + fs: fsspec.AbstractFileSystem | None = None, + ): self.outdir = outdir self.jsonl_writer = None + self.fs = fs self.args = args def open(self): if self.jsonl_writer is None: - self.jsonl_writer = open(self.outdir, "a") + if self.fs is None: + self.jsonl_writer = open(self.outdir, "a") + else: + self.jsonl_writer = self.fs.open(self.outdir, "a") if ( self.args is not None and self.args.logging.wandb is not None diff --git a/bytelatent/train.py b/bytelatent/train.py index 86d1c7a..2c3ea01 100644 --- a/bytelatent/train.py +++ b/bytelatent/train.py @@ -8,7 +8,6 @@ import sys from contextlib import ExitStack from copy import deepcopy from dataclasses import asdict, dataclass -from pathlib import Path from timeit import default_timer as timer from typing import Any, TypeVar @@ -18,13 +17,13 @@ import torch.nn.functional import torch.nn.functional as F import wandb import xformers.profiler -from omegaconf import OmegaConf from torch.distributed._tensor import DTensor from torch.distributed.checkpoint.stateful import Stateful from torch.optim import lr_scheduler from bytelatent.args import TrainArgs, parse_args from bytelatent.checkpoint import CheckpointManager, load_from_checkpoint +from bytelatent.data.file_util import get_fs from bytelatent.data.iterators.multiprocess_iterator import ( MultiprocessIterator, MultiprocessIteratorState, @@ -136,11 +135,12 @@ def validate_train_args(args: TrainArgs, output_size: int): if args.checkpoint.path is None: logger.info(f"Setting checkpoint path to {args.checkpoint.path}") - args.checkpoint.path = str(Path(args.dump_dir) / "checkpoints") + args.checkpoint.path = os.path.join(args.dump_dir, "checkpoints") + data_fs = get_fs(args.data.root_dir, s3_profile=args.data.s3_profile) for source in args.data.sources: data_path = os.path.join(args.data.root_dir, source) - assert os.path.exists(data_path), f"{data_path} doesn't exist" + assert data_fs.exists(data_path), f"{data_path} doesn't exist" if ( args.distributed.dp_replicate @@ -255,10 +255,15 @@ def train(args: TrainArgs): args, tokenizer.n_words, ) + dump_fs = get_fs(args.dump_dir, s3_profile=args.checkpoint.s3_profile) if get_is_master(): - os.makedirs(args.dump_dir, exist_ok=True) - args.dump_to_yaml_file(Path(args.dump_dir) / "config.yaml") - init_logger(Path(args.dump_dir) / "train.log") + dump_fs.mkdirs(args.dump_dir, exist_ok=True) + config_yaml_str = args.dump_to_yaml_str() + logging.info("TrainArgs: \n%s", config_yaml_str) + dump_fs.write_text( + os.path.join(args.dump_dir, "config.yaml"), config_yaml_str + ) + init_logger(os.path.join(args.dump_dir, "train.log"), fs=dump_fs) init_signal_handler(set_preemption_flag) # For handling preemption signals. setup_env(args.env) setup_torch_distributed(args.distributed) @@ -313,8 +318,11 @@ def train(args: TrainArgs): if args.checkpoint.init_ckpt_path: logger.info(f"Loading initial model from {args.checkpoint.init_ckpt_path}") + ckpt_fs = get_fs( + args.checkpoint.init_ckpt_path, s3_profile=args.checkpoint.s3_profile + ) load_from_checkpoint( - args.checkpoint.init_ckpt_path, model, model_key="model" + ckpt_fs, args.checkpoint.init_ckpt_path, model, model_key="model" ) # Put model_key="" if its directly the model checkpoint model.rope_embeddings.reset_parameters() # For RoPe initialization since it's a buffer it might not be loaded else: @@ -351,13 +359,14 @@ def train(args: TrainArgs): checkpoint.load(model, optimizer, train_state, world_mesh) # Either load from latest checkpoint or start from scratch if args.probe_freq is not None: + # TODO: Convert this to fsspec compatible if get_is_master(): - os.makedirs(Path(args.dump_dir) / "probe", exist_ok=True) + os.makedirs(os.path.join(args.dump_dir, "probe"), exist_ok=True) torch.distributed.barrier() probe = AutoProbeD( model, ( - Path(args.dump_dir) / "probe" / f"probe.{dp_rank}.jsonl" + os.path.join(args.dump_dir, "probe", f"probe.{dp_rank}.jsonl") if (dp_rank % 128 == 0) else None ), @@ -369,7 +378,7 @@ def train(args: TrainArgs): # train loop model.train() metric_logger = context_stack.enter_context( - MetricLogger(Path(args.dump_dir) / "metrics.jsonl", args) + MetricLogger(os.path.join(args.dump_dir, "metrics.jsonl"), args, fs=dump_fs) ) data_loader = train_state.data_loader_state.build() batch_iterator = data_loader.create_iter() From 2f42633b078b6b0fe5b6518f395704452436ad76 Mon Sep 17 00:00:00 2001 From: Pedro Rodriguez Date: Wed, 5 Feb 2025 22:26:31 +0000 Subject: [PATCH 28/59] Add bpb and n_bytes to metric logging Summary: Test Plan: --- bytelatent/distributed.py | 16 +++++++++++--- bytelatent/train.py | 46 +++++++++++++++++++++++++++++++++++++-- 2 files changed, 57 insertions(+), 5 deletions(-) diff --git a/bytelatent/distributed.py b/bytelatent/distributed.py index 168cb7c..b3f31a6 100644 --- a/bytelatent/distributed.py +++ b/bytelatent/distributed.py @@ -127,6 +127,16 @@ def dist_max(x: Union[int, float], mesh: DeviceMesh = None): return tensor +def dist_sum( + x: Union[int, float], mesh: DeviceMesh = None, reduce_dtype: torch.dtype = None +): + tensor = torch.tensor(x).cuda() + if reduce_dtype is not None: + tensor = tensor.to(reduce_dtype) + dist.all_reduce(tensor, op=ReduceOp.SUM, group=mesh.get_group() if mesh else None) + return tensor + + def dist_mean(x: Union[int, float], mesh: DeviceMesh = None): tensor = torch.tensor(x).cuda() dist.all_reduce(tensor, op=ReduceOp.AVG, group=mesh.get_group() if mesh else None) @@ -236,7 +246,7 @@ def setup_env(env_args: EnvironmentArgs): logger.warning(f"WARNING: Setting {name} to {value}") -def setup_torch_distributed(dist_args): +def setup_torch_distributed(dist_args: DistributedArgs): """ Handle single and multi-GPU / multi-node / SLURM jobs. Initialize the following variables: @@ -388,14 +398,14 @@ def clean_env(): def parallelize_model( - model, + model: torch.nn.Module, device_mesh, model_args, distributed_args: DistributedArgs, fsdp_grouping_plan: Optional[List[Tuple[str, bool]]] = None, tp_parallelize=None, no_recompute_ops=None, -): +) -> torch.nn.Module: if distributed_args.tp_size > 1: assert ( distributed_args.fsdp_type == "full_shard" diff --git a/bytelatent/train.py b/bytelatent/train.py index 2c3ea01..4641746 100644 --- a/bytelatent/train.py +++ b/bytelatent/train.py @@ -3,6 +3,7 @@ import gc import logging +import math import os import sys from contextlib import ExitStack @@ -11,6 +12,7 @@ from dataclasses import asdict, dataclass from timeit import default_timer as timer from typing import Any, TypeVar +import numpy as np import torch import torch.distributed import torch.nn.functional @@ -32,7 +34,9 @@ from bytelatent.data.iterators.packing_iterator import PackingIteratorState from bytelatent.distributed import ( check_model_value_range, clean_env, + dist_mean, dist_mean_dict, + dist_sum, get_device_mesh, get_is_master, get_world_size, @@ -391,6 +395,9 @@ def train(args: TrainArgs): time_last_log = timer() gc.collect() saved = False + step_losses: list[float] = [] + step_tok_losses: list[float] = [] + n_bytes: int = 0 while train_state.step < args.steps and ( args.max_steps is None or train_state.step < args.max_steps ): @@ -412,6 +419,24 @@ def train(args: TrainArgs): batch_patch_lengths = torch.from_numpy(batch.patch_lengths).cuda() mask = None if batch.mask is None else torch.from_numpy(batch.mask).cuda() + if args.data.tokenizer_args.name in ["bytes", "blt"]: + if mask is None: + n_bytes += batch_y.numel() + else: + n_bytes += mask.sum() + elif args.data.tokenizer_args.name in ["sp", "tiktoken"]: + for example in batch.y: + target_tokens = tokenizer.decode(example.tolist(), cut_at_eos=False) + n_bytes += ( + len(bytes(target_tokens, encoding="utf-8", errors="ignore")) + + sum(example == tokenizer.eos_id) + + sum(example == tokenizer.bos_id) + ) + else: + raise ValueError( + f"Unexpected tokenizer to count n_bytes for: {args.data.tokenizer_args.name}" + ) + if ( not args.train_entropy_model and args.model.encoder_enable_byte_ngrams @@ -486,7 +511,7 @@ def train(args: TrainArgs): batch_x, patch_lengths=batch_patch_lengths, ngram_ids=ngram_ids ) - loss, _ = compute_loss(pred, batch_y, mask, train_state.scale) + loss, tok_loss = compute_loss(pred, batch_y, mask, train_state.scale) # We scale loss with grad_acc_steps so the gradient is the same # regardless of grad_acc_steps @@ -497,6 +522,10 @@ def train(args: TrainArgs): # For logging we undo that scaling loss = loss.detach() * args.grad_acc_steps + # Undo loss scaling so downstream down't need to worry about it + step_losses.append((loss / train_state.scale).item()) + step_tok_losses.append(tok_loss / train_state.scale) + world_size = get_world_size() if 1 < world_size <= 8: # For some reason, there are errors in reduces due to @@ -597,20 +626,33 @@ def train(args: TrainArgs): gpu_memory_monitor.reset_peak_stats() nwords_since_last_log = 0 time_last_log = timer() + stacked_tok_loss = torch.cat(step_tok_losses, dim=0) + total_tok_loss = dist_sum( + stacked_tok_loss.sum().item(), reduce_dtype=torch.bfloat16 + ) + total_n_bytes = dist_sum(n_bytes, reduce_dtype=torch.bfloat16) + avg_bpb = total_tok_loss / math.log(2) / total_n_bytes + avg_loss = dist_mean(np.mean(step_losses).item()) logger.info( f"step: {train_state.step}" f" acc: {train_state.acc_step}" - f" loss: {round(loss.item(),4):>7}" + f" loss: step={round(loss.item(),4):>7} avg={avg_loss}" + f" bpb: {avg_bpb:3f}" f" grad: {grad_norm:.2e}" f" flops: {FLOPS:.2e}" f" wps: {wps:.2e}" f" iter: {curr_iter_time:>7}" f" data: {data_load_time:>5}" f" lr: {curr_lr:.2e}" + f" n_bytes={total_n_bytes}" f" mem: {gpu_mem_stats.max_active_pct:.0f}%" f" pow: {gpu_mem_stats.power_draw/1000} W" ) + n_bytes = 0 + step_losses = [] + step_tok_losses = [] + if every_n_steps( train_state, args.checkpoint.dump.every, acc_step=0 ) or every_n_steps(train_state, args.checkpoint.eval.every, acc_step=0): From 48cf4dfee14d060c6be1615484c85d34ac34296f Mon Sep 17 00:00:00 2001 From: Pedro Rodriguez Date: Wed, 5 Feb 2025 22:47:00 +0000 Subject: [PATCH 29/59] Allow ArrowIterator to read from json Summary: Test Plan: --- bytelatent/args.py | 38 +++++++- bytelatent/data/iterators/arrow_iterator.py | 97 +++++++++++-------- bytelatent/preprocess/preprocess_entropies.py | 26 +++-- 3 files changed, 104 insertions(+), 57 deletions(-) diff --git a/bytelatent/args.py b/bytelatent/args.py index fc72b32..263e8e3 100644 --- a/bytelatent/args.py +++ b/bytelatent/args.py @@ -1,8 +1,10 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. +import json import logging import os from typing import Any +import fsspec import numpy as np import yaml from omegaconf import OmegaConf @@ -10,11 +12,9 @@ from pydantic import BaseModel, ConfigDict from bytelatent.checkpoint import CheckpointArgs from bytelatent.data.data_types import Batch +from bytelatent.data.file_util import get_fs from bytelatent.data.iterators.abstract_iterator import StatefulIterator -from bytelatent.data.iterators.arrow_iterator import ( - ArrowFileIterator, - find_and_sanitize_chunks, -) +from bytelatent.data.iterators.arrow_iterator import ArrowFileIterator from bytelatent.data.iterators.looping_iterator import LoopingIterator from bytelatent.data.iterators.multiprocess_iterator import MultiprocessIterator from bytelatent.data.iterators.packing_iterator import PackingArgs, PackingIterator @@ -53,6 +53,33 @@ def parse_args(args_cls): return pydantic_args +TRAIN_DATA_FILE_PATTERN = "*.chunk.*.jsonl" + + +def find_and_sanitize_chunks( + dataset_path: str, + world_size: int, + file_pattern: str, + s3_profile: str | None = None, +): + fs = get_fs(dataset_path, s3_profile=s3_profile) + path_with_glob = os.path.join(dataset_path, file_pattern) + dataset_chunks = fs.glob(path_with_glob) + n_chunks = len(dataset_chunks) + + if n_chunks > world_size: + n_discard = n_chunks - world_size + dataset_chunks = dataset_chunks[:world_size] + else: + assert ( + world_size % n_chunks == 0 + ), "World size should be a multiple of number of chunks" + + assert n_chunks > 0, f"No valid chunks in {dataset_path}" + + return dataset_chunks + + def distribute_data_to_rank( *, dataset_path: str, @@ -62,9 +89,10 @@ def distribute_data_to_rank( rank: int, world_size: int, s3_profile: str | None = None, + file_pattern: str = TRAIN_DATA_FILE_PATTERN, ) -> ArrowFileIterator: dataset_chunks = find_and_sanitize_chunks( - dataset_path, world_size, s3_profile=s3_profile + dataset_path, world_size, file_pattern, s3_profile=s3_profile ) n_workers_per_chunk = world_size // len(dataset_chunks) rank_to_arrow_iterator_params = [] diff --git a/bytelatent/data/iterators/arrow_iterator.py b/bytelatent/data/iterators/arrow_iterator.py index 4e7b99e..1c68d3a 100644 --- a/bytelatent/data/iterators/arrow_iterator.py +++ b/bytelatent/data/iterators/arrow_iterator.py @@ -16,6 +16,7 @@ from bytelatent import ByteLatentError from bytelatent.data.data_types import BltExample from bytelatent.data.file_util import get_fs from bytelatent.data.iterators.abstract_iterator import IteratorState, StatefulIterator +from bytelatent.preprocess.preprocess_entropies import get_id_key, get_text logger = getLogger(__name__) @@ -32,6 +33,7 @@ class ArrowFileIteratorState(BaseModel, IteratorState): arrow_batch_size: int = 100 s3_profile: str | None filesystem_type: str | None = None + file_format: str def build(self) -> "ArrowFileIterator": arrow_file = ArrowFileIterator( @@ -44,6 +46,7 @@ class ArrowFileIteratorState(BaseModel, IteratorState): dataset_files=self.dataset_files, s3_profile=self.s3_profile, filesystem_type=self.filesystem_type, + file_format=self.file_format, ) if self.row_num != 0: arrow_file._set_row_num(self.row_num) @@ -70,6 +73,7 @@ class ArrowFileIterator(StatefulIterator): dataset_files: list[str] | None = None, s3_profile: str | None = None, filesystem_type: str | None = None, + file_format: str = "arrow", ): assert 0 <= worker_id < num_workers, (worker_id, num_workers) if file_path is None and dataset_files is None: @@ -87,12 +91,16 @@ class ArrowFileIterator(StatefulIterator): self.arrow_batch_size = arrow_batch_size self.s3_profile = s3_profile self.filesystem_type = filesystem_type + self.file_format = file_format self.fs = None if self.filesystem_type is not None: if self.filesystem_type == "file": self.fs = fsspec.filesystem("file") elif self.filesystem_type == "s3": self.fs = fsspec.filesystem("s3", profile=s3_profile) + else: + raise ValueError("Unknown filesystem") + logger.info("Arrow iterator using fs=%s", self.fs) if dataset_files is None: # Prepare arrow shards @@ -153,6 +161,7 @@ class ArrowFileIterator(StatefulIterator): dataset_files=self.dataset_files, s3_profile=self.s3_profile, filesystem_type=self.filesystem_type, + file_format=self.file_format, ) def create_iter( @@ -164,7 +173,7 @@ class ArrowFileIterator(StatefulIterator): else: filesystem = None self.dataset = pa.dataset.dataset( - self.dataset_files, format="arrow", filesystem=filesystem + self.dataset_files, format=self.file_format, filesystem=filesystem ) self.batch_iterator = self.dataset.to_batches( batch_size=self.arrow_batch_size @@ -173,13 +182,22 @@ class ArrowFileIterator(StatefulIterator): if self.batch_to_consume is not None: batch_columns: dict[str, list] = self.batch_to_consume self.batch_to_consume = None - sample_ids = batch_columns["sample_id"] - texts = batch_columns["text"] - entropies = batch_columns["entropies"] + if self.file_format == "arrow": + sample_ids = batch_columns["sample_id"] + texts = batch_columns["text"] + entropies = batch_columns["entropies"] + elif self.file_format == "json": + # This data hasn't been preprocessed to a uniform format, + # so we have to do it now and omit entropies + sample_ids = batch_columns[get_id_key(batch_columns)] + texts = get_text(batch_columns) + entropies = None + else: + raise ValueError(f"Unknown file format: {self.file_format}") for i in range(len(sample_ids)): out = BltExample( sample_id=sample_ids[i], - entropies=entropies[i], + entropies=entropies[i] if entropies is not None else None, text=texts[i], tokens=None, mask=None, @@ -191,13 +209,22 @@ class ArrowFileIterator(StatefulIterator): for batch in self.batch_iterator: batch_columns = batch.to_pydict() - sample_ids = batch_columns["sample_id"] - texts = batch_columns["text"] - entropies = batch_columns["entropies"] + if self.file_format == "arrow": + sample_ids = batch_columns["sample_id"] + texts = batch_columns["text"] + entropies = batch_columns["entropies"] + elif self.file_format == "json": + # This data hasn't been preprocessed to a uniform format, + # so we have to do it now and omit entropies + sample_ids = batch_columns[get_id_key(batch_columns)] + texts = get_text(batch_columns) + entropies = None + else: + raise ValueError(f"Unknown file format: {self.file_format}") for i in range(len(sample_ids)): out = BltExample( sample_id=sample_ids[i], - entropies=entropies[i], + entropies=entropies[i] if entropies is not None else None, text=texts[i], tokens=None, mask=None, @@ -231,13 +258,24 @@ class ArrowFileIterator(StatefulIterator): for batch in self.batch_iterator: if len(batch) > curr_remaining: batch_columns: dict[str, list] = batch.to_pydict() - batch_columns["sample_id"] = batch_columns["sample_id"][ - curr_remaining: - ] - batch_columns["entropies"] = batch_columns["entropies"][ - curr_remaining: - ] - batch_columns["text"] = batch_columns["text"][curr_remaining:] + if self.file_format == "arrow": + leftover_sample_ids = batch_columns["sample_id"][ + curr_remaining: + ] + leftover_entropies = batch_columns["entropies"][curr_remaining:] + leftover_texts = batch_columns["text"][curr_remaining:] + elif self.file_format == "json": + leftover_sample_ids = batch_columns[get_id_key(batch_columns)][ + curr_remaining: + ] + leftover_entropies = None + leftover_texts = get_text(batch_columns)[curr_remaining:] + else: + raise ValueError(f"Unknown file format: {self.file_format}") + + batch_columns["sample_id"] = leftover_sample_ids + batch_columns["entropies"] = leftover_entropies + batch_columns["text"] = leftover_texts self.batch_to_consume = batch_columns break elif len(batch) == curr_remaining: @@ -250,30 +288,3 @@ class ArrowFileIterator(StatefulIterator): logger.info( f"Finished setting arrow position to {target_row_num} for {self.dataset_files}" ) - - -TRAIN_DATA_FILE_PATTERN = "*.chunk.*.jsonl" - - -def find_and_sanitize_chunks( - dataset_path: str, - world_size: int, - file_pattern: str = TRAIN_DATA_FILE_PATTERN, - s3_profile: str | None = None, -): - fs = get_fs(dataset_path, s3_profile=s3_profile) - path_with_glob = os.path.join(dataset_path, file_pattern) - dataset_chunks = fs.glob(path_with_glob) - n_chunks = len(dataset_chunks) - - if n_chunks > world_size: - n_discard = n_chunks - world_size - dataset_chunks = dataset_chunks[:world_size] - else: - assert ( - world_size % n_chunks == 0 - ), "World size should be a multiple of number of chunks" - - assert n_chunks > 0, f"No valid chunks in {dataset_path}" - - return dataset_chunks diff --git a/bytelatent/preprocess/preprocess_entropies.py b/bytelatent/preprocess/preprocess_entropies.py index 519da94..31a4802 100644 --- a/bytelatent/preprocess/preprocess_entropies.py +++ b/bytelatent/preprocess/preprocess_entropies.py @@ -15,29 +15,37 @@ from bytelatent.entropy_model import load_entropy_model from bytelatent.tokenizers.build_tokenizer import TokenizerArgs -def get_id_from_doc(doc: dict) -> int: +def get_id_key(doc: dict) -> int: """ We need a reliable way to ensure that samples from jsonl and arrow are the same, but there is no unique id field, so derive the best possible """ if "sample_id" in doc: - sample_id = doc["sample_id"] + return "sample_id" elif "title" in doc: - sample_id = doc["title"] + return "title" elif "qid" in doc: - sample_id = doc["qid"] + return "qid" elif "paper_id" in doc: - sample_id = doc["paper_id"] + return "paper_id" elif "path" in doc: - sample_id = doc["path"] + return "path" elif "url" in doc: - sample_id = doc["url"] + return "url" elif "id" in doc: - sample_id = doc["id"] + return "id" else: raise ValueError(f"Could not find a id key from: {doc.keys()}") - return str(sample_id) + + +def get_id_from_doc(doc: dict) -> int: + """ + We need a reliable way to ensure that samples from jsonl + and arrow are the same, but there is no unique id field, + so derive the best possible + """ + return str(doc[get_id_key(doc)]) def get_text(doc: dict): From 8f1a9a858e06c225b540e257b451b9223f5672eb Mon Sep 17 00:00:00 2001 From: Pedro Rodriguez Date: Wed, 5 Feb 2025 22:47:01 +0000 Subject: [PATCH 30/59] Minimal working eval Summary: Test Plan: --- bytelatent/eval.py | 64 +++++++++++++++++++++++++++++------------- bytelatent/generate.py | 6 ++-- 2 files changed, 48 insertions(+), 22 deletions(-) diff --git a/bytelatent/eval.py b/bytelatent/eval.py index ae73066..1943a33 100644 --- a/bytelatent/eval.py +++ b/bytelatent/eval.py @@ -15,9 +15,16 @@ from lm_eval.api.model import LM from omegaconf import OmegaConf from pydantic import BaseModel, ConfigDict -from bytelatent.args import EvalArgs, ValidationArgs, parse_args +from bytelatent.args import ( + EvalArgs, + TrainArgs, + ValidationArgs, + find_and_sanitize_chunks, + parse_args, +) from bytelatent.checkpoint import CONSOLIDATE_FOLDER, consolidate_checkpoints from bytelatent.data.file_util import get_fs +from bytelatent.data.iterators.arrow_iterator import ArrowFileIterator from bytelatent.distributed import ( DistributedArgs, dist_mean_dict, @@ -117,19 +124,40 @@ class EvalHarnessLM(LM): return results -def eval_on_val(generator, val_args: ValidationArgs, train_cfg): - srcs = {} +def eval_on_val(generator, val_args: ValidationArgs, train_cfg: TrainArgs): + srcs = [] for src in val_args.sources: path = os.path.join(val_args.root_dir, src) - srcs[path] = 1.0 + srcs.append(path) + for src in train_cfg.data.sources: path = os.path.join(train_cfg.data.root_dir, src) - srcs[path] = 1.0 + srcs.append(path) - multi_state = init_choice_state( - "", srcs, 0, get_global_rank(), get_world_size(), "*.val.jsonl" - ) - path_to_iter = setup_sources(multi_state) + path_to_iter = {} + for path in srcs: + chunks = find_and_sanitize_chunks( + path, + world_size=1, + file_pattern="*.val.jsonl", + s3_profile=train_cfg.data.s3_profile, + ) + assert ( + len(chunks) == 1 + ), f"There should be only 1 chunk per validation file, but found: {chunks}" + chunk = chunks[0] + iterator = ArrowFileIterator( + dataset_files=[chunk], + file_path=None, + preprocess_dir=None, + entropy_model_name=None, + worker_id=0, + num_workers=1, + arrow_batch_size=train_cfg.data.arrow_batch_size, + s3_profile=train_cfg.data.s3_profile, + file_format="json", + ) + path_to_iter[path] = iterator max_gen_len = generator.max_gen_len # We temporarily lower max gen len @@ -137,16 +165,11 @@ def eval_on_val(generator, val_args: ValidationArgs, train_cfg): all_val_metrics = {} for src in path_to_iter: - jsonl_iterator = path_to_iter[src] + example_iterator = path_to_iter[src].create_iter() texts = [] logger.info(f"Running validation on {src}...") - for step, (content, state) in enumerate(jsonl_iterator): - if state["current_iter"] > 0 or ( - val_args.max_steps is not None and step >= val_args.max_steps - ): - break - content_key = "text" if ("text" in content) else "content" - texts.append(content[content_key]) + for step, example in enumerate(example_iterator): + texts.append(example.text) _, loglikelihood, _ = generator.generate(texts) @@ -191,7 +214,7 @@ def launch_eval(eval_args: EvalArgs): else: consolidate_path = os.path.join(eval_args.ckpt_dir, CONSOLIDATE_FOLDER) if not fs.exists(consolidate_path) and get_global_rank() == 0: - consolidate_path = consolidate_checkpoints(eval_args.ckpt_dir) + consolidate_path = consolidate_checkpoints(fs, eval_args.ckpt_dir) fs.mkdirs(eval_args.dump_dir, exist_ok=True) with fs.open(os.path.join(eval_args.dump_dir, "config.yaml"), "w") as f: @@ -210,10 +233,12 @@ def launch_eval(eval_args: EvalArgs): wrap = EvalHarnessLM(generator) # Redo - results = simple_evaluate(wrap, eval_args.harness.model_dump()) + results = simple_evaluate(wrap, **eval_args.harness.model_dump()) + val_results = None if eval_args.validation: val_results = eval_on_val(generator, eval_args.validation, train_cfg) + if get_global_rank() == 0: with fs.open(os.path.join(eval_args.dump_dir, "results.json"), "w") as f: f.write(json.dumps(results)) @@ -222,6 +247,7 @@ def launch_eval(eval_args: EvalArgs): with fs.open(os.path.join(eval_args.dump_dir, "validation.json"), "w") as f: f.write(json.dumps(val_results)) logger.info(f"All validation results: {val_results}") + if eval_args.metric_log_dir and get_global_rank() == 0: metric_log_path = os.path.join(eval_args.metric_log_dir, "metrics.eval.jsonl") diff --git a/bytelatent/generate.py b/bytelatent/generate.py index eb79d81..9d44f30 100644 --- a/bytelatent/generate.py +++ b/bytelatent/generate.py @@ -387,8 +387,7 @@ def load_consolidated_model_and_tokenizer( ): train_args_path = os.path.join(consolidated_path, "params.json") fs = get_fs(train_args_path) - with fs.open(train_args_path) as f: - train_args = TrainArgs.model_validate_json(f.read()) + train_args = TrainArgs.model_validate_json(fs.read_text(train_args_path)) if train_args.train_entropy_model: model_args = train_args.entropy_model @@ -401,7 +400,8 @@ def load_consolidated_model_and_tokenizer( train_args.distributed.model_dtype ] tokenizer = train_args.data.tokenizer_args.build() - st_dict = torch.load(consolidated_path / CONSOLIDATE_NAME, weights_only=True) + with fs.open(os.path.join(consolidated_path, CONSOLIDATE_NAME)) as f: + st_dict = torch.load(f, weights_only=True) model.load_state_dict(st_dict["model"]) model = model.cuda().eval() for param in model.parameters(): From 45bfe94c1eb16d89ac4553a040c0647573638de3 Mon Sep 17 00:00:00 2001 From: Pedro Rodriguez Date: Thu, 6 Feb 2025 01:27:15 +0000 Subject: [PATCH 31/59] Broken train reproducing bf16 error Summary: Test Plan: --- bytelatent/broken_train.py | 623 +++++++++++++++++++++++++++++++++++++ bytelatent/train.py | 2 +- 2 files changed, 624 insertions(+), 1 deletion(-) create mode 100644 bytelatent/broken_train.py diff --git a/bytelatent/broken_train.py b/bytelatent/broken_train.py new file mode 100644 index 0000000..e1630a7 --- /dev/null +++ b/bytelatent/broken_train.py @@ -0,0 +1,623 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. + +from datetime import timedelta +from enum import Enum +from functools import lru_cache +import logging +import math +import sys +import time +from typing import Dict, List, Optional, Tuple +from torch import Tensor +import os +import pickle + +import fsspec +import torch +import torch.distributed +import torch.nn.functional +import torch.nn.functional as F +from torch.distributed._tensor import DTensor + +from torch.distributed.device_mesh import init_device_mesh +from torch.utils._foreach_utils import ( + _device_has_foreach_support, + _group_tensors_by_device_and_dtype, + _has_foreach_support, +) + +from bytelatent.args import TrainArgs +from bytelatent.distributed import ( + DistributedArgs, + check_model_value_range, + parallelize_model, + setup_env, + setup_torch_distributed, +) + +logger = logging.getLogger() + + +def set_root_log_level(log_level: str): + logger = logging.getLogger() + level: int | str = log_level.upper() + try: + level = int(log_level) + except ValueError: + pass + try: + logger.setLevel(level) # type: ignore + except Exception: + logger.warning( + f"Failed to set logging level to {log_level}, using default 'NOTSET'" + ) + logger.setLevel(logging.NOTSET) + + +class LogFormatter(logging.Formatter): + """ + Custom logger for distributed jobs, displaying rank + and preserving indent from the custom prefix format. + """ + + def __init__(self): + self.start_time = time.time() + self.rank = get_global_rank() + self.show_rank = not get_is_slurm_job() # srun has --label + + def formatTime(self, record): + subsecond, seconds = math.modf(record.created) + curr_date = ( + time.strftime("%y-%m-%d %H:%M:%S", time.localtime(seconds)) + + f".{int(subsecond * 1_000_000):06d}" + ) + delta = timedelta(seconds=round(record.created - self.start_time)) + return f"{curr_date} - {delta}" + + def formatPrefix(self, record): + fmt_time = self.formatTime(record) + if self.show_rank: + return f"{self.rank}: {record.levelname:<7} {fmt_time} - " + else: + return f"{record.levelname:<7} {fmt_time} - " + + def formatMessage(self, record, indent: str): + content = record.getMessage() + content = content.replace("\n", "\n" + indent) + # Exception handling as in the default formatter, albeit with indenting + # according to our custom prefix + if record.exc_info: + # Cache the traceback text to avoid converting it multiple times + # (it's constant anyway) + if not record.exc_text: + record.exc_text = self.formatException(record.exc_info) + if record.exc_text: + if content[-1:] != "\n": + content = content + "\n" + indent + content = content + indent.join( + [l + "\n" for l in record.exc_text.splitlines()] + ) + if content[-1:] == "\n": + content = content[:-1] + if record.stack_info: + if content[-1:] != "\n": + content = content + "\n" + indent + stack_text = self.formatStack(record.stack_info) + content = content + indent.join([l + "\n" for l in stack_text.splitlines()]) + if content[-1:] == "\n": + content = content[:-1] + + return content + + def format(self, record): + prefix = self.formatPrefix(record) + indent = " " * len(prefix) + content = self.formatMessage(record, indent) + return prefix + content + + +def init_logger( + log_file: str | None = None, + *, + name: str | None = None, + level: str = "INFO", + fs: fsspec.AbstractFileSystem | None = None, +): + """ + Setup logging. + + Args: + log_file: A file name to save file logs to. + name: The name of the logger to configure, by default the root logger. + level: The logging level to use. + """ + set_root_log_level(level) + logger = logging.getLogger(name) + + # stdout: everything + stdout_handler = logging.StreamHandler(sys.stdout) + stdout_handler.setLevel(logging.NOTSET) + stdout_handler.setFormatter(LogFormatter()) + + # stderr: warnings / errors and above + stderr_handler = logging.StreamHandler(sys.stderr) + stderr_handler.setLevel(logging.WARNING) + stderr_handler.setFormatter(LogFormatter()) + + # set stream handlers + logger.handlers.clear() + logger.handlers.append(stdout_handler) + logger.handlers.append(stderr_handler) + + +@torch.no_grad() +def fixed_clip_grad_norm_( + parameters: torch.Tensor | list[torch.Tensor], + max_norm: float, + norm_type: float = 2.0, + error_if_nonfinite: bool = False, + foreach: Optional[bool] = None, +) -> torch.Tensor: + r"""Clip the gradient norm of an iterable of parameters. + + The norm is computed over the norms of the individual gradients of all parameters, + as if the norms of the individual gradients were concatenated into a single vector. + Gradients are modified in-place. + + Args: + parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a + single Tensor that will have gradients normalized + max_norm (float): max norm of the gradients + norm_type (float): type of the used p-norm. Can be ``'inf'`` for + infinity norm. + error_if_nonfinite (bool): if True, an error is thrown if the total + norm of the gradients from :attr:`parameters` is ``nan``, + ``inf``, or ``-inf``. Default: False (will switch to True in the future) + foreach (bool): use the faster foreach-based implementation. + If ``None``, use the foreach implementation for CUDA and CPU native tensors and silently + fall back to the slow implementation for other device types. + Default: ``None`` + + Returns: + Total norm of the parameter gradients (viewed as a single vector). + """ + if isinstance(parameters, torch.Tensor): + parameters = [parameters] + grads = [p.grad.to(torch.bfloat16) for p in parameters if p.grad is not None] + max_norm = float(max_norm) + norm_type = float(norm_type) + if len(grads) == 0: + return torch.tensor(0.0) + first_device = grads[0].device + grouped_grads: Dict[ + Tuple[torch.device, torch.dtype], Tuple[List[List[Tensor]], List[int]] + ] = _group_tensors_by_device_and_dtype( + [grads] + ) # type: ignore[assignment] + + norms: List[Tensor] = [] + for (device, _), ([device_grads], _) in grouped_grads.items(): # type: ignore[assignment] + if (foreach is None and _has_foreach_support(device_grads, device)) or ( + foreach and _device_has_foreach_support(device) + ): + norms.extend(torch._foreach_norm(device_grads, norm_type)) + elif foreach: + raise RuntimeError( + f"foreach=True was passed, but can't use the foreach API on {device.type} tensors" + ) + else: + norms.extend([torch.linalg.vector_norm(g, norm_type) for g in device_grads]) + + total_norm = torch.linalg.vector_norm( + torch.stack([norm.to(first_device) for norm in norms]), norm_type + ) + + if error_if_nonfinite and torch.logical_or(total_norm.isnan(), total_norm.isinf()): + raise RuntimeError( + f"The total norm of order {norm_type} for gradients from " + "`parameters` is non-finite, so it cannot be clipped. To disable " + "this error and scale the gradients by the non-finite norm anyway, " + "set `error_if_nonfinite=False`" + ) + clip_coef = max_norm / (total_norm + 1e-6) + # Note: multiplying by the clamped coef is redundant when the coef is clamped to 1, but doing so + # avoids a `if clip_coef < 1:` conditional which can require a CPU <=> device synchronization + # when the gradients do not reside in CPU memory. + clip_coef_clamped = torch.clamp(clip_coef, max=1.0) + for (device, _), ([device_grads], _) in grouped_grads.items(): # type: ignore[assignment] + if (foreach is None and _has_foreach_support(device_grads, device)) or ( + foreach and _device_has_foreach_support(device) + ): + torch._foreach_mul_(device_grads, clip_coef_clamped.to(device)) + elif foreach: + raise RuntimeError( + f"foreach=True was passed, but can't use the foreach API on {device.type} tensors" + ) + else: + clip_coef_clamped_device = clip_coef_clamped.to(device) + for g in device_grads: + g.mul_(clip_coef_clamped_device) + + return total_norm + + +def get_no_recompute_ops(): + return None + + +@lru_cache() +def get_is_torch_run() -> bool: + return os.environ.get("LOCAL_RANK") is not None + + +@lru_cache() +def get_is_slurm_job() -> bool: + return "SLURM_JOB_ID" in os.environ and not get_is_torch_run() + + +@lru_cache() +def get_global_rank() -> int: + if get_is_torch_run(): + return int(os.environ["RANK"]) + elif get_is_slurm_job(): + return int(os.environ["SLURM_PROCID"]) + else: + return 0 + + +@lru_cache() +def get_local_rank() -> int: + if get_is_torch_run(): + return int(os.environ["LOCAL_RANK"]) + elif get_is_slurm_job(): + return int(os.environ["SLURM_LOCALID"]) + else: + return 0 + + +@lru_cache() +def get_world_size() -> int: + if get_is_torch_run(): + return int(os.environ["WORLD_SIZE"]) + elif get_is_slurm_job(): + return int(os.environ["SLURM_NTASKS"]) + else: + return 1 + + +@lru_cache() +def get_is_master() -> bool: + return get_global_rank() == 0 + + +def validate_train_args(args: TrainArgs, output_size: int): + # assert args.model is not None or args.entropy_model is not None + if args.entropy_model is not None: + logger.info(f"Setting model output size to {args.entropy_model.vocab_size}") + args.entropy_model.vocab_size = output_size + + assert args.dump_dir, "Dump dir not set" + + if ( + args.distributed.dp_replicate + * args.distributed.dp_shard + * args.distributed.tp_size + != get_world_size() + ): + logging.info("Modifying TrainArgs distributed config") + assert get_world_size() % args.distributed.dp_shard == 0 + logging.info("World size: %s", get_world_size()) + logging.info( + "Existing setting: train_args.distributed.dp_shard=%s", + args.distributed.dp_shard, + ) + logging.info( + "Setting train_args.distributed.dp_replicate=%s, was dp_replicate=%s", + get_world_size() // args.distributed.dp_shard, + args.distributed.dp_replicate, + ) + args.distributed.dp_replicate = get_world_size() // args.distributed.dp_shard + + logging.info( + "Changing dp_replicate from %s to %s, to account for tp_size=%s", + args.distributed.dp_replicate, + args.distributed.dp_replicate // args.distributed.tp_size, + args.distributed.tp_size, + ) + assert args.distributed.dp_replicate % args.distributed.tp_size == 0 + args.distributed.dp_replicate = ( + args.distributed.dp_replicate // args.distributed.tp_size + ) + + logger.warning( + f"Setting Data Parallel size to {args.distributed.dp_replicate * args.distributed.dp_shard}" + ) + assert ( + args.distributed.dp_replicate + * args.distributed.dp_shard + * args.distributed.tp_size + == get_world_size() + ) + + if args.distributed.fsdp_type == "no_shard": + assert ( + args.distributed.dp_shard == 1 + and args.distributed.dp_replicate == get_world_size() + ) + + if args.model is not None: + args.model.max_seqlen = args.data.seq_len + if args.entropy_model is not None: + args.entropy_model.max_seqlen = args.data.seq_len + + if args.distributed.tp_size == 1: + logger.warning( + "Tensor parallelism has not been tested for a while, use at your own risk" + ) + + assert ( + args.probe_freq != args.profiling.mem_steps + ), "Don't profile during probe step" + assert ( + args.probe_freq != args.profiling.profile_steps + ), "Don't profile during probe step" + if args.logging.wandb is not None: + args.logging.wandb.name = args.name + + if args.probe_freq is not None: + assert ( + args.distributed.tp_size == 1 + ), "Probing not supported with tensor parallelism" + assert ( + args.distributed.selective_activation_checkpointing is False + ), "Probing not supported with selective activation checkpointing" + + +def compute_loss(p, y, mask, scale): + tok_loss = scale * F.cross_entropy( + p.flatten(0, 1), y.flatten(0, 1), reduction="none" + ) + if mask is None: + loss = tok_loss.mean() + else: + mask = mask.flatten(0, 1) + tok_loss = tok_loss * mask + loss = tok_loss.sum() / (mask.sum() + 1e-6) + return loss, tok_loss + + +def get_device_mesh(distributed_args): + tp_size = distributed_args.tp_size + dp_replicate = distributed_args.dp_replicate + dp_shard = distributed_args.dp_shard + + assert ( + dp_replicate * dp_shard * tp_size == get_world_size() + ), f"dp_replicate * dp_shard * tp_size ({dp_replicate} * {dp_shard} * {tp_size}) != world_size ({get_world_size()})" + + dims = [] + names = [] + if dp_replicate >= 1: + dims.append(dp_replicate) + names.append("dp_replicate") + if dp_shard > 1 or distributed_args.fsdp_type == "no_shard": + dims.append(dp_shard) + names.append("dp_shard") + if tp_size > 1: + dims.append(tp_size) + names.append("tp") + dims = tuple(dims) + names = tuple(names) + + return init_device_mesh("cuda", mesh_shape=dims, mesh_dim_names=names) + + +def build_fsdp_grouping_plan(): + group_plan: tuple[int, bool] = [] + + # Grouping and output seperately + group_plan.append(("tok_embeddings", False)) + + # Grouping by layers + # for i in range(model_args.n_layers): + # group_plan.append((f"layers.{i}", False)) + + group_plan.append(("output", True)) + + return group_plan + + +class MinimalModel(torch.nn.Module): + def __init__(self, dim: int, vocab_size: int): + super().__init__() + self.tok_embeddings = torch.nn.Embedding(vocab_size, dim) + + # self.norm = RMSNorm(args.dim, eps=args.norm_eps) + # self.layers = torch.nn.ModuleList() + # for _ in range(args.n_layers): + # self.layers.append(TransformerBlock(args)) + + self.output = torch.nn.Linear( + dim, + vocab_size, + bias=False, + ) + + def forward(self, tokens): + h = self.tok_embeddings(tokens) + logits = self.output(h) + # logits = self.output(self.norm(h)) + return logits + + def reset_parameters(self, init_std=None): + pass + + def init_weights(self): + pass + + +def train(): + args = TrainArgs( + dump_dir="/tmp", + name="debug_bf16", + model=None, + entropy_model=None, + distributed=DistributedArgs( + fsdp_type="full_shard", + model_dtype="bf16", + matmul_allow_tf32=False, + selective_activation_checkpointing=False, + tp_size=1, + ), + ) + tokenizer = args.data.tokenizer_args.build() + validate_train_args( + args, + tokenizer.n_words, + ) + dump_fs = fsspec.filesystem("file") + init_logger(os.path.join(args.dump_dir, "train.log"), fs=dump_fs) + setup_env(args.env) + setup_torch_distributed(args.distributed) + world_mesh = get_device_mesh(args.distributed) + logger.info(f"Starting job: {args.name}") + + # build dataloader + # need dp world size and rank + dp_mesh = world_mesh["dp_replicate"] + dp_degree = dp_mesh.size() + dp_rank = dp_mesh.get_local_rank() + if args.distributed.dp_shard > 1: + dp_rank = dp_rank * dp_degree + world_mesh["dp_shard"].get_local_rank() + dp_degree *= world_mesh["dp_shard"].size() + + logger.info(f"Running on dp rank : {dp_rank}") + logger.info(f"Running on dp size : {dp_degree}") + + torch.manual_seed(args.seed) + logger.info("Building model") + + # Initializing Model in meta device allows us to initialize models much bigger than 1 gpu's memory + with torch.device("meta"): + model = MinimalModel(768, tokenizer.n_words) + + model = parallelize_model( + model, + world_mesh, + args.model, + args.distributed, + fsdp_grouping_plan=build_fsdp_grouping_plan(), + tp_parallelize=None, + no_recompute_ops=get_no_recompute_ops(), + ) + + # Once we shard the model on different gpus we can actually initialize the model + # First we create empty tensors of the correct shapes + model = model.to_empty(device="cuda") + # Then we init the model. Please make sure this function initializes *ALL* parameters + # and buffers, otherwise you will have random values in the unitialized tensors + # which will silently fail (give nan gradients for example) + + with torch.random.fork_rng(devices=[torch.cuda.current_device()]): + torch.manual_seed(42) + model.init_weights() + check_model_value_range(model, range=10.0, std=1.0) + + # data_loader = args.data.build_from_rank(dp_rank, dp_degree) + + # train loop + model.train() + # data_loader = train_state.data_loader_state.build() + # batch_iterator = data_loader.create_iter() + # batch = next(batch_iterator) + # with open(f"/storage/home/par/toy-data/batch_{dp_rank}.pickle", "wb") as f: + # pickle.dump(batch, f) + with open(f"/storage/home/par/toy-data/batch_{dp_rank}.pickle", "rb") as f: + batch = pickle.load(f) + + batch_x = torch.from_numpy( + batch.x, + ).cuda() + batch_y = torch.from_numpy(batch.y).cuda() + mask = None if batch.mask is None else torch.from_numpy(batch.mask).cuda() + pred = model(batch_x) + loss, _ = compute_loss(pred, batch_y, mask, 1.0) + + # We scale loss with grad_acc_steps so the gradient is the same + # regardless of grad_acc_steps + loss = loss / args.grad_acc_steps + + # backward on scaled loss to create scaled gradients + loss.backward() + # For logging we undo that scaling + loss = loss.detach() * args.grad_acc_steps + + world_size = get_world_size() + if 1 < world_size <= 8 and False: + # For some reason, there are errors in reduces due to + # not working for non-bf16 numbers. This function is a patched + # version that converts gradients to bf16 before computing norms. + # The error only happens in distributed training on one node, + # hence the guard + grad_norm = fixed_clip_grad_norm_( + model.parameters(), max_norm=args.optim.clip, foreach=True + ) + else: + grad_norm = torch.nn.utils.clip_grad_norm_( + model.parameters(), max_norm=args.optim.clip, foreach=True + ) + + grad_norm = ( + grad_norm.full_tensor() if isinstance(grad_norm, DTensor) else grad_norm + ).item() + + # if isinstance(data_loader, MultiprocessIterator): + # logger.info("Closing MP iterator before exiting") + # data_loader.shutdown() + + +def main(): + """ + The command line interface here uses OmegaConf https://omegaconf.readthedocs.io/en/2.3_branch/usage.html#from-command-line-arguments + This accepts arguments as a dot list + So if the dataclass looks like + + @dataclass + class DummyArgs: + name: str + model: LMTransformerArgsgs + + @dataclass + class LMTransformerArgsgs: + dim: int + + Then you can pass model.dim=32 to change values in LMTransformerArgsgs + or just name=tictac for top level attributes. + + The behavior here is as follows: + 1. We instantiate TrainArgs with its default values + 2. We override those default values with the ones in the provided config file + 3. We override the result with the additional arguments provided through command line + + For example, if the config is the following + + model: + dim: 128 + n_layers: 4 + + and you call train.py with train.py model.dim=64 + + Then the final TrainArgs will have + + model: + dim: 64 + n_layers: 4 + + Plus all the default values in TrainArgs dataclass. + """ + train() + + +if __name__ == "__main__": + main() diff --git a/bytelatent/train.py b/bytelatent/train.py index 4641746..a775e46 100644 --- a/bytelatent/train.py +++ b/bytelatent/train.py @@ -527,7 +527,7 @@ def train(args: TrainArgs): step_tok_losses.append(tok_loss / train_state.scale) world_size = get_world_size() - if 1 < world_size <= 8: + if 1 < world_size <= 8 and False: # For some reason, there are errors in reduces due to # not working for non-bf16 numbers. This function is a patched # version that converts gradients to bf16 before computing norms. From 341264685a2dd979f8c204b27bc6faa23dd5c2ba Mon Sep 17 00:00:00 2001 From: Pedro Rodriguez Date: Thu, 6 Feb 2025 17:33:38 +0000 Subject: [PATCH 32/59] Update checkpointing to use fsspec Summary: - Make the data/checkpoint code fsspec compatible - Still will not work with s3 saves, due to `torch.distributed.checkpoint.save` not being out of the box workable with `fsspec`. Will implement in followup PR Test Plan: Run unit tests and the commands below ``` python -m bytelatent.train config=internal/configs/s3_debug.yaml eval=null checkpoint.dump.every=100 ``` ``` torchrun --nproc-per-node 8 -m bytelatent.train config=internal/configs/s3_debug.yaml eval=null checkpoint.dump.every=100 ``` These currently won't work due to the torch distributed save, but theses hould be tested at a later date ``` python -m bytelatent.train config=internal/configs/s3_debug.yaml eval=null checkpoint.dump.every=100 dump_dir=s3://blt/scratch/checkpoint-test/ ``` ``` torchrun --nproc-per-node 8 -m bytelatent.train config=internal/configs/s3_debug.yaml eval=null checkpoint.dump.every=100 dump_dir=s3://blt/scratch/checkpoint-test/ ``` --- bytelatent/args.py | 14 +++-- bytelatent/checkpoint.py | 115 +++++++++++++++++++++------------------ bytelatent/logger.py | 9 ++- bytelatent/metrics.py | 15 ++++- bytelatent/train.py | 31 +++++++---- 5 files changed, 112 insertions(+), 72 deletions(-) diff --git a/bytelatent/args.py b/bytelatent/args.py index d1bac46..fc72b32 100644 --- a/bytelatent/args.py +++ b/bytelatent/args.py @@ -294,6 +294,14 @@ class TrainArgs(BaseModel): def dump_to_yaml_file( self, path: str, log_config: bool = True, sort_keys: bool = True ): + yaml_str = self.dump_to_yaml_str(sort_keys=sort_keys) + with open(path, "w") as f: + if log_config: + logger.info("Using the following config for this run:") + logger.info(yaml_str) + f.write(yaml_str) + + def dump_to_yaml_str(self, sort_keys: bool = True): model_dict = self.model_dump(mode="json") yaml_str = yaml.dump( model_dict, @@ -301,8 +309,4 @@ class TrainArgs(BaseModel): sort_keys=sort_keys, default_flow_style=False, ) - with open(path, "w") as f: - if log_config: - logger.info("Using the following config for this run:") - logger.info(yaml_str) - f.write(yaml_str) + return yaml_str diff --git a/bytelatent/checkpoint.py b/bytelatent/checkpoint.py index f213c84..1668c88 100644 --- a/bytelatent/checkpoint.py +++ b/bytelatent/checkpoint.py @@ -4,10 +4,9 @@ import json import logging import os import re -from pathlib import Path -from typing import List, Optional, Tuple import fsspec +import s3fs import torch import torch.distributed as dist import torch.distributed.checkpoint as dcp @@ -70,26 +69,29 @@ def consolidate_checkpoints(fs: fsspec.AbstractFileSystem, ckpt_dir: str): Returns the path to the consolidated checkpoint """ - consolidate_path = Path(ckpt_dir) / CONSOLIDATE_FOLDER - if not (consolidate_path / CONSOLIDATE_NAME).exists(): - consolidate_path.mkdir(exist_ok=True) - logger.info(f"Consolidating to: {str(consolidate_path)}") - dcp_to_torch_save(ckpt_dir, str(consolidate_path / CONSOLIDATE_NAME)) - (consolidate_path / CONFIG_NAME).write_text( - (Path(ckpt_dir) / CONFIG_NAME).read_text() + consolidate_path = os.path.join(ckpt_dir, CONSOLIDATE_FOLDER) + consolidate_name = os.path.join(consolidate_path, CONSOLIDATE_NAME) + if not fs.exists(consolidate_name): + fs.mkdirs(consolidate_path, exist_ok=True) + logger.info(f"Consolidating to: {consolidate_path}") + dcp_to_torch_save(ckpt_dir, consolidate_name) + fs.write_text( + os.path.join(consolidate_path, CONFIG_NAME), + fs.read_text(os.path.join(ckpt_dir, CONFIG_NAME)), ) logger.info("Consolidated !") return consolidate_path def load_from_checkpoint( + fs: fsspec.AbstractFileSystem, ckpt_dir: str, model: nn.Module, - optimizer: Optional[torch.optim.Optimizer] = None, + optimizer: torch.optim.Optimizer | None = None, model_key: str = "model", optim_key: str = "optim", ): - if not (Path(ckpt_dir) / ".metadata").exists(): + if not fs.exists(os.path.join(ckpt_dir, ".metadata")): raise ValueError( f"Please convert the checkpoint distcp format using `torch.distributed.checkpoint.format_utils.torch_save_to_dcp` before loading it" ) @@ -115,19 +117,24 @@ class CheckpointManager: self.init_ckpt_path = args.init_ckpt_path self.continue_training_from_init = args.continue_training_from_init - assert self.fs.exists( - self.path - ), f"Path {self.path} does not exist and needs to be created before using CheckpointManager (use instantiate_and_make_dir)" + if not isinstance(self.fs, s3fs.S3FileSystem): + # S3 does not have a concept of directories + assert self.fs.exists( + self.path + ), f"Path {self.path} does not exist and needs to be created before using CheckpointManager (use instantiate_and_make_dir)" self.existing_saves = self.get_existing_saves() - def get_existing_saves(self) -> List[Path]: - folders = [ - p - for p in Path(self.path).iterdir() - if p.is_dir() and re.match(RE_FOLDER, p.name) - ] - folders.sort(key=lambda p: _get_key_step(p.name)) + def get_existing_saves(self) -> list[str]: + if self.fs.exists(self.path) and self.fs.isdir(self.path): + folders = [ + p + for p in self.fs.ls(self.path) + if self.fs.isdir(p) and re.match(RE_FOLDER, os.path.basename(p)) + ] + else: + folders = [] + folders.sort(key=lambda p: _get_key_step(os.path.basename(p))) return folders def clean_up(self): @@ -136,8 +143,9 @@ class CheckpointManager: eval_folders = [] other_folders = [] for p in self.existing_saves: - is_dump = _get_key_step(p.name) % self.dump_every.every == 0 - is_eval = _get_key_step(p.name) % self.eval_every.every == 0 + assert isinstance(p, str), f"Base path type: {p}" + is_dump = _get_key_step(os.path.basename(p)) % self.dump_every.every == 0 + is_eval = _get_key_step(os.path.basename(p)) % self.eval_every.every == 0 if is_dump: dump_folders.append(p) if is_eval: @@ -161,40 +169,39 @@ class CheckpointManager: if dist.get_rank() == 0: for folder in folder_to_remove: - for file in folder.iterdir(): - if file.is_file(): - file.unlink() - elif file.is_dir(): - assert file.name in [CONSOLIDATE_FOLDER] - for f in file.iterdir(): - f.unlink() - file.rmdir() - folder.rmdir() + for file in self.fs.ls(folder): + if self.fs.isfile(file): + self.fs.rm_file(file) + elif self.fs.isdir(file): + assert os.path.name(file) in [CONSOLIDATE_FOLDER] + for f in self.fs.ls(file): + self.fs.rm(f) + self.fs.rmdir(file) + self.fs.rmdir(folder) dist.barrier() self.existing_saves = list(folder_to_keep) - self.existing_saves.sort(key=lambda p: _get_key_step(p.name)) + self.existing_saves.sort(key=lambda p: _get_key_step(os.path.basename(p))) - def get_last_step_path(self, dp_rank: int = 0) -> Optional[Path]: + def get_last_step_path(self, dp_rank: int = 0) -> str | None: path = None for p in reversed(self.existing_saves): - if (p / TRAIN_STATE_NAME.format(dp_rank)).is_file(): + + if self.fs.isfile(os.path.join(p, TRAIN_STATE_NAME.format(dp_rank))): path = p break return path - def _create_folder(self, base_path: Path, folder_name: str) -> Path: - folder = base_path / folder_name + def _create_folder(self, base_path: str, folder_name: str) -> str: + folder = os.path.join(base_path, folder_name) if get_is_master(): - folder.mkdir(parents=False, exist_ok=True) + self.fs.mkdirs(folder, exist_ok=True) if dist.is_initialized(): dist.barrier() return folder - def _get_dp_tp_mesh( - self, device_mesh: Optional[DeviceMesh] = None - ) -> Tuple[int, int]: + def _get_dp_tp_mesh(self, device_mesh: DeviceMesh | None = None) -> tuple[int, int]: dp_rank = 0 tp_rank = 0 if device_mesh is not None: @@ -222,14 +229,14 @@ class CheckpointManager: model, optimizer, train_state, - config, - device_mesh: Optional[DeviceMesh] = None, + config: BaseModel, + device_mesh: DeviceMesh | None = None, ) -> bool: # When creating directory check if only rank0 or is there other solution - path = Path(self.path) + path = self.path curr_save_dir = self._create_folder(path, FOLDER_NAME.format(train_state.step)) - logger.info(f"Saving to: {str(curr_save_dir)}") + logger.info(f"Saving to: {curr_save_dir}") if dist.is_initialized(): dist.barrier() @@ -242,17 +249,19 @@ class CheckpointManager: if dist.is_initialized(): dist.barrier() + print("config type", type(config)) if get_is_master(): - config.dump_to_yaml_file(curr_save_dir / CONFIG_NAME) + self.fs.write_text( + os.path.join(curr_save_dir, CONFIG_NAME), config.model_dump_json() + ) # Add json dump here dp_rank, tp_rank = self._get_dp_tp_mesh(device_mesh) if tp_rank == 0: train_state_name = TRAIN_STATE_NAME.format(dp_rank) - logger.info( - f"Saving train state to: {str(curr_save_dir / train_state_name)}" - ) - with open(curr_save_dir / train_state_name, "w") as f: + train_state_full_path = os.path.join(curr_save_dir, train_state_name) + logger.info(f"Saving train state to: {train_state_full_path}") + with self.fs.open(train_state_full_path, "w") as f: json.dump(train_state.state_dict(), f) logger.info("Train state saved !") @@ -271,7 +280,7 @@ class CheckpointManager: optimizer, train_state, device_mesh: DeviceMesh, - path: Optional[Path] = None, + path: str | None = None, ): dp_rank, tp_rank = self._get_dp_tp_mesh(device_mesh) # Loading tries to load the provided path, if not available the last saved step and finally from the init path @@ -284,12 +293,12 @@ class CheckpointManager: # Only load train state if it's provided, the files exist and we're not loading from init path train_state_name = TRAIN_STATE_NAME.format(dp_rank) logger.info("Reloading train state") - with open(path / train_state_name, "r") as f: + with self.fs.open(os.path.join(path, train_state_name), "r") as f: train_state_dict = json.load(f) train_state.load_state_dict(train_state_dict) logger.info("Train state reloaded") - logger.info(f"Loading from: {str(path)}") + logger.info(f"Loading from: {path}") state_dict = self.get_state_dict( model=model, optimizer=optimizer, diff --git a/bytelatent/logger.py b/bytelatent/logger.py index 87f04cc..6f9a397 100644 --- a/bytelatent/logger.py +++ b/bytelatent/logger.py @@ -6,6 +6,8 @@ import sys import time from datetime import timedelta +import fsspec + from bytelatent.distributed import get_global_rank, get_is_slurm_job @@ -92,6 +94,7 @@ def init_logger( *, name: str | None = None, level: str = "INFO", + fs: fsspec.AbstractFileSystem | None = None, ): """ Setup logging. @@ -121,7 +124,11 @@ def init_logger( if log_file is not None and get_global_rank() == 0: # build file handler - file_handler = logging.FileHandler(log_file, "a") + if fs is None: + file_handler = logging.FileHandler(log_file, "a") + else: + file_stream = fs.open(log_file, mode="a") + file_handler = logging.StreamHandler(file_stream) file_handler.setLevel(logging.NOTSET) file_handler.setFormatter(LogFormatter()) # update logger diff --git a/bytelatent/metrics.py b/bytelatent/metrics.py index 77dc4d7..2df5be7 100644 --- a/bytelatent/metrics.py +++ b/bytelatent/metrics.py @@ -9,6 +9,7 @@ from datetime import datetime, timezone from pathlib import Path from typing import Any, Union +import fsspec import torch import torch.nn as nn import wandb @@ -54,14 +55,24 @@ class LoggingArgs(BaseModel): class MetricLogger: - def __init__(self, outdir: Path, args: Any | None = None): + def __init__( + self, + outdir: Path, + # args: TrainArgs + args: Any | None = None, + fs: fsspec.AbstractFileSystem | None = None, + ): self.outdir = outdir self.jsonl_writer = None + self.fs = fs self.args = args def open(self): if self.jsonl_writer is None: - self.jsonl_writer = open(self.outdir, "a") + if self.fs is None: + self.jsonl_writer = open(self.outdir, "a") + else: + self.jsonl_writer = self.fs.open(self.outdir, "a") if ( self.args is not None and self.args.logging.wandb is not None diff --git a/bytelatent/train.py b/bytelatent/train.py index 86d1c7a..2c3ea01 100644 --- a/bytelatent/train.py +++ b/bytelatent/train.py @@ -8,7 +8,6 @@ import sys from contextlib import ExitStack from copy import deepcopy from dataclasses import asdict, dataclass -from pathlib import Path from timeit import default_timer as timer from typing import Any, TypeVar @@ -18,13 +17,13 @@ import torch.nn.functional import torch.nn.functional as F import wandb import xformers.profiler -from omegaconf import OmegaConf from torch.distributed._tensor import DTensor from torch.distributed.checkpoint.stateful import Stateful from torch.optim import lr_scheduler from bytelatent.args import TrainArgs, parse_args from bytelatent.checkpoint import CheckpointManager, load_from_checkpoint +from bytelatent.data.file_util import get_fs from bytelatent.data.iterators.multiprocess_iterator import ( MultiprocessIterator, MultiprocessIteratorState, @@ -136,11 +135,12 @@ def validate_train_args(args: TrainArgs, output_size: int): if args.checkpoint.path is None: logger.info(f"Setting checkpoint path to {args.checkpoint.path}") - args.checkpoint.path = str(Path(args.dump_dir) / "checkpoints") + args.checkpoint.path = os.path.join(args.dump_dir, "checkpoints") + data_fs = get_fs(args.data.root_dir, s3_profile=args.data.s3_profile) for source in args.data.sources: data_path = os.path.join(args.data.root_dir, source) - assert os.path.exists(data_path), f"{data_path} doesn't exist" + assert data_fs.exists(data_path), f"{data_path} doesn't exist" if ( args.distributed.dp_replicate @@ -255,10 +255,15 @@ def train(args: TrainArgs): args, tokenizer.n_words, ) + dump_fs = get_fs(args.dump_dir, s3_profile=args.checkpoint.s3_profile) if get_is_master(): - os.makedirs(args.dump_dir, exist_ok=True) - args.dump_to_yaml_file(Path(args.dump_dir) / "config.yaml") - init_logger(Path(args.dump_dir) / "train.log") + dump_fs.mkdirs(args.dump_dir, exist_ok=True) + config_yaml_str = args.dump_to_yaml_str() + logging.info("TrainArgs: \n%s", config_yaml_str) + dump_fs.write_text( + os.path.join(args.dump_dir, "config.yaml"), config_yaml_str + ) + init_logger(os.path.join(args.dump_dir, "train.log"), fs=dump_fs) init_signal_handler(set_preemption_flag) # For handling preemption signals. setup_env(args.env) setup_torch_distributed(args.distributed) @@ -313,8 +318,11 @@ def train(args: TrainArgs): if args.checkpoint.init_ckpt_path: logger.info(f"Loading initial model from {args.checkpoint.init_ckpt_path}") + ckpt_fs = get_fs( + args.checkpoint.init_ckpt_path, s3_profile=args.checkpoint.s3_profile + ) load_from_checkpoint( - args.checkpoint.init_ckpt_path, model, model_key="model" + ckpt_fs, args.checkpoint.init_ckpt_path, model, model_key="model" ) # Put model_key="" if its directly the model checkpoint model.rope_embeddings.reset_parameters() # For RoPe initialization since it's a buffer it might not be loaded else: @@ -351,13 +359,14 @@ def train(args: TrainArgs): checkpoint.load(model, optimizer, train_state, world_mesh) # Either load from latest checkpoint or start from scratch if args.probe_freq is not None: + # TODO: Convert this to fsspec compatible if get_is_master(): - os.makedirs(Path(args.dump_dir) / "probe", exist_ok=True) + os.makedirs(os.path.join(args.dump_dir, "probe"), exist_ok=True) torch.distributed.barrier() probe = AutoProbeD( model, ( - Path(args.dump_dir) / "probe" / f"probe.{dp_rank}.jsonl" + os.path.join(args.dump_dir, "probe", f"probe.{dp_rank}.jsonl") if (dp_rank % 128 == 0) else None ), @@ -369,7 +378,7 @@ def train(args: TrainArgs): # train loop model.train() metric_logger = context_stack.enter_context( - MetricLogger(Path(args.dump_dir) / "metrics.jsonl", args) + MetricLogger(os.path.join(args.dump_dir, "metrics.jsonl"), args, fs=dump_fs) ) data_loader = train_state.data_loader_state.build() batch_iterator = data_loader.create_iter() From f058373889d55fc657bfa6bc34cf039e0832207e Mon Sep 17 00:00:00 2001 From: Pedro Rodriguez Date: Thu, 6 Feb 2025 17:37:20 +0000 Subject: [PATCH 33/59] Update checkpointing to use fsspec Summary: - Make the data/checkpoint code fsspec compatible - Still will not work with s3 saves, due to `torch.distributed.checkpoint.save` not being out of the box workable with `fsspec`. Will implement in followup PR Test Plan: Run unit tests and the commands below ``` python -m bytelatent.train config=internal/configs/s3_debug.yaml eval=null checkpoint.dump.every=100 ``` ``` torchrun --nproc-per-node 8 -m bytelatent.train config=internal/configs/s3_debug.yaml eval=null checkpoint.dump.every=100 ``` These currently won't work due to the torch distributed save, but theses hould be tested at a later date ``` python -m bytelatent.train config=internal/configs/s3_debug.yaml eval=null checkpoint.dump.every=100 dump_dir=s3://blt/scratch/checkpoint-test/ ``` ``` torchrun --nproc-per-node 8 -m bytelatent.train config=internal/configs/s3_debug.yaml eval=null checkpoint.dump.every=100 dump_dir=s3://blt/scratch/checkpoint-test/ ``` --- bytelatent/args.py | 14 +++-- bytelatent/checkpoint.py | 115 +++++++++++++++++++++------------------ bytelatent/logger.py | 9 ++- bytelatent/metrics.py | 15 ++++- bytelatent/train.py | 31 +++++++---- 5 files changed, 112 insertions(+), 72 deletions(-) diff --git a/bytelatent/args.py b/bytelatent/args.py index d1bac46..fc72b32 100644 --- a/bytelatent/args.py +++ b/bytelatent/args.py @@ -294,6 +294,14 @@ class TrainArgs(BaseModel): def dump_to_yaml_file( self, path: str, log_config: bool = True, sort_keys: bool = True ): + yaml_str = self.dump_to_yaml_str(sort_keys=sort_keys) + with open(path, "w") as f: + if log_config: + logger.info("Using the following config for this run:") + logger.info(yaml_str) + f.write(yaml_str) + + def dump_to_yaml_str(self, sort_keys: bool = True): model_dict = self.model_dump(mode="json") yaml_str = yaml.dump( model_dict, @@ -301,8 +309,4 @@ class TrainArgs(BaseModel): sort_keys=sort_keys, default_flow_style=False, ) - with open(path, "w") as f: - if log_config: - logger.info("Using the following config for this run:") - logger.info(yaml_str) - f.write(yaml_str) + return yaml_str diff --git a/bytelatent/checkpoint.py b/bytelatent/checkpoint.py index f213c84..1668c88 100644 --- a/bytelatent/checkpoint.py +++ b/bytelatent/checkpoint.py @@ -4,10 +4,9 @@ import json import logging import os import re -from pathlib import Path -from typing import List, Optional, Tuple import fsspec +import s3fs import torch import torch.distributed as dist import torch.distributed.checkpoint as dcp @@ -70,26 +69,29 @@ def consolidate_checkpoints(fs: fsspec.AbstractFileSystem, ckpt_dir: str): Returns the path to the consolidated checkpoint """ - consolidate_path = Path(ckpt_dir) / CONSOLIDATE_FOLDER - if not (consolidate_path / CONSOLIDATE_NAME).exists(): - consolidate_path.mkdir(exist_ok=True) - logger.info(f"Consolidating to: {str(consolidate_path)}") - dcp_to_torch_save(ckpt_dir, str(consolidate_path / CONSOLIDATE_NAME)) - (consolidate_path / CONFIG_NAME).write_text( - (Path(ckpt_dir) / CONFIG_NAME).read_text() + consolidate_path = os.path.join(ckpt_dir, CONSOLIDATE_FOLDER) + consolidate_name = os.path.join(consolidate_path, CONSOLIDATE_NAME) + if not fs.exists(consolidate_name): + fs.mkdirs(consolidate_path, exist_ok=True) + logger.info(f"Consolidating to: {consolidate_path}") + dcp_to_torch_save(ckpt_dir, consolidate_name) + fs.write_text( + os.path.join(consolidate_path, CONFIG_NAME), + fs.read_text(os.path.join(ckpt_dir, CONFIG_NAME)), ) logger.info("Consolidated !") return consolidate_path def load_from_checkpoint( + fs: fsspec.AbstractFileSystem, ckpt_dir: str, model: nn.Module, - optimizer: Optional[torch.optim.Optimizer] = None, + optimizer: torch.optim.Optimizer | None = None, model_key: str = "model", optim_key: str = "optim", ): - if not (Path(ckpt_dir) / ".metadata").exists(): + if not fs.exists(os.path.join(ckpt_dir, ".metadata")): raise ValueError( f"Please convert the checkpoint distcp format using `torch.distributed.checkpoint.format_utils.torch_save_to_dcp` before loading it" ) @@ -115,19 +117,24 @@ class CheckpointManager: self.init_ckpt_path = args.init_ckpt_path self.continue_training_from_init = args.continue_training_from_init - assert self.fs.exists( - self.path - ), f"Path {self.path} does not exist and needs to be created before using CheckpointManager (use instantiate_and_make_dir)" + if not isinstance(self.fs, s3fs.S3FileSystem): + # S3 does not have a concept of directories + assert self.fs.exists( + self.path + ), f"Path {self.path} does not exist and needs to be created before using CheckpointManager (use instantiate_and_make_dir)" self.existing_saves = self.get_existing_saves() - def get_existing_saves(self) -> List[Path]: - folders = [ - p - for p in Path(self.path).iterdir() - if p.is_dir() and re.match(RE_FOLDER, p.name) - ] - folders.sort(key=lambda p: _get_key_step(p.name)) + def get_existing_saves(self) -> list[str]: + if self.fs.exists(self.path) and self.fs.isdir(self.path): + folders = [ + p + for p in self.fs.ls(self.path) + if self.fs.isdir(p) and re.match(RE_FOLDER, os.path.basename(p)) + ] + else: + folders = [] + folders.sort(key=lambda p: _get_key_step(os.path.basename(p))) return folders def clean_up(self): @@ -136,8 +143,9 @@ class CheckpointManager: eval_folders = [] other_folders = [] for p in self.existing_saves: - is_dump = _get_key_step(p.name) % self.dump_every.every == 0 - is_eval = _get_key_step(p.name) % self.eval_every.every == 0 + assert isinstance(p, str), f"Base path type: {p}" + is_dump = _get_key_step(os.path.basename(p)) % self.dump_every.every == 0 + is_eval = _get_key_step(os.path.basename(p)) % self.eval_every.every == 0 if is_dump: dump_folders.append(p) if is_eval: @@ -161,40 +169,39 @@ class CheckpointManager: if dist.get_rank() == 0: for folder in folder_to_remove: - for file in folder.iterdir(): - if file.is_file(): - file.unlink() - elif file.is_dir(): - assert file.name in [CONSOLIDATE_FOLDER] - for f in file.iterdir(): - f.unlink() - file.rmdir() - folder.rmdir() + for file in self.fs.ls(folder): + if self.fs.isfile(file): + self.fs.rm_file(file) + elif self.fs.isdir(file): + assert os.path.name(file) in [CONSOLIDATE_FOLDER] + for f in self.fs.ls(file): + self.fs.rm(f) + self.fs.rmdir(file) + self.fs.rmdir(folder) dist.barrier() self.existing_saves = list(folder_to_keep) - self.existing_saves.sort(key=lambda p: _get_key_step(p.name)) + self.existing_saves.sort(key=lambda p: _get_key_step(os.path.basename(p))) - def get_last_step_path(self, dp_rank: int = 0) -> Optional[Path]: + def get_last_step_path(self, dp_rank: int = 0) -> str | None: path = None for p in reversed(self.existing_saves): - if (p / TRAIN_STATE_NAME.format(dp_rank)).is_file(): + + if self.fs.isfile(os.path.join(p, TRAIN_STATE_NAME.format(dp_rank))): path = p break return path - def _create_folder(self, base_path: Path, folder_name: str) -> Path: - folder = base_path / folder_name + def _create_folder(self, base_path: str, folder_name: str) -> str: + folder = os.path.join(base_path, folder_name) if get_is_master(): - folder.mkdir(parents=False, exist_ok=True) + self.fs.mkdirs(folder, exist_ok=True) if dist.is_initialized(): dist.barrier() return folder - def _get_dp_tp_mesh( - self, device_mesh: Optional[DeviceMesh] = None - ) -> Tuple[int, int]: + def _get_dp_tp_mesh(self, device_mesh: DeviceMesh | None = None) -> tuple[int, int]: dp_rank = 0 tp_rank = 0 if device_mesh is not None: @@ -222,14 +229,14 @@ class CheckpointManager: model, optimizer, train_state, - config, - device_mesh: Optional[DeviceMesh] = None, + config: BaseModel, + device_mesh: DeviceMesh | None = None, ) -> bool: # When creating directory check if only rank0 or is there other solution - path = Path(self.path) + path = self.path curr_save_dir = self._create_folder(path, FOLDER_NAME.format(train_state.step)) - logger.info(f"Saving to: {str(curr_save_dir)}") + logger.info(f"Saving to: {curr_save_dir}") if dist.is_initialized(): dist.barrier() @@ -242,17 +249,19 @@ class CheckpointManager: if dist.is_initialized(): dist.barrier() + print("config type", type(config)) if get_is_master(): - config.dump_to_yaml_file(curr_save_dir / CONFIG_NAME) + self.fs.write_text( + os.path.join(curr_save_dir, CONFIG_NAME), config.model_dump_json() + ) # Add json dump here dp_rank, tp_rank = self._get_dp_tp_mesh(device_mesh) if tp_rank == 0: train_state_name = TRAIN_STATE_NAME.format(dp_rank) - logger.info( - f"Saving train state to: {str(curr_save_dir / train_state_name)}" - ) - with open(curr_save_dir / train_state_name, "w") as f: + train_state_full_path = os.path.join(curr_save_dir, train_state_name) + logger.info(f"Saving train state to: {train_state_full_path}") + with self.fs.open(train_state_full_path, "w") as f: json.dump(train_state.state_dict(), f) logger.info("Train state saved !") @@ -271,7 +280,7 @@ class CheckpointManager: optimizer, train_state, device_mesh: DeviceMesh, - path: Optional[Path] = None, + path: str | None = None, ): dp_rank, tp_rank = self._get_dp_tp_mesh(device_mesh) # Loading tries to load the provided path, if not available the last saved step and finally from the init path @@ -284,12 +293,12 @@ class CheckpointManager: # Only load train state if it's provided, the files exist and we're not loading from init path train_state_name = TRAIN_STATE_NAME.format(dp_rank) logger.info("Reloading train state") - with open(path / train_state_name, "r") as f: + with self.fs.open(os.path.join(path, train_state_name), "r") as f: train_state_dict = json.load(f) train_state.load_state_dict(train_state_dict) logger.info("Train state reloaded") - logger.info(f"Loading from: {str(path)}") + logger.info(f"Loading from: {path}") state_dict = self.get_state_dict( model=model, optimizer=optimizer, diff --git a/bytelatent/logger.py b/bytelatent/logger.py index 87f04cc..6f9a397 100644 --- a/bytelatent/logger.py +++ b/bytelatent/logger.py @@ -6,6 +6,8 @@ import sys import time from datetime import timedelta +import fsspec + from bytelatent.distributed import get_global_rank, get_is_slurm_job @@ -92,6 +94,7 @@ def init_logger( *, name: str | None = None, level: str = "INFO", + fs: fsspec.AbstractFileSystem | None = None, ): """ Setup logging. @@ -121,7 +124,11 @@ def init_logger( if log_file is not None and get_global_rank() == 0: # build file handler - file_handler = logging.FileHandler(log_file, "a") + if fs is None: + file_handler = logging.FileHandler(log_file, "a") + else: + file_stream = fs.open(log_file, mode="a") + file_handler = logging.StreamHandler(file_stream) file_handler.setLevel(logging.NOTSET) file_handler.setFormatter(LogFormatter()) # update logger diff --git a/bytelatent/metrics.py b/bytelatent/metrics.py index e746e4f..fb443d7 100644 --- a/bytelatent/metrics.py +++ b/bytelatent/metrics.py @@ -8,6 +8,7 @@ from datetime import datetime, timezone from pathlib import Path from typing import Any, Union +import fsspec import torch import torch.nn as nn import wandb @@ -53,14 +54,24 @@ class LoggingArgs(BaseModel): class MetricLogger: - def __init__(self, outdir: Path, args: Any | None = None): + def __init__( + self, + outdir: Path, + # args: TrainArgs + args: Any | None = None, + fs: fsspec.AbstractFileSystem | None = None, + ): self.outdir = outdir self.jsonl_writer = None + self.fs = fs self.args = args def open(self): if self.jsonl_writer is None: - self.jsonl_writer = open(self.outdir, "a") + if self.fs is None: + self.jsonl_writer = open(self.outdir, "a") + else: + self.jsonl_writer = self.fs.open(self.outdir, "a") if ( self.args is not None and self.args.logging.wandb is not None diff --git a/bytelatent/train.py b/bytelatent/train.py index bb8307a..9bfe12a 100644 --- a/bytelatent/train.py +++ b/bytelatent/train.py @@ -8,7 +8,6 @@ import sys from contextlib import ExitStack from copy import deepcopy from dataclasses import asdict, dataclass -from pathlib import Path from timeit import default_timer as timer from typing import Any, TypeVar @@ -18,13 +17,13 @@ import torch.nn.functional import torch.nn.functional as F import wandb import xformers.profiler -from omegaconf import OmegaConf from torch.distributed._tensor import DTensor from torch.distributed.checkpoint.stateful import Stateful from torch.optim import lr_scheduler from bytelatent.args import TrainArgs, parse_args from bytelatent.checkpoint import CheckpointManager, load_from_checkpoint +from bytelatent.data.file_util import get_fs from bytelatent.data.iterators.multiprocess_iterator import ( MultiprocessIterator, MultiprocessIteratorState, @@ -136,11 +135,12 @@ def validate_train_args(args: TrainArgs, output_size: int): if args.checkpoint.path is None: logger.info(f"Setting checkpoint path to {args.checkpoint.path}") - args.checkpoint.path = str(Path(args.dump_dir) / "checkpoints") + args.checkpoint.path = os.path.join(args.dump_dir, "checkpoints") + data_fs = get_fs(args.data.root_dir, s3_profile=args.data.s3_profile) for source in args.data.sources: data_path = os.path.join(args.data.root_dir, source) - assert os.path.exists(data_path), f"{data_path} doesn't exist" + assert data_fs.exists(data_path), f"{data_path} doesn't exist" if ( args.distributed.dp_replicate @@ -255,10 +255,15 @@ def train(args: TrainArgs): args, tokenizer.n_words, ) + dump_fs = get_fs(args.dump_dir, s3_profile=args.checkpoint.s3_profile) if get_is_master(): - os.makedirs(args.dump_dir, exist_ok=True) - args.dump_to_yaml_file(Path(args.dump_dir) / "config.yaml") - init_logger(Path(args.dump_dir) / "train.log") + dump_fs.mkdirs(args.dump_dir, exist_ok=True) + config_yaml_str = args.dump_to_yaml_str() + logging.info("TrainArgs: \n%s", config_yaml_str) + dump_fs.write_text( + os.path.join(args.dump_dir, "config.yaml"), config_yaml_str + ) + init_logger(os.path.join(args.dump_dir, "train.log"), fs=dump_fs) init_signal_handler(set_preemption_flag) # For handling preemption signals. setup_env(args.env) setup_torch_distributed(args.distributed) @@ -313,8 +318,11 @@ def train(args: TrainArgs): if args.checkpoint.init_ckpt_path: logger.info(f"Loading initial model from {args.checkpoint.init_ckpt_path}") + ckpt_fs = get_fs( + args.checkpoint.init_ckpt_path, s3_profile=args.checkpoint.s3_profile + ) load_from_checkpoint( - args.checkpoint.init_ckpt_path, model, model_key="model" + ckpt_fs, args.checkpoint.init_ckpt_path, model, model_key="model" ) # Put model_key="" if its directly the model checkpoint model.rope_embeddings.reset_parameters() # For RoPe initialization since it's a buffer it might not be loaded else: @@ -352,13 +360,14 @@ def train(args: TrainArgs): checkpoint.load(model, optimizer, train_state, world_mesh) # Either load from latest checkpoint or start from scratch if args.probe_freq is not None: + # TODO: Convert this to fsspec compatible if get_is_master(): - os.makedirs(Path(args.dump_dir) / "probe", exist_ok=True) + os.makedirs(os.path.join(args.dump_dir, "probe"), exist_ok=True) torch.distributed.barrier() probe = AutoProbeD( model, ( - Path(args.dump_dir) / "probe" / f"probe.{dp_rank}.jsonl" + os.path.join(args.dump_dir, "probe", f"probe.{dp_rank}.jsonl") if (dp_rank % 128 == 0) else None ), @@ -370,7 +379,7 @@ def train(args: TrainArgs): # train loop model.train() metric_logger = context_stack.enter_context( - MetricLogger(Path(args.dump_dir) / "metrics.jsonl", args) + MetricLogger(os.path.join(args.dump_dir, "metrics.jsonl"), args, fs=dump_fs) ) data_loader = train_state.data_loader_state.build() batch_iterator = data_loader.create_iter() From 8d26140970e73327eae3e99f73d7d816375c1134 Mon Sep 17 00:00:00 2001 From: Pedro Rodriguez Date: Thu, 6 Feb 2025 17:42:32 +0000 Subject: [PATCH 34/59] Allow ArrowIterator to read from json Summary: Test Plan: --- bytelatent/args.py | 38 +++++++- bytelatent/data/iterators/arrow_iterator.py | 97 +++++++++++-------- bytelatent/preprocess/preprocess_entropies.py | 26 +++-- 3 files changed, 104 insertions(+), 57 deletions(-) diff --git a/bytelatent/args.py b/bytelatent/args.py index fc72b32..263e8e3 100644 --- a/bytelatent/args.py +++ b/bytelatent/args.py @@ -1,8 +1,10 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. +import json import logging import os from typing import Any +import fsspec import numpy as np import yaml from omegaconf import OmegaConf @@ -10,11 +12,9 @@ from pydantic import BaseModel, ConfigDict from bytelatent.checkpoint import CheckpointArgs from bytelatent.data.data_types import Batch +from bytelatent.data.file_util import get_fs from bytelatent.data.iterators.abstract_iterator import StatefulIterator -from bytelatent.data.iterators.arrow_iterator import ( - ArrowFileIterator, - find_and_sanitize_chunks, -) +from bytelatent.data.iterators.arrow_iterator import ArrowFileIterator from bytelatent.data.iterators.looping_iterator import LoopingIterator from bytelatent.data.iterators.multiprocess_iterator import MultiprocessIterator from bytelatent.data.iterators.packing_iterator import PackingArgs, PackingIterator @@ -53,6 +53,33 @@ def parse_args(args_cls): return pydantic_args +TRAIN_DATA_FILE_PATTERN = "*.chunk.*.jsonl" + + +def find_and_sanitize_chunks( + dataset_path: str, + world_size: int, + file_pattern: str, + s3_profile: str | None = None, +): + fs = get_fs(dataset_path, s3_profile=s3_profile) + path_with_glob = os.path.join(dataset_path, file_pattern) + dataset_chunks = fs.glob(path_with_glob) + n_chunks = len(dataset_chunks) + + if n_chunks > world_size: + n_discard = n_chunks - world_size + dataset_chunks = dataset_chunks[:world_size] + else: + assert ( + world_size % n_chunks == 0 + ), "World size should be a multiple of number of chunks" + + assert n_chunks > 0, f"No valid chunks in {dataset_path}" + + return dataset_chunks + + def distribute_data_to_rank( *, dataset_path: str, @@ -62,9 +89,10 @@ def distribute_data_to_rank( rank: int, world_size: int, s3_profile: str | None = None, + file_pattern: str = TRAIN_DATA_FILE_PATTERN, ) -> ArrowFileIterator: dataset_chunks = find_and_sanitize_chunks( - dataset_path, world_size, s3_profile=s3_profile + dataset_path, world_size, file_pattern, s3_profile=s3_profile ) n_workers_per_chunk = world_size // len(dataset_chunks) rank_to_arrow_iterator_params = [] diff --git a/bytelatent/data/iterators/arrow_iterator.py b/bytelatent/data/iterators/arrow_iterator.py index 4e7b99e..1c68d3a 100644 --- a/bytelatent/data/iterators/arrow_iterator.py +++ b/bytelatent/data/iterators/arrow_iterator.py @@ -16,6 +16,7 @@ from bytelatent import ByteLatentError from bytelatent.data.data_types import BltExample from bytelatent.data.file_util import get_fs from bytelatent.data.iterators.abstract_iterator import IteratorState, StatefulIterator +from bytelatent.preprocess.preprocess_entropies import get_id_key, get_text logger = getLogger(__name__) @@ -32,6 +33,7 @@ class ArrowFileIteratorState(BaseModel, IteratorState): arrow_batch_size: int = 100 s3_profile: str | None filesystem_type: str | None = None + file_format: str def build(self) -> "ArrowFileIterator": arrow_file = ArrowFileIterator( @@ -44,6 +46,7 @@ class ArrowFileIteratorState(BaseModel, IteratorState): dataset_files=self.dataset_files, s3_profile=self.s3_profile, filesystem_type=self.filesystem_type, + file_format=self.file_format, ) if self.row_num != 0: arrow_file._set_row_num(self.row_num) @@ -70,6 +73,7 @@ class ArrowFileIterator(StatefulIterator): dataset_files: list[str] | None = None, s3_profile: str | None = None, filesystem_type: str | None = None, + file_format: str = "arrow", ): assert 0 <= worker_id < num_workers, (worker_id, num_workers) if file_path is None and dataset_files is None: @@ -87,12 +91,16 @@ class ArrowFileIterator(StatefulIterator): self.arrow_batch_size = arrow_batch_size self.s3_profile = s3_profile self.filesystem_type = filesystem_type + self.file_format = file_format self.fs = None if self.filesystem_type is not None: if self.filesystem_type == "file": self.fs = fsspec.filesystem("file") elif self.filesystem_type == "s3": self.fs = fsspec.filesystem("s3", profile=s3_profile) + else: + raise ValueError("Unknown filesystem") + logger.info("Arrow iterator using fs=%s", self.fs) if dataset_files is None: # Prepare arrow shards @@ -153,6 +161,7 @@ class ArrowFileIterator(StatefulIterator): dataset_files=self.dataset_files, s3_profile=self.s3_profile, filesystem_type=self.filesystem_type, + file_format=self.file_format, ) def create_iter( @@ -164,7 +173,7 @@ class ArrowFileIterator(StatefulIterator): else: filesystem = None self.dataset = pa.dataset.dataset( - self.dataset_files, format="arrow", filesystem=filesystem + self.dataset_files, format=self.file_format, filesystem=filesystem ) self.batch_iterator = self.dataset.to_batches( batch_size=self.arrow_batch_size @@ -173,13 +182,22 @@ class ArrowFileIterator(StatefulIterator): if self.batch_to_consume is not None: batch_columns: dict[str, list] = self.batch_to_consume self.batch_to_consume = None - sample_ids = batch_columns["sample_id"] - texts = batch_columns["text"] - entropies = batch_columns["entropies"] + if self.file_format == "arrow": + sample_ids = batch_columns["sample_id"] + texts = batch_columns["text"] + entropies = batch_columns["entropies"] + elif self.file_format == "json": + # This data hasn't been preprocessed to a uniform format, + # so we have to do it now and omit entropies + sample_ids = batch_columns[get_id_key(batch_columns)] + texts = get_text(batch_columns) + entropies = None + else: + raise ValueError(f"Unknown file format: {self.file_format}") for i in range(len(sample_ids)): out = BltExample( sample_id=sample_ids[i], - entropies=entropies[i], + entropies=entropies[i] if entropies is not None else None, text=texts[i], tokens=None, mask=None, @@ -191,13 +209,22 @@ class ArrowFileIterator(StatefulIterator): for batch in self.batch_iterator: batch_columns = batch.to_pydict() - sample_ids = batch_columns["sample_id"] - texts = batch_columns["text"] - entropies = batch_columns["entropies"] + if self.file_format == "arrow": + sample_ids = batch_columns["sample_id"] + texts = batch_columns["text"] + entropies = batch_columns["entropies"] + elif self.file_format == "json": + # This data hasn't been preprocessed to a uniform format, + # so we have to do it now and omit entropies + sample_ids = batch_columns[get_id_key(batch_columns)] + texts = get_text(batch_columns) + entropies = None + else: + raise ValueError(f"Unknown file format: {self.file_format}") for i in range(len(sample_ids)): out = BltExample( sample_id=sample_ids[i], - entropies=entropies[i], + entropies=entropies[i] if entropies is not None else None, text=texts[i], tokens=None, mask=None, @@ -231,13 +258,24 @@ class ArrowFileIterator(StatefulIterator): for batch in self.batch_iterator: if len(batch) > curr_remaining: batch_columns: dict[str, list] = batch.to_pydict() - batch_columns["sample_id"] = batch_columns["sample_id"][ - curr_remaining: - ] - batch_columns["entropies"] = batch_columns["entropies"][ - curr_remaining: - ] - batch_columns["text"] = batch_columns["text"][curr_remaining:] + if self.file_format == "arrow": + leftover_sample_ids = batch_columns["sample_id"][ + curr_remaining: + ] + leftover_entropies = batch_columns["entropies"][curr_remaining:] + leftover_texts = batch_columns["text"][curr_remaining:] + elif self.file_format == "json": + leftover_sample_ids = batch_columns[get_id_key(batch_columns)][ + curr_remaining: + ] + leftover_entropies = None + leftover_texts = get_text(batch_columns)[curr_remaining:] + else: + raise ValueError(f"Unknown file format: {self.file_format}") + + batch_columns["sample_id"] = leftover_sample_ids + batch_columns["entropies"] = leftover_entropies + batch_columns["text"] = leftover_texts self.batch_to_consume = batch_columns break elif len(batch) == curr_remaining: @@ -250,30 +288,3 @@ class ArrowFileIterator(StatefulIterator): logger.info( f"Finished setting arrow position to {target_row_num} for {self.dataset_files}" ) - - -TRAIN_DATA_FILE_PATTERN = "*.chunk.*.jsonl" - - -def find_and_sanitize_chunks( - dataset_path: str, - world_size: int, - file_pattern: str = TRAIN_DATA_FILE_PATTERN, - s3_profile: str | None = None, -): - fs = get_fs(dataset_path, s3_profile=s3_profile) - path_with_glob = os.path.join(dataset_path, file_pattern) - dataset_chunks = fs.glob(path_with_glob) - n_chunks = len(dataset_chunks) - - if n_chunks > world_size: - n_discard = n_chunks - world_size - dataset_chunks = dataset_chunks[:world_size] - else: - assert ( - world_size % n_chunks == 0 - ), "World size should be a multiple of number of chunks" - - assert n_chunks > 0, f"No valid chunks in {dataset_path}" - - return dataset_chunks diff --git a/bytelatent/preprocess/preprocess_entropies.py b/bytelatent/preprocess/preprocess_entropies.py index 519da94..31a4802 100644 --- a/bytelatent/preprocess/preprocess_entropies.py +++ b/bytelatent/preprocess/preprocess_entropies.py @@ -15,29 +15,37 @@ from bytelatent.entropy_model import load_entropy_model from bytelatent.tokenizers.build_tokenizer import TokenizerArgs -def get_id_from_doc(doc: dict) -> int: +def get_id_key(doc: dict) -> int: """ We need a reliable way to ensure that samples from jsonl and arrow are the same, but there is no unique id field, so derive the best possible """ if "sample_id" in doc: - sample_id = doc["sample_id"] + return "sample_id" elif "title" in doc: - sample_id = doc["title"] + return "title" elif "qid" in doc: - sample_id = doc["qid"] + return "qid" elif "paper_id" in doc: - sample_id = doc["paper_id"] + return "paper_id" elif "path" in doc: - sample_id = doc["path"] + return "path" elif "url" in doc: - sample_id = doc["url"] + return "url" elif "id" in doc: - sample_id = doc["id"] + return "id" else: raise ValueError(f"Could not find a id key from: {doc.keys()}") - return str(sample_id) + + +def get_id_from_doc(doc: dict) -> int: + """ + We need a reliable way to ensure that samples from jsonl + and arrow are the same, but there is no unique id field, + so derive the best possible + """ + return str(doc[get_id_key(doc)]) def get_text(doc: dict): From 0e9421af079162f508872c6d60a85741e3919f6a Mon Sep 17 00:00:00 2001 From: Pedro Rodriguez Date: Thu, 6 Feb 2025 17:43:10 +0000 Subject: [PATCH 35/59] Allow ArrowIterator to read from json Summary: Test Plan: --- bytelatent/args.py | 38 +++++++- bytelatent/data/iterators/arrow_iterator.py | 97 +++++++++++-------- bytelatent/preprocess/preprocess_entropies.py | 26 +++-- bytelatent/stool.py | 2 +- 4 files changed, 105 insertions(+), 58 deletions(-) diff --git a/bytelatent/args.py b/bytelatent/args.py index fc72b32..263e8e3 100644 --- a/bytelatent/args.py +++ b/bytelatent/args.py @@ -1,8 +1,10 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. +import json import logging import os from typing import Any +import fsspec import numpy as np import yaml from omegaconf import OmegaConf @@ -10,11 +12,9 @@ from pydantic import BaseModel, ConfigDict from bytelatent.checkpoint import CheckpointArgs from bytelatent.data.data_types import Batch +from bytelatent.data.file_util import get_fs from bytelatent.data.iterators.abstract_iterator import StatefulIterator -from bytelatent.data.iterators.arrow_iterator import ( - ArrowFileIterator, - find_and_sanitize_chunks, -) +from bytelatent.data.iterators.arrow_iterator import ArrowFileIterator from bytelatent.data.iterators.looping_iterator import LoopingIterator from bytelatent.data.iterators.multiprocess_iterator import MultiprocessIterator from bytelatent.data.iterators.packing_iterator import PackingArgs, PackingIterator @@ -53,6 +53,33 @@ def parse_args(args_cls): return pydantic_args +TRAIN_DATA_FILE_PATTERN = "*.chunk.*.jsonl" + + +def find_and_sanitize_chunks( + dataset_path: str, + world_size: int, + file_pattern: str, + s3_profile: str | None = None, +): + fs = get_fs(dataset_path, s3_profile=s3_profile) + path_with_glob = os.path.join(dataset_path, file_pattern) + dataset_chunks = fs.glob(path_with_glob) + n_chunks = len(dataset_chunks) + + if n_chunks > world_size: + n_discard = n_chunks - world_size + dataset_chunks = dataset_chunks[:world_size] + else: + assert ( + world_size % n_chunks == 0 + ), "World size should be a multiple of number of chunks" + + assert n_chunks > 0, f"No valid chunks in {dataset_path}" + + return dataset_chunks + + def distribute_data_to_rank( *, dataset_path: str, @@ -62,9 +89,10 @@ def distribute_data_to_rank( rank: int, world_size: int, s3_profile: str | None = None, + file_pattern: str = TRAIN_DATA_FILE_PATTERN, ) -> ArrowFileIterator: dataset_chunks = find_and_sanitize_chunks( - dataset_path, world_size, s3_profile=s3_profile + dataset_path, world_size, file_pattern, s3_profile=s3_profile ) n_workers_per_chunk = world_size // len(dataset_chunks) rank_to_arrow_iterator_params = [] diff --git a/bytelatent/data/iterators/arrow_iterator.py b/bytelatent/data/iterators/arrow_iterator.py index 4e7b99e..1c68d3a 100644 --- a/bytelatent/data/iterators/arrow_iterator.py +++ b/bytelatent/data/iterators/arrow_iterator.py @@ -16,6 +16,7 @@ from bytelatent import ByteLatentError from bytelatent.data.data_types import BltExample from bytelatent.data.file_util import get_fs from bytelatent.data.iterators.abstract_iterator import IteratorState, StatefulIterator +from bytelatent.preprocess.preprocess_entropies import get_id_key, get_text logger = getLogger(__name__) @@ -32,6 +33,7 @@ class ArrowFileIteratorState(BaseModel, IteratorState): arrow_batch_size: int = 100 s3_profile: str | None filesystem_type: str | None = None + file_format: str def build(self) -> "ArrowFileIterator": arrow_file = ArrowFileIterator( @@ -44,6 +46,7 @@ class ArrowFileIteratorState(BaseModel, IteratorState): dataset_files=self.dataset_files, s3_profile=self.s3_profile, filesystem_type=self.filesystem_type, + file_format=self.file_format, ) if self.row_num != 0: arrow_file._set_row_num(self.row_num) @@ -70,6 +73,7 @@ class ArrowFileIterator(StatefulIterator): dataset_files: list[str] | None = None, s3_profile: str | None = None, filesystem_type: str | None = None, + file_format: str = "arrow", ): assert 0 <= worker_id < num_workers, (worker_id, num_workers) if file_path is None and dataset_files is None: @@ -87,12 +91,16 @@ class ArrowFileIterator(StatefulIterator): self.arrow_batch_size = arrow_batch_size self.s3_profile = s3_profile self.filesystem_type = filesystem_type + self.file_format = file_format self.fs = None if self.filesystem_type is not None: if self.filesystem_type == "file": self.fs = fsspec.filesystem("file") elif self.filesystem_type == "s3": self.fs = fsspec.filesystem("s3", profile=s3_profile) + else: + raise ValueError("Unknown filesystem") + logger.info("Arrow iterator using fs=%s", self.fs) if dataset_files is None: # Prepare arrow shards @@ -153,6 +161,7 @@ class ArrowFileIterator(StatefulIterator): dataset_files=self.dataset_files, s3_profile=self.s3_profile, filesystem_type=self.filesystem_type, + file_format=self.file_format, ) def create_iter( @@ -164,7 +173,7 @@ class ArrowFileIterator(StatefulIterator): else: filesystem = None self.dataset = pa.dataset.dataset( - self.dataset_files, format="arrow", filesystem=filesystem + self.dataset_files, format=self.file_format, filesystem=filesystem ) self.batch_iterator = self.dataset.to_batches( batch_size=self.arrow_batch_size @@ -173,13 +182,22 @@ class ArrowFileIterator(StatefulIterator): if self.batch_to_consume is not None: batch_columns: dict[str, list] = self.batch_to_consume self.batch_to_consume = None - sample_ids = batch_columns["sample_id"] - texts = batch_columns["text"] - entropies = batch_columns["entropies"] + if self.file_format == "arrow": + sample_ids = batch_columns["sample_id"] + texts = batch_columns["text"] + entropies = batch_columns["entropies"] + elif self.file_format == "json": + # This data hasn't been preprocessed to a uniform format, + # so we have to do it now and omit entropies + sample_ids = batch_columns[get_id_key(batch_columns)] + texts = get_text(batch_columns) + entropies = None + else: + raise ValueError(f"Unknown file format: {self.file_format}") for i in range(len(sample_ids)): out = BltExample( sample_id=sample_ids[i], - entropies=entropies[i], + entropies=entropies[i] if entropies is not None else None, text=texts[i], tokens=None, mask=None, @@ -191,13 +209,22 @@ class ArrowFileIterator(StatefulIterator): for batch in self.batch_iterator: batch_columns = batch.to_pydict() - sample_ids = batch_columns["sample_id"] - texts = batch_columns["text"] - entropies = batch_columns["entropies"] + if self.file_format == "arrow": + sample_ids = batch_columns["sample_id"] + texts = batch_columns["text"] + entropies = batch_columns["entropies"] + elif self.file_format == "json": + # This data hasn't been preprocessed to a uniform format, + # so we have to do it now and omit entropies + sample_ids = batch_columns[get_id_key(batch_columns)] + texts = get_text(batch_columns) + entropies = None + else: + raise ValueError(f"Unknown file format: {self.file_format}") for i in range(len(sample_ids)): out = BltExample( sample_id=sample_ids[i], - entropies=entropies[i], + entropies=entropies[i] if entropies is not None else None, text=texts[i], tokens=None, mask=None, @@ -231,13 +258,24 @@ class ArrowFileIterator(StatefulIterator): for batch in self.batch_iterator: if len(batch) > curr_remaining: batch_columns: dict[str, list] = batch.to_pydict() - batch_columns["sample_id"] = batch_columns["sample_id"][ - curr_remaining: - ] - batch_columns["entropies"] = batch_columns["entropies"][ - curr_remaining: - ] - batch_columns["text"] = batch_columns["text"][curr_remaining:] + if self.file_format == "arrow": + leftover_sample_ids = batch_columns["sample_id"][ + curr_remaining: + ] + leftover_entropies = batch_columns["entropies"][curr_remaining:] + leftover_texts = batch_columns["text"][curr_remaining:] + elif self.file_format == "json": + leftover_sample_ids = batch_columns[get_id_key(batch_columns)][ + curr_remaining: + ] + leftover_entropies = None + leftover_texts = get_text(batch_columns)[curr_remaining:] + else: + raise ValueError(f"Unknown file format: {self.file_format}") + + batch_columns["sample_id"] = leftover_sample_ids + batch_columns["entropies"] = leftover_entropies + batch_columns["text"] = leftover_texts self.batch_to_consume = batch_columns break elif len(batch) == curr_remaining: @@ -250,30 +288,3 @@ class ArrowFileIterator(StatefulIterator): logger.info( f"Finished setting arrow position to {target_row_num} for {self.dataset_files}" ) - - -TRAIN_DATA_FILE_PATTERN = "*.chunk.*.jsonl" - - -def find_and_sanitize_chunks( - dataset_path: str, - world_size: int, - file_pattern: str = TRAIN_DATA_FILE_PATTERN, - s3_profile: str | None = None, -): - fs = get_fs(dataset_path, s3_profile=s3_profile) - path_with_glob = os.path.join(dataset_path, file_pattern) - dataset_chunks = fs.glob(path_with_glob) - n_chunks = len(dataset_chunks) - - if n_chunks > world_size: - n_discard = n_chunks - world_size - dataset_chunks = dataset_chunks[:world_size] - else: - assert ( - world_size % n_chunks == 0 - ), "World size should be a multiple of number of chunks" - - assert n_chunks > 0, f"No valid chunks in {dataset_path}" - - return dataset_chunks diff --git a/bytelatent/preprocess/preprocess_entropies.py b/bytelatent/preprocess/preprocess_entropies.py index 519da94..31a4802 100644 --- a/bytelatent/preprocess/preprocess_entropies.py +++ b/bytelatent/preprocess/preprocess_entropies.py @@ -15,29 +15,37 @@ from bytelatent.entropy_model import load_entropy_model from bytelatent.tokenizers.build_tokenizer import TokenizerArgs -def get_id_from_doc(doc: dict) -> int: +def get_id_key(doc: dict) -> int: """ We need a reliable way to ensure that samples from jsonl and arrow are the same, but there is no unique id field, so derive the best possible """ if "sample_id" in doc: - sample_id = doc["sample_id"] + return "sample_id" elif "title" in doc: - sample_id = doc["title"] + return "title" elif "qid" in doc: - sample_id = doc["qid"] + return "qid" elif "paper_id" in doc: - sample_id = doc["paper_id"] + return "paper_id" elif "path" in doc: - sample_id = doc["path"] + return "path" elif "url" in doc: - sample_id = doc["url"] + return "url" elif "id" in doc: - sample_id = doc["id"] + return "id" else: raise ValueError(f"Could not find a id key from: {doc.keys()}") - return str(sample_id) + + +def get_id_from_doc(doc: dict) -> int: + """ + We need a reliable way to ensure that samples from jsonl + and arrow are the same, but there is no unique id field, + so derive the best possible + """ + return str(doc[get_id_key(doc)]) def get_text(doc: dict): diff --git a/bytelatent/stool.py b/bytelatent/stool.py index b156177..b47ddc7 100644 --- a/bytelatent/stool.py +++ b/bytelatent/stool.py @@ -4,10 +4,10 @@ import json import os import shutil import subprocess -from pydantic import BaseModel from typing import Any, Dict from omegaconf import OmegaConf +from pydantic import BaseModel class StoolArgs(BaseModel): From 9c3c997cae9eff047acca9b19c9a8c69dd5f9102 Mon Sep 17 00:00:00 2001 From: Pedro Rodriguez Date: Thu, 6 Feb 2025 17:43:20 +0000 Subject: [PATCH 36/59] Allow ArrowIterator to read from json Summary: Currently, arrow iterator can only read arrow files. However, the pyarrow library can read other formats, including jsonlines. This allows the same ArrowIterator to read from jsonlines, so we can read from the original source data, and simply omit the entropy column when doing so Test Plan: Run train script until dataloader starts --- bytelatent/args.py | 38 +++++++- bytelatent/data/iterators/arrow_iterator.py | 97 +++++++++++-------- bytelatent/preprocess/preprocess_entropies.py | 26 +++-- bytelatent/stool.py | 2 +- 4 files changed, 105 insertions(+), 58 deletions(-) diff --git a/bytelatent/args.py b/bytelatent/args.py index fc72b32..263e8e3 100644 --- a/bytelatent/args.py +++ b/bytelatent/args.py @@ -1,8 +1,10 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. +import json import logging import os from typing import Any +import fsspec import numpy as np import yaml from omegaconf import OmegaConf @@ -10,11 +12,9 @@ from pydantic import BaseModel, ConfigDict from bytelatent.checkpoint import CheckpointArgs from bytelatent.data.data_types import Batch +from bytelatent.data.file_util import get_fs from bytelatent.data.iterators.abstract_iterator import StatefulIterator -from bytelatent.data.iterators.arrow_iterator import ( - ArrowFileIterator, - find_and_sanitize_chunks, -) +from bytelatent.data.iterators.arrow_iterator import ArrowFileIterator from bytelatent.data.iterators.looping_iterator import LoopingIterator from bytelatent.data.iterators.multiprocess_iterator import MultiprocessIterator from bytelatent.data.iterators.packing_iterator import PackingArgs, PackingIterator @@ -53,6 +53,33 @@ def parse_args(args_cls): return pydantic_args +TRAIN_DATA_FILE_PATTERN = "*.chunk.*.jsonl" + + +def find_and_sanitize_chunks( + dataset_path: str, + world_size: int, + file_pattern: str, + s3_profile: str | None = None, +): + fs = get_fs(dataset_path, s3_profile=s3_profile) + path_with_glob = os.path.join(dataset_path, file_pattern) + dataset_chunks = fs.glob(path_with_glob) + n_chunks = len(dataset_chunks) + + if n_chunks > world_size: + n_discard = n_chunks - world_size + dataset_chunks = dataset_chunks[:world_size] + else: + assert ( + world_size % n_chunks == 0 + ), "World size should be a multiple of number of chunks" + + assert n_chunks > 0, f"No valid chunks in {dataset_path}" + + return dataset_chunks + + def distribute_data_to_rank( *, dataset_path: str, @@ -62,9 +89,10 @@ def distribute_data_to_rank( rank: int, world_size: int, s3_profile: str | None = None, + file_pattern: str = TRAIN_DATA_FILE_PATTERN, ) -> ArrowFileIterator: dataset_chunks = find_and_sanitize_chunks( - dataset_path, world_size, s3_profile=s3_profile + dataset_path, world_size, file_pattern, s3_profile=s3_profile ) n_workers_per_chunk = world_size // len(dataset_chunks) rank_to_arrow_iterator_params = [] diff --git a/bytelatent/data/iterators/arrow_iterator.py b/bytelatent/data/iterators/arrow_iterator.py index 4e7b99e..1c68d3a 100644 --- a/bytelatent/data/iterators/arrow_iterator.py +++ b/bytelatent/data/iterators/arrow_iterator.py @@ -16,6 +16,7 @@ from bytelatent import ByteLatentError from bytelatent.data.data_types import BltExample from bytelatent.data.file_util import get_fs from bytelatent.data.iterators.abstract_iterator import IteratorState, StatefulIterator +from bytelatent.preprocess.preprocess_entropies import get_id_key, get_text logger = getLogger(__name__) @@ -32,6 +33,7 @@ class ArrowFileIteratorState(BaseModel, IteratorState): arrow_batch_size: int = 100 s3_profile: str | None filesystem_type: str | None = None + file_format: str def build(self) -> "ArrowFileIterator": arrow_file = ArrowFileIterator( @@ -44,6 +46,7 @@ class ArrowFileIteratorState(BaseModel, IteratorState): dataset_files=self.dataset_files, s3_profile=self.s3_profile, filesystem_type=self.filesystem_type, + file_format=self.file_format, ) if self.row_num != 0: arrow_file._set_row_num(self.row_num) @@ -70,6 +73,7 @@ class ArrowFileIterator(StatefulIterator): dataset_files: list[str] | None = None, s3_profile: str | None = None, filesystem_type: str | None = None, + file_format: str = "arrow", ): assert 0 <= worker_id < num_workers, (worker_id, num_workers) if file_path is None and dataset_files is None: @@ -87,12 +91,16 @@ class ArrowFileIterator(StatefulIterator): self.arrow_batch_size = arrow_batch_size self.s3_profile = s3_profile self.filesystem_type = filesystem_type + self.file_format = file_format self.fs = None if self.filesystem_type is not None: if self.filesystem_type == "file": self.fs = fsspec.filesystem("file") elif self.filesystem_type == "s3": self.fs = fsspec.filesystem("s3", profile=s3_profile) + else: + raise ValueError("Unknown filesystem") + logger.info("Arrow iterator using fs=%s", self.fs) if dataset_files is None: # Prepare arrow shards @@ -153,6 +161,7 @@ class ArrowFileIterator(StatefulIterator): dataset_files=self.dataset_files, s3_profile=self.s3_profile, filesystem_type=self.filesystem_type, + file_format=self.file_format, ) def create_iter( @@ -164,7 +173,7 @@ class ArrowFileIterator(StatefulIterator): else: filesystem = None self.dataset = pa.dataset.dataset( - self.dataset_files, format="arrow", filesystem=filesystem + self.dataset_files, format=self.file_format, filesystem=filesystem ) self.batch_iterator = self.dataset.to_batches( batch_size=self.arrow_batch_size @@ -173,13 +182,22 @@ class ArrowFileIterator(StatefulIterator): if self.batch_to_consume is not None: batch_columns: dict[str, list] = self.batch_to_consume self.batch_to_consume = None - sample_ids = batch_columns["sample_id"] - texts = batch_columns["text"] - entropies = batch_columns["entropies"] + if self.file_format == "arrow": + sample_ids = batch_columns["sample_id"] + texts = batch_columns["text"] + entropies = batch_columns["entropies"] + elif self.file_format == "json": + # This data hasn't been preprocessed to a uniform format, + # so we have to do it now and omit entropies + sample_ids = batch_columns[get_id_key(batch_columns)] + texts = get_text(batch_columns) + entropies = None + else: + raise ValueError(f"Unknown file format: {self.file_format}") for i in range(len(sample_ids)): out = BltExample( sample_id=sample_ids[i], - entropies=entropies[i], + entropies=entropies[i] if entropies is not None else None, text=texts[i], tokens=None, mask=None, @@ -191,13 +209,22 @@ class ArrowFileIterator(StatefulIterator): for batch in self.batch_iterator: batch_columns = batch.to_pydict() - sample_ids = batch_columns["sample_id"] - texts = batch_columns["text"] - entropies = batch_columns["entropies"] + if self.file_format == "arrow": + sample_ids = batch_columns["sample_id"] + texts = batch_columns["text"] + entropies = batch_columns["entropies"] + elif self.file_format == "json": + # This data hasn't been preprocessed to a uniform format, + # so we have to do it now and omit entropies + sample_ids = batch_columns[get_id_key(batch_columns)] + texts = get_text(batch_columns) + entropies = None + else: + raise ValueError(f"Unknown file format: {self.file_format}") for i in range(len(sample_ids)): out = BltExample( sample_id=sample_ids[i], - entropies=entropies[i], + entropies=entropies[i] if entropies is not None else None, text=texts[i], tokens=None, mask=None, @@ -231,13 +258,24 @@ class ArrowFileIterator(StatefulIterator): for batch in self.batch_iterator: if len(batch) > curr_remaining: batch_columns: dict[str, list] = batch.to_pydict() - batch_columns["sample_id"] = batch_columns["sample_id"][ - curr_remaining: - ] - batch_columns["entropies"] = batch_columns["entropies"][ - curr_remaining: - ] - batch_columns["text"] = batch_columns["text"][curr_remaining:] + if self.file_format == "arrow": + leftover_sample_ids = batch_columns["sample_id"][ + curr_remaining: + ] + leftover_entropies = batch_columns["entropies"][curr_remaining:] + leftover_texts = batch_columns["text"][curr_remaining:] + elif self.file_format == "json": + leftover_sample_ids = batch_columns[get_id_key(batch_columns)][ + curr_remaining: + ] + leftover_entropies = None + leftover_texts = get_text(batch_columns)[curr_remaining:] + else: + raise ValueError(f"Unknown file format: {self.file_format}") + + batch_columns["sample_id"] = leftover_sample_ids + batch_columns["entropies"] = leftover_entropies + batch_columns["text"] = leftover_texts self.batch_to_consume = batch_columns break elif len(batch) == curr_remaining: @@ -250,30 +288,3 @@ class ArrowFileIterator(StatefulIterator): logger.info( f"Finished setting arrow position to {target_row_num} for {self.dataset_files}" ) - - -TRAIN_DATA_FILE_PATTERN = "*.chunk.*.jsonl" - - -def find_and_sanitize_chunks( - dataset_path: str, - world_size: int, - file_pattern: str = TRAIN_DATA_FILE_PATTERN, - s3_profile: str | None = None, -): - fs = get_fs(dataset_path, s3_profile=s3_profile) - path_with_glob = os.path.join(dataset_path, file_pattern) - dataset_chunks = fs.glob(path_with_glob) - n_chunks = len(dataset_chunks) - - if n_chunks > world_size: - n_discard = n_chunks - world_size - dataset_chunks = dataset_chunks[:world_size] - else: - assert ( - world_size % n_chunks == 0 - ), "World size should be a multiple of number of chunks" - - assert n_chunks > 0, f"No valid chunks in {dataset_path}" - - return dataset_chunks diff --git a/bytelatent/preprocess/preprocess_entropies.py b/bytelatent/preprocess/preprocess_entropies.py index 519da94..31a4802 100644 --- a/bytelatent/preprocess/preprocess_entropies.py +++ b/bytelatent/preprocess/preprocess_entropies.py @@ -15,29 +15,37 @@ from bytelatent.entropy_model import load_entropy_model from bytelatent.tokenizers.build_tokenizer import TokenizerArgs -def get_id_from_doc(doc: dict) -> int: +def get_id_key(doc: dict) -> int: """ We need a reliable way to ensure that samples from jsonl and arrow are the same, but there is no unique id field, so derive the best possible """ if "sample_id" in doc: - sample_id = doc["sample_id"] + return "sample_id" elif "title" in doc: - sample_id = doc["title"] + return "title" elif "qid" in doc: - sample_id = doc["qid"] + return "qid" elif "paper_id" in doc: - sample_id = doc["paper_id"] + return "paper_id" elif "path" in doc: - sample_id = doc["path"] + return "path" elif "url" in doc: - sample_id = doc["url"] + return "url" elif "id" in doc: - sample_id = doc["id"] + return "id" else: raise ValueError(f"Could not find a id key from: {doc.keys()}") - return str(sample_id) + + +def get_id_from_doc(doc: dict) -> int: + """ + We need a reliable way to ensure that samples from jsonl + and arrow are the same, but there is no unique id field, + so derive the best possible + """ + return str(doc[get_id_key(doc)]) def get_text(doc: dict): diff --git a/bytelatent/stool.py b/bytelatent/stool.py index b156177..b47ddc7 100644 --- a/bytelatent/stool.py +++ b/bytelatent/stool.py @@ -4,10 +4,10 @@ import json import os import shutil import subprocess -from pydantic import BaseModel from typing import Any, Dict from omegaconf import OmegaConf +from pydantic import BaseModel class StoolArgs(BaseModel): From 4e2ed0aa050b7d9ec610c69cb7010d59586fcf8d Mon Sep 17 00:00:00 2001 From: Pedro Rodriguez Date: Thu, 6 Feb 2025 18:08:01 +0000 Subject: [PATCH 37/59] Add bpb and n_bytes to metric logging Summary: Test Plan: --- bytelatent/distributed.py | 16 +++++++++++--- bytelatent/train.py | 46 +++++++++++++++++++++++++++++++++++++-- 2 files changed, 57 insertions(+), 5 deletions(-) diff --git a/bytelatent/distributed.py b/bytelatent/distributed.py index 168cb7c..b3f31a6 100644 --- a/bytelatent/distributed.py +++ b/bytelatent/distributed.py @@ -127,6 +127,16 @@ def dist_max(x: Union[int, float], mesh: DeviceMesh = None): return tensor +def dist_sum( + x: Union[int, float], mesh: DeviceMesh = None, reduce_dtype: torch.dtype = None +): + tensor = torch.tensor(x).cuda() + if reduce_dtype is not None: + tensor = tensor.to(reduce_dtype) + dist.all_reduce(tensor, op=ReduceOp.SUM, group=mesh.get_group() if mesh else None) + return tensor + + def dist_mean(x: Union[int, float], mesh: DeviceMesh = None): tensor = torch.tensor(x).cuda() dist.all_reduce(tensor, op=ReduceOp.AVG, group=mesh.get_group() if mesh else None) @@ -236,7 +246,7 @@ def setup_env(env_args: EnvironmentArgs): logger.warning(f"WARNING: Setting {name} to {value}") -def setup_torch_distributed(dist_args): +def setup_torch_distributed(dist_args: DistributedArgs): """ Handle single and multi-GPU / multi-node / SLURM jobs. Initialize the following variables: @@ -388,14 +398,14 @@ def clean_env(): def parallelize_model( - model, + model: torch.nn.Module, device_mesh, model_args, distributed_args: DistributedArgs, fsdp_grouping_plan: Optional[List[Tuple[str, bool]]] = None, tp_parallelize=None, no_recompute_ops=None, -): +) -> torch.nn.Module: if distributed_args.tp_size > 1: assert ( distributed_args.fsdp_type == "full_shard" diff --git a/bytelatent/train.py b/bytelatent/train.py index 9bfe12a..5b31e0c 100644 --- a/bytelatent/train.py +++ b/bytelatent/train.py @@ -3,6 +3,7 @@ import gc import logging +import math import os import sys from contextlib import ExitStack @@ -11,6 +12,7 @@ from dataclasses import asdict, dataclass from timeit import default_timer as timer from typing import Any, TypeVar +import numpy as np import torch import torch.distributed import torch.nn.functional @@ -32,7 +34,9 @@ from bytelatent.data.iterators.packing_iterator import PackingIteratorState from bytelatent.distributed import ( check_model_value_range, clean_env, + dist_mean, dist_mean_dict, + dist_sum, get_device_mesh, get_is_master, get_world_size, @@ -392,6 +396,9 @@ def train(args: TrainArgs): time_last_log = timer() gc.collect() saved = False + step_losses: list[float] = [] + step_tok_losses: list[float] = [] + n_bytes: int = 0 while train_state.step < args.steps and ( args.max_steps is None or train_state.step < args.max_steps ): @@ -413,6 +420,24 @@ def train(args: TrainArgs): batch_patch_lengths = torch.from_numpy(batch.patch_lengths).cuda() mask = None if batch.mask is None else torch.from_numpy(batch.mask).cuda() + if args.data.tokenizer_args.name in ["bytes", "blt"]: + if mask is None: + n_bytes += batch_y.numel() + else: + n_bytes += mask.sum() + elif args.data.tokenizer_args.name in ["sp", "tiktoken"]: + for example in batch.y: + target_tokens = tokenizer.decode(example.tolist(), cut_at_eos=False) + n_bytes += ( + len(bytes(target_tokens, encoding="utf-8", errors="ignore")) + + sum(example == tokenizer.eos_id) + + sum(example == tokenizer.bos_id) + ) + else: + raise ValueError( + f"Unexpected tokenizer to count n_bytes for: {args.data.tokenizer_args.name}" + ) + if ( not args.train_entropy_model and args.model.encoder_enable_byte_ngrams @@ -487,7 +512,7 @@ def train(args: TrainArgs): batch_x, patch_lengths=batch_patch_lengths, ngram_ids=ngram_ids ) - loss, _ = compute_loss(pred, batch_y, mask, train_state.scale) + loss, tok_loss = compute_loss(pred, batch_y, mask, train_state.scale) # We scale loss with grad_acc_steps so the gradient is the same # regardless of grad_acc_steps @@ -498,6 +523,10 @@ def train(args: TrainArgs): # For logging we undo that scaling loss = loss.detach() * args.grad_acc_steps + # Undo loss scaling so downstream down't need to worry about it + step_losses.append((loss / train_state.scale).item()) + step_tok_losses.append(tok_loss / train_state.scale) + world_size = get_world_size() if 1 < world_size <= 8: # For some reason, there are errors in reduces due to @@ -598,20 +627,33 @@ def train(args: TrainArgs): gpu_memory_monitor.reset_peak_stats() nwords_since_last_log = 0 time_last_log = timer() + stacked_tok_loss = torch.cat(step_tok_losses, dim=0) + total_tok_loss = dist_sum( + stacked_tok_loss.sum().item(), reduce_dtype=torch.bfloat16 + ) + total_n_bytes = dist_sum(n_bytes, reduce_dtype=torch.bfloat16) + avg_bpb = total_tok_loss / math.log(2) / total_n_bytes + avg_loss = dist_mean(np.mean(step_losses).item()) logger.info( f"step: {train_state.step}" f" acc: {train_state.acc_step}" - f" loss: {round(loss.item(),4):>7}" + f" loss: step={round(loss.item(),4):>7} avg={avg_loss}" + f" bpb: {avg_bpb:3f}" f" grad: {grad_norm:.2e}" f" flops: {FLOPS:.2e}" f" wps: {wps:.2e}" f" iter: {curr_iter_time:>7}" f" data: {data_load_time:>5}" f" lr: {curr_lr:.2e}" + f" n_bytes={total_n_bytes}" f" mem: {gpu_mem_stats.max_active_pct:.0f}%" f" pow: {gpu_mem_stats.power_draw/1000} W" ) + n_bytes = 0 + step_losses = [] + step_tok_losses = [] + if every_n_steps( train_state, args.checkpoint.dump.every, acc_step=0 ) or every_n_steps(train_state, args.checkpoint.eval.every, acc_step=0): From b6396eb0f47982a9fd8d527665d55b5263e405d1 Mon Sep 17 00:00:00 2001 From: Pedro Rodriguez Date: Fri, 7 Feb 2025 00:26:00 +0000 Subject: [PATCH 38/59] Add bpb and n_bytes to metric logging Summary: Test Plan: --- bytelatent/distributed.py | 16 ++++- bytelatent/metrics.py | 1 - bytelatent/train.py | 133 ++++++++++++++++++++++++++++++-------- 3 files changed, 120 insertions(+), 30 deletions(-) diff --git a/bytelatent/distributed.py b/bytelatent/distributed.py index 168cb7c..b3f31a6 100644 --- a/bytelatent/distributed.py +++ b/bytelatent/distributed.py @@ -127,6 +127,16 @@ def dist_max(x: Union[int, float], mesh: DeviceMesh = None): return tensor +def dist_sum( + x: Union[int, float], mesh: DeviceMesh = None, reduce_dtype: torch.dtype = None +): + tensor = torch.tensor(x).cuda() + if reduce_dtype is not None: + tensor = tensor.to(reduce_dtype) + dist.all_reduce(tensor, op=ReduceOp.SUM, group=mesh.get_group() if mesh else None) + return tensor + + def dist_mean(x: Union[int, float], mesh: DeviceMesh = None): tensor = torch.tensor(x).cuda() dist.all_reduce(tensor, op=ReduceOp.AVG, group=mesh.get_group() if mesh else None) @@ -236,7 +246,7 @@ def setup_env(env_args: EnvironmentArgs): logger.warning(f"WARNING: Setting {name} to {value}") -def setup_torch_distributed(dist_args): +def setup_torch_distributed(dist_args: DistributedArgs): """ Handle single and multi-GPU / multi-node / SLURM jobs. Initialize the following variables: @@ -388,14 +398,14 @@ def clean_env(): def parallelize_model( - model, + model: torch.nn.Module, device_mesh, model_args, distributed_args: DistributedArgs, fsdp_grouping_plan: Optional[List[Tuple[str, bool]]] = None, tp_parallelize=None, no_recompute_ops=None, -): +) -> torch.nn.Module: if distributed_args.tp_size > 1: assert ( distributed_args.fsdp_type == "full_shard" diff --git a/bytelatent/metrics.py b/bytelatent/metrics.py index fb443d7..ed805e5 100644 --- a/bytelatent/metrics.py +++ b/bytelatent/metrics.py @@ -49,7 +49,6 @@ class LoggingArgs(BaseModel): model_config = ConfigDict(extra="forbid") freq: int = 10 # Log every freq optimizer steps acc_freq: int | None = None # Log every acc_freq gradient accumulation steps - wandb: WandbArgs | None = None diff --git a/bytelatent/train.py b/bytelatent/train.py index 9bfe12a..4441cf6 100644 --- a/bytelatent/train.py +++ b/bytelatent/train.py @@ -3,6 +3,7 @@ import gc import logging +import math import os import sys from contextlib import ExitStack @@ -11,6 +12,7 @@ from dataclasses import asdict, dataclass from timeit import default_timer as timer from typing import Any, TypeVar +import numpy as np import torch import torch.distributed import torch.nn.functional @@ -32,7 +34,9 @@ from bytelatent.data.iterators.packing_iterator import PackingIteratorState from bytelatent.distributed import ( check_model_value_range, clean_env, + dist_mean, dist_mean_dict, + dist_sum, get_device_mesh, get_is_master, get_world_size, @@ -392,6 +396,9 @@ def train(args: TrainArgs): time_last_log = timer() gc.collect() saved = False + step_losses: list[float] = [] + step_tok_losses: list[float] = [] + n_bytes: int = 0 while train_state.step < args.steps and ( args.max_steps is None or train_state.step < args.max_steps ): @@ -413,6 +420,21 @@ def train(args: TrainArgs): batch_patch_lengths = torch.from_numpy(batch.patch_lengths).cuda() mask = None if batch.mask is None else torch.from_numpy(batch.mask).cuda() + if args.data.tokenizer_args.name in ["bytes", "blt"]: + n_bytes += batch_y.numel() if mask is None else mask.sum() + elif args.data.tokenizer_args.name in ["sp", "tiktoken"]: + for example in batch.y: + target_tokens = tokenizer.decode(example.tolist(), cut_at_eos=False) + n_bytes += ( + len(bytes(target_tokens, encoding="utf-8", errors="ignore")) + + sum(example == tokenizer.eos_id) + + sum(example == tokenizer.bos_id) + ) + else: + raise ValueError( + f"Unexpected tokenizer to count n_bytes for: {args.data.tokenizer_args.name}" + ) + if ( not args.train_entropy_model and args.model.encoder_enable_byte_ngrams @@ -487,7 +509,7 @@ def train(args: TrainArgs): batch_x, patch_lengths=batch_patch_lengths, ngram_ids=ngram_ids ) - loss, _ = compute_loss(pred, batch_y, mask, train_state.scale) + loss, tok_loss = compute_loss(pred, batch_y, mask, train_state.scale) # We scale loss with grad_acc_steps so the gradient is the same # regardless of grad_acc_steps @@ -498,6 +520,10 @@ def train(args: TrainArgs): # For logging we undo that scaling loss = loss.detach() * args.grad_acc_steps + # Undo loss scaling so downstream down't need to worry about it + step_losses.append((loss / train_state.scale).item()) + step_tok_losses.append(tok_loss / train_state.scale) + world_size = get_world_size() if 1 < world_size <= 8: # For some reason, there are errors in reduces due to @@ -568,50 +594,105 @@ def train(args: TrainArgs): * wps ) - metrics = flatten_dict( - { - "global_step": train_state.step, - "acc_step": train_state.acc_step, - "speed": { - "wps": wps, - "FLOPS": FLOPS, - "curr_iter_time": curr_iter_time, - "data_load_time": data_load_time, - }, - "optim": { - "grad_norm": grad_norm, - "lr": curr_lr, - "total_tokens": total_tokens, - }, - "memory": gpu_mem_stats._asdict(), - }, - sep="/", + # Below, semantics are: + # per_gpu: Metrics on a given rank + # across_gpus: Metrics averaged/summed across all ranks + # step: Metric at a step + # interval: Metric averaged/summed across all steps since the last log interval. + # Typically, this is 10 + step_loss_per_gpu = loss.item() + step_loss_across_gpus = dist_mean(step_loss_per_gpu).item() + interval_loss_per_gpu = np.mean(step_losses).item() + interval_loss_across_gpus = dist_mean(interval_loss_per_gpu).item() + + stacked_tok_loss = torch.cat(step_tok_losses, dim=0) + interval_total_tok_loss_per_gpu = stacked_tok_loss.sum().item() + interval_total_tok_loss_across_gpus = dist_sum( + interval_total_tok_loss_per_gpu, reduce_dtype=torch.bfloat16 + ).item() + interval_total_n_bytes_per_gpu = n_bytes + interval_total_n_bytes_across_gpus = dist_sum( + n_bytes, reduce_dtype=torch.bfloat16 + ).item() + + interval_bpb_per_gpu = ( + interval_total_tok_loss_per_gpu + / math.log(2) + / interval_total_n_bytes_per_gpu + ) + interval_bpb_across_gpus = ( + interval_total_tok_loss_across_gpus + / math.log(2) + / interval_total_n_bytes_across_gpus ) - to_sync = {} - to_sync["loss/out"] = loss.item() - metrics.update(dist_mean_dict(to_sync)) + metric_dict = { + "global_step": train_state.step, + "acc_step": train_state.acc_step, + "speed": { + "wps": wps, + "FLOPS": FLOPS, + "curr_iter_time": curr_iter_time, + "data_load_time": data_load_time, + }, + "optim": { + "grad_norm": grad_norm, + "lr": curr_lr, + "total_tokens": total_tokens, + }, + "memory": gpu_mem_stats._asdict(), + "loss": { + "step_per_gpu": step_loss_per_gpu, + "step_across_gpu": step_loss_across_gpus, + "interval_per_gpu": interval_loss_per_gpu, + "interval_across_gpu": interval_loss_across_gpus, + }, + "bpb": { + "interval_per_gpu": interval_bpb_per_gpu, + "interval_across_gpus": interval_bpb_across_gpus, + }, + "n_bytes": { + "interval_per_gpu": interval_total_n_bytes_per_gpu, + "interval_across_gpus": interval_total_n_bytes_across_gpus, + }, + } + + metrics = flatten_dict( + metric_dict, + sep="/", + ) if get_is_master(): metric_logger.log(metrics) - gpu_memory_monitor.reset_peak_stats() - nwords_since_last_log = 0 - time_last_log = timer() + # Below semantics are: + # step=Metrics at a step + # interval=Metrics averaged across the logging interval + # local=On one rank + # global=Across all ranks logger.info( f"step: {train_state.step}" f" acc: {train_state.acc_step}" - f" loss: {round(loss.item(),4):>7}" + f" loss: [step_local={round(step_loss_per_gpu, 4):>7} interval_local={round(interval_loss_per_gpu, 4):>7} step_global={round(step_loss_across_gpus, 4):>7} interval_global={round(interval_loss_across_gpus, 4):>7}]" + f" bpb: [interval_local={interval_bpb_per_gpu:3f} interval_global={interval_bpb_across_gpus:3f}]" f" grad: {grad_norm:.2e}" f" flops: {FLOPS:.2e}" f" wps: {wps:.2e}" f" iter: {curr_iter_time:>7}" f" data: {data_load_time:>5}" f" lr: {curr_lr:.2e}" + f" n_bytes: [interval_local={int(interval_total_n_bytes_per_gpu)} interval_global={int(interval_total_n_bytes_across_gpus)}]" f" mem: {gpu_mem_stats.max_active_pct:.0f}%" f" pow: {gpu_mem_stats.power_draw/1000} W" ) + n_bytes = 0 + step_losses = [] + step_tok_losses = [] + gpu_memory_monitor.reset_peak_stats() + nwords_since_last_log = 0 + time_last_log = timer() + if every_n_steps( train_state, args.checkpoint.dump.every, acc_step=0 ) or every_n_steps(train_state, args.checkpoint.eval.every, acc_step=0): From 8d7338308e46a14b45b30de499dd714b2ffe8f69 Mon Sep 17 00:00:00 2001 From: Pedro Rodriguez Date: Fri, 7 Feb 2025 00:26:00 +0000 Subject: [PATCH 39/59] Add bpb and n_bytes to metric logging Summary: Test Plan: --- .gitignore | 1 + bytelatent/distributed.py | 16 ++++- bytelatent/metrics.py | 1 - bytelatent/train.py | 136 ++++++++++++++++++++++++++++++-------- 4 files changed, 124 insertions(+), 30 deletions(-) diff --git a/.gitignore b/.gitignore index d1d7c2a..2d0f075 100644 --- a/.gitignore +++ b/.gitignore @@ -167,3 +167,4 @@ figures/ .DS_Store internal/ jobs_parallel-copy/ +wandb/ diff --git a/bytelatent/distributed.py b/bytelatent/distributed.py index 168cb7c..b3f31a6 100644 --- a/bytelatent/distributed.py +++ b/bytelatent/distributed.py @@ -127,6 +127,16 @@ def dist_max(x: Union[int, float], mesh: DeviceMesh = None): return tensor +def dist_sum( + x: Union[int, float], mesh: DeviceMesh = None, reduce_dtype: torch.dtype = None +): + tensor = torch.tensor(x).cuda() + if reduce_dtype is not None: + tensor = tensor.to(reduce_dtype) + dist.all_reduce(tensor, op=ReduceOp.SUM, group=mesh.get_group() if mesh else None) + return tensor + + def dist_mean(x: Union[int, float], mesh: DeviceMesh = None): tensor = torch.tensor(x).cuda() dist.all_reduce(tensor, op=ReduceOp.AVG, group=mesh.get_group() if mesh else None) @@ -236,7 +246,7 @@ def setup_env(env_args: EnvironmentArgs): logger.warning(f"WARNING: Setting {name} to {value}") -def setup_torch_distributed(dist_args): +def setup_torch_distributed(dist_args: DistributedArgs): """ Handle single and multi-GPU / multi-node / SLURM jobs. Initialize the following variables: @@ -388,14 +398,14 @@ def clean_env(): def parallelize_model( - model, + model: torch.nn.Module, device_mesh, model_args, distributed_args: DistributedArgs, fsdp_grouping_plan: Optional[List[Tuple[str, bool]]] = None, tp_parallelize=None, no_recompute_ops=None, -): +) -> torch.nn.Module: if distributed_args.tp_size > 1: assert ( distributed_args.fsdp_type == "full_shard" diff --git a/bytelatent/metrics.py b/bytelatent/metrics.py index fb443d7..ed805e5 100644 --- a/bytelatent/metrics.py +++ b/bytelatent/metrics.py @@ -49,7 +49,6 @@ class LoggingArgs(BaseModel): model_config = ConfigDict(extra="forbid") freq: int = 10 # Log every freq optimizer steps acc_freq: int | None = None # Log every acc_freq gradient accumulation steps - wandb: WandbArgs | None = None diff --git a/bytelatent/train.py b/bytelatent/train.py index 9bfe12a..ed84233 100644 --- a/bytelatent/train.py +++ b/bytelatent/train.py @@ -3,6 +3,7 @@ import gc import logging +import math import os import sys from contextlib import ExitStack @@ -11,6 +12,7 @@ from dataclasses import asdict, dataclass from timeit import default_timer as timer from typing import Any, TypeVar +import numpy as np import torch import torch.distributed import torch.nn.functional @@ -32,7 +34,9 @@ from bytelatent.data.iterators.packing_iterator import PackingIteratorState from bytelatent.distributed import ( check_model_value_range, clean_env, + dist_mean, dist_mean_dict, + dist_sum, get_device_mesh, get_is_master, get_world_size, @@ -392,6 +396,9 @@ def train(args: TrainArgs): time_last_log = timer() gc.collect() saved = False + step_losses: list[float] = [] + step_tok_losses: list[float] = [] + n_bytes: int = 0 while train_state.step < args.steps and ( args.max_steps is None or train_state.step < args.max_steps ): @@ -413,6 +420,21 @@ def train(args: TrainArgs): batch_patch_lengths = torch.from_numpy(batch.patch_lengths).cuda() mask = None if batch.mask is None else torch.from_numpy(batch.mask).cuda() + if args.data.tokenizer_args.name in ["bytes", "blt"]: + n_bytes += batch_y.numel() if mask is None else mask.sum() + elif args.data.tokenizer_args.name in ["sp", "tiktoken"]: + for example in batch.y: + target_tokens = tokenizer.decode(example.tolist(), cut_at_eos=False) + n_bytes += ( + len(bytes(target_tokens, encoding="utf-8", errors="ignore")) + + sum(example == tokenizer.eos_id) + + sum(example == tokenizer.bos_id) + ) + else: + raise ValueError( + f"Unexpected tokenizer to count n_bytes for: {args.data.tokenizer_args.name}" + ) + if ( not args.train_entropy_model and args.model.encoder_enable_byte_ngrams @@ -487,7 +509,7 @@ def train(args: TrainArgs): batch_x, patch_lengths=batch_patch_lengths, ngram_ids=ngram_ids ) - loss, _ = compute_loss(pred, batch_y, mask, train_state.scale) + loss, tok_loss = compute_loss(pred, batch_y, mask, train_state.scale) # We scale loss with grad_acc_steps so the gradient is the same # regardless of grad_acc_steps @@ -498,6 +520,10 @@ def train(args: TrainArgs): # For logging we undo that scaling loss = loss.detach() * args.grad_acc_steps + # Undo loss scaling so downstream down't need to worry about it + step_losses.append((loss / train_state.scale).item()) + step_tok_losses.append(tok_loss / train_state.scale) + world_size = get_world_size() if 1 < world_size <= 8: # For some reason, there are errors in reduces due to @@ -568,50 +594,108 @@ def train(args: TrainArgs): * wps ) - metrics = flatten_dict( - { - "global_step": train_state.step, - "acc_step": train_state.acc_step, - "speed": { - "wps": wps, - "FLOPS": FLOPS, - "curr_iter_time": curr_iter_time, - "data_load_time": data_load_time, - }, - "optim": { - "grad_norm": grad_norm, - "lr": curr_lr, - "total_tokens": total_tokens, - }, - "memory": gpu_mem_stats._asdict(), - }, - sep="/", + # Below, semantics are: + # per_gpu: Metrics on a given rank + # across_gpus: Metrics averaged/summed across all ranks + # step: Metric at a step + # interval: Metric averaged/summed across all steps since the last log interval. + # Typically, this is 10 + step_loss_per_gpu = loss.item() + step_loss_across_gpus = dist_mean(step_loss_per_gpu).item() + interval_loss_per_gpu = np.mean(step_losses).item() + interval_loss_across_gpus = dist_mean(interval_loss_per_gpu).item() + + stacked_tok_loss = torch.cat(step_tok_losses, dim=0) + interval_total_tok_loss_per_gpu = stacked_tok_loss.sum().item() + interval_total_tok_loss_across_gpus = dist_sum( + interval_total_tok_loss_per_gpu, reduce_dtype=torch.bfloat16 + ).item() + interval_total_n_bytes_per_gpu = n_bytes + interval_total_n_bytes_across_gpus = dist_sum( + n_bytes, reduce_dtype=torch.bfloat16 + ).item() + + interval_bpb_per_gpu = ( + interval_total_tok_loss_per_gpu + / math.log(2) + / interval_total_n_bytes_per_gpu + ) + interval_bpb_across_gpus = ( + interval_total_tok_loss_across_gpus + / math.log(2) + / interval_total_n_bytes_across_gpus ) - to_sync = {} - to_sync["loss/out"] = loss.item() - metrics.update(dist_mean_dict(to_sync)) + metric_dict = { + "global_step": train_state.step, + "acc_step": train_state.acc_step, + "speed": { + "wps": wps, + "FLOPS": FLOPS, + "curr_iter_time": curr_iter_time, + "data_load_time": data_load_time, + }, + "optim": { + "grad_norm": grad_norm, + "lr": curr_lr, + "total_tokens": total_tokens, + }, + "memory": gpu_mem_stats._asdict(), + "loss": { + "step_per_gpu": step_loss_per_gpu, + "step_across_gpu": step_loss_across_gpus, + "interval_per_gpu": interval_loss_per_gpu, + "interval_across_gpu": interval_loss_across_gpus, + }, + "bpb": { + "interval_per_gpu": interval_bpb_per_gpu, + "interval_across_gpus": interval_bpb_across_gpus, + }, + "n_bytes": { + "interval_per_gpu": interval_total_n_bytes_per_gpu, + "interval_across_gpus": interval_total_n_bytes_across_gpus, + }, + } + + metrics = flatten_dict( + metric_dict, + sep="/", + ) if get_is_master(): metric_logger.log(metrics) - gpu_memory_monitor.reset_peak_stats() - nwords_since_last_log = 0 - time_last_log = timer() + # Below semantics are: + # step=Metrics at a step + # interval=Metrics averaged across the logging interval + # local=On one rank + # global=Across all ranks logger.info( f"step: {train_state.step}" f" acc: {train_state.acc_step}" - f" loss: {round(loss.item(),4):>7}" + f" loss_gpu: {round(interval_loss_per_gpu, 4):>7}" + f" loss_avg: {round(interval_loss_across_gpus, 4):>7}" + f" bpb_gpu: {interval_bpb_per_gpu:3f}" + f" bpb_avg: {interval_bpb_across_gpus:3f}" f" grad: {grad_norm:.2e}" f" flops: {FLOPS:.2e}" f" wps: {wps:.2e}" f" iter: {curr_iter_time:>7}" f" data: {data_load_time:>5}" f" lr: {curr_lr:.2e}" + f" n_bytes_gpu: {int(interval_total_n_bytes_per_gpu)}" + f" n_bytes_sum: {int(interval_total_n_bytes_across_gpus)}" f" mem: {gpu_mem_stats.max_active_pct:.0f}%" f" pow: {gpu_mem_stats.power_draw/1000} W" ) + n_bytes = 0 + step_losses = [] + step_tok_losses = [] + gpu_memory_monitor.reset_peak_stats() + nwords_since_last_log = 0 + time_last_log = timer() + if every_n_steps( train_state, args.checkpoint.dump.every, acc_step=0 ) or every_n_steps(train_state, args.checkpoint.eval.every, acc_step=0): From 5c8fb4f1b30b9dec2b3bd5be30d91808a0e4ec45 Mon Sep 17 00:00:00 2001 From: Pedro Rodriguez Date: Fri, 7 Feb 2025 23:26:48 +0000 Subject: [PATCH 40/59] Fix multiprocessing dataloader checkpointing and use it in the train script Summary: Test Plan: --- bytelatent/data/iterators/arrow_iterator.py | 4 +-- .../data/iterators/multiprocess_iterator.py | 27 +++++++++++++------ bytelatent/train.py | 18 +++++++++++++ 3 files changed, 39 insertions(+), 10 deletions(-) diff --git a/bytelatent/data/iterators/arrow_iterator.py b/bytelatent/data/iterators/arrow_iterator.py index 1c68d3a..7e43360 100644 --- a/bytelatent/data/iterators/arrow_iterator.py +++ b/bytelatent/data/iterators/arrow_iterator.py @@ -236,7 +236,7 @@ class ArrowFileIterator(StatefulIterator): def _set_row_num(self, target_row_num: int): logger.info( - f"Setting arrow position to {target_row_num} for {self.dataset_files}" + f"Setting arrow position to {target_row_num} for {str(self.dataset_files)[:200]}" ) if target_row_num is None or target_row_num == 0: self.row_num = 0 @@ -286,5 +286,5 @@ class ArrowFileIterator(StatefulIterator): curr_remaining -= len(batch) self.row_num = target_row_num logger.info( - f"Finished setting arrow position to {target_row_num} for {self.dataset_files}" + f"Finished setting arrow position to {target_row_num} for {str(self.dataset_files)[:200]}" ) diff --git a/bytelatent/data/iterators/multiprocess_iterator.py b/bytelatent/data/iterators/multiprocess_iterator.py index 49d99ac..33bde94 100644 --- a/bytelatent/data/iterators/multiprocess_iterator.py +++ b/bytelatent/data/iterators/multiprocess_iterator.py @@ -54,9 +54,10 @@ def start_work_from_state( if stop_event.is_set(): # Signal the end of output, this ensures that even if the queue takes a while to # buffer, that the main thread receives everything (and tosses this fake batch) - logging.info( + logging.debug( "Worker thread: Stop event detected, outputting is_final=True batch" ) + logging.debug("Worker thread: batch_queue full=%s", batch_queue.full()) batch_queue.put( Batch( x=np.zeros((1, 1)), @@ -67,14 +68,17 @@ def start_work_from_state( ngram_ids=None, ) ) + logging.debug( + "Worker thread: is_final=True batch put in queue, breaking from loop." + ) break try: - logging.info("Worker thread: outputting state") - state_queue.put(iterator.get_state(), timeout=1) - logging.info("Worker thread: state dump complete") + logging.debug("Worker thread: outputting state") + state_queue.put(stateful_iterator.get_state(), timeout=1) + logging.debug("Worker thread: state dump complete") state_dumped_event.set() - logging.info("Worker thread: set state_dump_event") + logging.debug("Worker thread: set state_dump_event") except Full: raise ValueError( "Attempted to dump state into the state queue, but it was full" @@ -156,16 +160,20 @@ class MultiprocessIterator(StatefulIterator): serialized_prefetch_buffer=serialized_prefetch_buffer, ) else: - logging.info("Main thread: Sending stop iteration event") + logging.debug("Main thread: Sending stop iteration event") self.stop_iterating_event.set() - logging.info("Main thread: Waiting for state_dumped event") - self.state_dumped_event.wait() + logging.debug( + "Main thread: Emptying the batch_queue until batch.is_final=True is found." + ) self.prefetch_buffer = [] final_batch_received = False while True: try: batch = self.batch_queue.get(timeout=1) if batch.is_final: + logging.debug( + "Main thread: is_final=True batch found, stopping fetch from batch_queue" + ) final_batch_received = True break self.prefetch_buffer.append(batch) @@ -173,6 +181,9 @@ class MultiprocessIterator(StatefulIterator): logging.warning("Main thread: batch_queue is abnormally empty") assert final_batch_received + logging.debug("Main thread: Waiting for state_dumped event") + self.state_dumped_event.wait() + try: base_iterator_state = self.state_queue.get(timeout=1) assert isinstance(base_iterator_state, IteratorState) diff --git a/bytelatent/train.py b/bytelatent/train.py index ed84233..229d904 100644 --- a/bytelatent/train.py +++ b/bytelatent/train.py @@ -699,6 +699,12 @@ def train(args: TrainArgs): if every_n_steps( train_state, args.checkpoint.dump.every, acc_step=0 ) or every_n_steps(train_state, args.checkpoint.eval.every, acc_step=0): + # Re-init dataloader and iterator is necessary since get_state() + # on mp iterator shuts down MP to correctly persist state and it needs + # to be restarted. + train_state.data_loader_state = data_loader.get_state() + data_loader = train_state.data_loader_state.build() + batch_iterator = data_loader.create_iter() saved = checkpoint.save( model, optimizer, @@ -740,6 +746,12 @@ def train(args: TrainArgs): if preemption_flag["flag"]: if not saved: + # Re-init dataloader and iterator is necessary since get_state() + # on mp iterator shuts down MP to correctly persist state and it needs + # to be restarted. + train_state.data_loader_state = data_loader.get_state() + data_loader = train_state.data_loader_state.build() + batch_iterator = data_loader.create_iter() checkpoint.save( model, optimizer, @@ -751,6 +763,12 @@ def train(args: TrainArgs): sys.exit(0) if not saved: + # Re-init dataloader and iterator is necessary since get_state() + # on mp iterator shuts down MP to correctly persist state and it needs + # to be restarted. + train_state.data_loader_state = data_loader.get_state() + data_loader = train_state.data_loader_state.build() + batch_iterator = data_loader.create_iter() checkpoint.save( model, optimizer, From 38cc67a9535f2a19ae3dee526114ee77d291d76f Mon Sep 17 00:00:00 2001 From: Pedro Rodriguez Date: Tue, 11 Feb 2025 22:56:25 +0000 Subject: [PATCH 41/59] Fix multiprocessing dataloader checkpointing and use it in the train script Summary: Test Plan: --- .../data/iterators/abstract_iterator.py | 10 +++++++ bytelatent/data/iterators/arrow_iterator.py | 4 +-- .../data/iterators/multiprocess_iterator.py | 27 +++++++++++++------ bytelatent/train.py | 10 +++++++ 4 files changed, 41 insertions(+), 10 deletions(-) diff --git a/bytelatent/data/iterators/abstract_iterator.py b/bytelatent/data/iterators/abstract_iterator.py index 7fb442b..8ac7f19 100644 --- a/bytelatent/data/iterators/abstract_iterator.py +++ b/bytelatent/data/iterators/abstract_iterator.py @@ -21,3 +21,13 @@ class IteratorState(Generic[C]): @abc.abstractmethod def build(self) -> StatefulIterator[T, C]: pass + + +def get_state_and_refresh(iterator: StatefulIterator): + # Re-init dataloader and iterator is necessary since get_state() + # on mp iterator shuts down MP to correctly persist state and it needs + # to be restarted. + state = iterator.get_state() + data_loader = state.build() + py_iterator = data_loader.create_iter() + return state, data_loader, py_iterator diff --git a/bytelatent/data/iterators/arrow_iterator.py b/bytelatent/data/iterators/arrow_iterator.py index 1c68d3a..7e43360 100644 --- a/bytelatent/data/iterators/arrow_iterator.py +++ b/bytelatent/data/iterators/arrow_iterator.py @@ -236,7 +236,7 @@ class ArrowFileIterator(StatefulIterator): def _set_row_num(self, target_row_num: int): logger.info( - f"Setting arrow position to {target_row_num} for {self.dataset_files}" + f"Setting arrow position to {target_row_num} for {str(self.dataset_files)[:200]}" ) if target_row_num is None or target_row_num == 0: self.row_num = 0 @@ -286,5 +286,5 @@ class ArrowFileIterator(StatefulIterator): curr_remaining -= len(batch) self.row_num = target_row_num logger.info( - f"Finished setting arrow position to {target_row_num} for {self.dataset_files}" + f"Finished setting arrow position to {target_row_num} for {str(self.dataset_files)[:200]}" ) diff --git a/bytelatent/data/iterators/multiprocess_iterator.py b/bytelatent/data/iterators/multiprocess_iterator.py index 49d99ac..33bde94 100644 --- a/bytelatent/data/iterators/multiprocess_iterator.py +++ b/bytelatent/data/iterators/multiprocess_iterator.py @@ -54,9 +54,10 @@ def start_work_from_state( if stop_event.is_set(): # Signal the end of output, this ensures that even if the queue takes a while to # buffer, that the main thread receives everything (and tosses this fake batch) - logging.info( + logging.debug( "Worker thread: Stop event detected, outputting is_final=True batch" ) + logging.debug("Worker thread: batch_queue full=%s", batch_queue.full()) batch_queue.put( Batch( x=np.zeros((1, 1)), @@ -67,14 +68,17 @@ def start_work_from_state( ngram_ids=None, ) ) + logging.debug( + "Worker thread: is_final=True batch put in queue, breaking from loop." + ) break try: - logging.info("Worker thread: outputting state") - state_queue.put(iterator.get_state(), timeout=1) - logging.info("Worker thread: state dump complete") + logging.debug("Worker thread: outputting state") + state_queue.put(stateful_iterator.get_state(), timeout=1) + logging.debug("Worker thread: state dump complete") state_dumped_event.set() - logging.info("Worker thread: set state_dump_event") + logging.debug("Worker thread: set state_dump_event") except Full: raise ValueError( "Attempted to dump state into the state queue, but it was full" @@ -156,16 +160,20 @@ class MultiprocessIterator(StatefulIterator): serialized_prefetch_buffer=serialized_prefetch_buffer, ) else: - logging.info("Main thread: Sending stop iteration event") + logging.debug("Main thread: Sending stop iteration event") self.stop_iterating_event.set() - logging.info("Main thread: Waiting for state_dumped event") - self.state_dumped_event.wait() + logging.debug( + "Main thread: Emptying the batch_queue until batch.is_final=True is found." + ) self.prefetch_buffer = [] final_batch_received = False while True: try: batch = self.batch_queue.get(timeout=1) if batch.is_final: + logging.debug( + "Main thread: is_final=True batch found, stopping fetch from batch_queue" + ) final_batch_received = True break self.prefetch_buffer.append(batch) @@ -173,6 +181,9 @@ class MultiprocessIterator(StatefulIterator): logging.warning("Main thread: batch_queue is abnormally empty") assert final_batch_received + logging.debug("Main thread: Waiting for state_dumped event") + self.state_dumped_event.wait() + try: base_iterator_state = self.state_queue.get(timeout=1) assert isinstance(base_iterator_state, IteratorState) diff --git a/bytelatent/train.py b/bytelatent/train.py index ed84233..8b667f1 100644 --- a/bytelatent/train.py +++ b/bytelatent/train.py @@ -26,6 +26,7 @@ from torch.optim import lr_scheduler from bytelatent.args import TrainArgs, parse_args from bytelatent.checkpoint import CheckpointManager, load_from_checkpoint from bytelatent.data.file_util import get_fs +from bytelatent.data.iterators.abstract_iterator import get_state_and_refresh from bytelatent.data.iterators.multiprocess_iterator import ( MultiprocessIterator, MultiprocessIteratorState, @@ -699,6 +700,9 @@ def train(args: TrainArgs): if every_n_steps( train_state, args.checkpoint.dump.every, acc_step=0 ) or every_n_steps(train_state, args.checkpoint.eval.every, acc_step=0): + train_state.data_loader_state, data_loader, batch_iterator = ( + get_state_and_refresh(data_loader) + ) saved = checkpoint.save( model, optimizer, @@ -740,6 +744,9 @@ def train(args: TrainArgs): if preemption_flag["flag"]: if not saved: + train_state.data_loader_state, data_loader, batch_iterator = ( + get_state_and_refresh(data_loader) + ) checkpoint.save( model, optimizer, @@ -751,6 +758,9 @@ def train(args: TrainArgs): sys.exit(0) if not saved: + train_state.data_loader_state, data_loader, batch_iterator = ( + get_state_and_refresh(data_loader) + ) checkpoint.save( model, optimizer, From 92af9b3f564748f9f3799bed4ad7511f3c225fb7 Mon Sep 17 00:00:00 2001 From: Pedro Rodriguez Date: Wed, 12 Feb 2025 18:07:21 +0000 Subject: [PATCH 42/59] Test first batch matches Summary: Test Plan: --- .../data/iterators/test_arrow_iterator.py | 3 ++ bytelatent/data/test_data.py | 46 +++++++++++++++++++ bytelatent/test_entropy_model.py | 1 + 3 files changed, 50 insertions(+) create mode 100644 bytelatent/data/test_data.py diff --git a/bytelatent/data/iterators/test_arrow_iterator.py b/bytelatent/data/iterators/test_arrow_iterator.py index fd448eb..064217e 100644 --- a/bytelatent/data/iterators/test_arrow_iterator.py +++ b/bytelatent/data/iterators/test_arrow_iterator.py @@ -28,6 +28,7 @@ def test_basic_arrow_file(): row_num=0, arrow_batch_size=100, s3_profile=None, + file_format="arrow", ) arrow_file = initial_state.build() start_state = arrow_file.get_state() @@ -57,6 +58,7 @@ def test_basic_arrow_file(): row_num=251, arrow_batch_size=100, s3_profile=None, + file_format="arrow", ) arrow_file = resumed_state.build() for example in arrow_file.create_iter(): @@ -77,6 +79,7 @@ def test_basic_arrow_file(): row_num=0, arrow_batch_size=100, s3_profile=None, + file_format="arrow", ) arrow_file = rank_state.build() expected_ids = [] diff --git a/bytelatent/data/test_data.py b/bytelatent/data/test_data.py new file mode 100644 index 0000000..15f2996 --- /dev/null +++ b/bytelatent/data/test_data.py @@ -0,0 +1,46 @@ +import os +import pickle +import pytest +from omegaconf import OmegaConf +from bytelatent.args import TrainArgs +from bytelatent.constants import BLT_DATA + + +def get_test_config(): + if "BLT_INTERNAL" in os.environ: + internal_dir = os.environ["BLT_INTERNAL"] + else: + internal_dir = "../internal-blt/configs" + test_config = os.path.join(internal_dir, "tests.yaml") + return test_config + + +@pytest.mark.skipif( + not os.path.exists(get_test_config()), + reason="Skipping since internal config is missing", +) +def test_first_batch_matches(): + test_config_path = get_test_config() + default_cfg = OmegaConf.create(TrainArgs().model_dump()) + file_cfg = OmegaConf.load(test_config_path) + merged_cfg = OmegaConf.merge(default_cfg, file_cfg) + merged_cfg = OmegaConf.to_container(merged_cfg, resolve=True, throw_on_missing=True) + train_args = TrainArgs.model_validate(merged_cfg) + # MP doesn't work with async very well, but it doesn't change logic + train_args.data.load_async = False + + # Test data created by pickling first batch in train loop then exiting + with open(os.path.join(BLT_DATA, "fixtures", "first_batch_0.pickle"), "rb") as f: + first_batch = pickle.load(f) + + # Emulate 1 node, 8 gpu training + data_loader = train_args.data.build_from_rank(0, 8) + batch_iterator = data_loader.create_iter() + print("Getting first batch") + batch = next(batch_iterator) + assert (batch.x == first_batch.x).all() + assert (batch.y == first_batch.y).all() + assert (batch.mask == first_batch.mask).all() + assert (batch.patch_lengths == first_batch.patch_lengths).all() + assert batch.ngram_ids is None and first_batch.ngram_ids is None + assert batch.is_final == False and batch.is_final == False diff --git a/bytelatent/test_entropy_model.py b/bytelatent/test_entropy_model.py index 9db7ff6..8623eb1 100644 --- a/bytelatent/test_entropy_model.py +++ b/bytelatent/test_entropy_model.py @@ -25,6 +25,7 @@ def test_entropy_model(): row_num=0, arrow_batch_size=100, s3_profile=None, + file_format="arrow", ) arrow_file = initial_state.build() tokenizer_args = TokenizerArgs( From 4cee32ea8ceab6ca003e8bb79a3a0bde87b4666d Mon Sep 17 00:00:00 2001 From: Pedro Rodriguez Date: Wed, 12 Feb 2025 18:07:21 +0000 Subject: [PATCH 43/59] Test first batch matches Summary: Test Plan: --- apps/main/lingua_train.py | 3 +- .../data/iterators/test_arrow_iterator.py | 3 ++ bytelatent/data/test_data.py | 48 +++++++++++++++++++ bytelatent/metrics.py | 6 ++- bytelatent/profiling.py | 3 +- bytelatent/test_entropy_model.py | 1 + bytelatent/train.py | 3 +- 7 files changed, 62 insertions(+), 5 deletions(-) create mode 100644 bytelatent/data/test_data.py diff --git a/apps/main/lingua_train.py b/apps/main/lingua_train.py index 7925ec6..323b101 100644 --- a/apps/main/lingua_train.py +++ b/apps/main/lingua_train.py @@ -14,7 +14,6 @@ from typing import Any, Dict, Optional import torch import torch.distributed -import wandb import xformers.profiler from lingua.args import dataclass_from_dict, dump_config, flatten_dict from lingua.data import ( @@ -70,6 +69,8 @@ from bytelatent.transformer import ( tp_parallelize, ) +import wandb + logger = logging.getLogger() diff --git a/bytelatent/data/iterators/test_arrow_iterator.py b/bytelatent/data/iterators/test_arrow_iterator.py index fd448eb..064217e 100644 --- a/bytelatent/data/iterators/test_arrow_iterator.py +++ b/bytelatent/data/iterators/test_arrow_iterator.py @@ -28,6 +28,7 @@ def test_basic_arrow_file(): row_num=0, arrow_batch_size=100, s3_profile=None, + file_format="arrow", ) arrow_file = initial_state.build() start_state = arrow_file.get_state() @@ -57,6 +58,7 @@ def test_basic_arrow_file(): row_num=251, arrow_batch_size=100, s3_profile=None, + file_format="arrow", ) arrow_file = resumed_state.build() for example in arrow_file.create_iter(): @@ -77,6 +79,7 @@ def test_basic_arrow_file(): row_num=0, arrow_batch_size=100, s3_profile=None, + file_format="arrow", ) arrow_file = rank_state.build() expected_ids = [] diff --git a/bytelatent/data/test_data.py b/bytelatent/data/test_data.py new file mode 100644 index 0000000..efb8bcf --- /dev/null +++ b/bytelatent/data/test_data.py @@ -0,0 +1,48 @@ +import os +import pickle + +import pytest +from omegaconf import OmegaConf + +from bytelatent.args import TrainArgs +from bytelatent.constants import BLT_DATA + + +def get_test_config(): + if "BLT_INTERNAL" in os.environ: + internal_dir = os.environ["BLT_INTERNAL"] + else: + internal_dir = "../internal-blt/configs" + test_config = os.path.join(internal_dir, "tests.yaml") + return test_config + + +@pytest.mark.skipif( + not os.path.exists(get_test_config()), + reason="Skipping since internal config is missing", +) +def test_first_batch_matches(): + test_config_path = get_test_config() + default_cfg = OmegaConf.create(TrainArgs().model_dump()) + file_cfg = OmegaConf.load(test_config_path) + merged_cfg = OmegaConf.merge(default_cfg, file_cfg) + merged_cfg = OmegaConf.to_container(merged_cfg, resolve=True, throw_on_missing=True) + train_args = TrainArgs.model_validate(merged_cfg) + # MP doesn't work with async very well, but it doesn't change logic + train_args.data.load_async = False + + # Test data created by pickling first batch in train loop then exiting + with open(os.path.join(BLT_DATA, "fixtures", "first_batch_0.pickle"), "rb") as f: + first_batch = pickle.load(f) + + # Emulate 1 node, 8 gpu training + data_loader = train_args.data.build_from_rank(0, 8) + batch_iterator = data_loader.create_iter() + print("Getting first batch") + batch = next(batch_iterator) + assert (batch.x == first_batch.x).all() + assert (batch.y == first_batch.y).all() + assert (batch.mask == first_batch.mask).all() + assert (batch.patch_lengths == first_batch.patch_lengths).all() + assert batch.ngram_ids is None and first_batch.ngram_ids is None + assert batch.is_final == False and batch.is_final == False diff --git a/bytelatent/metrics.py b/bytelatent/metrics.py index ed805e5..86c6910 100644 --- a/bytelatent/metrics.py +++ b/bytelatent/metrics.py @@ -11,11 +11,12 @@ from typing import Any, Union import fsspec import torch import torch.nn as nn -import wandb from pydantic import BaseModel, ConfigDict from bytelatent.distributed import get_is_master +import wandb + logger = logging.getLogger() @@ -198,9 +199,10 @@ def upload_train_to_wandb( import json from pathlib import Path - import wandb from omegaconf import OmegaConf + import wandb + cfg = OmegaConf.load(Path(ckpt_dir) / "config.yaml") cfg = OmegaConf.to_container(cfg) diff --git a/bytelatent/profiling.py b/bytelatent/profiling.py index da3c90d..66bcbd6 100644 --- a/bytelatent/profiling.py +++ b/bytelatent/profiling.py @@ -7,7 +7,6 @@ import os from pathlib import Path import torch.distributed -import wandb import xformers.profiler from pydantic import BaseModel from torch.profiler.profiler import profile @@ -15,6 +14,8 @@ from xformers.profiler import MemSnapshotsProfiler, PyTorchProfiler from bytelatent.distributed import get_is_master +import wandb + class ProfilerArgs(BaseModel): run: bool = False diff --git a/bytelatent/test_entropy_model.py b/bytelatent/test_entropy_model.py index 9db7ff6..8623eb1 100644 --- a/bytelatent/test_entropy_model.py +++ b/bytelatent/test_entropy_model.py @@ -25,6 +25,7 @@ def test_entropy_model(): row_num=0, arrow_batch_size=100, s3_profile=None, + file_format="arrow", ) arrow_file = initial_state.build() tokenizer_args = TokenizerArgs( diff --git a/bytelatent/train.py b/bytelatent/train.py index ed84233..a5c0a83 100644 --- a/bytelatent/train.py +++ b/bytelatent/train.py @@ -17,7 +17,6 @@ import torch import torch.distributed import torch.nn.functional import torch.nn.functional as F -import wandb import xformers.profiler from torch.distributed._tensor import DTensor from torch.distributed.checkpoint.stateful import Stateful @@ -63,6 +62,8 @@ from bytelatent.transformer import ( tp_parallelize, ) +import wandb + logger = logging.getLogger() T = TypeVar("T") From bd3cf61bb995c5c6240de3b59025615b00be4c46 Mon Sep 17 00:00:00 2001 From: Pedro Rodriguez Date: Wed, 12 Feb 2025 18:09:26 +0000 Subject: [PATCH 44/59] Fix multiprocessing dataloader checkpointing and use it in the train script Summary: Test Plan: --- bytelatent/args.py | 2 -- .../data/iterators/abstract_iterator.py | 10 +++++++ bytelatent/data/iterators/arrow_iterator.py | 4 +-- .../data/iterators/multiprocess_iterator.py | 27 +++++++++++++------ bytelatent/train.py | 11 +++++++- 5 files changed, 41 insertions(+), 13 deletions(-) diff --git a/bytelatent/args.py b/bytelatent/args.py index 263e8e3..47bd0f9 100644 --- a/bytelatent/args.py +++ b/bytelatent/args.py @@ -1,10 +1,8 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. -import json import logging import os from typing import Any -import fsspec import numpy as np import yaml from omegaconf import OmegaConf diff --git a/bytelatent/data/iterators/abstract_iterator.py b/bytelatent/data/iterators/abstract_iterator.py index 7fb442b..8ac7f19 100644 --- a/bytelatent/data/iterators/abstract_iterator.py +++ b/bytelatent/data/iterators/abstract_iterator.py @@ -21,3 +21,13 @@ class IteratorState(Generic[C]): @abc.abstractmethod def build(self) -> StatefulIterator[T, C]: pass + + +def get_state_and_refresh(iterator: StatefulIterator): + # Re-init dataloader and iterator is necessary since get_state() + # on mp iterator shuts down MP to correctly persist state and it needs + # to be restarted. + state = iterator.get_state() + data_loader = state.build() + py_iterator = data_loader.create_iter() + return state, data_loader, py_iterator diff --git a/bytelatent/data/iterators/arrow_iterator.py b/bytelatent/data/iterators/arrow_iterator.py index 1c68d3a..7e43360 100644 --- a/bytelatent/data/iterators/arrow_iterator.py +++ b/bytelatent/data/iterators/arrow_iterator.py @@ -236,7 +236,7 @@ class ArrowFileIterator(StatefulIterator): def _set_row_num(self, target_row_num: int): logger.info( - f"Setting arrow position to {target_row_num} for {self.dataset_files}" + f"Setting arrow position to {target_row_num} for {str(self.dataset_files)[:200]}" ) if target_row_num is None or target_row_num == 0: self.row_num = 0 @@ -286,5 +286,5 @@ class ArrowFileIterator(StatefulIterator): curr_remaining -= len(batch) self.row_num = target_row_num logger.info( - f"Finished setting arrow position to {target_row_num} for {self.dataset_files}" + f"Finished setting arrow position to {target_row_num} for {str(self.dataset_files)[:200]}" ) diff --git a/bytelatent/data/iterators/multiprocess_iterator.py b/bytelatent/data/iterators/multiprocess_iterator.py index 49d99ac..33bde94 100644 --- a/bytelatent/data/iterators/multiprocess_iterator.py +++ b/bytelatent/data/iterators/multiprocess_iterator.py @@ -54,9 +54,10 @@ def start_work_from_state( if stop_event.is_set(): # Signal the end of output, this ensures that even if the queue takes a while to # buffer, that the main thread receives everything (and tosses this fake batch) - logging.info( + logging.debug( "Worker thread: Stop event detected, outputting is_final=True batch" ) + logging.debug("Worker thread: batch_queue full=%s", batch_queue.full()) batch_queue.put( Batch( x=np.zeros((1, 1)), @@ -67,14 +68,17 @@ def start_work_from_state( ngram_ids=None, ) ) + logging.debug( + "Worker thread: is_final=True batch put in queue, breaking from loop." + ) break try: - logging.info("Worker thread: outputting state") - state_queue.put(iterator.get_state(), timeout=1) - logging.info("Worker thread: state dump complete") + logging.debug("Worker thread: outputting state") + state_queue.put(stateful_iterator.get_state(), timeout=1) + logging.debug("Worker thread: state dump complete") state_dumped_event.set() - logging.info("Worker thread: set state_dump_event") + logging.debug("Worker thread: set state_dump_event") except Full: raise ValueError( "Attempted to dump state into the state queue, but it was full" @@ -156,16 +160,20 @@ class MultiprocessIterator(StatefulIterator): serialized_prefetch_buffer=serialized_prefetch_buffer, ) else: - logging.info("Main thread: Sending stop iteration event") + logging.debug("Main thread: Sending stop iteration event") self.stop_iterating_event.set() - logging.info("Main thread: Waiting for state_dumped event") - self.state_dumped_event.wait() + logging.debug( + "Main thread: Emptying the batch_queue until batch.is_final=True is found." + ) self.prefetch_buffer = [] final_batch_received = False while True: try: batch = self.batch_queue.get(timeout=1) if batch.is_final: + logging.debug( + "Main thread: is_final=True batch found, stopping fetch from batch_queue" + ) final_batch_received = True break self.prefetch_buffer.append(batch) @@ -173,6 +181,9 @@ class MultiprocessIterator(StatefulIterator): logging.warning("Main thread: batch_queue is abnormally empty") assert final_batch_received + logging.debug("Main thread: Waiting for state_dumped event") + self.state_dumped_event.wait() + try: base_iterator_state = self.state_queue.get(timeout=1) assert isinstance(base_iterator_state, IteratorState) diff --git a/bytelatent/train.py b/bytelatent/train.py index a5c0a83..9b35e58 100644 --- a/bytelatent/train.py +++ b/bytelatent/train.py @@ -25,6 +25,7 @@ from torch.optim import lr_scheduler from bytelatent.args import TrainArgs, parse_args from bytelatent.checkpoint import CheckpointManager, load_from_checkpoint from bytelatent.data.file_util import get_fs +from bytelatent.data.iterators.abstract_iterator import get_state_and_refresh from bytelatent.data.iterators.multiprocess_iterator import ( MultiprocessIterator, MultiprocessIteratorState, @@ -34,7 +35,6 @@ from bytelatent.distributed import ( check_model_value_range, clean_env, dist_mean, - dist_mean_dict, dist_sum, get_device_mesh, get_is_master, @@ -700,6 +700,9 @@ def train(args: TrainArgs): if every_n_steps( train_state, args.checkpoint.dump.every, acc_step=0 ) or every_n_steps(train_state, args.checkpoint.eval.every, acc_step=0): + train_state.data_loader_state, data_loader, batch_iterator = ( + get_state_and_refresh(data_loader) + ) saved = checkpoint.save( model, optimizer, @@ -741,6 +744,9 @@ def train(args: TrainArgs): if preemption_flag["flag"]: if not saved: + train_state.data_loader_state, data_loader, batch_iterator = ( + get_state_and_refresh(data_loader) + ) checkpoint.save( model, optimizer, @@ -752,6 +758,9 @@ def train(args: TrainArgs): sys.exit(0) if not saved: + train_state.data_loader_state, data_loader, batch_iterator = ( + get_state_and_refresh(data_loader) + ) checkpoint.save( model, optimizer, From c54c9f05173e80e75c31251940e239c7e93c9ae8 Mon Sep 17 00:00:00 2001 From: Pedro Rodriguez Date: Wed, 12 Feb 2025 18:07:21 +0000 Subject: [PATCH 45/59] Test first batch matches Summary: Test Plan: --- .../data/iterators/test_arrow_iterator.py | 3 ++ bytelatent/data/test_data.py | 48 +++++++++++++++++++ bytelatent/test_entropy_model.py | 1 + pyproject.toml | 1 + 4 files changed, 53 insertions(+) create mode 100644 bytelatent/data/test_data.py diff --git a/bytelatent/data/iterators/test_arrow_iterator.py b/bytelatent/data/iterators/test_arrow_iterator.py index fd448eb..064217e 100644 --- a/bytelatent/data/iterators/test_arrow_iterator.py +++ b/bytelatent/data/iterators/test_arrow_iterator.py @@ -28,6 +28,7 @@ def test_basic_arrow_file(): row_num=0, arrow_batch_size=100, s3_profile=None, + file_format="arrow", ) arrow_file = initial_state.build() start_state = arrow_file.get_state() @@ -57,6 +58,7 @@ def test_basic_arrow_file(): row_num=251, arrow_batch_size=100, s3_profile=None, + file_format="arrow", ) arrow_file = resumed_state.build() for example in arrow_file.create_iter(): @@ -77,6 +79,7 @@ def test_basic_arrow_file(): row_num=0, arrow_batch_size=100, s3_profile=None, + file_format="arrow", ) arrow_file = rank_state.build() expected_ids = [] diff --git a/bytelatent/data/test_data.py b/bytelatent/data/test_data.py new file mode 100644 index 0000000..efb8bcf --- /dev/null +++ b/bytelatent/data/test_data.py @@ -0,0 +1,48 @@ +import os +import pickle + +import pytest +from omegaconf import OmegaConf + +from bytelatent.args import TrainArgs +from bytelatent.constants import BLT_DATA + + +def get_test_config(): + if "BLT_INTERNAL" in os.environ: + internal_dir = os.environ["BLT_INTERNAL"] + else: + internal_dir = "../internal-blt/configs" + test_config = os.path.join(internal_dir, "tests.yaml") + return test_config + + +@pytest.mark.skipif( + not os.path.exists(get_test_config()), + reason="Skipping since internal config is missing", +) +def test_first_batch_matches(): + test_config_path = get_test_config() + default_cfg = OmegaConf.create(TrainArgs().model_dump()) + file_cfg = OmegaConf.load(test_config_path) + merged_cfg = OmegaConf.merge(default_cfg, file_cfg) + merged_cfg = OmegaConf.to_container(merged_cfg, resolve=True, throw_on_missing=True) + train_args = TrainArgs.model_validate(merged_cfg) + # MP doesn't work with async very well, but it doesn't change logic + train_args.data.load_async = False + + # Test data created by pickling first batch in train loop then exiting + with open(os.path.join(BLT_DATA, "fixtures", "first_batch_0.pickle"), "rb") as f: + first_batch = pickle.load(f) + + # Emulate 1 node, 8 gpu training + data_loader = train_args.data.build_from_rank(0, 8) + batch_iterator = data_loader.create_iter() + print("Getting first batch") + batch = next(batch_iterator) + assert (batch.x == first_batch.x).all() + assert (batch.y == first_batch.y).all() + assert (batch.mask == first_batch.mask).all() + assert (batch.patch_lengths == first_batch.patch_lengths).all() + assert batch.ngram_ids is None and first_batch.ngram_ids is None + assert batch.is_final == False and batch.is_final == False diff --git a/bytelatent/test_entropy_model.py b/bytelatent/test_entropy_model.py index 9db7ff6..8623eb1 100644 --- a/bytelatent/test_entropy_model.py +++ b/bytelatent/test_entropy_model.py @@ -25,6 +25,7 @@ def test_entropy_model(): row_num=0, arrow_batch_size=100, s3_profile=None, + file_format="arrow", ) arrow_file = initial_state.build() tokenizer_args = TokenizerArgs( diff --git a/pyproject.toml b/pyproject.toml index e2ecd0d..814d8d1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -2,4 +2,5 @@ profile = "black" known_bytelatent = "bytelatent" known_apps = "apps" +known_third_party = "wandb" sections = "FUTURE,STDLIB,THIRDPARTY,BYTELATENT,APPS,FIRSTPARTY,LOCALFOLDER" From 3e3193c1d41b7945dca25c3644b53dd3a5f0fade Mon Sep 17 00:00:00 2001 From: Pedro Rodriguez Date: Wed, 12 Feb 2025 18:24:40 +0000 Subject: [PATCH 46/59] Fix multiprocessing dataloader checkpointing and use it in the train script Summary: Test Plan: --- bytelatent/args.py | 2 -- .../data/iterators/abstract_iterator.py | 10 +++++++ bytelatent/data/iterators/arrow_iterator.py | 4 +-- .../data/iterators/multiprocess_iterator.py | 27 +++++++++++++------ bytelatent/train.py | 11 +++++++- 5 files changed, 41 insertions(+), 13 deletions(-) diff --git a/bytelatent/args.py b/bytelatent/args.py index 263e8e3..47bd0f9 100644 --- a/bytelatent/args.py +++ b/bytelatent/args.py @@ -1,10 +1,8 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. -import json import logging import os from typing import Any -import fsspec import numpy as np import yaml from omegaconf import OmegaConf diff --git a/bytelatent/data/iterators/abstract_iterator.py b/bytelatent/data/iterators/abstract_iterator.py index 7fb442b..8ac7f19 100644 --- a/bytelatent/data/iterators/abstract_iterator.py +++ b/bytelatent/data/iterators/abstract_iterator.py @@ -21,3 +21,13 @@ class IteratorState(Generic[C]): @abc.abstractmethod def build(self) -> StatefulIterator[T, C]: pass + + +def get_state_and_refresh(iterator: StatefulIterator): + # Re-init dataloader and iterator is necessary since get_state() + # on mp iterator shuts down MP to correctly persist state and it needs + # to be restarted. + state = iterator.get_state() + data_loader = state.build() + py_iterator = data_loader.create_iter() + return state, data_loader, py_iterator diff --git a/bytelatent/data/iterators/arrow_iterator.py b/bytelatent/data/iterators/arrow_iterator.py index 1c68d3a..7e43360 100644 --- a/bytelatent/data/iterators/arrow_iterator.py +++ b/bytelatent/data/iterators/arrow_iterator.py @@ -236,7 +236,7 @@ class ArrowFileIterator(StatefulIterator): def _set_row_num(self, target_row_num: int): logger.info( - f"Setting arrow position to {target_row_num} for {self.dataset_files}" + f"Setting arrow position to {target_row_num} for {str(self.dataset_files)[:200]}" ) if target_row_num is None or target_row_num == 0: self.row_num = 0 @@ -286,5 +286,5 @@ class ArrowFileIterator(StatefulIterator): curr_remaining -= len(batch) self.row_num = target_row_num logger.info( - f"Finished setting arrow position to {target_row_num} for {self.dataset_files}" + f"Finished setting arrow position to {target_row_num} for {str(self.dataset_files)[:200]}" ) diff --git a/bytelatent/data/iterators/multiprocess_iterator.py b/bytelatent/data/iterators/multiprocess_iterator.py index 49d99ac..33bde94 100644 --- a/bytelatent/data/iterators/multiprocess_iterator.py +++ b/bytelatent/data/iterators/multiprocess_iterator.py @@ -54,9 +54,10 @@ def start_work_from_state( if stop_event.is_set(): # Signal the end of output, this ensures that even if the queue takes a while to # buffer, that the main thread receives everything (and tosses this fake batch) - logging.info( + logging.debug( "Worker thread: Stop event detected, outputting is_final=True batch" ) + logging.debug("Worker thread: batch_queue full=%s", batch_queue.full()) batch_queue.put( Batch( x=np.zeros((1, 1)), @@ -67,14 +68,17 @@ def start_work_from_state( ngram_ids=None, ) ) + logging.debug( + "Worker thread: is_final=True batch put in queue, breaking from loop." + ) break try: - logging.info("Worker thread: outputting state") - state_queue.put(iterator.get_state(), timeout=1) - logging.info("Worker thread: state dump complete") + logging.debug("Worker thread: outputting state") + state_queue.put(stateful_iterator.get_state(), timeout=1) + logging.debug("Worker thread: state dump complete") state_dumped_event.set() - logging.info("Worker thread: set state_dump_event") + logging.debug("Worker thread: set state_dump_event") except Full: raise ValueError( "Attempted to dump state into the state queue, but it was full" @@ -156,16 +160,20 @@ class MultiprocessIterator(StatefulIterator): serialized_prefetch_buffer=serialized_prefetch_buffer, ) else: - logging.info("Main thread: Sending stop iteration event") + logging.debug("Main thread: Sending stop iteration event") self.stop_iterating_event.set() - logging.info("Main thread: Waiting for state_dumped event") - self.state_dumped_event.wait() + logging.debug( + "Main thread: Emptying the batch_queue until batch.is_final=True is found." + ) self.prefetch_buffer = [] final_batch_received = False while True: try: batch = self.batch_queue.get(timeout=1) if batch.is_final: + logging.debug( + "Main thread: is_final=True batch found, stopping fetch from batch_queue" + ) final_batch_received = True break self.prefetch_buffer.append(batch) @@ -173,6 +181,9 @@ class MultiprocessIterator(StatefulIterator): logging.warning("Main thread: batch_queue is abnormally empty") assert final_batch_received + logging.debug("Main thread: Waiting for state_dumped event") + self.state_dumped_event.wait() + try: base_iterator_state = self.state_queue.get(timeout=1) assert isinstance(base_iterator_state, IteratorState) diff --git a/bytelatent/train.py b/bytelatent/train.py index ed84233..d5b6de0 100644 --- a/bytelatent/train.py +++ b/bytelatent/train.py @@ -26,6 +26,7 @@ from torch.optim import lr_scheduler from bytelatent.args import TrainArgs, parse_args from bytelatent.checkpoint import CheckpointManager, load_from_checkpoint from bytelatent.data.file_util import get_fs +from bytelatent.data.iterators.abstract_iterator import get_state_and_refresh from bytelatent.data.iterators.multiprocess_iterator import ( MultiprocessIterator, MultiprocessIteratorState, @@ -35,7 +36,6 @@ from bytelatent.distributed import ( check_model_value_range, clean_env, dist_mean, - dist_mean_dict, dist_sum, get_device_mesh, get_is_master, @@ -699,6 +699,9 @@ def train(args: TrainArgs): if every_n_steps( train_state, args.checkpoint.dump.every, acc_step=0 ) or every_n_steps(train_state, args.checkpoint.eval.every, acc_step=0): + train_state.data_loader_state, data_loader, batch_iterator = ( + get_state_and_refresh(data_loader) + ) saved = checkpoint.save( model, optimizer, @@ -740,6 +743,9 @@ def train(args: TrainArgs): if preemption_flag["flag"]: if not saved: + train_state.data_loader_state, data_loader, batch_iterator = ( + get_state_and_refresh(data_loader) + ) checkpoint.save( model, optimizer, @@ -751,6 +757,9 @@ def train(args: TrainArgs): sys.exit(0) if not saved: + train_state.data_loader_state, data_loader, batch_iterator = ( + get_state_and_refresh(data_loader) + ) checkpoint.save( model, optimizer, From ece82cb96064e6bf55c09f9ac28807e4e82564b1 Mon Sep 17 00:00:00 2001 From: Pedro Rodriguez Date: Wed, 12 Feb 2025 19:24:49 +0000 Subject: [PATCH 47/59] Make it possible to specify multiple config files Summary: Test Plan: Test that this iterpolates in the right order, config -> configs -> cli args ``` # All three sources python -m bytelatent.print_config config=bytelatent/configs/debug.yaml configs=[internal/configs/s3_debug.yaml] eval=null # What worked before python -m bytelatent.print_config config=internal/configs/s3_debug.yaml eval=null ``` --- bytelatent/args.py | 19 +++++++++++++++---- bytelatent/configs/debug.yaml | 4 +--- bytelatent/print_config.py | 10 ++++++++++ 3 files changed, 26 insertions(+), 7 deletions(-) create mode 100644 bytelatent/print_config.py diff --git a/bytelatent/args.py b/bytelatent/args.py index 47bd0f9..6125ae7 100644 --- a/bytelatent/args.py +++ b/bytelatent/args.py @@ -40,12 +40,23 @@ def get_rng_state(seed: int, rank: int, world_size: int) -> dict[str, Any]: def parse_args(args_cls): cli_args = OmegaConf.from_cli() - file_cfg = OmegaConf.load(cli_args.config) - # We remove 'config' attribute from config as the underlying DataClass does not have it - del cli_args.config + file_cfgs = [] + if "config" in cli_args: + file_cfg = OmegaConf.load(cli_args["config"]) + del cli_args["config"] + file_cfgs.append(file_cfg) + + if "configs" in cli_args: + for c in cli_args["configs"]: + extra_file_cfg = OmegaConf.load(c) + file_cfgs.append(extra_file_cfg) + del cli_args["configs"] default_cfg = OmegaConf.create(args_cls().model_dump()) - cfg = OmegaConf.merge(default_cfg, file_cfg, cli_args) + to_merge = [default_cfg] + to_merge.extend(file_cfgs) + to_merge.append(cli_args) + cfg = OmegaConf.merge(*to_merge) cfg = OmegaConf.to_container(cfg, resolve=True, throw_on_missing=True) pydantic_args = args_cls.model_validate(cfg) return pydantic_args diff --git a/bytelatent/configs/debug.yaml b/bytelatent/configs/debug.yaml index 07d489f..1369364 100644 --- a/bytelatent/configs/debug.yaml +++ b/bytelatent/configs/debug.yaml @@ -56,13 +56,11 @@ model: recompute_attn: false custom_bwd: false layer_ckpt: "none" - patch_only_encoder: false - patch_only_decoder: false use_local_encoder_transformer: true init_use_gaussian: true init_use_depth: "current" - attn_bias_type: "block_causal" attn_impl: "xformers" + attn_bias_type: "block_causal" alpha_depth: "disabled" max_length: 256 local_attention_window_len: 512 diff --git a/bytelatent/print_config.py b/bytelatent/print_config.py new file mode 100644 index 0000000..3fd509a --- /dev/null +++ b/bytelatent/print_config.py @@ -0,0 +1,10 @@ +from bytelatent.args import TrainArgs, parse_args + + +def main(): + train_args = parse_args(TrainArgs) + print(train_args.model_dump_json(indent=4)) + + +if __name__ == "__main__": + main() From ab8f8a4412dd3b203d0d6bbc5fb73cfc1abbda97 Mon Sep 17 00:00:00 2001 From: Pedro Rodriguez Date: Thu, 13 Feb 2025 18:04:30 +0000 Subject: [PATCH 48/59] Test first batch matches Summary: Test Plan: --- .../data/iterators/test_arrow_iterator.py | 3 ++ bytelatent/data/test_data.py | 48 +++++++++++++++++++ bytelatent/test_entropy_model.py | 1 + pyproject.toml | 1 + 4 files changed, 53 insertions(+) create mode 100644 bytelatent/data/test_data.py diff --git a/bytelatent/data/iterators/test_arrow_iterator.py b/bytelatent/data/iterators/test_arrow_iterator.py index fd448eb..064217e 100644 --- a/bytelatent/data/iterators/test_arrow_iterator.py +++ b/bytelatent/data/iterators/test_arrow_iterator.py @@ -28,6 +28,7 @@ def test_basic_arrow_file(): row_num=0, arrow_batch_size=100, s3_profile=None, + file_format="arrow", ) arrow_file = initial_state.build() start_state = arrow_file.get_state() @@ -57,6 +58,7 @@ def test_basic_arrow_file(): row_num=251, arrow_batch_size=100, s3_profile=None, + file_format="arrow", ) arrow_file = resumed_state.build() for example in arrow_file.create_iter(): @@ -77,6 +79,7 @@ def test_basic_arrow_file(): row_num=0, arrow_batch_size=100, s3_profile=None, + file_format="arrow", ) arrow_file = rank_state.build() expected_ids = [] diff --git a/bytelatent/data/test_data.py b/bytelatent/data/test_data.py new file mode 100644 index 0000000..efb8bcf --- /dev/null +++ b/bytelatent/data/test_data.py @@ -0,0 +1,48 @@ +import os +import pickle + +import pytest +from omegaconf import OmegaConf + +from bytelatent.args import TrainArgs +from bytelatent.constants import BLT_DATA + + +def get_test_config(): + if "BLT_INTERNAL" in os.environ: + internal_dir = os.environ["BLT_INTERNAL"] + else: + internal_dir = "../internal-blt/configs" + test_config = os.path.join(internal_dir, "tests.yaml") + return test_config + + +@pytest.mark.skipif( + not os.path.exists(get_test_config()), + reason="Skipping since internal config is missing", +) +def test_first_batch_matches(): + test_config_path = get_test_config() + default_cfg = OmegaConf.create(TrainArgs().model_dump()) + file_cfg = OmegaConf.load(test_config_path) + merged_cfg = OmegaConf.merge(default_cfg, file_cfg) + merged_cfg = OmegaConf.to_container(merged_cfg, resolve=True, throw_on_missing=True) + train_args = TrainArgs.model_validate(merged_cfg) + # MP doesn't work with async very well, but it doesn't change logic + train_args.data.load_async = False + + # Test data created by pickling first batch in train loop then exiting + with open(os.path.join(BLT_DATA, "fixtures", "first_batch_0.pickle"), "rb") as f: + first_batch = pickle.load(f) + + # Emulate 1 node, 8 gpu training + data_loader = train_args.data.build_from_rank(0, 8) + batch_iterator = data_loader.create_iter() + print("Getting first batch") + batch = next(batch_iterator) + assert (batch.x == first_batch.x).all() + assert (batch.y == first_batch.y).all() + assert (batch.mask == first_batch.mask).all() + assert (batch.patch_lengths == first_batch.patch_lengths).all() + assert batch.ngram_ids is None and first_batch.ngram_ids is None + assert batch.is_final == False and batch.is_final == False diff --git a/bytelatent/test_entropy_model.py b/bytelatent/test_entropy_model.py index 9db7ff6..8623eb1 100644 --- a/bytelatent/test_entropy_model.py +++ b/bytelatent/test_entropy_model.py @@ -25,6 +25,7 @@ def test_entropy_model(): row_num=0, arrow_batch_size=100, s3_profile=None, + file_format="arrow", ) arrow_file = initial_state.build() tokenizer_args = TokenizerArgs( diff --git a/pyproject.toml b/pyproject.toml index e2ecd0d..814d8d1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -2,4 +2,5 @@ profile = "black" known_bytelatent = "bytelatent" known_apps = "apps" +known_third_party = "wandb" sections = "FUTURE,STDLIB,THIRDPARTY,BYTELATENT,APPS,FIRSTPARTY,LOCALFOLDER" From 0c6cb995a0aecb1355f57408102f37871ff3f825 Mon Sep 17 00:00:00 2001 From: Pedro Rodriguez Date: Thu, 13 Feb 2025 18:38:58 +0000 Subject: [PATCH 49/59] Fix multiprocessing dataloader checkpointing and use it in the train script Summary: Test Plan: --- bytelatent/args.py | 2 -- .../data/iterators/abstract_iterator.py | 10 +++++++ bytelatent/data/iterators/arrow_iterator.py | 4 +-- .../data/iterators/multiprocess_iterator.py | 27 +++++++++++++------ bytelatent/train.py | 11 +++++++- 5 files changed, 41 insertions(+), 13 deletions(-) diff --git a/bytelatent/args.py b/bytelatent/args.py index 263e8e3..47bd0f9 100644 --- a/bytelatent/args.py +++ b/bytelatent/args.py @@ -1,10 +1,8 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. -import json import logging import os from typing import Any -import fsspec import numpy as np import yaml from omegaconf import OmegaConf diff --git a/bytelatent/data/iterators/abstract_iterator.py b/bytelatent/data/iterators/abstract_iterator.py index 7fb442b..8ac7f19 100644 --- a/bytelatent/data/iterators/abstract_iterator.py +++ b/bytelatent/data/iterators/abstract_iterator.py @@ -21,3 +21,13 @@ class IteratorState(Generic[C]): @abc.abstractmethod def build(self) -> StatefulIterator[T, C]: pass + + +def get_state_and_refresh(iterator: StatefulIterator): + # Re-init dataloader and iterator is necessary since get_state() + # on mp iterator shuts down MP to correctly persist state and it needs + # to be restarted. + state = iterator.get_state() + data_loader = state.build() + py_iterator = data_loader.create_iter() + return state, data_loader, py_iterator diff --git a/bytelatent/data/iterators/arrow_iterator.py b/bytelatent/data/iterators/arrow_iterator.py index 1c68d3a..7e43360 100644 --- a/bytelatent/data/iterators/arrow_iterator.py +++ b/bytelatent/data/iterators/arrow_iterator.py @@ -236,7 +236,7 @@ class ArrowFileIterator(StatefulIterator): def _set_row_num(self, target_row_num: int): logger.info( - f"Setting arrow position to {target_row_num} for {self.dataset_files}" + f"Setting arrow position to {target_row_num} for {str(self.dataset_files)[:200]}" ) if target_row_num is None or target_row_num == 0: self.row_num = 0 @@ -286,5 +286,5 @@ class ArrowFileIterator(StatefulIterator): curr_remaining -= len(batch) self.row_num = target_row_num logger.info( - f"Finished setting arrow position to {target_row_num} for {self.dataset_files}" + f"Finished setting arrow position to {target_row_num} for {str(self.dataset_files)[:200]}" ) diff --git a/bytelatent/data/iterators/multiprocess_iterator.py b/bytelatent/data/iterators/multiprocess_iterator.py index 49d99ac..33bde94 100644 --- a/bytelatent/data/iterators/multiprocess_iterator.py +++ b/bytelatent/data/iterators/multiprocess_iterator.py @@ -54,9 +54,10 @@ def start_work_from_state( if stop_event.is_set(): # Signal the end of output, this ensures that even if the queue takes a while to # buffer, that the main thread receives everything (and tosses this fake batch) - logging.info( + logging.debug( "Worker thread: Stop event detected, outputting is_final=True batch" ) + logging.debug("Worker thread: batch_queue full=%s", batch_queue.full()) batch_queue.put( Batch( x=np.zeros((1, 1)), @@ -67,14 +68,17 @@ def start_work_from_state( ngram_ids=None, ) ) + logging.debug( + "Worker thread: is_final=True batch put in queue, breaking from loop." + ) break try: - logging.info("Worker thread: outputting state") - state_queue.put(iterator.get_state(), timeout=1) - logging.info("Worker thread: state dump complete") + logging.debug("Worker thread: outputting state") + state_queue.put(stateful_iterator.get_state(), timeout=1) + logging.debug("Worker thread: state dump complete") state_dumped_event.set() - logging.info("Worker thread: set state_dump_event") + logging.debug("Worker thread: set state_dump_event") except Full: raise ValueError( "Attempted to dump state into the state queue, but it was full" @@ -156,16 +160,20 @@ class MultiprocessIterator(StatefulIterator): serialized_prefetch_buffer=serialized_prefetch_buffer, ) else: - logging.info("Main thread: Sending stop iteration event") + logging.debug("Main thread: Sending stop iteration event") self.stop_iterating_event.set() - logging.info("Main thread: Waiting for state_dumped event") - self.state_dumped_event.wait() + logging.debug( + "Main thread: Emptying the batch_queue until batch.is_final=True is found." + ) self.prefetch_buffer = [] final_batch_received = False while True: try: batch = self.batch_queue.get(timeout=1) if batch.is_final: + logging.debug( + "Main thread: is_final=True batch found, stopping fetch from batch_queue" + ) final_batch_received = True break self.prefetch_buffer.append(batch) @@ -173,6 +181,9 @@ class MultiprocessIterator(StatefulIterator): logging.warning("Main thread: batch_queue is abnormally empty") assert final_batch_received + logging.debug("Main thread: Waiting for state_dumped event") + self.state_dumped_event.wait() + try: base_iterator_state = self.state_queue.get(timeout=1) assert isinstance(base_iterator_state, IteratorState) diff --git a/bytelatent/train.py b/bytelatent/train.py index 0ee87df..4c2dc35 100644 --- a/bytelatent/train.py +++ b/bytelatent/train.py @@ -26,6 +26,7 @@ from torch.optim import lr_scheduler from bytelatent.args import TrainArgs, parse_args from bytelatent.checkpoint import CheckpointManager, load_from_checkpoint from bytelatent.data.file_util import get_fs +from bytelatent.data.iterators.abstract_iterator import get_state_and_refresh from bytelatent.data.iterators.multiprocess_iterator import ( MultiprocessIterator, MultiprocessIteratorState, @@ -35,7 +36,6 @@ from bytelatent.distributed import ( check_model_value_range, clean_env, dist_mean, - dist_mean_dict, dist_sum, get_device_mesh, get_is_master, @@ -702,6 +702,9 @@ def train(args: TrainArgs): if every_n_steps( train_state, args.checkpoint.dump.every, acc_step=0 ) or every_n_steps(train_state, args.checkpoint.eval.every, acc_step=0): + train_state.data_loader_state, data_loader, batch_iterator = ( + get_state_and_refresh(data_loader) + ) saved = checkpoint.save( model, optimizer, @@ -743,6 +746,9 @@ def train(args: TrainArgs): if preemption_flag["flag"]: if not saved: + train_state.data_loader_state, data_loader, batch_iterator = ( + get_state_and_refresh(data_loader) + ) checkpoint.save( model, optimizer, @@ -754,6 +760,9 @@ def train(args: TrainArgs): sys.exit(0) if not saved: + train_state.data_loader_state, data_loader, batch_iterator = ( + get_state_and_refresh(data_loader) + ) checkpoint.save( model, optimizer, From 53529dcc78bcd275f6b45f9998021251e09a5e7a Mon Sep 17 00:00:00 2001 From: Pedro Rodriguez Date: Thu, 13 Feb 2025 19:01:48 +0000 Subject: [PATCH 50/59] Fix multiprocessing dataloader checkpointing and use it in the train script Summary: Test Plan: --- bytelatent/args.py | 2 - .../data/iterators/abstract_iterator.py | 10 ++++ bytelatent/data/iterators/arrow_iterator.py | 15 +++-- .../data/iterators/multiprocess_iterator.py | 27 ++++++--- bytelatent/train.py | 56 ++++++++++++------- 5 files changed, 77 insertions(+), 33 deletions(-) diff --git a/bytelatent/args.py b/bytelatent/args.py index 263e8e3..47bd0f9 100644 --- a/bytelatent/args.py +++ b/bytelatent/args.py @@ -1,10 +1,8 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. -import json import logging import os from typing import Any -import fsspec import numpy as np import yaml from omegaconf import OmegaConf diff --git a/bytelatent/data/iterators/abstract_iterator.py b/bytelatent/data/iterators/abstract_iterator.py index 7fb442b..8ac7f19 100644 --- a/bytelatent/data/iterators/abstract_iterator.py +++ b/bytelatent/data/iterators/abstract_iterator.py @@ -21,3 +21,13 @@ class IteratorState(Generic[C]): @abc.abstractmethod def build(self) -> StatefulIterator[T, C]: pass + + +def get_state_and_refresh(iterator: StatefulIterator): + # Re-init dataloader and iterator is necessary since get_state() + # on mp iterator shuts down MP to correctly persist state and it needs + # to be restarted. + state = iterator.get_state() + data_loader = state.build() + py_iterator = data_loader.create_iter() + return state, data_loader, py_iterator diff --git a/bytelatent/data/iterators/arrow_iterator.py b/bytelatent/data/iterators/arrow_iterator.py index 1c68d3a..995cd02 100644 --- a/bytelatent/data/iterators/arrow_iterator.py +++ b/bytelatent/data/iterators/arrow_iterator.py @@ -60,6 +60,13 @@ def shard_sort_key(file: str): return shard_number +def maybe_truncate_string(text: str, max_length: int): + if len(text) <= max_length: + return text + else: + return text[:max_length] + "..." + + class ArrowFileIterator(StatefulIterator): def __init__( self, @@ -235,9 +242,8 @@ class ArrowFileIterator(StatefulIterator): yield out def _set_row_num(self, target_row_num: int): - logger.info( - f"Setting arrow position to {target_row_num} for {self.dataset_files}" - ) + data_str = maybe_truncate_string(str(self.dataset_files), 200) + logger.info(f"Setting arrow position to {target_row_num} for {data_str}") if target_row_num is None or target_row_num == 0: self.row_num = 0 self.dataset = None @@ -285,6 +291,7 @@ class ArrowFileIterator(StatefulIterator): else: curr_remaining -= len(batch) self.row_num = target_row_num + data_str = maybe_truncate_string(str(self.dataset_files), 200) logger.info( - f"Finished setting arrow position to {target_row_num} for {self.dataset_files}" + f"Finished setting arrow position to {target_row_num} for {data_str}" ) diff --git a/bytelatent/data/iterators/multiprocess_iterator.py b/bytelatent/data/iterators/multiprocess_iterator.py index 49d99ac..33bde94 100644 --- a/bytelatent/data/iterators/multiprocess_iterator.py +++ b/bytelatent/data/iterators/multiprocess_iterator.py @@ -54,9 +54,10 @@ def start_work_from_state( if stop_event.is_set(): # Signal the end of output, this ensures that even if the queue takes a while to # buffer, that the main thread receives everything (and tosses this fake batch) - logging.info( + logging.debug( "Worker thread: Stop event detected, outputting is_final=True batch" ) + logging.debug("Worker thread: batch_queue full=%s", batch_queue.full()) batch_queue.put( Batch( x=np.zeros((1, 1)), @@ -67,14 +68,17 @@ def start_work_from_state( ngram_ids=None, ) ) + logging.debug( + "Worker thread: is_final=True batch put in queue, breaking from loop." + ) break try: - logging.info("Worker thread: outputting state") - state_queue.put(iterator.get_state(), timeout=1) - logging.info("Worker thread: state dump complete") + logging.debug("Worker thread: outputting state") + state_queue.put(stateful_iterator.get_state(), timeout=1) + logging.debug("Worker thread: state dump complete") state_dumped_event.set() - logging.info("Worker thread: set state_dump_event") + logging.debug("Worker thread: set state_dump_event") except Full: raise ValueError( "Attempted to dump state into the state queue, but it was full" @@ -156,16 +160,20 @@ class MultiprocessIterator(StatefulIterator): serialized_prefetch_buffer=serialized_prefetch_buffer, ) else: - logging.info("Main thread: Sending stop iteration event") + logging.debug("Main thread: Sending stop iteration event") self.stop_iterating_event.set() - logging.info("Main thread: Waiting for state_dumped event") - self.state_dumped_event.wait() + logging.debug( + "Main thread: Emptying the batch_queue until batch.is_final=True is found." + ) self.prefetch_buffer = [] final_batch_received = False while True: try: batch = self.batch_queue.get(timeout=1) if batch.is_final: + logging.debug( + "Main thread: is_final=True batch found, stopping fetch from batch_queue" + ) final_batch_received = True break self.prefetch_buffer.append(batch) @@ -173,6 +181,9 @@ class MultiprocessIterator(StatefulIterator): logging.warning("Main thread: batch_queue is abnormally empty") assert final_batch_received + logging.debug("Main thread: Waiting for state_dumped event") + self.state_dumped_event.wait() + try: base_iterator_state = self.state_queue.get(timeout=1) assert isinstance(base_iterator_state, IteratorState) diff --git a/bytelatent/train.py b/bytelatent/train.py index 0ee87df..3669167 100644 --- a/bytelatent/train.py +++ b/bytelatent/train.py @@ -26,6 +26,7 @@ from torch.optim import lr_scheduler from bytelatent.args import TrainArgs, parse_args from bytelatent.checkpoint import CheckpointManager, load_from_checkpoint from bytelatent.data.file_util import get_fs +from bytelatent.data.iterators.abstract_iterator import get_state_and_refresh from bytelatent.data.iterators.multiprocess_iterator import ( MultiprocessIterator, MultiprocessIteratorState, @@ -35,7 +36,6 @@ from bytelatent.distributed import ( check_model_value_range, clean_env, dist_mean, - dist_mean_dict, dist_sum, get_device_mesh, get_is_master, @@ -88,6 +88,13 @@ def get_iterator_state_name(iterator_state): raise ValueError(f"Unsupported iterator to get name from: {iterator_state}") +def to_py_num(num: int | float | torch.Tensor | np.ndarray) -> int | float: + if isinstance(num, (torch.Tensor, np.ndarray)): + return num.item() + else: + return num + + # TODO: Make this pydantic based instead of data class based # TODO: Generalize this to any iterator state @dataclass @@ -603,20 +610,20 @@ def train(args: TrainArgs): # step: Metric at a step # interval: Metric averaged/summed across all steps since the last log interval. # Typically, this is 10 - step_loss_per_gpu = loss.item() - step_loss_across_gpus = dist_mean(step_loss_per_gpu).item() - interval_loss_per_gpu = np.mean(step_losses).item() - interval_loss_across_gpus = dist_mean(interval_loss_per_gpu).item() + step_loss_per_gpu = loss + step_loss_across_gpus = dist_mean(step_loss_per_gpu) + interval_loss_per_gpu = np.mean(step_losses) + interval_loss_across_gpus = dist_mean(interval_loss_per_gpu) stacked_tok_loss = torch.cat(step_tok_losses, dim=0) - interval_total_tok_loss_per_gpu = stacked_tok_loss.sum().item() + interval_total_tok_loss_per_gpu = stacked_tok_loss.sum() interval_total_tok_loss_across_gpus = dist_sum( interval_total_tok_loss_per_gpu, reduce_dtype=torch.bfloat16 - ).item() - interval_total_n_bytes_per_gpu = n_bytes.item() + ) + interval_total_n_bytes_per_gpu = n_bytes interval_total_n_bytes_across_gpus = dist_sum( n_bytes, reduce_dtype=torch.bfloat16 - ).item() + ) interval_bpb_per_gpu = ( interval_total_tok_loss_per_gpu @@ -645,18 +652,20 @@ def train(args: TrainArgs): }, "memory": gpu_mem_stats._asdict(), "loss": { - "step_per_gpu": step_loss_per_gpu, - "step_across_gpu": step_loss_across_gpus, - "interval_per_gpu": interval_loss_per_gpu, - "interval_across_gpu": interval_loss_across_gpus, + "step_per_gpu": to_py_num(step_loss_per_gpu), + "step_across_gpu": to_py_num(step_loss_across_gpus), + "interval_per_gpu": to_py_num(interval_loss_per_gpu), + "interval_across_gpu": to_py_num(interval_loss_across_gpus), }, "bpb": { - "interval_per_gpu": interval_bpb_per_gpu, - "interval_across_gpus": interval_bpb_across_gpus, + "interval_per_gpu": to_py_num(interval_bpb_per_gpu), + "interval_across_gpus": to_py_num(interval_bpb_across_gpus), }, "n_bytes": { - "interval_per_gpu": interval_total_n_bytes_per_gpu, - "interval_across_gpus": interval_total_n_bytes_across_gpus, + "interval_per_gpu": to_py_num(interval_total_n_bytes_per_gpu), + "interval_across_gpus": to_py_num( + interval_total_n_bytes_across_gpus + ), }, } @@ -676,8 +685,8 @@ def train(args: TrainArgs): logger.info( f"step: {train_state.step}" f" acc: {train_state.acc_step}" - f" loss_gpu: {round(interval_loss_per_gpu, 4):>7}" - f" loss_avg: {round(interval_loss_across_gpus, 4):>7}" + f" loss_gpu: {round(to_py_num(interval_loss_per_gpu), 4):>7}" + f" loss_avg: {round(to_py_num(interval_loss_across_gpus), 4):>7}" f" bpb_gpu: {interval_bpb_per_gpu:3f}" f" bpb_avg: {interval_bpb_across_gpus:3f}" f" grad: {grad_norm:.2e}" @@ -702,6 +711,9 @@ def train(args: TrainArgs): if every_n_steps( train_state, args.checkpoint.dump.every, acc_step=0 ) or every_n_steps(train_state, args.checkpoint.eval.every, acc_step=0): + train_state.data_loader_state, data_loader, batch_iterator = ( + get_state_and_refresh(data_loader) + ) saved = checkpoint.save( model, optimizer, @@ -743,6 +755,9 @@ def train(args: TrainArgs): if preemption_flag["flag"]: if not saved: + train_state.data_loader_state, data_loader, batch_iterator = ( + get_state_and_refresh(data_loader) + ) checkpoint.save( model, optimizer, @@ -754,6 +769,9 @@ def train(args: TrainArgs): sys.exit(0) if not saved: + train_state.data_loader_state, data_loader, batch_iterator = ( + get_state_and_refresh(data_loader) + ) checkpoint.save( model, optimizer, From be3ff12cfe379d5219026305f692e4644e07c9ba Mon Sep 17 00:00:00 2001 From: Pedro Rodriguez Date: Fri, 14 Feb 2025 21:03:25 +0000 Subject: [PATCH 51/59] Make it possible to specify multiple config files Summary: Test Plan: Test that this iterpolates in the right order, config -> configs -> cli args ``` # All three sources python -m bytelatent.print_config config=bytelatent/configs/debug.yaml configs=[internal/configs/s3_debug.yaml] eval=null # What worked before python -m bytelatent.print_config config=internal/configs/s3_debug.yaml eval=null ``` --- bytelatent/args.py | 14 -- bytelatent/config_parser.py | 70 ++++++++++ bytelatent/configs/debug.yaml | 4 +- bytelatent/configs/entropy_model.yaml | 4 +- bytelatent/eval.py | 8 +- bytelatent/print_config.py | 11 ++ bytelatent/test_config_parser.py | 180 ++++++++++++++++++++++++++ bytelatent/train.py | 5 +- fixtures/test-cfgs/list.yaml | 1 + fixtures/test-cfgs/middle.yaml | 3 + fixtures/test-cfgs/override.yaml | 1 + fixtures/test-cfgs/root.yaml | 6 + fixtures/test-cfgs/top.yaml | 3 + 13 files changed, 283 insertions(+), 27 deletions(-) create mode 100644 bytelatent/config_parser.py create mode 100644 bytelatent/print_config.py create mode 100644 bytelatent/test_config_parser.py create mode 100644 fixtures/test-cfgs/list.yaml create mode 100644 fixtures/test-cfgs/middle.yaml create mode 100644 fixtures/test-cfgs/override.yaml create mode 100644 fixtures/test-cfgs/root.yaml create mode 100644 fixtures/test-cfgs/top.yaml diff --git a/bytelatent/args.py b/bytelatent/args.py index 47bd0f9..dd1fef5 100644 --- a/bytelatent/args.py +++ b/bytelatent/args.py @@ -5,7 +5,6 @@ from typing import Any import numpy as np import yaml -from omegaconf import OmegaConf from pydantic import BaseModel, ConfigDict from bytelatent.checkpoint import CheckpointArgs @@ -38,19 +37,6 @@ def get_rng_state(seed: int, rank: int, world_size: int) -> dict[str, Any]: return np.random.default_rng((seed, rank, world_size)).bit_generator.state -def parse_args(args_cls): - cli_args = OmegaConf.from_cli() - file_cfg = OmegaConf.load(cli_args.config) - # We remove 'config' attribute from config as the underlying DataClass does not have it - del cli_args.config - - default_cfg = OmegaConf.create(args_cls().model_dump()) - cfg = OmegaConf.merge(default_cfg, file_cfg, cli_args) - cfg = OmegaConf.to_container(cfg, resolve=True, throw_on_missing=True) - pydantic_args = args_cls.model_validate(cfg) - return pydantic_args - - TRAIN_DATA_FILE_PATTERN = "*.chunk.*.jsonl" diff --git a/bytelatent/config_parser.py b/bytelatent/config_parser.py new file mode 100644 index 0000000..6e60972 --- /dev/null +++ b/bytelatent/config_parser.py @@ -0,0 +1,70 @@ +import copy +from typing import Type + +import omegaconf +from omegaconf import DictConfig, OmegaConf +from pydantic import BaseModel + + +def parse_file_config(path: str) -> DictConfig: + file_cfg = OmegaConf.load(path) + if not isinstance(file_cfg, DictConfig): + raise ValueError( + f"File paths must parse to DictConfig, but it was: {type(file_cfg)}" + ) + return file_cfg + + +def recursively_parse_config(cfg: DictConfig) -> list[DictConfig]: + if "config" not in cfg: + return [cfg] + + ordered_cfgs = [] + cfg = copy.deepcopy(cfg) + config_arg = cfg["config"] + del cfg["config"] + ordered_cfgs.append(cfg) + + if isinstance(config_arg, str): + file_cfg = parse_file_config(config_arg) + sub_configs = recursively_parse_config(file_cfg) + ordered_cfgs = sub_configs + ordered_cfgs + elif isinstance(config_arg, omegaconf.listconfig.ListConfig): + sub_configs = [] + for c in config_arg: + if not isinstance(c, str): + raise ValueError( + f'If "config" is specified, it must be either a string path or a list of string paths. It was config={config_arg}' + ) + config_to_parse = parse_file_config(c) + sub_configs.extend(recursively_parse_config(config_to_parse)) + ordered_cfgs = sub_configs + ordered_cfgs + else: + raise ValueError( + f'If "config" is specified, it must be either a string path or a list of string paths, it was config={config_arg}' + ) + return ordered_cfgs + + +def parse_args_with_default( + *, default_cfg: DictConfig | None = None, cli_args: DictConfig | None = None +): + if cli_args is None: + cli_args = OmegaConf.from_cli() + assert isinstance( + cli_args, DictConfig + ), f"CLI Args must be a DictConfig, not {type(cli_args)}" + ordered_cfgs = recursively_parse_config(cli_args) + if default_cfg is not None: + ordered_cfgs.insert(0, default_cfg) + cfg = OmegaConf.merge(*ordered_cfgs) + return OmegaConf.to_container(cfg, resolve=True, throw_on_missing=True) + + +def parse_args_to_pydantic_model( + args_cls: Type[BaseModel], cli_args: DictConfig | None = None +): + default_cfg = OmegaConf.create(args_cls().model_dump()) + parsed_cfg = parse_args_with_default(default_cfg=default_cfg, cli_args=cli_args) + pydantic_args = args_cls.model_validate(parsed_cfg) + return pydantic_args diff --git a/bytelatent/configs/debug.yaml b/bytelatent/configs/debug.yaml index 07d489f..1369364 100644 --- a/bytelatent/configs/debug.yaml +++ b/bytelatent/configs/debug.yaml @@ -56,13 +56,11 @@ model: recompute_attn: false custom_bwd: false layer_ckpt: "none" - patch_only_encoder: false - patch_only_decoder: false use_local_encoder_transformer: true init_use_gaussian: true init_use_depth: "current" - attn_bias_type: "block_causal" attn_impl: "xformers" + attn_bias_type: "block_causal" alpha_depth: "disabled" max_length: 256 local_attention_window_len: 512 diff --git a/bytelatent/configs/entropy_model.yaml b/bytelatent/configs/entropy_model.yaml index d7c27b7..79cc85b 100644 --- a/bytelatent/configs/entropy_model.yaml +++ b/bytelatent/configs/entropy_model.yaml @@ -2,9 +2,10 @@ # Evals can be activated by uncommenting its config # python -m launchers.stool config=apps/main/configs/debug.yaml nodes=8 account=fair_amaia_cw_codegen qos=lowest -dump_dir: /tmp/ +dump_dir: /tmp/blt-entropy name: "debug" steps: 100_000 +max_steps: null probe_freq: null seed: 777 optim: @@ -35,7 +36,6 @@ entropy_model: attn_impl: "xformers" data: - s3_profile: blt root_dir: ??? sources: dclm_baseline_1.0: 1.0 diff --git a/bytelatent/eval.py b/bytelatent/eval.py index ae73066..50e17cd 100644 --- a/bytelatent/eval.py +++ b/bytelatent/eval.py @@ -5,18 +5,15 @@ import logging import os from collections import defaultdict from datetime import datetime -from pathlib import Path -from typing import Any import torch from lm_eval import simple_evaluate from lm_eval.api.instance import Instance from lm_eval.api.model import LM -from omegaconf import OmegaConf -from pydantic import BaseModel, ConfigDict -from bytelatent.args import EvalArgs, ValidationArgs, parse_args +from bytelatent.args import EvalArgs, ValidationArgs from bytelatent.checkpoint import CONSOLIDATE_FOLDER, consolidate_checkpoints +from bytelatent.config_parser import parse_args_to_pydantic_model from bytelatent.data.file_util import get_fs from bytelatent.distributed import ( DistributedArgs, @@ -29,7 +26,6 @@ from bytelatent.generate import ( PackedCausalTransformerGenerator, load_consolidated_model_and_tokenizer, ) -from bytelatent.transformer import LMTransformer, LMTransformerArgs EVAL_FOLDER_NAME = "{:010d}" diff --git a/bytelatent/print_config.py b/bytelatent/print_config.py new file mode 100644 index 0000000..0bc99e7 --- /dev/null +++ b/bytelatent/print_config.py @@ -0,0 +1,11 @@ +from bytelatent.args import TrainArgs +from bytelatent.config_parser import parse_args_to_pydantic_model + + +def main(): + train_args = parse_args_to_pydantic_model(TrainArgs) + print(train_args.model_dump_json(indent=4)) + + +if __name__ == "__main__": + main() diff --git a/bytelatent/test_config_parser.py b/bytelatent/test_config_parser.py new file mode 100644 index 0000000..c1ec99b --- /dev/null +++ b/bytelatent/test_config_parser.py @@ -0,0 +1,180 @@ +import os + +import pytest +from omegaconf import DictConfig, MissingMandatoryValue, OmegaConf +from pydantic import BaseModel, ConfigDict + +from bytelatent.config_parser import ( + parse_args_to_pydantic_model, + parse_file_config, + recursively_parse_config, +) + +FIXTURE_DIR = "fixtures/test-cfgs" + + +def test_parse_file_config(): + with pytest.raises(ValueError): + cfg = parse_file_config(os.path.join(FIXTURE_DIR, "list.yaml")) + assert isinstance(cfg, DictConfig) + + +def test_nop(): + cfg = OmegaConf.create({"a": 1}) + parsed_cfgs = recursively_parse_config(cfg) + assert len(parsed_cfgs) == 1 + assert parsed_cfgs[0] == cfg + + +def test_root(): + cli_cfg = OmegaConf.create({"config": os.path.join(FIXTURE_DIR, "root.yaml")}) + parsed_cfgs = recursively_parse_config(cli_cfg) + assert len(parsed_cfgs) == 2 + assert len(parsed_cfgs[1]) == 0 + assert parsed_cfgs[0]["seed"] == -1 + with pytest.raises(MissingMandatoryValue): + assert parsed_cfgs[0]["b"]["y"] is not None + + # Test basic cli override + cli_cfg = OmegaConf.create( + {"config": os.path.join(FIXTURE_DIR, "root.yaml"), "seed": 42} + ) + parsed_cfgs = recursively_parse_config(cli_cfg) + assert parsed_cfgs[1]["seed"] == 42 + cfg = OmegaConf.merge(*parsed_cfgs) + assert cfg["seed"] == 42 + + +def test_one_level_include(): + cli_cfg = OmegaConf.create({"config": os.path.join(FIXTURE_DIR, "middle.yaml")}) + parsed_cfgs = recursively_parse_config(cli_cfg) + assert len(parsed_cfgs) == 3 + assert parsed_cfgs[0]["seed"] == -1 + assert parsed_cfgs[1]["b"]["y"] == 10 + assert len(parsed_cfgs[2]) == 0 + cfg = OmegaConf.merge(*parsed_cfgs) + assert cfg["b"]["y"] == 10 + + cli_cfg = OmegaConf.create( + {"config": os.path.join(FIXTURE_DIR, "middle.yaml"), "b": {"y": 100}} + ) + parsed_cfgs = recursively_parse_config(cli_cfg) + assert len(parsed_cfgs) == 3 + assert parsed_cfgs[0]["seed"] == -1 + assert parsed_cfgs[1]["b"]["y"] == 10 + assert parsed_cfgs[2]["b"]["y"] == 100 + cfg = OmegaConf.merge(*parsed_cfgs) + assert cfg["b"]["y"] == 100 + + +def test_two_level_include(): + cli_cfg = OmegaConf.create( + {"config": os.path.join(FIXTURE_DIR, "top.yaml"), "p": 500, "b": {"z": -2}} + ) + parsed_cfgs = recursively_parse_config(cli_cfg) + assert len(parsed_cfgs) == 4 + assert parsed_cfgs[0]["seed"] == -1 + assert parsed_cfgs[1]["b"]["y"] == 10 + assert parsed_cfgs[2]["hello"] == "world" + assert parsed_cfgs[3]["p"] == 500 + assert parsed_cfgs[3]["b"]["z"] == -2 + cfg = OmegaConf.merge(*parsed_cfgs) + assert cfg["a"] == 1 + assert cfg["seed"] == -1 + assert cfg["b"]["x"] == 0 + assert cfg["b"]["y"] == 10 + assert cfg["b"]["z"] == -2 + assert cfg["hello"] == "world" + + +def test_multiple_includes(): + cli_cfg = OmegaConf.create( + { + "config": [ + os.path.join(FIXTURE_DIR, "top.yaml"), + os.path.join(FIXTURE_DIR, "override.yaml"), + ], + "p": 500, + "b": {"z": -2}, + } + ) + parsed_cfgs = recursively_parse_config(cli_cfg) + assert len(parsed_cfgs) == 5 + assert parsed_cfgs[0]["seed"] == -1 + assert parsed_cfgs[1]["b"]["y"] == 10 + assert parsed_cfgs[2]["hello"] == "world" + assert parsed_cfgs[3]["a"] == 100 + assert parsed_cfgs[4]["p"] == 500 + assert parsed_cfgs[4]["b"]["z"] == -2 + cfg = OmegaConf.merge(*parsed_cfgs) + assert cfg["a"] == 100 + assert cfg["seed"] == -1 + assert cfg["b"]["x"] == 0 + assert cfg["b"]["y"] == 10 + assert cfg["b"]["z"] == -2 + assert cfg["hello"] == "world" + + cli_cfg = OmegaConf.create( + { + "config": [ + os.path.join(FIXTURE_DIR, "top.yaml"), + os.path.join(FIXTURE_DIR, "override.yaml"), + ], + "p": 500, + "b": {"z": -2}, + "a": 1000, + } + ) + parsed_cfgs = recursively_parse_config(cli_cfg) + assert len(parsed_cfgs) == 5 + assert parsed_cfgs[0]["seed"] == -1 + assert parsed_cfgs[1]["b"]["y"] == 10 + assert parsed_cfgs[2]["hello"] == "world" + assert parsed_cfgs[3]["a"] == 100 + assert parsed_cfgs[4]["p"] == 500 + assert parsed_cfgs[4]["b"]["z"] == -2 + cfg = OmegaConf.merge(*parsed_cfgs) + assert cfg["a"] == 1000 + assert cfg["seed"] == -1 + assert cfg["b"]["x"] == 0 + assert cfg["b"]["y"] == 10 + assert cfg["b"]["z"] == -2 + assert cfg["hello"] == "world" + + +class SubConfig(BaseModel): + model_config = ConfigDict(extra="forbid") + x: int = -100 + y: int = -100 + z: int = -5 + + +class SampleConfig(BaseModel): + model_config = ConfigDict(extra="forbid") + a: int = -100 + seed: int = -100 + b: SubConfig = SubConfig() + hello: str = "" + p: int = -100 + + +def test_pydantic_parse(): + cli_cfg = OmegaConf.create( + { + "config": [ + os.path.join(FIXTURE_DIR, "top.yaml"), + os.path.join(FIXTURE_DIR, "override.yaml"), + ], + "p": 500, + "a": 1000, + } + ) + cfg = parse_args_to_pydantic_model(SampleConfig, cli_args=cli_cfg) + assert isinstance(cfg, SampleConfig) + assert cfg.a == 1000 + assert cfg.p == 500 + assert cfg.seed == -1 + assert cfg.b.x == 0 + assert cfg.b.y == 10 + assert cfg.b.z == -5 + assert cfg.hello == "world" diff --git a/bytelatent/train.py b/bytelatent/train.py index 3669167..ad74b44 100644 --- a/bytelatent/train.py +++ b/bytelatent/train.py @@ -23,8 +23,9 @@ from torch.distributed._tensor import DTensor from torch.distributed.checkpoint.stateful import Stateful from torch.optim import lr_scheduler -from bytelatent.args import TrainArgs, parse_args +from bytelatent.args import TrainArgs from bytelatent.checkpoint import CheckpointManager, load_from_checkpoint +from bytelatent.config_parser import parse_args_to_pydantic_model from bytelatent.data.file_util import get_fs from bytelatent.data.iterators.abstract_iterator import get_state_and_refresh from bytelatent.data.iterators.multiprocess_iterator import ( @@ -824,7 +825,7 @@ def main(): Plus all the default values in TrainArgs dataclass. """ - train_args = parse_args(TrainArgs) + train_args = parse_args_to_pydantic_model(TrainArgs) if train_args.debug_dynamo: import torch._dynamo diff --git a/fixtures/test-cfgs/list.yaml b/fixtures/test-cfgs/list.yaml new file mode 100644 index 0000000..b5d8bb5 --- /dev/null +++ b/fixtures/test-cfgs/list.yaml @@ -0,0 +1 @@ +[1, 2, 3] diff --git a/fixtures/test-cfgs/middle.yaml b/fixtures/test-cfgs/middle.yaml new file mode 100644 index 0000000..a476d8d --- /dev/null +++ b/fixtures/test-cfgs/middle.yaml @@ -0,0 +1,3 @@ +config: fixtures/test-cfgs/root.yaml +b: + y: 10 diff --git a/fixtures/test-cfgs/override.yaml b/fixtures/test-cfgs/override.yaml new file mode 100644 index 0000000..456df7b --- /dev/null +++ b/fixtures/test-cfgs/override.yaml @@ -0,0 +1 @@ +a: 100 diff --git a/fixtures/test-cfgs/root.yaml b/fixtures/test-cfgs/root.yaml new file mode 100644 index 0000000..dc4d285 --- /dev/null +++ b/fixtures/test-cfgs/root.yaml @@ -0,0 +1,6 @@ +seed: -1 +a: 1 +b: + x: 0 + y: ??? + z: ??? diff --git a/fixtures/test-cfgs/top.yaml b/fixtures/test-cfgs/top.yaml new file mode 100644 index 0000000..632866c --- /dev/null +++ b/fixtures/test-cfgs/top.yaml @@ -0,0 +1,3 @@ +config: fixtures/test-cfgs/middle.yaml + +hello: world From bec016482091e1e7442907a3a39b20a56b527a3a Mon Sep 17 00:00:00 2001 From: Pedro Rodriguez Date: Fri, 14 Feb 2025 21:03:56 +0000 Subject: [PATCH 52/59] Make it possible to specify multiple config files Summary: Test Plan: Test that this iterpolates in the right order, config -> configs -> cli args ``` # All three sources python -m bytelatent.print_config config=bytelatent/configs/debug.yaml configs=[internal/configs/s3_debug.yaml] eval=null # What worked before python -m bytelatent.print_config config=internal/configs/s3_debug.yaml eval=null ``` --- bytelatent/args.py | 14 -- bytelatent/config_parser.py | 70 ++++++++++ bytelatent/configs/debug.yaml | 4 +- bytelatent/configs/entropy_model.yaml | 4 +- bytelatent/eval.py | 8 +- bytelatent/print_config.py | 11 ++ bytelatent/test_config_parser.py | 180 ++++++++++++++++++++++++++ bytelatent/train.py | 5 +- fixtures/test-cfgs/list.yaml | 1 + fixtures/test-cfgs/middle.yaml | 3 + fixtures/test-cfgs/override.yaml | 1 + fixtures/test-cfgs/root.yaml | 6 + fixtures/test-cfgs/top.yaml | 3 + 13 files changed, 283 insertions(+), 27 deletions(-) create mode 100644 bytelatent/config_parser.py create mode 100644 bytelatent/print_config.py create mode 100644 bytelatent/test_config_parser.py create mode 100644 fixtures/test-cfgs/list.yaml create mode 100644 fixtures/test-cfgs/middle.yaml create mode 100644 fixtures/test-cfgs/override.yaml create mode 100644 fixtures/test-cfgs/root.yaml create mode 100644 fixtures/test-cfgs/top.yaml diff --git a/bytelatent/args.py b/bytelatent/args.py index 47bd0f9..dd1fef5 100644 --- a/bytelatent/args.py +++ b/bytelatent/args.py @@ -5,7 +5,6 @@ from typing import Any import numpy as np import yaml -from omegaconf import OmegaConf from pydantic import BaseModel, ConfigDict from bytelatent.checkpoint import CheckpointArgs @@ -38,19 +37,6 @@ def get_rng_state(seed: int, rank: int, world_size: int) -> dict[str, Any]: return np.random.default_rng((seed, rank, world_size)).bit_generator.state -def parse_args(args_cls): - cli_args = OmegaConf.from_cli() - file_cfg = OmegaConf.load(cli_args.config) - # We remove 'config' attribute from config as the underlying DataClass does not have it - del cli_args.config - - default_cfg = OmegaConf.create(args_cls().model_dump()) - cfg = OmegaConf.merge(default_cfg, file_cfg, cli_args) - cfg = OmegaConf.to_container(cfg, resolve=True, throw_on_missing=True) - pydantic_args = args_cls.model_validate(cfg) - return pydantic_args - - TRAIN_DATA_FILE_PATTERN = "*.chunk.*.jsonl" diff --git a/bytelatent/config_parser.py b/bytelatent/config_parser.py new file mode 100644 index 0000000..6e60972 --- /dev/null +++ b/bytelatent/config_parser.py @@ -0,0 +1,70 @@ +import copy +from typing import Type + +import omegaconf +from omegaconf import DictConfig, OmegaConf +from pydantic import BaseModel + + +def parse_file_config(path: str) -> DictConfig: + file_cfg = OmegaConf.load(path) + if not isinstance(file_cfg, DictConfig): + raise ValueError( + f"File paths must parse to DictConfig, but it was: {type(file_cfg)}" + ) + return file_cfg + + +def recursively_parse_config(cfg: DictConfig) -> list[DictConfig]: + if "config" not in cfg: + return [cfg] + + ordered_cfgs = [] + cfg = copy.deepcopy(cfg) + config_arg = cfg["config"] + del cfg["config"] + ordered_cfgs.append(cfg) + + if isinstance(config_arg, str): + file_cfg = parse_file_config(config_arg) + sub_configs = recursively_parse_config(file_cfg) + ordered_cfgs = sub_configs + ordered_cfgs + elif isinstance(config_arg, omegaconf.listconfig.ListConfig): + sub_configs = [] + for c in config_arg: + if not isinstance(c, str): + raise ValueError( + f'If "config" is specified, it must be either a string path or a list of string paths. It was config={config_arg}' + ) + config_to_parse = parse_file_config(c) + sub_configs.extend(recursively_parse_config(config_to_parse)) + ordered_cfgs = sub_configs + ordered_cfgs + else: + raise ValueError( + f'If "config" is specified, it must be either a string path or a list of string paths, it was config={config_arg}' + ) + return ordered_cfgs + + +def parse_args_with_default( + *, default_cfg: DictConfig | None = None, cli_args: DictConfig | None = None +): + if cli_args is None: + cli_args = OmegaConf.from_cli() + assert isinstance( + cli_args, DictConfig + ), f"CLI Args must be a DictConfig, not {type(cli_args)}" + ordered_cfgs = recursively_parse_config(cli_args) + if default_cfg is not None: + ordered_cfgs.insert(0, default_cfg) + cfg = OmegaConf.merge(*ordered_cfgs) + return OmegaConf.to_container(cfg, resolve=True, throw_on_missing=True) + + +def parse_args_to_pydantic_model( + args_cls: Type[BaseModel], cli_args: DictConfig | None = None +): + default_cfg = OmegaConf.create(args_cls().model_dump()) + parsed_cfg = parse_args_with_default(default_cfg=default_cfg, cli_args=cli_args) + pydantic_args = args_cls.model_validate(parsed_cfg) + return pydantic_args diff --git a/bytelatent/configs/debug.yaml b/bytelatent/configs/debug.yaml index 07d489f..1369364 100644 --- a/bytelatent/configs/debug.yaml +++ b/bytelatent/configs/debug.yaml @@ -56,13 +56,11 @@ model: recompute_attn: false custom_bwd: false layer_ckpt: "none" - patch_only_encoder: false - patch_only_decoder: false use_local_encoder_transformer: true init_use_gaussian: true init_use_depth: "current" - attn_bias_type: "block_causal" attn_impl: "xformers" + attn_bias_type: "block_causal" alpha_depth: "disabled" max_length: 256 local_attention_window_len: 512 diff --git a/bytelatent/configs/entropy_model.yaml b/bytelatent/configs/entropy_model.yaml index d7c27b7..79cc85b 100644 --- a/bytelatent/configs/entropy_model.yaml +++ b/bytelatent/configs/entropy_model.yaml @@ -2,9 +2,10 @@ # Evals can be activated by uncommenting its config # python -m launchers.stool config=apps/main/configs/debug.yaml nodes=8 account=fair_amaia_cw_codegen qos=lowest -dump_dir: /tmp/ +dump_dir: /tmp/blt-entropy name: "debug" steps: 100_000 +max_steps: null probe_freq: null seed: 777 optim: @@ -35,7 +36,6 @@ entropy_model: attn_impl: "xformers" data: - s3_profile: blt root_dir: ??? sources: dclm_baseline_1.0: 1.0 diff --git a/bytelatent/eval.py b/bytelatent/eval.py index ae73066..50e17cd 100644 --- a/bytelatent/eval.py +++ b/bytelatent/eval.py @@ -5,18 +5,15 @@ import logging import os from collections import defaultdict from datetime import datetime -from pathlib import Path -from typing import Any import torch from lm_eval import simple_evaluate from lm_eval.api.instance import Instance from lm_eval.api.model import LM -from omegaconf import OmegaConf -from pydantic import BaseModel, ConfigDict -from bytelatent.args import EvalArgs, ValidationArgs, parse_args +from bytelatent.args import EvalArgs, ValidationArgs from bytelatent.checkpoint import CONSOLIDATE_FOLDER, consolidate_checkpoints +from bytelatent.config_parser import parse_args_to_pydantic_model from bytelatent.data.file_util import get_fs from bytelatent.distributed import ( DistributedArgs, @@ -29,7 +26,6 @@ from bytelatent.generate import ( PackedCausalTransformerGenerator, load_consolidated_model_and_tokenizer, ) -from bytelatent.transformer import LMTransformer, LMTransformerArgs EVAL_FOLDER_NAME = "{:010d}" diff --git a/bytelatent/print_config.py b/bytelatent/print_config.py new file mode 100644 index 0000000..0bc99e7 --- /dev/null +++ b/bytelatent/print_config.py @@ -0,0 +1,11 @@ +from bytelatent.args import TrainArgs +from bytelatent.config_parser import parse_args_to_pydantic_model + + +def main(): + train_args = parse_args_to_pydantic_model(TrainArgs) + print(train_args.model_dump_json(indent=4)) + + +if __name__ == "__main__": + main() diff --git a/bytelatent/test_config_parser.py b/bytelatent/test_config_parser.py new file mode 100644 index 0000000..c1ec99b --- /dev/null +++ b/bytelatent/test_config_parser.py @@ -0,0 +1,180 @@ +import os + +import pytest +from omegaconf import DictConfig, MissingMandatoryValue, OmegaConf +from pydantic import BaseModel, ConfigDict + +from bytelatent.config_parser import ( + parse_args_to_pydantic_model, + parse_file_config, + recursively_parse_config, +) + +FIXTURE_DIR = "fixtures/test-cfgs" + + +def test_parse_file_config(): + with pytest.raises(ValueError): + cfg = parse_file_config(os.path.join(FIXTURE_DIR, "list.yaml")) + assert isinstance(cfg, DictConfig) + + +def test_nop(): + cfg = OmegaConf.create({"a": 1}) + parsed_cfgs = recursively_parse_config(cfg) + assert len(parsed_cfgs) == 1 + assert parsed_cfgs[0] == cfg + + +def test_root(): + cli_cfg = OmegaConf.create({"config": os.path.join(FIXTURE_DIR, "root.yaml")}) + parsed_cfgs = recursively_parse_config(cli_cfg) + assert len(parsed_cfgs) == 2 + assert len(parsed_cfgs[1]) == 0 + assert parsed_cfgs[0]["seed"] == -1 + with pytest.raises(MissingMandatoryValue): + assert parsed_cfgs[0]["b"]["y"] is not None + + # Test basic cli override + cli_cfg = OmegaConf.create( + {"config": os.path.join(FIXTURE_DIR, "root.yaml"), "seed": 42} + ) + parsed_cfgs = recursively_parse_config(cli_cfg) + assert parsed_cfgs[1]["seed"] == 42 + cfg = OmegaConf.merge(*parsed_cfgs) + assert cfg["seed"] == 42 + + +def test_one_level_include(): + cli_cfg = OmegaConf.create({"config": os.path.join(FIXTURE_DIR, "middle.yaml")}) + parsed_cfgs = recursively_parse_config(cli_cfg) + assert len(parsed_cfgs) == 3 + assert parsed_cfgs[0]["seed"] == -1 + assert parsed_cfgs[1]["b"]["y"] == 10 + assert len(parsed_cfgs[2]) == 0 + cfg = OmegaConf.merge(*parsed_cfgs) + assert cfg["b"]["y"] == 10 + + cli_cfg = OmegaConf.create( + {"config": os.path.join(FIXTURE_DIR, "middle.yaml"), "b": {"y": 100}} + ) + parsed_cfgs = recursively_parse_config(cli_cfg) + assert len(parsed_cfgs) == 3 + assert parsed_cfgs[0]["seed"] == -1 + assert parsed_cfgs[1]["b"]["y"] == 10 + assert parsed_cfgs[2]["b"]["y"] == 100 + cfg = OmegaConf.merge(*parsed_cfgs) + assert cfg["b"]["y"] == 100 + + +def test_two_level_include(): + cli_cfg = OmegaConf.create( + {"config": os.path.join(FIXTURE_DIR, "top.yaml"), "p": 500, "b": {"z": -2}} + ) + parsed_cfgs = recursively_parse_config(cli_cfg) + assert len(parsed_cfgs) == 4 + assert parsed_cfgs[0]["seed"] == -1 + assert parsed_cfgs[1]["b"]["y"] == 10 + assert parsed_cfgs[2]["hello"] == "world" + assert parsed_cfgs[3]["p"] == 500 + assert parsed_cfgs[3]["b"]["z"] == -2 + cfg = OmegaConf.merge(*parsed_cfgs) + assert cfg["a"] == 1 + assert cfg["seed"] == -1 + assert cfg["b"]["x"] == 0 + assert cfg["b"]["y"] == 10 + assert cfg["b"]["z"] == -2 + assert cfg["hello"] == "world" + + +def test_multiple_includes(): + cli_cfg = OmegaConf.create( + { + "config": [ + os.path.join(FIXTURE_DIR, "top.yaml"), + os.path.join(FIXTURE_DIR, "override.yaml"), + ], + "p": 500, + "b": {"z": -2}, + } + ) + parsed_cfgs = recursively_parse_config(cli_cfg) + assert len(parsed_cfgs) == 5 + assert parsed_cfgs[0]["seed"] == -1 + assert parsed_cfgs[1]["b"]["y"] == 10 + assert parsed_cfgs[2]["hello"] == "world" + assert parsed_cfgs[3]["a"] == 100 + assert parsed_cfgs[4]["p"] == 500 + assert parsed_cfgs[4]["b"]["z"] == -2 + cfg = OmegaConf.merge(*parsed_cfgs) + assert cfg["a"] == 100 + assert cfg["seed"] == -1 + assert cfg["b"]["x"] == 0 + assert cfg["b"]["y"] == 10 + assert cfg["b"]["z"] == -2 + assert cfg["hello"] == "world" + + cli_cfg = OmegaConf.create( + { + "config": [ + os.path.join(FIXTURE_DIR, "top.yaml"), + os.path.join(FIXTURE_DIR, "override.yaml"), + ], + "p": 500, + "b": {"z": -2}, + "a": 1000, + } + ) + parsed_cfgs = recursively_parse_config(cli_cfg) + assert len(parsed_cfgs) == 5 + assert parsed_cfgs[0]["seed"] == -1 + assert parsed_cfgs[1]["b"]["y"] == 10 + assert parsed_cfgs[2]["hello"] == "world" + assert parsed_cfgs[3]["a"] == 100 + assert parsed_cfgs[4]["p"] == 500 + assert parsed_cfgs[4]["b"]["z"] == -2 + cfg = OmegaConf.merge(*parsed_cfgs) + assert cfg["a"] == 1000 + assert cfg["seed"] == -1 + assert cfg["b"]["x"] == 0 + assert cfg["b"]["y"] == 10 + assert cfg["b"]["z"] == -2 + assert cfg["hello"] == "world" + + +class SubConfig(BaseModel): + model_config = ConfigDict(extra="forbid") + x: int = -100 + y: int = -100 + z: int = -5 + + +class SampleConfig(BaseModel): + model_config = ConfigDict(extra="forbid") + a: int = -100 + seed: int = -100 + b: SubConfig = SubConfig() + hello: str = "" + p: int = -100 + + +def test_pydantic_parse(): + cli_cfg = OmegaConf.create( + { + "config": [ + os.path.join(FIXTURE_DIR, "top.yaml"), + os.path.join(FIXTURE_DIR, "override.yaml"), + ], + "p": 500, + "a": 1000, + } + ) + cfg = parse_args_to_pydantic_model(SampleConfig, cli_args=cli_cfg) + assert isinstance(cfg, SampleConfig) + assert cfg.a == 1000 + assert cfg.p == 500 + assert cfg.seed == -1 + assert cfg.b.x == 0 + assert cfg.b.y == 10 + assert cfg.b.z == -5 + assert cfg.hello == "world" diff --git a/bytelatent/train.py b/bytelatent/train.py index 3669167..ad74b44 100644 --- a/bytelatent/train.py +++ b/bytelatent/train.py @@ -23,8 +23,9 @@ from torch.distributed._tensor import DTensor from torch.distributed.checkpoint.stateful import Stateful from torch.optim import lr_scheduler -from bytelatent.args import TrainArgs, parse_args +from bytelatent.args import TrainArgs from bytelatent.checkpoint import CheckpointManager, load_from_checkpoint +from bytelatent.config_parser import parse_args_to_pydantic_model from bytelatent.data.file_util import get_fs from bytelatent.data.iterators.abstract_iterator import get_state_and_refresh from bytelatent.data.iterators.multiprocess_iterator import ( @@ -824,7 +825,7 @@ def main(): Plus all the default values in TrainArgs dataclass. """ - train_args = parse_args(TrainArgs) + train_args = parse_args_to_pydantic_model(TrainArgs) if train_args.debug_dynamo: import torch._dynamo diff --git a/fixtures/test-cfgs/list.yaml b/fixtures/test-cfgs/list.yaml new file mode 100644 index 0000000..b5d8bb5 --- /dev/null +++ b/fixtures/test-cfgs/list.yaml @@ -0,0 +1 @@ +[1, 2, 3] diff --git a/fixtures/test-cfgs/middle.yaml b/fixtures/test-cfgs/middle.yaml new file mode 100644 index 0000000..a476d8d --- /dev/null +++ b/fixtures/test-cfgs/middle.yaml @@ -0,0 +1,3 @@ +config: fixtures/test-cfgs/root.yaml +b: + y: 10 diff --git a/fixtures/test-cfgs/override.yaml b/fixtures/test-cfgs/override.yaml new file mode 100644 index 0000000..456df7b --- /dev/null +++ b/fixtures/test-cfgs/override.yaml @@ -0,0 +1 @@ +a: 100 diff --git a/fixtures/test-cfgs/root.yaml b/fixtures/test-cfgs/root.yaml new file mode 100644 index 0000000..dc4d285 --- /dev/null +++ b/fixtures/test-cfgs/root.yaml @@ -0,0 +1,6 @@ +seed: -1 +a: 1 +b: + x: 0 + y: ??? + z: ??? diff --git a/fixtures/test-cfgs/top.yaml b/fixtures/test-cfgs/top.yaml new file mode 100644 index 0000000..632866c --- /dev/null +++ b/fixtures/test-cfgs/top.yaml @@ -0,0 +1,3 @@ +config: fixtures/test-cfgs/middle.yaml + +hello: world From aa78c96ea435f60567f48f9c1adb27c6dc04f911 Mon Sep 17 00:00:00 2001 From: Pedro Rodriguez Date: Fri, 14 Feb 2025 21:04:16 +0000 Subject: [PATCH 53/59] Make it possible to specify multiple config files Summary: Make it possible to specify multiple config files. Parsing CLI is not a special case anymore, just uses the same config inheritance method. Test Plan: Test that this iterpolates in the right order via unit tests Sample usage, loads the internal config, which references bytelatent/configs/entropy_model.yaml. The precendence order is: - Default pydantic args - Included configs, eg `config` - CLI args ``` python -m bytelatent.print_config config=internal/configs/entropy_model.yaml eval=null ``` --- bytelatent/args.py | 14 -- bytelatent/config_parser.py | 70 ++++++++++ bytelatent/configs/debug.yaml | 4 +- bytelatent/configs/entropy_model.yaml | 4 +- bytelatent/eval.py | 8 +- bytelatent/print_config.py | 11 ++ bytelatent/test_config_parser.py | 180 ++++++++++++++++++++++++++ bytelatent/train.py | 5 +- fixtures/test-cfgs/list.yaml | 1 + fixtures/test-cfgs/middle.yaml | 3 + fixtures/test-cfgs/override.yaml | 1 + fixtures/test-cfgs/root.yaml | 6 + fixtures/test-cfgs/top.yaml | 3 + 13 files changed, 283 insertions(+), 27 deletions(-) create mode 100644 bytelatent/config_parser.py create mode 100644 bytelatent/print_config.py create mode 100644 bytelatent/test_config_parser.py create mode 100644 fixtures/test-cfgs/list.yaml create mode 100644 fixtures/test-cfgs/middle.yaml create mode 100644 fixtures/test-cfgs/override.yaml create mode 100644 fixtures/test-cfgs/root.yaml create mode 100644 fixtures/test-cfgs/top.yaml diff --git a/bytelatent/args.py b/bytelatent/args.py index 47bd0f9..dd1fef5 100644 --- a/bytelatent/args.py +++ b/bytelatent/args.py @@ -5,7 +5,6 @@ from typing import Any import numpy as np import yaml -from omegaconf import OmegaConf from pydantic import BaseModel, ConfigDict from bytelatent.checkpoint import CheckpointArgs @@ -38,19 +37,6 @@ def get_rng_state(seed: int, rank: int, world_size: int) -> dict[str, Any]: return np.random.default_rng((seed, rank, world_size)).bit_generator.state -def parse_args(args_cls): - cli_args = OmegaConf.from_cli() - file_cfg = OmegaConf.load(cli_args.config) - # We remove 'config' attribute from config as the underlying DataClass does not have it - del cli_args.config - - default_cfg = OmegaConf.create(args_cls().model_dump()) - cfg = OmegaConf.merge(default_cfg, file_cfg, cli_args) - cfg = OmegaConf.to_container(cfg, resolve=True, throw_on_missing=True) - pydantic_args = args_cls.model_validate(cfg) - return pydantic_args - - TRAIN_DATA_FILE_PATTERN = "*.chunk.*.jsonl" diff --git a/bytelatent/config_parser.py b/bytelatent/config_parser.py new file mode 100644 index 0000000..6e60972 --- /dev/null +++ b/bytelatent/config_parser.py @@ -0,0 +1,70 @@ +import copy +from typing import Type + +import omegaconf +from omegaconf import DictConfig, OmegaConf +from pydantic import BaseModel + + +def parse_file_config(path: str) -> DictConfig: + file_cfg = OmegaConf.load(path) + if not isinstance(file_cfg, DictConfig): + raise ValueError( + f"File paths must parse to DictConfig, but it was: {type(file_cfg)}" + ) + return file_cfg + + +def recursively_parse_config(cfg: DictConfig) -> list[DictConfig]: + if "config" not in cfg: + return [cfg] + + ordered_cfgs = [] + cfg = copy.deepcopy(cfg) + config_arg = cfg["config"] + del cfg["config"] + ordered_cfgs.append(cfg) + + if isinstance(config_arg, str): + file_cfg = parse_file_config(config_arg) + sub_configs = recursively_parse_config(file_cfg) + ordered_cfgs = sub_configs + ordered_cfgs + elif isinstance(config_arg, omegaconf.listconfig.ListConfig): + sub_configs = [] + for c in config_arg: + if not isinstance(c, str): + raise ValueError( + f'If "config" is specified, it must be either a string path or a list of string paths. It was config={config_arg}' + ) + config_to_parse = parse_file_config(c) + sub_configs.extend(recursively_parse_config(config_to_parse)) + ordered_cfgs = sub_configs + ordered_cfgs + else: + raise ValueError( + f'If "config" is specified, it must be either a string path or a list of string paths, it was config={config_arg}' + ) + return ordered_cfgs + + +def parse_args_with_default( + *, default_cfg: DictConfig | None = None, cli_args: DictConfig | None = None +): + if cli_args is None: + cli_args = OmegaConf.from_cli() + assert isinstance( + cli_args, DictConfig + ), f"CLI Args must be a DictConfig, not {type(cli_args)}" + ordered_cfgs = recursively_parse_config(cli_args) + if default_cfg is not None: + ordered_cfgs.insert(0, default_cfg) + cfg = OmegaConf.merge(*ordered_cfgs) + return OmegaConf.to_container(cfg, resolve=True, throw_on_missing=True) + + +def parse_args_to_pydantic_model( + args_cls: Type[BaseModel], cli_args: DictConfig | None = None +): + default_cfg = OmegaConf.create(args_cls().model_dump()) + parsed_cfg = parse_args_with_default(default_cfg=default_cfg, cli_args=cli_args) + pydantic_args = args_cls.model_validate(parsed_cfg) + return pydantic_args diff --git a/bytelatent/configs/debug.yaml b/bytelatent/configs/debug.yaml index 07d489f..1369364 100644 --- a/bytelatent/configs/debug.yaml +++ b/bytelatent/configs/debug.yaml @@ -56,13 +56,11 @@ model: recompute_attn: false custom_bwd: false layer_ckpt: "none" - patch_only_encoder: false - patch_only_decoder: false use_local_encoder_transformer: true init_use_gaussian: true init_use_depth: "current" - attn_bias_type: "block_causal" attn_impl: "xformers" + attn_bias_type: "block_causal" alpha_depth: "disabled" max_length: 256 local_attention_window_len: 512 diff --git a/bytelatent/configs/entropy_model.yaml b/bytelatent/configs/entropy_model.yaml index d7c27b7..79cc85b 100644 --- a/bytelatent/configs/entropy_model.yaml +++ b/bytelatent/configs/entropy_model.yaml @@ -2,9 +2,10 @@ # Evals can be activated by uncommenting its config # python -m launchers.stool config=apps/main/configs/debug.yaml nodes=8 account=fair_amaia_cw_codegen qos=lowest -dump_dir: /tmp/ +dump_dir: /tmp/blt-entropy name: "debug" steps: 100_000 +max_steps: null probe_freq: null seed: 777 optim: @@ -35,7 +36,6 @@ entropy_model: attn_impl: "xformers" data: - s3_profile: blt root_dir: ??? sources: dclm_baseline_1.0: 1.0 diff --git a/bytelatent/eval.py b/bytelatent/eval.py index ae73066..50e17cd 100644 --- a/bytelatent/eval.py +++ b/bytelatent/eval.py @@ -5,18 +5,15 @@ import logging import os from collections import defaultdict from datetime import datetime -from pathlib import Path -from typing import Any import torch from lm_eval import simple_evaluate from lm_eval.api.instance import Instance from lm_eval.api.model import LM -from omegaconf import OmegaConf -from pydantic import BaseModel, ConfigDict -from bytelatent.args import EvalArgs, ValidationArgs, parse_args +from bytelatent.args import EvalArgs, ValidationArgs from bytelatent.checkpoint import CONSOLIDATE_FOLDER, consolidate_checkpoints +from bytelatent.config_parser import parse_args_to_pydantic_model from bytelatent.data.file_util import get_fs from bytelatent.distributed import ( DistributedArgs, @@ -29,7 +26,6 @@ from bytelatent.generate import ( PackedCausalTransformerGenerator, load_consolidated_model_and_tokenizer, ) -from bytelatent.transformer import LMTransformer, LMTransformerArgs EVAL_FOLDER_NAME = "{:010d}" diff --git a/bytelatent/print_config.py b/bytelatent/print_config.py new file mode 100644 index 0000000..0bc99e7 --- /dev/null +++ b/bytelatent/print_config.py @@ -0,0 +1,11 @@ +from bytelatent.args import TrainArgs +from bytelatent.config_parser import parse_args_to_pydantic_model + + +def main(): + train_args = parse_args_to_pydantic_model(TrainArgs) + print(train_args.model_dump_json(indent=4)) + + +if __name__ == "__main__": + main() diff --git a/bytelatent/test_config_parser.py b/bytelatent/test_config_parser.py new file mode 100644 index 0000000..c1ec99b --- /dev/null +++ b/bytelatent/test_config_parser.py @@ -0,0 +1,180 @@ +import os + +import pytest +from omegaconf import DictConfig, MissingMandatoryValue, OmegaConf +from pydantic import BaseModel, ConfigDict + +from bytelatent.config_parser import ( + parse_args_to_pydantic_model, + parse_file_config, + recursively_parse_config, +) + +FIXTURE_DIR = "fixtures/test-cfgs" + + +def test_parse_file_config(): + with pytest.raises(ValueError): + cfg = parse_file_config(os.path.join(FIXTURE_DIR, "list.yaml")) + assert isinstance(cfg, DictConfig) + + +def test_nop(): + cfg = OmegaConf.create({"a": 1}) + parsed_cfgs = recursively_parse_config(cfg) + assert len(parsed_cfgs) == 1 + assert parsed_cfgs[0] == cfg + + +def test_root(): + cli_cfg = OmegaConf.create({"config": os.path.join(FIXTURE_DIR, "root.yaml")}) + parsed_cfgs = recursively_parse_config(cli_cfg) + assert len(parsed_cfgs) == 2 + assert len(parsed_cfgs[1]) == 0 + assert parsed_cfgs[0]["seed"] == -1 + with pytest.raises(MissingMandatoryValue): + assert parsed_cfgs[0]["b"]["y"] is not None + + # Test basic cli override + cli_cfg = OmegaConf.create( + {"config": os.path.join(FIXTURE_DIR, "root.yaml"), "seed": 42} + ) + parsed_cfgs = recursively_parse_config(cli_cfg) + assert parsed_cfgs[1]["seed"] == 42 + cfg = OmegaConf.merge(*parsed_cfgs) + assert cfg["seed"] == 42 + + +def test_one_level_include(): + cli_cfg = OmegaConf.create({"config": os.path.join(FIXTURE_DIR, "middle.yaml")}) + parsed_cfgs = recursively_parse_config(cli_cfg) + assert len(parsed_cfgs) == 3 + assert parsed_cfgs[0]["seed"] == -1 + assert parsed_cfgs[1]["b"]["y"] == 10 + assert len(parsed_cfgs[2]) == 0 + cfg = OmegaConf.merge(*parsed_cfgs) + assert cfg["b"]["y"] == 10 + + cli_cfg = OmegaConf.create( + {"config": os.path.join(FIXTURE_DIR, "middle.yaml"), "b": {"y": 100}} + ) + parsed_cfgs = recursively_parse_config(cli_cfg) + assert len(parsed_cfgs) == 3 + assert parsed_cfgs[0]["seed"] == -1 + assert parsed_cfgs[1]["b"]["y"] == 10 + assert parsed_cfgs[2]["b"]["y"] == 100 + cfg = OmegaConf.merge(*parsed_cfgs) + assert cfg["b"]["y"] == 100 + + +def test_two_level_include(): + cli_cfg = OmegaConf.create( + {"config": os.path.join(FIXTURE_DIR, "top.yaml"), "p": 500, "b": {"z": -2}} + ) + parsed_cfgs = recursively_parse_config(cli_cfg) + assert len(parsed_cfgs) == 4 + assert parsed_cfgs[0]["seed"] == -1 + assert parsed_cfgs[1]["b"]["y"] == 10 + assert parsed_cfgs[2]["hello"] == "world" + assert parsed_cfgs[3]["p"] == 500 + assert parsed_cfgs[3]["b"]["z"] == -2 + cfg = OmegaConf.merge(*parsed_cfgs) + assert cfg["a"] == 1 + assert cfg["seed"] == -1 + assert cfg["b"]["x"] == 0 + assert cfg["b"]["y"] == 10 + assert cfg["b"]["z"] == -2 + assert cfg["hello"] == "world" + + +def test_multiple_includes(): + cli_cfg = OmegaConf.create( + { + "config": [ + os.path.join(FIXTURE_DIR, "top.yaml"), + os.path.join(FIXTURE_DIR, "override.yaml"), + ], + "p": 500, + "b": {"z": -2}, + } + ) + parsed_cfgs = recursively_parse_config(cli_cfg) + assert len(parsed_cfgs) == 5 + assert parsed_cfgs[0]["seed"] == -1 + assert parsed_cfgs[1]["b"]["y"] == 10 + assert parsed_cfgs[2]["hello"] == "world" + assert parsed_cfgs[3]["a"] == 100 + assert parsed_cfgs[4]["p"] == 500 + assert parsed_cfgs[4]["b"]["z"] == -2 + cfg = OmegaConf.merge(*parsed_cfgs) + assert cfg["a"] == 100 + assert cfg["seed"] == -1 + assert cfg["b"]["x"] == 0 + assert cfg["b"]["y"] == 10 + assert cfg["b"]["z"] == -2 + assert cfg["hello"] == "world" + + cli_cfg = OmegaConf.create( + { + "config": [ + os.path.join(FIXTURE_DIR, "top.yaml"), + os.path.join(FIXTURE_DIR, "override.yaml"), + ], + "p": 500, + "b": {"z": -2}, + "a": 1000, + } + ) + parsed_cfgs = recursively_parse_config(cli_cfg) + assert len(parsed_cfgs) == 5 + assert parsed_cfgs[0]["seed"] == -1 + assert parsed_cfgs[1]["b"]["y"] == 10 + assert parsed_cfgs[2]["hello"] == "world" + assert parsed_cfgs[3]["a"] == 100 + assert parsed_cfgs[4]["p"] == 500 + assert parsed_cfgs[4]["b"]["z"] == -2 + cfg = OmegaConf.merge(*parsed_cfgs) + assert cfg["a"] == 1000 + assert cfg["seed"] == -1 + assert cfg["b"]["x"] == 0 + assert cfg["b"]["y"] == 10 + assert cfg["b"]["z"] == -2 + assert cfg["hello"] == "world" + + +class SubConfig(BaseModel): + model_config = ConfigDict(extra="forbid") + x: int = -100 + y: int = -100 + z: int = -5 + + +class SampleConfig(BaseModel): + model_config = ConfigDict(extra="forbid") + a: int = -100 + seed: int = -100 + b: SubConfig = SubConfig() + hello: str = "" + p: int = -100 + + +def test_pydantic_parse(): + cli_cfg = OmegaConf.create( + { + "config": [ + os.path.join(FIXTURE_DIR, "top.yaml"), + os.path.join(FIXTURE_DIR, "override.yaml"), + ], + "p": 500, + "a": 1000, + } + ) + cfg = parse_args_to_pydantic_model(SampleConfig, cli_args=cli_cfg) + assert isinstance(cfg, SampleConfig) + assert cfg.a == 1000 + assert cfg.p == 500 + assert cfg.seed == -1 + assert cfg.b.x == 0 + assert cfg.b.y == 10 + assert cfg.b.z == -5 + assert cfg.hello == "world" diff --git a/bytelatent/train.py b/bytelatent/train.py index 3669167..ad74b44 100644 --- a/bytelatent/train.py +++ b/bytelatent/train.py @@ -23,8 +23,9 @@ from torch.distributed._tensor import DTensor from torch.distributed.checkpoint.stateful import Stateful from torch.optim import lr_scheduler -from bytelatent.args import TrainArgs, parse_args +from bytelatent.args import TrainArgs from bytelatent.checkpoint import CheckpointManager, load_from_checkpoint +from bytelatent.config_parser import parse_args_to_pydantic_model from bytelatent.data.file_util import get_fs from bytelatent.data.iterators.abstract_iterator import get_state_and_refresh from bytelatent.data.iterators.multiprocess_iterator import ( @@ -824,7 +825,7 @@ def main(): Plus all the default values in TrainArgs dataclass. """ - train_args = parse_args(TrainArgs) + train_args = parse_args_to_pydantic_model(TrainArgs) if train_args.debug_dynamo: import torch._dynamo diff --git a/fixtures/test-cfgs/list.yaml b/fixtures/test-cfgs/list.yaml new file mode 100644 index 0000000..b5d8bb5 --- /dev/null +++ b/fixtures/test-cfgs/list.yaml @@ -0,0 +1 @@ +[1, 2, 3] diff --git a/fixtures/test-cfgs/middle.yaml b/fixtures/test-cfgs/middle.yaml new file mode 100644 index 0000000..a476d8d --- /dev/null +++ b/fixtures/test-cfgs/middle.yaml @@ -0,0 +1,3 @@ +config: fixtures/test-cfgs/root.yaml +b: + y: 10 diff --git a/fixtures/test-cfgs/override.yaml b/fixtures/test-cfgs/override.yaml new file mode 100644 index 0000000..456df7b --- /dev/null +++ b/fixtures/test-cfgs/override.yaml @@ -0,0 +1 @@ +a: 100 diff --git a/fixtures/test-cfgs/root.yaml b/fixtures/test-cfgs/root.yaml new file mode 100644 index 0000000..dc4d285 --- /dev/null +++ b/fixtures/test-cfgs/root.yaml @@ -0,0 +1,6 @@ +seed: -1 +a: 1 +b: + x: 0 + y: ??? + z: ??? diff --git a/fixtures/test-cfgs/top.yaml b/fixtures/test-cfgs/top.yaml new file mode 100644 index 0000000..632866c --- /dev/null +++ b/fixtures/test-cfgs/top.yaml @@ -0,0 +1,3 @@ +config: fixtures/test-cfgs/middle.yaml + +hello: world From f94babc94eeaa4b21ab1300b95c3e69c2a30df23 Mon Sep 17 00:00:00 2001 From: Pedro Rodriguez Date: Fri, 14 Feb 2025 22:50:23 +0000 Subject: [PATCH 54/59] Make it possible to specify multiple config files Summary: Make it possible to specify multiple config files. Parsing CLI is not a special case anymore, just uses the same config inheritance method. Test Plan: Test that this iterpolates in the right order via unit tests Sample usage, loads the internal config, which references bytelatent/configs/entropy_model.yaml. The precendence order is: - Default pydantic args - Included configs, eg `config` - CLI args ``` python -m bytelatent.print_config config=internal/configs/entropy_model.yaml eval=null ``` Summary: Test Plan: --- bytelatent/args.py | 14 -- bytelatent/config_parser.py | 73 +++++++++++ bytelatent/configs/debug.yaml | 4 +- bytelatent/configs/entropy_model.yaml | 4 +- bytelatent/eval.py | 8 +- bytelatent/print_config.py | 11 ++ bytelatent/test_config_parser.py | 180 ++++++++++++++++++++++++++ bytelatent/train.py | 5 +- fixtures/test-cfgs/list.yaml | 1 + fixtures/test-cfgs/middle.yaml | 3 + fixtures/test-cfgs/override.yaml | 1 + fixtures/test-cfgs/root.yaml | 6 + fixtures/test-cfgs/top.yaml | 3 + 13 files changed, 286 insertions(+), 27 deletions(-) create mode 100644 bytelatent/config_parser.py create mode 100644 bytelatent/print_config.py create mode 100644 bytelatent/test_config_parser.py create mode 100644 fixtures/test-cfgs/list.yaml create mode 100644 fixtures/test-cfgs/middle.yaml create mode 100644 fixtures/test-cfgs/override.yaml create mode 100644 fixtures/test-cfgs/root.yaml create mode 100644 fixtures/test-cfgs/top.yaml diff --git a/bytelatent/args.py b/bytelatent/args.py index 47bd0f9..dd1fef5 100644 --- a/bytelatent/args.py +++ b/bytelatent/args.py @@ -5,7 +5,6 @@ from typing import Any import numpy as np import yaml -from omegaconf import OmegaConf from pydantic import BaseModel, ConfigDict from bytelatent.checkpoint import CheckpointArgs @@ -38,19 +37,6 @@ def get_rng_state(seed: int, rank: int, world_size: int) -> dict[str, Any]: return np.random.default_rng((seed, rank, world_size)).bit_generator.state -def parse_args(args_cls): - cli_args = OmegaConf.from_cli() - file_cfg = OmegaConf.load(cli_args.config) - # We remove 'config' attribute from config as the underlying DataClass does not have it - del cli_args.config - - default_cfg = OmegaConf.create(args_cls().model_dump()) - cfg = OmegaConf.merge(default_cfg, file_cfg, cli_args) - cfg = OmegaConf.to_container(cfg, resolve=True, throw_on_missing=True) - pydantic_args = args_cls.model_validate(cfg) - return pydantic_args - - TRAIN_DATA_FILE_PATTERN = "*.chunk.*.jsonl" diff --git a/bytelatent/config_parser.py b/bytelatent/config_parser.py new file mode 100644 index 0000000..2e310a2 --- /dev/null +++ b/bytelatent/config_parser.py @@ -0,0 +1,73 @@ +import copy +from typing import Type, TypeVar + +import omegaconf +from omegaconf import DictConfig, OmegaConf +from pydantic import BaseModel + + +def parse_file_config(path: str) -> DictConfig: + file_cfg = OmegaConf.load(path) + if not isinstance(file_cfg, DictConfig): + raise ValueError( + f"File paths must parse to DictConfig, but it was: {type(file_cfg)}" + ) + return file_cfg + + +def recursively_parse_config(cfg: DictConfig) -> list[DictConfig]: + if "config" not in cfg: + return [cfg] + + ordered_cfgs = [] + cfg = copy.deepcopy(cfg) + config_arg = cfg["config"] + del cfg["config"] + ordered_cfgs.append(cfg) + + if isinstance(config_arg, str): + file_cfg = parse_file_config(config_arg) + sub_configs = recursively_parse_config(file_cfg) + ordered_cfgs = sub_configs + ordered_cfgs + elif isinstance(config_arg, omegaconf.listconfig.ListConfig): + sub_configs = [] + for c in config_arg: + if not isinstance(c, str): + raise ValueError( + f'If "config" is specified, it must be either a string path or a list of string paths. It was config={config_arg}' + ) + config_to_parse = parse_file_config(c) + sub_configs.extend(recursively_parse_config(config_to_parse)) + ordered_cfgs = sub_configs + ordered_cfgs + else: + raise ValueError( + f'If "config" is specified, it must be either a string path or a list of string paths, it was config={config_arg}' + ) + return ordered_cfgs + + +def parse_args_with_default( + *, default_cfg: DictConfig | None = None, cli_args: DictConfig | None = None +): + if cli_args is None: + cli_args = OmegaConf.from_cli() + assert isinstance( + cli_args, DictConfig + ), f"CLI Args must be a DictConfig, not {type(cli_args)}" + ordered_cfgs = recursively_parse_config(cli_args) + if default_cfg is not None: + ordered_cfgs.insert(0, default_cfg) + cfg = OmegaConf.merge(*ordered_cfgs) + return OmegaConf.to_container(cfg, resolve=True, throw_on_missing=True) + + +T = TypeVar("T", bound=BaseModel) + + +def parse_args_to_pydantic_model( + args_cls: Type[T], cli_args: DictConfig | None = None +) -> T: + default_cfg = OmegaConf.create(args_cls().model_dump()) + parsed_cfg = parse_args_with_default(default_cfg=default_cfg, cli_args=cli_args) + pydantic_args = args_cls.model_validate(parsed_cfg) + return pydantic_args diff --git a/bytelatent/configs/debug.yaml b/bytelatent/configs/debug.yaml index 07d489f..1369364 100644 --- a/bytelatent/configs/debug.yaml +++ b/bytelatent/configs/debug.yaml @@ -56,13 +56,11 @@ model: recompute_attn: false custom_bwd: false layer_ckpt: "none" - patch_only_encoder: false - patch_only_decoder: false use_local_encoder_transformer: true init_use_gaussian: true init_use_depth: "current" - attn_bias_type: "block_causal" attn_impl: "xformers" + attn_bias_type: "block_causal" alpha_depth: "disabled" max_length: 256 local_attention_window_len: 512 diff --git a/bytelatent/configs/entropy_model.yaml b/bytelatent/configs/entropy_model.yaml index d7c27b7..79cc85b 100644 --- a/bytelatent/configs/entropy_model.yaml +++ b/bytelatent/configs/entropy_model.yaml @@ -2,9 +2,10 @@ # Evals can be activated by uncommenting its config # python -m launchers.stool config=apps/main/configs/debug.yaml nodes=8 account=fair_amaia_cw_codegen qos=lowest -dump_dir: /tmp/ +dump_dir: /tmp/blt-entropy name: "debug" steps: 100_000 +max_steps: null probe_freq: null seed: 777 optim: @@ -35,7 +36,6 @@ entropy_model: attn_impl: "xformers" data: - s3_profile: blt root_dir: ??? sources: dclm_baseline_1.0: 1.0 diff --git a/bytelatent/eval.py b/bytelatent/eval.py index ae73066..50e17cd 100644 --- a/bytelatent/eval.py +++ b/bytelatent/eval.py @@ -5,18 +5,15 @@ import logging import os from collections import defaultdict from datetime import datetime -from pathlib import Path -from typing import Any import torch from lm_eval import simple_evaluate from lm_eval.api.instance import Instance from lm_eval.api.model import LM -from omegaconf import OmegaConf -from pydantic import BaseModel, ConfigDict -from bytelatent.args import EvalArgs, ValidationArgs, parse_args +from bytelatent.args import EvalArgs, ValidationArgs from bytelatent.checkpoint import CONSOLIDATE_FOLDER, consolidate_checkpoints +from bytelatent.config_parser import parse_args_to_pydantic_model from bytelatent.data.file_util import get_fs from bytelatent.distributed import ( DistributedArgs, @@ -29,7 +26,6 @@ from bytelatent.generate import ( PackedCausalTransformerGenerator, load_consolidated_model_and_tokenizer, ) -from bytelatent.transformer import LMTransformer, LMTransformerArgs EVAL_FOLDER_NAME = "{:010d}" diff --git a/bytelatent/print_config.py b/bytelatent/print_config.py new file mode 100644 index 0000000..0bc99e7 --- /dev/null +++ b/bytelatent/print_config.py @@ -0,0 +1,11 @@ +from bytelatent.args import TrainArgs +from bytelatent.config_parser import parse_args_to_pydantic_model + + +def main(): + train_args = parse_args_to_pydantic_model(TrainArgs) + print(train_args.model_dump_json(indent=4)) + + +if __name__ == "__main__": + main() diff --git a/bytelatent/test_config_parser.py b/bytelatent/test_config_parser.py new file mode 100644 index 0000000..c1ec99b --- /dev/null +++ b/bytelatent/test_config_parser.py @@ -0,0 +1,180 @@ +import os + +import pytest +from omegaconf import DictConfig, MissingMandatoryValue, OmegaConf +from pydantic import BaseModel, ConfigDict + +from bytelatent.config_parser import ( + parse_args_to_pydantic_model, + parse_file_config, + recursively_parse_config, +) + +FIXTURE_DIR = "fixtures/test-cfgs" + + +def test_parse_file_config(): + with pytest.raises(ValueError): + cfg = parse_file_config(os.path.join(FIXTURE_DIR, "list.yaml")) + assert isinstance(cfg, DictConfig) + + +def test_nop(): + cfg = OmegaConf.create({"a": 1}) + parsed_cfgs = recursively_parse_config(cfg) + assert len(parsed_cfgs) == 1 + assert parsed_cfgs[0] == cfg + + +def test_root(): + cli_cfg = OmegaConf.create({"config": os.path.join(FIXTURE_DIR, "root.yaml")}) + parsed_cfgs = recursively_parse_config(cli_cfg) + assert len(parsed_cfgs) == 2 + assert len(parsed_cfgs[1]) == 0 + assert parsed_cfgs[0]["seed"] == -1 + with pytest.raises(MissingMandatoryValue): + assert parsed_cfgs[0]["b"]["y"] is not None + + # Test basic cli override + cli_cfg = OmegaConf.create( + {"config": os.path.join(FIXTURE_DIR, "root.yaml"), "seed": 42} + ) + parsed_cfgs = recursively_parse_config(cli_cfg) + assert parsed_cfgs[1]["seed"] == 42 + cfg = OmegaConf.merge(*parsed_cfgs) + assert cfg["seed"] == 42 + + +def test_one_level_include(): + cli_cfg = OmegaConf.create({"config": os.path.join(FIXTURE_DIR, "middle.yaml")}) + parsed_cfgs = recursively_parse_config(cli_cfg) + assert len(parsed_cfgs) == 3 + assert parsed_cfgs[0]["seed"] == -1 + assert parsed_cfgs[1]["b"]["y"] == 10 + assert len(parsed_cfgs[2]) == 0 + cfg = OmegaConf.merge(*parsed_cfgs) + assert cfg["b"]["y"] == 10 + + cli_cfg = OmegaConf.create( + {"config": os.path.join(FIXTURE_DIR, "middle.yaml"), "b": {"y": 100}} + ) + parsed_cfgs = recursively_parse_config(cli_cfg) + assert len(parsed_cfgs) == 3 + assert parsed_cfgs[0]["seed"] == -1 + assert parsed_cfgs[1]["b"]["y"] == 10 + assert parsed_cfgs[2]["b"]["y"] == 100 + cfg = OmegaConf.merge(*parsed_cfgs) + assert cfg["b"]["y"] == 100 + + +def test_two_level_include(): + cli_cfg = OmegaConf.create( + {"config": os.path.join(FIXTURE_DIR, "top.yaml"), "p": 500, "b": {"z": -2}} + ) + parsed_cfgs = recursively_parse_config(cli_cfg) + assert len(parsed_cfgs) == 4 + assert parsed_cfgs[0]["seed"] == -1 + assert parsed_cfgs[1]["b"]["y"] == 10 + assert parsed_cfgs[2]["hello"] == "world" + assert parsed_cfgs[3]["p"] == 500 + assert parsed_cfgs[3]["b"]["z"] == -2 + cfg = OmegaConf.merge(*parsed_cfgs) + assert cfg["a"] == 1 + assert cfg["seed"] == -1 + assert cfg["b"]["x"] == 0 + assert cfg["b"]["y"] == 10 + assert cfg["b"]["z"] == -2 + assert cfg["hello"] == "world" + + +def test_multiple_includes(): + cli_cfg = OmegaConf.create( + { + "config": [ + os.path.join(FIXTURE_DIR, "top.yaml"), + os.path.join(FIXTURE_DIR, "override.yaml"), + ], + "p": 500, + "b": {"z": -2}, + } + ) + parsed_cfgs = recursively_parse_config(cli_cfg) + assert len(parsed_cfgs) == 5 + assert parsed_cfgs[0]["seed"] == -1 + assert parsed_cfgs[1]["b"]["y"] == 10 + assert parsed_cfgs[2]["hello"] == "world" + assert parsed_cfgs[3]["a"] == 100 + assert parsed_cfgs[4]["p"] == 500 + assert parsed_cfgs[4]["b"]["z"] == -2 + cfg = OmegaConf.merge(*parsed_cfgs) + assert cfg["a"] == 100 + assert cfg["seed"] == -1 + assert cfg["b"]["x"] == 0 + assert cfg["b"]["y"] == 10 + assert cfg["b"]["z"] == -2 + assert cfg["hello"] == "world" + + cli_cfg = OmegaConf.create( + { + "config": [ + os.path.join(FIXTURE_DIR, "top.yaml"), + os.path.join(FIXTURE_DIR, "override.yaml"), + ], + "p": 500, + "b": {"z": -2}, + "a": 1000, + } + ) + parsed_cfgs = recursively_parse_config(cli_cfg) + assert len(parsed_cfgs) == 5 + assert parsed_cfgs[0]["seed"] == -1 + assert parsed_cfgs[1]["b"]["y"] == 10 + assert parsed_cfgs[2]["hello"] == "world" + assert parsed_cfgs[3]["a"] == 100 + assert parsed_cfgs[4]["p"] == 500 + assert parsed_cfgs[4]["b"]["z"] == -2 + cfg = OmegaConf.merge(*parsed_cfgs) + assert cfg["a"] == 1000 + assert cfg["seed"] == -1 + assert cfg["b"]["x"] == 0 + assert cfg["b"]["y"] == 10 + assert cfg["b"]["z"] == -2 + assert cfg["hello"] == "world" + + +class SubConfig(BaseModel): + model_config = ConfigDict(extra="forbid") + x: int = -100 + y: int = -100 + z: int = -5 + + +class SampleConfig(BaseModel): + model_config = ConfigDict(extra="forbid") + a: int = -100 + seed: int = -100 + b: SubConfig = SubConfig() + hello: str = "" + p: int = -100 + + +def test_pydantic_parse(): + cli_cfg = OmegaConf.create( + { + "config": [ + os.path.join(FIXTURE_DIR, "top.yaml"), + os.path.join(FIXTURE_DIR, "override.yaml"), + ], + "p": 500, + "a": 1000, + } + ) + cfg = parse_args_to_pydantic_model(SampleConfig, cli_args=cli_cfg) + assert isinstance(cfg, SampleConfig) + assert cfg.a == 1000 + assert cfg.p == 500 + assert cfg.seed == -1 + assert cfg.b.x == 0 + assert cfg.b.y == 10 + assert cfg.b.z == -5 + assert cfg.hello == "world" diff --git a/bytelatent/train.py b/bytelatent/train.py index 3669167..ad74b44 100644 --- a/bytelatent/train.py +++ b/bytelatent/train.py @@ -23,8 +23,9 @@ from torch.distributed._tensor import DTensor from torch.distributed.checkpoint.stateful import Stateful from torch.optim import lr_scheduler -from bytelatent.args import TrainArgs, parse_args +from bytelatent.args import TrainArgs from bytelatent.checkpoint import CheckpointManager, load_from_checkpoint +from bytelatent.config_parser import parse_args_to_pydantic_model from bytelatent.data.file_util import get_fs from bytelatent.data.iterators.abstract_iterator import get_state_and_refresh from bytelatent.data.iterators.multiprocess_iterator import ( @@ -824,7 +825,7 @@ def main(): Plus all the default values in TrainArgs dataclass. """ - train_args = parse_args(TrainArgs) + train_args = parse_args_to_pydantic_model(TrainArgs) if train_args.debug_dynamo: import torch._dynamo diff --git a/fixtures/test-cfgs/list.yaml b/fixtures/test-cfgs/list.yaml new file mode 100644 index 0000000..b5d8bb5 --- /dev/null +++ b/fixtures/test-cfgs/list.yaml @@ -0,0 +1 @@ +[1, 2, 3] diff --git a/fixtures/test-cfgs/middle.yaml b/fixtures/test-cfgs/middle.yaml new file mode 100644 index 0000000..a476d8d --- /dev/null +++ b/fixtures/test-cfgs/middle.yaml @@ -0,0 +1,3 @@ +config: fixtures/test-cfgs/root.yaml +b: + y: 10 diff --git a/fixtures/test-cfgs/override.yaml b/fixtures/test-cfgs/override.yaml new file mode 100644 index 0000000..456df7b --- /dev/null +++ b/fixtures/test-cfgs/override.yaml @@ -0,0 +1 @@ +a: 100 diff --git a/fixtures/test-cfgs/root.yaml b/fixtures/test-cfgs/root.yaml new file mode 100644 index 0000000..dc4d285 --- /dev/null +++ b/fixtures/test-cfgs/root.yaml @@ -0,0 +1,6 @@ +seed: -1 +a: 1 +b: + x: 0 + y: ??? + z: ??? diff --git a/fixtures/test-cfgs/top.yaml b/fixtures/test-cfgs/top.yaml new file mode 100644 index 0000000..632866c --- /dev/null +++ b/fixtures/test-cfgs/top.yaml @@ -0,0 +1,3 @@ +config: fixtures/test-cfgs/middle.yaml + +hello: world From a3e0647d03986aacdcc016326f3f1f1b40892893 Mon Sep 17 00:00:00 2001 From: Pedro Rodriguez Date: Fri, 14 Feb 2025 23:45:11 +0000 Subject: [PATCH 55/59] Make apex logs less noisy Summary: Test Plan: --- bytelatent/base_transformer.py | 6 ++++-- bytelatent/model/latent_transformer.py | 5 ++--- bytelatent/model/local_models.py | 7 +++---- bytelatent/transformer.py | 8 +++++--- 4 files changed, 14 insertions(+), 12 deletions(-) diff --git a/bytelatent/base_transformer.py b/bytelatent/base_transformer.py index 7b76b9e..d44676d 100644 --- a/bytelatent/base_transformer.py +++ b/bytelatent/base_transformer.py @@ -1,4 +1,5 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. +import logging import os from enum import Enum from typing import Optional, Tuple, Union @@ -14,15 +15,16 @@ from torch.nn.attention.flex_attention import ( ) from xformers.ops import AttentionBias, fmha -from bytelatent import probe from bytelatent.tokenizers.constants import EOS_ID +logger = logging.getLogger() + try: from apex.normalization.fused_layer_norm import FusedRMSNorm RMSNorm = FusedRMSNorm except (ImportError, ModuleNotFoundError): - print("Apex not found. Using nn.RMSNorm") + logging.debug("Apex not found. Using nn.RMSNorm") RMSNorm = nn.RMSNorm if int(os.environ.get("BLT_ALLOW_MISSING_FLEX_ATTENTION", False)) == 0: diff --git a/bytelatent/model/latent_transformer.py b/bytelatent/model/latent_transformer.py index 95b6d8b..a6cabdc 100644 --- a/bytelatent/model/latent_transformer.py +++ b/bytelatent/model/latent_transformer.py @@ -17,16 +17,15 @@ from bytelatent.base_transformer import ( ) from bytelatent.model.utils import create_causal_mask +logger = logging.getLogger() try: from apex.normalization.fused_layer_norm import FusedRMSNorm RMSNorm = FusedRMSNorm except (ImportError, ModuleNotFoundError): - print("Apex not found. Using nn.RMSNorm") + logging.debug("Apex not found. Using nn.RMSNorm") RMSNorm = nn.RMSNorm -logger = logging.getLogger() - class CrossAttention(nn.Module): """ diff --git a/bytelatent/model/local_models.py b/bytelatent/model/local_models.py index 353c878..09a5a19 100644 --- a/bytelatent/model/local_models.py +++ b/bytelatent/model/local_models.py @@ -6,7 +6,7 @@ from typing import Any, List, Optional, Tuple, Union import torch import torch.nn import torch.nn as nn -from pydantic import BaseModel, ConfigDict +from pydantic import ConfigDict from torch.nn import functional as F from torch.nn.attention.flex_attention import BlockMask from xformers.ops import AttentionBias @@ -21,16 +21,15 @@ from bytelatent.model.latent_transformer import CrossAttention from bytelatent.model.utils import create_causal_mask, downsample from bytelatent.tokenizers.blt_tokenizer import BOE_ID +logger = logging.getLogger() try: from apex.normalization.fused_layer_norm import FusedRMSNorm RMSNorm = FusedRMSNorm except (ImportError, ModuleNotFoundError): - print("Apex not found. Using nn.RMSNorm") + logging.debug("Apex not found. Using nn.RMSNorm") RMSNorm = nn.RMSNorm -logger = logging.getLogger() - class LocalModelArgs(BaseTransformerArgs): model_config = ConfigDict(extra="forbid") diff --git a/bytelatent/transformer.py b/bytelatent/transformer.py index 2e45ea5..da03761 100644 --- a/bytelatent/transformer.py +++ b/bytelatent/transformer.py @@ -1,6 +1,6 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. -from dataclasses import dataclass +import logging from typing import Optional, Tuple, Union import torch @@ -14,7 +14,7 @@ from torch.distributed.tensor.parallel import ( parallelize_module, ) from torch.nn.attention.flex_attention import BlockMask, create_block_mask -from xformers.ops import AttentionBias, fmha +from xformers.ops import AttentionBias from bytelatent.base_transformer import ( BaseTransformer, @@ -23,12 +23,14 @@ from bytelatent.base_transformer import ( ) from bytelatent.model.utils import create_causal_mask +logger = logging.getLogger() + try: from apex.normalization.fused_layer_norm import FusedRMSNorm RMSNorm = FusedRMSNorm except (ImportError, ModuleNotFoundError): - print("Apex not found. Using nn.RMSNorm") + logging.debug("Apex not found. Using nn.RMSNorm") RMSNorm = nn.RMSNorm From 655eca670d45bfc9221148580872c4a6462ca66f Mon Sep 17 00:00:00 2001 From: Pedro Rodriguez Date: Fri, 14 Feb 2025 23:44:54 +0000 Subject: [PATCH 56/59] Minimal working eval Summary: Test Plan: --- bytelatent/args.py | 4 +-- bytelatent/eval.py | 66 +++++++++++++++++++++++++++++------------- bytelatent/generate.py | 6 ++-- bytelatent/train.py | 4 ++- 4 files changed, 54 insertions(+), 26 deletions(-) diff --git a/bytelatent/args.py b/bytelatent/args.py index dd1fef5..15e44f1 100644 --- a/bytelatent/args.py +++ b/bytelatent/args.py @@ -248,8 +248,8 @@ class ValidationArgs(BaseModel): class EvalArgs(BaseModel): model_config = ConfigDict(extra="forbid") - dump_dir: str - ckpt_dir: str + dump_dir: str | None = None + ckpt_dir: str | None = None metric_log_dir: str | None = None generator: PackedCausalTransformerGeneratorArgs = ( PackedCausalTransformerGeneratorArgs() diff --git a/bytelatent/eval.py b/bytelatent/eval.py index 50e17cd..3ffe0ae 100644 --- a/bytelatent/eval.py +++ b/bytelatent/eval.py @@ -11,10 +11,16 @@ from lm_eval import simple_evaluate from lm_eval.api.instance import Instance from lm_eval.api.model import LM -from bytelatent.args import EvalArgs, ValidationArgs +from bytelatent.args import ( + EvalArgs, + TrainArgs, + ValidationArgs, + find_and_sanitize_chunks, +) from bytelatent.checkpoint import CONSOLIDATE_FOLDER, consolidate_checkpoints from bytelatent.config_parser import parse_args_to_pydantic_model from bytelatent.data.file_util import get_fs +from bytelatent.data.iterators.arrow_iterator import ArrowFileIterator from bytelatent.distributed import ( DistributedArgs, dist_mean_dict, @@ -113,19 +119,40 @@ class EvalHarnessLM(LM): return results -def eval_on_val(generator, val_args: ValidationArgs, train_cfg): - srcs = {} +def eval_on_val(generator, val_args: ValidationArgs, train_cfg: TrainArgs): + srcs = [] for src in val_args.sources: path = os.path.join(val_args.root_dir, src) - srcs[path] = 1.0 + srcs.append(path) + for src in train_cfg.data.sources: path = os.path.join(train_cfg.data.root_dir, src) - srcs[path] = 1.0 + srcs.append(path) - multi_state = init_choice_state( - "", srcs, 0, get_global_rank(), get_world_size(), "*.val.jsonl" - ) - path_to_iter = setup_sources(multi_state) + path_to_iter = {} + for path in srcs: + chunks = find_and_sanitize_chunks( + path, + world_size=1, + file_pattern="*.val.jsonl", + s3_profile=train_cfg.data.s3_profile, + ) + assert ( + len(chunks) == 1 + ), f"There should be only 1 chunk per validation file, but found: {chunks}" + chunk = chunks[0] + iterator = ArrowFileIterator( + dataset_files=[chunk], + file_path=None, + preprocess_dir=None, + entropy_model_name=None, + worker_id=0, + num_workers=1, + arrow_batch_size=train_cfg.data.arrow_batch_size, + s3_profile=train_cfg.data.s3_profile, + file_format="json", + ) + path_to_iter[path] = iterator max_gen_len = generator.max_gen_len # We temporarily lower max gen len @@ -133,16 +160,11 @@ def eval_on_val(generator, val_args: ValidationArgs, train_cfg): all_val_metrics = {} for src in path_to_iter: - jsonl_iterator = path_to_iter[src] + example_iterator = path_to_iter[src].create_iter() texts = [] logger.info(f"Running validation on {src}...") - for step, (content, state) in enumerate(jsonl_iterator): - if state["current_iter"] > 0 or ( - val_args.max_steps is not None and step >= val_args.max_steps - ): - break - content_key = "text" if ("text" in content) else "content" - texts.append(content[content_key]) + for step, example in enumerate(example_iterator): + texts.append(example.text) _, loglikelihood, _ = generator.generate(texts) @@ -187,7 +209,7 @@ def launch_eval(eval_args: EvalArgs): else: consolidate_path = os.path.join(eval_args.ckpt_dir, CONSOLIDATE_FOLDER) if not fs.exists(consolidate_path) and get_global_rank() == 0: - consolidate_path = consolidate_checkpoints(eval_args.ckpt_dir) + consolidate_path = consolidate_checkpoints(fs, eval_args.ckpt_dir) fs.mkdirs(eval_args.dump_dir, exist_ok=True) with fs.open(os.path.join(eval_args.dump_dir, "config.yaml"), "w") as f: @@ -206,10 +228,13 @@ def launch_eval(eval_args: EvalArgs): wrap = EvalHarnessLM(generator) # Redo - results = simple_evaluate(wrap, eval_args.harness.model_dump()) + # results = simple_evaluate(wrap, **eval_args.harness.model_dump()) + results = {"results": []} + val_results = None if eval_args.validation: val_results = eval_on_val(generator, eval_args.validation, train_cfg) + if get_global_rank() == 0: with fs.open(os.path.join(eval_args.dump_dir, "results.json"), "w") as f: f.write(json.dumps(results)) @@ -218,6 +243,7 @@ def launch_eval(eval_args: EvalArgs): with fs.open(os.path.join(eval_args.dump_dir, "validation.json"), "w") as f: f.write(json.dumps(val_results)) logger.info(f"All validation results: {val_results}") + if eval_args.metric_log_dir and get_global_rank() == 0: metric_log_path = os.path.join(eval_args.metric_log_dir, "metrics.eval.jsonl") @@ -247,7 +273,7 @@ def launch_eval(eval_args: EvalArgs): def main(): - eval_args = parse_args(EvalArgs) + eval_args = parse_args_to_pydantic_model(EvalArgs) launch_eval(eval_args) diff --git a/bytelatent/generate.py b/bytelatent/generate.py index eb79d81..9d44f30 100644 --- a/bytelatent/generate.py +++ b/bytelatent/generate.py @@ -387,8 +387,7 @@ def load_consolidated_model_and_tokenizer( ): train_args_path = os.path.join(consolidated_path, "params.json") fs = get_fs(train_args_path) - with fs.open(train_args_path) as f: - train_args = TrainArgs.model_validate_json(f.read()) + train_args = TrainArgs.model_validate_json(fs.read_text(train_args_path)) if train_args.train_entropy_model: model_args = train_args.entropy_model @@ -401,7 +400,8 @@ def load_consolidated_model_and_tokenizer( train_args.distributed.model_dtype ] tokenizer = train_args.data.tokenizer_args.build() - st_dict = torch.load(consolidated_path / CONSOLIDATE_NAME, weights_only=True) + with fs.open(os.path.join(consolidated_path, CONSOLIDATE_NAME)) as f: + st_dict = torch.load(f, weights_only=True) model.load_state_dict(st_dict["model"]) model = model.cuda().eval() for param in model.parameters(): diff --git a/bytelatent/train.py b/bytelatent/train.py index ad74b44..af1c694 100644 --- a/bytelatent/train.py +++ b/bytelatent/train.py @@ -241,7 +241,9 @@ def set_preemption_flag(signum, frame): preemption_flag["flag"] = True -def every_n_steps(train_state, freq, acc_step=None, acc_freq=None): +def every_n_steps(train_state, freq: int, acc_step=None, acc_freq=None): + if freq < 0: + return False test = train_state.step % freq == 0 if acc_step is not None: test = test and (train_state.acc_step == acc_step) From 3117ac1f1f43bd5189b8b4dc23496d52398ff1f6 Mon Sep 17 00:00:00 2001 From: Pedro Rodriguez Date: Tue, 18 Feb 2025 18:41:02 +0000 Subject: [PATCH 57/59] Make it possible to specify multiple config files Summary: Make it possible to specify multiple config files. Parsing CLI is not a special case anymore, just uses the same config inheritance method. Test Plan: Test that this iterpolates in the right order via unit tests Sample usage, loads the internal config, which references bytelatent/configs/entropy_model.yaml. The precendence order is: - Default pydantic args - Included configs, eg `config` - CLI args ``` python -m bytelatent.print_config config=internal/configs/entropy_model.yaml eval=null ``` Summary: Test Plan: --- bytelatent/args.py | 14 -- bytelatent/config_parser.py | 73 +++++++++++ bytelatent/configs/debug.yaml | 4 +- bytelatent/configs/entropy_model.yaml | 4 +- bytelatent/eval.py | 8 +- bytelatent/print_config.py | 11 ++ bytelatent/test_config_parser.py | 180 ++++++++++++++++++++++++++ bytelatent/train.py | 5 +- fixtures/test-cfgs/list.yaml | 1 + fixtures/test-cfgs/middle.yaml | 3 + fixtures/test-cfgs/override.yaml | 1 + fixtures/test-cfgs/root.yaml | 6 + fixtures/test-cfgs/top.yaml | 3 + 13 files changed, 286 insertions(+), 27 deletions(-) create mode 100644 bytelatent/config_parser.py create mode 100644 bytelatent/print_config.py create mode 100644 bytelatent/test_config_parser.py create mode 100644 fixtures/test-cfgs/list.yaml create mode 100644 fixtures/test-cfgs/middle.yaml create mode 100644 fixtures/test-cfgs/override.yaml create mode 100644 fixtures/test-cfgs/root.yaml create mode 100644 fixtures/test-cfgs/top.yaml diff --git a/bytelatent/args.py b/bytelatent/args.py index 47bd0f9..dd1fef5 100644 --- a/bytelatent/args.py +++ b/bytelatent/args.py @@ -5,7 +5,6 @@ from typing import Any import numpy as np import yaml -from omegaconf import OmegaConf from pydantic import BaseModel, ConfigDict from bytelatent.checkpoint import CheckpointArgs @@ -38,19 +37,6 @@ def get_rng_state(seed: int, rank: int, world_size: int) -> dict[str, Any]: return np.random.default_rng((seed, rank, world_size)).bit_generator.state -def parse_args(args_cls): - cli_args = OmegaConf.from_cli() - file_cfg = OmegaConf.load(cli_args.config) - # We remove 'config' attribute from config as the underlying DataClass does not have it - del cli_args.config - - default_cfg = OmegaConf.create(args_cls().model_dump()) - cfg = OmegaConf.merge(default_cfg, file_cfg, cli_args) - cfg = OmegaConf.to_container(cfg, resolve=True, throw_on_missing=True) - pydantic_args = args_cls.model_validate(cfg) - return pydantic_args - - TRAIN_DATA_FILE_PATTERN = "*.chunk.*.jsonl" diff --git a/bytelatent/config_parser.py b/bytelatent/config_parser.py new file mode 100644 index 0000000..2e310a2 --- /dev/null +++ b/bytelatent/config_parser.py @@ -0,0 +1,73 @@ +import copy +from typing import Type, TypeVar + +import omegaconf +from omegaconf import DictConfig, OmegaConf +from pydantic import BaseModel + + +def parse_file_config(path: str) -> DictConfig: + file_cfg = OmegaConf.load(path) + if not isinstance(file_cfg, DictConfig): + raise ValueError( + f"File paths must parse to DictConfig, but it was: {type(file_cfg)}" + ) + return file_cfg + + +def recursively_parse_config(cfg: DictConfig) -> list[DictConfig]: + if "config" not in cfg: + return [cfg] + + ordered_cfgs = [] + cfg = copy.deepcopy(cfg) + config_arg = cfg["config"] + del cfg["config"] + ordered_cfgs.append(cfg) + + if isinstance(config_arg, str): + file_cfg = parse_file_config(config_arg) + sub_configs = recursively_parse_config(file_cfg) + ordered_cfgs = sub_configs + ordered_cfgs + elif isinstance(config_arg, omegaconf.listconfig.ListConfig): + sub_configs = [] + for c in config_arg: + if not isinstance(c, str): + raise ValueError( + f'If "config" is specified, it must be either a string path or a list of string paths. It was config={config_arg}' + ) + config_to_parse = parse_file_config(c) + sub_configs.extend(recursively_parse_config(config_to_parse)) + ordered_cfgs = sub_configs + ordered_cfgs + else: + raise ValueError( + f'If "config" is specified, it must be either a string path or a list of string paths, it was config={config_arg}' + ) + return ordered_cfgs + + +def parse_args_with_default( + *, default_cfg: DictConfig | None = None, cli_args: DictConfig | None = None +): + if cli_args is None: + cli_args = OmegaConf.from_cli() + assert isinstance( + cli_args, DictConfig + ), f"CLI Args must be a DictConfig, not {type(cli_args)}" + ordered_cfgs = recursively_parse_config(cli_args) + if default_cfg is not None: + ordered_cfgs.insert(0, default_cfg) + cfg = OmegaConf.merge(*ordered_cfgs) + return OmegaConf.to_container(cfg, resolve=True, throw_on_missing=True) + + +T = TypeVar("T", bound=BaseModel) + + +def parse_args_to_pydantic_model( + args_cls: Type[T], cli_args: DictConfig | None = None +) -> T: + default_cfg = OmegaConf.create(args_cls().model_dump()) + parsed_cfg = parse_args_with_default(default_cfg=default_cfg, cli_args=cli_args) + pydantic_args = args_cls.model_validate(parsed_cfg) + return pydantic_args diff --git a/bytelatent/configs/debug.yaml b/bytelatent/configs/debug.yaml index 07d489f..1369364 100644 --- a/bytelatent/configs/debug.yaml +++ b/bytelatent/configs/debug.yaml @@ -56,13 +56,11 @@ model: recompute_attn: false custom_bwd: false layer_ckpt: "none" - patch_only_encoder: false - patch_only_decoder: false use_local_encoder_transformer: true init_use_gaussian: true init_use_depth: "current" - attn_bias_type: "block_causal" attn_impl: "xformers" + attn_bias_type: "block_causal" alpha_depth: "disabled" max_length: 256 local_attention_window_len: 512 diff --git a/bytelatent/configs/entropy_model.yaml b/bytelatent/configs/entropy_model.yaml index d7c27b7..79cc85b 100644 --- a/bytelatent/configs/entropy_model.yaml +++ b/bytelatent/configs/entropy_model.yaml @@ -2,9 +2,10 @@ # Evals can be activated by uncommenting its config # python -m launchers.stool config=apps/main/configs/debug.yaml nodes=8 account=fair_amaia_cw_codegen qos=lowest -dump_dir: /tmp/ +dump_dir: /tmp/blt-entropy name: "debug" steps: 100_000 +max_steps: null probe_freq: null seed: 777 optim: @@ -35,7 +36,6 @@ entropy_model: attn_impl: "xformers" data: - s3_profile: blt root_dir: ??? sources: dclm_baseline_1.0: 1.0 diff --git a/bytelatent/eval.py b/bytelatent/eval.py index ae73066..50e17cd 100644 --- a/bytelatent/eval.py +++ b/bytelatent/eval.py @@ -5,18 +5,15 @@ import logging import os from collections import defaultdict from datetime import datetime -from pathlib import Path -from typing import Any import torch from lm_eval import simple_evaluate from lm_eval.api.instance import Instance from lm_eval.api.model import LM -from omegaconf import OmegaConf -from pydantic import BaseModel, ConfigDict -from bytelatent.args import EvalArgs, ValidationArgs, parse_args +from bytelatent.args import EvalArgs, ValidationArgs from bytelatent.checkpoint import CONSOLIDATE_FOLDER, consolidate_checkpoints +from bytelatent.config_parser import parse_args_to_pydantic_model from bytelatent.data.file_util import get_fs from bytelatent.distributed import ( DistributedArgs, @@ -29,7 +26,6 @@ from bytelatent.generate import ( PackedCausalTransformerGenerator, load_consolidated_model_and_tokenizer, ) -from bytelatent.transformer import LMTransformer, LMTransformerArgs EVAL_FOLDER_NAME = "{:010d}" diff --git a/bytelatent/print_config.py b/bytelatent/print_config.py new file mode 100644 index 0000000..0bc99e7 --- /dev/null +++ b/bytelatent/print_config.py @@ -0,0 +1,11 @@ +from bytelatent.args import TrainArgs +from bytelatent.config_parser import parse_args_to_pydantic_model + + +def main(): + train_args = parse_args_to_pydantic_model(TrainArgs) + print(train_args.model_dump_json(indent=4)) + + +if __name__ == "__main__": + main() diff --git a/bytelatent/test_config_parser.py b/bytelatent/test_config_parser.py new file mode 100644 index 0000000..c1ec99b --- /dev/null +++ b/bytelatent/test_config_parser.py @@ -0,0 +1,180 @@ +import os + +import pytest +from omegaconf import DictConfig, MissingMandatoryValue, OmegaConf +from pydantic import BaseModel, ConfigDict + +from bytelatent.config_parser import ( + parse_args_to_pydantic_model, + parse_file_config, + recursively_parse_config, +) + +FIXTURE_DIR = "fixtures/test-cfgs" + + +def test_parse_file_config(): + with pytest.raises(ValueError): + cfg = parse_file_config(os.path.join(FIXTURE_DIR, "list.yaml")) + assert isinstance(cfg, DictConfig) + + +def test_nop(): + cfg = OmegaConf.create({"a": 1}) + parsed_cfgs = recursively_parse_config(cfg) + assert len(parsed_cfgs) == 1 + assert parsed_cfgs[0] == cfg + + +def test_root(): + cli_cfg = OmegaConf.create({"config": os.path.join(FIXTURE_DIR, "root.yaml")}) + parsed_cfgs = recursively_parse_config(cli_cfg) + assert len(parsed_cfgs) == 2 + assert len(parsed_cfgs[1]) == 0 + assert parsed_cfgs[0]["seed"] == -1 + with pytest.raises(MissingMandatoryValue): + assert parsed_cfgs[0]["b"]["y"] is not None + + # Test basic cli override + cli_cfg = OmegaConf.create( + {"config": os.path.join(FIXTURE_DIR, "root.yaml"), "seed": 42} + ) + parsed_cfgs = recursively_parse_config(cli_cfg) + assert parsed_cfgs[1]["seed"] == 42 + cfg = OmegaConf.merge(*parsed_cfgs) + assert cfg["seed"] == 42 + + +def test_one_level_include(): + cli_cfg = OmegaConf.create({"config": os.path.join(FIXTURE_DIR, "middle.yaml")}) + parsed_cfgs = recursively_parse_config(cli_cfg) + assert len(parsed_cfgs) == 3 + assert parsed_cfgs[0]["seed"] == -1 + assert parsed_cfgs[1]["b"]["y"] == 10 + assert len(parsed_cfgs[2]) == 0 + cfg = OmegaConf.merge(*parsed_cfgs) + assert cfg["b"]["y"] == 10 + + cli_cfg = OmegaConf.create( + {"config": os.path.join(FIXTURE_DIR, "middle.yaml"), "b": {"y": 100}} + ) + parsed_cfgs = recursively_parse_config(cli_cfg) + assert len(parsed_cfgs) == 3 + assert parsed_cfgs[0]["seed"] == -1 + assert parsed_cfgs[1]["b"]["y"] == 10 + assert parsed_cfgs[2]["b"]["y"] == 100 + cfg = OmegaConf.merge(*parsed_cfgs) + assert cfg["b"]["y"] == 100 + + +def test_two_level_include(): + cli_cfg = OmegaConf.create( + {"config": os.path.join(FIXTURE_DIR, "top.yaml"), "p": 500, "b": {"z": -2}} + ) + parsed_cfgs = recursively_parse_config(cli_cfg) + assert len(parsed_cfgs) == 4 + assert parsed_cfgs[0]["seed"] == -1 + assert parsed_cfgs[1]["b"]["y"] == 10 + assert parsed_cfgs[2]["hello"] == "world" + assert parsed_cfgs[3]["p"] == 500 + assert parsed_cfgs[3]["b"]["z"] == -2 + cfg = OmegaConf.merge(*parsed_cfgs) + assert cfg["a"] == 1 + assert cfg["seed"] == -1 + assert cfg["b"]["x"] == 0 + assert cfg["b"]["y"] == 10 + assert cfg["b"]["z"] == -2 + assert cfg["hello"] == "world" + + +def test_multiple_includes(): + cli_cfg = OmegaConf.create( + { + "config": [ + os.path.join(FIXTURE_DIR, "top.yaml"), + os.path.join(FIXTURE_DIR, "override.yaml"), + ], + "p": 500, + "b": {"z": -2}, + } + ) + parsed_cfgs = recursively_parse_config(cli_cfg) + assert len(parsed_cfgs) == 5 + assert parsed_cfgs[0]["seed"] == -1 + assert parsed_cfgs[1]["b"]["y"] == 10 + assert parsed_cfgs[2]["hello"] == "world" + assert parsed_cfgs[3]["a"] == 100 + assert parsed_cfgs[4]["p"] == 500 + assert parsed_cfgs[4]["b"]["z"] == -2 + cfg = OmegaConf.merge(*parsed_cfgs) + assert cfg["a"] == 100 + assert cfg["seed"] == -1 + assert cfg["b"]["x"] == 0 + assert cfg["b"]["y"] == 10 + assert cfg["b"]["z"] == -2 + assert cfg["hello"] == "world" + + cli_cfg = OmegaConf.create( + { + "config": [ + os.path.join(FIXTURE_DIR, "top.yaml"), + os.path.join(FIXTURE_DIR, "override.yaml"), + ], + "p": 500, + "b": {"z": -2}, + "a": 1000, + } + ) + parsed_cfgs = recursively_parse_config(cli_cfg) + assert len(parsed_cfgs) == 5 + assert parsed_cfgs[0]["seed"] == -1 + assert parsed_cfgs[1]["b"]["y"] == 10 + assert parsed_cfgs[2]["hello"] == "world" + assert parsed_cfgs[3]["a"] == 100 + assert parsed_cfgs[4]["p"] == 500 + assert parsed_cfgs[4]["b"]["z"] == -2 + cfg = OmegaConf.merge(*parsed_cfgs) + assert cfg["a"] == 1000 + assert cfg["seed"] == -1 + assert cfg["b"]["x"] == 0 + assert cfg["b"]["y"] == 10 + assert cfg["b"]["z"] == -2 + assert cfg["hello"] == "world" + + +class SubConfig(BaseModel): + model_config = ConfigDict(extra="forbid") + x: int = -100 + y: int = -100 + z: int = -5 + + +class SampleConfig(BaseModel): + model_config = ConfigDict(extra="forbid") + a: int = -100 + seed: int = -100 + b: SubConfig = SubConfig() + hello: str = "" + p: int = -100 + + +def test_pydantic_parse(): + cli_cfg = OmegaConf.create( + { + "config": [ + os.path.join(FIXTURE_DIR, "top.yaml"), + os.path.join(FIXTURE_DIR, "override.yaml"), + ], + "p": 500, + "a": 1000, + } + ) + cfg = parse_args_to_pydantic_model(SampleConfig, cli_args=cli_cfg) + assert isinstance(cfg, SampleConfig) + assert cfg.a == 1000 + assert cfg.p == 500 + assert cfg.seed == -1 + assert cfg.b.x == 0 + assert cfg.b.y == 10 + assert cfg.b.z == -5 + assert cfg.hello == "world" diff --git a/bytelatent/train.py b/bytelatent/train.py index 3669167..ad74b44 100644 --- a/bytelatent/train.py +++ b/bytelatent/train.py @@ -23,8 +23,9 @@ from torch.distributed._tensor import DTensor from torch.distributed.checkpoint.stateful import Stateful from torch.optim import lr_scheduler -from bytelatent.args import TrainArgs, parse_args +from bytelatent.args import TrainArgs from bytelatent.checkpoint import CheckpointManager, load_from_checkpoint +from bytelatent.config_parser import parse_args_to_pydantic_model from bytelatent.data.file_util import get_fs from bytelatent.data.iterators.abstract_iterator import get_state_and_refresh from bytelatent.data.iterators.multiprocess_iterator import ( @@ -824,7 +825,7 @@ def main(): Plus all the default values in TrainArgs dataclass. """ - train_args = parse_args(TrainArgs) + train_args = parse_args_to_pydantic_model(TrainArgs) if train_args.debug_dynamo: import torch._dynamo diff --git a/fixtures/test-cfgs/list.yaml b/fixtures/test-cfgs/list.yaml new file mode 100644 index 0000000..b5d8bb5 --- /dev/null +++ b/fixtures/test-cfgs/list.yaml @@ -0,0 +1 @@ +[1, 2, 3] diff --git a/fixtures/test-cfgs/middle.yaml b/fixtures/test-cfgs/middle.yaml new file mode 100644 index 0000000..a476d8d --- /dev/null +++ b/fixtures/test-cfgs/middle.yaml @@ -0,0 +1,3 @@ +config: fixtures/test-cfgs/root.yaml +b: + y: 10 diff --git a/fixtures/test-cfgs/override.yaml b/fixtures/test-cfgs/override.yaml new file mode 100644 index 0000000..456df7b --- /dev/null +++ b/fixtures/test-cfgs/override.yaml @@ -0,0 +1 @@ +a: 100 diff --git a/fixtures/test-cfgs/root.yaml b/fixtures/test-cfgs/root.yaml new file mode 100644 index 0000000..dc4d285 --- /dev/null +++ b/fixtures/test-cfgs/root.yaml @@ -0,0 +1,6 @@ +seed: -1 +a: 1 +b: + x: 0 + y: ??? + z: ??? diff --git a/fixtures/test-cfgs/top.yaml b/fixtures/test-cfgs/top.yaml new file mode 100644 index 0000000..632866c --- /dev/null +++ b/fixtures/test-cfgs/top.yaml @@ -0,0 +1,3 @@ +config: fixtures/test-cfgs/middle.yaml + +hello: world From 2f247263b9873c262312c79f9c11fb42700d0083 Mon Sep 17 00:00:00 2001 From: Pedro Rodriguez Date: Tue, 18 Feb 2025 18:43:06 +0000 Subject: [PATCH 58/59] Make apex logs less noisy Summary: Test Plan: --- bytelatent/base_transformer.py | 6 ++++-- bytelatent/model/latent_transformer.py | 5 ++--- bytelatent/model/local_models.py | 7 +++---- bytelatent/transformer.py | 8 +++++--- 4 files changed, 14 insertions(+), 12 deletions(-) diff --git a/bytelatent/base_transformer.py b/bytelatent/base_transformer.py index 7b76b9e..d44676d 100644 --- a/bytelatent/base_transformer.py +++ b/bytelatent/base_transformer.py @@ -1,4 +1,5 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. +import logging import os from enum import Enum from typing import Optional, Tuple, Union @@ -14,15 +15,16 @@ from torch.nn.attention.flex_attention import ( ) from xformers.ops import AttentionBias, fmha -from bytelatent import probe from bytelatent.tokenizers.constants import EOS_ID +logger = logging.getLogger() + try: from apex.normalization.fused_layer_norm import FusedRMSNorm RMSNorm = FusedRMSNorm except (ImportError, ModuleNotFoundError): - print("Apex not found. Using nn.RMSNorm") + logging.debug("Apex not found. Using nn.RMSNorm") RMSNorm = nn.RMSNorm if int(os.environ.get("BLT_ALLOW_MISSING_FLEX_ATTENTION", False)) == 0: diff --git a/bytelatent/model/latent_transformer.py b/bytelatent/model/latent_transformer.py index 95b6d8b..a6cabdc 100644 --- a/bytelatent/model/latent_transformer.py +++ b/bytelatent/model/latent_transformer.py @@ -17,16 +17,15 @@ from bytelatent.base_transformer import ( ) from bytelatent.model.utils import create_causal_mask +logger = logging.getLogger() try: from apex.normalization.fused_layer_norm import FusedRMSNorm RMSNorm = FusedRMSNorm except (ImportError, ModuleNotFoundError): - print("Apex not found. Using nn.RMSNorm") + logging.debug("Apex not found. Using nn.RMSNorm") RMSNorm = nn.RMSNorm -logger = logging.getLogger() - class CrossAttention(nn.Module): """ diff --git a/bytelatent/model/local_models.py b/bytelatent/model/local_models.py index 353c878..09a5a19 100644 --- a/bytelatent/model/local_models.py +++ b/bytelatent/model/local_models.py @@ -6,7 +6,7 @@ from typing import Any, List, Optional, Tuple, Union import torch import torch.nn import torch.nn as nn -from pydantic import BaseModel, ConfigDict +from pydantic import ConfigDict from torch.nn import functional as F from torch.nn.attention.flex_attention import BlockMask from xformers.ops import AttentionBias @@ -21,16 +21,15 @@ from bytelatent.model.latent_transformer import CrossAttention from bytelatent.model.utils import create_causal_mask, downsample from bytelatent.tokenizers.blt_tokenizer import BOE_ID +logger = logging.getLogger() try: from apex.normalization.fused_layer_norm import FusedRMSNorm RMSNorm = FusedRMSNorm except (ImportError, ModuleNotFoundError): - print("Apex not found. Using nn.RMSNorm") + logging.debug("Apex not found. Using nn.RMSNorm") RMSNorm = nn.RMSNorm -logger = logging.getLogger() - class LocalModelArgs(BaseTransformerArgs): model_config = ConfigDict(extra="forbid") diff --git a/bytelatent/transformer.py b/bytelatent/transformer.py index 2e45ea5..da03761 100644 --- a/bytelatent/transformer.py +++ b/bytelatent/transformer.py @@ -1,6 +1,6 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. -from dataclasses import dataclass +import logging from typing import Optional, Tuple, Union import torch @@ -14,7 +14,7 @@ from torch.distributed.tensor.parallel import ( parallelize_module, ) from torch.nn.attention.flex_attention import BlockMask, create_block_mask -from xformers.ops import AttentionBias, fmha +from xformers.ops import AttentionBias from bytelatent.base_transformer import ( BaseTransformer, @@ -23,12 +23,14 @@ from bytelatent.base_transformer import ( ) from bytelatent.model.utils import create_causal_mask +logger = logging.getLogger() + try: from apex.normalization.fused_layer_norm import FusedRMSNorm RMSNorm = FusedRMSNorm except (ImportError, ModuleNotFoundError): - print("Apex not found. Using nn.RMSNorm") + logging.debug("Apex not found. Using nn.RMSNorm") RMSNorm = nn.RMSNorm From 2a717d6b40c715ee4283539e7e94de2af0499cdf Mon Sep 17 00:00:00 2001 From: Pedro Rodriguez Date: Thu, 20 Feb 2025 00:35:04 +0000 Subject: [PATCH 59/59] Update iterators --- .gitignore | 2 + bytelatent/args.py | 10 ++- .../data/iterators/abstract_iterator.py | 6 ++ bytelatent/data/iterators/arrow_iterator.py | 72 +++++++++-------- bytelatent/data/iterators/dev_iterators.py | 78 +++++++++++++++++++ bytelatent/data/iterators/limit_iterator.py | 47 +++++++++++ bytelatent/data/iterators/looping_iterator.py | 8 +- .../data/iterators/multiprocess_iterator.py | 10 ++- bytelatent/data/iterators/packing_iterator.py | 7 +- .../data/iterators/preprocess_iterator.py | 19 +++-- .../data/iterators/sampling_iterator.py | 9 ++- .../data/iterators/sequence_iterator.py | 28 +++++-- bytelatent/data/iterators/test_iters.py | 76 +----------------- .../data/iterators/test_limit_iterator.py | 45 +++++++++++ 14 files changed, 285 insertions(+), 132 deletions(-) create mode 100644 bytelatent/data/iterators/dev_iterators.py create mode 100644 bytelatent/data/iterators/limit_iterator.py create mode 100644 bytelatent/data/iterators/test_limit_iterator.py diff --git a/.gitignore b/.gitignore index 2d0f075..cef4d53 100644 --- a/.gitignore +++ b/.gitignore @@ -168,3 +168,5 @@ figures/ internal/ jobs_parallel-copy/ wandb/ +*.ipynb + diff --git a/bytelatent/args.py b/bytelatent/args.py index dd1fef5..8ffa717 100644 --- a/bytelatent/args.py +++ b/bytelatent/args.py @@ -72,6 +72,7 @@ def distribute_data_to_rank( arrow_batch_size: int, rank: int, world_size: int, + file_format: str, s3_profile: str | None = None, file_pattern: str = TRAIN_DATA_FILE_PATTERN, ) -> ArrowFileIterator: @@ -85,6 +86,7 @@ def distribute_data_to_rank( rank_to_arrow_iterator_params.append( ArrowFileIterator( file_path=chunk_path, + file_format=file_format, worker_id=worker_id, num_workers=n_workers_per_chunk, preprocess_dir=preprocess_dir, @@ -130,6 +132,7 @@ class DataloaderArgs(BaseModel): entropy_model_name: str | None = "transformer_100m" arrow_batch_size: int = 100 buffer_size: int = 64 + file_format: str = "arrow" pad_to_max_length: bool = True max_encoder_seq_length: int = 12288 @@ -151,6 +154,7 @@ class DataloaderArgs(BaseModel): for dataset_path in self.sources: shuffle_rng_state = get_rng_state(self.seed + 1, rank, world_size) arrow_iterator = distribute_data_to_rank( + file_format=self.file_format, dataset_path=os.path.join(self.root_dir, dataset_path), preprocess_dir=self.preprocess_dir, entropy_model_name=self.entropy_model_name, @@ -238,7 +242,7 @@ class LMHarnessArgs(BaseModel): class ValidationArgs(BaseModel): model_config = ConfigDict(extra="forbid") - max_steps: int | None = ( + max_n_docs: int | None = ( None # If None the whole validation file is used -> /!\ This number of steps is gpu dependent (100 max steps on 8 gpus = 800 steps on 1 gpu) ) use_val_from_train_src: bool = True # Use the validation set from training sources @@ -248,8 +252,8 @@ class ValidationArgs(BaseModel): class EvalArgs(BaseModel): model_config = ConfigDict(extra="forbid") - dump_dir: str - ckpt_dir: str + dump_dir: str | None = None + ckpt_dir: str | None = None metric_log_dir: str | None = None generator: PackedCausalTransformerGeneratorArgs = ( PackedCausalTransformerGeneratorArgs() diff --git a/bytelatent/data/iterators/abstract_iterator.py b/bytelatent/data/iterators/abstract_iterator.py index 8ac7f19..e80edd3 100644 --- a/bytelatent/data/iterators/abstract_iterator.py +++ b/bytelatent/data/iterators/abstract_iterator.py @@ -2,6 +2,8 @@ import abc from typing import Any, Generator, Generic, TypeVar +import pydantic + T = TypeVar("T") C = TypeVar("C") @@ -23,6 +25,10 @@ class IteratorState(Generic[C]): pass +class PydanticIteratorState(pydantic.BaseModel, IteratorState): + model_config = pydantic.ConfigDict(extra="forbid") + + def get_state_and_refresh(iterator: StatefulIterator): # Re-init dataloader and iterator is necessary since get_state() # on mp iterator shuts down MP to correctly persist state and it needs diff --git a/bytelatent/data/iterators/arrow_iterator.py b/bytelatent/data/iterators/arrow_iterator.py index 995cd02..e6f60c5 100644 --- a/bytelatent/data/iterators/arrow_iterator.py +++ b/bytelatent/data/iterators/arrow_iterator.py @@ -15,13 +15,16 @@ from pydantic import BaseModel, ConfigDict from bytelatent import ByteLatentError from bytelatent.data.data_types import BltExample from bytelatent.data.file_util import get_fs -from bytelatent.data.iterators.abstract_iterator import IteratorState, StatefulIterator +from bytelatent.data.iterators.abstract_iterator import ( + PydanticIteratorState, + StatefulIterator, +) from bytelatent.preprocess.preprocess_entropies import get_id_key, get_text logger = getLogger(__name__) -class ArrowFileIteratorState(BaseModel, IteratorState): +class ArrowFileIteratorState(PydanticIteratorState): model_config = ConfigDict(extra="forbid") file_path: str | None row_num: int @@ -110,39 +113,42 @@ class ArrowFileIterator(StatefulIterator): logger.info("Arrow iterator using fs=%s", self.fs) if dataset_files is None: - # Prepare arrow shards - jsonl_file = file_path - parts = re.match( - r"(.+)\.chunk\.[0-9]+\.jsonl", os.path.basename(jsonl_file) - ) - assert parts is not None - dataset = parts.group(1) - data_dir = os.path.join(preprocess_dir, dataset, entropy_model_name) - data_dir_with_glob = os.path.join( - data_dir, f"{os.path.basename(jsonl_file)}.shard_*.arrow" - ) - if self.fs is None: - self.fs = get_fs(data_dir_with_glob, s3_profile=s3_profile) - if isinstance(self.fs, s3fs.S3FileSystem): - self.filesystem_type = "s3" - else: - self.filesystem_type = "file" - shard_files = self.fs.glob(data_dir_with_glob) - - for s in shard_files: - complete_file = os.path.join( - data_dir, f"{os.path.basename(s)}.complete" + if file_format == "json": + self.dataset_files = [file_path] + else: + # Prepare arrow shards + jsonl_file = file_path + parts = re.match( + r"(.+)\.chunk\.[0-9]+\.jsonl", os.path.basename(jsonl_file) ) - - if not self.fs.exists(complete_file): - raise ValueError(f"Missing .complete for input file: {s}") - - shard_files = sorted(shard_files, key=shard_sort_key) - if len(shard_files) == 0: - raise ByteLatentError( - f"Zero shard_files found corresponding to: {file_path} using preprocess_dir={preprocess_dir} and entropy_model_name={entropy_model_name}, so the search path is data_dir={data_dir} for matches to {jsonl_file.name}.shard_*.arrow" + assert parts is not None + dataset = parts.group(1) + data_dir = os.path.join(preprocess_dir, dataset, entropy_model_name) + data_dir_with_glob = os.path.join( + data_dir, f"{os.path.basename(jsonl_file)}.shard_*.arrow" ) - self.dataset_files = [f for f in shard_files] + if self.fs is None: + self.fs = get_fs(data_dir_with_glob, s3_profile=s3_profile) + if isinstance(self.fs, s3fs.S3FileSystem): + self.filesystem_type = "s3" + else: + self.filesystem_type = "file" + shard_files = self.fs.glob(data_dir_with_glob) + + for s in shard_files: + complete_file = os.path.join( + data_dir, f"{os.path.basename(s)}.complete" + ) + + if not self.fs.exists(complete_file): + raise ValueError(f"Missing .complete for input file: {s}") + + shard_files = sorted(shard_files, key=shard_sort_key) + if len(shard_files) == 0: + raise ByteLatentError( + f"Zero shard_files found corresponding to: {file_path} using preprocess_dir={preprocess_dir} and entropy_model_name={entropy_model_name}, so the search path is data_dir={data_dir} for matches to {jsonl_file.name}.shard_*.arrow" + ) + self.dataset_files = [f for f in shard_files] else: self.preprocess_dir = None self.dataset_files = dataset_files diff --git a/bytelatent/data/iterators/dev_iterators.py b/bytelatent/data/iterators/dev_iterators.py new file mode 100644 index 0000000..1b33e3d --- /dev/null +++ b/bytelatent/data/iterators/dev_iterators.py @@ -0,0 +1,78 @@ +import pandas as pd +from pydantic import ConfigDict + +from bytelatent.data.data_types import BltExample +from bytelatent.data.iterators.abstract_iterator import ( + PydanticIteratorState, + StatefulIterator, +) + + +class BltTestIteratorState(PydanticIteratorState): + model_config = ConfigDict(extra="forbid") + position: int + total: int + + def build(self): + blt_iter = BltTestIteratorState(total=self.total) + blt_iter.position = self.position + return blt_iter + + +class BltTestIterator(StatefulIterator): + def __init__(self, total: int): + self.position = 0 + self.total = total + + def get_state(self): + return BltTestIteratorState(position=self.position, total=self.total) + + def create_iter(self): + for i in range(self.total): + self.position += 1 + yield BltExample( + sample_id=f"test_{i}", + text=f"This is some test {i} text.", + tokens=None, + mask=None, + entropies=None, + patch_lengths=None, + ) + + +class BltTestWithEntropiesIteratorState(PydanticIteratorState): + model_config = ConfigDict(extra="forbid") + position: int + total: int + + def build(self): + blt_iter = BltTestWithEntropiesIteratorState(total=self.total) + blt_iter.position = self.position + return blt_iter + + +class BltTestWithEntropiesIterator(StatefulIterator): + def __init__(self, total: int): + self.position = 0 + self.total = total + + def get_state(self): + return BltTestIteratorState(position=self.position, total=self.total) + + def create_iter(self): + text = "Daenerys Targaryen is in Game of Thrones, a fantasy epic by George R.R. Martin." + df = pd.read_json("fixtures/tokens_with_entropies.json") + tokens = df["token_ids"].tolist() + entropies = df["entropies"].tolist() + # BOS and EOS + assert len(tokens) == len(text) + 2 + for i in range(self.total): + self.position += 1 + yield BltExample( + sample_id=f"test_{i}", + text=text, + tokens=tokens, + mask=[True] * len(tokens), + entropies=entropies, + patch_lengths=None, + ) diff --git a/bytelatent/data/iterators/limit_iterator.py b/bytelatent/data/iterators/limit_iterator.py new file mode 100644 index 0000000..4ca43a9 --- /dev/null +++ b/bytelatent/data/iterators/limit_iterator.py @@ -0,0 +1,47 @@ +from pydantic import ConfigDict + +from bytelatent.data.iterators.abstract_iterator import ( + PydanticIteratorState, + StatefulIterator, +) +from bytelatent.data.iterators.arrow_iterator import ArrowFileIteratorState +from bytelatent.data.iterators.dev_iterators import BltTestIteratorState + + +class LimitIteratorState(PydanticIteratorState): + model_config = ConfigDict(extra="forbid") + base_iterator_state: ( + BltTestIteratorState | ArrowFileIteratorState | PydanticIteratorState + ) + n_yielded: int + limit: int + + def build(self) -> "LimitIterator": + return LimitIterator( + base_iterator=self.base_iterator_state.build(), + n_yielded=self.n_yielded, + limit=self.limit, + ) + + +class LimitIterator(StatefulIterator): + def __init__(self, base_iterator: StatefulIterator, limit: int, n_yielded: int = 0): + self.base_iterator = base_iterator + self.n_yielded = n_yielded + self.limit = limit + + def get_state(self): + return LimitIteratorState( + base_iterator_state=self.base_iterator.get_state(), + n_yielded=self.n_yielded, + limit=self.limit, + ) + + def create_iter(self): + iterator = self.base_iterator.create_iter() + try: + while self.n_yielded < self.limit or self.limit < 0: + yield next(iterator) + self.n_yielded += 1 + except StopIteration: + pass diff --git a/bytelatent/data/iterators/looping_iterator.py b/bytelatent/data/iterators/looping_iterator.py index 2eff38c..7406f61 100644 --- a/bytelatent/data/iterators/looping_iterator.py +++ b/bytelatent/data/iterators/looping_iterator.py @@ -1,14 +1,16 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. -from pydantic import BaseModel -from bytelatent.data.iterators.abstract_iterator import IteratorState, StatefulIterator +from bytelatent.data.iterators.abstract_iterator import ( + PydanticIteratorState, + StatefulIterator, +) from bytelatent.data.iterators.arrow_iterator import ( ArrowFileIterator, ArrowFileIteratorState, ) -class LoopingIteratorState(BaseModel, IteratorState): +class LoopingIteratorState(PydanticIteratorState): file_iterator_state: ArrowFileIteratorState epoch: int diff --git a/bytelatent/data/iterators/multiprocess_iterator.py b/bytelatent/data/iterators/multiprocess_iterator.py index 33bde94..b4df945 100644 --- a/bytelatent/data/iterators/multiprocess_iterator.py +++ b/bytelatent/data/iterators/multiprocess_iterator.py @@ -6,16 +6,20 @@ from multiprocessing.synchronize import Event as EventClass from queue import Empty, Full import numpy as np -from pydantic import BaseModel, ConfigDict +from pydantic import ConfigDict from bytelatent.data.data_types import Batch -from bytelatent.data.iterators.abstract_iterator import IteratorState, StatefulIterator +from bytelatent.data.iterators.abstract_iterator import ( + IteratorState, + PydanticIteratorState, + StatefulIterator, +) from bytelatent.data.iterators.packing_iterator import PackingIteratorState logger = logging.getLogger() -class MultiprocessIteratorState(BaseModel, IteratorState): +class MultiprocessIteratorState(PydanticIteratorState): model_config = ConfigDict(extra="forbid") base_iterator_state: PackingIteratorState n_batches_to_prefetch: int diff --git a/bytelatent/data/iterators/packing_iterator.py b/bytelatent/data/iterators/packing_iterator.py index fa29149..5ed280d 100644 --- a/bytelatent/data/iterators/packing_iterator.py +++ b/bytelatent/data/iterators/packing_iterator.py @@ -5,7 +5,10 @@ import numpy as np from pydantic import BaseModel, ConfigDict from bytelatent.data.data_types import Batch, BltSequence -from bytelatent.data.iterators.abstract_iterator import IteratorState, StatefulIterator +from bytelatent.data.iterators.abstract_iterator import ( + PydanticIteratorState, + StatefulIterator, +) from bytelatent.data.iterators.sampling_iterator import SamplingIteratorState @@ -20,7 +23,7 @@ class PackingArgs(BaseModel): tokenizer_name: str -class PackingIteratorState(BaseModel, IteratorState): +class PackingIteratorState(PydanticIteratorState): model_config = ConfigDict(extra="forbid") sequence_iterator_state: SamplingIteratorState packing_args: PackingArgs diff --git a/bytelatent/data/iterators/preprocess_iterator.py b/bytelatent/data/iterators/preprocess_iterator.py index 8eeba41..f72364d 100644 --- a/bytelatent/data/iterators/preprocess_iterator.py +++ b/bytelatent/data/iterators/preprocess_iterator.py @@ -5,20 +5,29 @@ import torch from pydantic import BaseModel, ConfigDict from bytelatent.data.data_types import BltExample -from bytelatent.data.iterators.abstract_iterator import IteratorState, StatefulIterator +from bytelatent.data.iterators.abstract_iterator import ( + PydanticIteratorState, + StatefulIterator, +) from bytelatent.data.iterators.arrow_iterator import ( ArrowFileIterator, ArrowFileIteratorState, ) -from bytelatent.data.iterators.looping_iterator import LoopingIteratorState +from bytelatent.data.iterators.limit_iterator import LimitIterator, LimitIteratorState +from bytelatent.data.iterators.looping_iterator import ( + LoopingIterator, + LoopingIteratorState, +) from bytelatent.data.patcher import Patcher, PatcherArgs, PatchingModeEnum from bytelatent.tokenizers.blt_tokenizer import BltTokenizer from bytelatent.tokenizers.build_tokenizer import TokenizerArgs -class PreprocessIteratorState(BaseModel, IteratorState): +class PreprocessIteratorState(PydanticIteratorState): model_config = ConfigDict(extra="forbid") - arrow_file_iterator_state: ArrowFileIteratorState | LoopingIteratorState + arrow_file_iterator_state: ( + ArrowFileIteratorState | LoopingIteratorState | LimitIteratorState + ) add_tokens: bool add_patches: bool tokenizer_args: TokenizerArgs @@ -43,7 +52,7 @@ class PreprocessIterator(StatefulIterator): def __init__( self, - arrow_iterator: ArrowFileIterator, + arrow_iterator: ArrowFileIterator | LoopingIterator | LimitIterator, *, patcher_args: PatcherArgs, tokenizer_args: TokenizerArgs, diff --git a/bytelatent/data/iterators/sampling_iterator.py b/bytelatent/data/iterators/sampling_iterator.py index 6474bf6..170f215 100644 --- a/bytelatent/data/iterators/sampling_iterator.py +++ b/bytelatent/data/iterators/sampling_iterator.py @@ -2,13 +2,16 @@ from typing import Any import numpy as np -from pydantic import BaseModel, ConfigDict +from pydantic import ConfigDict -from bytelatent.data.iterators.abstract_iterator import StatefulIterator +from bytelatent.data.iterators.abstract_iterator import ( + PydanticIteratorState, + StatefulIterator, +) from bytelatent.data.iterators.sequence_iterator import SequenceIteratorState -class SamplingIteratorState(BaseModel): +class SamplingIteratorState(PydanticIteratorState): model_config = ConfigDict(extra="forbid") rng_state: dict[str, Any] source_to_weight: dict[str, float] diff --git a/bytelatent/data/iterators/sequence_iterator.py b/bytelatent/data/iterators/sequence_iterator.py index d90ea31..0a492be 100644 --- a/bytelatent/data/iterators/sequence_iterator.py +++ b/bytelatent/data/iterators/sequence_iterator.py @@ -6,7 +6,10 @@ import numpy as np from pydantic import BaseModel, ConfigDict from bytelatent.data.data_types import BltSequence -from bytelatent.data.iterators.abstract_iterator import IteratorState, StatefulIterator +from bytelatent.data.iterators.abstract_iterator import ( + PydanticIteratorState, + StatefulIterator, +) from bytelatent.data.iterators.preprocess_iterator import ( PreprocessIterator, PreprocessIteratorState, @@ -21,11 +24,12 @@ class SequencePackingArgs(BaseModel): buffer_size: int -class SequenceIteratorState(BaseModel, IteratorState): +class SequenceIteratorState(PydanticIteratorState): model_config = ConfigDict(extra="forbid") sequence_packing_args: SequencePackingArgs preprocess_iterator_state: PreprocessIteratorState - rng_state: dict[str, Any] + # If None, rng is disabled. + rng_state: dict[str, Any] | None def build(self): preprocess_iterator = self.preprocess_iterator_state.build() @@ -41,22 +45,25 @@ class SequenceIterator(StatefulIterator): self, preprocess_iterator: PreprocessIterator, *, - rng_state: dict[str, Any], + rng_state: dict[str, Any] | None, sequence_packing_args: SequencePackingArgs, ): self.preprocess_iterator = preprocess_iterator self.sequence_packing_args = sequence_packing_args self.output_seq_len = sequence_packing_args.output_seq_len self.buffer_size = sequence_packing_args.buffer_size - self.rng = np.random.default_rng() - self.rng.bit_generator.state = rng_state + if rng_state is None: + self.rng = None + else: + self.rng = np.random.default_rng() + self.rng.bit_generator.state = rng_state def get_state(self): # TODO: need to also perist the current shuffle buffer return SequenceIteratorState( sequence_packing_args=self.sequence_packing_args, preprocess_iterator_state=self.preprocess_iterator.get_state(), - rng_state=self.rng.bit_generator.state, + rng_state=None if self.rng is None else self.rng.bit_generator.state, ) def create_iter(self): @@ -114,7 +121,12 @@ class SequenceIterator(StatefulIterator): seq_patch_lengths: list[list[int]] = x_patches.tolist() assert len(seq_patch_lengths) == self.buffer_size - for idx in self.rng.permutation(len(seq_patch_lengths)): + if self.rng is None: + permutations = list(range(len(seq_patch_lengths))) + else: + permutations = self.rng.permutation(len(seq_patch_lengths)) + + for idx in permutations: assert len(seq_patch_lengths[idx]) == self.output_seq_len assert ( sum(seq_patch_lengths[idx]) diff --git a/bytelatent/data/iterators/test_iters.py b/bytelatent/data/iterators/test_iters.py index 9bc9d59..4749c8a 100644 --- a/bytelatent/data/iterators/test_iters.py +++ b/bytelatent/data/iterators/test_iters.py @@ -1,83 +1,15 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. -import pandas as pd -from pydantic import BaseModel from bytelatent.constants import BLT_DATA -from bytelatent.data.data_types import BltExample -from bytelatent.data.iterators.abstract_iterator import IteratorState, StatefulIterator +from bytelatent.data.iterators.dev_iterators import ( + BltTestIterator, + BltTestWithEntropiesIterator, +) from bytelatent.data.iterators.preprocess_iterator import PreprocessIterator from bytelatent.data.patcher import PatcherArgs, PatchingModeEnum from bytelatent.tokenizers.build_tokenizer import TokenizerArgs -class BltTestIteratorState(BaseModel, IteratorState): - position: int - total: int - - def build(self): - blt_iter = BltTestIteratorState(total=self.total) - blt_iter.position = self.position - return blt_iter - - -class BltTestIterator(StatefulIterator): - def __init__(self, total: int): - self.position = 0 - self.total = total - - def get_state(self): - return BltTestIteratorState(position=self.position, total=self.total) - - def create_iter(self): - for i in range(self.total): - self.position += 1 - yield BltExample( - sample_id=f"test_{i}", - text=f"This is some test {i} text.", - tokens=None, - mask=None, - entropies=None, - patch_lengths=None, - ) - - -class BltTestWithEntropiesIteratorState(BaseModel, IteratorState): - position: int - total: int - - def build(self): - blt_iter = BltTestWithEntropiesIteratorState(total=self.total) - blt_iter.position = self.position - return blt_iter - - -class BltTestWithEntropiesIterator(StatefulIterator): - def __init__(self, total: int): - self.position = 0 - self.total = total - - def get_state(self): - return BltTestIteratorState(position=self.position, total=self.total) - - def create_iter(self): - text = "Daenerys Targaryen is in Game of Thrones, a fantasy epic by George R.R. Martin." - df = pd.read_json("fixtures/tokens_with_entropies.json") - tokens = df["token_ids"].tolist() - entropies = df["entropies"].tolist() - # BOS and EOS - assert len(tokens) == len(text) + 2 - for i in range(self.total): - self.position += 1 - yield BltExample( - sample_id=f"test_{i}", - text=text, - tokens=tokens, - mask=[True] * len(tokens), - entropies=entropies, - patch_lengths=None, - ) - - def test_preprocess_iter(): total = 3 tokenizer_args = TokenizerArgs( diff --git a/bytelatent/data/iterators/test_limit_iterator.py b/bytelatent/data/iterators/test_limit_iterator.py new file mode 100644 index 0000000..47d5c27 --- /dev/null +++ b/bytelatent/data/iterators/test_limit_iterator.py @@ -0,0 +1,45 @@ +from bytelatent.data.iterators.dev_iterators import BltTestIterator +from bytelatent.data.iterators.limit_iterator import LimitIterator + + +def test_limit_iterator(): + total = 10 + limit = 5 + base_iterator = BltTestIterator(total=total) + limit_iterator = LimitIterator(base_iterator, limit=limit) + iterator = limit_iterator.create_iter() + n = 0 + for example in iterator: + assert example.sample_id == f"test_{n}" + n += 1 + assert n == limit + + limit = 10 + base_iterator = BltTestIterator(total=total) + limit_iterator = LimitIterator(base_iterator, limit=limit) + iterator = limit_iterator.create_iter() + n = 0 + for example in iterator: + assert example.sample_id == f"test_{n}" + n += 1 + assert n == limit == total + + limit = 20 + base_iterator = BltTestIterator(total=total) + limit_iterator = LimitIterator(base_iterator, limit=limit) + iterator = limit_iterator.create_iter() + n = 0 + for example in iterator: + assert example.sample_id == f"test_{n}" + n += 1 + assert n == total + + limit = -1 + base_iterator = BltTestIterator(total=total) + limit_iterator = LimitIterator(base_iterator, limit=limit) + iterator = limit_iterator.create_iter() + n = 0 + for example in iterator: + assert example.sample_id == f"test_{n}" + n += 1 + assert n == total