Several changes to enable entropy model training/eval

Summary:

- Make arrow iterator able to read from jsonl files, the entropies are omitted in this case
- Make the data/checkpoint code fsspec compatible
- Fix issues with all reduce with non-bf16 in dist_sum and norm computation.
- Minimal fixes to get eval to run, it is slow currently
- Add bpb numbers during training


Test Plan:

Run

```
torchrun --nproc-per-node 8 -m bytelatent.train config=internal/configs/entropy_model.yaml eval=null max_steps=10100
```
This commit is contained in:
Pedro Rodriguez 2025-02-04 18:04:54 +00:00
parent 7044771a12
commit ab399e981d
9 changed files with 381 additions and 134 deletions

View file

@ -1,8 +1,10 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
import json
import logging
import os
from typing import Any
import fsspec
import numpy as np
import yaml
from omegaconf import OmegaConf
@ -10,11 +12,9 @@ from pydantic import BaseModel, ConfigDict
from bytelatent.checkpoint import CheckpointArgs
from bytelatent.data.data_types import Batch
from bytelatent.data.file_util import get_fs
from bytelatent.data.iterators.abstract_iterator import StatefulIterator
from bytelatent.data.iterators.arrow_iterator import (
ArrowFileIterator,
find_and_sanitize_chunks,
)
from bytelatent.data.iterators.arrow_iterator import ArrowFileIterator
from bytelatent.data.iterators.looping_iterator import LoopingIterator
from bytelatent.data.iterators.multiprocess_iterator import MultiprocessIterator
from bytelatent.data.iterators.packing_iterator import PackingArgs, PackingIterator
@ -53,6 +53,43 @@ def parse_args(args_cls):
return pydantic_args
def read_args_file(fs: fsspec.AbstractFileSystem, path: str) -> Any:
with fs.open(path, "rt") as f:
if path.endswith(".json"):
return json.load(f)
elif path.endswith(".yaml"):
return yaml.load(f)
else:
raise ValueError("Invalid args file format")
TRAIN_DATA_FILE_PATTERN = "*.chunk.*.jsonl"
def find_and_sanitize_chunks(
dataset_path: str,
world_size: int,
file_pattern: str,
s3_profile: str | None = None,
):
fs = get_fs(dataset_path, s3_profile=s3_profile)
path_with_glob = os.path.join(dataset_path, file_pattern)
dataset_chunks = fs.glob(path_with_glob)
n_chunks = len(dataset_chunks)
if n_chunks > world_size:
n_discard = n_chunks - world_size
dataset_chunks = dataset_chunks[:world_size]
else:
assert (
world_size % n_chunks == 0
), "World size should be a multiple of number of chunks"
assert n_chunks > 0, f"No valid chunks in {dataset_path}"
return dataset_chunks
def distribute_data_to_rank(
*,
dataset_path: str,
@ -62,9 +99,10 @@ def distribute_data_to_rank(
rank: int,
world_size: int,
s3_profile: str | None = None,
file_pattern: str = TRAIN_DATA_FILE_PATTERN,
) -> ArrowFileIterator:
dataset_chunks = find_and_sanitize_chunks(
dataset_path, world_size, s3_profile=s3_profile
dataset_path, world_size, file_pattern, s3_profile=s3_profile
)
n_workers_per_chunk = world_size // len(dataset_chunks)
rank_to_arrow_iterator_params = []

View file

@ -4,8 +4,6 @@ import json
import logging
import os
import re
from pathlib import Path
from typing import List, Optional, Tuple
import fsspec
import torch
@ -70,26 +68,29 @@ def consolidate_checkpoints(fs: fsspec.AbstractFileSystem, ckpt_dir: str):
Returns the path to the consolidated checkpoint
"""
consolidate_path = Path(ckpt_dir) / CONSOLIDATE_FOLDER
if not (consolidate_path / CONSOLIDATE_NAME).exists():
consolidate_path.mkdir(exist_ok=True)
logger.info(f"Consolidating to: {str(consolidate_path)}")
dcp_to_torch_save(ckpt_dir, str(consolidate_path / CONSOLIDATE_NAME))
(consolidate_path / CONFIG_NAME).write_text(
(Path(ckpt_dir) / CONFIG_NAME).read_text()
consolidate_path = os.path.join(ckpt_dir, CONSOLIDATE_FOLDER)
consolidate_name = os.path.join(consolidate_path, CONSOLIDATE_NAME)
if not fs.exists(consolidate_name):
fs.mkdirs(consolidate_path, exist_ok=True)
logger.info(f"Consolidating to: {consolidate_path}")
dcp_to_torch_save(ckpt_dir, consolidate_name)
fs.write_text(
os.path.join(consolidate_path, CONFIG_NAME),
fs.read_text(os.path.join(ckpt_dir, CONFIG_NAME)),
)
logger.info("Consolidated !")
return consolidate_path
def load_from_checkpoint(
fs: fsspec.AbstractFileSystem,
ckpt_dir: str,
model: nn.Module,
optimizer: Optional[torch.optim.Optimizer] = None,
optimizer: torch.optim.Optimizer | None = None,
model_key: str = "model",
optim_key: str = "optim",
):
if not (Path(ckpt_dir) / ".metadata").exists():
if not fs.exists(os.path.join(ckpt_dir, ".metadata")):
raise ValueError(
f"Please convert the checkpoint distcp format using `torch.distributed.checkpoint.format_utils.torch_save_to_dcp` before loading it"
)
@ -121,13 +122,13 @@ class CheckpointManager:
self.existing_saves = self.get_existing_saves()
def get_existing_saves(self) -> List[Path]:
def get_existing_saves(self) -> list[str]:
folders = [
p
for p in Path(self.path).iterdir()
if p.is_dir() and re.match(RE_FOLDER, p.name)
for p in self.fs.ls(self.path)
if self.fs.isdir(p) and re.match(RE_FOLDER, os.path.basename(p))
]
folders.sort(key=lambda p: _get_key_step(p.name))
folders.sort(key=lambda p: _get_key_step(os.path.basename(p)))
return folders
def clean_up(self):
@ -136,8 +137,9 @@ class CheckpointManager:
eval_folders = []
other_folders = []
for p in self.existing_saves:
is_dump = _get_key_step(p.name) % self.dump_every.every == 0
is_eval = _get_key_step(p.name) % self.eval_every.every == 0
assert isinstance(p, str), f"Base path type: {p}"
is_dump = _get_key_step(os.path.basename(p)) % self.dump_every.every == 0
is_eval = _get_key_step(os.path.basename(p)) % self.eval_every.every == 0
if is_dump:
dump_folders.append(p)
if is_eval:
@ -161,40 +163,39 @@ class CheckpointManager:
if dist.get_rank() == 0:
for folder in folder_to_remove:
for file in folder.iterdir():
if file.is_file():
file.unlink()
elif file.is_dir():
assert file.name in [CONSOLIDATE_FOLDER]
for f in file.iterdir():
f.unlink()
file.rmdir()
folder.rmdir()
for file in self.fs.ls(folder):
if self.fs.isfile(file):
self.fs.rm_file(file)
elif self.fs.isdir(file):
assert os.path.name(file) in [CONSOLIDATE_FOLDER]
for f in self.fs.ls(file):
self.fs.rm(f)
self.fs.rmdir(file)
self.fs.rmdir(folder)
dist.barrier()
self.existing_saves = list(folder_to_keep)
self.existing_saves.sort(key=lambda p: _get_key_step(p.name))
self.existing_saves.sort(key=lambda p: _get_key_step(os.path.basename(p)))
def get_last_step_path(self, dp_rank: int = 0) -> Optional[Path]:
def get_last_step_path(self, dp_rank: int = 0) -> str | None:
path = None
for p in reversed(self.existing_saves):
if (p / TRAIN_STATE_NAME.format(dp_rank)).is_file():
if self.fs.isfile(os.path.join(p, TRAIN_STATE_NAME.format(dp_rank))):
path = p
break
return path
def _create_folder(self, base_path: Path, folder_name: str) -> Path:
folder = base_path / folder_name
def _create_folder(self, base_path: str, folder_name: str) -> str:
folder = os.path.join(base_path, folder_name)
if get_is_master():
folder.mkdir(parents=False, exist_ok=True)
self.fs.mkdirs(folder, exist_ok=True)
if dist.is_initialized():
dist.barrier()
return folder
def _get_dp_tp_mesh(
self, device_mesh: Optional[DeviceMesh] = None
) -> Tuple[int, int]:
def _get_dp_tp_mesh(self, device_mesh: DeviceMesh | None = None) -> tuple[int, int]:
dp_rank = 0
tp_rank = 0
if device_mesh is not None:
@ -222,14 +223,14 @@ class CheckpointManager:
model,
optimizer,
train_state,
config,
device_mesh: Optional[DeviceMesh] = None,
config: BaseModel,
device_mesh: DeviceMesh | None = None,
) -> bool:
# When creating directory check if only rank0 or is there other solution
path = Path(self.path)
path = self.path
curr_save_dir = self._create_folder(path, FOLDER_NAME.format(train_state.step))
logger.info(f"Saving to: {str(curr_save_dir)}")
logger.info(f"Saving to: {curr_save_dir}")
if dist.is_initialized():
dist.barrier()
@ -242,17 +243,19 @@ class CheckpointManager:
if dist.is_initialized():
dist.barrier()
print("config type", type(config))
if get_is_master():
config.dump_to_yaml_file(curr_save_dir / CONFIG_NAME)
self.fs.write_text(
os.path.join(curr_save_dir, CONFIG_NAME), config.model_dump_json()
)
# Add json dump here
dp_rank, tp_rank = self._get_dp_tp_mesh(device_mesh)
if tp_rank == 0:
train_state_name = TRAIN_STATE_NAME.format(dp_rank)
logger.info(
f"Saving train state to: {str(curr_save_dir / train_state_name)}"
)
with open(curr_save_dir / train_state_name, "w") as f:
train_state_full_path = os.path.join(curr_save_dir, train_state_name)
logger.info(f"Saving train state to: {train_state_full_path}")
with self.fs.open(train_state_full_path, "w") as f:
json.dump(train_state.state_dict(), f)
logger.info("Train state saved !")
@ -271,7 +274,7 @@ class CheckpointManager:
optimizer,
train_state,
device_mesh: DeviceMesh,
path: Optional[Path] = None,
path: str | None = None,
):
dp_rank, tp_rank = self._get_dp_tp_mesh(device_mesh)
# Loading tries to load the provided path, if not available the last saved step and finally from the init path
@ -284,12 +287,12 @@ class CheckpointManager:
# Only load train state if it's provided, the files exist and we're not loading from init path
train_state_name = TRAIN_STATE_NAME.format(dp_rank)
logger.info("Reloading train state")
with open(path / train_state_name, "r") as f:
with self.fs.open(os.path.join(path, train_state_name), "r") as f:
train_state_dict = json.load(f)
train_state.load_state_dict(train_state_dict)
logger.info("Train state reloaded")
logger.info(f"Loading from: {str(path)}")
logger.info(f"Loading from: {path}")
state_dict = self.get_state_dict(
model=model,
optimizer=optimizer,

View file

@ -16,6 +16,7 @@ from bytelatent import ByteLatentError
from bytelatent.data.data_types import BltExample
from bytelatent.data.file_util import get_fs
from bytelatent.data.iterators.abstract_iterator import IteratorState, StatefulIterator
from bytelatent.preprocess.preprocess_entropies import get_id_key, get_text
logger = getLogger(__name__)
@ -32,6 +33,7 @@ class ArrowFileIteratorState(BaseModel, IteratorState):
arrow_batch_size: int = 100
s3_profile: str | None
filesystem_type: str | None = None
file_format: str
def build(self) -> "ArrowFileIterator":
arrow_file = ArrowFileIterator(
@ -44,6 +46,7 @@ class ArrowFileIteratorState(BaseModel, IteratorState):
dataset_files=self.dataset_files,
s3_profile=self.s3_profile,
filesystem_type=self.filesystem_type,
file_format=self.file_format,
)
if self.row_num != 0:
arrow_file._set_row_num(self.row_num)
@ -70,6 +73,7 @@ class ArrowFileIterator(StatefulIterator):
dataset_files: list[str] | None = None,
s3_profile: str | None = None,
filesystem_type: str | None = None,
file_format: str = "arrow",
):
assert 0 <= worker_id < num_workers, (worker_id, num_workers)
if file_path is None and dataset_files is None:
@ -87,12 +91,16 @@ class ArrowFileIterator(StatefulIterator):
self.arrow_batch_size = arrow_batch_size
self.s3_profile = s3_profile
self.filesystem_type = filesystem_type
self.file_format = file_format
self.fs = None
if self.filesystem_type is not None:
if self.filesystem_type == "file":
self.fs = fsspec.filesystem("file")
elif self.filesystem_type == "s3":
self.fs = fsspec.filesystem("s3", profile=s3_profile)
else:
raise ValueError("Unknown filesystem")
logger.info("Arrow iterator using fs=%s", self.fs)
if dataset_files is None:
# Prepare arrow shards
@ -153,6 +161,7 @@ class ArrowFileIterator(StatefulIterator):
dataset_files=self.dataset_files,
s3_profile=self.s3_profile,
filesystem_type=self.filesystem_type,
file_format=self.file_format,
)
def create_iter(
@ -164,7 +173,7 @@ class ArrowFileIterator(StatefulIterator):
else:
filesystem = None
self.dataset = pa.dataset.dataset(
self.dataset_files, format="arrow", filesystem=filesystem
self.dataset_files, format=self.file_format, filesystem=filesystem
)
self.batch_iterator = self.dataset.to_batches(
batch_size=self.arrow_batch_size
@ -173,13 +182,22 @@ class ArrowFileIterator(StatefulIterator):
if self.batch_to_consume is not None:
batch_columns: dict[str, list] = self.batch_to_consume
self.batch_to_consume = None
if self.file_format == "arrow":
sample_ids = batch_columns["sample_id"]
texts = batch_columns["text"]
entropies = batch_columns["entropies"]
elif self.file_format == "json":
# This data hasn't been preprocessed to a uniform format,
# so we have to do it now and omit entropies
sample_ids = batch_columns[get_id_key(batch_columns)]
texts = get_text(batch_columns)
entropies = None
else:
raise ValueError(f"Unknown file format: {self.file_format}")
for i in range(len(sample_ids)):
out = BltExample(
sample_id=sample_ids[i],
entropies=entropies[i],
entropies=entropies[i] if entropies is not None else None,
text=texts[i],
tokens=None,
mask=None,
@ -191,13 +209,22 @@ class ArrowFileIterator(StatefulIterator):
for batch in self.batch_iterator:
batch_columns = batch.to_pydict()
if self.file_format == "arrow":
sample_ids = batch_columns["sample_id"]
texts = batch_columns["text"]
entropies = batch_columns["entropies"]
elif self.file_format == "json":
# This data hasn't been preprocessed to a uniform format,
# so we have to do it now and omit entropies
sample_ids = batch_columns[get_id_key(batch_columns)]
texts = get_text(batch_columns)
entropies = None
else:
raise ValueError(f"Unknown file format: {self.file_format}")
for i in range(len(sample_ids)):
out = BltExample(
sample_id=sample_ids[i],
entropies=entropies[i],
entropies=entropies[i] if entropies is not None else None,
text=texts[i],
tokens=None,
mask=None,
@ -231,13 +258,24 @@ class ArrowFileIterator(StatefulIterator):
for batch in self.batch_iterator:
if len(batch) > curr_remaining:
batch_columns: dict[str, list] = batch.to_pydict()
batch_columns["sample_id"] = batch_columns["sample_id"][
if self.file_format == "arrow":
leftover_sample_ids = batch_columns["sample_id"][
curr_remaining:
]
batch_columns["entropies"] = batch_columns["entropies"][
leftover_entropies = batch_columns["entropies"][curr_remaining:]
leftover_texts = batch_columns["text"][curr_remaining:]
elif self.file_format == "json":
leftover_sample_ids = batch_columns[get_id_key(batch_columns)][
curr_remaining:
]
batch_columns["text"] = batch_columns["text"][curr_remaining:]
leftover_entropies = None
leftover_texts = get_text(batch_columns)[curr_remaining:]
else:
raise ValueError(f"Unknown file format: {self.file_format}")
batch_columns["sample_id"] = leftover_sample_ids
batch_columns["entropies"] = leftover_entropies
batch_columns["text"] = leftover_texts
self.batch_to_consume = batch_columns
break
elif len(batch) == curr_remaining:
@ -250,30 +288,3 @@ class ArrowFileIterator(StatefulIterator):
logger.info(
f"Finished setting arrow position to {target_row_num} for {self.dataset_files}"
)
TRAIN_DATA_FILE_PATTERN = "*.chunk.*.jsonl"
def find_and_sanitize_chunks(
dataset_path: str,
world_size: int,
file_pattern: str = TRAIN_DATA_FILE_PATTERN,
s3_profile: str | None = None,
):
fs = get_fs(dataset_path, s3_profile=s3_profile)
path_with_glob = os.path.join(dataset_path, file_pattern)
dataset_chunks = fs.glob(path_with_glob)
n_chunks = len(dataset_chunks)
if n_chunks > world_size:
n_discard = n_chunks - world_size
dataset_chunks = dataset_chunks[:world_size]
else:
assert (
world_size % n_chunks == 0
), "World size should be a multiple of number of chunks"
assert n_chunks > 0, f"No valid chunks in {dataset_path}"
return dataset_chunks

View file

@ -127,6 +127,16 @@ def dist_max(x: Union[int, float], mesh: DeviceMesh = None):
return tensor
def dist_sum(
x: Union[int, float], mesh: DeviceMesh = None, reduce_dtype: torch.dtype = None
):
tensor = torch.tensor(x).cuda()
if reduce_dtype is not None:
tensor = tensor.to(reduce_dtype)
dist.all_reduce(tensor, op=ReduceOp.SUM, group=mesh.get_group() if mesh else None)
return tensor
def dist_mean(x: Union[int, float], mesh: DeviceMesh = None):
tensor = torch.tensor(x).cuda()
dist.all_reduce(tensor, op=ReduceOp.AVG, group=mesh.get_group() if mesh else None)
@ -236,7 +246,7 @@ def setup_env(env_args: EnvironmentArgs):
logger.warning(f"WARNING: Setting {name} to {value}")
def setup_torch_distributed(dist_args):
def setup_torch_distributed(dist_args: DistributedArgs):
"""
Handle single and multi-GPU / multi-node / SLURM jobs.
Initialize the following variables:
@ -388,14 +398,14 @@ def clean_env():
def parallelize_model(
model,
model: torch.nn.Module,
device_mesh,
model_args,
distributed_args: DistributedArgs,
fsdp_grouping_plan: Optional[List[Tuple[str, bool]]] = None,
tp_parallelize=None,
no_recompute_ops=None,
):
) -> torch.nn.Module:
if distributed_args.tp_size > 1:
assert (
distributed_args.fsdp_type == "full_shard"
@ -429,6 +439,8 @@ def parallelize_model(
device_mesh["dp_shard"].size() == 1
), "dp_shard must be 1 for no_shard fsdp_type"
# TODO: Remove with something better
# model = model.to(param_dtype)
fsdp_config = dict(
mp_policy=(
MixedPrecisionPolicy(

View file

@ -15,9 +15,16 @@ from lm_eval.api.model import LM
from omegaconf import OmegaConf
from pydantic import BaseModel, ConfigDict
from bytelatent.args import EvalArgs, ValidationArgs, parse_args
from bytelatent.args import (
EvalArgs,
TrainArgs,
ValidationArgs,
find_and_sanitize_chunks,
parse_args,
)
from bytelatent.checkpoint import CONSOLIDATE_FOLDER, consolidate_checkpoints
from bytelatent.data.file_util import get_fs
from bytelatent.data.iterators.arrow_iterator import ArrowFileIterator
from bytelatent.distributed import (
DistributedArgs,
dist_mean_dict,
@ -117,19 +124,40 @@ class EvalHarnessLM(LM):
return results
def eval_on_val(generator, val_args: ValidationArgs, train_cfg):
srcs = {}
def eval_on_val(generator, val_args: ValidationArgs, train_cfg: TrainArgs):
srcs = []
for src in val_args.sources:
path = os.path.join(val_args.root_dir, src)
srcs[path] = 1.0
srcs.append(path)
for src in train_cfg.data.sources:
path = os.path.join(train_cfg.data.root_dir, src)
srcs[path] = 1.0
srcs.append(path)
multi_state = init_choice_state(
"", srcs, 0, get_global_rank(), get_world_size(), "*.val.jsonl"
path_to_iter = {}
for path in srcs:
chunks = find_and_sanitize_chunks(
path,
world_size=1,
file_pattern="*.val.jsonl",
s3_profile=train_cfg.data.s3_profile,
)
path_to_iter = setup_sources(multi_state)
assert (
len(chunks) == 1
), f"There should be only 1 chunk per validation file, but found: {chunks}"
chunk = chunks[0]
iterator = ArrowFileIterator(
dataset_files=[chunk],
file_path=None,
preprocess_dir=None,
entropy_model_name=None,
worker_id=0,
num_workers=1,
arrow_batch_size=train_cfg.data.arrow_batch_size,
s3_profile=train_cfg.data.s3_profile,
file_format="json",
)
path_to_iter[path] = iterator
max_gen_len = generator.max_gen_len
# We temporarily lower max gen len
@ -137,16 +165,11 @@ def eval_on_val(generator, val_args: ValidationArgs, train_cfg):
all_val_metrics = {}
for src in path_to_iter:
jsonl_iterator = path_to_iter[src]
example_iterator = path_to_iter[src].create_iter()
texts = []
logger.info(f"Running validation on {src}...")
for step, (content, state) in enumerate(jsonl_iterator):
if state["current_iter"] > 0 or (
val_args.max_steps is not None and step >= val_args.max_steps
):
break
content_key = "text" if ("text" in content) else "content"
texts.append(content[content_key])
for step, example in enumerate(example_iterator):
texts.append(example.text)
_, loglikelihood, _ = generator.generate(texts)
@ -191,7 +214,7 @@ def launch_eval(eval_args: EvalArgs):
else:
consolidate_path = os.path.join(eval_args.ckpt_dir, CONSOLIDATE_FOLDER)
if not fs.exists(consolidate_path) and get_global_rank() == 0:
consolidate_path = consolidate_checkpoints(eval_args.ckpt_dir)
consolidate_path = consolidate_checkpoints(fs, eval_args.ckpt_dir)
fs.mkdirs(eval_args.dump_dir, exist_ok=True)
with fs.open(os.path.join(eval_args.dump_dir, "config.yaml"), "w") as f:
@ -210,10 +233,12 @@ def launch_eval(eval_args: EvalArgs):
wrap = EvalHarnessLM(generator)
# Redo
results = simple_evaluate(wrap, eval_args.harness.model_dump())
results = simple_evaluate(wrap, **eval_args.harness.model_dump())
val_results = None
if eval_args.validation:
val_results = eval_on_val(generator, eval_args.validation, train_cfg)
if get_global_rank() == 0:
with fs.open(os.path.join(eval_args.dump_dir, "results.json"), "w") as f:
f.write(json.dumps(results))
@ -222,6 +247,7 @@ def launch_eval(eval_args: EvalArgs):
with fs.open(os.path.join(eval_args.dump_dir, "validation.json"), "w") as f:
f.write(json.dumps(val_results))
logger.info(f"All validation results: {val_results}")
if eval_args.metric_log_dir and get_global_rank() == 0:
metric_log_path = os.path.join(eval_args.metric_log_dir, "metrics.eval.jsonl")

View file

@ -387,8 +387,7 @@ def load_consolidated_model_and_tokenizer(
):
train_args_path = os.path.join(consolidated_path, "params.json")
fs = get_fs(train_args_path)
with fs.open(train_args_path) as f:
train_args = TrainArgs.model_validate_json(f.read())
train_args = TrainArgs.model_validate_json(fs.read_text(train_args_path))
if train_args.train_entropy_model:
model_args = train_args.entropy_model
@ -401,7 +400,8 @@ def load_consolidated_model_and_tokenizer(
train_args.distributed.model_dtype
]
tokenizer = train_args.data.tokenizer_args.build()
st_dict = torch.load(consolidated_path / CONSOLIDATE_NAME, weights_only=True)
with fs.open(os.path.join(consolidated_path, CONSOLIDATE_NAME)) as f:
st_dict = torch.load(f, weights_only=True)
model.load_state_dict(st_dict["model"])
model = model.cuda().eval()
for param in model.parameters():

100
bytelatent/norms.py Normal file
View file

@ -0,0 +1,100 @@
from typing import Dict, List, Optional, Tuple
import torch
from torch import Tensor
from torch.utils._foreach_utils import (
_device_has_foreach_support,
_group_tensors_by_device_and_dtype,
_has_foreach_support,
)
@torch.no_grad()
def fixed_clip_grad_norm_(
parameters: torch.Tensor | list[torch.Tensor],
max_norm: float,
norm_type: float = 2.0,
error_if_nonfinite: bool = False,
foreach: Optional[bool] = None,
) -> torch.Tensor:
r"""Clip the gradient norm of an iterable of parameters.
The norm is computed over the norms of the individual gradients of all parameters,
as if the norms of the individual gradients were concatenated into a single vector.
Gradients are modified in-place.
Args:
parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a
single Tensor that will have gradients normalized
max_norm (float): max norm of the gradients
norm_type (float): type of the used p-norm. Can be ``'inf'`` for
infinity norm.
error_if_nonfinite (bool): if True, an error is thrown if the total
norm of the gradients from :attr:`parameters` is ``nan``,
``inf``, or ``-inf``. Default: False (will switch to True in the future)
foreach (bool): use the faster foreach-based implementation.
If ``None``, use the foreach implementation for CUDA and CPU native tensors and silently
fall back to the slow implementation for other device types.
Default: ``None``
Returns:
Total norm of the parameter gradients (viewed as a single vector).
"""
if isinstance(parameters, torch.Tensor):
parameters = [parameters]
grads = [p.grad.to(torch.bfloat16) for p in parameters if p.grad is not None]
max_norm = float(max_norm)
norm_type = float(norm_type)
if len(grads) == 0:
return torch.tensor(0.0)
first_device = grads[0].device
grouped_grads: Dict[
Tuple[torch.device, torch.dtype], Tuple[List[List[Tensor]], List[int]]
] = _group_tensors_by_device_and_dtype(
[grads]
) # type: ignore[assignment]
norms: List[Tensor] = []
for (device, _), ([device_grads], _) in grouped_grads.items(): # type: ignore[assignment]
if (foreach is None and _has_foreach_support(device_grads, device)) or (
foreach and _device_has_foreach_support(device)
):
norms.extend(torch._foreach_norm(device_grads, norm_type))
elif foreach:
raise RuntimeError(
f"foreach=True was passed, but can't use the foreach API on {device.type} tensors"
)
else:
norms.extend([torch.linalg.vector_norm(g, norm_type) for g in device_grads])
total_norm = torch.linalg.vector_norm(
torch.stack([norm.to(first_device) for norm in norms]), norm_type
)
if error_if_nonfinite and torch.logical_or(total_norm.isnan(), total_norm.isinf()):
raise RuntimeError(
f"The total norm of order {norm_type} for gradients from "
"`parameters` is non-finite, so it cannot be clipped. To disable "
"this error and scale the gradients by the non-finite norm anyway, "
"set `error_if_nonfinite=False`"
)
clip_coef = max_norm / (total_norm + 1e-6)
# Note: multiplying by the clamped coef is redundant when the coef is clamped to 1, but doing so
# avoids a `if clip_coef < 1:` conditional which can require a CPU <=> device synchronization
# when the gradients do not reside in CPU memory.
clip_coef_clamped = torch.clamp(clip_coef, max=1.0)
for (device, _), ([device_grads], _) in grouped_grads.items(): # type: ignore[assignment]
if (foreach is None and _has_foreach_support(device_grads, device)) or (
foreach and _device_has_foreach_support(device)
):
torch._foreach_mul_(device_grads, clip_coef_clamped.to(device))
elif foreach:
raise RuntimeError(
f"foreach=True was passed, but can't use the foreach API on {device.type} tensors"
)
else:
clip_coef_clamped_device = clip_coef_clamped.to(device)
for g in device_grads:
g.mul_(clip_coef_clamped_device)
return total_norm

View file

@ -15,29 +15,37 @@ from bytelatent.entropy_model import load_entropy_model
from bytelatent.tokenizers.build_tokenizer import TokenizerArgs
def get_id_from_doc(doc: dict) -> int:
def get_id_key(doc: dict) -> int:
"""
We need a reliable way to ensure that samples from jsonl
and arrow are the same, but there is no unique id field,
so derive the best possible
"""
if "sample_id" in doc:
sample_id = doc["sample_id"]
return "sample_id"
elif "title" in doc:
sample_id = doc["title"]
return "title"
elif "qid" in doc:
sample_id = doc["qid"]
return "qid"
elif "paper_id" in doc:
sample_id = doc["paper_id"]
return "paper_id"
elif "path" in doc:
sample_id = doc["path"]
return "path"
elif "url" in doc:
sample_id = doc["url"]
return "url"
elif "id" in doc:
sample_id = doc["id"]
return "id"
else:
raise ValueError(f"Could not find a id key from: {doc.keys()}")
return str(sample_id)
def get_id_from_doc(doc: dict) -> int:
"""
We need a reliable way to ensure that samples from jsonl
and arrow are the same, but there is no unique id field,
so derive the best possible
"""
return str(doc[get_id_key(doc)])
def get_text(doc: dict):

View file

@ -3,6 +3,7 @@
import gc
import logging
import math
import os
import sys
from contextlib import ExitStack
@ -12,6 +13,7 @@ from pathlib import Path
from timeit import default_timer as timer
from typing import Any, TypeVar
import numpy as np
import torch
import torch.distributed
import torch.nn.functional
@ -25,6 +27,7 @@ from torch.optim import lr_scheduler
from bytelatent.args import TrainArgs, parse_args
from bytelatent.checkpoint import CheckpointManager, load_from_checkpoint
from bytelatent.data.file_util import get_fs
from bytelatent.data.iterators.multiprocess_iterator import (
MultiprocessIterator,
MultiprocessIteratorState,
@ -33,7 +36,9 @@ from bytelatent.data.iterators.packing_iterator import PackingIteratorState
from bytelatent.distributed import (
check_model_value_range,
clean_env,
dist_mean,
dist_mean_dict,
dist_sum,
get_device_mesh,
get_is_master,
get_world_size,
@ -47,6 +52,7 @@ from bytelatent.eval import EVAL_FOLDER_NAME, launch_eval
from bytelatent.logger import init_logger
from bytelatent.metrics import GPUMemoryMonitor, MetricLogger, get_num_params
from bytelatent.model.blt import ByteLatentTransformer
from bytelatent.norms import fixed_clip_grad_norm_
from bytelatent.optim import build_optimizer
from bytelatent.probe import AutoProbeD
from bytelatent.profiling import maybe_run_profiler
@ -295,8 +301,11 @@ def train(args: TrainArgs):
if args.checkpoint.init_ckpt_path:
logger.info(f"Loading initial model from {args.checkpoint.init_ckpt_path}")
ckpt_fs = get_fs(
args.checkpoint.init_ckpt_path, s3_profile=args.checkpoint.s3_profile
)
load_from_checkpoint(
args.checkpoint.init_ckpt_path, model, model_key="model"
ckpt_fs, args.checkpoint.init_ckpt_path, model, model_key="model"
) # Put model_key="" if its directly the model checkpoint
model.rope_embeddings.reset_parameters() # For RoPe initialization since it's a buffer it might not be loaded
else:
@ -364,6 +373,9 @@ def train(args: TrainArgs):
time_last_log = timer()
gc.collect()
saved = False
step_losses: list[float] = []
step_tok_losses: list[float] = []
n_bytes: int = 0
while train_state.step < args.steps and (
args.max_steps is None or train_state.step < args.max_steps
):
@ -385,6 +397,24 @@ def train(args: TrainArgs):
batch_patch_lengths = torch.from_numpy(batch.patch_lengths).cuda()
mask = None if batch.mask is None else torch.from_numpy(batch.mask).cuda()
if args.data.tokenizer_args.name in ["bytes", "blt"]:
if mask is None:
n_bytes += batch_y.numel()
else:
n_bytes += mask.sum()
elif args.data.tokenizer_args.name in ["sp", "tiktoken"]:
for example in batch.y:
target_tokens = tokenizer.decode(example.tolist(), cut_at_eos=False)
n_bytes += (
len(bytes(target_tokens, encoding="utf-8", errors="ignore"))
+ sum(example == tokenizer.eos_id)
+ sum(example == tokenizer.bos_id)
)
else:
raise ValueError(
f"Unexpected tokenizer to count n_bytes for: {args.data.tokenizer_args.name}"
)
if (
not args.train_entropy_model
and args.model.encoder_enable_byte_ngrams
@ -459,7 +489,7 @@ def train(args: TrainArgs):
batch_x, patch_lengths=batch_patch_lengths, ngram_ids=ngram_ids
)
loss, _ = compute_loss(pred, batch_y, mask, train_state.scale)
loss, tok_loss = compute_loss(pred, batch_y, mask, train_state.scale)
# We scale loss with grad_acc_steps so the gradient is the same
# regardless of grad_acc_steps
@ -470,8 +500,14 @@ def train(args: TrainArgs):
# For logging we undo that scaling
loss = loss.detach() * args.grad_acc_steps
grad_norm = torch.nn.utils.clip_grad_norm_(
model.parameters(), max_norm=args.optim.clip, foreach=True
# Undo loss scaling so downstream down't need to worry about it
step_losses.append((loss / train_state.scale).item())
step_tok_losses.append(tok_loss / train_state.scale)
# grad_norm = torch.nn.utils.clip_grad_norm_(
grad_norm = fixed_clip_grad_norm_(
model.parameters(),
max_norm=args.optim.clip, # foreach=True
)
grad_norm = (
@ -559,20 +595,33 @@ def train(args: TrainArgs):
gpu_memory_monitor.reset_peak_stats()
nwords_since_last_log = 0
time_last_log = timer()
stacked_tok_loss = torch.cat(step_tok_losses, dim=0)
total_tok_loss = dist_sum(
stacked_tok_loss.sum().item(), reduce_dtype=torch.bfloat16
)
total_n_bytes = dist_sum(n_bytes, reduce_dtype=torch.bfloat16)
avg_bpb = total_tok_loss / math.log(2) / total_n_bytes
avg_loss = dist_mean(np.mean(step_losses).item())
logger.info(
f"step: {train_state.step}"
f" acc: {train_state.acc_step}"
f" loss: {round(loss.item(),4):>7}"
f" loss: step={round(loss.item(),4):>7} avg={avg_loss}"
f" bpb: {avg_bpb:3f}"
f" grad: {grad_norm:.2e}"
f" flops: {FLOPS:.2e}"
f" wps: {wps:.2e}"
f" iter: {curr_iter_time:>7}"
f" data: {data_load_time:>5}"
f" lr: {curr_lr:.2e}"
f" n_bytes={total_n_bytes}"
f" mem: {gpu_mem_stats.max_active_pct:.0f}%"
f" pow: {gpu_mem_stats.power_draw/1000} W"
)
n_bytes = 0
step_losses = []
step_tok_losses = []
if every_n_steps(
train_state, args.checkpoint.dump.every, acc_step=0
) or every_n_steps(train_state, args.checkpoint.eval.every, acc_step=0):