Add bpb and n_bytes to metric logging

Summary:

Test Plan:
This commit is contained in:
Pedro Rodriguez 2025-02-05 22:26:31 +00:00
parent 1450464031
commit 2f42633b07
2 changed files with 57 additions and 5 deletions

View file

@ -127,6 +127,16 @@ def dist_max(x: Union[int, float], mesh: DeviceMesh = None):
return tensor
def dist_sum(
x: Union[int, float], mesh: DeviceMesh = None, reduce_dtype: torch.dtype = None
):
tensor = torch.tensor(x).cuda()
if reduce_dtype is not None:
tensor = tensor.to(reduce_dtype)
dist.all_reduce(tensor, op=ReduceOp.SUM, group=mesh.get_group() if mesh else None)
return tensor
def dist_mean(x: Union[int, float], mesh: DeviceMesh = None):
tensor = torch.tensor(x).cuda()
dist.all_reduce(tensor, op=ReduceOp.AVG, group=mesh.get_group() if mesh else None)
@ -236,7 +246,7 @@ def setup_env(env_args: EnvironmentArgs):
logger.warning(f"WARNING: Setting {name} to {value}")
def setup_torch_distributed(dist_args):
def setup_torch_distributed(dist_args: DistributedArgs):
"""
Handle single and multi-GPU / multi-node / SLURM jobs.
Initialize the following variables:
@ -388,14 +398,14 @@ def clean_env():
def parallelize_model(
model,
model: torch.nn.Module,
device_mesh,
model_args,
distributed_args: DistributedArgs,
fsdp_grouping_plan: Optional[List[Tuple[str, bool]]] = None,
tp_parallelize=None,
no_recompute_ops=None,
):
) -> torch.nn.Module:
if distributed_args.tp_size > 1:
assert (
distributed_args.fsdp_type == "full_shard"

View file

@ -3,6 +3,7 @@
import gc
import logging
import math
import os
import sys
from contextlib import ExitStack
@ -11,6 +12,7 @@ from dataclasses import asdict, dataclass
from timeit import default_timer as timer
from typing import Any, TypeVar
import numpy as np
import torch
import torch.distributed
import torch.nn.functional
@ -32,7 +34,9 @@ from bytelatent.data.iterators.packing_iterator import PackingIteratorState
from bytelatent.distributed import (
check_model_value_range,
clean_env,
dist_mean,
dist_mean_dict,
dist_sum,
get_device_mesh,
get_is_master,
get_world_size,
@ -391,6 +395,9 @@ def train(args: TrainArgs):
time_last_log = timer()
gc.collect()
saved = False
step_losses: list[float] = []
step_tok_losses: list[float] = []
n_bytes: int = 0
while train_state.step < args.steps and (
args.max_steps is None or train_state.step < args.max_steps
):
@ -412,6 +419,24 @@ def train(args: TrainArgs):
batch_patch_lengths = torch.from_numpy(batch.patch_lengths).cuda()
mask = None if batch.mask is None else torch.from_numpy(batch.mask).cuda()
if args.data.tokenizer_args.name in ["bytes", "blt"]:
if mask is None:
n_bytes += batch_y.numel()
else:
n_bytes += mask.sum()
elif args.data.tokenizer_args.name in ["sp", "tiktoken"]:
for example in batch.y:
target_tokens = tokenizer.decode(example.tolist(), cut_at_eos=False)
n_bytes += (
len(bytes(target_tokens, encoding="utf-8", errors="ignore"))
+ sum(example == tokenizer.eos_id)
+ sum(example == tokenizer.bos_id)
)
else:
raise ValueError(
f"Unexpected tokenizer to count n_bytes for: {args.data.tokenizer_args.name}"
)
if (
not args.train_entropy_model
and args.model.encoder_enable_byte_ngrams
@ -486,7 +511,7 @@ def train(args: TrainArgs):
batch_x, patch_lengths=batch_patch_lengths, ngram_ids=ngram_ids
)
loss, _ = compute_loss(pred, batch_y, mask, train_state.scale)
loss, tok_loss = compute_loss(pred, batch_y, mask, train_state.scale)
# We scale loss with grad_acc_steps so the gradient is the same
# regardless of grad_acc_steps
@ -497,6 +522,10 @@ def train(args: TrainArgs):
# For logging we undo that scaling
loss = loss.detach() * args.grad_acc_steps
# Undo loss scaling so downstream down't need to worry about it
step_losses.append((loss / train_state.scale).item())
step_tok_losses.append(tok_loss / train_state.scale)
world_size = get_world_size()
if 1 < world_size <= 8:
# For some reason, there are errors in reduces due to
@ -597,20 +626,33 @@ def train(args: TrainArgs):
gpu_memory_monitor.reset_peak_stats()
nwords_since_last_log = 0
time_last_log = timer()
stacked_tok_loss = torch.cat(step_tok_losses, dim=0)
total_tok_loss = dist_sum(
stacked_tok_loss.sum().item(), reduce_dtype=torch.bfloat16
)
total_n_bytes = dist_sum(n_bytes, reduce_dtype=torch.bfloat16)
avg_bpb = total_tok_loss / math.log(2) / total_n_bytes
avg_loss = dist_mean(np.mean(step_losses).item())
logger.info(
f"step: {train_state.step}"
f" acc: {train_state.acc_step}"
f" loss: {round(loss.item(),4):>7}"
f" loss: step={round(loss.item(),4):>7} avg={avg_loss}"
f" bpb: {avg_bpb:3f}"
f" grad: {grad_norm:.2e}"
f" flops: {FLOPS:.2e}"
f" wps: {wps:.2e}"
f" iter: {curr_iter_time:>7}"
f" data: {data_load_time:>5}"
f" lr: {curr_lr:.2e}"
f" n_bytes={total_n_bytes}"
f" mem: {gpu_mem_stats.max_active_pct:.0f}%"
f" pow: {gpu_mem_stats.power_draw/1000} W"
)
n_bytes = 0
step_losses = []
step_tok_losses = []
if every_n_steps(
train_state, args.checkpoint.dump.every, acc_step=0
) or every_n_steps(train_state, args.checkpoint.eval.every, acc_step=0):