blt/bytelatent/checkpoint.py

312 lines
10 KiB
Python
Raw Permalink Normal View History

2024-12-12 23:32:30 +00:00
# Copyright (c) Meta Platforms, Inc. and affiliates.
import json
import logging
import os
import re
from pathlib import Path
from typing import List, Optional, Tuple
import torch
import torch.distributed as dist
import torch.distributed.checkpoint as dcp
import torch.nn as nn
import torch.optim.optimizer
from pydantic import BaseModel, ConfigDict
from torch.distributed._tensor import DeviceMesh
from torch.distributed.checkpoint.format_utils import dcp_to_torch_save
from torch.distributed.checkpoint.state_dict import (
get_model_state_dict,
get_state_dict,
set_state_dict,
)
from bytelatent.distributed import get_is_master
logger = logging.getLogger("CHECKPOINT")
FOLDER_NAME = "{:010d}"
RE_FOLDER = r"\d{10}"
RE_CKPT = r"__\d_\d\.distcp"
CONSOLIDATE_FOLDER = "consolidated"
CONSOLIDATE_NAME = "consolidated.pth"
CONFIG_NAME = "params.json"
TRAIN_STATE_NAME = "train_state_{:05d}.json"
RE_DIGITS = re.compile(r"\d+")
class SaveEvery(BaseModel):
model_config = ConfigDict(extra="forbid")
every: int = 1000
keep: int = 0
class CheckpointArgs(BaseModel):
model_config = ConfigDict(extra="forbid")
dump: SaveEvery = SaveEvery()
eval: SaveEvery = SaveEvery()
path: str | None = None
init_ckpt_path: str | None = None
continue_training_from_init: bool = False
def _get_key_step(name: str):
return int(re.findall(RE_DIGITS, name)[-1])
def consolidate_checkpoints(ckpt_dir: str):
"""
Consolidates all FSDP checkpoints in a directory to a single file
Consolidate checkpoint is saved in a subdirectory of ckpt_dir
Parameters:
ckpt_dir: str - path to the directory containing the checkpoints
Returns the path to the consolidated checkpoint
"""
consolidate_path = Path(ckpt_dir) / CONSOLIDATE_FOLDER
if not (consolidate_path / CONSOLIDATE_NAME).exists():
consolidate_path.mkdir(exist_ok=True)
logger.info(f"Consolidating to: {str(consolidate_path)}")
dcp_to_torch_save(ckpt_dir, str(consolidate_path / CONSOLIDATE_NAME))
(consolidate_path / CONFIG_NAME).write_text(
(Path(ckpt_dir) / CONFIG_NAME).read_text()
)
logger.info("Consolidated !")
return consolidate_path
def load_from_checkpoint(
ckpt_dir: str,
model: nn.Module,
optimizer: Optional[torch.optim.Optimizer] = None,
model_key: str = "model",
optim_key: str = "optim",
):
if not (Path(ckpt_dir) / ".metadata").exists():
raise ValueError(
f"Please convert the checkpoint distcp format using `torch.distributed.checkpoint.format_utils.torch_save_to_dcp` before loading it"
)
state_dict = {}
if optimizer is not None:
state_dict[model_key], state_dict[optim_key] = get_state_dict(model, optimizer)
else:
state_dict[model_key] = get_model_state_dict(model)
if model_key == "": # If only loading a model directly, the key should be empty
state_dict = state_dict.pop(model_key)
dcp.load(state_dict, checkpoint_id=ckpt_dir)
class CheckpointManager:
def __init__(self, args: CheckpointArgs):
self.path = args.path
self.dump_every = args.dump
self.eval_every = args.eval
self.init_ckpt_path = args.init_ckpt_path
self.continue_training_from_init = args.continue_training_from_init
assert os.path.exists(
self.path
), f"Path {self.path} does not exist and needs to be created before using CheckpointManager (use instantiate_and_make_dir)"
self.existing_saves = self.get_existing_saves()
def get_existing_saves(self) -> List[Path]:
folders = [
p
for p in Path(self.path).iterdir()
if p.is_dir() and re.match(RE_FOLDER, p.name)
]
folders.sort(key=lambda p: _get_key_step(p.name))
return folders
def clean_up(self):
logger.info("Cleaning up checkpoints...")
dump_folders = []
eval_folders = []
other_folders = []
for p in self.existing_saves:
is_dump = _get_key_step(p.name) % self.dump_every.every == 0
is_eval = _get_key_step(p.name) % self.eval_every.every == 0
if is_dump:
dump_folders.append(p)
if is_eval:
eval_folders.append(p)
if not (is_dump or is_eval):
other_folders.append(p)
logger.info(f"Dump folders: {dump_folders}")
logger.info(f"Eval folders: {eval_folders}")
logger.info(f"Other folders: {other_folders}")
if self.dump_every.keep > 0:
dump_folders = dump_folders[-self.dump_every.keep :]
if self.eval_every.keep > 0:
eval_folders = eval_folders[-self.eval_every.keep :]
folder_to_keep = set(other_folders + dump_folders + eval_folders)
folder_to_remove = set(self.existing_saves) - folder_to_keep
logger.info(f"Removing folders: {folder_to_remove}")
if dist.get_rank() == 0:
for folder in folder_to_remove:
for file in folder.iterdir():
if file.is_file():
file.unlink()
elif file.is_dir():
assert file.name in [CONSOLIDATE_FOLDER]
for f in file.iterdir():
f.unlink()
file.rmdir()
folder.rmdir()
dist.barrier()
self.existing_saves = list(folder_to_keep)
self.existing_saves.sort(key=lambda p: _get_key_step(p.name))
def get_last_step_path(self, dp_rank: int = 0) -> Optional[Path]:
path = None
for p in reversed(self.existing_saves):
if (p / TRAIN_STATE_NAME.format(dp_rank)).is_file():
path = p
break
return path
def _create_folder(self, base_path: Path, folder_name: str) -> Path:
folder = base_path / folder_name
if get_is_master():
folder.mkdir(parents=False, exist_ok=True)
if dist.is_initialized():
dist.barrier()
return folder
def _get_dp_tp_mesh(
self, device_mesh: Optional[DeviceMesh] = None
) -> Tuple[int, int]:
dp_rank = 0
tp_rank = 0
if device_mesh is not None:
if "dp_replicate" in device_mesh.mesh_dim_names:
dp_rank = device_mesh.get_local_rank("dp_replicate")
if "dp_shard" in device_mesh.mesh_dim_names:
dp_rank = dp_rank * device_mesh[
"dp_replicate"
].size() + device_mesh.get_local_rank("dp_shard")
if "tp" in device_mesh.mesh_dim_names:
tp_rank = device_mesh.get_local_rank("tp")
return dp_rank, tp_rank
@torch.no_grad()
def get_state_dict(
self,
model,
optimizer,
):
model_sd, optim_sd = get_state_dict(model, optimizer)
return {"model": model_sd, "optim": optim_sd}
def save(
self,
model,
optimizer,
train_state,
config,
device_mesh: Optional[DeviceMesh] = None,
) -> bool:
# When creating directory check if only rank0 or is there other solution
path = Path(self.path)
curr_save_dir = self._create_folder(path, FOLDER_NAME.format(train_state.step))
logger.info(f"Saving to: {str(curr_save_dir)}")
if dist.is_initialized():
dist.barrier()
logger.info("Saving...")
state_dict = self.get_state_dict(model, optimizer)
dcp.save(state_dict, checkpoint_id=curr_save_dir)
logger.info("State dict saved!")
if dist.is_initialized():
dist.barrier()
if get_is_master():
config.dump_to_yaml_file(curr_save_dir / CONFIG_NAME)
# Add json dump here
dp_rank, tp_rank = self._get_dp_tp_mesh(device_mesh)
if tp_rank == 0:
train_state_name = TRAIN_STATE_NAME.format(dp_rank)
logger.info(
f"Saving train state to: {str(curr_save_dir / train_state_name)}"
)
with open(curr_save_dir / train_state_name, "w") as f:
json.dump(train_state.state_dict(), f)
logger.info("Train state saved !")
self.existing_saves.append(curr_save_dir)
self.clean_up()
if dist.is_initialized():
dist.barrier()
return True
@torch.no_grad()
def load(
self,
model: nn.Module,
optimizer,
train_state,
device_mesh: DeviceMesh,
path: Optional[Path] = None,
):
dp_rank, tp_rank = self._get_dp_tp_mesh(device_mesh)
# Loading tries to load the provided path, if not available the last saved step and finally from the init path
path = path or self.get_last_step_path(dp_rank=dp_rank)
# If none of those are available don't do anything
if path is None:
# If no checkpoints exist do nothing
return
# Only load train state if it's provided, the files exist and we're not loading from init path
train_state_name = TRAIN_STATE_NAME.format(dp_rank)
logger.info("Reloading train state")
with open(path / train_state_name, "r") as f:
train_state_dict = json.load(f)
train_state.load_state_dict(train_state_dict)
logger.info("Train state reloaded")
logger.info(f"Loading from: {str(path)}")
state_dict = self.get_state_dict(
model=model,
optimizer=optimizer,
)
dcp.load(state_dict, checkpoint_id=path)
logger.info("State dict loaded.")
logger.info("Reloading model and optim")
set_state_dict(
model,
optimizer,
model_state_dict=state_dict["model"],
optim_state_dict=state_dict["optim"],
)
logger.info("Model and optim reloaded")
@classmethod
def instantiate_and_make_dir(cls, args: CheckpointArgs):
if get_is_master():
os.makedirs(args.path, exist_ok=True)
dist.barrier()
return cls(args)