mirror of
https://github.com/facebookresearch/blt.git
synced 2025-02-23 13:32:14 +00:00
624 lines
20 KiB
Python
624 lines
20 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
|
|
|
|
from datetime import timedelta
|
|
from enum import Enum
|
|
from functools import lru_cache
|
|
import logging
|
|
import math
|
|
import sys
|
|
import time
|
|
from typing import Dict, List, Optional, Tuple
|
|
from torch import Tensor
|
|
import os
|
|
import pickle
|
|
|
|
import fsspec
|
|
import torch
|
|
import torch.distributed
|
|
import torch.nn.functional
|
|
import torch.nn.functional as F
|
|
from torch.distributed._tensor import DTensor
|
|
|
|
from torch.distributed.device_mesh import init_device_mesh
|
|
from torch.utils._foreach_utils import (
|
|
_device_has_foreach_support,
|
|
_group_tensors_by_device_and_dtype,
|
|
_has_foreach_support,
|
|
)
|
|
|
|
from bytelatent.args import TrainArgs
|
|
from bytelatent.distributed import (
|
|
DistributedArgs,
|
|
check_model_value_range,
|
|
parallelize_model,
|
|
setup_env,
|
|
setup_torch_distributed,
|
|
)
|
|
|
|
logger = logging.getLogger()
|
|
|
|
|
|
def set_root_log_level(log_level: str):
|
|
logger = logging.getLogger()
|
|
level: int | str = log_level.upper()
|
|
try:
|
|
level = int(log_level)
|
|
except ValueError:
|
|
pass
|
|
try:
|
|
logger.setLevel(level) # type: ignore
|
|
except Exception:
|
|
logger.warning(
|
|
f"Failed to set logging level to {log_level}, using default 'NOTSET'"
|
|
)
|
|
logger.setLevel(logging.NOTSET)
|
|
|
|
|
|
class LogFormatter(logging.Formatter):
|
|
"""
|
|
Custom logger for distributed jobs, displaying rank
|
|
and preserving indent from the custom prefix format.
|
|
"""
|
|
|
|
def __init__(self):
|
|
self.start_time = time.time()
|
|
self.rank = get_global_rank()
|
|
self.show_rank = not get_is_slurm_job() # srun has --label
|
|
|
|
def formatTime(self, record):
|
|
subsecond, seconds = math.modf(record.created)
|
|
curr_date = (
|
|
time.strftime("%y-%m-%d %H:%M:%S", time.localtime(seconds))
|
|
+ f".{int(subsecond * 1_000_000):06d}"
|
|
)
|
|
delta = timedelta(seconds=round(record.created - self.start_time))
|
|
return f"{curr_date} - {delta}"
|
|
|
|
def formatPrefix(self, record):
|
|
fmt_time = self.formatTime(record)
|
|
if self.show_rank:
|
|
return f"{self.rank}: {record.levelname:<7} {fmt_time} - "
|
|
else:
|
|
return f"{record.levelname:<7} {fmt_time} - "
|
|
|
|
def formatMessage(self, record, indent: str):
|
|
content = record.getMessage()
|
|
content = content.replace("\n", "\n" + indent)
|
|
# Exception handling as in the default formatter, albeit with indenting
|
|
# according to our custom prefix
|
|
if record.exc_info:
|
|
# Cache the traceback text to avoid converting it multiple times
|
|
# (it's constant anyway)
|
|
if not record.exc_text:
|
|
record.exc_text = self.formatException(record.exc_info)
|
|
if record.exc_text:
|
|
if content[-1:] != "\n":
|
|
content = content + "\n" + indent
|
|
content = content + indent.join(
|
|
[l + "\n" for l in record.exc_text.splitlines()]
|
|
)
|
|
if content[-1:] == "\n":
|
|
content = content[:-1]
|
|
if record.stack_info:
|
|
if content[-1:] != "\n":
|
|
content = content + "\n" + indent
|
|
stack_text = self.formatStack(record.stack_info)
|
|
content = content + indent.join([l + "\n" for l in stack_text.splitlines()])
|
|
if content[-1:] == "\n":
|
|
content = content[:-1]
|
|
|
|
return content
|
|
|
|
def format(self, record):
|
|
prefix = self.formatPrefix(record)
|
|
indent = " " * len(prefix)
|
|
content = self.formatMessage(record, indent)
|
|
return prefix + content
|
|
|
|
|
|
def init_logger(
|
|
log_file: str | None = None,
|
|
*,
|
|
name: str | None = None,
|
|
level: str = "INFO",
|
|
fs: fsspec.AbstractFileSystem | None = None,
|
|
):
|
|
"""
|
|
Setup logging.
|
|
|
|
Args:
|
|
log_file: A file name to save file logs to.
|
|
name: The name of the logger to configure, by default the root logger.
|
|
level: The logging level to use.
|
|
"""
|
|
set_root_log_level(level)
|
|
logger = logging.getLogger(name)
|
|
|
|
# stdout: everything
|
|
stdout_handler = logging.StreamHandler(sys.stdout)
|
|
stdout_handler.setLevel(logging.NOTSET)
|
|
stdout_handler.setFormatter(LogFormatter())
|
|
|
|
# stderr: warnings / errors and above
|
|
stderr_handler = logging.StreamHandler(sys.stderr)
|
|
stderr_handler.setLevel(logging.WARNING)
|
|
stderr_handler.setFormatter(LogFormatter())
|
|
|
|
# set stream handlers
|
|
logger.handlers.clear()
|
|
logger.handlers.append(stdout_handler)
|
|
logger.handlers.append(stderr_handler)
|
|
|
|
|
|
@torch.no_grad()
|
|
def fixed_clip_grad_norm_(
|
|
parameters: torch.Tensor | list[torch.Tensor],
|
|
max_norm: float,
|
|
norm_type: float = 2.0,
|
|
error_if_nonfinite: bool = False,
|
|
foreach: Optional[bool] = None,
|
|
) -> torch.Tensor:
|
|
r"""Clip the gradient norm of an iterable of parameters.
|
|
|
|
The norm is computed over the norms of the individual gradients of all parameters,
|
|
as if the norms of the individual gradients were concatenated into a single vector.
|
|
Gradients are modified in-place.
|
|
|
|
Args:
|
|
parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a
|
|
single Tensor that will have gradients normalized
|
|
max_norm (float): max norm of the gradients
|
|
norm_type (float): type of the used p-norm. Can be ``'inf'`` for
|
|
infinity norm.
|
|
error_if_nonfinite (bool): if True, an error is thrown if the total
|
|
norm of the gradients from :attr:`parameters` is ``nan``,
|
|
``inf``, or ``-inf``. Default: False (will switch to True in the future)
|
|
foreach (bool): use the faster foreach-based implementation.
|
|
If ``None``, use the foreach implementation for CUDA and CPU native tensors and silently
|
|
fall back to the slow implementation for other device types.
|
|
Default: ``None``
|
|
|
|
Returns:
|
|
Total norm of the parameter gradients (viewed as a single vector).
|
|
"""
|
|
if isinstance(parameters, torch.Tensor):
|
|
parameters = [parameters]
|
|
grads = [p.grad.to(torch.bfloat16) for p in parameters if p.grad is not None]
|
|
max_norm = float(max_norm)
|
|
norm_type = float(norm_type)
|
|
if len(grads) == 0:
|
|
return torch.tensor(0.0)
|
|
first_device = grads[0].device
|
|
grouped_grads: Dict[
|
|
Tuple[torch.device, torch.dtype], Tuple[List[List[Tensor]], List[int]]
|
|
] = _group_tensors_by_device_and_dtype(
|
|
[grads]
|
|
) # type: ignore[assignment]
|
|
|
|
norms: List[Tensor] = []
|
|
for (device, _), ([device_grads], _) in grouped_grads.items(): # type: ignore[assignment]
|
|
if (foreach is None and _has_foreach_support(device_grads, device)) or (
|
|
foreach and _device_has_foreach_support(device)
|
|
):
|
|
norms.extend(torch._foreach_norm(device_grads, norm_type))
|
|
elif foreach:
|
|
raise RuntimeError(
|
|
f"foreach=True was passed, but can't use the foreach API on {device.type} tensors"
|
|
)
|
|
else:
|
|
norms.extend([torch.linalg.vector_norm(g, norm_type) for g in device_grads])
|
|
|
|
total_norm = torch.linalg.vector_norm(
|
|
torch.stack([norm.to(first_device) for norm in norms]), norm_type
|
|
)
|
|
|
|
if error_if_nonfinite and torch.logical_or(total_norm.isnan(), total_norm.isinf()):
|
|
raise RuntimeError(
|
|
f"The total norm of order {norm_type} for gradients from "
|
|
"`parameters` is non-finite, so it cannot be clipped. To disable "
|
|
"this error and scale the gradients by the non-finite norm anyway, "
|
|
"set `error_if_nonfinite=False`"
|
|
)
|
|
clip_coef = max_norm / (total_norm + 1e-6)
|
|
# Note: multiplying by the clamped coef is redundant when the coef is clamped to 1, but doing so
|
|
# avoids a `if clip_coef < 1:` conditional which can require a CPU <=> device synchronization
|
|
# when the gradients do not reside in CPU memory.
|
|
clip_coef_clamped = torch.clamp(clip_coef, max=1.0)
|
|
for (device, _), ([device_grads], _) in grouped_grads.items(): # type: ignore[assignment]
|
|
if (foreach is None and _has_foreach_support(device_grads, device)) or (
|
|
foreach and _device_has_foreach_support(device)
|
|
):
|
|
torch._foreach_mul_(device_grads, clip_coef_clamped.to(device))
|
|
elif foreach:
|
|
raise RuntimeError(
|
|
f"foreach=True was passed, but can't use the foreach API on {device.type} tensors"
|
|
)
|
|
else:
|
|
clip_coef_clamped_device = clip_coef_clamped.to(device)
|
|
for g in device_grads:
|
|
g.mul_(clip_coef_clamped_device)
|
|
|
|
return total_norm
|
|
|
|
|
|
def get_no_recompute_ops():
|
|
return None
|
|
|
|
|
|
@lru_cache()
|
|
def get_is_torch_run() -> bool:
|
|
return os.environ.get("LOCAL_RANK") is not None
|
|
|
|
|
|
@lru_cache()
|
|
def get_is_slurm_job() -> bool:
|
|
return "SLURM_JOB_ID" in os.environ and not get_is_torch_run()
|
|
|
|
|
|
@lru_cache()
|
|
def get_global_rank() -> int:
|
|
if get_is_torch_run():
|
|
return int(os.environ["RANK"])
|
|
elif get_is_slurm_job():
|
|
return int(os.environ["SLURM_PROCID"])
|
|
else:
|
|
return 0
|
|
|
|
|
|
@lru_cache()
|
|
def get_local_rank() -> int:
|
|
if get_is_torch_run():
|
|
return int(os.environ["LOCAL_RANK"])
|
|
elif get_is_slurm_job():
|
|
return int(os.environ["SLURM_LOCALID"])
|
|
else:
|
|
return 0
|
|
|
|
|
|
@lru_cache()
|
|
def get_world_size() -> int:
|
|
if get_is_torch_run():
|
|
return int(os.environ["WORLD_SIZE"])
|
|
elif get_is_slurm_job():
|
|
return int(os.environ["SLURM_NTASKS"])
|
|
else:
|
|
return 1
|
|
|
|
|
|
@lru_cache()
|
|
def get_is_master() -> bool:
|
|
return get_global_rank() == 0
|
|
|
|
|
|
def validate_train_args(args: TrainArgs, output_size: int):
|
|
# assert args.model is not None or args.entropy_model is not None
|
|
if args.entropy_model is not None:
|
|
logger.info(f"Setting model output size to {args.entropy_model.vocab_size}")
|
|
args.entropy_model.vocab_size = output_size
|
|
|
|
assert args.dump_dir, "Dump dir not set"
|
|
|
|
if (
|
|
args.distributed.dp_replicate
|
|
* args.distributed.dp_shard
|
|
* args.distributed.tp_size
|
|
!= get_world_size()
|
|
):
|
|
logging.info("Modifying TrainArgs distributed config")
|
|
assert get_world_size() % args.distributed.dp_shard == 0
|
|
logging.info("World size: %s", get_world_size())
|
|
logging.info(
|
|
"Existing setting: train_args.distributed.dp_shard=%s",
|
|
args.distributed.dp_shard,
|
|
)
|
|
logging.info(
|
|
"Setting train_args.distributed.dp_replicate=%s, was dp_replicate=%s",
|
|
get_world_size() // args.distributed.dp_shard,
|
|
args.distributed.dp_replicate,
|
|
)
|
|
args.distributed.dp_replicate = get_world_size() // args.distributed.dp_shard
|
|
|
|
logging.info(
|
|
"Changing dp_replicate from %s to %s, to account for tp_size=%s",
|
|
args.distributed.dp_replicate,
|
|
args.distributed.dp_replicate // args.distributed.tp_size,
|
|
args.distributed.tp_size,
|
|
)
|
|
assert args.distributed.dp_replicate % args.distributed.tp_size == 0
|
|
args.distributed.dp_replicate = (
|
|
args.distributed.dp_replicate // args.distributed.tp_size
|
|
)
|
|
|
|
logger.warning(
|
|
f"Setting Data Parallel size to {args.distributed.dp_replicate * args.distributed.dp_shard}"
|
|
)
|
|
assert (
|
|
args.distributed.dp_replicate
|
|
* args.distributed.dp_shard
|
|
* args.distributed.tp_size
|
|
== get_world_size()
|
|
)
|
|
|
|
if args.distributed.fsdp_type == "no_shard":
|
|
assert (
|
|
args.distributed.dp_shard == 1
|
|
and args.distributed.dp_replicate == get_world_size()
|
|
)
|
|
|
|
if args.model is not None:
|
|
args.model.max_seqlen = args.data.seq_len
|
|
if args.entropy_model is not None:
|
|
args.entropy_model.max_seqlen = args.data.seq_len
|
|
|
|
if args.distributed.tp_size == 1:
|
|
logger.warning(
|
|
"Tensor parallelism has not been tested for a while, use at your own risk"
|
|
)
|
|
|
|
assert (
|
|
args.probe_freq != args.profiling.mem_steps
|
|
), "Don't profile during probe step"
|
|
assert (
|
|
args.probe_freq != args.profiling.profile_steps
|
|
), "Don't profile during probe step"
|
|
if args.logging.wandb is not None:
|
|
args.logging.wandb.name = args.name
|
|
|
|
if args.probe_freq is not None:
|
|
assert (
|
|
args.distributed.tp_size == 1
|
|
), "Probing not supported with tensor parallelism"
|
|
assert (
|
|
args.distributed.selective_activation_checkpointing is False
|
|
), "Probing not supported with selective activation checkpointing"
|
|
|
|
|
|
def compute_loss(p, y, mask, scale):
|
|
tok_loss = scale * F.cross_entropy(
|
|
p.flatten(0, 1), y.flatten(0, 1), reduction="none"
|
|
)
|
|
if mask is None:
|
|
loss = tok_loss.mean()
|
|
else:
|
|
mask = mask.flatten(0, 1)
|
|
tok_loss = tok_loss * mask
|
|
loss = tok_loss.sum() / (mask.sum() + 1e-6)
|
|
return loss, tok_loss
|
|
|
|
|
|
def get_device_mesh(distributed_args):
|
|
tp_size = distributed_args.tp_size
|
|
dp_replicate = distributed_args.dp_replicate
|
|
dp_shard = distributed_args.dp_shard
|
|
|
|
assert (
|
|
dp_replicate * dp_shard * tp_size == get_world_size()
|
|
), f"dp_replicate * dp_shard * tp_size ({dp_replicate} * {dp_shard} * {tp_size}) != world_size ({get_world_size()})"
|
|
|
|
dims = []
|
|
names = []
|
|
if dp_replicate >= 1:
|
|
dims.append(dp_replicate)
|
|
names.append("dp_replicate")
|
|
if dp_shard > 1 or distributed_args.fsdp_type == "no_shard":
|
|
dims.append(dp_shard)
|
|
names.append("dp_shard")
|
|
if tp_size > 1:
|
|
dims.append(tp_size)
|
|
names.append("tp")
|
|
dims = tuple(dims)
|
|
names = tuple(names)
|
|
|
|
return init_device_mesh("cuda", mesh_shape=dims, mesh_dim_names=names)
|
|
|
|
|
|
def build_fsdp_grouping_plan():
|
|
group_plan: tuple[int, bool] = []
|
|
|
|
# Grouping and output seperately
|
|
group_plan.append(("tok_embeddings", False))
|
|
|
|
# Grouping by layers
|
|
# for i in range(model_args.n_layers):
|
|
# group_plan.append((f"layers.{i}", False))
|
|
|
|
group_plan.append(("output", True))
|
|
|
|
return group_plan
|
|
|
|
|
|
class MinimalModel(torch.nn.Module):
|
|
def __init__(self, dim: int, vocab_size: int):
|
|
super().__init__()
|
|
self.tok_embeddings = torch.nn.Embedding(vocab_size, dim)
|
|
|
|
# self.norm = RMSNorm(args.dim, eps=args.norm_eps)
|
|
# self.layers = torch.nn.ModuleList()
|
|
# for _ in range(args.n_layers):
|
|
# self.layers.append(TransformerBlock(args))
|
|
|
|
self.output = torch.nn.Linear(
|
|
dim,
|
|
vocab_size,
|
|
bias=False,
|
|
)
|
|
|
|
def forward(self, tokens):
|
|
h = self.tok_embeddings(tokens)
|
|
logits = self.output(h)
|
|
# logits = self.output(self.norm(h))
|
|
return logits
|
|
|
|
def reset_parameters(self, init_std=None):
|
|
pass
|
|
|
|
def init_weights(self):
|
|
pass
|
|
|
|
|
|
def train():
|
|
args = TrainArgs(
|
|
dump_dir="/tmp",
|
|
name="debug_bf16",
|
|
model=None,
|
|
entropy_model=None,
|
|
distributed=DistributedArgs(
|
|
fsdp_type="full_shard",
|
|
model_dtype="bf16",
|
|
matmul_allow_tf32=False,
|
|
selective_activation_checkpointing=False,
|
|
tp_size=1,
|
|
),
|
|
)
|
|
tokenizer = args.data.tokenizer_args.build()
|
|
validate_train_args(
|
|
args,
|
|
tokenizer.n_words,
|
|
)
|
|
dump_fs = fsspec.filesystem("file")
|
|
init_logger(os.path.join(args.dump_dir, "train.log"), fs=dump_fs)
|
|
setup_env(args.env)
|
|
setup_torch_distributed(args.distributed)
|
|
world_mesh = get_device_mesh(args.distributed)
|
|
logger.info(f"Starting job: {args.name}")
|
|
|
|
# build dataloader
|
|
# need dp world size and rank
|
|
dp_mesh = world_mesh["dp_replicate"]
|
|
dp_degree = dp_mesh.size()
|
|
dp_rank = dp_mesh.get_local_rank()
|
|
if args.distributed.dp_shard > 1:
|
|
dp_rank = dp_rank * dp_degree + world_mesh["dp_shard"].get_local_rank()
|
|
dp_degree *= world_mesh["dp_shard"].size()
|
|
|
|
logger.info(f"Running on dp rank : {dp_rank}")
|
|
logger.info(f"Running on dp size : {dp_degree}")
|
|
|
|
torch.manual_seed(args.seed)
|
|
logger.info("Building model")
|
|
|
|
# Initializing Model in meta device allows us to initialize models much bigger than 1 gpu's memory
|
|
with torch.device("meta"):
|
|
model = MinimalModel(768, tokenizer.n_words)
|
|
|
|
model = parallelize_model(
|
|
model,
|
|
world_mesh,
|
|
args.model,
|
|
args.distributed,
|
|
fsdp_grouping_plan=build_fsdp_grouping_plan(),
|
|
tp_parallelize=None,
|
|
no_recompute_ops=get_no_recompute_ops(),
|
|
)
|
|
|
|
# Once we shard the model on different gpus we can actually initialize the model
|
|
# First we create empty tensors of the correct shapes
|
|
model = model.to_empty(device="cuda")
|
|
# Then we init the model. Please make sure this function initializes *ALL* parameters
|
|
# and buffers, otherwise you will have random values in the unitialized tensors
|
|
# which will silently fail (give nan gradients for example)
|
|
|
|
with torch.random.fork_rng(devices=[torch.cuda.current_device()]):
|
|
torch.manual_seed(42)
|
|
model.init_weights()
|
|
check_model_value_range(model, range=10.0, std=1.0)
|
|
|
|
# data_loader = args.data.build_from_rank(dp_rank, dp_degree)
|
|
|
|
# train loop
|
|
model.train()
|
|
# data_loader = train_state.data_loader_state.build()
|
|
# batch_iterator = data_loader.create_iter()
|
|
# batch = next(batch_iterator)
|
|
# with open(f"/storage/home/par/toy-data/batch_{dp_rank}.pickle", "wb") as f:
|
|
# pickle.dump(batch, f)
|
|
with open(f"/storage/home/par/toy-data/batch_{dp_rank}.pickle", "rb") as f:
|
|
batch = pickle.load(f)
|
|
|
|
batch_x = torch.from_numpy(
|
|
batch.x,
|
|
).cuda()
|
|
batch_y = torch.from_numpy(batch.y).cuda()
|
|
mask = None if batch.mask is None else torch.from_numpy(batch.mask).cuda()
|
|
pred = model(batch_x)
|
|
loss, _ = compute_loss(pred, batch_y, mask, 1.0)
|
|
|
|
# We scale loss with grad_acc_steps so the gradient is the same
|
|
# regardless of grad_acc_steps
|
|
loss = loss / args.grad_acc_steps
|
|
|
|
# backward on scaled loss to create scaled gradients
|
|
loss.backward()
|
|
# For logging we undo that scaling
|
|
loss = loss.detach() * args.grad_acc_steps
|
|
|
|
world_size = get_world_size()
|
|
if 1 < world_size <= 8 and False:
|
|
# For some reason, there are errors in reduces due to
|
|
# not working for non-bf16 numbers. This function is a patched
|
|
# version that converts gradients to bf16 before computing norms.
|
|
# The error only happens in distributed training on one node,
|
|
# hence the guard
|
|
grad_norm = fixed_clip_grad_norm_(
|
|
model.parameters(), max_norm=args.optim.clip, foreach=True
|
|
)
|
|
else:
|
|
grad_norm = torch.nn.utils.clip_grad_norm_(
|
|
model.parameters(), max_norm=args.optim.clip, foreach=True
|
|
)
|
|
|
|
grad_norm = (
|
|
grad_norm.full_tensor() if isinstance(grad_norm, DTensor) else grad_norm
|
|
).item()
|
|
|
|
# if isinstance(data_loader, MultiprocessIterator):
|
|
# logger.info("Closing MP iterator before exiting")
|
|
# data_loader.shutdown()
|
|
|
|
|
|
def main():
|
|
"""
|
|
The command line interface here uses OmegaConf https://omegaconf.readthedocs.io/en/2.3_branch/usage.html#from-command-line-arguments
|
|
This accepts arguments as a dot list
|
|
So if the dataclass looks like
|
|
|
|
@dataclass
|
|
class DummyArgs:
|
|
name: str
|
|
model: LMTransformerArgsgs
|
|
|
|
@dataclass
|
|
class LMTransformerArgsgs:
|
|
dim: int
|
|
|
|
Then you can pass model.dim=32 to change values in LMTransformerArgsgs
|
|
or just name=tictac for top level attributes.
|
|
|
|
The behavior here is as follows:
|
|
1. We instantiate TrainArgs with its default values
|
|
2. We override those default values with the ones in the provided config file
|
|
3. We override the result with the additional arguments provided through command line
|
|
|
|
For example, if the config is the following
|
|
|
|
model:
|
|
dim: 128
|
|
n_layers: 4
|
|
|
|
and you call train.py with train.py model.dim=64
|
|
|
|
Then the final TrainArgs will have
|
|
|
|
model:
|
|
dim: 64
|
|
n_layers: 4
|
|
|
|
Plus all the default values in TrainArgs dataclass.
|
|
"""
|
|
train()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|