From 48cf4dfee14d060c6be1615484c85d34ac34296f Mon Sep 17 00:00:00 2001 From: Pedro Rodriguez Date: Wed, 5 Feb 2025 22:47:00 +0000 Subject: [PATCH 1/3] Allow ArrowIterator to read from json Summary: Test Plan: --- bytelatent/args.py | 38 +++++++- bytelatent/data/iterators/arrow_iterator.py | 97 +++++++++++-------- bytelatent/preprocess/preprocess_entropies.py | 26 +++-- 3 files changed, 104 insertions(+), 57 deletions(-) diff --git a/bytelatent/args.py b/bytelatent/args.py index fc72b32..263e8e3 100644 --- a/bytelatent/args.py +++ b/bytelatent/args.py @@ -1,8 +1,10 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. +import json import logging import os from typing import Any +import fsspec import numpy as np import yaml from omegaconf import OmegaConf @@ -10,11 +12,9 @@ from pydantic import BaseModel, ConfigDict from bytelatent.checkpoint import CheckpointArgs from bytelatent.data.data_types import Batch +from bytelatent.data.file_util import get_fs from bytelatent.data.iterators.abstract_iterator import StatefulIterator -from bytelatent.data.iterators.arrow_iterator import ( - ArrowFileIterator, - find_and_sanitize_chunks, -) +from bytelatent.data.iterators.arrow_iterator import ArrowFileIterator from bytelatent.data.iterators.looping_iterator import LoopingIterator from bytelatent.data.iterators.multiprocess_iterator import MultiprocessIterator from bytelatent.data.iterators.packing_iterator import PackingArgs, PackingIterator @@ -53,6 +53,33 @@ def parse_args(args_cls): return pydantic_args +TRAIN_DATA_FILE_PATTERN = "*.chunk.*.jsonl" + + +def find_and_sanitize_chunks( + dataset_path: str, + world_size: int, + file_pattern: str, + s3_profile: str | None = None, +): + fs = get_fs(dataset_path, s3_profile=s3_profile) + path_with_glob = os.path.join(dataset_path, file_pattern) + dataset_chunks = fs.glob(path_with_glob) + n_chunks = len(dataset_chunks) + + if n_chunks > world_size: + n_discard = n_chunks - world_size + dataset_chunks = dataset_chunks[:world_size] + else: + assert ( + world_size % n_chunks == 0 + ), "World size should be a multiple of number of chunks" + + assert n_chunks > 0, f"No valid chunks in {dataset_path}" + + return dataset_chunks + + def distribute_data_to_rank( *, dataset_path: str, @@ -62,9 +89,10 @@ def distribute_data_to_rank( rank: int, world_size: int, s3_profile: str | None = None, + file_pattern: str = TRAIN_DATA_FILE_PATTERN, ) -> ArrowFileIterator: dataset_chunks = find_and_sanitize_chunks( - dataset_path, world_size, s3_profile=s3_profile + dataset_path, world_size, file_pattern, s3_profile=s3_profile ) n_workers_per_chunk = world_size // len(dataset_chunks) rank_to_arrow_iterator_params = [] diff --git a/bytelatent/data/iterators/arrow_iterator.py b/bytelatent/data/iterators/arrow_iterator.py index 4e7b99e..1c68d3a 100644 --- a/bytelatent/data/iterators/arrow_iterator.py +++ b/bytelatent/data/iterators/arrow_iterator.py @@ -16,6 +16,7 @@ from bytelatent import ByteLatentError from bytelatent.data.data_types import BltExample from bytelatent.data.file_util import get_fs from bytelatent.data.iterators.abstract_iterator import IteratorState, StatefulIterator +from bytelatent.preprocess.preprocess_entropies import get_id_key, get_text logger = getLogger(__name__) @@ -32,6 +33,7 @@ class ArrowFileIteratorState(BaseModel, IteratorState): arrow_batch_size: int = 100 s3_profile: str | None filesystem_type: str | None = None + file_format: str def build(self) -> "ArrowFileIterator": arrow_file = ArrowFileIterator( @@ -44,6 +46,7 @@ class ArrowFileIteratorState(BaseModel, IteratorState): dataset_files=self.dataset_files, s3_profile=self.s3_profile, filesystem_type=self.filesystem_type, + file_format=self.file_format, ) if self.row_num != 0: arrow_file._set_row_num(self.row_num) @@ -70,6 +73,7 @@ class ArrowFileIterator(StatefulIterator): dataset_files: list[str] | None = None, s3_profile: str | None = None, filesystem_type: str | None = None, + file_format: str = "arrow", ): assert 0 <= worker_id < num_workers, (worker_id, num_workers) if file_path is None and dataset_files is None: @@ -87,12 +91,16 @@ class ArrowFileIterator(StatefulIterator): self.arrow_batch_size = arrow_batch_size self.s3_profile = s3_profile self.filesystem_type = filesystem_type + self.file_format = file_format self.fs = None if self.filesystem_type is not None: if self.filesystem_type == "file": self.fs = fsspec.filesystem("file") elif self.filesystem_type == "s3": self.fs = fsspec.filesystem("s3", profile=s3_profile) + else: + raise ValueError("Unknown filesystem") + logger.info("Arrow iterator using fs=%s", self.fs) if dataset_files is None: # Prepare arrow shards @@ -153,6 +161,7 @@ class ArrowFileIterator(StatefulIterator): dataset_files=self.dataset_files, s3_profile=self.s3_profile, filesystem_type=self.filesystem_type, + file_format=self.file_format, ) def create_iter( @@ -164,7 +173,7 @@ class ArrowFileIterator(StatefulIterator): else: filesystem = None self.dataset = pa.dataset.dataset( - self.dataset_files, format="arrow", filesystem=filesystem + self.dataset_files, format=self.file_format, filesystem=filesystem ) self.batch_iterator = self.dataset.to_batches( batch_size=self.arrow_batch_size @@ -173,13 +182,22 @@ class ArrowFileIterator(StatefulIterator): if self.batch_to_consume is not None: batch_columns: dict[str, list] = self.batch_to_consume self.batch_to_consume = None - sample_ids = batch_columns["sample_id"] - texts = batch_columns["text"] - entropies = batch_columns["entropies"] + if self.file_format == "arrow": + sample_ids = batch_columns["sample_id"] + texts = batch_columns["text"] + entropies = batch_columns["entropies"] + elif self.file_format == "json": + # This data hasn't been preprocessed to a uniform format, + # so we have to do it now and omit entropies + sample_ids = batch_columns[get_id_key(batch_columns)] + texts = get_text(batch_columns) + entropies = None + else: + raise ValueError(f"Unknown file format: {self.file_format}") for i in range(len(sample_ids)): out = BltExample( sample_id=sample_ids[i], - entropies=entropies[i], + entropies=entropies[i] if entropies is not None else None, text=texts[i], tokens=None, mask=None, @@ -191,13 +209,22 @@ class ArrowFileIterator(StatefulIterator): for batch in self.batch_iterator: batch_columns = batch.to_pydict() - sample_ids = batch_columns["sample_id"] - texts = batch_columns["text"] - entropies = batch_columns["entropies"] + if self.file_format == "arrow": + sample_ids = batch_columns["sample_id"] + texts = batch_columns["text"] + entropies = batch_columns["entropies"] + elif self.file_format == "json": + # This data hasn't been preprocessed to a uniform format, + # so we have to do it now and omit entropies + sample_ids = batch_columns[get_id_key(batch_columns)] + texts = get_text(batch_columns) + entropies = None + else: + raise ValueError(f"Unknown file format: {self.file_format}") for i in range(len(sample_ids)): out = BltExample( sample_id=sample_ids[i], - entropies=entropies[i], + entropies=entropies[i] if entropies is not None else None, text=texts[i], tokens=None, mask=None, @@ -231,13 +258,24 @@ class ArrowFileIterator(StatefulIterator): for batch in self.batch_iterator: if len(batch) > curr_remaining: batch_columns: dict[str, list] = batch.to_pydict() - batch_columns["sample_id"] = batch_columns["sample_id"][ - curr_remaining: - ] - batch_columns["entropies"] = batch_columns["entropies"][ - curr_remaining: - ] - batch_columns["text"] = batch_columns["text"][curr_remaining:] + if self.file_format == "arrow": + leftover_sample_ids = batch_columns["sample_id"][ + curr_remaining: + ] + leftover_entropies = batch_columns["entropies"][curr_remaining:] + leftover_texts = batch_columns["text"][curr_remaining:] + elif self.file_format == "json": + leftover_sample_ids = batch_columns[get_id_key(batch_columns)][ + curr_remaining: + ] + leftover_entropies = None + leftover_texts = get_text(batch_columns)[curr_remaining:] + else: + raise ValueError(f"Unknown file format: {self.file_format}") + + batch_columns["sample_id"] = leftover_sample_ids + batch_columns["entropies"] = leftover_entropies + batch_columns["text"] = leftover_texts self.batch_to_consume = batch_columns break elif len(batch) == curr_remaining: @@ -250,30 +288,3 @@ class ArrowFileIterator(StatefulIterator): logger.info( f"Finished setting arrow position to {target_row_num} for {self.dataset_files}" ) - - -TRAIN_DATA_FILE_PATTERN = "*.chunk.*.jsonl" - - -def find_and_sanitize_chunks( - dataset_path: str, - world_size: int, - file_pattern: str = TRAIN_DATA_FILE_PATTERN, - s3_profile: str | None = None, -): - fs = get_fs(dataset_path, s3_profile=s3_profile) - path_with_glob = os.path.join(dataset_path, file_pattern) - dataset_chunks = fs.glob(path_with_glob) - n_chunks = len(dataset_chunks) - - if n_chunks > world_size: - n_discard = n_chunks - world_size - dataset_chunks = dataset_chunks[:world_size] - else: - assert ( - world_size % n_chunks == 0 - ), "World size should be a multiple of number of chunks" - - assert n_chunks > 0, f"No valid chunks in {dataset_path}" - - return dataset_chunks diff --git a/bytelatent/preprocess/preprocess_entropies.py b/bytelatent/preprocess/preprocess_entropies.py index 519da94..31a4802 100644 --- a/bytelatent/preprocess/preprocess_entropies.py +++ b/bytelatent/preprocess/preprocess_entropies.py @@ -15,29 +15,37 @@ from bytelatent.entropy_model import load_entropy_model from bytelatent.tokenizers.build_tokenizer import TokenizerArgs -def get_id_from_doc(doc: dict) -> int: +def get_id_key(doc: dict) -> int: """ We need a reliable way to ensure that samples from jsonl and arrow are the same, but there is no unique id field, so derive the best possible """ if "sample_id" in doc: - sample_id = doc["sample_id"] + return "sample_id" elif "title" in doc: - sample_id = doc["title"] + return "title" elif "qid" in doc: - sample_id = doc["qid"] + return "qid" elif "paper_id" in doc: - sample_id = doc["paper_id"] + return "paper_id" elif "path" in doc: - sample_id = doc["path"] + return "path" elif "url" in doc: - sample_id = doc["url"] + return "url" elif "id" in doc: - sample_id = doc["id"] + return "id" else: raise ValueError(f"Could not find a id key from: {doc.keys()}") - return str(sample_id) + + +def get_id_from_doc(doc: dict) -> int: + """ + We need a reliable way to ensure that samples from jsonl + and arrow are the same, but there is no unique id field, + so derive the best possible + """ + return str(doc[get_id_key(doc)]) def get_text(doc: dict): From 8f1a9a858e06c225b540e257b451b9223f5672eb Mon Sep 17 00:00:00 2001 From: Pedro Rodriguez Date: Wed, 5 Feb 2025 22:47:01 +0000 Subject: [PATCH 2/3] Minimal working eval Summary: Test Plan: --- bytelatent/eval.py | 64 +++++++++++++++++++++++++++++------------- bytelatent/generate.py | 6 ++-- 2 files changed, 48 insertions(+), 22 deletions(-) diff --git a/bytelatent/eval.py b/bytelatent/eval.py index ae73066..1943a33 100644 --- a/bytelatent/eval.py +++ b/bytelatent/eval.py @@ -15,9 +15,16 @@ from lm_eval.api.model import LM from omegaconf import OmegaConf from pydantic import BaseModel, ConfigDict -from bytelatent.args import EvalArgs, ValidationArgs, parse_args +from bytelatent.args import ( + EvalArgs, + TrainArgs, + ValidationArgs, + find_and_sanitize_chunks, + parse_args, +) from bytelatent.checkpoint import CONSOLIDATE_FOLDER, consolidate_checkpoints from bytelatent.data.file_util import get_fs +from bytelatent.data.iterators.arrow_iterator import ArrowFileIterator from bytelatent.distributed import ( DistributedArgs, dist_mean_dict, @@ -117,19 +124,40 @@ class EvalHarnessLM(LM): return results -def eval_on_val(generator, val_args: ValidationArgs, train_cfg): - srcs = {} +def eval_on_val(generator, val_args: ValidationArgs, train_cfg: TrainArgs): + srcs = [] for src in val_args.sources: path = os.path.join(val_args.root_dir, src) - srcs[path] = 1.0 + srcs.append(path) + for src in train_cfg.data.sources: path = os.path.join(train_cfg.data.root_dir, src) - srcs[path] = 1.0 + srcs.append(path) - multi_state = init_choice_state( - "", srcs, 0, get_global_rank(), get_world_size(), "*.val.jsonl" - ) - path_to_iter = setup_sources(multi_state) + path_to_iter = {} + for path in srcs: + chunks = find_and_sanitize_chunks( + path, + world_size=1, + file_pattern="*.val.jsonl", + s3_profile=train_cfg.data.s3_profile, + ) + assert ( + len(chunks) == 1 + ), f"There should be only 1 chunk per validation file, but found: {chunks}" + chunk = chunks[0] + iterator = ArrowFileIterator( + dataset_files=[chunk], + file_path=None, + preprocess_dir=None, + entropy_model_name=None, + worker_id=0, + num_workers=1, + arrow_batch_size=train_cfg.data.arrow_batch_size, + s3_profile=train_cfg.data.s3_profile, + file_format="json", + ) + path_to_iter[path] = iterator max_gen_len = generator.max_gen_len # We temporarily lower max gen len @@ -137,16 +165,11 @@ def eval_on_val(generator, val_args: ValidationArgs, train_cfg): all_val_metrics = {} for src in path_to_iter: - jsonl_iterator = path_to_iter[src] + example_iterator = path_to_iter[src].create_iter() texts = [] logger.info(f"Running validation on {src}...") - for step, (content, state) in enumerate(jsonl_iterator): - if state["current_iter"] > 0 or ( - val_args.max_steps is not None and step >= val_args.max_steps - ): - break - content_key = "text" if ("text" in content) else "content" - texts.append(content[content_key]) + for step, example in enumerate(example_iterator): + texts.append(example.text) _, loglikelihood, _ = generator.generate(texts) @@ -191,7 +214,7 @@ def launch_eval(eval_args: EvalArgs): else: consolidate_path = os.path.join(eval_args.ckpt_dir, CONSOLIDATE_FOLDER) if not fs.exists(consolidate_path) and get_global_rank() == 0: - consolidate_path = consolidate_checkpoints(eval_args.ckpt_dir) + consolidate_path = consolidate_checkpoints(fs, eval_args.ckpt_dir) fs.mkdirs(eval_args.dump_dir, exist_ok=True) with fs.open(os.path.join(eval_args.dump_dir, "config.yaml"), "w") as f: @@ -210,10 +233,12 @@ def launch_eval(eval_args: EvalArgs): wrap = EvalHarnessLM(generator) # Redo - results = simple_evaluate(wrap, eval_args.harness.model_dump()) + results = simple_evaluate(wrap, **eval_args.harness.model_dump()) + val_results = None if eval_args.validation: val_results = eval_on_val(generator, eval_args.validation, train_cfg) + if get_global_rank() == 0: with fs.open(os.path.join(eval_args.dump_dir, "results.json"), "w") as f: f.write(json.dumps(results)) @@ -222,6 +247,7 @@ def launch_eval(eval_args: EvalArgs): with fs.open(os.path.join(eval_args.dump_dir, "validation.json"), "w") as f: f.write(json.dumps(val_results)) logger.info(f"All validation results: {val_results}") + if eval_args.metric_log_dir and get_global_rank() == 0: metric_log_path = os.path.join(eval_args.metric_log_dir, "metrics.eval.jsonl") diff --git a/bytelatent/generate.py b/bytelatent/generate.py index eb79d81..9d44f30 100644 --- a/bytelatent/generate.py +++ b/bytelatent/generate.py @@ -387,8 +387,7 @@ def load_consolidated_model_and_tokenizer( ): train_args_path = os.path.join(consolidated_path, "params.json") fs = get_fs(train_args_path) - with fs.open(train_args_path) as f: - train_args = TrainArgs.model_validate_json(f.read()) + train_args = TrainArgs.model_validate_json(fs.read_text(train_args_path)) if train_args.train_entropy_model: model_args = train_args.entropy_model @@ -401,7 +400,8 @@ def load_consolidated_model_and_tokenizer( train_args.distributed.model_dtype ] tokenizer = train_args.data.tokenizer_args.build() - st_dict = torch.load(consolidated_path / CONSOLIDATE_NAME, weights_only=True) + with fs.open(os.path.join(consolidated_path, CONSOLIDATE_NAME)) as f: + st_dict = torch.load(f, weights_only=True) model.load_state_dict(st_dict["model"]) model = model.cuda().eval() for param in model.parameters(): From 45bfe94c1eb16d89ac4553a040c0647573638de3 Mon Sep 17 00:00:00 2001 From: Pedro Rodriguez Date: Thu, 6 Feb 2025 01:27:15 +0000 Subject: [PATCH 3/3] Broken train reproducing bf16 error Summary: Test Plan: --- bytelatent/broken_train.py | 623 +++++++++++++++++++++++++++++++++++++ bytelatent/train.py | 2 +- 2 files changed, 624 insertions(+), 1 deletion(-) create mode 100644 bytelatent/broken_train.py diff --git a/bytelatent/broken_train.py b/bytelatent/broken_train.py new file mode 100644 index 0000000..e1630a7 --- /dev/null +++ b/bytelatent/broken_train.py @@ -0,0 +1,623 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. + +from datetime import timedelta +from enum import Enum +from functools import lru_cache +import logging +import math +import sys +import time +from typing import Dict, List, Optional, Tuple +from torch import Tensor +import os +import pickle + +import fsspec +import torch +import torch.distributed +import torch.nn.functional +import torch.nn.functional as F +from torch.distributed._tensor import DTensor + +from torch.distributed.device_mesh import init_device_mesh +from torch.utils._foreach_utils import ( + _device_has_foreach_support, + _group_tensors_by_device_and_dtype, + _has_foreach_support, +) + +from bytelatent.args import TrainArgs +from bytelatent.distributed import ( + DistributedArgs, + check_model_value_range, + parallelize_model, + setup_env, + setup_torch_distributed, +) + +logger = logging.getLogger() + + +def set_root_log_level(log_level: str): + logger = logging.getLogger() + level: int | str = log_level.upper() + try: + level = int(log_level) + except ValueError: + pass + try: + logger.setLevel(level) # type: ignore + except Exception: + logger.warning( + f"Failed to set logging level to {log_level}, using default 'NOTSET'" + ) + logger.setLevel(logging.NOTSET) + + +class LogFormatter(logging.Formatter): + """ + Custom logger for distributed jobs, displaying rank + and preserving indent from the custom prefix format. + """ + + def __init__(self): + self.start_time = time.time() + self.rank = get_global_rank() + self.show_rank = not get_is_slurm_job() # srun has --label + + def formatTime(self, record): + subsecond, seconds = math.modf(record.created) + curr_date = ( + time.strftime("%y-%m-%d %H:%M:%S", time.localtime(seconds)) + + f".{int(subsecond * 1_000_000):06d}" + ) + delta = timedelta(seconds=round(record.created - self.start_time)) + return f"{curr_date} - {delta}" + + def formatPrefix(self, record): + fmt_time = self.formatTime(record) + if self.show_rank: + return f"{self.rank}: {record.levelname:<7} {fmt_time} - " + else: + return f"{record.levelname:<7} {fmt_time} - " + + def formatMessage(self, record, indent: str): + content = record.getMessage() + content = content.replace("\n", "\n" + indent) + # Exception handling as in the default formatter, albeit with indenting + # according to our custom prefix + if record.exc_info: + # Cache the traceback text to avoid converting it multiple times + # (it's constant anyway) + if not record.exc_text: + record.exc_text = self.formatException(record.exc_info) + if record.exc_text: + if content[-1:] != "\n": + content = content + "\n" + indent + content = content + indent.join( + [l + "\n" for l in record.exc_text.splitlines()] + ) + if content[-1:] == "\n": + content = content[:-1] + if record.stack_info: + if content[-1:] != "\n": + content = content + "\n" + indent + stack_text = self.formatStack(record.stack_info) + content = content + indent.join([l + "\n" for l in stack_text.splitlines()]) + if content[-1:] == "\n": + content = content[:-1] + + return content + + def format(self, record): + prefix = self.formatPrefix(record) + indent = " " * len(prefix) + content = self.formatMessage(record, indent) + return prefix + content + + +def init_logger( + log_file: str | None = None, + *, + name: str | None = None, + level: str = "INFO", + fs: fsspec.AbstractFileSystem | None = None, +): + """ + Setup logging. + + Args: + log_file: A file name to save file logs to. + name: The name of the logger to configure, by default the root logger. + level: The logging level to use. + """ + set_root_log_level(level) + logger = logging.getLogger(name) + + # stdout: everything + stdout_handler = logging.StreamHandler(sys.stdout) + stdout_handler.setLevel(logging.NOTSET) + stdout_handler.setFormatter(LogFormatter()) + + # stderr: warnings / errors and above + stderr_handler = logging.StreamHandler(sys.stderr) + stderr_handler.setLevel(logging.WARNING) + stderr_handler.setFormatter(LogFormatter()) + + # set stream handlers + logger.handlers.clear() + logger.handlers.append(stdout_handler) + logger.handlers.append(stderr_handler) + + +@torch.no_grad() +def fixed_clip_grad_norm_( + parameters: torch.Tensor | list[torch.Tensor], + max_norm: float, + norm_type: float = 2.0, + error_if_nonfinite: bool = False, + foreach: Optional[bool] = None, +) -> torch.Tensor: + r"""Clip the gradient norm of an iterable of parameters. + + The norm is computed over the norms of the individual gradients of all parameters, + as if the norms of the individual gradients were concatenated into a single vector. + Gradients are modified in-place. + + Args: + parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a + single Tensor that will have gradients normalized + max_norm (float): max norm of the gradients + norm_type (float): type of the used p-norm. Can be ``'inf'`` for + infinity norm. + error_if_nonfinite (bool): if True, an error is thrown if the total + norm of the gradients from :attr:`parameters` is ``nan``, + ``inf``, or ``-inf``. Default: False (will switch to True in the future) + foreach (bool): use the faster foreach-based implementation. + If ``None``, use the foreach implementation for CUDA and CPU native tensors and silently + fall back to the slow implementation for other device types. + Default: ``None`` + + Returns: + Total norm of the parameter gradients (viewed as a single vector). + """ + if isinstance(parameters, torch.Tensor): + parameters = [parameters] + grads = [p.grad.to(torch.bfloat16) for p in parameters if p.grad is not None] + max_norm = float(max_norm) + norm_type = float(norm_type) + if len(grads) == 0: + return torch.tensor(0.0) + first_device = grads[0].device + grouped_grads: Dict[ + Tuple[torch.device, torch.dtype], Tuple[List[List[Tensor]], List[int]] + ] = _group_tensors_by_device_and_dtype( + [grads] + ) # type: ignore[assignment] + + norms: List[Tensor] = [] + for (device, _), ([device_grads], _) in grouped_grads.items(): # type: ignore[assignment] + if (foreach is None and _has_foreach_support(device_grads, device)) or ( + foreach and _device_has_foreach_support(device) + ): + norms.extend(torch._foreach_norm(device_grads, norm_type)) + elif foreach: + raise RuntimeError( + f"foreach=True was passed, but can't use the foreach API on {device.type} tensors" + ) + else: + norms.extend([torch.linalg.vector_norm(g, norm_type) for g in device_grads]) + + total_norm = torch.linalg.vector_norm( + torch.stack([norm.to(first_device) for norm in norms]), norm_type + ) + + if error_if_nonfinite and torch.logical_or(total_norm.isnan(), total_norm.isinf()): + raise RuntimeError( + f"The total norm of order {norm_type} for gradients from " + "`parameters` is non-finite, so it cannot be clipped. To disable " + "this error and scale the gradients by the non-finite norm anyway, " + "set `error_if_nonfinite=False`" + ) + clip_coef = max_norm / (total_norm + 1e-6) + # Note: multiplying by the clamped coef is redundant when the coef is clamped to 1, but doing so + # avoids a `if clip_coef < 1:` conditional which can require a CPU <=> device synchronization + # when the gradients do not reside in CPU memory. + clip_coef_clamped = torch.clamp(clip_coef, max=1.0) + for (device, _), ([device_grads], _) in grouped_grads.items(): # type: ignore[assignment] + if (foreach is None and _has_foreach_support(device_grads, device)) or ( + foreach and _device_has_foreach_support(device) + ): + torch._foreach_mul_(device_grads, clip_coef_clamped.to(device)) + elif foreach: + raise RuntimeError( + f"foreach=True was passed, but can't use the foreach API on {device.type} tensors" + ) + else: + clip_coef_clamped_device = clip_coef_clamped.to(device) + for g in device_grads: + g.mul_(clip_coef_clamped_device) + + return total_norm + + +def get_no_recompute_ops(): + return None + + +@lru_cache() +def get_is_torch_run() -> bool: + return os.environ.get("LOCAL_RANK") is not None + + +@lru_cache() +def get_is_slurm_job() -> bool: + return "SLURM_JOB_ID" in os.environ and not get_is_torch_run() + + +@lru_cache() +def get_global_rank() -> int: + if get_is_torch_run(): + return int(os.environ["RANK"]) + elif get_is_slurm_job(): + return int(os.environ["SLURM_PROCID"]) + else: + return 0 + + +@lru_cache() +def get_local_rank() -> int: + if get_is_torch_run(): + return int(os.environ["LOCAL_RANK"]) + elif get_is_slurm_job(): + return int(os.environ["SLURM_LOCALID"]) + else: + return 0 + + +@lru_cache() +def get_world_size() -> int: + if get_is_torch_run(): + return int(os.environ["WORLD_SIZE"]) + elif get_is_slurm_job(): + return int(os.environ["SLURM_NTASKS"]) + else: + return 1 + + +@lru_cache() +def get_is_master() -> bool: + return get_global_rank() == 0 + + +def validate_train_args(args: TrainArgs, output_size: int): + # assert args.model is not None or args.entropy_model is not None + if args.entropy_model is not None: + logger.info(f"Setting model output size to {args.entropy_model.vocab_size}") + args.entropy_model.vocab_size = output_size + + assert args.dump_dir, "Dump dir not set" + + if ( + args.distributed.dp_replicate + * args.distributed.dp_shard + * args.distributed.tp_size + != get_world_size() + ): + logging.info("Modifying TrainArgs distributed config") + assert get_world_size() % args.distributed.dp_shard == 0 + logging.info("World size: %s", get_world_size()) + logging.info( + "Existing setting: train_args.distributed.dp_shard=%s", + args.distributed.dp_shard, + ) + logging.info( + "Setting train_args.distributed.dp_replicate=%s, was dp_replicate=%s", + get_world_size() // args.distributed.dp_shard, + args.distributed.dp_replicate, + ) + args.distributed.dp_replicate = get_world_size() // args.distributed.dp_shard + + logging.info( + "Changing dp_replicate from %s to %s, to account for tp_size=%s", + args.distributed.dp_replicate, + args.distributed.dp_replicate // args.distributed.tp_size, + args.distributed.tp_size, + ) + assert args.distributed.dp_replicate % args.distributed.tp_size == 0 + args.distributed.dp_replicate = ( + args.distributed.dp_replicate // args.distributed.tp_size + ) + + logger.warning( + f"Setting Data Parallel size to {args.distributed.dp_replicate * args.distributed.dp_shard}" + ) + assert ( + args.distributed.dp_replicate + * args.distributed.dp_shard + * args.distributed.tp_size + == get_world_size() + ) + + if args.distributed.fsdp_type == "no_shard": + assert ( + args.distributed.dp_shard == 1 + and args.distributed.dp_replicate == get_world_size() + ) + + if args.model is not None: + args.model.max_seqlen = args.data.seq_len + if args.entropy_model is not None: + args.entropy_model.max_seqlen = args.data.seq_len + + if args.distributed.tp_size == 1: + logger.warning( + "Tensor parallelism has not been tested for a while, use at your own risk" + ) + + assert ( + args.probe_freq != args.profiling.mem_steps + ), "Don't profile during probe step" + assert ( + args.probe_freq != args.profiling.profile_steps + ), "Don't profile during probe step" + if args.logging.wandb is not None: + args.logging.wandb.name = args.name + + if args.probe_freq is not None: + assert ( + args.distributed.tp_size == 1 + ), "Probing not supported with tensor parallelism" + assert ( + args.distributed.selective_activation_checkpointing is False + ), "Probing not supported with selective activation checkpointing" + + +def compute_loss(p, y, mask, scale): + tok_loss = scale * F.cross_entropy( + p.flatten(0, 1), y.flatten(0, 1), reduction="none" + ) + if mask is None: + loss = tok_loss.mean() + else: + mask = mask.flatten(0, 1) + tok_loss = tok_loss * mask + loss = tok_loss.sum() / (mask.sum() + 1e-6) + return loss, tok_loss + + +def get_device_mesh(distributed_args): + tp_size = distributed_args.tp_size + dp_replicate = distributed_args.dp_replicate + dp_shard = distributed_args.dp_shard + + assert ( + dp_replicate * dp_shard * tp_size == get_world_size() + ), f"dp_replicate * dp_shard * tp_size ({dp_replicate} * {dp_shard} * {tp_size}) != world_size ({get_world_size()})" + + dims = [] + names = [] + if dp_replicate >= 1: + dims.append(dp_replicate) + names.append("dp_replicate") + if dp_shard > 1 or distributed_args.fsdp_type == "no_shard": + dims.append(dp_shard) + names.append("dp_shard") + if tp_size > 1: + dims.append(tp_size) + names.append("tp") + dims = tuple(dims) + names = tuple(names) + + return init_device_mesh("cuda", mesh_shape=dims, mesh_dim_names=names) + + +def build_fsdp_grouping_plan(): + group_plan: tuple[int, bool] = [] + + # Grouping and output seperately + group_plan.append(("tok_embeddings", False)) + + # Grouping by layers + # for i in range(model_args.n_layers): + # group_plan.append((f"layers.{i}", False)) + + group_plan.append(("output", True)) + + return group_plan + + +class MinimalModel(torch.nn.Module): + def __init__(self, dim: int, vocab_size: int): + super().__init__() + self.tok_embeddings = torch.nn.Embedding(vocab_size, dim) + + # self.norm = RMSNorm(args.dim, eps=args.norm_eps) + # self.layers = torch.nn.ModuleList() + # for _ in range(args.n_layers): + # self.layers.append(TransformerBlock(args)) + + self.output = torch.nn.Linear( + dim, + vocab_size, + bias=False, + ) + + def forward(self, tokens): + h = self.tok_embeddings(tokens) + logits = self.output(h) + # logits = self.output(self.norm(h)) + return logits + + def reset_parameters(self, init_std=None): + pass + + def init_weights(self): + pass + + +def train(): + args = TrainArgs( + dump_dir="/tmp", + name="debug_bf16", + model=None, + entropy_model=None, + distributed=DistributedArgs( + fsdp_type="full_shard", + model_dtype="bf16", + matmul_allow_tf32=False, + selective_activation_checkpointing=False, + tp_size=1, + ), + ) + tokenizer = args.data.tokenizer_args.build() + validate_train_args( + args, + tokenizer.n_words, + ) + dump_fs = fsspec.filesystem("file") + init_logger(os.path.join(args.dump_dir, "train.log"), fs=dump_fs) + setup_env(args.env) + setup_torch_distributed(args.distributed) + world_mesh = get_device_mesh(args.distributed) + logger.info(f"Starting job: {args.name}") + + # build dataloader + # need dp world size and rank + dp_mesh = world_mesh["dp_replicate"] + dp_degree = dp_mesh.size() + dp_rank = dp_mesh.get_local_rank() + if args.distributed.dp_shard > 1: + dp_rank = dp_rank * dp_degree + world_mesh["dp_shard"].get_local_rank() + dp_degree *= world_mesh["dp_shard"].size() + + logger.info(f"Running on dp rank : {dp_rank}") + logger.info(f"Running on dp size : {dp_degree}") + + torch.manual_seed(args.seed) + logger.info("Building model") + + # Initializing Model in meta device allows us to initialize models much bigger than 1 gpu's memory + with torch.device("meta"): + model = MinimalModel(768, tokenizer.n_words) + + model = parallelize_model( + model, + world_mesh, + args.model, + args.distributed, + fsdp_grouping_plan=build_fsdp_grouping_plan(), + tp_parallelize=None, + no_recompute_ops=get_no_recompute_ops(), + ) + + # Once we shard the model on different gpus we can actually initialize the model + # First we create empty tensors of the correct shapes + model = model.to_empty(device="cuda") + # Then we init the model. Please make sure this function initializes *ALL* parameters + # and buffers, otherwise you will have random values in the unitialized tensors + # which will silently fail (give nan gradients for example) + + with torch.random.fork_rng(devices=[torch.cuda.current_device()]): + torch.manual_seed(42) + model.init_weights() + check_model_value_range(model, range=10.0, std=1.0) + + # data_loader = args.data.build_from_rank(dp_rank, dp_degree) + + # train loop + model.train() + # data_loader = train_state.data_loader_state.build() + # batch_iterator = data_loader.create_iter() + # batch = next(batch_iterator) + # with open(f"/storage/home/par/toy-data/batch_{dp_rank}.pickle", "wb") as f: + # pickle.dump(batch, f) + with open(f"/storage/home/par/toy-data/batch_{dp_rank}.pickle", "rb") as f: + batch = pickle.load(f) + + batch_x = torch.from_numpy( + batch.x, + ).cuda() + batch_y = torch.from_numpy(batch.y).cuda() + mask = None if batch.mask is None else torch.from_numpy(batch.mask).cuda() + pred = model(batch_x) + loss, _ = compute_loss(pred, batch_y, mask, 1.0) + + # We scale loss with grad_acc_steps so the gradient is the same + # regardless of grad_acc_steps + loss = loss / args.grad_acc_steps + + # backward on scaled loss to create scaled gradients + loss.backward() + # For logging we undo that scaling + loss = loss.detach() * args.grad_acc_steps + + world_size = get_world_size() + if 1 < world_size <= 8 and False: + # For some reason, there are errors in reduces due to + # not working for non-bf16 numbers. This function is a patched + # version that converts gradients to bf16 before computing norms. + # The error only happens in distributed training on one node, + # hence the guard + grad_norm = fixed_clip_grad_norm_( + model.parameters(), max_norm=args.optim.clip, foreach=True + ) + else: + grad_norm = torch.nn.utils.clip_grad_norm_( + model.parameters(), max_norm=args.optim.clip, foreach=True + ) + + grad_norm = ( + grad_norm.full_tensor() if isinstance(grad_norm, DTensor) else grad_norm + ).item() + + # if isinstance(data_loader, MultiprocessIterator): + # logger.info("Closing MP iterator before exiting") + # data_loader.shutdown() + + +def main(): + """ + The command line interface here uses OmegaConf https://omegaconf.readthedocs.io/en/2.3_branch/usage.html#from-command-line-arguments + This accepts arguments as a dot list + So if the dataclass looks like + + @dataclass + class DummyArgs: + name: str + model: LMTransformerArgsgs + + @dataclass + class LMTransformerArgsgs: + dim: int + + Then you can pass model.dim=32 to change values in LMTransformerArgsgs + or just name=tictac for top level attributes. + + The behavior here is as follows: + 1. We instantiate TrainArgs with its default values + 2. We override those default values with the ones in the provided config file + 3. We override the result with the additional arguments provided through command line + + For example, if the config is the following + + model: + dim: 128 + n_layers: 4 + + and you call train.py with train.py model.dim=64 + + Then the final TrainArgs will have + + model: + dim: 64 + n_layers: 4 + + Plus all the default values in TrainArgs dataclass. + """ + train() + + +if __name__ == "__main__": + main() diff --git a/bytelatent/train.py b/bytelatent/train.py index 4641746..a775e46 100644 --- a/bytelatent/train.py +++ b/bytelatent/train.py @@ -527,7 +527,7 @@ def train(args: TrainArgs): step_tok_losses.append(tok_loss / train_state.scale) world_size = get_world_size() - if 1 < world_size <= 8: + if 1 < world_size <= 8 and False: # For some reason, there are errors in reduces due to # not working for non-bf16 numbers. This function is a patched # version that converts gradients to bf16 before computing norms.