mirror of
https://github.com/facebookresearch/blt.git
synced 2025-02-22 13:02:14 +00:00
Fix distributed all reduce grad norm (#40)
Summary: With >1 GPU, but only 1 node, all reduces fail when inputs are not bf16. This uses a modified copy of torch's grad norm to avoid failures Test Plan: - Run unit tests: - Run single gpu training: `python -m bytelatent.train config=internal/configs/s3_debug.yaml eval=null checkpoint.dump.every=100` - Run 1 node, multi-gpu training `torchrun --nproc-per-node 8 -m bytelatent.train config=internal/configs/s3_debug.yaml eval=null checkpoint.dump.every=100`
This commit is contained in:
parent
7044771a12
commit
c79b1fdbd0
100
bytelatent/norms.py
Normal file
100
bytelatent/norms.py
Normal file
|
@ -0,0 +1,100 @@
|
|||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from torch.utils._foreach_utils import (
|
||||
_device_has_foreach_support,
|
||||
_group_tensors_by_device_and_dtype,
|
||||
_has_foreach_support,
|
||||
)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def fixed_clip_grad_norm_(
|
||||
parameters: torch.Tensor | list[torch.Tensor],
|
||||
max_norm: float,
|
||||
norm_type: float = 2.0,
|
||||
error_if_nonfinite: bool = False,
|
||||
foreach: Optional[bool] = None,
|
||||
) -> torch.Tensor:
|
||||
r"""Clip the gradient norm of an iterable of parameters.
|
||||
|
||||
The norm is computed over the norms of the individual gradients of all parameters,
|
||||
as if the norms of the individual gradients were concatenated into a single vector.
|
||||
Gradients are modified in-place.
|
||||
|
||||
Args:
|
||||
parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a
|
||||
single Tensor that will have gradients normalized
|
||||
max_norm (float): max norm of the gradients
|
||||
norm_type (float): type of the used p-norm. Can be ``'inf'`` for
|
||||
infinity norm.
|
||||
error_if_nonfinite (bool): if True, an error is thrown if the total
|
||||
norm of the gradients from :attr:`parameters` is ``nan``,
|
||||
``inf``, or ``-inf``. Default: False (will switch to True in the future)
|
||||
foreach (bool): use the faster foreach-based implementation.
|
||||
If ``None``, use the foreach implementation for CUDA and CPU native tensors and silently
|
||||
fall back to the slow implementation for other device types.
|
||||
Default: ``None``
|
||||
|
||||
Returns:
|
||||
Total norm of the parameter gradients (viewed as a single vector).
|
||||
"""
|
||||
if isinstance(parameters, torch.Tensor):
|
||||
parameters = [parameters]
|
||||
grads = [p.grad.to(torch.bfloat16) for p in parameters if p.grad is not None]
|
||||
max_norm = float(max_norm)
|
||||
norm_type = float(norm_type)
|
||||
if len(grads) == 0:
|
||||
return torch.tensor(0.0)
|
||||
first_device = grads[0].device
|
||||
grouped_grads: Dict[
|
||||
Tuple[torch.device, torch.dtype], Tuple[List[List[Tensor]], List[int]]
|
||||
] = _group_tensors_by_device_and_dtype(
|
||||
[grads]
|
||||
) # type: ignore[assignment]
|
||||
|
||||
norms: List[Tensor] = []
|
||||
for (device, _), ([device_grads], _) in grouped_grads.items(): # type: ignore[assignment]
|
||||
if (foreach is None and _has_foreach_support(device_grads, device)) or (
|
||||
foreach and _device_has_foreach_support(device)
|
||||
):
|
||||
norms.extend(torch._foreach_norm(device_grads, norm_type))
|
||||
elif foreach:
|
||||
raise RuntimeError(
|
||||
f"foreach=True was passed, but can't use the foreach API on {device.type} tensors"
|
||||
)
|
||||
else:
|
||||
norms.extend([torch.linalg.vector_norm(g, norm_type) for g in device_grads])
|
||||
|
||||
total_norm = torch.linalg.vector_norm(
|
||||
torch.stack([norm.to(first_device) for norm in norms]), norm_type
|
||||
)
|
||||
|
||||
if error_if_nonfinite and torch.logical_or(total_norm.isnan(), total_norm.isinf()):
|
||||
raise RuntimeError(
|
||||
f"The total norm of order {norm_type} for gradients from "
|
||||
"`parameters` is non-finite, so it cannot be clipped. To disable "
|
||||
"this error and scale the gradients by the non-finite norm anyway, "
|
||||
"set `error_if_nonfinite=False`"
|
||||
)
|
||||
clip_coef = max_norm / (total_norm + 1e-6)
|
||||
# Note: multiplying by the clamped coef is redundant when the coef is clamped to 1, but doing so
|
||||
# avoids a `if clip_coef < 1:` conditional which can require a CPU <=> device synchronization
|
||||
# when the gradients do not reside in CPU memory.
|
||||
clip_coef_clamped = torch.clamp(clip_coef, max=1.0)
|
||||
for (device, _), ([device_grads], _) in grouped_grads.items(): # type: ignore[assignment]
|
||||
if (foreach is None and _has_foreach_support(device_grads, device)) or (
|
||||
foreach and _device_has_foreach_support(device)
|
||||
):
|
||||
torch._foreach_mul_(device_grads, clip_coef_clamped.to(device))
|
||||
elif foreach:
|
||||
raise RuntimeError(
|
||||
f"foreach=True was passed, but can't use the foreach API on {device.type} tensors"
|
||||
)
|
||||
else:
|
||||
clip_coef_clamped_device = clip_coef_clamped.to(device)
|
||||
for g in device_grads:
|
||||
g.mul_(clip_coef_clamped_device)
|
||||
|
||||
return total_norm
|
|
@ -47,6 +47,7 @@ from bytelatent.eval import EVAL_FOLDER_NAME, launch_eval
|
|||
from bytelatent.logger import init_logger
|
||||
from bytelatent.metrics import GPUMemoryMonitor, MetricLogger, get_num_params
|
||||
from bytelatent.model.blt import ByteLatentTransformer
|
||||
from bytelatent.norms import fixed_clip_grad_norm_
|
||||
from bytelatent.optim import build_optimizer
|
||||
from bytelatent.probe import AutoProbeD
|
||||
from bytelatent.profiling import maybe_run_profiler
|
||||
|
@ -147,9 +148,26 @@ def validate_train_args(args: TrainArgs, output_size: int):
|
|||
* args.distributed.tp_size
|
||||
!= get_world_size()
|
||||
):
|
||||
logging.info("Modifying TrainArgs distributed config")
|
||||
assert get_world_size() % args.distributed.dp_shard == 0
|
||||
logging.info("World size: %s", get_world_size())
|
||||
logging.info(
|
||||
"Existing setting: train_args.distributed.dp_shard=%s",
|
||||
args.distributed.dp_shard,
|
||||
)
|
||||
logging.info(
|
||||
"Setting train_args.distributed.dp_replicate=%s, was dp_replicate=%s",
|
||||
get_world_size() // args.distributed.dp_shard,
|
||||
args.distributed.dp_replicate,
|
||||
)
|
||||
args.distributed.dp_replicate = get_world_size() // args.distributed.dp_shard
|
||||
|
||||
logging.info(
|
||||
"Changing dp_replicate from %s to %s, to account for tp_size=%s",
|
||||
args.distributed.dp_replicate,
|
||||
args.distributed.dp_replicate // args.distributed.tp_size,
|
||||
args.distributed.tp_size,
|
||||
)
|
||||
assert args.distributed.dp_replicate % args.distributed.tp_size == 0
|
||||
args.distributed.dp_replicate = (
|
||||
args.distributed.dp_replicate // args.distributed.tp_size
|
||||
|
@ -470,9 +488,20 @@ def train(args: TrainArgs):
|
|||
# For logging we undo that scaling
|
||||
loss = loss.detach() * args.grad_acc_steps
|
||||
|
||||
grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||
model.parameters(), max_norm=args.optim.clip, foreach=True
|
||||
)
|
||||
world_size = get_world_size()
|
||||
if 1 < world_size <= 8:
|
||||
# For some reason, there are errors in reduces due to
|
||||
# not working for non-bf16 numbers. This function is a patched
|
||||
# version that converts gradients to bf16 before computing norms.
|
||||
# The error only happens in distributed training on one node,
|
||||
# hence the guard
|
||||
grad_norm = fixed_clip_grad_norm_(
|
||||
model.parameters(), max_norm=args.optim.clip, foreach=True
|
||||
)
|
||||
else:
|
||||
grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||
model.parameters(), max_norm=args.optim.clip, foreach=True
|
||||
)
|
||||
|
||||
grad_norm = (
|
||||
grad_norm.full_tensor() if isinstance(grad_norm, DTensor) else grad_norm
|
||||
|
|
Loading…
Reference in a new issue