mirror of
https://github.com/facebookresearch/blt.git
synced 2025-02-23 13:32:14 +00:00
Several changes to enable entropy model training/eval
Summary: - Make arrow iterator able to read from jsonl files, the entropies are omitted in this case - Make the data/checkpoint code fsspec compatible - Fix issues with all reduce with non-bf16 in dist_sum and norm computation. - Minimal fixes to get eval to run, it is slow currently - Add bpb numbers during training Test Plan:
This commit is contained in:
parent
7044771a12
commit
c6ef4285e2
|
@ -1,8 +1,10 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
import fsspec
|
||||
import numpy as np
|
||||
import yaml
|
||||
from omegaconf import OmegaConf
|
||||
|
@ -10,10 +12,10 @@ from pydantic import BaseModel, ConfigDict
|
|||
|
||||
from bytelatent.checkpoint import CheckpointArgs
|
||||
from bytelatent.data.data_types import Batch
|
||||
from bytelatent.data.file_util import get_fs
|
||||
from bytelatent.data.iterators.abstract_iterator import StatefulIterator
|
||||
from bytelatent.data.iterators.arrow_iterator import (
|
||||
ArrowFileIterator,
|
||||
find_and_sanitize_chunks,
|
||||
)
|
||||
from bytelatent.data.iterators.looping_iterator import LoopingIterator
|
||||
from bytelatent.data.iterators.multiprocess_iterator import MultiprocessIterator
|
||||
|
@ -53,6 +55,43 @@ def parse_args(args_cls):
|
|||
return pydantic_args
|
||||
|
||||
|
||||
def read_args_file(fs: fsspec.AbstractFileSystem, path: str) -> Any:
|
||||
with fs.open(path, "rt") as f:
|
||||
if path.endswith(".json"):
|
||||
return json.load(f)
|
||||
elif path.endswith(".yaml"):
|
||||
return yaml.load(f)
|
||||
else:
|
||||
raise ValueError("Invalid args file format")
|
||||
|
||||
|
||||
TRAIN_DATA_FILE_PATTERN = "*.chunk.*.jsonl"
|
||||
|
||||
|
||||
def find_and_sanitize_chunks(
|
||||
dataset_path: str,
|
||||
world_size: int,
|
||||
file_pattern: str,
|
||||
s3_profile: str | None = None,
|
||||
):
|
||||
fs = get_fs(dataset_path, s3_profile=s3_profile)
|
||||
path_with_glob = os.path.join(dataset_path, file_pattern)
|
||||
dataset_chunks = fs.glob(path_with_glob)
|
||||
n_chunks = len(dataset_chunks)
|
||||
|
||||
if n_chunks > world_size:
|
||||
n_discard = n_chunks - world_size
|
||||
dataset_chunks = dataset_chunks[:world_size]
|
||||
else:
|
||||
assert (
|
||||
world_size % n_chunks == 0
|
||||
), "World size should be a multiple of number of chunks"
|
||||
|
||||
assert n_chunks > 0, f"No valid chunks in {dataset_path}"
|
||||
|
||||
return dataset_chunks
|
||||
|
||||
|
||||
def distribute_data_to_rank(
|
||||
*,
|
||||
dataset_path: str,
|
||||
|
@ -62,9 +101,10 @@ def distribute_data_to_rank(
|
|||
rank: int,
|
||||
world_size: int,
|
||||
s3_profile: str | None = None,
|
||||
file_pattern: str = TRAIN_DATA_FILE_PATTERN,
|
||||
) -> ArrowFileIterator:
|
||||
dataset_chunks = find_and_sanitize_chunks(
|
||||
dataset_path, world_size, s3_profile=s3_profile
|
||||
dataset_path, world_size, file_pattern, s3_profile=s3_profile
|
||||
)
|
||||
n_workers_per_chunk = world_size // len(dataset_chunks)
|
||||
rank_to_arrow_iterator_params = []
|
||||
|
|
|
@ -4,8 +4,6 @@ import json
|
|||
import logging
|
||||
import os
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import fsspec
|
||||
import torch
|
||||
|
@ -70,26 +68,29 @@ def consolidate_checkpoints(fs: fsspec.AbstractFileSystem, ckpt_dir: str):
|
|||
|
||||
Returns the path to the consolidated checkpoint
|
||||
"""
|
||||
consolidate_path = Path(ckpt_dir) / CONSOLIDATE_FOLDER
|
||||
if not (consolidate_path / CONSOLIDATE_NAME).exists():
|
||||
consolidate_path.mkdir(exist_ok=True)
|
||||
logger.info(f"Consolidating to: {str(consolidate_path)}")
|
||||
dcp_to_torch_save(ckpt_dir, str(consolidate_path / CONSOLIDATE_NAME))
|
||||
(consolidate_path / CONFIG_NAME).write_text(
|
||||
(Path(ckpt_dir) / CONFIG_NAME).read_text()
|
||||
consolidate_path = os.path.join(ckpt_dir, CONSOLIDATE_FOLDER)
|
||||
consolidate_name = os.path.join(consolidate_path, CONSOLIDATE_NAME)
|
||||
if not fs.exists(consolidate_name):
|
||||
fs.mkdirs(consolidate_path, exist_ok=True)
|
||||
logger.info(f"Consolidating to: {consolidate_path}")
|
||||
dcp_to_torch_save(ckpt_dir, consolidate_name)
|
||||
fs.write_text(
|
||||
os.path.join(consolidate_path, CONFIG_NAME),
|
||||
fs.read_text(os.path.join(ckpt_dir, CONFIG_NAME)),
|
||||
)
|
||||
logger.info("Consolidated !")
|
||||
return consolidate_path
|
||||
|
||||
|
||||
def load_from_checkpoint(
|
||||
fs: fsspec.AbstractFileSystem,
|
||||
ckpt_dir: str,
|
||||
model: nn.Module,
|
||||
optimizer: Optional[torch.optim.Optimizer] = None,
|
||||
optimizer: torch.optim.Optimizer | None = None,
|
||||
model_key: str = "model",
|
||||
optim_key: str = "optim",
|
||||
):
|
||||
if not (Path(ckpt_dir) / ".metadata").exists():
|
||||
if not fs.exists(os.path.join(ckpt_dir, ".metadata")):
|
||||
raise ValueError(
|
||||
f"Please convert the checkpoint distcp format using `torch.distributed.checkpoint.format_utils.torch_save_to_dcp` before loading it"
|
||||
)
|
||||
|
@ -121,13 +122,13 @@ class CheckpointManager:
|
|||
|
||||
self.existing_saves = self.get_existing_saves()
|
||||
|
||||
def get_existing_saves(self) -> List[Path]:
|
||||
def get_existing_saves(self) -> list[str]:
|
||||
folders = [
|
||||
p
|
||||
for p in Path(self.path).iterdir()
|
||||
if p.is_dir() and re.match(RE_FOLDER, p.name)
|
||||
for p in self.fs.ls(self.path)
|
||||
if self.fs.isdir(p) and re.match(RE_FOLDER, os.path.basename(p))
|
||||
]
|
||||
folders.sort(key=lambda p: _get_key_step(p.name))
|
||||
folders.sort(key=lambda p: _get_key_step(os.path.basename(p)))
|
||||
return folders
|
||||
|
||||
def clean_up(self):
|
||||
|
@ -136,8 +137,9 @@ class CheckpointManager:
|
|||
eval_folders = []
|
||||
other_folders = []
|
||||
for p in self.existing_saves:
|
||||
is_dump = _get_key_step(p.name) % self.dump_every.every == 0
|
||||
is_eval = _get_key_step(p.name) % self.eval_every.every == 0
|
||||
assert isinstance(p, str), f"Base path type: {p}"
|
||||
is_dump = _get_key_step(os.path.basename(p)) % self.dump_every.every == 0
|
||||
is_eval = _get_key_step(os.path.basename(p)) % self.eval_every.every == 0
|
||||
if is_dump:
|
||||
dump_folders.append(p)
|
||||
if is_eval:
|
||||
|
@ -161,40 +163,39 @@ class CheckpointManager:
|
|||
|
||||
if dist.get_rank() == 0:
|
||||
for folder in folder_to_remove:
|
||||
for file in folder.iterdir():
|
||||
if file.is_file():
|
||||
file.unlink()
|
||||
elif file.is_dir():
|
||||
assert file.name in [CONSOLIDATE_FOLDER]
|
||||
for f in file.iterdir():
|
||||
f.unlink()
|
||||
file.rmdir()
|
||||
folder.rmdir()
|
||||
for file in self.fs.ls(folder):
|
||||
if self.fs.isfile(file):
|
||||
self.fs.rm_file(file)
|
||||
elif self.fs.isdir(file):
|
||||
assert os.path.name(file) in [CONSOLIDATE_FOLDER]
|
||||
for f in self.fs.ls(file):
|
||||
self.fs.rm(f)
|
||||
self.fs.rmdir(file)
|
||||
self.fs.rmdir(folder)
|
||||
|
||||
dist.barrier()
|
||||
|
||||
self.existing_saves = list(folder_to_keep)
|
||||
self.existing_saves.sort(key=lambda p: _get_key_step(p.name))
|
||||
self.existing_saves.sort(key=lambda p: _get_key_step(os.path.basename(p)))
|
||||
|
||||
def get_last_step_path(self, dp_rank: int = 0) -> Optional[Path]:
|
||||
def get_last_step_path(self, dp_rank: int = 0) -> str | None:
|
||||
path = None
|
||||
for p in reversed(self.existing_saves):
|
||||
if (p / TRAIN_STATE_NAME.format(dp_rank)).is_file():
|
||||
|
||||
if self.fs.isfile(os.path.join(p, TRAIN_STATE_NAME.format(dp_rank))):
|
||||
path = p
|
||||
break
|
||||
return path
|
||||
|
||||
def _create_folder(self, base_path: Path, folder_name: str) -> Path:
|
||||
folder = base_path / folder_name
|
||||
def _create_folder(self, base_path: str, folder_name: str) -> str:
|
||||
folder = os.path.join(base_path, folder_name)
|
||||
if get_is_master():
|
||||
folder.mkdir(parents=False, exist_ok=True)
|
||||
self.fs.mkdirs(folder, exist_ok=True)
|
||||
if dist.is_initialized():
|
||||
dist.barrier()
|
||||
return folder
|
||||
|
||||
def _get_dp_tp_mesh(
|
||||
self, device_mesh: Optional[DeviceMesh] = None
|
||||
) -> Tuple[int, int]:
|
||||
def _get_dp_tp_mesh(self, device_mesh: DeviceMesh | None = None) -> tuple[int, int]:
|
||||
dp_rank = 0
|
||||
tp_rank = 0
|
||||
if device_mesh is not None:
|
||||
|
@ -222,14 +223,14 @@ class CheckpointManager:
|
|||
model,
|
||||
optimizer,
|
||||
train_state,
|
||||
config,
|
||||
device_mesh: Optional[DeviceMesh] = None,
|
||||
config: BaseModel,
|
||||
device_mesh: DeviceMesh | None = None,
|
||||
) -> bool:
|
||||
|
||||
# When creating directory check if only rank0 or is there other solution
|
||||
path = Path(self.path)
|
||||
path = self.path
|
||||
curr_save_dir = self._create_folder(path, FOLDER_NAME.format(train_state.step))
|
||||
logger.info(f"Saving to: {str(curr_save_dir)}")
|
||||
logger.info(f"Saving to: {curr_save_dir}")
|
||||
|
||||
if dist.is_initialized():
|
||||
dist.barrier()
|
||||
|
@ -242,17 +243,19 @@ class CheckpointManager:
|
|||
if dist.is_initialized():
|
||||
dist.barrier()
|
||||
|
||||
print("config type", type(config))
|
||||
if get_is_master():
|
||||
config.dump_to_yaml_file(curr_save_dir / CONFIG_NAME)
|
||||
self.fs.write_text(
|
||||
os.path.join(curr_save_dir, CONFIG_NAME), config.model_dump_json()
|
||||
)
|
||||
|
||||
# Add json dump here
|
||||
dp_rank, tp_rank = self._get_dp_tp_mesh(device_mesh)
|
||||
if tp_rank == 0:
|
||||
train_state_name = TRAIN_STATE_NAME.format(dp_rank)
|
||||
logger.info(
|
||||
f"Saving train state to: {str(curr_save_dir / train_state_name)}"
|
||||
)
|
||||
with open(curr_save_dir / train_state_name, "w") as f:
|
||||
train_state_full_path = os.path.join(curr_save_dir, train_state_name)
|
||||
logger.info(f"Saving train state to: {train_state_full_path}")
|
||||
with self.fs.open(train_state_full_path, "w") as f:
|
||||
json.dump(train_state.state_dict(), f)
|
||||
logger.info("Train state saved !")
|
||||
|
||||
|
@ -271,7 +274,7 @@ class CheckpointManager:
|
|||
optimizer,
|
||||
train_state,
|
||||
device_mesh: DeviceMesh,
|
||||
path: Optional[Path] = None,
|
||||
path: str | None = None,
|
||||
):
|
||||
dp_rank, tp_rank = self._get_dp_tp_mesh(device_mesh)
|
||||
# Loading tries to load the provided path, if not available the last saved step and finally from the init path
|
||||
|
@ -284,12 +287,12 @@ class CheckpointManager:
|
|||
# Only load train state if it's provided, the files exist and we're not loading from init path
|
||||
train_state_name = TRAIN_STATE_NAME.format(dp_rank)
|
||||
logger.info("Reloading train state")
|
||||
with open(path / train_state_name, "r") as f:
|
||||
with self.fs.open(os.path.join(path, train_state_name), "r") as f:
|
||||
train_state_dict = json.load(f)
|
||||
train_state.load_state_dict(train_state_dict)
|
||||
logger.info("Train state reloaded")
|
||||
|
||||
logger.info(f"Loading from: {str(path)}")
|
||||
logger.info(f"Loading from: {path}")
|
||||
state_dict = self.get_state_dict(
|
||||
model=model,
|
||||
optimizer=optimizer,
|
||||
|
|
|
@ -16,6 +16,10 @@ from bytelatent import ByteLatentError
|
|||
from bytelatent.data.data_types import BltExample
|
||||
from bytelatent.data.file_util import get_fs
|
||||
from bytelatent.data.iterators.abstract_iterator import IteratorState, StatefulIterator
|
||||
from bytelatent.preprocess.preprocess_entropies import (
|
||||
get_id_key,
|
||||
get_text,
|
||||
)
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
@ -32,6 +36,7 @@ class ArrowFileIteratorState(BaseModel, IteratorState):
|
|||
arrow_batch_size: int = 100
|
||||
s3_profile: str | None
|
||||
filesystem_type: str | None = None
|
||||
file_format: str
|
||||
|
||||
def build(self) -> "ArrowFileIterator":
|
||||
arrow_file = ArrowFileIterator(
|
||||
|
@ -44,6 +49,7 @@ class ArrowFileIteratorState(BaseModel, IteratorState):
|
|||
dataset_files=self.dataset_files,
|
||||
s3_profile=self.s3_profile,
|
||||
filesystem_type=self.filesystem_type,
|
||||
file_format=self.file_format,
|
||||
)
|
||||
if self.row_num != 0:
|
||||
arrow_file._set_row_num(self.row_num)
|
||||
|
@ -70,6 +76,7 @@ class ArrowFileIterator(StatefulIterator):
|
|||
dataset_files: list[str] | None = None,
|
||||
s3_profile: str | None = None,
|
||||
filesystem_type: str | None = None,
|
||||
file_format: str = "arrow",
|
||||
):
|
||||
assert 0 <= worker_id < num_workers, (worker_id, num_workers)
|
||||
if file_path is None and dataset_files is None:
|
||||
|
@ -87,12 +94,16 @@ class ArrowFileIterator(StatefulIterator):
|
|||
self.arrow_batch_size = arrow_batch_size
|
||||
self.s3_profile = s3_profile
|
||||
self.filesystem_type = filesystem_type
|
||||
self.file_format = file_format
|
||||
self.fs = None
|
||||
if self.filesystem_type is not None:
|
||||
if self.filesystem_type == "file":
|
||||
self.fs = fsspec.filesystem("file")
|
||||
elif self.filesystem_type == "s3":
|
||||
self.fs = fsspec.filesystem("s3", profile=s3_profile)
|
||||
else:
|
||||
raise ValueError("Unknown filesystem")
|
||||
logger.info("Arrow iterator using fs=%s", self.fs)
|
||||
|
||||
if dataset_files is None:
|
||||
# Prepare arrow shards
|
||||
|
@ -153,6 +164,7 @@ class ArrowFileIterator(StatefulIterator):
|
|||
dataset_files=self.dataset_files,
|
||||
s3_profile=self.s3_profile,
|
||||
filesystem_type=self.filesystem_type,
|
||||
file_format=self.file_format,
|
||||
)
|
||||
|
||||
def create_iter(
|
||||
|
@ -164,7 +176,7 @@ class ArrowFileIterator(StatefulIterator):
|
|||
else:
|
||||
filesystem = None
|
||||
self.dataset = pa.dataset.dataset(
|
||||
self.dataset_files, format="arrow", filesystem=filesystem
|
||||
self.dataset_files, format=self.file_format, filesystem=filesystem
|
||||
)
|
||||
self.batch_iterator = self.dataset.to_batches(
|
||||
batch_size=self.arrow_batch_size
|
||||
|
@ -173,13 +185,22 @@ class ArrowFileIterator(StatefulIterator):
|
|||
if self.batch_to_consume is not None:
|
||||
batch_columns: dict[str, list] = self.batch_to_consume
|
||||
self.batch_to_consume = None
|
||||
sample_ids = batch_columns["sample_id"]
|
||||
texts = batch_columns["text"]
|
||||
entropies = batch_columns["entropies"]
|
||||
if self.file_format == "arrow":
|
||||
sample_ids = batch_columns["sample_id"]
|
||||
texts = batch_columns["text"]
|
||||
entropies = batch_columns["entropies"]
|
||||
elif self.file_format == "json":
|
||||
# This data hasn't been preprocessed to a uniform format,
|
||||
# so we have to do it now and omit entropies
|
||||
sample_ids = batch_columns[get_id_key(batch_columns)]
|
||||
texts = get_text(batch_columns)
|
||||
entropies = None
|
||||
else:
|
||||
raise ValueError(f"Unknown file format: {self.file_format}")
|
||||
for i in range(len(sample_ids)):
|
||||
out = BltExample(
|
||||
sample_id=sample_ids[i],
|
||||
entropies=entropies[i],
|
||||
entropies=entropies[i] if entropies is not None else None,
|
||||
text=texts[i],
|
||||
tokens=None,
|
||||
mask=None,
|
||||
|
@ -191,13 +212,22 @@ class ArrowFileIterator(StatefulIterator):
|
|||
|
||||
for batch in self.batch_iterator:
|
||||
batch_columns = batch.to_pydict()
|
||||
sample_ids = batch_columns["sample_id"]
|
||||
texts = batch_columns["text"]
|
||||
entropies = batch_columns["entropies"]
|
||||
if self.file_format == "arrow":
|
||||
sample_ids = batch_columns["sample_id"]
|
||||
texts = batch_columns["text"]
|
||||
entropies = batch_columns["entropies"]
|
||||
elif self.file_format == "json":
|
||||
# This data hasn't been preprocessed to a uniform format,
|
||||
# so we have to do it now and omit entropies
|
||||
sample_ids = batch_columns[get_id_key(batch_columns)]
|
||||
texts = get_text(batch_columns)
|
||||
entropies = None
|
||||
else:
|
||||
raise ValueError(f"Unknown file format: {self.file_format}")
|
||||
for i in range(len(sample_ids)):
|
||||
out = BltExample(
|
||||
sample_id=sample_ids[i],
|
||||
entropies=entropies[i],
|
||||
entropies=entropies[i] if entropies is not None else None,
|
||||
text=texts[i],
|
||||
tokens=None,
|
||||
mask=None,
|
||||
|
@ -231,13 +261,24 @@ class ArrowFileIterator(StatefulIterator):
|
|||
for batch in self.batch_iterator:
|
||||
if len(batch) > curr_remaining:
|
||||
batch_columns: dict[str, list] = batch.to_pydict()
|
||||
batch_columns["sample_id"] = batch_columns["sample_id"][
|
||||
curr_remaining:
|
||||
]
|
||||
batch_columns["entropies"] = batch_columns["entropies"][
|
||||
curr_remaining:
|
||||
]
|
||||
batch_columns["text"] = batch_columns["text"][curr_remaining:]
|
||||
if self.file_format == "arrow":
|
||||
leftover_sample_ids = batch_columns["sample_id"][
|
||||
curr_remaining:
|
||||
]
|
||||
leftover_entropies = batch_columns["entropies"][curr_remaining:]
|
||||
leftover_texts = batch_columns["text"][curr_remaining:]
|
||||
elif self.file_format == "json":
|
||||
leftover_sample_ids = batch_columns[get_id_key(batch_columns)][
|
||||
curr_remaining:
|
||||
]
|
||||
leftover_entropies = None
|
||||
leftover_texts = get_text(batch_columns)[curr_remaining:]
|
||||
else:
|
||||
raise ValueError(f"Unknown file format: {self.file_format}")
|
||||
|
||||
batch_columns["sample_id"] = leftover_sample_ids
|
||||
batch_columns["entropies"] = leftover_entropies
|
||||
batch_columns["text"] = leftover_texts
|
||||
self.batch_to_consume = batch_columns
|
||||
break
|
||||
elif len(batch) == curr_remaining:
|
||||
|
@ -250,30 +291,3 @@ class ArrowFileIterator(StatefulIterator):
|
|||
logger.info(
|
||||
f"Finished setting arrow position to {target_row_num} for {self.dataset_files}"
|
||||
)
|
||||
|
||||
|
||||
TRAIN_DATA_FILE_PATTERN = "*.chunk.*.jsonl"
|
||||
|
||||
|
||||
def find_and_sanitize_chunks(
|
||||
dataset_path: str,
|
||||
world_size: int,
|
||||
file_pattern: str = TRAIN_DATA_FILE_PATTERN,
|
||||
s3_profile: str | None = None,
|
||||
):
|
||||
fs = get_fs(dataset_path, s3_profile=s3_profile)
|
||||
path_with_glob = os.path.join(dataset_path, file_pattern)
|
||||
dataset_chunks = fs.glob(path_with_glob)
|
||||
n_chunks = len(dataset_chunks)
|
||||
|
||||
if n_chunks > world_size:
|
||||
n_discard = n_chunks - world_size
|
||||
dataset_chunks = dataset_chunks[:world_size]
|
||||
else:
|
||||
assert (
|
||||
world_size % n_chunks == 0
|
||||
), "World size should be a multiple of number of chunks"
|
||||
|
||||
assert n_chunks > 0, f"No valid chunks in {dataset_path}"
|
||||
|
||||
return dataset_chunks
|
||||
|
|
|
@ -127,6 +127,16 @@ def dist_max(x: Union[int, float], mesh: DeviceMesh = None):
|
|||
return tensor
|
||||
|
||||
|
||||
def dist_sum(
|
||||
x: Union[int, float], mesh: DeviceMesh = None, reduce_dtype: torch.dtype = None
|
||||
):
|
||||
tensor = torch.tensor(x).cuda()
|
||||
if reduce_dtype is not None:
|
||||
tensor = tensor.to(reduce_dtype)
|
||||
dist.all_reduce(tensor, op=ReduceOp.SUM, group=mesh.get_group() if mesh else None)
|
||||
return tensor
|
||||
|
||||
|
||||
def dist_mean(x: Union[int, float], mesh: DeviceMesh = None):
|
||||
tensor = torch.tensor(x).cuda()
|
||||
dist.all_reduce(tensor, op=ReduceOp.AVG, group=mesh.get_group() if mesh else None)
|
||||
|
@ -236,7 +246,7 @@ def setup_env(env_args: EnvironmentArgs):
|
|||
logger.warning(f"WARNING: Setting {name} to {value}")
|
||||
|
||||
|
||||
def setup_torch_distributed(dist_args):
|
||||
def setup_torch_distributed(dist_args: DistributedArgs):
|
||||
"""
|
||||
Handle single and multi-GPU / multi-node / SLURM jobs.
|
||||
Initialize the following variables:
|
||||
|
@ -388,14 +398,14 @@ def clean_env():
|
|||
|
||||
|
||||
def parallelize_model(
|
||||
model,
|
||||
model: torch.nn.Module,
|
||||
device_mesh,
|
||||
model_args,
|
||||
distributed_args: DistributedArgs,
|
||||
fsdp_grouping_plan: Optional[List[Tuple[str, bool]]] = None,
|
||||
tp_parallelize=None,
|
||||
no_recompute_ops=None,
|
||||
):
|
||||
) -> torch.nn.Module:
|
||||
if distributed_args.tp_size > 1:
|
||||
assert (
|
||||
distributed_args.fsdp_type == "full_shard"
|
||||
|
@ -429,6 +439,8 @@ def parallelize_model(
|
|||
device_mesh["dp_shard"].size() == 1
|
||||
), "dp_shard must be 1 for no_shard fsdp_type"
|
||||
|
||||
# TODO: Remove with something better
|
||||
# model = model.to(param_dtype)
|
||||
fsdp_config = dict(
|
||||
mp_policy=(
|
||||
MixedPrecisionPolicy(
|
||||
|
|
|
@ -15,9 +15,16 @@ from lm_eval.api.model import LM
|
|||
from omegaconf import OmegaConf
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
from bytelatent.args import EvalArgs, ValidationArgs, parse_args
|
||||
from bytelatent.args import (
|
||||
EvalArgs,
|
||||
TrainArgs,
|
||||
ValidationArgs,
|
||||
find_and_sanitize_chunks,
|
||||
parse_args,
|
||||
)
|
||||
from bytelatent.checkpoint import CONSOLIDATE_FOLDER, consolidate_checkpoints
|
||||
from bytelatent.data.file_util import get_fs
|
||||
from bytelatent.data.iterators.arrow_iterator import ArrowFileIterator
|
||||
from bytelatent.distributed import (
|
||||
DistributedArgs,
|
||||
dist_mean_dict,
|
||||
|
@ -117,19 +124,40 @@ class EvalHarnessLM(LM):
|
|||
return results
|
||||
|
||||
|
||||
def eval_on_val(generator, val_args: ValidationArgs, train_cfg):
|
||||
srcs = {}
|
||||
def eval_on_val(generator, val_args: ValidationArgs, train_cfg: TrainArgs):
|
||||
srcs = []
|
||||
for src in val_args.sources:
|
||||
path = os.path.join(val_args.root_dir, src)
|
||||
srcs[path] = 1.0
|
||||
srcs.append(path)
|
||||
|
||||
for src in train_cfg.data.sources:
|
||||
path = os.path.join(train_cfg.data.root_dir, src)
|
||||
srcs[path] = 1.0
|
||||
srcs.append(path)
|
||||
|
||||
multi_state = init_choice_state(
|
||||
"", srcs, 0, get_global_rank(), get_world_size(), "*.val.jsonl"
|
||||
)
|
||||
path_to_iter = setup_sources(multi_state)
|
||||
path_to_iter = {}
|
||||
for path in srcs:
|
||||
chunks = find_and_sanitize_chunks(
|
||||
path,
|
||||
world_size=1,
|
||||
file_pattern="*.val.jsonl",
|
||||
s3_profile=train_cfg.data.s3_profile,
|
||||
)
|
||||
assert (
|
||||
len(chunks) == 1
|
||||
), f"There should be only 1 chunk per validation file, but found: {chunks}"
|
||||
chunk = chunks[0]
|
||||
iterator = ArrowFileIterator(
|
||||
dataset_files=[chunk],
|
||||
file_path=None,
|
||||
preprocess_dir=None,
|
||||
entropy_model_name=None,
|
||||
worker_id=0,
|
||||
num_workers=1,
|
||||
arrow_batch_size=train_cfg.data.arrow_batch_size,
|
||||
s3_profile=train_cfg.data.s3_profile,
|
||||
file_format="json",
|
||||
)
|
||||
path_to_iter[path] = iterator
|
||||
|
||||
max_gen_len = generator.max_gen_len
|
||||
# We temporarily lower max gen len
|
||||
|
@ -137,16 +165,11 @@ def eval_on_val(generator, val_args: ValidationArgs, train_cfg):
|
|||
|
||||
all_val_metrics = {}
|
||||
for src in path_to_iter:
|
||||
jsonl_iterator = path_to_iter[src]
|
||||
example_iterator = path_to_iter[src].create_iter()
|
||||
texts = []
|
||||
logger.info(f"Running validation on {src}...")
|
||||
for step, (content, state) in enumerate(jsonl_iterator):
|
||||
if state["current_iter"] > 0 or (
|
||||
val_args.max_steps is not None and step >= val_args.max_steps
|
||||
):
|
||||
break
|
||||
content_key = "text" if ("text" in content) else "content"
|
||||
texts.append(content[content_key])
|
||||
for step, example in enumerate(example_iterator):
|
||||
texts.append(example.text)
|
||||
|
||||
_, loglikelihood, _ = generator.generate(texts)
|
||||
|
||||
|
@ -191,7 +214,7 @@ def launch_eval(eval_args: EvalArgs):
|
|||
else:
|
||||
consolidate_path = os.path.join(eval_args.ckpt_dir, CONSOLIDATE_FOLDER)
|
||||
if not fs.exists(consolidate_path) and get_global_rank() == 0:
|
||||
consolidate_path = consolidate_checkpoints(eval_args.ckpt_dir)
|
||||
consolidate_path = consolidate_checkpoints(fs, eval_args.ckpt_dir)
|
||||
|
||||
fs.mkdirs(eval_args.dump_dir, exist_ok=True)
|
||||
with fs.open(os.path.join(eval_args.dump_dir, "config.yaml"), "w") as f:
|
||||
|
@ -210,10 +233,12 @@ def launch_eval(eval_args: EvalArgs):
|
|||
|
||||
wrap = EvalHarnessLM(generator)
|
||||
# Redo
|
||||
results = simple_evaluate(wrap, eval_args.harness.model_dump())
|
||||
results = simple_evaluate(wrap, **eval_args.harness.model_dump())
|
||||
|
||||
val_results = None
|
||||
if eval_args.validation:
|
||||
val_results = eval_on_val(generator, eval_args.validation, train_cfg)
|
||||
|
||||
if get_global_rank() == 0:
|
||||
with fs.open(os.path.join(eval_args.dump_dir, "results.json"), "w") as f:
|
||||
f.write(json.dumps(results))
|
||||
|
@ -222,6 +247,7 @@ def launch_eval(eval_args: EvalArgs):
|
|||
with fs.open(os.path.join(eval_args.dump_dir, "validation.json"), "w") as f:
|
||||
f.write(json.dumps(val_results))
|
||||
logger.info(f"All validation results: {val_results}")
|
||||
|
||||
if eval_args.metric_log_dir and get_global_rank() == 0:
|
||||
metric_log_path = os.path.join(eval_args.metric_log_dir, "metrics.eval.jsonl")
|
||||
|
||||
|
|
|
@ -387,8 +387,7 @@ def load_consolidated_model_and_tokenizer(
|
|||
):
|
||||
train_args_path = os.path.join(consolidated_path, "params.json")
|
||||
fs = get_fs(train_args_path)
|
||||
with fs.open(train_args_path) as f:
|
||||
train_args = TrainArgs.model_validate_json(f.read())
|
||||
train_args = TrainArgs.model_validate_json(fs.read_text(train_args_path))
|
||||
|
||||
if train_args.train_entropy_model:
|
||||
model_args = train_args.entropy_model
|
||||
|
@ -401,7 +400,8 @@ def load_consolidated_model_and_tokenizer(
|
|||
train_args.distributed.model_dtype
|
||||
]
|
||||
tokenizer = train_args.data.tokenizer_args.build()
|
||||
st_dict = torch.load(consolidated_path / CONSOLIDATE_NAME, weights_only=True)
|
||||
with fs.open(os.path.join(consolidated_path, CONSOLIDATE_NAME)) as f:
|
||||
st_dict = torch.load(f, weights_only=True)
|
||||
model.load_state_dict(st_dict["model"])
|
||||
model = model.cuda().eval()
|
||||
for param in model.parameters():
|
||||
|
|
99
bytelatent/norms.py
Normal file
99
bytelatent/norms.py
Normal file
|
@ -0,0 +1,99 @@
|
|||
from typing import Optional, Tuple, Dict, List
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from torch.utils._foreach_utils import (
|
||||
_device_has_foreach_support,
|
||||
_group_tensors_by_device_and_dtype,
|
||||
_has_foreach_support,
|
||||
)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def fixed_clip_grad_norm_(
|
||||
parameters: torch.Tensor | list[torch.Tensor],
|
||||
max_norm: float,
|
||||
norm_type: float = 2.0,
|
||||
error_if_nonfinite: bool = False,
|
||||
foreach: Optional[bool] = None,
|
||||
) -> torch.Tensor:
|
||||
r"""Clip the gradient norm of an iterable of parameters.
|
||||
|
||||
The norm is computed over the norms of the individual gradients of all parameters,
|
||||
as if the norms of the individual gradients were concatenated into a single vector.
|
||||
Gradients are modified in-place.
|
||||
|
||||
Args:
|
||||
parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a
|
||||
single Tensor that will have gradients normalized
|
||||
max_norm (float): max norm of the gradients
|
||||
norm_type (float): type of the used p-norm. Can be ``'inf'`` for
|
||||
infinity norm.
|
||||
error_if_nonfinite (bool): if True, an error is thrown if the total
|
||||
norm of the gradients from :attr:`parameters` is ``nan``,
|
||||
``inf``, or ``-inf``. Default: False (will switch to True in the future)
|
||||
foreach (bool): use the faster foreach-based implementation.
|
||||
If ``None``, use the foreach implementation for CUDA and CPU native tensors and silently
|
||||
fall back to the slow implementation for other device types.
|
||||
Default: ``None``
|
||||
|
||||
Returns:
|
||||
Total norm of the parameter gradients (viewed as a single vector).
|
||||
"""
|
||||
if isinstance(parameters, torch.Tensor):
|
||||
parameters = [parameters]
|
||||
grads = [p.grad.to(torch.bfloat16) for p in parameters if p.grad is not None]
|
||||
max_norm = float(max_norm)
|
||||
norm_type = float(norm_type)
|
||||
if len(grads) == 0:
|
||||
return torch.tensor(0.0)
|
||||
first_device = grads[0].device
|
||||
grouped_grads: Dict[
|
||||
Tuple[torch.device, torch.dtype], Tuple[List[List[Tensor]], List[int]]
|
||||
] = _group_tensors_by_device_and_dtype(
|
||||
[grads]
|
||||
) # type: ignore[assignment]
|
||||
|
||||
norms: List[Tensor] = []
|
||||
for (device, _), ([device_grads], _) in grouped_grads.items(): # type: ignore[assignment]
|
||||
if (foreach is None and _has_foreach_support(device_grads, device)) or (
|
||||
foreach and _device_has_foreach_support(device)
|
||||
):
|
||||
norms.extend(torch._foreach_norm(device_grads, norm_type))
|
||||
elif foreach:
|
||||
raise RuntimeError(
|
||||
f"foreach=True was passed, but can't use the foreach API on {device.type} tensors"
|
||||
)
|
||||
else:
|
||||
norms.extend([torch.linalg.vector_norm(g, norm_type) for g in device_grads])
|
||||
|
||||
total_norm = torch.linalg.vector_norm(
|
||||
torch.stack([norm.to(first_device) for norm in norms]), norm_type
|
||||
)
|
||||
|
||||
if error_if_nonfinite and torch.logical_or(total_norm.isnan(), total_norm.isinf()):
|
||||
raise RuntimeError(
|
||||
f"The total norm of order {norm_type} for gradients from "
|
||||
"`parameters` is non-finite, so it cannot be clipped. To disable "
|
||||
"this error and scale the gradients by the non-finite norm anyway, "
|
||||
"set `error_if_nonfinite=False`"
|
||||
)
|
||||
clip_coef = max_norm / (total_norm + 1e-6)
|
||||
# Note: multiplying by the clamped coef is redundant when the coef is clamped to 1, but doing so
|
||||
# avoids a `if clip_coef < 1:` conditional which can require a CPU <=> device synchronization
|
||||
# when the gradients do not reside in CPU memory.
|
||||
clip_coef_clamped = torch.clamp(clip_coef, max=1.0)
|
||||
for (device, _), ([device_grads], _) in grouped_grads.items(): # type: ignore[assignment]
|
||||
if (foreach is None and _has_foreach_support(device_grads, device)) or (
|
||||
foreach and _device_has_foreach_support(device)
|
||||
):
|
||||
torch._foreach_mul_(device_grads, clip_coef_clamped.to(device))
|
||||
elif foreach:
|
||||
raise RuntimeError(
|
||||
f"foreach=True was passed, but can't use the foreach API on {device.type} tensors"
|
||||
)
|
||||
else:
|
||||
clip_coef_clamped_device = clip_coef_clamped.to(device)
|
||||
for g in device_grads:
|
||||
g.mul_(clip_coef_clamped_device)
|
||||
|
||||
return total_norm
|
|
@ -15,29 +15,37 @@ from bytelatent.entropy_model import load_entropy_model
|
|||
from bytelatent.tokenizers.build_tokenizer import TokenizerArgs
|
||||
|
||||
|
||||
def get_id_from_doc(doc: dict) -> int:
|
||||
def get_id_key(doc: dict) -> int:
|
||||
"""
|
||||
We need a reliable way to ensure that samples from jsonl
|
||||
and arrow are the same, but there is no unique id field,
|
||||
so derive the best possible
|
||||
"""
|
||||
if "sample_id" in doc:
|
||||
sample_id = doc["sample_id"]
|
||||
return "sample_id"
|
||||
elif "title" in doc:
|
||||
sample_id = doc["title"]
|
||||
return "title"
|
||||
elif "qid" in doc:
|
||||
sample_id = doc["qid"]
|
||||
return "qid"
|
||||
elif "paper_id" in doc:
|
||||
sample_id = doc["paper_id"]
|
||||
return "paper_id"
|
||||
elif "path" in doc:
|
||||
sample_id = doc["path"]
|
||||
return "path"
|
||||
elif "url" in doc:
|
||||
sample_id = doc["url"]
|
||||
return "url"
|
||||
elif "id" in doc:
|
||||
sample_id = doc["id"]
|
||||
return "id"
|
||||
else:
|
||||
raise ValueError(f"Could not find a id key from: {doc.keys()}")
|
||||
return str(sample_id)
|
||||
|
||||
|
||||
def get_id_from_doc(doc: dict) -> int:
|
||||
"""
|
||||
We need a reliable way to ensure that samples from jsonl
|
||||
and arrow are the same, but there is no unique id field,
|
||||
so derive the best possible
|
||||
"""
|
||||
return str(doc[get_id_key(doc)])
|
||||
|
||||
|
||||
def get_text(doc: dict):
|
||||
|
|
|
@ -2,6 +2,8 @@
|
|||
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
|
||||
|
||||
import gc
|
||||
import math
|
||||
import numpy as np
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
|
@ -25,6 +27,7 @@ from torch.optim import lr_scheduler
|
|||
|
||||
from bytelatent.args import TrainArgs, parse_args
|
||||
from bytelatent.checkpoint import CheckpointManager, load_from_checkpoint
|
||||
from bytelatent.data.file_util import get_fs
|
||||
from bytelatent.data.iterators.multiprocess_iterator import (
|
||||
MultiprocessIterator,
|
||||
MultiprocessIteratorState,
|
||||
|
@ -33,7 +36,9 @@ from bytelatent.data.iterators.packing_iterator import PackingIteratorState
|
|||
from bytelatent.distributed import (
|
||||
check_model_value_range,
|
||||
clean_env,
|
||||
dist_mean,
|
||||
dist_mean_dict,
|
||||
dist_sum,
|
||||
get_device_mesh,
|
||||
get_is_master,
|
||||
get_world_size,
|
||||
|
@ -47,6 +52,7 @@ from bytelatent.eval import EVAL_FOLDER_NAME, launch_eval
|
|||
from bytelatent.logger import init_logger
|
||||
from bytelatent.metrics import GPUMemoryMonitor, MetricLogger, get_num_params
|
||||
from bytelatent.model.blt import ByteLatentTransformer
|
||||
from bytelatent.norms import fixed_clip_grad_norm_
|
||||
from bytelatent.optim import build_optimizer
|
||||
from bytelatent.probe import AutoProbeD
|
||||
from bytelatent.profiling import maybe_run_profiler
|
||||
|
@ -295,8 +301,11 @@ def train(args: TrainArgs):
|
|||
|
||||
if args.checkpoint.init_ckpt_path:
|
||||
logger.info(f"Loading initial model from {args.checkpoint.init_ckpt_path}")
|
||||
ckpt_fs = get_fs(
|
||||
args.checkpoint.init_ckpt_path, s3_profile=args.checkpoint.s3_profile
|
||||
)
|
||||
load_from_checkpoint(
|
||||
args.checkpoint.init_ckpt_path, model, model_key="model"
|
||||
ckpt_fs, args.checkpoint.init_ckpt_path, model, model_key="model"
|
||||
) # Put model_key="" if its directly the model checkpoint
|
||||
model.rope_embeddings.reset_parameters() # For RoPe initialization since it's a buffer it might not be loaded
|
||||
else:
|
||||
|
@ -364,6 +373,9 @@ def train(args: TrainArgs):
|
|||
time_last_log = timer()
|
||||
gc.collect()
|
||||
saved = False
|
||||
step_losses: list[float] = []
|
||||
step_tok_losses: list[float] = []
|
||||
n_bytes: int = 0
|
||||
while train_state.step < args.steps and (
|
||||
args.max_steps is None or train_state.step < args.max_steps
|
||||
):
|
||||
|
@ -385,6 +397,24 @@ def train(args: TrainArgs):
|
|||
batch_patch_lengths = torch.from_numpy(batch.patch_lengths).cuda()
|
||||
mask = None if batch.mask is None else torch.from_numpy(batch.mask).cuda()
|
||||
|
||||
if args.data.tokenizer_args.name in ["bytes", "blt"]:
|
||||
if mask is None:
|
||||
n_bytes += batch_y.numel()
|
||||
else:
|
||||
n_bytes += mask.sum()
|
||||
elif args.data.tokenizer_args.name in ["sp", "tiktoken"]:
|
||||
for example in batch.y:
|
||||
target_tokens = tokenizer.decode(example.tolist(), cut_at_eos=False)
|
||||
n_bytes += (
|
||||
len(bytes(target_tokens, encoding="utf-8", errors="ignore"))
|
||||
+ sum(example == tokenizer.eos_id)
|
||||
+ sum(example == tokenizer.bos_id)
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unexpected tokenizer to count n_bytes for: {args.data.tokenizer_args.name}"
|
||||
)
|
||||
|
||||
if (
|
||||
not args.train_entropy_model
|
||||
and args.model.encoder_enable_byte_ngrams
|
||||
|
@ -459,7 +489,7 @@ def train(args: TrainArgs):
|
|||
batch_x, patch_lengths=batch_patch_lengths, ngram_ids=ngram_ids
|
||||
)
|
||||
|
||||
loss, _ = compute_loss(pred, batch_y, mask, train_state.scale)
|
||||
loss, tok_loss = compute_loss(pred, batch_y, mask, train_state.scale)
|
||||
|
||||
# We scale loss with grad_acc_steps so the gradient is the same
|
||||
# regardless of grad_acc_steps
|
||||
|
@ -470,8 +500,14 @@ def train(args: TrainArgs):
|
|||
# For logging we undo that scaling
|
||||
loss = loss.detach() * args.grad_acc_steps
|
||||
|
||||
grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||
model.parameters(), max_norm=args.optim.clip, foreach=True
|
||||
# Undo loss scaling so downstream down't need to worry about it
|
||||
step_losses.append((loss / train_state.scale).item())
|
||||
step_tok_losses.append(tok_loss / train_state.scale)
|
||||
|
||||
# grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||
grad_norm = fixed_clip_grad_norm_(
|
||||
model.parameters(),
|
||||
max_norm=args.optim.clip, # foreach=True
|
||||
)
|
||||
|
||||
grad_norm = (
|
||||
|
@ -559,20 +595,33 @@ def train(args: TrainArgs):
|
|||
gpu_memory_monitor.reset_peak_stats()
|
||||
nwords_since_last_log = 0
|
||||
time_last_log = timer()
|
||||
stacked_tok_loss = torch.cat(step_tok_losses, dim=0)
|
||||
total_tok_loss = dist_sum(
|
||||
stacked_tok_loss.sum().item(), reduce_dtype=torch.bfloat16
|
||||
)
|
||||
total_n_bytes = dist_sum(n_bytes, reduce_dtype=torch.bfloat16)
|
||||
avg_bpb = total_tok_loss / math.log(2) / total_n_bytes
|
||||
avg_loss = dist_mean(np.mean(step_losses).item())
|
||||
logger.info(
|
||||
f"step: {train_state.step}"
|
||||
f" acc: {train_state.acc_step}"
|
||||
f" loss: {round(loss.item(),4):>7}"
|
||||
f" loss: step={round(loss.item(),4):>7} avg={avg_loss}"
|
||||
f" bpb: {avg_bpb:3f}"
|
||||
f" grad: {grad_norm:.2e}"
|
||||
f" flops: {FLOPS:.2e}"
|
||||
f" wps: {wps:.2e}"
|
||||
f" iter: {curr_iter_time:>7}"
|
||||
f" data: {data_load_time:>5}"
|
||||
f" lr: {curr_lr:.2e}"
|
||||
f" n_bytes={total_n_bytes}"
|
||||
f" mem: {gpu_mem_stats.max_active_pct:.0f}%"
|
||||
f" pow: {gpu_mem_stats.power_draw/1000} W"
|
||||
)
|
||||
|
||||
n_bytes = 0
|
||||
step_losses = []
|
||||
step_tok_losses = []
|
||||
|
||||
if every_n_steps(
|
||||
train_state, args.checkpoint.dump.every, acc_step=0
|
||||
) or every_n_steps(train_state, args.checkpoint.eval.every, acc_step=0):
|
||||
|
|
Loading…
Reference in a new issue