mirror of
https://github.com/facebookresearch/blt.git
synced 2025-02-22 21:12:15 +00:00
Update checkpointing to use fsspec (#39)
Summary: - Make the data/checkpoint code fsspec compatible - Still will not work with s3 saves, due to `torch.distributed.checkpoint.save` not being out of the box workable with `fsspec`. Will implement in followup PR Test Plan: Run unit tests and the commands below ``` python -m bytelatent.train config=internal/configs/s3_debug.yaml eval=null checkpoint.dump.every=100 ``` ``` torchrun --nproc-per-node 8 -m bytelatent.train config=internal/configs/s3_debug.yaml eval=null checkpoint.dump.every=100 ``` These currently won't work due to the torch distributed save, but theses hould be tested at a later date ``` python -m bytelatent.train config=internal/configs/s3_debug.yaml eval=null checkpoint.dump.every=100 dump_dir=s3://blt/scratch/checkpoint-test/ ``` ``` torchrun --nproc-per-node 8 -m bytelatent.train config=internal/configs/s3_debug.yaml eval=null checkpoint.dump.every=100 dump_dir=s3://blt/scratch/checkpoint-test/ ```
This commit is contained in:
parent
739dc71a0a
commit
afedb16598
|
@ -294,6 +294,14 @@ class TrainArgs(BaseModel):
|
||||||
def dump_to_yaml_file(
|
def dump_to_yaml_file(
|
||||||
self, path: str, log_config: bool = True, sort_keys: bool = True
|
self, path: str, log_config: bool = True, sort_keys: bool = True
|
||||||
):
|
):
|
||||||
|
yaml_str = self.dump_to_yaml_str(sort_keys=sort_keys)
|
||||||
|
with open(path, "w") as f:
|
||||||
|
if log_config:
|
||||||
|
logger.info("Using the following config for this run:")
|
||||||
|
logger.info(yaml_str)
|
||||||
|
f.write(yaml_str)
|
||||||
|
|
||||||
|
def dump_to_yaml_str(self, sort_keys: bool = True):
|
||||||
model_dict = self.model_dump(mode="json")
|
model_dict = self.model_dump(mode="json")
|
||||||
yaml_str = yaml.dump(
|
yaml_str = yaml.dump(
|
||||||
model_dict,
|
model_dict,
|
||||||
|
@ -301,8 +309,4 @@ class TrainArgs(BaseModel):
|
||||||
sort_keys=sort_keys,
|
sort_keys=sort_keys,
|
||||||
default_flow_style=False,
|
default_flow_style=False,
|
||||||
)
|
)
|
||||||
with open(path, "w") as f:
|
return yaml_str
|
||||||
if log_config:
|
|
||||||
logger.info("Using the following config for this run:")
|
|
||||||
logger.info(yaml_str)
|
|
||||||
f.write(yaml_str)
|
|
||||||
|
|
|
@ -4,10 +4,9 @@ import json
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
from pathlib import Path
|
|
||||||
from typing import List, Optional, Tuple
|
|
||||||
|
|
||||||
import fsspec
|
import fsspec
|
||||||
|
import s3fs
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
import torch.distributed.checkpoint as dcp
|
import torch.distributed.checkpoint as dcp
|
||||||
|
@ -70,26 +69,29 @@ def consolidate_checkpoints(fs: fsspec.AbstractFileSystem, ckpt_dir: str):
|
||||||
|
|
||||||
Returns the path to the consolidated checkpoint
|
Returns the path to the consolidated checkpoint
|
||||||
"""
|
"""
|
||||||
consolidate_path = Path(ckpt_dir) / CONSOLIDATE_FOLDER
|
consolidate_path = os.path.join(ckpt_dir, CONSOLIDATE_FOLDER)
|
||||||
if not (consolidate_path / CONSOLIDATE_NAME).exists():
|
consolidate_name = os.path.join(consolidate_path, CONSOLIDATE_NAME)
|
||||||
consolidate_path.mkdir(exist_ok=True)
|
if not fs.exists(consolidate_name):
|
||||||
logger.info(f"Consolidating to: {str(consolidate_path)}")
|
fs.mkdirs(consolidate_path, exist_ok=True)
|
||||||
dcp_to_torch_save(ckpt_dir, str(consolidate_path / CONSOLIDATE_NAME))
|
logger.info(f"Consolidating to: {consolidate_path}")
|
||||||
(consolidate_path / CONFIG_NAME).write_text(
|
dcp_to_torch_save(ckpt_dir, consolidate_name)
|
||||||
(Path(ckpt_dir) / CONFIG_NAME).read_text()
|
fs.write_text(
|
||||||
|
os.path.join(consolidate_path, CONFIG_NAME),
|
||||||
|
fs.read_text(os.path.join(ckpt_dir, CONFIG_NAME)),
|
||||||
)
|
)
|
||||||
logger.info("Consolidated !")
|
logger.info("Consolidated !")
|
||||||
return consolidate_path
|
return consolidate_path
|
||||||
|
|
||||||
|
|
||||||
def load_from_checkpoint(
|
def load_from_checkpoint(
|
||||||
|
fs: fsspec.AbstractFileSystem,
|
||||||
ckpt_dir: str,
|
ckpt_dir: str,
|
||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
optimizer: Optional[torch.optim.Optimizer] = None,
|
optimizer: torch.optim.Optimizer | None = None,
|
||||||
model_key: str = "model",
|
model_key: str = "model",
|
||||||
optim_key: str = "optim",
|
optim_key: str = "optim",
|
||||||
):
|
):
|
||||||
if not (Path(ckpt_dir) / ".metadata").exists():
|
if not fs.exists(os.path.join(ckpt_dir, ".metadata")):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Please convert the checkpoint distcp format using `torch.distributed.checkpoint.format_utils.torch_save_to_dcp` before loading it"
|
f"Please convert the checkpoint distcp format using `torch.distributed.checkpoint.format_utils.torch_save_to_dcp` before loading it"
|
||||||
)
|
)
|
||||||
|
@ -115,19 +117,24 @@ class CheckpointManager:
|
||||||
self.init_ckpt_path = args.init_ckpt_path
|
self.init_ckpt_path = args.init_ckpt_path
|
||||||
self.continue_training_from_init = args.continue_training_from_init
|
self.continue_training_from_init = args.continue_training_from_init
|
||||||
|
|
||||||
assert self.fs.exists(
|
if not isinstance(self.fs, s3fs.S3FileSystem):
|
||||||
self.path
|
# S3 does not have a concept of directories
|
||||||
), f"Path {self.path} does not exist and needs to be created before using CheckpointManager (use instantiate_and_make_dir)"
|
assert self.fs.exists(
|
||||||
|
self.path
|
||||||
|
), f"Path {self.path} does not exist and needs to be created before using CheckpointManager (use instantiate_and_make_dir)"
|
||||||
|
|
||||||
self.existing_saves = self.get_existing_saves()
|
self.existing_saves = self.get_existing_saves()
|
||||||
|
|
||||||
def get_existing_saves(self) -> List[Path]:
|
def get_existing_saves(self) -> list[str]:
|
||||||
folders = [
|
if self.fs.exists(self.path) and self.fs.isdir(self.path):
|
||||||
p
|
folders = [
|
||||||
for p in Path(self.path).iterdir()
|
p
|
||||||
if p.is_dir() and re.match(RE_FOLDER, p.name)
|
for p in self.fs.ls(self.path)
|
||||||
]
|
if self.fs.isdir(p) and re.match(RE_FOLDER, os.path.basename(p))
|
||||||
folders.sort(key=lambda p: _get_key_step(p.name))
|
]
|
||||||
|
else:
|
||||||
|
folders = []
|
||||||
|
folders.sort(key=lambda p: _get_key_step(os.path.basename(p)))
|
||||||
return folders
|
return folders
|
||||||
|
|
||||||
def clean_up(self):
|
def clean_up(self):
|
||||||
|
@ -136,8 +143,9 @@ class CheckpointManager:
|
||||||
eval_folders = []
|
eval_folders = []
|
||||||
other_folders = []
|
other_folders = []
|
||||||
for p in self.existing_saves:
|
for p in self.existing_saves:
|
||||||
is_dump = _get_key_step(p.name) % self.dump_every.every == 0
|
assert isinstance(p, str), f"Base path type: {p}"
|
||||||
is_eval = _get_key_step(p.name) % self.eval_every.every == 0
|
is_dump = _get_key_step(os.path.basename(p)) % self.dump_every.every == 0
|
||||||
|
is_eval = _get_key_step(os.path.basename(p)) % self.eval_every.every == 0
|
||||||
if is_dump:
|
if is_dump:
|
||||||
dump_folders.append(p)
|
dump_folders.append(p)
|
||||||
if is_eval:
|
if is_eval:
|
||||||
|
@ -161,40 +169,39 @@ class CheckpointManager:
|
||||||
|
|
||||||
if dist.get_rank() == 0:
|
if dist.get_rank() == 0:
|
||||||
for folder in folder_to_remove:
|
for folder in folder_to_remove:
|
||||||
for file in folder.iterdir():
|
for file in self.fs.ls(folder):
|
||||||
if file.is_file():
|
if self.fs.isfile(file):
|
||||||
file.unlink()
|
self.fs.rm_file(file)
|
||||||
elif file.is_dir():
|
elif self.fs.isdir(file):
|
||||||
assert file.name in [CONSOLIDATE_FOLDER]
|
assert os.path.name(file) in [CONSOLIDATE_FOLDER]
|
||||||
for f in file.iterdir():
|
for f in self.fs.ls(file):
|
||||||
f.unlink()
|
self.fs.rm(f)
|
||||||
file.rmdir()
|
self.fs.rmdir(file)
|
||||||
folder.rmdir()
|
self.fs.rmdir(folder)
|
||||||
|
|
||||||
dist.barrier()
|
dist.barrier()
|
||||||
|
|
||||||
self.existing_saves = list(folder_to_keep)
|
self.existing_saves = list(folder_to_keep)
|
||||||
self.existing_saves.sort(key=lambda p: _get_key_step(p.name))
|
self.existing_saves.sort(key=lambda p: _get_key_step(os.path.basename(p)))
|
||||||
|
|
||||||
def get_last_step_path(self, dp_rank: int = 0) -> Optional[Path]:
|
def get_last_step_path(self, dp_rank: int = 0) -> str | None:
|
||||||
path = None
|
path = None
|
||||||
for p in reversed(self.existing_saves):
|
for p in reversed(self.existing_saves):
|
||||||
if (p / TRAIN_STATE_NAME.format(dp_rank)).is_file():
|
|
||||||
|
if self.fs.isfile(os.path.join(p, TRAIN_STATE_NAME.format(dp_rank))):
|
||||||
path = p
|
path = p
|
||||||
break
|
break
|
||||||
return path
|
return path
|
||||||
|
|
||||||
def _create_folder(self, base_path: Path, folder_name: str) -> Path:
|
def _create_folder(self, base_path: str, folder_name: str) -> str:
|
||||||
folder = base_path / folder_name
|
folder = os.path.join(base_path, folder_name)
|
||||||
if get_is_master():
|
if get_is_master():
|
||||||
folder.mkdir(parents=False, exist_ok=True)
|
self.fs.mkdirs(folder, exist_ok=True)
|
||||||
if dist.is_initialized():
|
if dist.is_initialized():
|
||||||
dist.barrier()
|
dist.barrier()
|
||||||
return folder
|
return folder
|
||||||
|
|
||||||
def _get_dp_tp_mesh(
|
def _get_dp_tp_mesh(self, device_mesh: DeviceMesh | None = None) -> tuple[int, int]:
|
||||||
self, device_mesh: Optional[DeviceMesh] = None
|
|
||||||
) -> Tuple[int, int]:
|
|
||||||
dp_rank = 0
|
dp_rank = 0
|
||||||
tp_rank = 0
|
tp_rank = 0
|
||||||
if device_mesh is not None:
|
if device_mesh is not None:
|
||||||
|
@ -222,14 +229,14 @@ class CheckpointManager:
|
||||||
model,
|
model,
|
||||||
optimizer,
|
optimizer,
|
||||||
train_state,
|
train_state,
|
||||||
config,
|
config: BaseModel,
|
||||||
device_mesh: Optional[DeviceMesh] = None,
|
device_mesh: DeviceMesh | None = None,
|
||||||
) -> bool:
|
) -> bool:
|
||||||
|
|
||||||
# When creating directory check if only rank0 or is there other solution
|
# When creating directory check if only rank0 or is there other solution
|
||||||
path = Path(self.path)
|
path = self.path
|
||||||
curr_save_dir = self._create_folder(path, FOLDER_NAME.format(train_state.step))
|
curr_save_dir = self._create_folder(path, FOLDER_NAME.format(train_state.step))
|
||||||
logger.info(f"Saving to: {str(curr_save_dir)}")
|
logger.info(f"Saving to: {curr_save_dir}")
|
||||||
|
|
||||||
if dist.is_initialized():
|
if dist.is_initialized():
|
||||||
dist.barrier()
|
dist.barrier()
|
||||||
|
@ -242,17 +249,19 @@ class CheckpointManager:
|
||||||
if dist.is_initialized():
|
if dist.is_initialized():
|
||||||
dist.barrier()
|
dist.barrier()
|
||||||
|
|
||||||
|
print("config type", type(config))
|
||||||
if get_is_master():
|
if get_is_master():
|
||||||
config.dump_to_yaml_file(curr_save_dir / CONFIG_NAME)
|
self.fs.write_text(
|
||||||
|
os.path.join(curr_save_dir, CONFIG_NAME), config.model_dump_json()
|
||||||
|
)
|
||||||
|
|
||||||
# Add json dump here
|
# Add json dump here
|
||||||
dp_rank, tp_rank = self._get_dp_tp_mesh(device_mesh)
|
dp_rank, tp_rank = self._get_dp_tp_mesh(device_mesh)
|
||||||
if tp_rank == 0:
|
if tp_rank == 0:
|
||||||
train_state_name = TRAIN_STATE_NAME.format(dp_rank)
|
train_state_name = TRAIN_STATE_NAME.format(dp_rank)
|
||||||
logger.info(
|
train_state_full_path = os.path.join(curr_save_dir, train_state_name)
|
||||||
f"Saving train state to: {str(curr_save_dir / train_state_name)}"
|
logger.info(f"Saving train state to: {train_state_full_path}")
|
||||||
)
|
with self.fs.open(train_state_full_path, "w") as f:
|
||||||
with open(curr_save_dir / train_state_name, "w") as f:
|
|
||||||
json.dump(train_state.state_dict(), f)
|
json.dump(train_state.state_dict(), f)
|
||||||
logger.info("Train state saved !")
|
logger.info("Train state saved !")
|
||||||
|
|
||||||
|
@ -271,7 +280,7 @@ class CheckpointManager:
|
||||||
optimizer,
|
optimizer,
|
||||||
train_state,
|
train_state,
|
||||||
device_mesh: DeviceMesh,
|
device_mesh: DeviceMesh,
|
||||||
path: Optional[Path] = None,
|
path: str | None = None,
|
||||||
):
|
):
|
||||||
dp_rank, tp_rank = self._get_dp_tp_mesh(device_mesh)
|
dp_rank, tp_rank = self._get_dp_tp_mesh(device_mesh)
|
||||||
# Loading tries to load the provided path, if not available the last saved step and finally from the init path
|
# Loading tries to load the provided path, if not available the last saved step and finally from the init path
|
||||||
|
@ -284,12 +293,12 @@ class CheckpointManager:
|
||||||
# Only load train state if it's provided, the files exist and we're not loading from init path
|
# Only load train state if it's provided, the files exist and we're not loading from init path
|
||||||
train_state_name = TRAIN_STATE_NAME.format(dp_rank)
|
train_state_name = TRAIN_STATE_NAME.format(dp_rank)
|
||||||
logger.info("Reloading train state")
|
logger.info("Reloading train state")
|
||||||
with open(path / train_state_name, "r") as f:
|
with self.fs.open(os.path.join(path, train_state_name), "r") as f:
|
||||||
train_state_dict = json.load(f)
|
train_state_dict = json.load(f)
|
||||||
train_state.load_state_dict(train_state_dict)
|
train_state.load_state_dict(train_state_dict)
|
||||||
logger.info("Train state reloaded")
|
logger.info("Train state reloaded")
|
||||||
|
|
||||||
logger.info(f"Loading from: {str(path)}")
|
logger.info(f"Loading from: {path}")
|
||||||
state_dict = self.get_state_dict(
|
state_dict = self.get_state_dict(
|
||||||
model=model,
|
model=model,
|
||||||
optimizer=optimizer,
|
optimizer=optimizer,
|
||||||
|
|
|
@ -6,6 +6,8 @@ import sys
|
||||||
import time
|
import time
|
||||||
from datetime import timedelta
|
from datetime import timedelta
|
||||||
|
|
||||||
|
import fsspec
|
||||||
|
|
||||||
from bytelatent.distributed import get_global_rank, get_is_slurm_job
|
from bytelatent.distributed import get_global_rank, get_is_slurm_job
|
||||||
|
|
||||||
|
|
||||||
|
@ -92,6 +94,7 @@ def init_logger(
|
||||||
*,
|
*,
|
||||||
name: str | None = None,
|
name: str | None = None,
|
||||||
level: str = "INFO",
|
level: str = "INFO",
|
||||||
|
fs: fsspec.AbstractFileSystem | None = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Setup logging.
|
Setup logging.
|
||||||
|
@ -121,7 +124,11 @@ def init_logger(
|
||||||
|
|
||||||
if log_file is not None and get_global_rank() == 0:
|
if log_file is not None and get_global_rank() == 0:
|
||||||
# build file handler
|
# build file handler
|
||||||
file_handler = logging.FileHandler(log_file, "a")
|
if fs is None:
|
||||||
|
file_handler = logging.FileHandler(log_file, "a")
|
||||||
|
else:
|
||||||
|
file_stream = fs.open(log_file, mode="a")
|
||||||
|
file_handler = logging.StreamHandler(file_stream)
|
||||||
file_handler.setLevel(logging.NOTSET)
|
file_handler.setLevel(logging.NOTSET)
|
||||||
file_handler.setFormatter(LogFormatter())
|
file_handler.setFormatter(LogFormatter())
|
||||||
# update logger
|
# update logger
|
||||||
|
|
|
@ -8,6 +8,7 @@ from datetime import datetime, timezone
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Union
|
from typing import Any, Union
|
||||||
|
|
||||||
|
import fsspec
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import wandb
|
import wandb
|
||||||
|
@ -53,14 +54,24 @@ class LoggingArgs(BaseModel):
|
||||||
|
|
||||||
|
|
||||||
class MetricLogger:
|
class MetricLogger:
|
||||||
def __init__(self, outdir: Path, args: Any | None = None):
|
def __init__(
|
||||||
|
self,
|
||||||
|
outdir: Path,
|
||||||
|
# args: TrainArgs
|
||||||
|
args: Any | None = None,
|
||||||
|
fs: fsspec.AbstractFileSystem | None = None,
|
||||||
|
):
|
||||||
self.outdir = outdir
|
self.outdir = outdir
|
||||||
self.jsonl_writer = None
|
self.jsonl_writer = None
|
||||||
|
self.fs = fs
|
||||||
self.args = args
|
self.args = args
|
||||||
|
|
||||||
def open(self):
|
def open(self):
|
||||||
if self.jsonl_writer is None:
|
if self.jsonl_writer is None:
|
||||||
self.jsonl_writer = open(self.outdir, "a")
|
if self.fs is None:
|
||||||
|
self.jsonl_writer = open(self.outdir, "a")
|
||||||
|
else:
|
||||||
|
self.jsonl_writer = self.fs.open(self.outdir, "a")
|
||||||
if (
|
if (
|
||||||
self.args is not None
|
self.args is not None
|
||||||
and self.args.logging.wandb is not None
|
and self.args.logging.wandb is not None
|
||||||
|
|
|
@ -8,7 +8,6 @@ import sys
|
||||||
from contextlib import ExitStack
|
from contextlib import ExitStack
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from dataclasses import asdict, dataclass
|
from dataclasses import asdict, dataclass
|
||||||
from pathlib import Path
|
|
||||||
from timeit import default_timer as timer
|
from timeit import default_timer as timer
|
||||||
from typing import Any, TypeVar
|
from typing import Any, TypeVar
|
||||||
|
|
||||||
|
@ -18,13 +17,13 @@ import torch.nn.functional
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
import wandb
|
import wandb
|
||||||
import xformers.profiler
|
import xformers.profiler
|
||||||
from omegaconf import OmegaConf
|
|
||||||
from torch.distributed._tensor import DTensor
|
from torch.distributed._tensor import DTensor
|
||||||
from torch.distributed.checkpoint.stateful import Stateful
|
from torch.distributed.checkpoint.stateful import Stateful
|
||||||
from torch.optim import lr_scheduler
|
from torch.optim import lr_scheduler
|
||||||
|
|
||||||
from bytelatent.args import TrainArgs, parse_args
|
from bytelatent.args import TrainArgs, parse_args
|
||||||
from bytelatent.checkpoint import CheckpointManager, load_from_checkpoint
|
from bytelatent.checkpoint import CheckpointManager, load_from_checkpoint
|
||||||
|
from bytelatent.data.file_util import get_fs
|
||||||
from bytelatent.data.iterators.multiprocess_iterator import (
|
from bytelatent.data.iterators.multiprocess_iterator import (
|
||||||
MultiprocessIterator,
|
MultiprocessIterator,
|
||||||
MultiprocessIteratorState,
|
MultiprocessIteratorState,
|
||||||
|
@ -136,11 +135,12 @@ def validate_train_args(args: TrainArgs, output_size: int):
|
||||||
|
|
||||||
if args.checkpoint.path is None:
|
if args.checkpoint.path is None:
|
||||||
logger.info(f"Setting checkpoint path to {args.checkpoint.path}")
|
logger.info(f"Setting checkpoint path to {args.checkpoint.path}")
|
||||||
args.checkpoint.path = str(Path(args.dump_dir) / "checkpoints")
|
args.checkpoint.path = os.path.join(args.dump_dir, "checkpoints")
|
||||||
|
|
||||||
|
data_fs = get_fs(args.data.root_dir, s3_profile=args.data.s3_profile)
|
||||||
for source in args.data.sources:
|
for source in args.data.sources:
|
||||||
data_path = os.path.join(args.data.root_dir, source)
|
data_path = os.path.join(args.data.root_dir, source)
|
||||||
assert os.path.exists(data_path), f"{data_path} doesn't exist"
|
assert data_fs.exists(data_path), f"{data_path} doesn't exist"
|
||||||
|
|
||||||
if (
|
if (
|
||||||
args.distributed.dp_replicate
|
args.distributed.dp_replicate
|
||||||
|
@ -255,10 +255,15 @@ def train(args: TrainArgs):
|
||||||
args,
|
args,
|
||||||
tokenizer.n_words,
|
tokenizer.n_words,
|
||||||
)
|
)
|
||||||
|
dump_fs = get_fs(args.dump_dir, s3_profile=args.checkpoint.s3_profile)
|
||||||
if get_is_master():
|
if get_is_master():
|
||||||
os.makedirs(args.dump_dir, exist_ok=True)
|
dump_fs.mkdirs(args.dump_dir, exist_ok=True)
|
||||||
args.dump_to_yaml_file(Path(args.dump_dir) / "config.yaml")
|
config_yaml_str = args.dump_to_yaml_str()
|
||||||
init_logger(Path(args.dump_dir) / "train.log")
|
logging.info("TrainArgs: \n%s", config_yaml_str)
|
||||||
|
dump_fs.write_text(
|
||||||
|
os.path.join(args.dump_dir, "config.yaml"), config_yaml_str
|
||||||
|
)
|
||||||
|
init_logger(os.path.join(args.dump_dir, "train.log"), fs=dump_fs)
|
||||||
init_signal_handler(set_preemption_flag) # For handling preemption signals.
|
init_signal_handler(set_preemption_flag) # For handling preemption signals.
|
||||||
setup_env(args.env)
|
setup_env(args.env)
|
||||||
setup_torch_distributed(args.distributed)
|
setup_torch_distributed(args.distributed)
|
||||||
|
@ -313,8 +318,11 @@ def train(args: TrainArgs):
|
||||||
|
|
||||||
if args.checkpoint.init_ckpt_path:
|
if args.checkpoint.init_ckpt_path:
|
||||||
logger.info(f"Loading initial model from {args.checkpoint.init_ckpt_path}")
|
logger.info(f"Loading initial model from {args.checkpoint.init_ckpt_path}")
|
||||||
|
ckpt_fs = get_fs(
|
||||||
|
args.checkpoint.init_ckpt_path, s3_profile=args.checkpoint.s3_profile
|
||||||
|
)
|
||||||
load_from_checkpoint(
|
load_from_checkpoint(
|
||||||
args.checkpoint.init_ckpt_path, model, model_key="model"
|
ckpt_fs, args.checkpoint.init_ckpt_path, model, model_key="model"
|
||||||
) # Put model_key="" if its directly the model checkpoint
|
) # Put model_key="" if its directly the model checkpoint
|
||||||
model.rope_embeddings.reset_parameters() # For RoPe initialization since it's a buffer it might not be loaded
|
model.rope_embeddings.reset_parameters() # For RoPe initialization since it's a buffer it might not be loaded
|
||||||
else:
|
else:
|
||||||
|
@ -352,13 +360,14 @@ def train(args: TrainArgs):
|
||||||
checkpoint.load(model, optimizer, train_state, world_mesh)
|
checkpoint.load(model, optimizer, train_state, world_mesh)
|
||||||
# Either load from latest checkpoint or start from scratch
|
# Either load from latest checkpoint or start from scratch
|
||||||
if args.probe_freq is not None:
|
if args.probe_freq is not None:
|
||||||
|
# TODO: Convert this to fsspec compatible
|
||||||
if get_is_master():
|
if get_is_master():
|
||||||
os.makedirs(Path(args.dump_dir) / "probe", exist_ok=True)
|
os.makedirs(os.path.join(args.dump_dir, "probe"), exist_ok=True)
|
||||||
torch.distributed.barrier()
|
torch.distributed.barrier()
|
||||||
probe = AutoProbeD(
|
probe = AutoProbeD(
|
||||||
model,
|
model,
|
||||||
(
|
(
|
||||||
Path(args.dump_dir) / "probe" / f"probe.{dp_rank}.jsonl"
|
os.path.join(args.dump_dir, "probe", f"probe.{dp_rank}.jsonl")
|
||||||
if (dp_rank % 128 == 0)
|
if (dp_rank % 128 == 0)
|
||||||
else None
|
else None
|
||||||
),
|
),
|
||||||
|
@ -370,7 +379,7 @@ def train(args: TrainArgs):
|
||||||
# train loop
|
# train loop
|
||||||
model.train()
|
model.train()
|
||||||
metric_logger = context_stack.enter_context(
|
metric_logger = context_stack.enter_context(
|
||||||
MetricLogger(Path(args.dump_dir) / "metrics.jsonl", args)
|
MetricLogger(os.path.join(args.dump_dir, "metrics.jsonl"), args, fs=dump_fs)
|
||||||
)
|
)
|
||||||
data_loader = train_state.data_loader_state.build()
|
data_loader = train_state.data_loader_state.build()
|
||||||
batch_iterator = data_loader.create_iter()
|
batch_iterator = data_loader.create_iter()
|
||||||
|
|
Loading…
Reference in a new issue