blt/bytelatent/metrics.py

243 lines
6.9 KiB
Python
Raw Normal View History

2024-12-12 23:32:30 +00:00
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
import json
import logging
from collections import namedtuple
from datetime import datetime, timezone
from pathlib import Path
from typing import Any, Union
import fsspec
2024-12-12 23:32:30 +00:00
import torch
import torch.nn as nn
import wandb
from pydantic import BaseModel, ConfigDict
from bytelatent.distributed import get_is_master
logger = logging.getLogger()
class WandbArgs(BaseModel):
model_config = ConfigDict(extra="forbid")
job_type: str | None = None
dir: str | None = None
project: str | None = None
entity: str | None = None
tags: list | None = None
group: str | None = None
name: str | None = None
notes: str | None = None
config_exclude_keys: list[str] | None = None
config_include_keys: list[str] | None = None
anonymous: str | None = None
mode: str | None = None
allow_val_change: bool | None = None
resume: Union[bool, str] | None = None
force: bool | None = None
tensorboard: bool | None = None
sync_tensorboard: bool | None = None
monitor_gym: bool | None = None
save_code: bool | None = None
id: str | None = None
fork_from: str | None = None
resume_from: str | None = None
class LoggingArgs(BaseModel):
model_config = ConfigDict(extra="forbid")
freq: int = 10 # Log every freq optimizer steps
acc_freq: int | None = None # Log every acc_freq gradient accumulation steps
wandb: WandbArgs | None = None
class MetricLogger:
def __init__(
self,
outdir: Path,
# args: TrainArgs
args: Any | None = None,
fs: fsspec.AbstractFileSystem | None = None,
):
2024-12-12 23:32:30 +00:00
self.outdir = outdir
self.jsonl_writer = None
self.fs = fs
2024-12-12 23:32:30 +00:00
self.args = args
def open(self):
if self.jsonl_writer is None:
if self.fs is None:
self.jsonl_writer = open(self.outdir, "a")
else:
self.jsonl_writer = self.fs.open(self.outdir, "a")
2024-12-12 23:32:30 +00:00
if (
self.args is not None
and self.args.logging.wandb is not None
and get_is_master()
):
run = wandb.init(
config=self.args.model_dump(),
**self.args.logging.wandb.model_dump(),
2024-12-12 23:32:30 +00:00
)
def log(self, metrics: dict[str, Any]):
if (
self.args is not None
and self.args.logging.wandb is not None
and (wandb.run is not None)
):
wandb.log(metrics, step=metrics["global_step"])
metrics.update({"created_at": datetime.now(timezone.utc).isoformat()})
print(json.dumps(metrics), file=self.jsonl_writer, flush=True)
def close(self):
if self.jsonl_writer is not None:
self.jsonl_writer.close()
self.jsonl_writer = None
def __enter__(self):
self.open()
return self
def __exit__(self, exc_type, exc_value, traceback):
self.close()
def __del__(self):
self.close()
GPUMemStats = namedtuple(
"GPUMemStats",
[
"max_active_gib",
"max_active_pct",
"max_reserved_gib",
"max_reserved_pct",
"num_alloc_retries",
"num_ooms",
"power_draw",
],
)
class GPUMemoryMonitor:
"""
Class to monitor GPU memory usage
"""
def __init__(self, device: str = "cuda:0"):
self.device = torch.device(device) # device object
self.device_name = torch.cuda.get_device_name(self.device)
self.device_index = torch.cuda.current_device()
self.device_capacity = torch.cuda.get_device_properties(
self.device
).total_memory
self.device_capacity_gib = self._to_gib(self.device_capacity)
# reset stats, clear cache
torch.cuda.reset_peak_memory_stats()
torch.cuda.empty_cache()
def _to_gib(self, memory_in_bytes):
# NOTE: GiB (gibibyte) is 1024, vs GB is 1000
_gib_in_bytes = 1024 * 1024 * 1024
memory_in_gib = memory_in_bytes / _gib_in_bytes
return memory_in_gib
def _to_pct(self, memory):
return 100 * memory / self.device_capacity
def get_peak_stats(self):
cuda_info = torch.cuda.memory_stats(self.device)
max_active = cuda_info["active_bytes.all.peak"]
max_active_gib = self._to_gib(max_active)
max_active_pct = self._to_pct(max_active)
max_reserved = cuda_info["reserved_bytes.all.peak"]
max_reserved_gib = self._to_gib(max_reserved)
max_reserved_pct = self._to_pct(max_reserved)
num_retries = cuda_info["num_alloc_retries"]
num_ooms = cuda_info["num_ooms"]
power_draw = torch.cuda.power_draw()
if num_retries > 0:
logger.warning(f"{num_retries} CUDA memory allocation retries.")
if num_ooms > 0:
logger.warning(f"{num_ooms} CUDA OOM errors thrown.")
return GPUMemStats(
max_active_gib,
max_active_pct,
max_reserved_gib,
max_reserved_pct,
num_retries,
num_ooms,
power_draw,
)
def reset_peak_stats(self):
torch.cuda.reset_peak_memory_stats()
torch.cuda.reset_accumulated_memory_stats()
def __str__(self):
mem_stats = self.get_peak_stats()
display_str = f"{self.device_name} ({self.device_index}): {self.device_capacity_gib} GiB capacity, "
display_str += (
f"{mem_stats.max_reserved_gib} GiB peak, {mem_stats.max_reserved_pct}% peak"
)
return f"{display_str}"
def upload_train_to_wandb(
ckpt_dir, project="lingua", entity="codegen-team", train=True, eval=True
):
import json
from pathlib import Path
import wandb
from omegaconf import OmegaConf
cfg = OmegaConf.load(Path(ckpt_dir) / "config.yaml")
cfg = OmegaConf.to_container(cfg)
if train:
wandb.init(config=cfg, name=cfg["name"], project=project, entity=entity)
with open(Path(ckpt_dir) / "metrics.jsonl") as f:
for l in f:
m = json.loads(l)
wandb.log(m, step=m["global_step"])
wandb.finish()
if eval:
wandb.init(config=cfg, name=cfg["name"], project=project, entity=entity)
with open(Path(ckpt_dir) / "metrics.eval.jsonl") as f:
for l in f:
m = json.loads(l)
wandb.log(
{
f"evals/{name.replace('/','.')}": value
for name, value in m.items()
if "/" in name
},
step=m["global_step"],
)
wandb.finish()
def get_num_params(model: nn.Module) -> int:
"""
Get the total model params
Args : only_trainable: whether to only count trainable params
"""
numel = {n: p.numel() for n, p in model.named_parameters()}
return sum(numel.values())