diff --git a/bytelatent/broken_train.py b/bytelatent/broken_train.py new file mode 100644 index 0000000..e1630a7 --- /dev/null +++ b/bytelatent/broken_train.py @@ -0,0 +1,623 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. + +from datetime import timedelta +from enum import Enum +from functools import lru_cache +import logging +import math +import sys +import time +from typing import Dict, List, Optional, Tuple +from torch import Tensor +import os +import pickle + +import fsspec +import torch +import torch.distributed +import torch.nn.functional +import torch.nn.functional as F +from torch.distributed._tensor import DTensor + +from torch.distributed.device_mesh import init_device_mesh +from torch.utils._foreach_utils import ( + _device_has_foreach_support, + _group_tensors_by_device_and_dtype, + _has_foreach_support, +) + +from bytelatent.args import TrainArgs +from bytelatent.distributed import ( + DistributedArgs, + check_model_value_range, + parallelize_model, + setup_env, + setup_torch_distributed, +) + +logger = logging.getLogger() + + +def set_root_log_level(log_level: str): + logger = logging.getLogger() + level: int | str = log_level.upper() + try: + level = int(log_level) + except ValueError: + pass + try: + logger.setLevel(level) # type: ignore + except Exception: + logger.warning( + f"Failed to set logging level to {log_level}, using default 'NOTSET'" + ) + logger.setLevel(logging.NOTSET) + + +class LogFormatter(logging.Formatter): + """ + Custom logger for distributed jobs, displaying rank + and preserving indent from the custom prefix format. + """ + + def __init__(self): + self.start_time = time.time() + self.rank = get_global_rank() + self.show_rank = not get_is_slurm_job() # srun has --label + + def formatTime(self, record): + subsecond, seconds = math.modf(record.created) + curr_date = ( + time.strftime("%y-%m-%d %H:%M:%S", time.localtime(seconds)) + + f".{int(subsecond * 1_000_000):06d}" + ) + delta = timedelta(seconds=round(record.created - self.start_time)) + return f"{curr_date} - {delta}" + + def formatPrefix(self, record): + fmt_time = self.formatTime(record) + if self.show_rank: + return f"{self.rank}: {record.levelname:<7} {fmt_time} - " + else: + return f"{record.levelname:<7} {fmt_time} - " + + def formatMessage(self, record, indent: str): + content = record.getMessage() + content = content.replace("\n", "\n" + indent) + # Exception handling as in the default formatter, albeit with indenting + # according to our custom prefix + if record.exc_info: + # Cache the traceback text to avoid converting it multiple times + # (it's constant anyway) + if not record.exc_text: + record.exc_text = self.formatException(record.exc_info) + if record.exc_text: + if content[-1:] != "\n": + content = content + "\n" + indent + content = content + indent.join( + [l + "\n" for l in record.exc_text.splitlines()] + ) + if content[-1:] == "\n": + content = content[:-1] + if record.stack_info: + if content[-1:] != "\n": + content = content + "\n" + indent + stack_text = self.formatStack(record.stack_info) + content = content + indent.join([l + "\n" for l in stack_text.splitlines()]) + if content[-1:] == "\n": + content = content[:-1] + + return content + + def format(self, record): + prefix = self.formatPrefix(record) + indent = " " * len(prefix) + content = self.formatMessage(record, indent) + return prefix + content + + +def init_logger( + log_file: str | None = None, + *, + name: str | None = None, + level: str = "INFO", + fs: fsspec.AbstractFileSystem | None = None, +): + """ + Setup logging. + + Args: + log_file: A file name to save file logs to. + name: The name of the logger to configure, by default the root logger. + level: The logging level to use. + """ + set_root_log_level(level) + logger = logging.getLogger(name) + + # stdout: everything + stdout_handler = logging.StreamHandler(sys.stdout) + stdout_handler.setLevel(logging.NOTSET) + stdout_handler.setFormatter(LogFormatter()) + + # stderr: warnings / errors and above + stderr_handler = logging.StreamHandler(sys.stderr) + stderr_handler.setLevel(logging.WARNING) + stderr_handler.setFormatter(LogFormatter()) + + # set stream handlers + logger.handlers.clear() + logger.handlers.append(stdout_handler) + logger.handlers.append(stderr_handler) + + +@torch.no_grad() +def fixed_clip_grad_norm_( + parameters: torch.Tensor | list[torch.Tensor], + max_norm: float, + norm_type: float = 2.0, + error_if_nonfinite: bool = False, + foreach: Optional[bool] = None, +) -> torch.Tensor: + r"""Clip the gradient norm of an iterable of parameters. + + The norm is computed over the norms of the individual gradients of all parameters, + as if the norms of the individual gradients were concatenated into a single vector. + Gradients are modified in-place. + + Args: + parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a + single Tensor that will have gradients normalized + max_norm (float): max norm of the gradients + norm_type (float): type of the used p-norm. Can be ``'inf'`` for + infinity norm. + error_if_nonfinite (bool): if True, an error is thrown if the total + norm of the gradients from :attr:`parameters` is ``nan``, + ``inf``, or ``-inf``. Default: False (will switch to True in the future) + foreach (bool): use the faster foreach-based implementation. + If ``None``, use the foreach implementation for CUDA and CPU native tensors and silently + fall back to the slow implementation for other device types. + Default: ``None`` + + Returns: + Total norm of the parameter gradients (viewed as a single vector). + """ + if isinstance(parameters, torch.Tensor): + parameters = [parameters] + grads = [p.grad.to(torch.bfloat16) for p in parameters if p.grad is not None] + max_norm = float(max_norm) + norm_type = float(norm_type) + if len(grads) == 0: + return torch.tensor(0.0) + first_device = grads[0].device + grouped_grads: Dict[ + Tuple[torch.device, torch.dtype], Tuple[List[List[Tensor]], List[int]] + ] = _group_tensors_by_device_and_dtype( + [grads] + ) # type: ignore[assignment] + + norms: List[Tensor] = [] + for (device, _), ([device_grads], _) in grouped_grads.items(): # type: ignore[assignment] + if (foreach is None and _has_foreach_support(device_grads, device)) or ( + foreach and _device_has_foreach_support(device) + ): + norms.extend(torch._foreach_norm(device_grads, norm_type)) + elif foreach: + raise RuntimeError( + f"foreach=True was passed, but can't use the foreach API on {device.type} tensors" + ) + else: + norms.extend([torch.linalg.vector_norm(g, norm_type) for g in device_grads]) + + total_norm = torch.linalg.vector_norm( + torch.stack([norm.to(first_device) for norm in norms]), norm_type + ) + + if error_if_nonfinite and torch.logical_or(total_norm.isnan(), total_norm.isinf()): + raise RuntimeError( + f"The total norm of order {norm_type} for gradients from " + "`parameters` is non-finite, so it cannot be clipped. To disable " + "this error and scale the gradients by the non-finite norm anyway, " + "set `error_if_nonfinite=False`" + ) + clip_coef = max_norm / (total_norm + 1e-6) + # Note: multiplying by the clamped coef is redundant when the coef is clamped to 1, but doing so + # avoids a `if clip_coef < 1:` conditional which can require a CPU <=> device synchronization + # when the gradients do not reside in CPU memory. + clip_coef_clamped = torch.clamp(clip_coef, max=1.0) + for (device, _), ([device_grads], _) in grouped_grads.items(): # type: ignore[assignment] + if (foreach is None and _has_foreach_support(device_grads, device)) or ( + foreach and _device_has_foreach_support(device) + ): + torch._foreach_mul_(device_grads, clip_coef_clamped.to(device)) + elif foreach: + raise RuntimeError( + f"foreach=True was passed, but can't use the foreach API on {device.type} tensors" + ) + else: + clip_coef_clamped_device = clip_coef_clamped.to(device) + for g in device_grads: + g.mul_(clip_coef_clamped_device) + + return total_norm + + +def get_no_recompute_ops(): + return None + + +@lru_cache() +def get_is_torch_run() -> bool: + return os.environ.get("LOCAL_RANK") is not None + + +@lru_cache() +def get_is_slurm_job() -> bool: + return "SLURM_JOB_ID" in os.environ and not get_is_torch_run() + + +@lru_cache() +def get_global_rank() -> int: + if get_is_torch_run(): + return int(os.environ["RANK"]) + elif get_is_slurm_job(): + return int(os.environ["SLURM_PROCID"]) + else: + return 0 + + +@lru_cache() +def get_local_rank() -> int: + if get_is_torch_run(): + return int(os.environ["LOCAL_RANK"]) + elif get_is_slurm_job(): + return int(os.environ["SLURM_LOCALID"]) + else: + return 0 + + +@lru_cache() +def get_world_size() -> int: + if get_is_torch_run(): + return int(os.environ["WORLD_SIZE"]) + elif get_is_slurm_job(): + return int(os.environ["SLURM_NTASKS"]) + else: + return 1 + + +@lru_cache() +def get_is_master() -> bool: + return get_global_rank() == 0 + + +def validate_train_args(args: TrainArgs, output_size: int): + # assert args.model is not None or args.entropy_model is not None + if args.entropy_model is not None: + logger.info(f"Setting model output size to {args.entropy_model.vocab_size}") + args.entropy_model.vocab_size = output_size + + assert args.dump_dir, "Dump dir not set" + + if ( + args.distributed.dp_replicate + * args.distributed.dp_shard + * args.distributed.tp_size + != get_world_size() + ): + logging.info("Modifying TrainArgs distributed config") + assert get_world_size() % args.distributed.dp_shard == 0 + logging.info("World size: %s", get_world_size()) + logging.info( + "Existing setting: train_args.distributed.dp_shard=%s", + args.distributed.dp_shard, + ) + logging.info( + "Setting train_args.distributed.dp_replicate=%s, was dp_replicate=%s", + get_world_size() // args.distributed.dp_shard, + args.distributed.dp_replicate, + ) + args.distributed.dp_replicate = get_world_size() // args.distributed.dp_shard + + logging.info( + "Changing dp_replicate from %s to %s, to account for tp_size=%s", + args.distributed.dp_replicate, + args.distributed.dp_replicate // args.distributed.tp_size, + args.distributed.tp_size, + ) + assert args.distributed.dp_replicate % args.distributed.tp_size == 0 + args.distributed.dp_replicate = ( + args.distributed.dp_replicate // args.distributed.tp_size + ) + + logger.warning( + f"Setting Data Parallel size to {args.distributed.dp_replicate * args.distributed.dp_shard}" + ) + assert ( + args.distributed.dp_replicate + * args.distributed.dp_shard + * args.distributed.tp_size + == get_world_size() + ) + + if args.distributed.fsdp_type == "no_shard": + assert ( + args.distributed.dp_shard == 1 + and args.distributed.dp_replicate == get_world_size() + ) + + if args.model is not None: + args.model.max_seqlen = args.data.seq_len + if args.entropy_model is not None: + args.entropy_model.max_seqlen = args.data.seq_len + + if args.distributed.tp_size == 1: + logger.warning( + "Tensor parallelism has not been tested for a while, use at your own risk" + ) + + assert ( + args.probe_freq != args.profiling.mem_steps + ), "Don't profile during probe step" + assert ( + args.probe_freq != args.profiling.profile_steps + ), "Don't profile during probe step" + if args.logging.wandb is not None: + args.logging.wandb.name = args.name + + if args.probe_freq is not None: + assert ( + args.distributed.tp_size == 1 + ), "Probing not supported with tensor parallelism" + assert ( + args.distributed.selective_activation_checkpointing is False + ), "Probing not supported with selective activation checkpointing" + + +def compute_loss(p, y, mask, scale): + tok_loss = scale * F.cross_entropy( + p.flatten(0, 1), y.flatten(0, 1), reduction="none" + ) + if mask is None: + loss = tok_loss.mean() + else: + mask = mask.flatten(0, 1) + tok_loss = tok_loss * mask + loss = tok_loss.sum() / (mask.sum() + 1e-6) + return loss, tok_loss + + +def get_device_mesh(distributed_args): + tp_size = distributed_args.tp_size + dp_replicate = distributed_args.dp_replicate + dp_shard = distributed_args.dp_shard + + assert ( + dp_replicate * dp_shard * tp_size == get_world_size() + ), f"dp_replicate * dp_shard * tp_size ({dp_replicate} * {dp_shard} * {tp_size}) != world_size ({get_world_size()})" + + dims = [] + names = [] + if dp_replicate >= 1: + dims.append(dp_replicate) + names.append("dp_replicate") + if dp_shard > 1 or distributed_args.fsdp_type == "no_shard": + dims.append(dp_shard) + names.append("dp_shard") + if tp_size > 1: + dims.append(tp_size) + names.append("tp") + dims = tuple(dims) + names = tuple(names) + + return init_device_mesh("cuda", mesh_shape=dims, mesh_dim_names=names) + + +def build_fsdp_grouping_plan(): + group_plan: tuple[int, bool] = [] + + # Grouping and output seperately + group_plan.append(("tok_embeddings", False)) + + # Grouping by layers + # for i in range(model_args.n_layers): + # group_plan.append((f"layers.{i}", False)) + + group_plan.append(("output", True)) + + return group_plan + + +class MinimalModel(torch.nn.Module): + def __init__(self, dim: int, vocab_size: int): + super().__init__() + self.tok_embeddings = torch.nn.Embedding(vocab_size, dim) + + # self.norm = RMSNorm(args.dim, eps=args.norm_eps) + # self.layers = torch.nn.ModuleList() + # for _ in range(args.n_layers): + # self.layers.append(TransformerBlock(args)) + + self.output = torch.nn.Linear( + dim, + vocab_size, + bias=False, + ) + + def forward(self, tokens): + h = self.tok_embeddings(tokens) + logits = self.output(h) + # logits = self.output(self.norm(h)) + return logits + + def reset_parameters(self, init_std=None): + pass + + def init_weights(self): + pass + + +def train(): + args = TrainArgs( + dump_dir="/tmp", + name="debug_bf16", + model=None, + entropy_model=None, + distributed=DistributedArgs( + fsdp_type="full_shard", + model_dtype="bf16", + matmul_allow_tf32=False, + selective_activation_checkpointing=False, + tp_size=1, + ), + ) + tokenizer = args.data.tokenizer_args.build() + validate_train_args( + args, + tokenizer.n_words, + ) + dump_fs = fsspec.filesystem("file") + init_logger(os.path.join(args.dump_dir, "train.log"), fs=dump_fs) + setup_env(args.env) + setup_torch_distributed(args.distributed) + world_mesh = get_device_mesh(args.distributed) + logger.info(f"Starting job: {args.name}") + + # build dataloader + # need dp world size and rank + dp_mesh = world_mesh["dp_replicate"] + dp_degree = dp_mesh.size() + dp_rank = dp_mesh.get_local_rank() + if args.distributed.dp_shard > 1: + dp_rank = dp_rank * dp_degree + world_mesh["dp_shard"].get_local_rank() + dp_degree *= world_mesh["dp_shard"].size() + + logger.info(f"Running on dp rank : {dp_rank}") + logger.info(f"Running on dp size : {dp_degree}") + + torch.manual_seed(args.seed) + logger.info("Building model") + + # Initializing Model in meta device allows us to initialize models much bigger than 1 gpu's memory + with torch.device("meta"): + model = MinimalModel(768, tokenizer.n_words) + + model = parallelize_model( + model, + world_mesh, + args.model, + args.distributed, + fsdp_grouping_plan=build_fsdp_grouping_plan(), + tp_parallelize=None, + no_recompute_ops=get_no_recompute_ops(), + ) + + # Once we shard the model on different gpus we can actually initialize the model + # First we create empty tensors of the correct shapes + model = model.to_empty(device="cuda") + # Then we init the model. Please make sure this function initializes *ALL* parameters + # and buffers, otherwise you will have random values in the unitialized tensors + # which will silently fail (give nan gradients for example) + + with torch.random.fork_rng(devices=[torch.cuda.current_device()]): + torch.manual_seed(42) + model.init_weights() + check_model_value_range(model, range=10.0, std=1.0) + + # data_loader = args.data.build_from_rank(dp_rank, dp_degree) + + # train loop + model.train() + # data_loader = train_state.data_loader_state.build() + # batch_iterator = data_loader.create_iter() + # batch = next(batch_iterator) + # with open(f"/storage/home/par/toy-data/batch_{dp_rank}.pickle", "wb") as f: + # pickle.dump(batch, f) + with open(f"/storage/home/par/toy-data/batch_{dp_rank}.pickle", "rb") as f: + batch = pickle.load(f) + + batch_x = torch.from_numpy( + batch.x, + ).cuda() + batch_y = torch.from_numpy(batch.y).cuda() + mask = None if batch.mask is None else torch.from_numpy(batch.mask).cuda() + pred = model(batch_x) + loss, _ = compute_loss(pred, batch_y, mask, 1.0) + + # We scale loss with grad_acc_steps so the gradient is the same + # regardless of grad_acc_steps + loss = loss / args.grad_acc_steps + + # backward on scaled loss to create scaled gradients + loss.backward() + # For logging we undo that scaling + loss = loss.detach() * args.grad_acc_steps + + world_size = get_world_size() + if 1 < world_size <= 8 and False: + # For some reason, there are errors in reduces due to + # not working for non-bf16 numbers. This function is a patched + # version that converts gradients to bf16 before computing norms. + # The error only happens in distributed training on one node, + # hence the guard + grad_norm = fixed_clip_grad_norm_( + model.parameters(), max_norm=args.optim.clip, foreach=True + ) + else: + grad_norm = torch.nn.utils.clip_grad_norm_( + model.parameters(), max_norm=args.optim.clip, foreach=True + ) + + grad_norm = ( + grad_norm.full_tensor() if isinstance(grad_norm, DTensor) else grad_norm + ).item() + + # if isinstance(data_loader, MultiprocessIterator): + # logger.info("Closing MP iterator before exiting") + # data_loader.shutdown() + + +def main(): + """ + The command line interface here uses OmegaConf https://omegaconf.readthedocs.io/en/2.3_branch/usage.html#from-command-line-arguments + This accepts arguments as a dot list + So if the dataclass looks like + + @dataclass + class DummyArgs: + name: str + model: LMTransformerArgsgs + + @dataclass + class LMTransformerArgsgs: + dim: int + + Then you can pass model.dim=32 to change values in LMTransformerArgsgs + or just name=tictac for top level attributes. + + The behavior here is as follows: + 1. We instantiate TrainArgs with its default values + 2. We override those default values with the ones in the provided config file + 3. We override the result with the additional arguments provided through command line + + For example, if the config is the following + + model: + dim: 128 + n_layers: 4 + + and you call train.py with train.py model.dim=64 + + Then the final TrainArgs will have + + model: + dim: 64 + n_layers: 4 + + Plus all the default values in TrainArgs dataclass. + """ + train() + + +if __name__ == "__main__": + main() diff --git a/bytelatent/train.py b/bytelatent/train.py index 4641746..a775e46 100644 --- a/bytelatent/train.py +++ b/bytelatent/train.py @@ -527,7 +527,7 @@ def train(args: TrainArgs): step_tok_losses.append(tok_loss / train_state.scale) world_size = get_world_size() - if 1 < world_size <= 8: + if 1 < world_size <= 8 and False: # For some reason, there are errors in reduces due to # not working for non-bf16 numbers. This function is a patched # version that converts gradients to bf16 before computing norms.