blt/bytelatent/broken_train.py

624 lines
20 KiB
Python
Raw Normal View History

# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
from datetime import timedelta
from enum import Enum
from functools import lru_cache
import logging
import math
import sys
import time
from typing import Dict, List, Optional, Tuple
from torch import Tensor
import os
import pickle
import fsspec
import torch
import torch.distributed
import torch.nn.functional
import torch.nn.functional as F
from torch.distributed._tensor import DTensor
from torch.distributed.device_mesh import init_device_mesh
from torch.utils._foreach_utils import (
_device_has_foreach_support,
_group_tensors_by_device_and_dtype,
_has_foreach_support,
)
from bytelatent.args import TrainArgs
from bytelatent.distributed import (
DistributedArgs,
check_model_value_range,
parallelize_model,
setup_env,
setup_torch_distributed,
)
logger = logging.getLogger()
def set_root_log_level(log_level: str):
logger = logging.getLogger()
level: int | str = log_level.upper()
try:
level = int(log_level)
except ValueError:
pass
try:
logger.setLevel(level) # type: ignore
except Exception:
logger.warning(
f"Failed to set logging level to {log_level}, using default 'NOTSET'"
)
logger.setLevel(logging.NOTSET)
class LogFormatter(logging.Formatter):
"""
Custom logger for distributed jobs, displaying rank
and preserving indent from the custom prefix format.
"""
def __init__(self):
self.start_time = time.time()
self.rank = get_global_rank()
self.show_rank = not get_is_slurm_job() # srun has --label
def formatTime(self, record):
subsecond, seconds = math.modf(record.created)
curr_date = (
time.strftime("%y-%m-%d %H:%M:%S", time.localtime(seconds))
+ f".{int(subsecond * 1_000_000):06d}"
)
delta = timedelta(seconds=round(record.created - self.start_time))
return f"{curr_date} - {delta}"
def formatPrefix(self, record):
fmt_time = self.formatTime(record)
if self.show_rank:
return f"{self.rank}: {record.levelname:<7} {fmt_time} - "
else:
return f"{record.levelname:<7} {fmt_time} - "
def formatMessage(self, record, indent: str):
content = record.getMessage()
content = content.replace("\n", "\n" + indent)
# Exception handling as in the default formatter, albeit with indenting
# according to our custom prefix
if record.exc_info:
# Cache the traceback text to avoid converting it multiple times
# (it's constant anyway)
if not record.exc_text:
record.exc_text = self.formatException(record.exc_info)
if record.exc_text:
if content[-1:] != "\n":
content = content + "\n" + indent
content = content + indent.join(
[l + "\n" for l in record.exc_text.splitlines()]
)
if content[-1:] == "\n":
content = content[:-1]
if record.stack_info:
if content[-1:] != "\n":
content = content + "\n" + indent
stack_text = self.formatStack(record.stack_info)
content = content + indent.join([l + "\n" for l in stack_text.splitlines()])
if content[-1:] == "\n":
content = content[:-1]
return content
def format(self, record):
prefix = self.formatPrefix(record)
indent = " " * len(prefix)
content = self.formatMessage(record, indent)
return prefix + content
def init_logger(
log_file: str | None = None,
*,
name: str | None = None,
level: str = "INFO",
fs: fsspec.AbstractFileSystem | None = None,
):
"""
Setup logging.
Args:
log_file: A file name to save file logs to.
name: The name of the logger to configure, by default the root logger.
level: The logging level to use.
"""
set_root_log_level(level)
logger = logging.getLogger(name)
# stdout: everything
stdout_handler = logging.StreamHandler(sys.stdout)
stdout_handler.setLevel(logging.NOTSET)
stdout_handler.setFormatter(LogFormatter())
# stderr: warnings / errors and above
stderr_handler = logging.StreamHandler(sys.stderr)
stderr_handler.setLevel(logging.WARNING)
stderr_handler.setFormatter(LogFormatter())
# set stream handlers
logger.handlers.clear()
logger.handlers.append(stdout_handler)
logger.handlers.append(stderr_handler)
@torch.no_grad()
def fixed_clip_grad_norm_(
parameters: torch.Tensor | list[torch.Tensor],
max_norm: float,
norm_type: float = 2.0,
error_if_nonfinite: bool = False,
foreach: Optional[bool] = None,
) -> torch.Tensor:
r"""Clip the gradient norm of an iterable of parameters.
The norm is computed over the norms of the individual gradients of all parameters,
as if the norms of the individual gradients were concatenated into a single vector.
Gradients are modified in-place.
Args:
parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a
single Tensor that will have gradients normalized
max_norm (float): max norm of the gradients
norm_type (float): type of the used p-norm. Can be ``'inf'`` for
infinity norm.
error_if_nonfinite (bool): if True, an error is thrown if the total
norm of the gradients from :attr:`parameters` is ``nan``,
``inf``, or ``-inf``. Default: False (will switch to True in the future)
foreach (bool): use the faster foreach-based implementation.
If ``None``, use the foreach implementation for CUDA and CPU native tensors and silently
fall back to the slow implementation for other device types.
Default: ``None``
Returns:
Total norm of the parameter gradients (viewed as a single vector).
"""
if isinstance(parameters, torch.Tensor):
parameters = [parameters]
grads = [p.grad.to(torch.bfloat16) for p in parameters if p.grad is not None]
max_norm = float(max_norm)
norm_type = float(norm_type)
if len(grads) == 0:
return torch.tensor(0.0)
first_device = grads[0].device
grouped_grads: Dict[
Tuple[torch.device, torch.dtype], Tuple[List[List[Tensor]], List[int]]
] = _group_tensors_by_device_and_dtype(
[grads]
) # type: ignore[assignment]
norms: List[Tensor] = []
for (device, _), ([device_grads], _) in grouped_grads.items(): # type: ignore[assignment]
if (foreach is None and _has_foreach_support(device_grads, device)) or (
foreach and _device_has_foreach_support(device)
):
norms.extend(torch._foreach_norm(device_grads, norm_type))
elif foreach:
raise RuntimeError(
f"foreach=True was passed, but can't use the foreach API on {device.type} tensors"
)
else:
norms.extend([torch.linalg.vector_norm(g, norm_type) for g in device_grads])
total_norm = torch.linalg.vector_norm(
torch.stack([norm.to(first_device) for norm in norms]), norm_type
)
if error_if_nonfinite and torch.logical_or(total_norm.isnan(), total_norm.isinf()):
raise RuntimeError(
f"The total norm of order {norm_type} for gradients from "
"`parameters` is non-finite, so it cannot be clipped. To disable "
"this error and scale the gradients by the non-finite norm anyway, "
"set `error_if_nonfinite=False`"
)
clip_coef = max_norm / (total_norm + 1e-6)
# Note: multiplying by the clamped coef is redundant when the coef is clamped to 1, but doing so
# avoids a `if clip_coef < 1:` conditional which can require a CPU <=> device synchronization
# when the gradients do not reside in CPU memory.
clip_coef_clamped = torch.clamp(clip_coef, max=1.0)
for (device, _), ([device_grads], _) in grouped_grads.items(): # type: ignore[assignment]
if (foreach is None and _has_foreach_support(device_grads, device)) or (
foreach and _device_has_foreach_support(device)
):
torch._foreach_mul_(device_grads, clip_coef_clamped.to(device))
elif foreach:
raise RuntimeError(
f"foreach=True was passed, but can't use the foreach API on {device.type} tensors"
)
else:
clip_coef_clamped_device = clip_coef_clamped.to(device)
for g in device_grads:
g.mul_(clip_coef_clamped_device)
return total_norm
def get_no_recompute_ops():
return None
@lru_cache()
def get_is_torch_run() -> bool:
return os.environ.get("LOCAL_RANK") is not None
@lru_cache()
def get_is_slurm_job() -> bool:
return "SLURM_JOB_ID" in os.environ and not get_is_torch_run()
@lru_cache()
def get_global_rank() -> int:
if get_is_torch_run():
return int(os.environ["RANK"])
elif get_is_slurm_job():
return int(os.environ["SLURM_PROCID"])
else:
return 0
@lru_cache()
def get_local_rank() -> int:
if get_is_torch_run():
return int(os.environ["LOCAL_RANK"])
elif get_is_slurm_job():
return int(os.environ["SLURM_LOCALID"])
else:
return 0
@lru_cache()
def get_world_size() -> int:
if get_is_torch_run():
return int(os.environ["WORLD_SIZE"])
elif get_is_slurm_job():
return int(os.environ["SLURM_NTASKS"])
else:
return 1
@lru_cache()
def get_is_master() -> bool:
return get_global_rank() == 0
def validate_train_args(args: TrainArgs, output_size: int):
# assert args.model is not None or args.entropy_model is not None
if args.entropy_model is not None:
logger.info(f"Setting model output size to {args.entropy_model.vocab_size}")
args.entropy_model.vocab_size = output_size
assert args.dump_dir, "Dump dir not set"
if (
args.distributed.dp_replicate
* args.distributed.dp_shard
* args.distributed.tp_size
!= get_world_size()
):
logging.info("Modifying TrainArgs distributed config")
assert get_world_size() % args.distributed.dp_shard == 0
logging.info("World size: %s", get_world_size())
logging.info(
"Existing setting: train_args.distributed.dp_shard=%s",
args.distributed.dp_shard,
)
logging.info(
"Setting train_args.distributed.dp_replicate=%s, was dp_replicate=%s",
get_world_size() // args.distributed.dp_shard,
args.distributed.dp_replicate,
)
args.distributed.dp_replicate = get_world_size() // args.distributed.dp_shard
logging.info(
"Changing dp_replicate from %s to %s, to account for tp_size=%s",
args.distributed.dp_replicate,
args.distributed.dp_replicate // args.distributed.tp_size,
args.distributed.tp_size,
)
assert args.distributed.dp_replicate % args.distributed.tp_size == 0
args.distributed.dp_replicate = (
args.distributed.dp_replicate // args.distributed.tp_size
)
logger.warning(
f"Setting Data Parallel size to {args.distributed.dp_replicate * args.distributed.dp_shard}"
)
assert (
args.distributed.dp_replicate
* args.distributed.dp_shard
* args.distributed.tp_size
== get_world_size()
)
if args.distributed.fsdp_type == "no_shard":
assert (
args.distributed.dp_shard == 1
and args.distributed.dp_replicate == get_world_size()
)
if args.model is not None:
args.model.max_seqlen = args.data.seq_len
if args.entropy_model is not None:
args.entropy_model.max_seqlen = args.data.seq_len
if args.distributed.tp_size == 1:
logger.warning(
"Tensor parallelism has not been tested for a while, use at your own risk"
)
assert (
args.probe_freq != args.profiling.mem_steps
), "Don't profile during probe step"
assert (
args.probe_freq != args.profiling.profile_steps
), "Don't profile during probe step"
if args.logging.wandb is not None:
args.logging.wandb.name = args.name
if args.probe_freq is not None:
assert (
args.distributed.tp_size == 1
), "Probing not supported with tensor parallelism"
assert (
args.distributed.selective_activation_checkpointing is False
), "Probing not supported with selective activation checkpointing"
def compute_loss(p, y, mask, scale):
tok_loss = scale * F.cross_entropy(
p.flatten(0, 1), y.flatten(0, 1), reduction="none"
)
if mask is None:
loss = tok_loss.mean()
else:
mask = mask.flatten(0, 1)
tok_loss = tok_loss * mask
loss = tok_loss.sum() / (mask.sum() + 1e-6)
return loss, tok_loss
def get_device_mesh(distributed_args):
tp_size = distributed_args.tp_size
dp_replicate = distributed_args.dp_replicate
dp_shard = distributed_args.dp_shard
assert (
dp_replicate * dp_shard * tp_size == get_world_size()
), f"dp_replicate * dp_shard * tp_size ({dp_replicate} * {dp_shard} * {tp_size}) != world_size ({get_world_size()})"
dims = []
names = []
if dp_replicate >= 1:
dims.append(dp_replicate)
names.append("dp_replicate")
if dp_shard > 1 or distributed_args.fsdp_type == "no_shard":
dims.append(dp_shard)
names.append("dp_shard")
if tp_size > 1:
dims.append(tp_size)
names.append("tp")
dims = tuple(dims)
names = tuple(names)
return init_device_mesh("cuda", mesh_shape=dims, mesh_dim_names=names)
def build_fsdp_grouping_plan():
group_plan: tuple[int, bool] = []
# Grouping and output seperately
group_plan.append(("tok_embeddings", False))
# Grouping by layers
# for i in range(model_args.n_layers):
# group_plan.append((f"layers.{i}", False))
group_plan.append(("output", True))
return group_plan
class MinimalModel(torch.nn.Module):
def __init__(self, dim: int, vocab_size: int):
super().__init__()
self.tok_embeddings = torch.nn.Embedding(vocab_size, dim)
# self.norm = RMSNorm(args.dim, eps=args.norm_eps)
# self.layers = torch.nn.ModuleList()
# for _ in range(args.n_layers):
# self.layers.append(TransformerBlock(args))
self.output = torch.nn.Linear(
dim,
vocab_size,
bias=False,
)
def forward(self, tokens):
h = self.tok_embeddings(tokens)
logits = self.output(h)
# logits = self.output(self.norm(h))
return logits
def reset_parameters(self, init_std=None):
pass
def init_weights(self):
pass
def train():
args = TrainArgs(
dump_dir="/tmp",
name="debug_bf16",
model=None,
entropy_model=None,
distributed=DistributedArgs(
fsdp_type="full_shard",
model_dtype="bf16",
matmul_allow_tf32=False,
selective_activation_checkpointing=False,
tp_size=1,
),
)
tokenizer = args.data.tokenizer_args.build()
validate_train_args(
args,
tokenizer.n_words,
)
dump_fs = fsspec.filesystem("file")
init_logger(os.path.join(args.dump_dir, "train.log"), fs=dump_fs)
setup_env(args.env)
setup_torch_distributed(args.distributed)
world_mesh = get_device_mesh(args.distributed)
logger.info(f"Starting job: {args.name}")
# build dataloader
# need dp world size and rank
dp_mesh = world_mesh["dp_replicate"]
dp_degree = dp_mesh.size()
dp_rank = dp_mesh.get_local_rank()
if args.distributed.dp_shard > 1:
dp_rank = dp_rank * dp_degree + world_mesh["dp_shard"].get_local_rank()
dp_degree *= world_mesh["dp_shard"].size()
logger.info(f"Running on dp rank : {dp_rank}")
logger.info(f"Running on dp size : {dp_degree}")
torch.manual_seed(args.seed)
logger.info("Building model")
# Initializing Model in meta device allows us to initialize models much bigger than 1 gpu's memory
with torch.device("meta"):
model = MinimalModel(768, tokenizer.n_words)
model = parallelize_model(
model,
world_mesh,
args.model,
args.distributed,
fsdp_grouping_plan=build_fsdp_grouping_plan(),
tp_parallelize=None,
no_recompute_ops=get_no_recompute_ops(),
)
# Once we shard the model on different gpus we can actually initialize the model
# First we create empty tensors of the correct shapes
model = model.to_empty(device="cuda")
# Then we init the model. Please make sure this function initializes *ALL* parameters
# and buffers, otherwise you will have random values in the unitialized tensors
# which will silently fail (give nan gradients for example)
with torch.random.fork_rng(devices=[torch.cuda.current_device()]):
torch.manual_seed(42)
model.init_weights()
check_model_value_range(model, range=10.0, std=1.0)
# data_loader = args.data.build_from_rank(dp_rank, dp_degree)
# train loop
model.train()
# data_loader = train_state.data_loader_state.build()
# batch_iterator = data_loader.create_iter()
# batch = next(batch_iterator)
# with open(f"/storage/home/par/toy-data/batch_{dp_rank}.pickle", "wb") as f:
# pickle.dump(batch, f)
with open(f"/storage/home/par/toy-data/batch_{dp_rank}.pickle", "rb") as f:
batch = pickle.load(f)
batch_x = torch.from_numpy(
batch.x,
).cuda()
batch_y = torch.from_numpy(batch.y).cuda()
mask = None if batch.mask is None else torch.from_numpy(batch.mask).cuda()
pred = model(batch_x)
loss, _ = compute_loss(pred, batch_y, mask, 1.0)
# We scale loss with grad_acc_steps so the gradient is the same
# regardless of grad_acc_steps
loss = loss / args.grad_acc_steps
# backward on scaled loss to create scaled gradients
loss.backward()
# For logging we undo that scaling
loss = loss.detach() * args.grad_acc_steps
world_size = get_world_size()
if 1 < world_size <= 8 and False:
# For some reason, there are errors in reduces due to
# not working for non-bf16 numbers. This function is a patched
# version that converts gradients to bf16 before computing norms.
# The error only happens in distributed training on one node,
# hence the guard
grad_norm = fixed_clip_grad_norm_(
model.parameters(), max_norm=args.optim.clip, foreach=True
)
else:
grad_norm = torch.nn.utils.clip_grad_norm_(
model.parameters(), max_norm=args.optim.clip, foreach=True
)
grad_norm = (
grad_norm.full_tensor() if isinstance(grad_norm, DTensor) else grad_norm
).item()
# if isinstance(data_loader, MultiprocessIterator):
# logger.info("Closing MP iterator before exiting")
# data_loader.shutdown()
def main():
"""
The command line interface here uses OmegaConf https://omegaconf.readthedocs.io/en/2.3_branch/usage.html#from-command-line-arguments
This accepts arguments as a dot list
So if the dataclass looks like
@dataclass
class DummyArgs:
name: str
model: LMTransformerArgsgs
@dataclass
class LMTransformerArgsgs:
dim: int
Then you can pass model.dim=32 to change values in LMTransformerArgsgs
or just name=tictac for top level attributes.
The behavior here is as follows:
1. We instantiate TrainArgs with its default values
2. We override those default values with the ones in the provided config file
3. We override the result with the additional arguments provided through command line
For example, if the config is the following
model:
dim: 128
n_layers: 4
and you call train.py with train.py model.dim=64
Then the final TrainArgs will have
model:
dim: 64
n_layers: 4
Plus all the default values in TrainArgs dataclass.
"""
train()
if __name__ == "__main__":
main()