From c79b1fdbd0dc8a275a69a4c770fccae66c455a21 Mon Sep 17 00:00:00 2001 From: Pedro Rodriguez Date: Tue, 4 Feb 2025 16:53:50 -0800 Subject: [PATCH] Fix distributed all reduce grad norm (#40) Summary: With >1 GPU, but only 1 node, all reduces fail when inputs are not bf16. This uses a modified copy of torch's grad norm to avoid failures Test Plan: - Run unit tests: - Run single gpu training: `python -m bytelatent.train config=internal/configs/s3_debug.yaml eval=null checkpoint.dump.every=100` - Run 1 node, multi-gpu training `torchrun --nproc-per-node 8 -m bytelatent.train config=internal/configs/s3_debug.yaml eval=null checkpoint.dump.every=100` --- bytelatent/norms.py | 100 ++++++++++++++++++++++++++++++++++++++++++++ bytelatent/train.py | 35 ++++++++++++++-- 2 files changed, 132 insertions(+), 3 deletions(-) create mode 100644 bytelatent/norms.py diff --git a/bytelatent/norms.py b/bytelatent/norms.py new file mode 100644 index 0000000..81d1652 --- /dev/null +++ b/bytelatent/norms.py @@ -0,0 +1,100 @@ +from typing import Dict, List, Optional, Tuple + +import torch +from torch import Tensor +from torch.utils._foreach_utils import ( + _device_has_foreach_support, + _group_tensors_by_device_and_dtype, + _has_foreach_support, +) + + +@torch.no_grad() +def fixed_clip_grad_norm_( + parameters: torch.Tensor | list[torch.Tensor], + max_norm: float, + norm_type: float = 2.0, + error_if_nonfinite: bool = False, + foreach: Optional[bool] = None, +) -> torch.Tensor: + r"""Clip the gradient norm of an iterable of parameters. + + The norm is computed over the norms of the individual gradients of all parameters, + as if the norms of the individual gradients were concatenated into a single vector. + Gradients are modified in-place. + + Args: + parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a + single Tensor that will have gradients normalized + max_norm (float): max norm of the gradients + norm_type (float): type of the used p-norm. Can be ``'inf'`` for + infinity norm. + error_if_nonfinite (bool): if True, an error is thrown if the total + norm of the gradients from :attr:`parameters` is ``nan``, + ``inf``, or ``-inf``. Default: False (will switch to True in the future) + foreach (bool): use the faster foreach-based implementation. + If ``None``, use the foreach implementation for CUDA and CPU native tensors and silently + fall back to the slow implementation for other device types. + Default: ``None`` + + Returns: + Total norm of the parameter gradients (viewed as a single vector). + """ + if isinstance(parameters, torch.Tensor): + parameters = [parameters] + grads = [p.grad.to(torch.bfloat16) for p in parameters if p.grad is not None] + max_norm = float(max_norm) + norm_type = float(norm_type) + if len(grads) == 0: + return torch.tensor(0.0) + first_device = grads[0].device + grouped_grads: Dict[ + Tuple[torch.device, torch.dtype], Tuple[List[List[Tensor]], List[int]] + ] = _group_tensors_by_device_and_dtype( + [grads] + ) # type: ignore[assignment] + + norms: List[Tensor] = [] + for (device, _), ([device_grads], _) in grouped_grads.items(): # type: ignore[assignment] + if (foreach is None and _has_foreach_support(device_grads, device)) or ( + foreach and _device_has_foreach_support(device) + ): + norms.extend(torch._foreach_norm(device_grads, norm_type)) + elif foreach: + raise RuntimeError( + f"foreach=True was passed, but can't use the foreach API on {device.type} tensors" + ) + else: + norms.extend([torch.linalg.vector_norm(g, norm_type) for g in device_grads]) + + total_norm = torch.linalg.vector_norm( + torch.stack([norm.to(first_device) for norm in norms]), norm_type + ) + + if error_if_nonfinite and torch.logical_or(total_norm.isnan(), total_norm.isinf()): + raise RuntimeError( + f"The total norm of order {norm_type} for gradients from " + "`parameters` is non-finite, so it cannot be clipped. To disable " + "this error and scale the gradients by the non-finite norm anyway, " + "set `error_if_nonfinite=False`" + ) + clip_coef = max_norm / (total_norm + 1e-6) + # Note: multiplying by the clamped coef is redundant when the coef is clamped to 1, but doing so + # avoids a `if clip_coef < 1:` conditional which can require a CPU <=> device synchronization + # when the gradients do not reside in CPU memory. + clip_coef_clamped = torch.clamp(clip_coef, max=1.0) + for (device, _), ([device_grads], _) in grouped_grads.items(): # type: ignore[assignment] + if (foreach is None and _has_foreach_support(device_grads, device)) or ( + foreach and _device_has_foreach_support(device) + ): + torch._foreach_mul_(device_grads, clip_coef_clamped.to(device)) + elif foreach: + raise RuntimeError( + f"foreach=True was passed, but can't use the foreach API on {device.type} tensors" + ) + else: + clip_coef_clamped_device = clip_coef_clamped.to(device) + for g in device_grads: + g.mul_(clip_coef_clamped_device) + + return total_norm diff --git a/bytelatent/train.py b/bytelatent/train.py index 6b20ecd..86d1c7a 100644 --- a/bytelatent/train.py +++ b/bytelatent/train.py @@ -47,6 +47,7 @@ from bytelatent.eval import EVAL_FOLDER_NAME, launch_eval from bytelatent.logger import init_logger from bytelatent.metrics import GPUMemoryMonitor, MetricLogger, get_num_params from bytelatent.model.blt import ByteLatentTransformer +from bytelatent.norms import fixed_clip_grad_norm_ from bytelatent.optim import build_optimizer from bytelatent.probe import AutoProbeD from bytelatent.profiling import maybe_run_profiler @@ -147,9 +148,26 @@ def validate_train_args(args: TrainArgs, output_size: int): * args.distributed.tp_size != get_world_size() ): + logging.info("Modifying TrainArgs distributed config") assert get_world_size() % args.distributed.dp_shard == 0 + logging.info("World size: %s", get_world_size()) + logging.info( + "Existing setting: train_args.distributed.dp_shard=%s", + args.distributed.dp_shard, + ) + logging.info( + "Setting train_args.distributed.dp_replicate=%s, was dp_replicate=%s", + get_world_size() // args.distributed.dp_shard, + args.distributed.dp_replicate, + ) args.distributed.dp_replicate = get_world_size() // args.distributed.dp_shard + logging.info( + "Changing dp_replicate from %s to %s, to account for tp_size=%s", + args.distributed.dp_replicate, + args.distributed.dp_replicate // args.distributed.tp_size, + args.distributed.tp_size, + ) assert args.distributed.dp_replicate % args.distributed.tp_size == 0 args.distributed.dp_replicate = ( args.distributed.dp_replicate // args.distributed.tp_size @@ -470,9 +488,20 @@ def train(args: TrainArgs): # For logging we undo that scaling loss = loss.detach() * args.grad_acc_steps - grad_norm = torch.nn.utils.clip_grad_norm_( - model.parameters(), max_norm=args.optim.clip, foreach=True - ) + world_size = get_world_size() + if 1 < world_size <= 8: + # For some reason, there are errors in reduces due to + # not working for non-bf16 numbers. This function is a patched + # version that converts gradients to bf16 before computing norms. + # The error only happens in distributed training on one node, + # hence the guard + grad_norm = fixed_clip_grad_norm_( + model.parameters(), max_norm=args.optim.clip, foreach=True + ) + else: + grad_norm = torch.nn.utils.clip_grad_norm_( + model.parameters(), max_norm=args.optim.clip, foreach=True + ) grad_norm = ( grad_norm.full_tensor() if isinstance(grad_norm, DTensor) else grad_norm