mirror of
https://github.com/facebookresearch/blt.git
synced 2025-02-22 21:12:15 +00:00
parent
aebdc481a8
commit
fe45f69fbf
1
.gitignore
vendored
1
.gitignore
vendored
|
@ -167,3 +167,4 @@ figures/
|
||||||
.DS_Store
|
.DS_Store
|
||||||
internal/
|
internal/
|
||||||
jobs_parallel-copy/
|
jobs_parallel-copy/
|
||||||
|
wandb/
|
||||||
|
|
|
@ -127,6 +127,16 @@ def dist_max(x: Union[int, float], mesh: DeviceMesh = None):
|
||||||
return tensor
|
return tensor
|
||||||
|
|
||||||
|
|
||||||
|
def dist_sum(
|
||||||
|
x: Union[int, float], mesh: DeviceMesh = None, reduce_dtype: torch.dtype = None
|
||||||
|
):
|
||||||
|
tensor = torch.tensor(x).cuda()
|
||||||
|
if reduce_dtype is not None:
|
||||||
|
tensor = tensor.to(reduce_dtype)
|
||||||
|
dist.all_reduce(tensor, op=ReduceOp.SUM, group=mesh.get_group() if mesh else None)
|
||||||
|
return tensor
|
||||||
|
|
||||||
|
|
||||||
def dist_mean(x: Union[int, float], mesh: DeviceMesh = None):
|
def dist_mean(x: Union[int, float], mesh: DeviceMesh = None):
|
||||||
tensor = torch.tensor(x).cuda()
|
tensor = torch.tensor(x).cuda()
|
||||||
dist.all_reduce(tensor, op=ReduceOp.AVG, group=mesh.get_group() if mesh else None)
|
dist.all_reduce(tensor, op=ReduceOp.AVG, group=mesh.get_group() if mesh else None)
|
||||||
|
@ -236,7 +246,7 @@ def setup_env(env_args: EnvironmentArgs):
|
||||||
logger.warning(f"WARNING: Setting {name} to {value}")
|
logger.warning(f"WARNING: Setting {name} to {value}")
|
||||||
|
|
||||||
|
|
||||||
def setup_torch_distributed(dist_args):
|
def setup_torch_distributed(dist_args: DistributedArgs):
|
||||||
"""
|
"""
|
||||||
Handle single and multi-GPU / multi-node / SLURM jobs.
|
Handle single and multi-GPU / multi-node / SLURM jobs.
|
||||||
Initialize the following variables:
|
Initialize the following variables:
|
||||||
|
@ -388,14 +398,14 @@ def clean_env():
|
||||||
|
|
||||||
|
|
||||||
def parallelize_model(
|
def parallelize_model(
|
||||||
model,
|
model: torch.nn.Module,
|
||||||
device_mesh,
|
device_mesh,
|
||||||
model_args,
|
model_args,
|
||||||
distributed_args: DistributedArgs,
|
distributed_args: DistributedArgs,
|
||||||
fsdp_grouping_plan: Optional[List[Tuple[str, bool]]] = None,
|
fsdp_grouping_plan: Optional[List[Tuple[str, bool]]] = None,
|
||||||
tp_parallelize=None,
|
tp_parallelize=None,
|
||||||
no_recompute_ops=None,
|
no_recompute_ops=None,
|
||||||
):
|
) -> torch.nn.Module:
|
||||||
if distributed_args.tp_size > 1:
|
if distributed_args.tp_size > 1:
|
||||||
assert (
|
assert (
|
||||||
distributed_args.fsdp_type == "full_shard"
|
distributed_args.fsdp_type == "full_shard"
|
||||||
|
|
|
@ -49,7 +49,6 @@ class LoggingArgs(BaseModel):
|
||||||
model_config = ConfigDict(extra="forbid")
|
model_config = ConfigDict(extra="forbid")
|
||||||
freq: int = 10 # Log every freq optimizer steps
|
freq: int = 10 # Log every freq optimizer steps
|
||||||
acc_freq: int | None = None # Log every acc_freq gradient accumulation steps
|
acc_freq: int | None = None # Log every acc_freq gradient accumulation steps
|
||||||
|
|
||||||
wandb: WandbArgs | None = None
|
wandb: WandbArgs | None = None
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -3,6 +3,7 @@
|
||||||
|
|
||||||
import gc
|
import gc
|
||||||
import logging
|
import logging
|
||||||
|
import math
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
from contextlib import ExitStack
|
from contextlib import ExitStack
|
||||||
|
@ -11,6 +12,7 @@ from dataclasses import asdict, dataclass
|
||||||
from timeit import default_timer as timer
|
from timeit import default_timer as timer
|
||||||
from typing import Any, TypeVar
|
from typing import Any, TypeVar
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed
|
import torch.distributed
|
||||||
import torch.nn.functional
|
import torch.nn.functional
|
||||||
|
@ -32,7 +34,9 @@ from bytelatent.data.iterators.packing_iterator import PackingIteratorState
|
||||||
from bytelatent.distributed import (
|
from bytelatent.distributed import (
|
||||||
check_model_value_range,
|
check_model_value_range,
|
||||||
clean_env,
|
clean_env,
|
||||||
|
dist_mean,
|
||||||
dist_mean_dict,
|
dist_mean_dict,
|
||||||
|
dist_sum,
|
||||||
get_device_mesh,
|
get_device_mesh,
|
||||||
get_is_master,
|
get_is_master,
|
||||||
get_world_size,
|
get_world_size,
|
||||||
|
@ -392,6 +396,9 @@ def train(args: TrainArgs):
|
||||||
time_last_log = timer()
|
time_last_log = timer()
|
||||||
gc.collect()
|
gc.collect()
|
||||||
saved = False
|
saved = False
|
||||||
|
step_losses: list[float] = []
|
||||||
|
step_tok_losses: list[float] = []
|
||||||
|
n_bytes: int = 0
|
||||||
while train_state.step < args.steps and (
|
while train_state.step < args.steps and (
|
||||||
args.max_steps is None or train_state.step < args.max_steps
|
args.max_steps is None or train_state.step < args.max_steps
|
||||||
):
|
):
|
||||||
|
@ -413,6 +420,21 @@ def train(args: TrainArgs):
|
||||||
batch_patch_lengths = torch.from_numpy(batch.patch_lengths).cuda()
|
batch_patch_lengths = torch.from_numpy(batch.patch_lengths).cuda()
|
||||||
mask = None if batch.mask is None else torch.from_numpy(batch.mask).cuda()
|
mask = None if batch.mask is None else torch.from_numpy(batch.mask).cuda()
|
||||||
|
|
||||||
|
if args.data.tokenizer_args.name in ["bytes", "blt"]:
|
||||||
|
n_bytes += batch_y.numel() if mask is None else mask.sum()
|
||||||
|
elif args.data.tokenizer_args.name in ["sp", "tiktoken"]:
|
||||||
|
for example in batch.y:
|
||||||
|
target_tokens = tokenizer.decode(example.tolist(), cut_at_eos=False)
|
||||||
|
n_bytes += (
|
||||||
|
len(bytes(target_tokens, encoding="utf-8", errors="ignore"))
|
||||||
|
+ sum(example == tokenizer.eos_id)
|
||||||
|
+ sum(example == tokenizer.bos_id)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"Unexpected tokenizer to count n_bytes for: {args.data.tokenizer_args.name}"
|
||||||
|
)
|
||||||
|
|
||||||
if (
|
if (
|
||||||
not args.train_entropy_model
|
not args.train_entropy_model
|
||||||
and args.model.encoder_enable_byte_ngrams
|
and args.model.encoder_enable_byte_ngrams
|
||||||
|
@ -487,7 +509,7 @@ def train(args: TrainArgs):
|
||||||
batch_x, patch_lengths=batch_patch_lengths, ngram_ids=ngram_ids
|
batch_x, patch_lengths=batch_patch_lengths, ngram_ids=ngram_ids
|
||||||
)
|
)
|
||||||
|
|
||||||
loss, _ = compute_loss(pred, batch_y, mask, train_state.scale)
|
loss, tok_loss = compute_loss(pred, batch_y, mask, train_state.scale)
|
||||||
|
|
||||||
# We scale loss with grad_acc_steps so the gradient is the same
|
# We scale loss with grad_acc_steps so the gradient is the same
|
||||||
# regardless of grad_acc_steps
|
# regardless of grad_acc_steps
|
||||||
|
@ -498,6 +520,10 @@ def train(args: TrainArgs):
|
||||||
# For logging we undo that scaling
|
# For logging we undo that scaling
|
||||||
loss = loss.detach() * args.grad_acc_steps
|
loss = loss.detach() * args.grad_acc_steps
|
||||||
|
|
||||||
|
# Undo loss scaling so downstream down't need to worry about it
|
||||||
|
step_losses.append((loss / train_state.scale).item())
|
||||||
|
step_tok_losses.append(tok_loss / train_state.scale)
|
||||||
|
|
||||||
world_size = get_world_size()
|
world_size = get_world_size()
|
||||||
if 1 < world_size <= 8:
|
if 1 < world_size <= 8:
|
||||||
# For some reason, there are errors in reduces due to
|
# For some reason, there are errors in reduces due to
|
||||||
|
@ -568,8 +594,39 @@ def train(args: TrainArgs):
|
||||||
* wps
|
* wps
|
||||||
)
|
)
|
||||||
|
|
||||||
metrics = flatten_dict(
|
# Below, semantics are:
|
||||||
{
|
# per_gpu: Metrics on a given rank
|
||||||
|
# across_gpus: Metrics averaged/summed across all ranks
|
||||||
|
# step: Metric at a step
|
||||||
|
# interval: Metric averaged/summed across all steps since the last log interval.
|
||||||
|
# Typically, this is 10
|
||||||
|
step_loss_per_gpu = loss.item()
|
||||||
|
step_loss_across_gpus = dist_mean(step_loss_per_gpu).item()
|
||||||
|
interval_loss_per_gpu = np.mean(step_losses).item()
|
||||||
|
interval_loss_across_gpus = dist_mean(interval_loss_per_gpu).item()
|
||||||
|
|
||||||
|
stacked_tok_loss = torch.cat(step_tok_losses, dim=0)
|
||||||
|
interval_total_tok_loss_per_gpu = stacked_tok_loss.sum().item()
|
||||||
|
interval_total_tok_loss_across_gpus = dist_sum(
|
||||||
|
interval_total_tok_loss_per_gpu, reduce_dtype=torch.bfloat16
|
||||||
|
).item()
|
||||||
|
interval_total_n_bytes_per_gpu = n_bytes
|
||||||
|
interval_total_n_bytes_across_gpus = dist_sum(
|
||||||
|
n_bytes, reduce_dtype=torch.bfloat16
|
||||||
|
).item()
|
||||||
|
|
||||||
|
interval_bpb_per_gpu = (
|
||||||
|
interval_total_tok_loss_per_gpu
|
||||||
|
/ math.log(2)
|
||||||
|
/ interval_total_n_bytes_per_gpu
|
||||||
|
)
|
||||||
|
interval_bpb_across_gpus = (
|
||||||
|
interval_total_tok_loss_across_gpus
|
||||||
|
/ math.log(2)
|
||||||
|
/ interval_total_n_bytes_across_gpus
|
||||||
|
)
|
||||||
|
|
||||||
|
metric_dict = {
|
||||||
"global_step": train_state.step,
|
"global_step": train_state.step,
|
||||||
"acc_step": train_state.acc_step,
|
"acc_step": train_state.acc_step,
|
||||||
"speed": {
|
"speed": {
|
||||||
|
@ -584,34 +641,61 @@ def train(args: TrainArgs):
|
||||||
"total_tokens": total_tokens,
|
"total_tokens": total_tokens,
|
||||||
},
|
},
|
||||||
"memory": gpu_mem_stats._asdict(),
|
"memory": gpu_mem_stats._asdict(),
|
||||||
|
"loss": {
|
||||||
|
"step_per_gpu": step_loss_per_gpu,
|
||||||
|
"step_across_gpu": step_loss_across_gpus,
|
||||||
|
"interval_per_gpu": interval_loss_per_gpu,
|
||||||
|
"interval_across_gpu": interval_loss_across_gpus,
|
||||||
},
|
},
|
||||||
|
"bpb": {
|
||||||
|
"interval_per_gpu": interval_bpb_per_gpu,
|
||||||
|
"interval_across_gpus": interval_bpb_across_gpus,
|
||||||
|
},
|
||||||
|
"n_bytes": {
|
||||||
|
"interval_per_gpu": interval_total_n_bytes_per_gpu,
|
||||||
|
"interval_across_gpus": interval_total_n_bytes_across_gpus,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
metrics = flatten_dict(
|
||||||
|
metric_dict,
|
||||||
sep="/",
|
sep="/",
|
||||||
)
|
)
|
||||||
|
|
||||||
to_sync = {}
|
|
||||||
to_sync["loss/out"] = loss.item()
|
|
||||||
metrics.update(dist_mean_dict(to_sync))
|
|
||||||
|
|
||||||
if get_is_master():
|
if get_is_master():
|
||||||
metric_logger.log(metrics)
|
metric_logger.log(metrics)
|
||||||
|
|
||||||
gpu_memory_monitor.reset_peak_stats()
|
# Below semantics are:
|
||||||
nwords_since_last_log = 0
|
# step=Metrics at a step
|
||||||
time_last_log = timer()
|
# interval=Metrics averaged across the logging interval
|
||||||
|
# local=On one rank
|
||||||
|
# global=Across all ranks
|
||||||
logger.info(
|
logger.info(
|
||||||
f"step: {train_state.step}"
|
f"step: {train_state.step}"
|
||||||
f" acc: {train_state.acc_step}"
|
f" acc: {train_state.acc_step}"
|
||||||
f" loss: {round(loss.item(),4):>7}"
|
f" loss_gpu: {round(interval_loss_per_gpu, 4):>7}"
|
||||||
|
f" loss_avg: {round(interval_loss_across_gpus, 4):>7}"
|
||||||
|
f" bpb_gpu: {interval_bpb_per_gpu:3f}"
|
||||||
|
f" bpb_avg: {interval_bpb_across_gpus:3f}"
|
||||||
f" grad: {grad_norm:.2e}"
|
f" grad: {grad_norm:.2e}"
|
||||||
f" flops: {FLOPS:.2e}"
|
f" flops: {FLOPS:.2e}"
|
||||||
f" wps: {wps:.2e}"
|
f" wps: {wps:.2e}"
|
||||||
f" iter: {curr_iter_time:>7}"
|
f" iter: {curr_iter_time:>7}"
|
||||||
f" data: {data_load_time:>5}"
|
f" data: {data_load_time:>5}"
|
||||||
f" lr: {curr_lr:.2e}"
|
f" lr: {curr_lr:.2e}"
|
||||||
|
f" n_bytes_gpu: {int(interval_total_n_bytes_per_gpu)}"
|
||||||
|
f" n_bytes_sum: {int(interval_total_n_bytes_across_gpus)}"
|
||||||
f" mem: {gpu_mem_stats.max_active_pct:.0f}%"
|
f" mem: {gpu_mem_stats.max_active_pct:.0f}%"
|
||||||
f" pow: {gpu_mem_stats.power_draw/1000} W"
|
f" pow: {gpu_mem_stats.power_draw/1000} W"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
n_bytes = 0
|
||||||
|
step_losses = []
|
||||||
|
step_tok_losses = []
|
||||||
|
gpu_memory_monitor.reset_peak_stats()
|
||||||
|
nwords_since_last_log = 0
|
||||||
|
time_last_log = timer()
|
||||||
|
|
||||||
if every_n_steps(
|
if every_n_steps(
|
||||||
train_state, args.checkpoint.dump.every, acc_step=0
|
train_state, args.checkpoint.dump.every, acc_step=0
|
||||||
) or every_n_steps(train_state, args.checkpoint.eval.every, acc_step=0):
|
) or every_n_steps(train_state, args.checkpoint.eval.every, acc_step=0):
|
||||||
|
|
Loading…
Reference in a new issue