mirror of
https://github.com/facebookresearch/blt.git
synced 2025-01-18 08:27:45 +00:00
233 lines
6.7 KiB
Python
233 lines
6.7 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
|
|
|
|
import json
|
|
import logging
|
|
from collections import namedtuple
|
|
from dataclasses import asdict
|
|
from datetime import datetime, timezone
|
|
from pathlib import Path
|
|
from typing import Any, Union
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import wandb
|
|
from pydantic import BaseModel, ConfigDict
|
|
|
|
from bytelatent.distributed import get_is_master
|
|
|
|
logger = logging.getLogger()
|
|
|
|
|
|
class WandbArgs(BaseModel):
|
|
model_config = ConfigDict(extra="forbid")
|
|
job_type: str | None = None
|
|
dir: str | None = None
|
|
project: str | None = None
|
|
entity: str | None = None
|
|
tags: list | None = None
|
|
group: str | None = None
|
|
name: str | None = None
|
|
notes: str | None = None
|
|
config_exclude_keys: list[str] | None = None
|
|
config_include_keys: list[str] | None = None
|
|
anonymous: str | None = None
|
|
mode: str | None = None
|
|
allow_val_change: bool | None = None
|
|
resume: Union[bool, str] | None = None
|
|
force: bool | None = None
|
|
tensorboard: bool | None = None
|
|
sync_tensorboard: bool | None = None
|
|
monitor_gym: bool | None = None
|
|
save_code: bool | None = None
|
|
id: str | None = None
|
|
fork_from: str | None = None
|
|
resume_from: str | None = None
|
|
|
|
|
|
class LoggingArgs(BaseModel):
|
|
model_config = ConfigDict(extra="forbid")
|
|
freq: int = 10 # Log every freq optimizer steps
|
|
acc_freq: int | None = None # Log every acc_freq gradient accumulation steps
|
|
|
|
wandb: WandbArgs | None = None
|
|
|
|
|
|
class MetricLogger:
|
|
def __init__(self, outdir: Path, args: Any | None = None):
|
|
self.outdir = outdir
|
|
self.jsonl_writer = None
|
|
self.args = args
|
|
|
|
def open(self):
|
|
if self.jsonl_writer is None:
|
|
self.jsonl_writer = open(self.outdir, "a")
|
|
if (
|
|
self.args is not None
|
|
and self.args.logging.wandb is not None
|
|
and get_is_master()
|
|
):
|
|
run = wandb.init(
|
|
config=asdict(self.args),
|
|
**asdict(self.args.logging.wandb),
|
|
)
|
|
|
|
def log(self, metrics: dict[str, Any]):
|
|
if (
|
|
self.args is not None
|
|
and self.args.logging.wandb is not None
|
|
and (wandb.run is not None)
|
|
):
|
|
wandb.log(metrics, step=metrics["global_step"])
|
|
|
|
metrics.update({"created_at": datetime.now(timezone.utc).isoformat()})
|
|
print(json.dumps(metrics), file=self.jsonl_writer, flush=True)
|
|
|
|
def close(self):
|
|
if self.jsonl_writer is not None:
|
|
self.jsonl_writer.close()
|
|
self.jsonl_writer = None
|
|
|
|
def __enter__(self):
|
|
self.open()
|
|
return self
|
|
|
|
def __exit__(self, exc_type, exc_value, traceback):
|
|
self.close()
|
|
|
|
def __del__(self):
|
|
self.close()
|
|
|
|
|
|
GPUMemStats = namedtuple(
|
|
"GPUMemStats",
|
|
[
|
|
"max_active_gib",
|
|
"max_active_pct",
|
|
"max_reserved_gib",
|
|
"max_reserved_pct",
|
|
"num_alloc_retries",
|
|
"num_ooms",
|
|
"power_draw",
|
|
],
|
|
)
|
|
|
|
|
|
class GPUMemoryMonitor:
|
|
"""
|
|
Class to monitor GPU memory usage
|
|
"""
|
|
|
|
def __init__(self, device: str = "cuda:0"):
|
|
self.device = torch.device(device) # device object
|
|
self.device_name = torch.cuda.get_device_name(self.device)
|
|
self.device_index = torch.cuda.current_device()
|
|
self.device_capacity = torch.cuda.get_device_properties(
|
|
self.device
|
|
).total_memory
|
|
self.device_capacity_gib = self._to_gib(self.device_capacity)
|
|
|
|
# reset stats, clear cache
|
|
torch.cuda.reset_peak_memory_stats()
|
|
torch.cuda.empty_cache()
|
|
|
|
def _to_gib(self, memory_in_bytes):
|
|
# NOTE: GiB (gibibyte) is 1024, vs GB is 1000
|
|
_gib_in_bytes = 1024 * 1024 * 1024
|
|
memory_in_gib = memory_in_bytes / _gib_in_bytes
|
|
return memory_in_gib
|
|
|
|
def _to_pct(self, memory):
|
|
return 100 * memory / self.device_capacity
|
|
|
|
def get_peak_stats(self):
|
|
cuda_info = torch.cuda.memory_stats(self.device)
|
|
|
|
max_active = cuda_info["active_bytes.all.peak"]
|
|
max_active_gib = self._to_gib(max_active)
|
|
max_active_pct = self._to_pct(max_active)
|
|
|
|
max_reserved = cuda_info["reserved_bytes.all.peak"]
|
|
max_reserved_gib = self._to_gib(max_reserved)
|
|
max_reserved_pct = self._to_pct(max_reserved)
|
|
|
|
num_retries = cuda_info["num_alloc_retries"]
|
|
num_ooms = cuda_info["num_ooms"]
|
|
power_draw = torch.cuda.power_draw()
|
|
|
|
if num_retries > 0:
|
|
logger.warning(f"{num_retries} CUDA memory allocation retries.")
|
|
if num_ooms > 0:
|
|
logger.warning(f"{num_ooms} CUDA OOM errors thrown.")
|
|
|
|
return GPUMemStats(
|
|
max_active_gib,
|
|
max_active_pct,
|
|
max_reserved_gib,
|
|
max_reserved_pct,
|
|
num_retries,
|
|
num_ooms,
|
|
power_draw,
|
|
)
|
|
|
|
def reset_peak_stats(self):
|
|
torch.cuda.reset_peak_memory_stats()
|
|
torch.cuda.reset_accumulated_memory_stats()
|
|
|
|
def __str__(self):
|
|
mem_stats = self.get_peak_stats()
|
|
display_str = f"{self.device_name} ({self.device_index}): {self.device_capacity_gib} GiB capacity, "
|
|
display_str += (
|
|
f"{mem_stats.max_reserved_gib} GiB peak, {mem_stats.max_reserved_pct}% peak"
|
|
)
|
|
return f"{display_str}"
|
|
|
|
|
|
def upload_train_to_wandb(
|
|
ckpt_dir, project="lingua", entity="codegen-team", train=True, eval=True
|
|
):
|
|
import json
|
|
from pathlib import Path
|
|
|
|
import wandb
|
|
from omegaconf import OmegaConf
|
|
|
|
cfg = OmegaConf.load(Path(ckpt_dir) / "config.yaml")
|
|
cfg = OmegaConf.to_container(cfg)
|
|
|
|
if train:
|
|
wandb.init(config=cfg, name=cfg["name"], project=project, entity=entity)
|
|
|
|
with open(Path(ckpt_dir) / "metrics.jsonl") as f:
|
|
for l in f:
|
|
m = json.loads(l)
|
|
wandb.log(m, step=m["global_step"])
|
|
|
|
wandb.finish()
|
|
|
|
if eval:
|
|
wandb.init(config=cfg, name=cfg["name"], project=project, entity=entity)
|
|
|
|
with open(Path(ckpt_dir) / "metrics.eval.jsonl") as f:
|
|
for l in f:
|
|
m = json.loads(l)
|
|
wandb.log(
|
|
{
|
|
f"evals/{name.replace('/','.')}": value
|
|
for name, value in m.items()
|
|
if "/" in name
|
|
},
|
|
step=m["global_step"],
|
|
)
|
|
|
|
wandb.finish()
|
|
|
|
|
|
def get_num_params(model: nn.Module) -> int:
|
|
"""
|
|
Get the total model params
|
|
Args : only_trainable: whether to only count trainable params
|
|
"""
|
|
numel = {n: p.numel() for n, p in model.named_parameters()}
|
|
return sum(numel.values())
|