From 7cf8fab49bc75737c73d050e1f3bb85fd826e00c Mon Sep 17 00:00:00 2001 From: Srinivasan Iyer Date: Wed, 5 Feb 2025 16:24:39 -0800 Subject: [PATCH 1/4] Fix wandb logging (#42) Co-authored-by: Srini Iyer --- bytelatent/metrics.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/bytelatent/metrics.py b/bytelatent/metrics.py index 77dc4d7..e746e4f 100644 --- a/bytelatent/metrics.py +++ b/bytelatent/metrics.py @@ -4,7 +4,6 @@ import json import logging from collections import namedtuple -from dataclasses import asdict from datetime import datetime, timezone from pathlib import Path from typing import Any, Union @@ -68,8 +67,8 @@ class MetricLogger: and get_is_master() ): run = wandb.init( - config=asdict(self.args), - **asdict(self.args.logging.wandb), + config=self.args.model_dump(), + **self.args.logging.wandb.model_dump(), ) def log(self, metrics: dict[str, Any]): From 6fbaf7266f8a19c3dc06d9f7bfef79e98eca9dc2 Mon Sep 17 00:00:00 2001 From: Srinivasan Iyer Date: Wed, 5 Feb 2025 17:18:40 -0800 Subject: [PATCH 2/4] fix stool (#44) Co-authored-by: Srini Iyer --- bytelatent/stool.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/bytelatent/stool.py b/bytelatent/stool.py index 965f4cb..b156177 100644 --- a/bytelatent/stool.py +++ b/bytelatent/stool.py @@ -4,14 +4,15 @@ import json import os import shutil import subprocess -from dataclasses import dataclass +from pydantic import BaseModel from typing import Any, Dict from omegaconf import OmegaConf -@dataclass -class StoolArgs: +class StoolArgs(BaseModel): + name: str = None + dump_dir: str = None config: Any = None launcher: str = "sbatch" # Can be sbatch or bash if already in salloc script: str = "apps.main.train" # The script to run. @@ -64,7 +65,7 @@ source activate {conda_env_path} export OMP_NUM_THREADS=1 export LAUNCH_WITH="SBATCH" export DUMP_DIR={dump_dir} -srun {log_output} -n {tasks} -N {nodes_per_run} python -u -m {script} config=$DUMP_DIR/base_config.yaml +srun {log_output} -n {tasks} -N {nodes_per_run} python -u -m {script} config=$DUMP_DIR/base_config.yaml dump_dir=$DUMP_DIR name={name} """ @@ -150,8 +151,8 @@ def validate_args(args) -> None: def launch_job(args: StoolArgs): # Set up args default and validate them depending on the cluster or partition requested validate_args(args) - dump_dir = args.config["dump_dir"] - job_name = args.config["name"] + job_name = args.name or args.config["name"] + dump_dir = os.path.join(args.dump_dir, job_name) or args.config["dump_dir"] print("Creating directories...") os.makedirs(dump_dir, exist_ok=args.dirs_exists_ok or args.override) if args.override: @@ -230,8 +231,7 @@ if __name__ == "__main__": Then you can pass model.dim=32 to change values in LMTransformerArgs or just name=tictac for top level attributes. """ - raise NotImplementedError("Update this to blt code") args = OmegaConf.from_cli() args.config = OmegaConf.load(args.config) - args = dataclass_from_dict(StoolArgs, args) + args = StoolArgs.model_validate(args) launch_job(args) From 739dc71a0a94af02701b46be2b6a88e68050dcdc Mon Sep 17 00:00:00 2001 From: Srinivasan Iyer Date: Wed, 5 Feb 2025 17:19:37 -0800 Subject: [PATCH 3/4] Add rope fp32 (#43) * Log model * Add flag for rope outer in fp32 --------- Co-authored-by: Srini Iyer --- bytelatent/base_transformer.py | 33 ++++++++++++++++++++++++++++---- bytelatent/model/blt.py | 8 +++----- bytelatent/model/local_models.py | 1 + bytelatent/train.py | 1 + 4 files changed, 34 insertions(+), 9 deletions(-) diff --git a/bytelatent/base_transformer.py b/bytelatent/base_transformer.py index dd0cce6..87d7334 100644 --- a/bytelatent/base_transformer.py +++ b/bytelatent/base_transformer.py @@ -45,6 +45,7 @@ class BaseTransformerArgs(BaseModel): norm_eps: float = 1e-5 rope_theta: float = 10000.0 + rope_use_fp32_in_outer_product: bool = False init_base_std: float | None = None init_std_factor: InitStdFactor = InitStdFactor.DISABLED @@ -78,7 +79,12 @@ def repeat_kv(x: torch.Tensor, n_rep: int, dim: int) -> torch.Tensor: ) -def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0): +def precompute_freqs_cis( + dim: int, + end: int, + theta: float = 10000.0, + rope_use_fp32_in_outer_product: bool = False, +): """ Precompute the frequency tensor for complex exponentials (cis) with given dimensions. @@ -96,6 +102,9 @@ def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0): """ freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) t = torch.arange(end, device=freqs.device) + if rope_use_fp32_in_outer_product: + t = t.to(torch.float32) + freqs = torch.outer(t, freqs).float() cos, sin = freqs.cos(), freqs.sin() @@ -232,22 +241,37 @@ class RotaryEmbedding(torch.nn.Module): RotaryEmbedding Module """ - def __init__(self, theta: float, head_dim: int, max_seqlen: int = 1024): + def __init__( + self, + theta: float, + head_dim: int, + max_seqlen: int = 1024, + rope_use_fp32_in_outer_product: bool = False, + ): super().__init__() self.theta = theta self.head_dim = head_dim self.max_seqlen = max_seqlen + self.rope_use_fp32_in_outer_product = rope_use_fp32_in_outer_product self.register_buffer( "freqs_cis", - precompute_freqs_cis(dim=head_dim, end=max_seqlen, theta=theta), + precompute_freqs_cis( + dim=head_dim, + end=max_seqlen, + theta=theta, + rope_use_fp32_in_outer_product=self.rope_use_fp32_in_outer_product, + ), persistent=False, ) def reset_parameters(self): self.freqs_cis[...] = precompute_freqs_cis( - dim=self.head_dim, end=self.max_seqlen, theta=self.theta + dim=self.head_dim, + end=self.max_seqlen, + theta=self.theta, + rope_use_fp32_in_outer_product=self.rope_use_fp32_in_outer_product, ) def forward( @@ -577,6 +601,7 @@ class BaseTransformer(nn.Module): theta=args.rope_theta, head_dim=args.head_dim or args.dim // args.n_heads, max_seqlen=args.max_seqlen, + rope_use_fp32_in_outer_product=args.rope_use_fp32_in_outer_product, ) self.eos_id = args.eos_id diff --git a/bytelatent/model/blt.py b/bytelatent/model/blt.py index a62be23..53a3be6 100644 --- a/bytelatent/model/blt.py +++ b/bytelatent/model/blt.py @@ -414,7 +414,7 @@ class ByteLatentTransformerArgs(BaseTransformerArgs): patch_in_forward: bool = False # Architecture and dimensions - dim_token: int = 256 + dim_token: int | None = None dim_global: int = 512 dim_local_decoder: int = 512 dim_local_encoder: int = 512 @@ -523,10 +523,6 @@ class ByteLatentTransformerArgs(BaseTransformerArgs): use_fsdp: bool = True attn_to_keep: str = "all" - # RoPE parameters - rope_theta: float = 10000.0 - rope_use_fp32_in_outer_product: bool = False - # Parameter mixing pm_size: int = 0 @@ -619,6 +615,7 @@ def create_local_encoder(args: ByteLatentTransformerArgs) -> LocalEncoder: sliding_window=args.local_attention_window_len, use_rope=args.use_rope, rope_theta=args.rope_theta, + rope_use_fp32_in_outer_product=args.rope_use_fp32_in_outer_product, init_base_std=args.init_base_std, init_std_factor=args.init_std_factor, n_kv_heads=args.n_kv_heads, @@ -661,6 +658,7 @@ def create_local_decoder(args: ByteLatentTransformerArgs) -> LocalDecoder: sliding_window=args.local_attention_window_len, use_rope=args.use_rope, rope_theta=args.rope_theta, + rope_use_fp32_in_outer_product=args.rope_use_fp32_in_outer_product, init_base_std=args.init_base_std, init_std_factor=args.init_std_factor, n_kv_heads=args.n_kv_heads, diff --git a/bytelatent/model/local_models.py b/bytelatent/model/local_models.py index c16f62e..d0e24c0 100644 --- a/bytelatent/model/local_models.py +++ b/bytelatent/model/local_models.py @@ -86,6 +86,7 @@ class LocalModelBase(nn.Module): theta=args.rope_theta, head_dim=args.head_dim or args.dim // args.n_heads, max_seqlen=args.max_seqlen, + rope_use_fp32_in_outer_product=args.rope_use_fp32_in_outer_product, ) self.pos_embeddings = None diff --git a/bytelatent/train.py b/bytelatent/train.py index 86d1c7a..bb8307a 100644 --- a/bytelatent/train.py +++ b/bytelatent/train.py @@ -325,6 +325,7 @@ def train(args: TrainArgs): # log model size + logger.info(model) logger.info(f"Model size: {model_param_count:,} total parameters") gpu_memory_monitor = GPUMemoryMonitor("cuda") From f058373889d55fc657bfa6bc34cf039e0832207e Mon Sep 17 00:00:00 2001 From: Pedro Rodriguez Date: Thu, 6 Feb 2025 17:37:20 +0000 Subject: [PATCH 4/4] Update checkpointing to use fsspec Summary: - Make the data/checkpoint code fsspec compatible - Still will not work with s3 saves, due to `torch.distributed.checkpoint.save` not being out of the box workable with `fsspec`. Will implement in followup PR Test Plan: Run unit tests and the commands below ``` python -m bytelatent.train config=internal/configs/s3_debug.yaml eval=null checkpoint.dump.every=100 ``` ``` torchrun --nproc-per-node 8 -m bytelatent.train config=internal/configs/s3_debug.yaml eval=null checkpoint.dump.every=100 ``` These currently won't work due to the torch distributed save, but theses hould be tested at a later date ``` python -m bytelatent.train config=internal/configs/s3_debug.yaml eval=null checkpoint.dump.every=100 dump_dir=s3://blt/scratch/checkpoint-test/ ``` ``` torchrun --nproc-per-node 8 -m bytelatent.train config=internal/configs/s3_debug.yaml eval=null checkpoint.dump.every=100 dump_dir=s3://blt/scratch/checkpoint-test/ ``` --- bytelatent/args.py | 14 +++-- bytelatent/checkpoint.py | 115 +++++++++++++++++++++------------------ bytelatent/logger.py | 9 ++- bytelatent/metrics.py | 15 ++++- bytelatent/train.py | 31 +++++++---- 5 files changed, 112 insertions(+), 72 deletions(-) diff --git a/bytelatent/args.py b/bytelatent/args.py index d1bac46..fc72b32 100644 --- a/bytelatent/args.py +++ b/bytelatent/args.py @@ -294,6 +294,14 @@ class TrainArgs(BaseModel): def dump_to_yaml_file( self, path: str, log_config: bool = True, sort_keys: bool = True ): + yaml_str = self.dump_to_yaml_str(sort_keys=sort_keys) + with open(path, "w") as f: + if log_config: + logger.info("Using the following config for this run:") + logger.info(yaml_str) + f.write(yaml_str) + + def dump_to_yaml_str(self, sort_keys: bool = True): model_dict = self.model_dump(mode="json") yaml_str = yaml.dump( model_dict, @@ -301,8 +309,4 @@ class TrainArgs(BaseModel): sort_keys=sort_keys, default_flow_style=False, ) - with open(path, "w") as f: - if log_config: - logger.info("Using the following config for this run:") - logger.info(yaml_str) - f.write(yaml_str) + return yaml_str diff --git a/bytelatent/checkpoint.py b/bytelatent/checkpoint.py index f213c84..1668c88 100644 --- a/bytelatent/checkpoint.py +++ b/bytelatent/checkpoint.py @@ -4,10 +4,9 @@ import json import logging import os import re -from pathlib import Path -from typing import List, Optional, Tuple import fsspec +import s3fs import torch import torch.distributed as dist import torch.distributed.checkpoint as dcp @@ -70,26 +69,29 @@ def consolidate_checkpoints(fs: fsspec.AbstractFileSystem, ckpt_dir: str): Returns the path to the consolidated checkpoint """ - consolidate_path = Path(ckpt_dir) / CONSOLIDATE_FOLDER - if not (consolidate_path / CONSOLIDATE_NAME).exists(): - consolidate_path.mkdir(exist_ok=True) - logger.info(f"Consolidating to: {str(consolidate_path)}") - dcp_to_torch_save(ckpt_dir, str(consolidate_path / CONSOLIDATE_NAME)) - (consolidate_path / CONFIG_NAME).write_text( - (Path(ckpt_dir) / CONFIG_NAME).read_text() + consolidate_path = os.path.join(ckpt_dir, CONSOLIDATE_FOLDER) + consolidate_name = os.path.join(consolidate_path, CONSOLIDATE_NAME) + if not fs.exists(consolidate_name): + fs.mkdirs(consolidate_path, exist_ok=True) + logger.info(f"Consolidating to: {consolidate_path}") + dcp_to_torch_save(ckpt_dir, consolidate_name) + fs.write_text( + os.path.join(consolidate_path, CONFIG_NAME), + fs.read_text(os.path.join(ckpt_dir, CONFIG_NAME)), ) logger.info("Consolidated !") return consolidate_path def load_from_checkpoint( + fs: fsspec.AbstractFileSystem, ckpt_dir: str, model: nn.Module, - optimizer: Optional[torch.optim.Optimizer] = None, + optimizer: torch.optim.Optimizer | None = None, model_key: str = "model", optim_key: str = "optim", ): - if not (Path(ckpt_dir) / ".metadata").exists(): + if not fs.exists(os.path.join(ckpt_dir, ".metadata")): raise ValueError( f"Please convert the checkpoint distcp format using `torch.distributed.checkpoint.format_utils.torch_save_to_dcp` before loading it" ) @@ -115,19 +117,24 @@ class CheckpointManager: self.init_ckpt_path = args.init_ckpt_path self.continue_training_from_init = args.continue_training_from_init - assert self.fs.exists( - self.path - ), f"Path {self.path} does not exist and needs to be created before using CheckpointManager (use instantiate_and_make_dir)" + if not isinstance(self.fs, s3fs.S3FileSystem): + # S3 does not have a concept of directories + assert self.fs.exists( + self.path + ), f"Path {self.path} does not exist and needs to be created before using CheckpointManager (use instantiate_and_make_dir)" self.existing_saves = self.get_existing_saves() - def get_existing_saves(self) -> List[Path]: - folders = [ - p - for p in Path(self.path).iterdir() - if p.is_dir() and re.match(RE_FOLDER, p.name) - ] - folders.sort(key=lambda p: _get_key_step(p.name)) + def get_existing_saves(self) -> list[str]: + if self.fs.exists(self.path) and self.fs.isdir(self.path): + folders = [ + p + for p in self.fs.ls(self.path) + if self.fs.isdir(p) and re.match(RE_FOLDER, os.path.basename(p)) + ] + else: + folders = [] + folders.sort(key=lambda p: _get_key_step(os.path.basename(p))) return folders def clean_up(self): @@ -136,8 +143,9 @@ class CheckpointManager: eval_folders = [] other_folders = [] for p in self.existing_saves: - is_dump = _get_key_step(p.name) % self.dump_every.every == 0 - is_eval = _get_key_step(p.name) % self.eval_every.every == 0 + assert isinstance(p, str), f"Base path type: {p}" + is_dump = _get_key_step(os.path.basename(p)) % self.dump_every.every == 0 + is_eval = _get_key_step(os.path.basename(p)) % self.eval_every.every == 0 if is_dump: dump_folders.append(p) if is_eval: @@ -161,40 +169,39 @@ class CheckpointManager: if dist.get_rank() == 0: for folder in folder_to_remove: - for file in folder.iterdir(): - if file.is_file(): - file.unlink() - elif file.is_dir(): - assert file.name in [CONSOLIDATE_FOLDER] - for f in file.iterdir(): - f.unlink() - file.rmdir() - folder.rmdir() + for file in self.fs.ls(folder): + if self.fs.isfile(file): + self.fs.rm_file(file) + elif self.fs.isdir(file): + assert os.path.name(file) in [CONSOLIDATE_FOLDER] + for f in self.fs.ls(file): + self.fs.rm(f) + self.fs.rmdir(file) + self.fs.rmdir(folder) dist.barrier() self.existing_saves = list(folder_to_keep) - self.existing_saves.sort(key=lambda p: _get_key_step(p.name)) + self.existing_saves.sort(key=lambda p: _get_key_step(os.path.basename(p))) - def get_last_step_path(self, dp_rank: int = 0) -> Optional[Path]: + def get_last_step_path(self, dp_rank: int = 0) -> str | None: path = None for p in reversed(self.existing_saves): - if (p / TRAIN_STATE_NAME.format(dp_rank)).is_file(): + + if self.fs.isfile(os.path.join(p, TRAIN_STATE_NAME.format(dp_rank))): path = p break return path - def _create_folder(self, base_path: Path, folder_name: str) -> Path: - folder = base_path / folder_name + def _create_folder(self, base_path: str, folder_name: str) -> str: + folder = os.path.join(base_path, folder_name) if get_is_master(): - folder.mkdir(parents=False, exist_ok=True) + self.fs.mkdirs(folder, exist_ok=True) if dist.is_initialized(): dist.barrier() return folder - def _get_dp_tp_mesh( - self, device_mesh: Optional[DeviceMesh] = None - ) -> Tuple[int, int]: + def _get_dp_tp_mesh(self, device_mesh: DeviceMesh | None = None) -> tuple[int, int]: dp_rank = 0 tp_rank = 0 if device_mesh is not None: @@ -222,14 +229,14 @@ class CheckpointManager: model, optimizer, train_state, - config, - device_mesh: Optional[DeviceMesh] = None, + config: BaseModel, + device_mesh: DeviceMesh | None = None, ) -> bool: # When creating directory check if only rank0 or is there other solution - path = Path(self.path) + path = self.path curr_save_dir = self._create_folder(path, FOLDER_NAME.format(train_state.step)) - logger.info(f"Saving to: {str(curr_save_dir)}") + logger.info(f"Saving to: {curr_save_dir}") if dist.is_initialized(): dist.barrier() @@ -242,17 +249,19 @@ class CheckpointManager: if dist.is_initialized(): dist.barrier() + print("config type", type(config)) if get_is_master(): - config.dump_to_yaml_file(curr_save_dir / CONFIG_NAME) + self.fs.write_text( + os.path.join(curr_save_dir, CONFIG_NAME), config.model_dump_json() + ) # Add json dump here dp_rank, tp_rank = self._get_dp_tp_mesh(device_mesh) if tp_rank == 0: train_state_name = TRAIN_STATE_NAME.format(dp_rank) - logger.info( - f"Saving train state to: {str(curr_save_dir / train_state_name)}" - ) - with open(curr_save_dir / train_state_name, "w") as f: + train_state_full_path = os.path.join(curr_save_dir, train_state_name) + logger.info(f"Saving train state to: {train_state_full_path}") + with self.fs.open(train_state_full_path, "w") as f: json.dump(train_state.state_dict(), f) logger.info("Train state saved !") @@ -271,7 +280,7 @@ class CheckpointManager: optimizer, train_state, device_mesh: DeviceMesh, - path: Optional[Path] = None, + path: str | None = None, ): dp_rank, tp_rank = self._get_dp_tp_mesh(device_mesh) # Loading tries to load the provided path, if not available the last saved step and finally from the init path @@ -284,12 +293,12 @@ class CheckpointManager: # Only load train state if it's provided, the files exist and we're not loading from init path train_state_name = TRAIN_STATE_NAME.format(dp_rank) logger.info("Reloading train state") - with open(path / train_state_name, "r") as f: + with self.fs.open(os.path.join(path, train_state_name), "r") as f: train_state_dict = json.load(f) train_state.load_state_dict(train_state_dict) logger.info("Train state reloaded") - logger.info(f"Loading from: {str(path)}") + logger.info(f"Loading from: {path}") state_dict = self.get_state_dict( model=model, optimizer=optimizer, diff --git a/bytelatent/logger.py b/bytelatent/logger.py index 87f04cc..6f9a397 100644 --- a/bytelatent/logger.py +++ b/bytelatent/logger.py @@ -6,6 +6,8 @@ import sys import time from datetime import timedelta +import fsspec + from bytelatent.distributed import get_global_rank, get_is_slurm_job @@ -92,6 +94,7 @@ def init_logger( *, name: str | None = None, level: str = "INFO", + fs: fsspec.AbstractFileSystem | None = None, ): """ Setup logging. @@ -121,7 +124,11 @@ def init_logger( if log_file is not None and get_global_rank() == 0: # build file handler - file_handler = logging.FileHandler(log_file, "a") + if fs is None: + file_handler = logging.FileHandler(log_file, "a") + else: + file_stream = fs.open(log_file, mode="a") + file_handler = logging.StreamHandler(file_stream) file_handler.setLevel(logging.NOTSET) file_handler.setFormatter(LogFormatter()) # update logger diff --git a/bytelatent/metrics.py b/bytelatent/metrics.py index e746e4f..fb443d7 100644 --- a/bytelatent/metrics.py +++ b/bytelatent/metrics.py @@ -8,6 +8,7 @@ from datetime import datetime, timezone from pathlib import Path from typing import Any, Union +import fsspec import torch import torch.nn as nn import wandb @@ -53,14 +54,24 @@ class LoggingArgs(BaseModel): class MetricLogger: - def __init__(self, outdir: Path, args: Any | None = None): + def __init__( + self, + outdir: Path, + # args: TrainArgs + args: Any | None = None, + fs: fsspec.AbstractFileSystem | None = None, + ): self.outdir = outdir self.jsonl_writer = None + self.fs = fs self.args = args def open(self): if self.jsonl_writer is None: - self.jsonl_writer = open(self.outdir, "a") + if self.fs is None: + self.jsonl_writer = open(self.outdir, "a") + else: + self.jsonl_writer = self.fs.open(self.outdir, "a") if ( self.args is not None and self.args.logging.wandb is not None diff --git a/bytelatent/train.py b/bytelatent/train.py index bb8307a..9bfe12a 100644 --- a/bytelatent/train.py +++ b/bytelatent/train.py @@ -8,7 +8,6 @@ import sys from contextlib import ExitStack from copy import deepcopy from dataclasses import asdict, dataclass -from pathlib import Path from timeit import default_timer as timer from typing import Any, TypeVar @@ -18,13 +17,13 @@ import torch.nn.functional import torch.nn.functional as F import wandb import xformers.profiler -from omegaconf import OmegaConf from torch.distributed._tensor import DTensor from torch.distributed.checkpoint.stateful import Stateful from torch.optim import lr_scheduler from bytelatent.args import TrainArgs, parse_args from bytelatent.checkpoint import CheckpointManager, load_from_checkpoint +from bytelatent.data.file_util import get_fs from bytelatent.data.iterators.multiprocess_iterator import ( MultiprocessIterator, MultiprocessIteratorState, @@ -136,11 +135,12 @@ def validate_train_args(args: TrainArgs, output_size: int): if args.checkpoint.path is None: logger.info(f"Setting checkpoint path to {args.checkpoint.path}") - args.checkpoint.path = str(Path(args.dump_dir) / "checkpoints") + args.checkpoint.path = os.path.join(args.dump_dir, "checkpoints") + data_fs = get_fs(args.data.root_dir, s3_profile=args.data.s3_profile) for source in args.data.sources: data_path = os.path.join(args.data.root_dir, source) - assert os.path.exists(data_path), f"{data_path} doesn't exist" + assert data_fs.exists(data_path), f"{data_path} doesn't exist" if ( args.distributed.dp_replicate @@ -255,10 +255,15 @@ def train(args: TrainArgs): args, tokenizer.n_words, ) + dump_fs = get_fs(args.dump_dir, s3_profile=args.checkpoint.s3_profile) if get_is_master(): - os.makedirs(args.dump_dir, exist_ok=True) - args.dump_to_yaml_file(Path(args.dump_dir) / "config.yaml") - init_logger(Path(args.dump_dir) / "train.log") + dump_fs.mkdirs(args.dump_dir, exist_ok=True) + config_yaml_str = args.dump_to_yaml_str() + logging.info("TrainArgs: \n%s", config_yaml_str) + dump_fs.write_text( + os.path.join(args.dump_dir, "config.yaml"), config_yaml_str + ) + init_logger(os.path.join(args.dump_dir, "train.log"), fs=dump_fs) init_signal_handler(set_preemption_flag) # For handling preemption signals. setup_env(args.env) setup_torch_distributed(args.distributed) @@ -313,8 +318,11 @@ def train(args: TrainArgs): if args.checkpoint.init_ckpt_path: logger.info(f"Loading initial model from {args.checkpoint.init_ckpt_path}") + ckpt_fs = get_fs( + args.checkpoint.init_ckpt_path, s3_profile=args.checkpoint.s3_profile + ) load_from_checkpoint( - args.checkpoint.init_ckpt_path, model, model_key="model" + ckpt_fs, args.checkpoint.init_ckpt_path, model, model_key="model" ) # Put model_key="" if its directly the model checkpoint model.rope_embeddings.reset_parameters() # For RoPe initialization since it's a buffer it might not be loaded else: @@ -352,13 +360,14 @@ def train(args: TrainArgs): checkpoint.load(model, optimizer, train_state, world_mesh) # Either load from latest checkpoint or start from scratch if args.probe_freq is not None: + # TODO: Convert this to fsspec compatible if get_is_master(): - os.makedirs(Path(args.dump_dir) / "probe", exist_ok=True) + os.makedirs(os.path.join(args.dump_dir, "probe"), exist_ok=True) torch.distributed.barrier() probe = AutoProbeD( model, ( - Path(args.dump_dir) / "probe" / f"probe.{dp_rank}.jsonl" + os.path.join(args.dump_dir, "probe", f"probe.{dp_rank}.jsonl") if (dp_rank % 128 == 0) else None ), @@ -370,7 +379,7 @@ def train(args: TrainArgs): # train loop model.train() metric_logger = context_stack.enter_context( - MetricLogger(Path(args.dump_dir) / "metrics.jsonl", args) + MetricLogger(os.path.join(args.dump_dir, "metrics.jsonl"), args, fs=dump_fs) ) data_loader = train_state.data_loader_state.build() batch_iterator = data_loader.create_iter()