Merge 45bfe94c1e into sapling-pr-archive-EntilZha
Some checks are pending
Lint with Black / lint (push) Waiting to run
Lint with isort / lint (push) Waiting to run

This commit is contained in:
Pedro Rodriguez 2025-02-05 17:27:35 -08:00 committed by GitHub
commit 2d1c766050
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 776 additions and 80 deletions

View file

@ -1,8 +1,10 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
import json
import logging
import os
from typing import Any
import fsspec
import numpy as np
import yaml
from omegaconf import OmegaConf
@ -10,11 +12,9 @@ from pydantic import BaseModel, ConfigDict
from bytelatent.checkpoint import CheckpointArgs
from bytelatent.data.data_types import Batch
from bytelatent.data.file_util import get_fs
from bytelatent.data.iterators.abstract_iterator import StatefulIterator
from bytelatent.data.iterators.arrow_iterator import (
ArrowFileIterator,
find_and_sanitize_chunks,
)
from bytelatent.data.iterators.arrow_iterator import ArrowFileIterator
from bytelatent.data.iterators.looping_iterator import LoopingIterator
from bytelatent.data.iterators.multiprocess_iterator import MultiprocessIterator
from bytelatent.data.iterators.packing_iterator import PackingArgs, PackingIterator
@ -53,6 +53,33 @@ def parse_args(args_cls):
return pydantic_args
TRAIN_DATA_FILE_PATTERN = "*.chunk.*.jsonl"
def find_and_sanitize_chunks(
dataset_path: str,
world_size: int,
file_pattern: str,
s3_profile: str | None = None,
):
fs = get_fs(dataset_path, s3_profile=s3_profile)
path_with_glob = os.path.join(dataset_path, file_pattern)
dataset_chunks = fs.glob(path_with_glob)
n_chunks = len(dataset_chunks)
if n_chunks > world_size:
n_discard = n_chunks - world_size
dataset_chunks = dataset_chunks[:world_size]
else:
assert (
world_size % n_chunks == 0
), "World size should be a multiple of number of chunks"
assert n_chunks > 0, f"No valid chunks in {dataset_path}"
return dataset_chunks
def distribute_data_to_rank(
*,
dataset_path: str,
@ -62,9 +89,10 @@ def distribute_data_to_rank(
rank: int,
world_size: int,
s3_profile: str | None = None,
file_pattern: str = TRAIN_DATA_FILE_PATTERN,
) -> ArrowFileIterator:
dataset_chunks = find_and_sanitize_chunks(
dataset_path, world_size, s3_profile=s3_profile
dataset_path, world_size, file_pattern, s3_profile=s3_profile
)
n_workers_per_chunk = world_size // len(dataset_chunks)
rank_to_arrow_iterator_params = []

623
bytelatent/broken_train.py Normal file
View file

@ -0,0 +1,623 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
from datetime import timedelta
from enum import Enum
from functools import lru_cache
import logging
import math
import sys
import time
from typing import Dict, List, Optional, Tuple
from torch import Tensor
import os
import pickle
import fsspec
import torch
import torch.distributed
import torch.nn.functional
import torch.nn.functional as F
from torch.distributed._tensor import DTensor
from torch.distributed.device_mesh import init_device_mesh
from torch.utils._foreach_utils import (
_device_has_foreach_support,
_group_tensors_by_device_and_dtype,
_has_foreach_support,
)
from bytelatent.args import TrainArgs
from bytelatent.distributed import (
DistributedArgs,
check_model_value_range,
parallelize_model,
setup_env,
setup_torch_distributed,
)
logger = logging.getLogger()
def set_root_log_level(log_level: str):
logger = logging.getLogger()
level: int | str = log_level.upper()
try:
level = int(log_level)
except ValueError:
pass
try:
logger.setLevel(level) # type: ignore
except Exception:
logger.warning(
f"Failed to set logging level to {log_level}, using default 'NOTSET'"
)
logger.setLevel(logging.NOTSET)
class LogFormatter(logging.Formatter):
"""
Custom logger for distributed jobs, displaying rank
and preserving indent from the custom prefix format.
"""
def __init__(self):
self.start_time = time.time()
self.rank = get_global_rank()
self.show_rank = not get_is_slurm_job() # srun has --label
def formatTime(self, record):
subsecond, seconds = math.modf(record.created)
curr_date = (
time.strftime("%y-%m-%d %H:%M:%S", time.localtime(seconds))
+ f".{int(subsecond * 1_000_000):06d}"
)
delta = timedelta(seconds=round(record.created - self.start_time))
return f"{curr_date} - {delta}"
def formatPrefix(self, record):
fmt_time = self.formatTime(record)
if self.show_rank:
return f"{self.rank}: {record.levelname:<7} {fmt_time} - "
else:
return f"{record.levelname:<7} {fmt_time} - "
def formatMessage(self, record, indent: str):
content = record.getMessage()
content = content.replace("\n", "\n" + indent)
# Exception handling as in the default formatter, albeit with indenting
# according to our custom prefix
if record.exc_info:
# Cache the traceback text to avoid converting it multiple times
# (it's constant anyway)
if not record.exc_text:
record.exc_text = self.formatException(record.exc_info)
if record.exc_text:
if content[-1:] != "\n":
content = content + "\n" + indent
content = content + indent.join(
[l + "\n" for l in record.exc_text.splitlines()]
)
if content[-1:] == "\n":
content = content[:-1]
if record.stack_info:
if content[-1:] != "\n":
content = content + "\n" + indent
stack_text = self.formatStack(record.stack_info)
content = content + indent.join([l + "\n" for l in stack_text.splitlines()])
if content[-1:] == "\n":
content = content[:-1]
return content
def format(self, record):
prefix = self.formatPrefix(record)
indent = " " * len(prefix)
content = self.formatMessage(record, indent)
return prefix + content
def init_logger(
log_file: str | None = None,
*,
name: str | None = None,
level: str = "INFO",
fs: fsspec.AbstractFileSystem | None = None,
):
"""
Setup logging.
Args:
log_file: A file name to save file logs to.
name: The name of the logger to configure, by default the root logger.
level: The logging level to use.
"""
set_root_log_level(level)
logger = logging.getLogger(name)
# stdout: everything
stdout_handler = logging.StreamHandler(sys.stdout)
stdout_handler.setLevel(logging.NOTSET)
stdout_handler.setFormatter(LogFormatter())
# stderr: warnings / errors and above
stderr_handler = logging.StreamHandler(sys.stderr)
stderr_handler.setLevel(logging.WARNING)
stderr_handler.setFormatter(LogFormatter())
# set stream handlers
logger.handlers.clear()
logger.handlers.append(stdout_handler)
logger.handlers.append(stderr_handler)
@torch.no_grad()
def fixed_clip_grad_norm_(
parameters: torch.Tensor | list[torch.Tensor],
max_norm: float,
norm_type: float = 2.0,
error_if_nonfinite: bool = False,
foreach: Optional[bool] = None,
) -> torch.Tensor:
r"""Clip the gradient norm of an iterable of parameters.
The norm is computed over the norms of the individual gradients of all parameters,
as if the norms of the individual gradients were concatenated into a single vector.
Gradients are modified in-place.
Args:
parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a
single Tensor that will have gradients normalized
max_norm (float): max norm of the gradients
norm_type (float): type of the used p-norm. Can be ``'inf'`` for
infinity norm.
error_if_nonfinite (bool): if True, an error is thrown if the total
norm of the gradients from :attr:`parameters` is ``nan``,
``inf``, or ``-inf``. Default: False (will switch to True in the future)
foreach (bool): use the faster foreach-based implementation.
If ``None``, use the foreach implementation for CUDA and CPU native tensors and silently
fall back to the slow implementation for other device types.
Default: ``None``
Returns:
Total norm of the parameter gradients (viewed as a single vector).
"""
if isinstance(parameters, torch.Tensor):
parameters = [parameters]
grads = [p.grad.to(torch.bfloat16) for p in parameters if p.grad is not None]
max_norm = float(max_norm)
norm_type = float(norm_type)
if len(grads) == 0:
return torch.tensor(0.0)
first_device = grads[0].device
grouped_grads: Dict[
Tuple[torch.device, torch.dtype], Tuple[List[List[Tensor]], List[int]]
] = _group_tensors_by_device_and_dtype(
[grads]
) # type: ignore[assignment]
norms: List[Tensor] = []
for (device, _), ([device_grads], _) in grouped_grads.items(): # type: ignore[assignment]
if (foreach is None and _has_foreach_support(device_grads, device)) or (
foreach and _device_has_foreach_support(device)
):
norms.extend(torch._foreach_norm(device_grads, norm_type))
elif foreach:
raise RuntimeError(
f"foreach=True was passed, but can't use the foreach API on {device.type} tensors"
)
else:
norms.extend([torch.linalg.vector_norm(g, norm_type) for g in device_grads])
total_norm = torch.linalg.vector_norm(
torch.stack([norm.to(first_device) for norm in norms]), norm_type
)
if error_if_nonfinite and torch.logical_or(total_norm.isnan(), total_norm.isinf()):
raise RuntimeError(
f"The total norm of order {norm_type} for gradients from "
"`parameters` is non-finite, so it cannot be clipped. To disable "
"this error and scale the gradients by the non-finite norm anyway, "
"set `error_if_nonfinite=False`"
)
clip_coef = max_norm / (total_norm + 1e-6)
# Note: multiplying by the clamped coef is redundant when the coef is clamped to 1, but doing so
# avoids a `if clip_coef < 1:` conditional which can require a CPU <=> device synchronization
# when the gradients do not reside in CPU memory.
clip_coef_clamped = torch.clamp(clip_coef, max=1.0)
for (device, _), ([device_grads], _) in grouped_grads.items(): # type: ignore[assignment]
if (foreach is None and _has_foreach_support(device_grads, device)) or (
foreach and _device_has_foreach_support(device)
):
torch._foreach_mul_(device_grads, clip_coef_clamped.to(device))
elif foreach:
raise RuntimeError(
f"foreach=True was passed, but can't use the foreach API on {device.type} tensors"
)
else:
clip_coef_clamped_device = clip_coef_clamped.to(device)
for g in device_grads:
g.mul_(clip_coef_clamped_device)
return total_norm
def get_no_recompute_ops():
return None
@lru_cache()
def get_is_torch_run() -> bool:
return os.environ.get("LOCAL_RANK") is not None
@lru_cache()
def get_is_slurm_job() -> bool:
return "SLURM_JOB_ID" in os.environ and not get_is_torch_run()
@lru_cache()
def get_global_rank() -> int:
if get_is_torch_run():
return int(os.environ["RANK"])
elif get_is_slurm_job():
return int(os.environ["SLURM_PROCID"])
else:
return 0
@lru_cache()
def get_local_rank() -> int:
if get_is_torch_run():
return int(os.environ["LOCAL_RANK"])
elif get_is_slurm_job():
return int(os.environ["SLURM_LOCALID"])
else:
return 0
@lru_cache()
def get_world_size() -> int:
if get_is_torch_run():
return int(os.environ["WORLD_SIZE"])
elif get_is_slurm_job():
return int(os.environ["SLURM_NTASKS"])
else:
return 1
@lru_cache()
def get_is_master() -> bool:
return get_global_rank() == 0
def validate_train_args(args: TrainArgs, output_size: int):
# assert args.model is not None or args.entropy_model is not None
if args.entropy_model is not None:
logger.info(f"Setting model output size to {args.entropy_model.vocab_size}")
args.entropy_model.vocab_size = output_size
assert args.dump_dir, "Dump dir not set"
if (
args.distributed.dp_replicate
* args.distributed.dp_shard
* args.distributed.tp_size
!= get_world_size()
):
logging.info("Modifying TrainArgs distributed config")
assert get_world_size() % args.distributed.dp_shard == 0
logging.info("World size: %s", get_world_size())
logging.info(
"Existing setting: train_args.distributed.dp_shard=%s",
args.distributed.dp_shard,
)
logging.info(
"Setting train_args.distributed.dp_replicate=%s, was dp_replicate=%s",
get_world_size() // args.distributed.dp_shard,
args.distributed.dp_replicate,
)
args.distributed.dp_replicate = get_world_size() // args.distributed.dp_shard
logging.info(
"Changing dp_replicate from %s to %s, to account for tp_size=%s",
args.distributed.dp_replicate,
args.distributed.dp_replicate // args.distributed.tp_size,
args.distributed.tp_size,
)
assert args.distributed.dp_replicate % args.distributed.tp_size == 0
args.distributed.dp_replicate = (
args.distributed.dp_replicate // args.distributed.tp_size
)
logger.warning(
f"Setting Data Parallel size to {args.distributed.dp_replicate * args.distributed.dp_shard}"
)
assert (
args.distributed.dp_replicate
* args.distributed.dp_shard
* args.distributed.tp_size
== get_world_size()
)
if args.distributed.fsdp_type == "no_shard":
assert (
args.distributed.dp_shard == 1
and args.distributed.dp_replicate == get_world_size()
)
if args.model is not None:
args.model.max_seqlen = args.data.seq_len
if args.entropy_model is not None:
args.entropy_model.max_seqlen = args.data.seq_len
if args.distributed.tp_size == 1:
logger.warning(
"Tensor parallelism has not been tested for a while, use at your own risk"
)
assert (
args.probe_freq != args.profiling.mem_steps
), "Don't profile during probe step"
assert (
args.probe_freq != args.profiling.profile_steps
), "Don't profile during probe step"
if args.logging.wandb is not None:
args.logging.wandb.name = args.name
if args.probe_freq is not None:
assert (
args.distributed.tp_size == 1
), "Probing not supported with tensor parallelism"
assert (
args.distributed.selective_activation_checkpointing is False
), "Probing not supported with selective activation checkpointing"
def compute_loss(p, y, mask, scale):
tok_loss = scale * F.cross_entropy(
p.flatten(0, 1), y.flatten(0, 1), reduction="none"
)
if mask is None:
loss = tok_loss.mean()
else:
mask = mask.flatten(0, 1)
tok_loss = tok_loss * mask
loss = tok_loss.sum() / (mask.sum() + 1e-6)
return loss, tok_loss
def get_device_mesh(distributed_args):
tp_size = distributed_args.tp_size
dp_replicate = distributed_args.dp_replicate
dp_shard = distributed_args.dp_shard
assert (
dp_replicate * dp_shard * tp_size == get_world_size()
), f"dp_replicate * dp_shard * tp_size ({dp_replicate} * {dp_shard} * {tp_size}) != world_size ({get_world_size()})"
dims = []
names = []
if dp_replicate >= 1:
dims.append(dp_replicate)
names.append("dp_replicate")
if dp_shard > 1 or distributed_args.fsdp_type == "no_shard":
dims.append(dp_shard)
names.append("dp_shard")
if tp_size > 1:
dims.append(tp_size)
names.append("tp")
dims = tuple(dims)
names = tuple(names)
return init_device_mesh("cuda", mesh_shape=dims, mesh_dim_names=names)
def build_fsdp_grouping_plan():
group_plan: tuple[int, bool] = []
# Grouping and output seperately
group_plan.append(("tok_embeddings", False))
# Grouping by layers
# for i in range(model_args.n_layers):
# group_plan.append((f"layers.{i}", False))
group_plan.append(("output", True))
return group_plan
class MinimalModel(torch.nn.Module):
def __init__(self, dim: int, vocab_size: int):
super().__init__()
self.tok_embeddings = torch.nn.Embedding(vocab_size, dim)
# self.norm = RMSNorm(args.dim, eps=args.norm_eps)
# self.layers = torch.nn.ModuleList()
# for _ in range(args.n_layers):
# self.layers.append(TransformerBlock(args))
self.output = torch.nn.Linear(
dim,
vocab_size,
bias=False,
)
def forward(self, tokens):
h = self.tok_embeddings(tokens)
logits = self.output(h)
# logits = self.output(self.norm(h))
return logits
def reset_parameters(self, init_std=None):
pass
def init_weights(self):
pass
def train():
args = TrainArgs(
dump_dir="/tmp",
name="debug_bf16",
model=None,
entropy_model=None,
distributed=DistributedArgs(
fsdp_type="full_shard",
model_dtype="bf16",
matmul_allow_tf32=False,
selective_activation_checkpointing=False,
tp_size=1,
),
)
tokenizer = args.data.tokenizer_args.build()
validate_train_args(
args,
tokenizer.n_words,
)
dump_fs = fsspec.filesystem("file")
init_logger(os.path.join(args.dump_dir, "train.log"), fs=dump_fs)
setup_env(args.env)
setup_torch_distributed(args.distributed)
world_mesh = get_device_mesh(args.distributed)
logger.info(f"Starting job: {args.name}")
# build dataloader
# need dp world size and rank
dp_mesh = world_mesh["dp_replicate"]
dp_degree = dp_mesh.size()
dp_rank = dp_mesh.get_local_rank()
if args.distributed.dp_shard > 1:
dp_rank = dp_rank * dp_degree + world_mesh["dp_shard"].get_local_rank()
dp_degree *= world_mesh["dp_shard"].size()
logger.info(f"Running on dp rank : {dp_rank}")
logger.info(f"Running on dp size : {dp_degree}")
torch.manual_seed(args.seed)
logger.info("Building model")
# Initializing Model in meta device allows us to initialize models much bigger than 1 gpu's memory
with torch.device("meta"):
model = MinimalModel(768, tokenizer.n_words)
model = parallelize_model(
model,
world_mesh,
args.model,
args.distributed,
fsdp_grouping_plan=build_fsdp_grouping_plan(),
tp_parallelize=None,
no_recompute_ops=get_no_recompute_ops(),
)
# Once we shard the model on different gpus we can actually initialize the model
# First we create empty tensors of the correct shapes
model = model.to_empty(device="cuda")
# Then we init the model. Please make sure this function initializes *ALL* parameters
# and buffers, otherwise you will have random values in the unitialized tensors
# which will silently fail (give nan gradients for example)
with torch.random.fork_rng(devices=[torch.cuda.current_device()]):
torch.manual_seed(42)
model.init_weights()
check_model_value_range(model, range=10.0, std=1.0)
# data_loader = args.data.build_from_rank(dp_rank, dp_degree)
# train loop
model.train()
# data_loader = train_state.data_loader_state.build()
# batch_iterator = data_loader.create_iter()
# batch = next(batch_iterator)
# with open(f"/storage/home/par/toy-data/batch_{dp_rank}.pickle", "wb") as f:
# pickle.dump(batch, f)
with open(f"/storage/home/par/toy-data/batch_{dp_rank}.pickle", "rb") as f:
batch = pickle.load(f)
batch_x = torch.from_numpy(
batch.x,
).cuda()
batch_y = torch.from_numpy(batch.y).cuda()
mask = None if batch.mask is None else torch.from_numpy(batch.mask).cuda()
pred = model(batch_x)
loss, _ = compute_loss(pred, batch_y, mask, 1.0)
# We scale loss with grad_acc_steps so the gradient is the same
# regardless of grad_acc_steps
loss = loss / args.grad_acc_steps
# backward on scaled loss to create scaled gradients
loss.backward()
# For logging we undo that scaling
loss = loss.detach() * args.grad_acc_steps
world_size = get_world_size()
if 1 < world_size <= 8 and False:
# For some reason, there are errors in reduces due to
# not working for non-bf16 numbers. This function is a patched
# version that converts gradients to bf16 before computing norms.
# The error only happens in distributed training on one node,
# hence the guard
grad_norm = fixed_clip_grad_norm_(
model.parameters(), max_norm=args.optim.clip, foreach=True
)
else:
grad_norm = torch.nn.utils.clip_grad_norm_(
model.parameters(), max_norm=args.optim.clip, foreach=True
)
grad_norm = (
grad_norm.full_tensor() if isinstance(grad_norm, DTensor) else grad_norm
).item()
# if isinstance(data_loader, MultiprocessIterator):
# logger.info("Closing MP iterator before exiting")
# data_loader.shutdown()
def main():
"""
The command line interface here uses OmegaConf https://omegaconf.readthedocs.io/en/2.3_branch/usage.html#from-command-line-arguments
This accepts arguments as a dot list
So if the dataclass looks like
@dataclass
class DummyArgs:
name: str
model: LMTransformerArgsgs
@dataclass
class LMTransformerArgsgs:
dim: int
Then you can pass model.dim=32 to change values in LMTransformerArgsgs
or just name=tictac for top level attributes.
The behavior here is as follows:
1. We instantiate TrainArgs with its default values
2. We override those default values with the ones in the provided config file
3. We override the result with the additional arguments provided through command line
For example, if the config is the following
model:
dim: 128
n_layers: 4
and you call train.py with train.py model.dim=64
Then the final TrainArgs will have
model:
dim: 64
n_layers: 4
Plus all the default values in TrainArgs dataclass.
"""
train()
if __name__ == "__main__":
main()

View file

@ -16,6 +16,7 @@ from bytelatent import ByteLatentError
from bytelatent.data.data_types import BltExample
from bytelatent.data.file_util import get_fs
from bytelatent.data.iterators.abstract_iterator import IteratorState, StatefulIterator
from bytelatent.preprocess.preprocess_entropies import get_id_key, get_text
logger = getLogger(__name__)
@ -32,6 +33,7 @@ class ArrowFileIteratorState(BaseModel, IteratorState):
arrow_batch_size: int = 100
s3_profile: str | None
filesystem_type: str | None = None
file_format: str
def build(self) -> "ArrowFileIterator":
arrow_file = ArrowFileIterator(
@ -44,6 +46,7 @@ class ArrowFileIteratorState(BaseModel, IteratorState):
dataset_files=self.dataset_files,
s3_profile=self.s3_profile,
filesystem_type=self.filesystem_type,
file_format=self.file_format,
)
if self.row_num != 0:
arrow_file._set_row_num(self.row_num)
@ -70,6 +73,7 @@ class ArrowFileIterator(StatefulIterator):
dataset_files: list[str] | None = None,
s3_profile: str | None = None,
filesystem_type: str | None = None,
file_format: str = "arrow",
):
assert 0 <= worker_id < num_workers, (worker_id, num_workers)
if file_path is None and dataset_files is None:
@ -87,12 +91,16 @@ class ArrowFileIterator(StatefulIterator):
self.arrow_batch_size = arrow_batch_size
self.s3_profile = s3_profile
self.filesystem_type = filesystem_type
self.file_format = file_format
self.fs = None
if self.filesystem_type is not None:
if self.filesystem_type == "file":
self.fs = fsspec.filesystem("file")
elif self.filesystem_type == "s3":
self.fs = fsspec.filesystem("s3", profile=s3_profile)
else:
raise ValueError("Unknown filesystem")
logger.info("Arrow iterator using fs=%s", self.fs)
if dataset_files is None:
# Prepare arrow shards
@ -153,6 +161,7 @@ class ArrowFileIterator(StatefulIterator):
dataset_files=self.dataset_files,
s3_profile=self.s3_profile,
filesystem_type=self.filesystem_type,
file_format=self.file_format,
)
def create_iter(
@ -164,7 +173,7 @@ class ArrowFileIterator(StatefulIterator):
else:
filesystem = None
self.dataset = pa.dataset.dataset(
self.dataset_files, format="arrow", filesystem=filesystem
self.dataset_files, format=self.file_format, filesystem=filesystem
)
self.batch_iterator = self.dataset.to_batches(
batch_size=self.arrow_batch_size
@ -173,13 +182,22 @@ class ArrowFileIterator(StatefulIterator):
if self.batch_to_consume is not None:
batch_columns: dict[str, list] = self.batch_to_consume
self.batch_to_consume = None
if self.file_format == "arrow":
sample_ids = batch_columns["sample_id"]
texts = batch_columns["text"]
entropies = batch_columns["entropies"]
elif self.file_format == "json":
# This data hasn't been preprocessed to a uniform format,
# so we have to do it now and omit entropies
sample_ids = batch_columns[get_id_key(batch_columns)]
texts = get_text(batch_columns)
entropies = None
else:
raise ValueError(f"Unknown file format: {self.file_format}")
for i in range(len(sample_ids)):
out = BltExample(
sample_id=sample_ids[i],
entropies=entropies[i],
entropies=entropies[i] if entropies is not None else None,
text=texts[i],
tokens=None,
mask=None,
@ -191,13 +209,22 @@ class ArrowFileIterator(StatefulIterator):
for batch in self.batch_iterator:
batch_columns = batch.to_pydict()
if self.file_format == "arrow":
sample_ids = batch_columns["sample_id"]
texts = batch_columns["text"]
entropies = batch_columns["entropies"]
elif self.file_format == "json":
# This data hasn't been preprocessed to a uniform format,
# so we have to do it now and omit entropies
sample_ids = batch_columns[get_id_key(batch_columns)]
texts = get_text(batch_columns)
entropies = None
else:
raise ValueError(f"Unknown file format: {self.file_format}")
for i in range(len(sample_ids)):
out = BltExample(
sample_id=sample_ids[i],
entropies=entropies[i],
entropies=entropies[i] if entropies is not None else None,
text=texts[i],
tokens=None,
mask=None,
@ -231,13 +258,24 @@ class ArrowFileIterator(StatefulIterator):
for batch in self.batch_iterator:
if len(batch) > curr_remaining:
batch_columns: dict[str, list] = batch.to_pydict()
batch_columns["sample_id"] = batch_columns["sample_id"][
if self.file_format == "arrow":
leftover_sample_ids = batch_columns["sample_id"][
curr_remaining:
]
batch_columns["entropies"] = batch_columns["entropies"][
leftover_entropies = batch_columns["entropies"][curr_remaining:]
leftover_texts = batch_columns["text"][curr_remaining:]
elif self.file_format == "json":
leftover_sample_ids = batch_columns[get_id_key(batch_columns)][
curr_remaining:
]
batch_columns["text"] = batch_columns["text"][curr_remaining:]
leftover_entropies = None
leftover_texts = get_text(batch_columns)[curr_remaining:]
else:
raise ValueError(f"Unknown file format: {self.file_format}")
batch_columns["sample_id"] = leftover_sample_ids
batch_columns["entropies"] = leftover_entropies
batch_columns["text"] = leftover_texts
self.batch_to_consume = batch_columns
break
elif len(batch) == curr_remaining:
@ -250,30 +288,3 @@ class ArrowFileIterator(StatefulIterator):
logger.info(
f"Finished setting arrow position to {target_row_num} for {self.dataset_files}"
)
TRAIN_DATA_FILE_PATTERN = "*.chunk.*.jsonl"
def find_and_sanitize_chunks(
dataset_path: str,
world_size: int,
file_pattern: str = TRAIN_DATA_FILE_PATTERN,
s3_profile: str | None = None,
):
fs = get_fs(dataset_path, s3_profile=s3_profile)
path_with_glob = os.path.join(dataset_path, file_pattern)
dataset_chunks = fs.glob(path_with_glob)
n_chunks = len(dataset_chunks)
if n_chunks > world_size:
n_discard = n_chunks - world_size
dataset_chunks = dataset_chunks[:world_size]
else:
assert (
world_size % n_chunks == 0
), "World size should be a multiple of number of chunks"
assert n_chunks > 0, f"No valid chunks in {dataset_path}"
return dataset_chunks

View file

@ -15,9 +15,16 @@ from lm_eval.api.model import LM
from omegaconf import OmegaConf
from pydantic import BaseModel, ConfigDict
from bytelatent.args import EvalArgs, ValidationArgs, parse_args
from bytelatent.args import (
EvalArgs,
TrainArgs,
ValidationArgs,
find_and_sanitize_chunks,
parse_args,
)
from bytelatent.checkpoint import CONSOLIDATE_FOLDER, consolidate_checkpoints
from bytelatent.data.file_util import get_fs
from bytelatent.data.iterators.arrow_iterator import ArrowFileIterator
from bytelatent.distributed import (
DistributedArgs,
dist_mean_dict,
@ -117,19 +124,40 @@ class EvalHarnessLM(LM):
return results
def eval_on_val(generator, val_args: ValidationArgs, train_cfg):
srcs = {}
def eval_on_val(generator, val_args: ValidationArgs, train_cfg: TrainArgs):
srcs = []
for src in val_args.sources:
path = os.path.join(val_args.root_dir, src)
srcs[path] = 1.0
srcs.append(path)
for src in train_cfg.data.sources:
path = os.path.join(train_cfg.data.root_dir, src)
srcs[path] = 1.0
srcs.append(path)
multi_state = init_choice_state(
"", srcs, 0, get_global_rank(), get_world_size(), "*.val.jsonl"
path_to_iter = {}
for path in srcs:
chunks = find_and_sanitize_chunks(
path,
world_size=1,
file_pattern="*.val.jsonl",
s3_profile=train_cfg.data.s3_profile,
)
path_to_iter = setup_sources(multi_state)
assert (
len(chunks) == 1
), f"There should be only 1 chunk per validation file, but found: {chunks}"
chunk = chunks[0]
iterator = ArrowFileIterator(
dataset_files=[chunk],
file_path=None,
preprocess_dir=None,
entropy_model_name=None,
worker_id=0,
num_workers=1,
arrow_batch_size=train_cfg.data.arrow_batch_size,
s3_profile=train_cfg.data.s3_profile,
file_format="json",
)
path_to_iter[path] = iterator
max_gen_len = generator.max_gen_len
# We temporarily lower max gen len
@ -137,16 +165,11 @@ def eval_on_val(generator, val_args: ValidationArgs, train_cfg):
all_val_metrics = {}
for src in path_to_iter:
jsonl_iterator = path_to_iter[src]
example_iterator = path_to_iter[src].create_iter()
texts = []
logger.info(f"Running validation on {src}...")
for step, (content, state) in enumerate(jsonl_iterator):
if state["current_iter"] > 0 or (
val_args.max_steps is not None and step >= val_args.max_steps
):
break
content_key = "text" if ("text" in content) else "content"
texts.append(content[content_key])
for step, example in enumerate(example_iterator):
texts.append(example.text)
_, loglikelihood, _ = generator.generate(texts)
@ -191,7 +214,7 @@ def launch_eval(eval_args: EvalArgs):
else:
consolidate_path = os.path.join(eval_args.ckpt_dir, CONSOLIDATE_FOLDER)
if not fs.exists(consolidate_path) and get_global_rank() == 0:
consolidate_path = consolidate_checkpoints(eval_args.ckpt_dir)
consolidate_path = consolidate_checkpoints(fs, eval_args.ckpt_dir)
fs.mkdirs(eval_args.dump_dir, exist_ok=True)
with fs.open(os.path.join(eval_args.dump_dir, "config.yaml"), "w") as f:
@ -210,10 +233,12 @@ def launch_eval(eval_args: EvalArgs):
wrap = EvalHarnessLM(generator)
# Redo
results = simple_evaluate(wrap, eval_args.harness.model_dump())
results = simple_evaluate(wrap, **eval_args.harness.model_dump())
val_results = None
if eval_args.validation:
val_results = eval_on_val(generator, eval_args.validation, train_cfg)
if get_global_rank() == 0:
with fs.open(os.path.join(eval_args.dump_dir, "results.json"), "w") as f:
f.write(json.dumps(results))
@ -222,6 +247,7 @@ def launch_eval(eval_args: EvalArgs):
with fs.open(os.path.join(eval_args.dump_dir, "validation.json"), "w") as f:
f.write(json.dumps(val_results))
logger.info(f"All validation results: {val_results}")
if eval_args.metric_log_dir and get_global_rank() == 0:
metric_log_path = os.path.join(eval_args.metric_log_dir, "metrics.eval.jsonl")

View file

@ -387,8 +387,7 @@ def load_consolidated_model_and_tokenizer(
):
train_args_path = os.path.join(consolidated_path, "params.json")
fs = get_fs(train_args_path)
with fs.open(train_args_path) as f:
train_args = TrainArgs.model_validate_json(f.read())
train_args = TrainArgs.model_validate_json(fs.read_text(train_args_path))
if train_args.train_entropy_model:
model_args = train_args.entropy_model
@ -401,7 +400,8 @@ def load_consolidated_model_and_tokenizer(
train_args.distributed.model_dtype
]
tokenizer = train_args.data.tokenizer_args.build()
st_dict = torch.load(consolidated_path / CONSOLIDATE_NAME, weights_only=True)
with fs.open(os.path.join(consolidated_path, CONSOLIDATE_NAME)) as f:
st_dict = torch.load(f, weights_only=True)
model.load_state_dict(st_dict["model"])
model = model.cuda().eval()
for param in model.parameters():

View file

@ -15,29 +15,37 @@ from bytelatent.entropy_model import load_entropy_model
from bytelatent.tokenizers.build_tokenizer import TokenizerArgs
def get_id_from_doc(doc: dict) -> int:
def get_id_key(doc: dict) -> int:
"""
We need a reliable way to ensure that samples from jsonl
and arrow are the same, but there is no unique id field,
so derive the best possible
"""
if "sample_id" in doc:
sample_id = doc["sample_id"]
return "sample_id"
elif "title" in doc:
sample_id = doc["title"]
return "title"
elif "qid" in doc:
sample_id = doc["qid"]
return "qid"
elif "paper_id" in doc:
sample_id = doc["paper_id"]
return "paper_id"
elif "path" in doc:
sample_id = doc["path"]
return "path"
elif "url" in doc:
sample_id = doc["url"]
return "url"
elif "id" in doc:
sample_id = doc["id"]
return "id"
else:
raise ValueError(f"Could not find a id key from: {doc.keys()}")
return str(sample_id)
def get_id_from_doc(doc: dict) -> int:
"""
We need a reliable way to ensure that samples from jsonl
and arrow are the same, but there is no unique id field,
so derive the best possible
"""
return str(doc[get_id_key(doc)])
def get_text(doc: dict):

View file

@ -527,7 +527,7 @@ def train(args: TrainArgs):
step_tok_losses.append(tok_loss / train_state.scale)
world_size = get_world_size()
if 1 < world_size <= 8:
if 1 < world_size <= 8 and False:
# For some reason, there are errors in reduces due to
# not working for non-bf16 numbers. This function is a patched
# version that converts gradients to bf16 before computing norms.