mirror of
https://github.com/facebookresearch/blt.git
synced 2025-02-22 21:12:15 +00:00
Merge b6e53f1d4c
into sapling-pr-archive-EntilZha
This commit is contained in:
commit
c3d7f720f0
|
@ -4,8 +4,6 @@ import json
|
|||
import logging
|
||||
import os
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import fsspec
|
||||
import torch
|
||||
|
@ -70,26 +68,29 @@ def consolidate_checkpoints(fs: fsspec.AbstractFileSystem, ckpt_dir: str):
|
|||
|
||||
Returns the path to the consolidated checkpoint
|
||||
"""
|
||||
consolidate_path = Path(ckpt_dir) / CONSOLIDATE_FOLDER
|
||||
if not (consolidate_path / CONSOLIDATE_NAME).exists():
|
||||
consolidate_path.mkdir(exist_ok=True)
|
||||
logger.info(f"Consolidating to: {str(consolidate_path)}")
|
||||
dcp_to_torch_save(ckpt_dir, str(consolidate_path / CONSOLIDATE_NAME))
|
||||
(consolidate_path / CONFIG_NAME).write_text(
|
||||
(Path(ckpt_dir) / CONFIG_NAME).read_text()
|
||||
consolidate_path = os.path.join(ckpt_dir, CONSOLIDATE_FOLDER)
|
||||
consolidate_name = os.path.join(consolidate_path, CONSOLIDATE_NAME)
|
||||
if not fs.exists(consolidate_name):
|
||||
fs.mkdirs(consolidate_path, exist_ok=True)
|
||||
logger.info(f"Consolidating to: {consolidate_path}")
|
||||
dcp_to_torch_save(ckpt_dir, consolidate_name)
|
||||
fs.write_text(
|
||||
os.path.join(consolidate_path, CONFIG_NAME),
|
||||
fs.read_text(os.path.join(ckpt_dir, CONFIG_NAME)),
|
||||
)
|
||||
logger.info("Consolidated !")
|
||||
return consolidate_path
|
||||
|
||||
|
||||
def load_from_checkpoint(
|
||||
fs: fsspec.AbstractFileSystem,
|
||||
ckpt_dir: str,
|
||||
model: nn.Module,
|
||||
optimizer: Optional[torch.optim.Optimizer] = None,
|
||||
optimizer: torch.optim.Optimizer | None = None,
|
||||
model_key: str = "model",
|
||||
optim_key: str = "optim",
|
||||
):
|
||||
if not (Path(ckpt_dir) / ".metadata").exists():
|
||||
if not fs.exists(os.path.join(ckpt_dir, ".metadata")):
|
||||
raise ValueError(
|
||||
f"Please convert the checkpoint distcp format using `torch.distributed.checkpoint.format_utils.torch_save_to_dcp` before loading it"
|
||||
)
|
||||
|
@ -121,13 +122,13 @@ class CheckpointManager:
|
|||
|
||||
self.existing_saves = self.get_existing_saves()
|
||||
|
||||
def get_existing_saves(self) -> List[Path]:
|
||||
def get_existing_saves(self) -> list[str]:
|
||||
folders = [
|
||||
p
|
||||
for p in Path(self.path).iterdir()
|
||||
if p.is_dir() and re.match(RE_FOLDER, p.name)
|
||||
for p in self.fs.ls(self.path)
|
||||
if self.fs.isdir(p) and re.match(RE_FOLDER, os.path.basename(p))
|
||||
]
|
||||
folders.sort(key=lambda p: _get_key_step(p.name))
|
||||
folders.sort(key=lambda p: _get_key_step(os.path.basename(p)))
|
||||
return folders
|
||||
|
||||
def clean_up(self):
|
||||
|
@ -136,8 +137,9 @@ class CheckpointManager:
|
|||
eval_folders = []
|
||||
other_folders = []
|
||||
for p in self.existing_saves:
|
||||
is_dump = _get_key_step(p.name) % self.dump_every.every == 0
|
||||
is_eval = _get_key_step(p.name) % self.eval_every.every == 0
|
||||
assert isinstance(p, str), f"Base path type: {p}"
|
||||
is_dump = _get_key_step(os.path.basename(p)) % self.dump_every.every == 0
|
||||
is_eval = _get_key_step(os.path.basename(p)) % self.eval_every.every == 0
|
||||
if is_dump:
|
||||
dump_folders.append(p)
|
||||
if is_eval:
|
||||
|
@ -161,40 +163,39 @@ class CheckpointManager:
|
|||
|
||||
if dist.get_rank() == 0:
|
||||
for folder in folder_to_remove:
|
||||
for file in folder.iterdir():
|
||||
if file.is_file():
|
||||
file.unlink()
|
||||
elif file.is_dir():
|
||||
assert file.name in [CONSOLIDATE_FOLDER]
|
||||
for f in file.iterdir():
|
||||
f.unlink()
|
||||
file.rmdir()
|
||||
folder.rmdir()
|
||||
for file in self.fs.ls(folder):
|
||||
if self.fs.isfile(file):
|
||||
self.fs.rm_file(file)
|
||||
elif self.fs.isdir(file):
|
||||
assert os.path.name(file) in [CONSOLIDATE_FOLDER]
|
||||
for f in self.fs.ls(file):
|
||||
self.fs.rm(f)
|
||||
self.fs.rmdir(file)
|
||||
self.fs.rmdir(folder)
|
||||
|
||||
dist.barrier()
|
||||
|
||||
self.existing_saves = list(folder_to_keep)
|
||||
self.existing_saves.sort(key=lambda p: _get_key_step(p.name))
|
||||
self.existing_saves.sort(key=lambda p: _get_key_step(os.path.basename(p)))
|
||||
|
||||
def get_last_step_path(self, dp_rank: int = 0) -> Optional[Path]:
|
||||
def get_last_step_path(self, dp_rank: int = 0) -> str | None:
|
||||
path = None
|
||||
for p in reversed(self.existing_saves):
|
||||
if (p / TRAIN_STATE_NAME.format(dp_rank)).is_file():
|
||||
|
||||
if self.fs.isfile(os.path.join(p, TRAIN_STATE_NAME.format(dp_rank))):
|
||||
path = p
|
||||
break
|
||||
return path
|
||||
|
||||
def _create_folder(self, base_path: Path, folder_name: str) -> Path:
|
||||
folder = base_path / folder_name
|
||||
def _create_folder(self, base_path: str, folder_name: str) -> str:
|
||||
folder = os.path.join(base_path, folder_name)
|
||||
if get_is_master():
|
||||
folder.mkdir(parents=False, exist_ok=True)
|
||||
self.fs.mkdirs(folder, exist_ok=True)
|
||||
if dist.is_initialized():
|
||||
dist.barrier()
|
||||
return folder
|
||||
|
||||
def _get_dp_tp_mesh(
|
||||
self, device_mesh: Optional[DeviceMesh] = None
|
||||
) -> Tuple[int, int]:
|
||||
def _get_dp_tp_mesh(self, device_mesh: DeviceMesh | None = None) -> tuple[int, int]:
|
||||
dp_rank = 0
|
||||
tp_rank = 0
|
||||
if device_mesh is not None:
|
||||
|
@ -222,14 +223,14 @@ class CheckpointManager:
|
|||
model,
|
||||
optimizer,
|
||||
train_state,
|
||||
config,
|
||||
device_mesh: Optional[DeviceMesh] = None,
|
||||
config: BaseModel,
|
||||
device_mesh: DeviceMesh | None = None,
|
||||
) -> bool:
|
||||
|
||||
# When creating directory check if only rank0 or is there other solution
|
||||
path = Path(self.path)
|
||||
path = self.path
|
||||
curr_save_dir = self._create_folder(path, FOLDER_NAME.format(train_state.step))
|
||||
logger.info(f"Saving to: {str(curr_save_dir)}")
|
||||
logger.info(f"Saving to: {curr_save_dir}")
|
||||
|
||||
if dist.is_initialized():
|
||||
dist.barrier()
|
||||
|
@ -242,17 +243,19 @@ class CheckpointManager:
|
|||
if dist.is_initialized():
|
||||
dist.barrier()
|
||||
|
||||
print("config type", type(config))
|
||||
if get_is_master():
|
||||
config.dump_to_yaml_file(curr_save_dir / CONFIG_NAME)
|
||||
self.fs.write_text(
|
||||
os.path.join(curr_save_dir, CONFIG_NAME), config.model_dump_json()
|
||||
)
|
||||
|
||||
# Add json dump here
|
||||
dp_rank, tp_rank = self._get_dp_tp_mesh(device_mesh)
|
||||
if tp_rank == 0:
|
||||
train_state_name = TRAIN_STATE_NAME.format(dp_rank)
|
||||
logger.info(
|
||||
f"Saving train state to: {str(curr_save_dir / train_state_name)}"
|
||||
)
|
||||
with open(curr_save_dir / train_state_name, "w") as f:
|
||||
train_state_full_path = os.path.join(curr_save_dir, train_state_name)
|
||||
logger.info(f"Saving train state to: {train_state_full_path}")
|
||||
with self.fs.open(train_state_full_path, "w") as f:
|
||||
json.dump(train_state.state_dict(), f)
|
||||
logger.info("Train state saved !")
|
||||
|
||||
|
@ -271,7 +274,7 @@ class CheckpointManager:
|
|||
optimizer,
|
||||
train_state,
|
||||
device_mesh: DeviceMesh,
|
||||
path: Optional[Path] = None,
|
||||
path: str | None = None,
|
||||
):
|
||||
dp_rank, tp_rank = self._get_dp_tp_mesh(device_mesh)
|
||||
# Loading tries to load the provided path, if not available the last saved step and finally from the init path
|
||||
|
@ -284,12 +287,12 @@ class CheckpointManager:
|
|||
# Only load train state if it's provided, the files exist and we're not loading from init path
|
||||
train_state_name = TRAIN_STATE_NAME.format(dp_rank)
|
||||
logger.info("Reloading train state")
|
||||
with open(path / train_state_name, "r") as f:
|
||||
with self.fs.open(os.path.join(path, train_state_name), "r") as f:
|
||||
train_state_dict = json.load(f)
|
||||
train_state.load_state_dict(train_state_dict)
|
||||
logger.info("Train state reloaded")
|
||||
|
||||
logger.info(f"Loading from: {str(path)}")
|
||||
logger.info(f"Loading from: {path}")
|
||||
state_dict = self.get_state_dict(
|
||||
model=model,
|
||||
optimizer=optimizer,
|
||||
|
|
|
@ -25,6 +25,7 @@ from torch.optim import lr_scheduler
|
|||
|
||||
from bytelatent.args import TrainArgs, parse_args
|
||||
from bytelatent.checkpoint import CheckpointManager, load_from_checkpoint
|
||||
from bytelatent.data.file_util import get_fs
|
||||
from bytelatent.data.iterators.multiprocess_iterator import (
|
||||
MultiprocessIterator,
|
||||
MultiprocessIteratorState,
|
||||
|
@ -313,8 +314,11 @@ def train(args: TrainArgs):
|
|||
|
||||
if args.checkpoint.init_ckpt_path:
|
||||
logger.info(f"Loading initial model from {args.checkpoint.init_ckpt_path}")
|
||||
ckpt_fs = get_fs(
|
||||
args.checkpoint.init_ckpt_path, s3_profile=args.checkpoint.s3_profile
|
||||
)
|
||||
load_from_checkpoint(
|
||||
args.checkpoint.init_ckpt_path, model, model_key="model"
|
||||
ckpt_fs, args.checkpoint.init_ckpt_path, model, model_key="model"
|
||||
) # Put model_key="" if its directly the model checkpoint
|
||||
model.rope_embeddings.reset_parameters() # For RoPe initialization since it's a buffer it might not be loaded
|
||||
else:
|
||||
|
|
Loading…
Reference in a new issue