Update checkpointing to use fsspec

Summary:

- Make arrow iterator able to read from jsonl files, the entropies are omitted in this case
- Make the data/checkpoint code fsspec compatible
- Fix issues with all reduce with non-bf16 in dist_sum and norm computation.
- Minimal fixes to get eval to run, it is slow currently
- Add bpb numbers during training


Test Plan:

Run

```
torchrun --nproc-per-node 8 -m bytelatent.train config=internal/configs/entropy_model.yaml eval=null max_steps=10100
```

```
python -m bytelatent.train config=internal/configs/s3_debug.yaml eval=null
```

```
torchrun --nproc-per-node 8 -m bytelatent.train config=internal/configs/s3_debug.yaml eval=null
```
This commit is contained in:
Pedro Rodriguez 2025-02-05 00:20:57 +00:00
parent 7044771a12
commit e742218d65

View file

@ -4,8 +4,6 @@ import json
import logging import logging
import os import os
import re import re
from pathlib import Path
from typing import List, Optional, Tuple
import fsspec import fsspec
import torch import torch
@ -70,26 +68,29 @@ def consolidate_checkpoints(fs: fsspec.AbstractFileSystem, ckpt_dir: str):
Returns the path to the consolidated checkpoint Returns the path to the consolidated checkpoint
""" """
consolidate_path = Path(ckpt_dir) / CONSOLIDATE_FOLDER consolidate_path = os.path.join(ckpt_dir, CONSOLIDATE_FOLDER)
if not (consolidate_path / CONSOLIDATE_NAME).exists(): consolidate_name = os.path.join(consolidate_path, CONSOLIDATE_NAME)
consolidate_path.mkdir(exist_ok=True) if not fs.exists(consolidate_name):
logger.info(f"Consolidating to: {str(consolidate_path)}") fs.mkdirs(consolidate_path, exist_ok=True)
dcp_to_torch_save(ckpt_dir, str(consolidate_path / CONSOLIDATE_NAME)) logger.info(f"Consolidating to: {consolidate_path}")
(consolidate_path / CONFIG_NAME).write_text( dcp_to_torch_save(ckpt_dir, consolidate_name)
(Path(ckpt_dir) / CONFIG_NAME).read_text() fs.write_text(
os.path.join(consolidate_path, CONFIG_NAME),
fs.read_text(os.path.join(ckpt_dir, CONFIG_NAME)),
) )
logger.info("Consolidated !") logger.info("Consolidated !")
return consolidate_path return consolidate_path
def load_from_checkpoint( def load_from_checkpoint(
fs: fsspec.AbstractFileSystem,
ckpt_dir: str, ckpt_dir: str,
model: nn.Module, model: nn.Module,
optimizer: Optional[torch.optim.Optimizer] = None, optimizer: torch.optim.Optimizer | None = None,
model_key: str = "model", model_key: str = "model",
optim_key: str = "optim", optim_key: str = "optim",
): ):
if not (Path(ckpt_dir) / ".metadata").exists(): if not fs.exists(os.path.join(ckpt_dir, ".metadata")):
raise ValueError( raise ValueError(
f"Please convert the checkpoint distcp format using `torch.distributed.checkpoint.format_utils.torch_save_to_dcp` before loading it" f"Please convert the checkpoint distcp format using `torch.distributed.checkpoint.format_utils.torch_save_to_dcp` before loading it"
) )
@ -121,13 +122,13 @@ class CheckpointManager:
self.existing_saves = self.get_existing_saves() self.existing_saves = self.get_existing_saves()
def get_existing_saves(self) -> List[Path]: def get_existing_saves(self) -> list[str]:
folders = [ folders = [
p p
for p in Path(self.path).iterdir() for p in self.fs.ls(self.path)
if p.is_dir() and re.match(RE_FOLDER, p.name) if self.fs.isdir(p) and re.match(RE_FOLDER, os.path.basename(p))
] ]
folders.sort(key=lambda p: _get_key_step(p.name)) folders.sort(key=lambda p: _get_key_step(os.path.basename(p)))
return folders return folders
def clean_up(self): def clean_up(self):
@ -136,8 +137,9 @@ class CheckpointManager:
eval_folders = [] eval_folders = []
other_folders = [] other_folders = []
for p in self.existing_saves: for p in self.existing_saves:
is_dump = _get_key_step(p.name) % self.dump_every.every == 0 assert isinstance(p, str), f"Base path type: {p}"
is_eval = _get_key_step(p.name) % self.eval_every.every == 0 is_dump = _get_key_step(os.path.basename(p)) % self.dump_every.every == 0
is_eval = _get_key_step(os.path.basename(p)) % self.eval_every.every == 0
if is_dump: if is_dump:
dump_folders.append(p) dump_folders.append(p)
if is_eval: if is_eval:
@ -161,40 +163,39 @@ class CheckpointManager:
if dist.get_rank() == 0: if dist.get_rank() == 0:
for folder in folder_to_remove: for folder in folder_to_remove:
for file in folder.iterdir(): for file in self.fs.ls(folder):
if file.is_file(): if self.fs.isfile(file):
file.unlink() self.fs.rm_file(file)
elif file.is_dir(): elif self.fs.isdir(file):
assert file.name in [CONSOLIDATE_FOLDER] assert os.path.name(file) in [CONSOLIDATE_FOLDER]
for f in file.iterdir(): for f in self.fs.ls(file):
f.unlink() self.fs.rm(f)
file.rmdir() self.fs.rmdir(file)
folder.rmdir() self.fs.rmdir(folder)
dist.barrier() dist.barrier()
self.existing_saves = list(folder_to_keep) self.existing_saves = list(folder_to_keep)
self.existing_saves.sort(key=lambda p: _get_key_step(p.name)) self.existing_saves.sort(key=lambda p: _get_key_step(os.path.basename(p)))
def get_last_step_path(self, dp_rank: int = 0) -> Optional[Path]: def get_last_step_path(self, dp_rank: int = 0) -> str | None:
path = None path = None
for p in reversed(self.existing_saves): for p in reversed(self.existing_saves):
if (p / TRAIN_STATE_NAME.format(dp_rank)).is_file():
if self.fs.isfile(os.path.join(p, TRAIN_STATE_NAME.format(dp_rank))):
path = p path = p
break break
return path return path
def _create_folder(self, base_path: Path, folder_name: str) -> Path: def _create_folder(self, base_path: str, folder_name: str) -> str:
folder = base_path / folder_name folder = os.path.join(base_path, folder_name)
if get_is_master(): if get_is_master():
folder.mkdir(parents=False, exist_ok=True) self.fs.mkdirs(folder, exist_ok=True)
if dist.is_initialized(): if dist.is_initialized():
dist.barrier() dist.barrier()
return folder return folder
def _get_dp_tp_mesh( def _get_dp_tp_mesh(self, device_mesh: DeviceMesh | None = None) -> tuple[int, int]:
self, device_mesh: Optional[DeviceMesh] = None
) -> Tuple[int, int]:
dp_rank = 0 dp_rank = 0
tp_rank = 0 tp_rank = 0
if device_mesh is not None: if device_mesh is not None:
@ -222,14 +223,14 @@ class CheckpointManager:
model, model,
optimizer, optimizer,
train_state, train_state,
config, config: BaseModel,
device_mesh: Optional[DeviceMesh] = None, device_mesh: DeviceMesh | None = None,
) -> bool: ) -> bool:
# When creating directory check if only rank0 or is there other solution # When creating directory check if only rank0 or is there other solution
path = Path(self.path) path = self.path
curr_save_dir = self._create_folder(path, FOLDER_NAME.format(train_state.step)) curr_save_dir = self._create_folder(path, FOLDER_NAME.format(train_state.step))
logger.info(f"Saving to: {str(curr_save_dir)}") logger.info(f"Saving to: {curr_save_dir}")
if dist.is_initialized(): if dist.is_initialized():
dist.barrier() dist.barrier()
@ -242,17 +243,19 @@ class CheckpointManager:
if dist.is_initialized(): if dist.is_initialized():
dist.barrier() dist.barrier()
print("config type", type(config))
if get_is_master(): if get_is_master():
config.dump_to_yaml_file(curr_save_dir / CONFIG_NAME) self.fs.write_text(
os.path.join(curr_save_dir, CONFIG_NAME), config.model_dump_json()
)
# Add json dump here # Add json dump here
dp_rank, tp_rank = self._get_dp_tp_mesh(device_mesh) dp_rank, tp_rank = self._get_dp_tp_mesh(device_mesh)
if tp_rank == 0: if tp_rank == 0:
train_state_name = TRAIN_STATE_NAME.format(dp_rank) train_state_name = TRAIN_STATE_NAME.format(dp_rank)
logger.info( train_state_full_path = os.path.join(curr_save_dir, train_state_name)
f"Saving train state to: {str(curr_save_dir / train_state_name)}" logger.info(f"Saving train state to: {train_state_full_path}")
) with self.fs.open(train_state_full_path, "w") as f:
with open(curr_save_dir / train_state_name, "w") as f:
json.dump(train_state.state_dict(), f) json.dump(train_state.state_dict(), f)
logger.info("Train state saved !") logger.info("Train state saved !")
@ -271,7 +274,7 @@ class CheckpointManager:
optimizer, optimizer,
train_state, train_state,
device_mesh: DeviceMesh, device_mesh: DeviceMesh,
path: Optional[Path] = None, path: str | None = None,
): ):
dp_rank, tp_rank = self._get_dp_tp_mesh(device_mesh) dp_rank, tp_rank = self._get_dp_tp_mesh(device_mesh)
# Loading tries to load the provided path, if not available the last saved step and finally from the init path # Loading tries to load the provided path, if not available the last saved step and finally from the init path
@ -284,12 +287,12 @@ class CheckpointManager:
# Only load train state if it's provided, the files exist and we're not loading from init path # Only load train state if it's provided, the files exist and we're not loading from init path
train_state_name = TRAIN_STATE_NAME.format(dp_rank) train_state_name = TRAIN_STATE_NAME.format(dp_rank)
logger.info("Reloading train state") logger.info("Reloading train state")
with open(path / train_state_name, "r") as f: with self.fs.open(os.path.join(path, train_state_name), "r") as f:
train_state_dict = json.load(f) train_state_dict = json.load(f)
train_state.load_state_dict(train_state_dict) train_state.load_state_dict(train_state_dict)
logger.info("Train state reloaded") logger.info("Train state reloaded")
logger.info(f"Loading from: {str(path)}") logger.info(f"Loading from: {path}")
state_dict = self.get_state_dict( state_dict = self.get_state_dict(
model=model, model=model,
optimizer=optimizer, optimizer=optimizer,