From c79b1fdbd0dc8a275a69a4c770fccae66c455a21 Mon Sep 17 00:00:00 2001 From: Pedro Rodriguez Date: Tue, 4 Feb 2025 16:53:50 -0800 Subject: [PATCH 1/2] Fix distributed all reduce grad norm (#40) Summary: With >1 GPU, but only 1 node, all reduces fail when inputs are not bf16. This uses a modified copy of torch's grad norm to avoid failures Test Plan: - Run unit tests: - Run single gpu training: `python -m bytelatent.train config=internal/configs/s3_debug.yaml eval=null checkpoint.dump.every=100` - Run 1 node, multi-gpu training `torchrun --nproc-per-node 8 -m bytelatent.train config=internal/configs/s3_debug.yaml eval=null checkpoint.dump.every=100` --- bytelatent/norms.py | 100 ++++++++++++++++++++++++++++++++++++++++++++ bytelatent/train.py | 35 ++++++++++++++-- 2 files changed, 132 insertions(+), 3 deletions(-) create mode 100644 bytelatent/norms.py diff --git a/bytelatent/norms.py b/bytelatent/norms.py new file mode 100644 index 0000000..81d1652 --- /dev/null +++ b/bytelatent/norms.py @@ -0,0 +1,100 @@ +from typing import Dict, List, Optional, Tuple + +import torch +from torch import Tensor +from torch.utils._foreach_utils import ( + _device_has_foreach_support, + _group_tensors_by_device_and_dtype, + _has_foreach_support, +) + + +@torch.no_grad() +def fixed_clip_grad_norm_( + parameters: torch.Tensor | list[torch.Tensor], + max_norm: float, + norm_type: float = 2.0, + error_if_nonfinite: bool = False, + foreach: Optional[bool] = None, +) -> torch.Tensor: + r"""Clip the gradient norm of an iterable of parameters. + + The norm is computed over the norms of the individual gradients of all parameters, + as if the norms of the individual gradients were concatenated into a single vector. + Gradients are modified in-place. + + Args: + parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a + single Tensor that will have gradients normalized + max_norm (float): max norm of the gradients + norm_type (float): type of the used p-norm. Can be ``'inf'`` for + infinity norm. + error_if_nonfinite (bool): if True, an error is thrown if the total + norm of the gradients from :attr:`parameters` is ``nan``, + ``inf``, or ``-inf``. Default: False (will switch to True in the future) + foreach (bool): use the faster foreach-based implementation. + If ``None``, use the foreach implementation for CUDA and CPU native tensors and silently + fall back to the slow implementation for other device types. + Default: ``None`` + + Returns: + Total norm of the parameter gradients (viewed as a single vector). + """ + if isinstance(parameters, torch.Tensor): + parameters = [parameters] + grads = [p.grad.to(torch.bfloat16) for p in parameters if p.grad is not None] + max_norm = float(max_norm) + norm_type = float(norm_type) + if len(grads) == 0: + return torch.tensor(0.0) + first_device = grads[0].device + grouped_grads: Dict[ + Tuple[torch.device, torch.dtype], Tuple[List[List[Tensor]], List[int]] + ] = _group_tensors_by_device_and_dtype( + [grads] + ) # type: ignore[assignment] + + norms: List[Tensor] = [] + for (device, _), ([device_grads], _) in grouped_grads.items(): # type: ignore[assignment] + if (foreach is None and _has_foreach_support(device_grads, device)) or ( + foreach and _device_has_foreach_support(device) + ): + norms.extend(torch._foreach_norm(device_grads, norm_type)) + elif foreach: + raise RuntimeError( + f"foreach=True was passed, but can't use the foreach API on {device.type} tensors" + ) + else: + norms.extend([torch.linalg.vector_norm(g, norm_type) for g in device_grads]) + + total_norm = torch.linalg.vector_norm( + torch.stack([norm.to(first_device) for norm in norms]), norm_type + ) + + if error_if_nonfinite and torch.logical_or(total_norm.isnan(), total_norm.isinf()): + raise RuntimeError( + f"The total norm of order {norm_type} for gradients from " + "`parameters` is non-finite, so it cannot be clipped. To disable " + "this error and scale the gradients by the non-finite norm anyway, " + "set `error_if_nonfinite=False`" + ) + clip_coef = max_norm / (total_norm + 1e-6) + # Note: multiplying by the clamped coef is redundant when the coef is clamped to 1, but doing so + # avoids a `if clip_coef < 1:` conditional which can require a CPU <=> device synchronization + # when the gradients do not reside in CPU memory. + clip_coef_clamped = torch.clamp(clip_coef, max=1.0) + for (device, _), ([device_grads], _) in grouped_grads.items(): # type: ignore[assignment] + if (foreach is None and _has_foreach_support(device_grads, device)) or ( + foreach and _device_has_foreach_support(device) + ): + torch._foreach_mul_(device_grads, clip_coef_clamped.to(device)) + elif foreach: + raise RuntimeError( + f"foreach=True was passed, but can't use the foreach API on {device.type} tensors" + ) + else: + clip_coef_clamped_device = clip_coef_clamped.to(device) + for g in device_grads: + g.mul_(clip_coef_clamped_device) + + return total_norm diff --git a/bytelatent/train.py b/bytelatent/train.py index 6b20ecd..86d1c7a 100644 --- a/bytelatent/train.py +++ b/bytelatent/train.py @@ -47,6 +47,7 @@ from bytelatent.eval import EVAL_FOLDER_NAME, launch_eval from bytelatent.logger import init_logger from bytelatent.metrics import GPUMemoryMonitor, MetricLogger, get_num_params from bytelatent.model.blt import ByteLatentTransformer +from bytelatent.norms import fixed_clip_grad_norm_ from bytelatent.optim import build_optimizer from bytelatent.probe import AutoProbeD from bytelatent.profiling import maybe_run_profiler @@ -147,9 +148,26 @@ def validate_train_args(args: TrainArgs, output_size: int): * args.distributed.tp_size != get_world_size() ): + logging.info("Modifying TrainArgs distributed config") assert get_world_size() % args.distributed.dp_shard == 0 + logging.info("World size: %s", get_world_size()) + logging.info( + "Existing setting: train_args.distributed.dp_shard=%s", + args.distributed.dp_shard, + ) + logging.info( + "Setting train_args.distributed.dp_replicate=%s, was dp_replicate=%s", + get_world_size() // args.distributed.dp_shard, + args.distributed.dp_replicate, + ) args.distributed.dp_replicate = get_world_size() // args.distributed.dp_shard + logging.info( + "Changing dp_replicate from %s to %s, to account for tp_size=%s", + args.distributed.dp_replicate, + args.distributed.dp_replicate // args.distributed.tp_size, + args.distributed.tp_size, + ) assert args.distributed.dp_replicate % args.distributed.tp_size == 0 args.distributed.dp_replicate = ( args.distributed.dp_replicate // args.distributed.tp_size @@ -470,9 +488,20 @@ def train(args: TrainArgs): # For logging we undo that scaling loss = loss.detach() * args.grad_acc_steps - grad_norm = torch.nn.utils.clip_grad_norm_( - model.parameters(), max_norm=args.optim.clip, foreach=True - ) + world_size = get_world_size() + if 1 < world_size <= 8: + # For some reason, there are errors in reduces due to + # not working for non-bf16 numbers. This function is a patched + # version that converts gradients to bf16 before computing norms. + # The error only happens in distributed training on one node, + # hence the guard + grad_norm = fixed_clip_grad_norm_( + model.parameters(), max_norm=args.optim.clip, foreach=True + ) + else: + grad_norm = torch.nn.utils.clip_grad_norm_( + model.parameters(), max_norm=args.optim.clip, foreach=True + ) grad_norm = ( grad_norm.full_tensor() if isinstance(grad_norm, DTensor) else grad_norm From b6e53f1d4c2418775cc2ad4050e06b6ac0d6c401 Mon Sep 17 00:00:00 2001 From: Pedro Rodriguez Date: Wed, 5 Feb 2025 00:55:18 +0000 Subject: [PATCH 2/2] Update checkpointing to use fsspec Summary: - Make the data/checkpoint code fsspec compatible Test Plan: Run unit tests and the commands below ``` python -m bytelatent.train config=internal/configs/s3_debug.yaml eval=null checkpoint.dump.every=100 ``` ``` torchrun --nproc-per-node 8 -m bytelatent.train config=internal/configs/s3_debug.yaml eval=null checkpoint.dump.every=100 ``` ``` python -m bytelatent.train config=internal/configs/s3_debug.yaml eval=null checkpoint.dump.every=100 dump_dir=s3://blt/scratch/checkpoint-test/ ``` ``` torchrun --nproc-per-node 8 -m bytelatent.train config=internal/configs/s3_debug.yaml eval=null checkpoint.dump.every=100 dump_dir=s3://blt/scratch/checkpoint-test/ ``` --- bytelatent/checkpoint.py | 97 +++++++++++++++++++++------------------- bytelatent/train.py | 6 ++- 2 files changed, 55 insertions(+), 48 deletions(-) diff --git a/bytelatent/checkpoint.py b/bytelatent/checkpoint.py index f213c84..6631673 100644 --- a/bytelatent/checkpoint.py +++ b/bytelatent/checkpoint.py @@ -4,8 +4,6 @@ import json import logging import os import re -from pathlib import Path -from typing import List, Optional, Tuple import fsspec import torch @@ -70,26 +68,29 @@ def consolidate_checkpoints(fs: fsspec.AbstractFileSystem, ckpt_dir: str): Returns the path to the consolidated checkpoint """ - consolidate_path = Path(ckpt_dir) / CONSOLIDATE_FOLDER - if not (consolidate_path / CONSOLIDATE_NAME).exists(): - consolidate_path.mkdir(exist_ok=True) - logger.info(f"Consolidating to: {str(consolidate_path)}") - dcp_to_torch_save(ckpt_dir, str(consolidate_path / CONSOLIDATE_NAME)) - (consolidate_path / CONFIG_NAME).write_text( - (Path(ckpt_dir) / CONFIG_NAME).read_text() + consolidate_path = os.path.join(ckpt_dir, CONSOLIDATE_FOLDER) + consolidate_name = os.path.join(consolidate_path, CONSOLIDATE_NAME) + if not fs.exists(consolidate_name): + fs.mkdirs(consolidate_path, exist_ok=True) + logger.info(f"Consolidating to: {consolidate_path}") + dcp_to_torch_save(ckpt_dir, consolidate_name) + fs.write_text( + os.path.join(consolidate_path, CONFIG_NAME), + fs.read_text(os.path.join(ckpt_dir, CONFIG_NAME)), ) logger.info("Consolidated !") return consolidate_path def load_from_checkpoint( + fs: fsspec.AbstractFileSystem, ckpt_dir: str, model: nn.Module, - optimizer: Optional[torch.optim.Optimizer] = None, + optimizer: torch.optim.Optimizer | None = None, model_key: str = "model", optim_key: str = "optim", ): - if not (Path(ckpt_dir) / ".metadata").exists(): + if not fs.exists(os.path.join(ckpt_dir, ".metadata")): raise ValueError( f"Please convert the checkpoint distcp format using `torch.distributed.checkpoint.format_utils.torch_save_to_dcp` before loading it" ) @@ -121,13 +122,13 @@ class CheckpointManager: self.existing_saves = self.get_existing_saves() - def get_existing_saves(self) -> List[Path]: + def get_existing_saves(self) -> list[str]: folders = [ p - for p in Path(self.path).iterdir() - if p.is_dir() and re.match(RE_FOLDER, p.name) + for p in self.fs.ls(self.path) + if self.fs.isdir(p) and re.match(RE_FOLDER, os.path.basename(p)) ] - folders.sort(key=lambda p: _get_key_step(p.name)) + folders.sort(key=lambda p: _get_key_step(os.path.basename(p))) return folders def clean_up(self): @@ -136,8 +137,9 @@ class CheckpointManager: eval_folders = [] other_folders = [] for p in self.existing_saves: - is_dump = _get_key_step(p.name) % self.dump_every.every == 0 - is_eval = _get_key_step(p.name) % self.eval_every.every == 0 + assert isinstance(p, str), f"Base path type: {p}" + is_dump = _get_key_step(os.path.basename(p)) % self.dump_every.every == 0 + is_eval = _get_key_step(os.path.basename(p)) % self.eval_every.every == 0 if is_dump: dump_folders.append(p) if is_eval: @@ -161,40 +163,39 @@ class CheckpointManager: if dist.get_rank() == 0: for folder in folder_to_remove: - for file in folder.iterdir(): - if file.is_file(): - file.unlink() - elif file.is_dir(): - assert file.name in [CONSOLIDATE_FOLDER] - for f in file.iterdir(): - f.unlink() - file.rmdir() - folder.rmdir() + for file in self.fs.ls(folder): + if self.fs.isfile(file): + self.fs.rm_file(file) + elif self.fs.isdir(file): + assert os.path.name(file) in [CONSOLIDATE_FOLDER] + for f in self.fs.ls(file): + self.fs.rm(f) + self.fs.rmdir(file) + self.fs.rmdir(folder) dist.barrier() self.existing_saves = list(folder_to_keep) - self.existing_saves.sort(key=lambda p: _get_key_step(p.name)) + self.existing_saves.sort(key=lambda p: _get_key_step(os.path.basename(p))) - def get_last_step_path(self, dp_rank: int = 0) -> Optional[Path]: + def get_last_step_path(self, dp_rank: int = 0) -> str | None: path = None for p in reversed(self.existing_saves): - if (p / TRAIN_STATE_NAME.format(dp_rank)).is_file(): + + if self.fs.isfile(os.path.join(p, TRAIN_STATE_NAME.format(dp_rank))): path = p break return path - def _create_folder(self, base_path: Path, folder_name: str) -> Path: - folder = base_path / folder_name + def _create_folder(self, base_path: str, folder_name: str) -> str: + folder = os.path.join(base_path, folder_name) if get_is_master(): - folder.mkdir(parents=False, exist_ok=True) + self.fs.mkdirs(folder, exist_ok=True) if dist.is_initialized(): dist.barrier() return folder - def _get_dp_tp_mesh( - self, device_mesh: Optional[DeviceMesh] = None - ) -> Tuple[int, int]: + def _get_dp_tp_mesh(self, device_mesh: DeviceMesh | None = None) -> tuple[int, int]: dp_rank = 0 tp_rank = 0 if device_mesh is not None: @@ -222,14 +223,14 @@ class CheckpointManager: model, optimizer, train_state, - config, - device_mesh: Optional[DeviceMesh] = None, + config: BaseModel, + device_mesh: DeviceMesh | None = None, ) -> bool: # When creating directory check if only rank0 or is there other solution - path = Path(self.path) + path = self.path curr_save_dir = self._create_folder(path, FOLDER_NAME.format(train_state.step)) - logger.info(f"Saving to: {str(curr_save_dir)}") + logger.info(f"Saving to: {curr_save_dir}") if dist.is_initialized(): dist.barrier() @@ -242,17 +243,19 @@ class CheckpointManager: if dist.is_initialized(): dist.barrier() + print("config type", type(config)) if get_is_master(): - config.dump_to_yaml_file(curr_save_dir / CONFIG_NAME) + self.fs.write_text( + os.path.join(curr_save_dir, CONFIG_NAME), config.model_dump_json() + ) # Add json dump here dp_rank, tp_rank = self._get_dp_tp_mesh(device_mesh) if tp_rank == 0: train_state_name = TRAIN_STATE_NAME.format(dp_rank) - logger.info( - f"Saving train state to: {str(curr_save_dir / train_state_name)}" - ) - with open(curr_save_dir / train_state_name, "w") as f: + train_state_full_path = os.path.join(curr_save_dir, train_state_name) + logger.info(f"Saving train state to: {train_state_full_path}") + with self.fs.open(train_state_full_path, "w") as f: json.dump(train_state.state_dict(), f) logger.info("Train state saved !") @@ -271,7 +274,7 @@ class CheckpointManager: optimizer, train_state, device_mesh: DeviceMesh, - path: Optional[Path] = None, + path: str | None = None, ): dp_rank, tp_rank = self._get_dp_tp_mesh(device_mesh) # Loading tries to load the provided path, if not available the last saved step and finally from the init path @@ -284,12 +287,12 @@ class CheckpointManager: # Only load train state if it's provided, the files exist and we're not loading from init path train_state_name = TRAIN_STATE_NAME.format(dp_rank) logger.info("Reloading train state") - with open(path / train_state_name, "r") as f: + with self.fs.open(os.path.join(path, train_state_name), "r") as f: train_state_dict = json.load(f) train_state.load_state_dict(train_state_dict) logger.info("Train state reloaded") - logger.info(f"Loading from: {str(path)}") + logger.info(f"Loading from: {path}") state_dict = self.get_state_dict( model=model, optimizer=optimizer, diff --git a/bytelatent/train.py b/bytelatent/train.py index 86d1c7a..c80a74c 100644 --- a/bytelatent/train.py +++ b/bytelatent/train.py @@ -25,6 +25,7 @@ from torch.optim import lr_scheduler from bytelatent.args import TrainArgs, parse_args from bytelatent.checkpoint import CheckpointManager, load_from_checkpoint +from bytelatent.data.file_util import get_fs from bytelatent.data.iterators.multiprocess_iterator import ( MultiprocessIterator, MultiprocessIteratorState, @@ -313,8 +314,11 @@ def train(args: TrainArgs): if args.checkpoint.init_ckpt_path: logger.info(f"Loading initial model from {args.checkpoint.init_ckpt_path}") + ckpt_fs = get_fs( + args.checkpoint.init_ckpt_path, s3_profile=args.checkpoint.s3_profile + ) load_from_checkpoint( - args.checkpoint.init_ckpt_path, model, model_key="model" + ckpt_fs, args.checkpoint.init_ckpt_path, model, model_key="model" ) # Put model_key="" if its directly the model checkpoint model.rope_embeddings.reset_parameters() # For RoPe initialization since it's a buffer it might not be loaded else: