Update checkpointing to use fsspec (#39)

Summary:

- Make the data/checkpoint code fsspec compatible
- Still will not work with s3 saves, due to `torch.distributed.checkpoint.save` not being out of the box workable with `fsspec`. Will implement in followup PR


Test Plan:

Run unit tests and the commands below

```
python -m bytelatent.train config=internal/configs/s3_debug.yaml eval=null checkpoint.dump.every=100
```

```
torchrun --nproc-per-node 8 -m bytelatent.train config=internal/configs/s3_debug.yaml eval=null checkpoint.dump.every=100
```

These currently won't work due to the torch distributed save, but theses hould be tested at a later date

```
python -m bytelatent.train config=internal/configs/s3_debug.yaml eval=null checkpoint.dump.every=100 dump_dir=s3://blt/scratch/checkpoint-test/
```

```
torchrun --nproc-per-node 8 -m bytelatent.train config=internal/configs/s3_debug.yaml eval=null checkpoint.dump.every=100 dump_dir=s3://blt/scratch/checkpoint-test/
```
This commit is contained in:
Pedro Rodriguez 2025-02-06 09:41:58 -08:00 committed by GitHub
parent 739dc71a0a
commit afedb16598
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 112 additions and 72 deletions

View file

@ -294,6 +294,14 @@ class TrainArgs(BaseModel):
def dump_to_yaml_file(
self, path: str, log_config: bool = True, sort_keys: bool = True
):
yaml_str = self.dump_to_yaml_str(sort_keys=sort_keys)
with open(path, "w") as f:
if log_config:
logger.info("Using the following config for this run:")
logger.info(yaml_str)
f.write(yaml_str)
def dump_to_yaml_str(self, sort_keys: bool = True):
model_dict = self.model_dump(mode="json")
yaml_str = yaml.dump(
model_dict,
@ -301,8 +309,4 @@ class TrainArgs(BaseModel):
sort_keys=sort_keys,
default_flow_style=False,
)
with open(path, "w") as f:
if log_config:
logger.info("Using the following config for this run:")
logger.info(yaml_str)
f.write(yaml_str)
return yaml_str

View file

@ -4,10 +4,9 @@ import json
import logging
import os
import re
from pathlib import Path
from typing import List, Optional, Tuple
import fsspec
import s3fs
import torch
import torch.distributed as dist
import torch.distributed.checkpoint as dcp
@ -70,26 +69,29 @@ def consolidate_checkpoints(fs: fsspec.AbstractFileSystem, ckpt_dir: str):
Returns the path to the consolidated checkpoint
"""
consolidate_path = Path(ckpt_dir) / CONSOLIDATE_FOLDER
if not (consolidate_path / CONSOLIDATE_NAME).exists():
consolidate_path.mkdir(exist_ok=True)
logger.info(f"Consolidating to: {str(consolidate_path)}")
dcp_to_torch_save(ckpt_dir, str(consolidate_path / CONSOLIDATE_NAME))
(consolidate_path / CONFIG_NAME).write_text(
(Path(ckpt_dir) / CONFIG_NAME).read_text()
consolidate_path = os.path.join(ckpt_dir, CONSOLIDATE_FOLDER)
consolidate_name = os.path.join(consolidate_path, CONSOLIDATE_NAME)
if not fs.exists(consolidate_name):
fs.mkdirs(consolidate_path, exist_ok=True)
logger.info(f"Consolidating to: {consolidate_path}")
dcp_to_torch_save(ckpt_dir, consolidate_name)
fs.write_text(
os.path.join(consolidate_path, CONFIG_NAME),
fs.read_text(os.path.join(ckpt_dir, CONFIG_NAME)),
)
logger.info("Consolidated !")
return consolidate_path
def load_from_checkpoint(
fs: fsspec.AbstractFileSystem,
ckpt_dir: str,
model: nn.Module,
optimizer: Optional[torch.optim.Optimizer] = None,
optimizer: torch.optim.Optimizer | None = None,
model_key: str = "model",
optim_key: str = "optim",
):
if not (Path(ckpt_dir) / ".metadata").exists():
if not fs.exists(os.path.join(ckpt_dir, ".metadata")):
raise ValueError(
f"Please convert the checkpoint distcp format using `torch.distributed.checkpoint.format_utils.torch_save_to_dcp` before loading it"
)
@ -115,19 +117,24 @@ class CheckpointManager:
self.init_ckpt_path = args.init_ckpt_path
self.continue_training_from_init = args.continue_training_from_init
if not isinstance(self.fs, s3fs.S3FileSystem):
# S3 does not have a concept of directories
assert self.fs.exists(
self.path
), f"Path {self.path} does not exist and needs to be created before using CheckpointManager (use instantiate_and_make_dir)"
self.existing_saves = self.get_existing_saves()
def get_existing_saves(self) -> List[Path]:
def get_existing_saves(self) -> list[str]:
if self.fs.exists(self.path) and self.fs.isdir(self.path):
folders = [
p
for p in Path(self.path).iterdir()
if p.is_dir() and re.match(RE_FOLDER, p.name)
for p in self.fs.ls(self.path)
if self.fs.isdir(p) and re.match(RE_FOLDER, os.path.basename(p))
]
folders.sort(key=lambda p: _get_key_step(p.name))
else:
folders = []
folders.sort(key=lambda p: _get_key_step(os.path.basename(p)))
return folders
def clean_up(self):
@ -136,8 +143,9 @@ class CheckpointManager:
eval_folders = []
other_folders = []
for p in self.existing_saves:
is_dump = _get_key_step(p.name) % self.dump_every.every == 0
is_eval = _get_key_step(p.name) % self.eval_every.every == 0
assert isinstance(p, str), f"Base path type: {p}"
is_dump = _get_key_step(os.path.basename(p)) % self.dump_every.every == 0
is_eval = _get_key_step(os.path.basename(p)) % self.eval_every.every == 0
if is_dump:
dump_folders.append(p)
if is_eval:
@ -161,40 +169,39 @@ class CheckpointManager:
if dist.get_rank() == 0:
for folder in folder_to_remove:
for file in folder.iterdir():
if file.is_file():
file.unlink()
elif file.is_dir():
assert file.name in [CONSOLIDATE_FOLDER]
for f in file.iterdir():
f.unlink()
file.rmdir()
folder.rmdir()
for file in self.fs.ls(folder):
if self.fs.isfile(file):
self.fs.rm_file(file)
elif self.fs.isdir(file):
assert os.path.name(file) in [CONSOLIDATE_FOLDER]
for f in self.fs.ls(file):
self.fs.rm(f)
self.fs.rmdir(file)
self.fs.rmdir(folder)
dist.barrier()
self.existing_saves = list(folder_to_keep)
self.existing_saves.sort(key=lambda p: _get_key_step(p.name))
self.existing_saves.sort(key=lambda p: _get_key_step(os.path.basename(p)))
def get_last_step_path(self, dp_rank: int = 0) -> Optional[Path]:
def get_last_step_path(self, dp_rank: int = 0) -> str | None:
path = None
for p in reversed(self.existing_saves):
if (p / TRAIN_STATE_NAME.format(dp_rank)).is_file():
if self.fs.isfile(os.path.join(p, TRAIN_STATE_NAME.format(dp_rank))):
path = p
break
return path
def _create_folder(self, base_path: Path, folder_name: str) -> Path:
folder = base_path / folder_name
def _create_folder(self, base_path: str, folder_name: str) -> str:
folder = os.path.join(base_path, folder_name)
if get_is_master():
folder.mkdir(parents=False, exist_ok=True)
self.fs.mkdirs(folder, exist_ok=True)
if dist.is_initialized():
dist.barrier()
return folder
def _get_dp_tp_mesh(
self, device_mesh: Optional[DeviceMesh] = None
) -> Tuple[int, int]:
def _get_dp_tp_mesh(self, device_mesh: DeviceMesh | None = None) -> tuple[int, int]:
dp_rank = 0
tp_rank = 0
if device_mesh is not None:
@ -222,14 +229,14 @@ class CheckpointManager:
model,
optimizer,
train_state,
config,
device_mesh: Optional[DeviceMesh] = None,
config: BaseModel,
device_mesh: DeviceMesh | None = None,
) -> bool:
# When creating directory check if only rank0 or is there other solution
path = Path(self.path)
path = self.path
curr_save_dir = self._create_folder(path, FOLDER_NAME.format(train_state.step))
logger.info(f"Saving to: {str(curr_save_dir)}")
logger.info(f"Saving to: {curr_save_dir}")
if dist.is_initialized():
dist.barrier()
@ -242,17 +249,19 @@ class CheckpointManager:
if dist.is_initialized():
dist.barrier()
print("config type", type(config))
if get_is_master():
config.dump_to_yaml_file(curr_save_dir / CONFIG_NAME)
self.fs.write_text(
os.path.join(curr_save_dir, CONFIG_NAME), config.model_dump_json()
)
# Add json dump here
dp_rank, tp_rank = self._get_dp_tp_mesh(device_mesh)
if tp_rank == 0:
train_state_name = TRAIN_STATE_NAME.format(dp_rank)
logger.info(
f"Saving train state to: {str(curr_save_dir / train_state_name)}"
)
with open(curr_save_dir / train_state_name, "w") as f:
train_state_full_path = os.path.join(curr_save_dir, train_state_name)
logger.info(f"Saving train state to: {train_state_full_path}")
with self.fs.open(train_state_full_path, "w") as f:
json.dump(train_state.state_dict(), f)
logger.info("Train state saved !")
@ -271,7 +280,7 @@ class CheckpointManager:
optimizer,
train_state,
device_mesh: DeviceMesh,
path: Optional[Path] = None,
path: str | None = None,
):
dp_rank, tp_rank = self._get_dp_tp_mesh(device_mesh)
# Loading tries to load the provided path, if not available the last saved step and finally from the init path
@ -284,12 +293,12 @@ class CheckpointManager:
# Only load train state if it's provided, the files exist and we're not loading from init path
train_state_name = TRAIN_STATE_NAME.format(dp_rank)
logger.info("Reloading train state")
with open(path / train_state_name, "r") as f:
with self.fs.open(os.path.join(path, train_state_name), "r") as f:
train_state_dict = json.load(f)
train_state.load_state_dict(train_state_dict)
logger.info("Train state reloaded")
logger.info(f"Loading from: {str(path)}")
logger.info(f"Loading from: {path}")
state_dict = self.get_state_dict(
model=model,
optimizer=optimizer,

View file

@ -6,6 +6,8 @@ import sys
import time
from datetime import timedelta
import fsspec
from bytelatent.distributed import get_global_rank, get_is_slurm_job
@ -92,6 +94,7 @@ def init_logger(
*,
name: str | None = None,
level: str = "INFO",
fs: fsspec.AbstractFileSystem | None = None,
):
"""
Setup logging.
@ -121,7 +124,11 @@ def init_logger(
if log_file is not None and get_global_rank() == 0:
# build file handler
if fs is None:
file_handler = logging.FileHandler(log_file, "a")
else:
file_stream = fs.open(log_file, mode="a")
file_handler = logging.StreamHandler(file_stream)
file_handler.setLevel(logging.NOTSET)
file_handler.setFormatter(LogFormatter())
# update logger

View file

@ -8,6 +8,7 @@ from datetime import datetime, timezone
from pathlib import Path
from typing import Any, Union
import fsspec
import torch
import torch.nn as nn
import wandb
@ -53,14 +54,24 @@ class LoggingArgs(BaseModel):
class MetricLogger:
def __init__(self, outdir: Path, args: Any | None = None):
def __init__(
self,
outdir: Path,
# args: TrainArgs
args: Any | None = None,
fs: fsspec.AbstractFileSystem | None = None,
):
self.outdir = outdir
self.jsonl_writer = None
self.fs = fs
self.args = args
def open(self):
if self.jsonl_writer is None:
if self.fs is None:
self.jsonl_writer = open(self.outdir, "a")
else:
self.jsonl_writer = self.fs.open(self.outdir, "a")
if (
self.args is not None
and self.args.logging.wandb is not None

View file

@ -8,7 +8,6 @@ import sys
from contextlib import ExitStack
from copy import deepcopy
from dataclasses import asdict, dataclass
from pathlib import Path
from timeit import default_timer as timer
from typing import Any, TypeVar
@ -18,13 +17,13 @@ import torch.nn.functional
import torch.nn.functional as F
import wandb
import xformers.profiler
from omegaconf import OmegaConf
from torch.distributed._tensor import DTensor
from torch.distributed.checkpoint.stateful import Stateful
from torch.optim import lr_scheduler
from bytelatent.args import TrainArgs, parse_args
from bytelatent.checkpoint import CheckpointManager, load_from_checkpoint
from bytelatent.data.file_util import get_fs
from bytelatent.data.iterators.multiprocess_iterator import (
MultiprocessIterator,
MultiprocessIteratorState,
@ -136,11 +135,12 @@ def validate_train_args(args: TrainArgs, output_size: int):
if args.checkpoint.path is None:
logger.info(f"Setting checkpoint path to {args.checkpoint.path}")
args.checkpoint.path = str(Path(args.dump_dir) / "checkpoints")
args.checkpoint.path = os.path.join(args.dump_dir, "checkpoints")
data_fs = get_fs(args.data.root_dir, s3_profile=args.data.s3_profile)
for source in args.data.sources:
data_path = os.path.join(args.data.root_dir, source)
assert os.path.exists(data_path), f"{data_path} doesn't exist"
assert data_fs.exists(data_path), f"{data_path} doesn't exist"
if (
args.distributed.dp_replicate
@ -255,10 +255,15 @@ def train(args: TrainArgs):
args,
tokenizer.n_words,
)
dump_fs = get_fs(args.dump_dir, s3_profile=args.checkpoint.s3_profile)
if get_is_master():
os.makedirs(args.dump_dir, exist_ok=True)
args.dump_to_yaml_file(Path(args.dump_dir) / "config.yaml")
init_logger(Path(args.dump_dir) / "train.log")
dump_fs.mkdirs(args.dump_dir, exist_ok=True)
config_yaml_str = args.dump_to_yaml_str()
logging.info("TrainArgs: \n%s", config_yaml_str)
dump_fs.write_text(
os.path.join(args.dump_dir, "config.yaml"), config_yaml_str
)
init_logger(os.path.join(args.dump_dir, "train.log"), fs=dump_fs)
init_signal_handler(set_preemption_flag) # For handling preemption signals.
setup_env(args.env)
setup_torch_distributed(args.distributed)
@ -313,8 +318,11 @@ def train(args: TrainArgs):
if args.checkpoint.init_ckpt_path:
logger.info(f"Loading initial model from {args.checkpoint.init_ckpt_path}")
ckpt_fs = get_fs(
args.checkpoint.init_ckpt_path, s3_profile=args.checkpoint.s3_profile
)
load_from_checkpoint(
args.checkpoint.init_ckpt_path, model, model_key="model"
ckpt_fs, args.checkpoint.init_ckpt_path, model, model_key="model"
) # Put model_key="" if its directly the model checkpoint
model.rope_embeddings.reset_parameters() # For RoPe initialization since it's a buffer it might not be loaded
else:
@ -352,13 +360,14 @@ def train(args: TrainArgs):
checkpoint.load(model, optimizer, train_state, world_mesh)
# Either load from latest checkpoint or start from scratch
if args.probe_freq is not None:
# TODO: Convert this to fsspec compatible
if get_is_master():
os.makedirs(Path(args.dump_dir) / "probe", exist_ok=True)
os.makedirs(os.path.join(args.dump_dir, "probe"), exist_ok=True)
torch.distributed.barrier()
probe = AutoProbeD(
model,
(
Path(args.dump_dir) / "probe" / f"probe.{dp_rank}.jsonl"
os.path.join(args.dump_dir, "probe", f"probe.{dp_rank}.jsonl")
if (dp_rank % 128 == 0)
else None
),
@ -370,7 +379,7 @@ def train(args: TrainArgs):
# train loop
model.train()
metric_logger = context_stack.enter_context(
MetricLogger(Path(args.dump_dir) / "metrics.jsonl", args)
MetricLogger(os.path.join(args.dump_dir, "metrics.jsonl"), args, fs=dump_fs)
)
data_loader = train_state.data_loader_state.build()
batch_iterator = data_loader.create_iter()