diff --git a/bytelatent/distributed.py b/bytelatent/distributed.py index 168cb7c..b3f31a6 100644 --- a/bytelatent/distributed.py +++ b/bytelatent/distributed.py @@ -127,6 +127,16 @@ def dist_max(x: Union[int, float], mesh: DeviceMesh = None): return tensor +def dist_sum( + x: Union[int, float], mesh: DeviceMesh = None, reduce_dtype: torch.dtype = None +): + tensor = torch.tensor(x).cuda() + if reduce_dtype is not None: + tensor = tensor.to(reduce_dtype) + dist.all_reduce(tensor, op=ReduceOp.SUM, group=mesh.get_group() if mesh else None) + return tensor + + def dist_mean(x: Union[int, float], mesh: DeviceMesh = None): tensor = torch.tensor(x).cuda() dist.all_reduce(tensor, op=ReduceOp.AVG, group=mesh.get_group() if mesh else None) @@ -236,7 +246,7 @@ def setup_env(env_args: EnvironmentArgs): logger.warning(f"WARNING: Setting {name} to {value}") -def setup_torch_distributed(dist_args): +def setup_torch_distributed(dist_args: DistributedArgs): """ Handle single and multi-GPU / multi-node / SLURM jobs. Initialize the following variables: @@ -388,14 +398,14 @@ def clean_env(): def parallelize_model( - model, + model: torch.nn.Module, device_mesh, model_args, distributed_args: DistributedArgs, fsdp_grouping_plan: Optional[List[Tuple[str, bool]]] = None, tp_parallelize=None, no_recompute_ops=None, -): +) -> torch.nn.Module: if distributed_args.tp_size > 1: assert ( distributed_args.fsdp_type == "full_shard" diff --git a/bytelatent/train.py b/bytelatent/train.py index 2c3ea01..4641746 100644 --- a/bytelatent/train.py +++ b/bytelatent/train.py @@ -3,6 +3,7 @@ import gc import logging +import math import os import sys from contextlib import ExitStack @@ -11,6 +12,7 @@ from dataclasses import asdict, dataclass from timeit import default_timer as timer from typing import Any, TypeVar +import numpy as np import torch import torch.distributed import torch.nn.functional @@ -32,7 +34,9 @@ from bytelatent.data.iterators.packing_iterator import PackingIteratorState from bytelatent.distributed import ( check_model_value_range, clean_env, + dist_mean, dist_mean_dict, + dist_sum, get_device_mesh, get_is_master, get_world_size, @@ -391,6 +395,9 @@ def train(args: TrainArgs): time_last_log = timer() gc.collect() saved = False + step_losses: list[float] = [] + step_tok_losses: list[float] = [] + n_bytes: int = 0 while train_state.step < args.steps and ( args.max_steps is None or train_state.step < args.max_steps ): @@ -412,6 +419,24 @@ def train(args: TrainArgs): batch_patch_lengths = torch.from_numpy(batch.patch_lengths).cuda() mask = None if batch.mask is None else torch.from_numpy(batch.mask).cuda() + if args.data.tokenizer_args.name in ["bytes", "blt"]: + if mask is None: + n_bytes += batch_y.numel() + else: + n_bytes += mask.sum() + elif args.data.tokenizer_args.name in ["sp", "tiktoken"]: + for example in batch.y: + target_tokens = tokenizer.decode(example.tolist(), cut_at_eos=False) + n_bytes += ( + len(bytes(target_tokens, encoding="utf-8", errors="ignore")) + + sum(example == tokenizer.eos_id) + + sum(example == tokenizer.bos_id) + ) + else: + raise ValueError( + f"Unexpected tokenizer to count n_bytes for: {args.data.tokenizer_args.name}" + ) + if ( not args.train_entropy_model and args.model.encoder_enable_byte_ngrams @@ -486,7 +511,7 @@ def train(args: TrainArgs): batch_x, patch_lengths=batch_patch_lengths, ngram_ids=ngram_ids ) - loss, _ = compute_loss(pred, batch_y, mask, train_state.scale) + loss, tok_loss = compute_loss(pred, batch_y, mask, train_state.scale) # We scale loss with grad_acc_steps so the gradient is the same # regardless of grad_acc_steps @@ -497,6 +522,10 @@ def train(args: TrainArgs): # For logging we undo that scaling loss = loss.detach() * args.grad_acc_steps + # Undo loss scaling so downstream down't need to worry about it + step_losses.append((loss / train_state.scale).item()) + step_tok_losses.append(tok_loss / train_state.scale) + world_size = get_world_size() if 1 < world_size <= 8: # For some reason, there are errors in reduces due to @@ -597,20 +626,33 @@ def train(args: TrainArgs): gpu_memory_monitor.reset_peak_stats() nwords_since_last_log = 0 time_last_log = timer() + stacked_tok_loss = torch.cat(step_tok_losses, dim=0) + total_tok_loss = dist_sum( + stacked_tok_loss.sum().item(), reduce_dtype=torch.bfloat16 + ) + total_n_bytes = dist_sum(n_bytes, reduce_dtype=torch.bfloat16) + avg_bpb = total_tok_loss / math.log(2) / total_n_bytes + avg_loss = dist_mean(np.mean(step_losses).item()) logger.info( f"step: {train_state.step}" f" acc: {train_state.acc_step}" - f" loss: {round(loss.item(),4):>7}" + f" loss: step={round(loss.item(),4):>7} avg={avg_loss}" + f" bpb: {avg_bpb:3f}" f" grad: {grad_norm:.2e}" f" flops: {FLOPS:.2e}" f" wps: {wps:.2e}" f" iter: {curr_iter_time:>7}" f" data: {data_load_time:>5}" f" lr: {curr_lr:.2e}" + f" n_bytes={total_n_bytes}" f" mem: {gpu_mem_stats.max_active_pct:.0f}%" f" pow: {gpu_mem_stats.power_draw/1000} W" ) + n_bytes = 0 + step_losses = [] + step_tok_losses = [] + if every_n_steps( train_state, args.checkpoint.dump.every, acc_step=0 ) or every_n_steps(train_state, args.checkpoint.eval.every, acc_step=0):