mirror of
https://github.com/facebookresearch/blt.git
synced 2025-02-23 05:22:16 +00:00
Merge 45bfe94c1e
into sapling-pr-archive-EntilZha
This commit is contained in:
commit
2d1c766050
|
@ -1,8 +1,10 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
import fsspec
|
||||
import numpy as np
|
||||
import yaml
|
||||
from omegaconf import OmegaConf
|
||||
|
@ -10,11 +12,9 @@ from pydantic import BaseModel, ConfigDict
|
|||
|
||||
from bytelatent.checkpoint import CheckpointArgs
|
||||
from bytelatent.data.data_types import Batch
|
||||
from bytelatent.data.file_util import get_fs
|
||||
from bytelatent.data.iterators.abstract_iterator import StatefulIterator
|
||||
from bytelatent.data.iterators.arrow_iterator import (
|
||||
ArrowFileIterator,
|
||||
find_and_sanitize_chunks,
|
||||
)
|
||||
from bytelatent.data.iterators.arrow_iterator import ArrowFileIterator
|
||||
from bytelatent.data.iterators.looping_iterator import LoopingIterator
|
||||
from bytelatent.data.iterators.multiprocess_iterator import MultiprocessIterator
|
||||
from bytelatent.data.iterators.packing_iterator import PackingArgs, PackingIterator
|
||||
|
@ -53,6 +53,33 @@ def parse_args(args_cls):
|
|||
return pydantic_args
|
||||
|
||||
|
||||
TRAIN_DATA_FILE_PATTERN = "*.chunk.*.jsonl"
|
||||
|
||||
|
||||
def find_and_sanitize_chunks(
|
||||
dataset_path: str,
|
||||
world_size: int,
|
||||
file_pattern: str,
|
||||
s3_profile: str | None = None,
|
||||
):
|
||||
fs = get_fs(dataset_path, s3_profile=s3_profile)
|
||||
path_with_glob = os.path.join(dataset_path, file_pattern)
|
||||
dataset_chunks = fs.glob(path_with_glob)
|
||||
n_chunks = len(dataset_chunks)
|
||||
|
||||
if n_chunks > world_size:
|
||||
n_discard = n_chunks - world_size
|
||||
dataset_chunks = dataset_chunks[:world_size]
|
||||
else:
|
||||
assert (
|
||||
world_size % n_chunks == 0
|
||||
), "World size should be a multiple of number of chunks"
|
||||
|
||||
assert n_chunks > 0, f"No valid chunks in {dataset_path}"
|
||||
|
||||
return dataset_chunks
|
||||
|
||||
|
||||
def distribute_data_to_rank(
|
||||
*,
|
||||
dataset_path: str,
|
||||
|
@ -62,9 +89,10 @@ def distribute_data_to_rank(
|
|||
rank: int,
|
||||
world_size: int,
|
||||
s3_profile: str | None = None,
|
||||
file_pattern: str = TRAIN_DATA_FILE_PATTERN,
|
||||
) -> ArrowFileIterator:
|
||||
dataset_chunks = find_and_sanitize_chunks(
|
||||
dataset_path, world_size, s3_profile=s3_profile
|
||||
dataset_path, world_size, file_pattern, s3_profile=s3_profile
|
||||
)
|
||||
n_workers_per_chunk = world_size // len(dataset_chunks)
|
||||
rank_to_arrow_iterator_params = []
|
||||
|
|
623
bytelatent/broken_train.py
Normal file
623
bytelatent/broken_train.py
Normal file
|
@ -0,0 +1,623 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
|
||||
|
||||
from datetime import timedelta
|
||||
from enum import Enum
|
||||
from functools import lru_cache
|
||||
import logging
|
||||
import math
|
||||
import sys
|
||||
import time
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
from torch import Tensor
|
||||
import os
|
||||
import pickle
|
||||
|
||||
import fsspec
|
||||
import torch
|
||||
import torch.distributed
|
||||
import torch.nn.functional
|
||||
import torch.nn.functional as F
|
||||
from torch.distributed._tensor import DTensor
|
||||
|
||||
from torch.distributed.device_mesh import init_device_mesh
|
||||
from torch.utils._foreach_utils import (
|
||||
_device_has_foreach_support,
|
||||
_group_tensors_by_device_and_dtype,
|
||||
_has_foreach_support,
|
||||
)
|
||||
|
||||
from bytelatent.args import TrainArgs
|
||||
from bytelatent.distributed import (
|
||||
DistributedArgs,
|
||||
check_model_value_range,
|
||||
parallelize_model,
|
||||
setup_env,
|
||||
setup_torch_distributed,
|
||||
)
|
||||
|
||||
logger = logging.getLogger()
|
||||
|
||||
|
||||
def set_root_log_level(log_level: str):
|
||||
logger = logging.getLogger()
|
||||
level: int | str = log_level.upper()
|
||||
try:
|
||||
level = int(log_level)
|
||||
except ValueError:
|
||||
pass
|
||||
try:
|
||||
logger.setLevel(level) # type: ignore
|
||||
except Exception:
|
||||
logger.warning(
|
||||
f"Failed to set logging level to {log_level}, using default 'NOTSET'"
|
||||
)
|
||||
logger.setLevel(logging.NOTSET)
|
||||
|
||||
|
||||
class LogFormatter(logging.Formatter):
|
||||
"""
|
||||
Custom logger for distributed jobs, displaying rank
|
||||
and preserving indent from the custom prefix format.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.start_time = time.time()
|
||||
self.rank = get_global_rank()
|
||||
self.show_rank = not get_is_slurm_job() # srun has --label
|
||||
|
||||
def formatTime(self, record):
|
||||
subsecond, seconds = math.modf(record.created)
|
||||
curr_date = (
|
||||
time.strftime("%y-%m-%d %H:%M:%S", time.localtime(seconds))
|
||||
+ f".{int(subsecond * 1_000_000):06d}"
|
||||
)
|
||||
delta = timedelta(seconds=round(record.created - self.start_time))
|
||||
return f"{curr_date} - {delta}"
|
||||
|
||||
def formatPrefix(self, record):
|
||||
fmt_time = self.formatTime(record)
|
||||
if self.show_rank:
|
||||
return f"{self.rank}: {record.levelname:<7} {fmt_time} - "
|
||||
else:
|
||||
return f"{record.levelname:<7} {fmt_time} - "
|
||||
|
||||
def formatMessage(self, record, indent: str):
|
||||
content = record.getMessage()
|
||||
content = content.replace("\n", "\n" + indent)
|
||||
# Exception handling as in the default formatter, albeit with indenting
|
||||
# according to our custom prefix
|
||||
if record.exc_info:
|
||||
# Cache the traceback text to avoid converting it multiple times
|
||||
# (it's constant anyway)
|
||||
if not record.exc_text:
|
||||
record.exc_text = self.formatException(record.exc_info)
|
||||
if record.exc_text:
|
||||
if content[-1:] != "\n":
|
||||
content = content + "\n" + indent
|
||||
content = content + indent.join(
|
||||
[l + "\n" for l in record.exc_text.splitlines()]
|
||||
)
|
||||
if content[-1:] == "\n":
|
||||
content = content[:-1]
|
||||
if record.stack_info:
|
||||
if content[-1:] != "\n":
|
||||
content = content + "\n" + indent
|
||||
stack_text = self.formatStack(record.stack_info)
|
||||
content = content + indent.join([l + "\n" for l in stack_text.splitlines()])
|
||||
if content[-1:] == "\n":
|
||||
content = content[:-1]
|
||||
|
||||
return content
|
||||
|
||||
def format(self, record):
|
||||
prefix = self.formatPrefix(record)
|
||||
indent = " " * len(prefix)
|
||||
content = self.formatMessage(record, indent)
|
||||
return prefix + content
|
||||
|
||||
|
||||
def init_logger(
|
||||
log_file: str | None = None,
|
||||
*,
|
||||
name: str | None = None,
|
||||
level: str = "INFO",
|
||||
fs: fsspec.AbstractFileSystem | None = None,
|
||||
):
|
||||
"""
|
||||
Setup logging.
|
||||
|
||||
Args:
|
||||
log_file: A file name to save file logs to.
|
||||
name: The name of the logger to configure, by default the root logger.
|
||||
level: The logging level to use.
|
||||
"""
|
||||
set_root_log_level(level)
|
||||
logger = logging.getLogger(name)
|
||||
|
||||
# stdout: everything
|
||||
stdout_handler = logging.StreamHandler(sys.stdout)
|
||||
stdout_handler.setLevel(logging.NOTSET)
|
||||
stdout_handler.setFormatter(LogFormatter())
|
||||
|
||||
# stderr: warnings / errors and above
|
||||
stderr_handler = logging.StreamHandler(sys.stderr)
|
||||
stderr_handler.setLevel(logging.WARNING)
|
||||
stderr_handler.setFormatter(LogFormatter())
|
||||
|
||||
# set stream handlers
|
||||
logger.handlers.clear()
|
||||
logger.handlers.append(stdout_handler)
|
||||
logger.handlers.append(stderr_handler)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def fixed_clip_grad_norm_(
|
||||
parameters: torch.Tensor | list[torch.Tensor],
|
||||
max_norm: float,
|
||||
norm_type: float = 2.0,
|
||||
error_if_nonfinite: bool = False,
|
||||
foreach: Optional[bool] = None,
|
||||
) -> torch.Tensor:
|
||||
r"""Clip the gradient norm of an iterable of parameters.
|
||||
|
||||
The norm is computed over the norms of the individual gradients of all parameters,
|
||||
as if the norms of the individual gradients were concatenated into a single vector.
|
||||
Gradients are modified in-place.
|
||||
|
||||
Args:
|
||||
parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a
|
||||
single Tensor that will have gradients normalized
|
||||
max_norm (float): max norm of the gradients
|
||||
norm_type (float): type of the used p-norm. Can be ``'inf'`` for
|
||||
infinity norm.
|
||||
error_if_nonfinite (bool): if True, an error is thrown if the total
|
||||
norm of the gradients from :attr:`parameters` is ``nan``,
|
||||
``inf``, or ``-inf``. Default: False (will switch to True in the future)
|
||||
foreach (bool): use the faster foreach-based implementation.
|
||||
If ``None``, use the foreach implementation for CUDA and CPU native tensors and silently
|
||||
fall back to the slow implementation for other device types.
|
||||
Default: ``None``
|
||||
|
||||
Returns:
|
||||
Total norm of the parameter gradients (viewed as a single vector).
|
||||
"""
|
||||
if isinstance(parameters, torch.Tensor):
|
||||
parameters = [parameters]
|
||||
grads = [p.grad.to(torch.bfloat16) for p in parameters if p.grad is not None]
|
||||
max_norm = float(max_norm)
|
||||
norm_type = float(norm_type)
|
||||
if len(grads) == 0:
|
||||
return torch.tensor(0.0)
|
||||
first_device = grads[0].device
|
||||
grouped_grads: Dict[
|
||||
Tuple[torch.device, torch.dtype], Tuple[List[List[Tensor]], List[int]]
|
||||
] = _group_tensors_by_device_and_dtype(
|
||||
[grads]
|
||||
) # type: ignore[assignment]
|
||||
|
||||
norms: List[Tensor] = []
|
||||
for (device, _), ([device_grads], _) in grouped_grads.items(): # type: ignore[assignment]
|
||||
if (foreach is None and _has_foreach_support(device_grads, device)) or (
|
||||
foreach and _device_has_foreach_support(device)
|
||||
):
|
||||
norms.extend(torch._foreach_norm(device_grads, norm_type))
|
||||
elif foreach:
|
||||
raise RuntimeError(
|
||||
f"foreach=True was passed, but can't use the foreach API on {device.type} tensors"
|
||||
)
|
||||
else:
|
||||
norms.extend([torch.linalg.vector_norm(g, norm_type) for g in device_grads])
|
||||
|
||||
total_norm = torch.linalg.vector_norm(
|
||||
torch.stack([norm.to(first_device) for norm in norms]), norm_type
|
||||
)
|
||||
|
||||
if error_if_nonfinite and torch.logical_or(total_norm.isnan(), total_norm.isinf()):
|
||||
raise RuntimeError(
|
||||
f"The total norm of order {norm_type} for gradients from "
|
||||
"`parameters` is non-finite, so it cannot be clipped. To disable "
|
||||
"this error and scale the gradients by the non-finite norm anyway, "
|
||||
"set `error_if_nonfinite=False`"
|
||||
)
|
||||
clip_coef = max_norm / (total_norm + 1e-6)
|
||||
# Note: multiplying by the clamped coef is redundant when the coef is clamped to 1, but doing so
|
||||
# avoids a `if clip_coef < 1:` conditional which can require a CPU <=> device synchronization
|
||||
# when the gradients do not reside in CPU memory.
|
||||
clip_coef_clamped = torch.clamp(clip_coef, max=1.0)
|
||||
for (device, _), ([device_grads], _) in grouped_grads.items(): # type: ignore[assignment]
|
||||
if (foreach is None and _has_foreach_support(device_grads, device)) or (
|
||||
foreach and _device_has_foreach_support(device)
|
||||
):
|
||||
torch._foreach_mul_(device_grads, clip_coef_clamped.to(device))
|
||||
elif foreach:
|
||||
raise RuntimeError(
|
||||
f"foreach=True was passed, but can't use the foreach API on {device.type} tensors"
|
||||
)
|
||||
else:
|
||||
clip_coef_clamped_device = clip_coef_clamped.to(device)
|
||||
for g in device_grads:
|
||||
g.mul_(clip_coef_clamped_device)
|
||||
|
||||
return total_norm
|
||||
|
||||
|
||||
def get_no_recompute_ops():
|
||||
return None
|
||||
|
||||
|
||||
@lru_cache()
|
||||
def get_is_torch_run() -> bool:
|
||||
return os.environ.get("LOCAL_RANK") is not None
|
||||
|
||||
|
||||
@lru_cache()
|
||||
def get_is_slurm_job() -> bool:
|
||||
return "SLURM_JOB_ID" in os.environ and not get_is_torch_run()
|
||||
|
||||
|
||||
@lru_cache()
|
||||
def get_global_rank() -> int:
|
||||
if get_is_torch_run():
|
||||
return int(os.environ["RANK"])
|
||||
elif get_is_slurm_job():
|
||||
return int(os.environ["SLURM_PROCID"])
|
||||
else:
|
||||
return 0
|
||||
|
||||
|
||||
@lru_cache()
|
||||
def get_local_rank() -> int:
|
||||
if get_is_torch_run():
|
||||
return int(os.environ["LOCAL_RANK"])
|
||||
elif get_is_slurm_job():
|
||||
return int(os.environ["SLURM_LOCALID"])
|
||||
else:
|
||||
return 0
|
||||
|
||||
|
||||
@lru_cache()
|
||||
def get_world_size() -> int:
|
||||
if get_is_torch_run():
|
||||
return int(os.environ["WORLD_SIZE"])
|
||||
elif get_is_slurm_job():
|
||||
return int(os.environ["SLURM_NTASKS"])
|
||||
else:
|
||||
return 1
|
||||
|
||||
|
||||
@lru_cache()
|
||||
def get_is_master() -> bool:
|
||||
return get_global_rank() == 0
|
||||
|
||||
|
||||
def validate_train_args(args: TrainArgs, output_size: int):
|
||||
# assert args.model is not None or args.entropy_model is not None
|
||||
if args.entropy_model is not None:
|
||||
logger.info(f"Setting model output size to {args.entropy_model.vocab_size}")
|
||||
args.entropy_model.vocab_size = output_size
|
||||
|
||||
assert args.dump_dir, "Dump dir not set"
|
||||
|
||||
if (
|
||||
args.distributed.dp_replicate
|
||||
* args.distributed.dp_shard
|
||||
* args.distributed.tp_size
|
||||
!= get_world_size()
|
||||
):
|
||||
logging.info("Modifying TrainArgs distributed config")
|
||||
assert get_world_size() % args.distributed.dp_shard == 0
|
||||
logging.info("World size: %s", get_world_size())
|
||||
logging.info(
|
||||
"Existing setting: train_args.distributed.dp_shard=%s",
|
||||
args.distributed.dp_shard,
|
||||
)
|
||||
logging.info(
|
||||
"Setting train_args.distributed.dp_replicate=%s, was dp_replicate=%s",
|
||||
get_world_size() // args.distributed.dp_shard,
|
||||
args.distributed.dp_replicate,
|
||||
)
|
||||
args.distributed.dp_replicate = get_world_size() // args.distributed.dp_shard
|
||||
|
||||
logging.info(
|
||||
"Changing dp_replicate from %s to %s, to account for tp_size=%s",
|
||||
args.distributed.dp_replicate,
|
||||
args.distributed.dp_replicate // args.distributed.tp_size,
|
||||
args.distributed.tp_size,
|
||||
)
|
||||
assert args.distributed.dp_replicate % args.distributed.tp_size == 0
|
||||
args.distributed.dp_replicate = (
|
||||
args.distributed.dp_replicate // args.distributed.tp_size
|
||||
)
|
||||
|
||||
logger.warning(
|
||||
f"Setting Data Parallel size to {args.distributed.dp_replicate * args.distributed.dp_shard}"
|
||||
)
|
||||
assert (
|
||||
args.distributed.dp_replicate
|
||||
* args.distributed.dp_shard
|
||||
* args.distributed.tp_size
|
||||
== get_world_size()
|
||||
)
|
||||
|
||||
if args.distributed.fsdp_type == "no_shard":
|
||||
assert (
|
||||
args.distributed.dp_shard == 1
|
||||
and args.distributed.dp_replicate == get_world_size()
|
||||
)
|
||||
|
||||
if args.model is not None:
|
||||
args.model.max_seqlen = args.data.seq_len
|
||||
if args.entropy_model is not None:
|
||||
args.entropy_model.max_seqlen = args.data.seq_len
|
||||
|
||||
if args.distributed.tp_size == 1:
|
||||
logger.warning(
|
||||
"Tensor parallelism has not been tested for a while, use at your own risk"
|
||||
)
|
||||
|
||||
assert (
|
||||
args.probe_freq != args.profiling.mem_steps
|
||||
), "Don't profile during probe step"
|
||||
assert (
|
||||
args.probe_freq != args.profiling.profile_steps
|
||||
), "Don't profile during probe step"
|
||||
if args.logging.wandb is not None:
|
||||
args.logging.wandb.name = args.name
|
||||
|
||||
if args.probe_freq is not None:
|
||||
assert (
|
||||
args.distributed.tp_size == 1
|
||||
), "Probing not supported with tensor parallelism"
|
||||
assert (
|
||||
args.distributed.selective_activation_checkpointing is False
|
||||
), "Probing not supported with selective activation checkpointing"
|
||||
|
||||
|
||||
def compute_loss(p, y, mask, scale):
|
||||
tok_loss = scale * F.cross_entropy(
|
||||
p.flatten(0, 1), y.flatten(0, 1), reduction="none"
|
||||
)
|
||||
if mask is None:
|
||||
loss = tok_loss.mean()
|
||||
else:
|
||||
mask = mask.flatten(0, 1)
|
||||
tok_loss = tok_loss * mask
|
||||
loss = tok_loss.sum() / (mask.sum() + 1e-6)
|
||||
return loss, tok_loss
|
||||
|
||||
|
||||
def get_device_mesh(distributed_args):
|
||||
tp_size = distributed_args.tp_size
|
||||
dp_replicate = distributed_args.dp_replicate
|
||||
dp_shard = distributed_args.dp_shard
|
||||
|
||||
assert (
|
||||
dp_replicate * dp_shard * tp_size == get_world_size()
|
||||
), f"dp_replicate * dp_shard * tp_size ({dp_replicate} * {dp_shard} * {tp_size}) != world_size ({get_world_size()})"
|
||||
|
||||
dims = []
|
||||
names = []
|
||||
if dp_replicate >= 1:
|
||||
dims.append(dp_replicate)
|
||||
names.append("dp_replicate")
|
||||
if dp_shard > 1 or distributed_args.fsdp_type == "no_shard":
|
||||
dims.append(dp_shard)
|
||||
names.append("dp_shard")
|
||||
if tp_size > 1:
|
||||
dims.append(tp_size)
|
||||
names.append("tp")
|
||||
dims = tuple(dims)
|
||||
names = tuple(names)
|
||||
|
||||
return init_device_mesh("cuda", mesh_shape=dims, mesh_dim_names=names)
|
||||
|
||||
|
||||
def build_fsdp_grouping_plan():
|
||||
group_plan: tuple[int, bool] = []
|
||||
|
||||
# Grouping and output seperately
|
||||
group_plan.append(("tok_embeddings", False))
|
||||
|
||||
# Grouping by layers
|
||||
# for i in range(model_args.n_layers):
|
||||
# group_plan.append((f"layers.{i}", False))
|
||||
|
||||
group_plan.append(("output", True))
|
||||
|
||||
return group_plan
|
||||
|
||||
|
||||
class MinimalModel(torch.nn.Module):
|
||||
def __init__(self, dim: int, vocab_size: int):
|
||||
super().__init__()
|
||||
self.tok_embeddings = torch.nn.Embedding(vocab_size, dim)
|
||||
|
||||
# self.norm = RMSNorm(args.dim, eps=args.norm_eps)
|
||||
# self.layers = torch.nn.ModuleList()
|
||||
# for _ in range(args.n_layers):
|
||||
# self.layers.append(TransformerBlock(args))
|
||||
|
||||
self.output = torch.nn.Linear(
|
||||
dim,
|
||||
vocab_size,
|
||||
bias=False,
|
||||
)
|
||||
|
||||
def forward(self, tokens):
|
||||
h = self.tok_embeddings(tokens)
|
||||
logits = self.output(h)
|
||||
# logits = self.output(self.norm(h))
|
||||
return logits
|
||||
|
||||
def reset_parameters(self, init_std=None):
|
||||
pass
|
||||
|
||||
def init_weights(self):
|
||||
pass
|
||||
|
||||
|
||||
def train():
|
||||
args = TrainArgs(
|
||||
dump_dir="/tmp",
|
||||
name="debug_bf16",
|
||||
model=None,
|
||||
entropy_model=None,
|
||||
distributed=DistributedArgs(
|
||||
fsdp_type="full_shard",
|
||||
model_dtype="bf16",
|
||||
matmul_allow_tf32=False,
|
||||
selective_activation_checkpointing=False,
|
||||
tp_size=1,
|
||||
),
|
||||
)
|
||||
tokenizer = args.data.tokenizer_args.build()
|
||||
validate_train_args(
|
||||
args,
|
||||
tokenizer.n_words,
|
||||
)
|
||||
dump_fs = fsspec.filesystem("file")
|
||||
init_logger(os.path.join(args.dump_dir, "train.log"), fs=dump_fs)
|
||||
setup_env(args.env)
|
||||
setup_torch_distributed(args.distributed)
|
||||
world_mesh = get_device_mesh(args.distributed)
|
||||
logger.info(f"Starting job: {args.name}")
|
||||
|
||||
# build dataloader
|
||||
# need dp world size and rank
|
||||
dp_mesh = world_mesh["dp_replicate"]
|
||||
dp_degree = dp_mesh.size()
|
||||
dp_rank = dp_mesh.get_local_rank()
|
||||
if args.distributed.dp_shard > 1:
|
||||
dp_rank = dp_rank * dp_degree + world_mesh["dp_shard"].get_local_rank()
|
||||
dp_degree *= world_mesh["dp_shard"].size()
|
||||
|
||||
logger.info(f"Running on dp rank : {dp_rank}")
|
||||
logger.info(f"Running on dp size : {dp_degree}")
|
||||
|
||||
torch.manual_seed(args.seed)
|
||||
logger.info("Building model")
|
||||
|
||||
# Initializing Model in meta device allows us to initialize models much bigger than 1 gpu's memory
|
||||
with torch.device("meta"):
|
||||
model = MinimalModel(768, tokenizer.n_words)
|
||||
|
||||
model = parallelize_model(
|
||||
model,
|
||||
world_mesh,
|
||||
args.model,
|
||||
args.distributed,
|
||||
fsdp_grouping_plan=build_fsdp_grouping_plan(),
|
||||
tp_parallelize=None,
|
||||
no_recompute_ops=get_no_recompute_ops(),
|
||||
)
|
||||
|
||||
# Once we shard the model on different gpus we can actually initialize the model
|
||||
# First we create empty tensors of the correct shapes
|
||||
model = model.to_empty(device="cuda")
|
||||
# Then we init the model. Please make sure this function initializes *ALL* parameters
|
||||
# and buffers, otherwise you will have random values in the unitialized tensors
|
||||
# which will silently fail (give nan gradients for example)
|
||||
|
||||
with torch.random.fork_rng(devices=[torch.cuda.current_device()]):
|
||||
torch.manual_seed(42)
|
||||
model.init_weights()
|
||||
check_model_value_range(model, range=10.0, std=1.0)
|
||||
|
||||
# data_loader = args.data.build_from_rank(dp_rank, dp_degree)
|
||||
|
||||
# train loop
|
||||
model.train()
|
||||
# data_loader = train_state.data_loader_state.build()
|
||||
# batch_iterator = data_loader.create_iter()
|
||||
# batch = next(batch_iterator)
|
||||
# with open(f"/storage/home/par/toy-data/batch_{dp_rank}.pickle", "wb") as f:
|
||||
# pickle.dump(batch, f)
|
||||
with open(f"/storage/home/par/toy-data/batch_{dp_rank}.pickle", "rb") as f:
|
||||
batch = pickle.load(f)
|
||||
|
||||
batch_x = torch.from_numpy(
|
||||
batch.x,
|
||||
).cuda()
|
||||
batch_y = torch.from_numpy(batch.y).cuda()
|
||||
mask = None if batch.mask is None else torch.from_numpy(batch.mask).cuda()
|
||||
pred = model(batch_x)
|
||||
loss, _ = compute_loss(pred, batch_y, mask, 1.0)
|
||||
|
||||
# We scale loss with grad_acc_steps so the gradient is the same
|
||||
# regardless of grad_acc_steps
|
||||
loss = loss / args.grad_acc_steps
|
||||
|
||||
# backward on scaled loss to create scaled gradients
|
||||
loss.backward()
|
||||
# For logging we undo that scaling
|
||||
loss = loss.detach() * args.grad_acc_steps
|
||||
|
||||
world_size = get_world_size()
|
||||
if 1 < world_size <= 8 and False:
|
||||
# For some reason, there are errors in reduces due to
|
||||
# not working for non-bf16 numbers. This function is a patched
|
||||
# version that converts gradients to bf16 before computing norms.
|
||||
# The error only happens in distributed training on one node,
|
||||
# hence the guard
|
||||
grad_norm = fixed_clip_grad_norm_(
|
||||
model.parameters(), max_norm=args.optim.clip, foreach=True
|
||||
)
|
||||
else:
|
||||
grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||
model.parameters(), max_norm=args.optim.clip, foreach=True
|
||||
)
|
||||
|
||||
grad_norm = (
|
||||
grad_norm.full_tensor() if isinstance(grad_norm, DTensor) else grad_norm
|
||||
).item()
|
||||
|
||||
# if isinstance(data_loader, MultiprocessIterator):
|
||||
# logger.info("Closing MP iterator before exiting")
|
||||
# data_loader.shutdown()
|
||||
|
||||
|
||||
def main():
|
||||
"""
|
||||
The command line interface here uses OmegaConf https://omegaconf.readthedocs.io/en/2.3_branch/usage.html#from-command-line-arguments
|
||||
This accepts arguments as a dot list
|
||||
So if the dataclass looks like
|
||||
|
||||
@dataclass
|
||||
class DummyArgs:
|
||||
name: str
|
||||
model: LMTransformerArgsgs
|
||||
|
||||
@dataclass
|
||||
class LMTransformerArgsgs:
|
||||
dim: int
|
||||
|
||||
Then you can pass model.dim=32 to change values in LMTransformerArgsgs
|
||||
or just name=tictac for top level attributes.
|
||||
|
||||
The behavior here is as follows:
|
||||
1. We instantiate TrainArgs with its default values
|
||||
2. We override those default values with the ones in the provided config file
|
||||
3. We override the result with the additional arguments provided through command line
|
||||
|
||||
For example, if the config is the following
|
||||
|
||||
model:
|
||||
dim: 128
|
||||
n_layers: 4
|
||||
|
||||
and you call train.py with train.py model.dim=64
|
||||
|
||||
Then the final TrainArgs will have
|
||||
|
||||
model:
|
||||
dim: 64
|
||||
n_layers: 4
|
||||
|
||||
Plus all the default values in TrainArgs dataclass.
|
||||
"""
|
||||
train()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
|
@ -16,6 +16,7 @@ from bytelatent import ByteLatentError
|
|||
from bytelatent.data.data_types import BltExample
|
||||
from bytelatent.data.file_util import get_fs
|
||||
from bytelatent.data.iterators.abstract_iterator import IteratorState, StatefulIterator
|
||||
from bytelatent.preprocess.preprocess_entropies import get_id_key, get_text
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
@ -32,6 +33,7 @@ class ArrowFileIteratorState(BaseModel, IteratorState):
|
|||
arrow_batch_size: int = 100
|
||||
s3_profile: str | None
|
||||
filesystem_type: str | None = None
|
||||
file_format: str
|
||||
|
||||
def build(self) -> "ArrowFileIterator":
|
||||
arrow_file = ArrowFileIterator(
|
||||
|
@ -44,6 +46,7 @@ class ArrowFileIteratorState(BaseModel, IteratorState):
|
|||
dataset_files=self.dataset_files,
|
||||
s3_profile=self.s3_profile,
|
||||
filesystem_type=self.filesystem_type,
|
||||
file_format=self.file_format,
|
||||
)
|
||||
if self.row_num != 0:
|
||||
arrow_file._set_row_num(self.row_num)
|
||||
|
@ -70,6 +73,7 @@ class ArrowFileIterator(StatefulIterator):
|
|||
dataset_files: list[str] | None = None,
|
||||
s3_profile: str | None = None,
|
||||
filesystem_type: str | None = None,
|
||||
file_format: str = "arrow",
|
||||
):
|
||||
assert 0 <= worker_id < num_workers, (worker_id, num_workers)
|
||||
if file_path is None and dataset_files is None:
|
||||
|
@ -87,12 +91,16 @@ class ArrowFileIterator(StatefulIterator):
|
|||
self.arrow_batch_size = arrow_batch_size
|
||||
self.s3_profile = s3_profile
|
||||
self.filesystem_type = filesystem_type
|
||||
self.file_format = file_format
|
||||
self.fs = None
|
||||
if self.filesystem_type is not None:
|
||||
if self.filesystem_type == "file":
|
||||
self.fs = fsspec.filesystem("file")
|
||||
elif self.filesystem_type == "s3":
|
||||
self.fs = fsspec.filesystem("s3", profile=s3_profile)
|
||||
else:
|
||||
raise ValueError("Unknown filesystem")
|
||||
logger.info("Arrow iterator using fs=%s", self.fs)
|
||||
|
||||
if dataset_files is None:
|
||||
# Prepare arrow shards
|
||||
|
@ -153,6 +161,7 @@ class ArrowFileIterator(StatefulIterator):
|
|||
dataset_files=self.dataset_files,
|
||||
s3_profile=self.s3_profile,
|
||||
filesystem_type=self.filesystem_type,
|
||||
file_format=self.file_format,
|
||||
)
|
||||
|
||||
def create_iter(
|
||||
|
@ -164,7 +173,7 @@ class ArrowFileIterator(StatefulIterator):
|
|||
else:
|
||||
filesystem = None
|
||||
self.dataset = pa.dataset.dataset(
|
||||
self.dataset_files, format="arrow", filesystem=filesystem
|
||||
self.dataset_files, format=self.file_format, filesystem=filesystem
|
||||
)
|
||||
self.batch_iterator = self.dataset.to_batches(
|
||||
batch_size=self.arrow_batch_size
|
||||
|
@ -173,13 +182,22 @@ class ArrowFileIterator(StatefulIterator):
|
|||
if self.batch_to_consume is not None:
|
||||
batch_columns: dict[str, list] = self.batch_to_consume
|
||||
self.batch_to_consume = None
|
||||
if self.file_format == "arrow":
|
||||
sample_ids = batch_columns["sample_id"]
|
||||
texts = batch_columns["text"]
|
||||
entropies = batch_columns["entropies"]
|
||||
elif self.file_format == "json":
|
||||
# This data hasn't been preprocessed to a uniform format,
|
||||
# so we have to do it now and omit entropies
|
||||
sample_ids = batch_columns[get_id_key(batch_columns)]
|
||||
texts = get_text(batch_columns)
|
||||
entropies = None
|
||||
else:
|
||||
raise ValueError(f"Unknown file format: {self.file_format}")
|
||||
for i in range(len(sample_ids)):
|
||||
out = BltExample(
|
||||
sample_id=sample_ids[i],
|
||||
entropies=entropies[i],
|
||||
entropies=entropies[i] if entropies is not None else None,
|
||||
text=texts[i],
|
||||
tokens=None,
|
||||
mask=None,
|
||||
|
@ -191,13 +209,22 @@ class ArrowFileIterator(StatefulIterator):
|
|||
|
||||
for batch in self.batch_iterator:
|
||||
batch_columns = batch.to_pydict()
|
||||
if self.file_format == "arrow":
|
||||
sample_ids = batch_columns["sample_id"]
|
||||
texts = batch_columns["text"]
|
||||
entropies = batch_columns["entropies"]
|
||||
elif self.file_format == "json":
|
||||
# This data hasn't been preprocessed to a uniform format,
|
||||
# so we have to do it now and omit entropies
|
||||
sample_ids = batch_columns[get_id_key(batch_columns)]
|
||||
texts = get_text(batch_columns)
|
||||
entropies = None
|
||||
else:
|
||||
raise ValueError(f"Unknown file format: {self.file_format}")
|
||||
for i in range(len(sample_ids)):
|
||||
out = BltExample(
|
||||
sample_id=sample_ids[i],
|
||||
entropies=entropies[i],
|
||||
entropies=entropies[i] if entropies is not None else None,
|
||||
text=texts[i],
|
||||
tokens=None,
|
||||
mask=None,
|
||||
|
@ -231,13 +258,24 @@ class ArrowFileIterator(StatefulIterator):
|
|||
for batch in self.batch_iterator:
|
||||
if len(batch) > curr_remaining:
|
||||
batch_columns: dict[str, list] = batch.to_pydict()
|
||||
batch_columns["sample_id"] = batch_columns["sample_id"][
|
||||
if self.file_format == "arrow":
|
||||
leftover_sample_ids = batch_columns["sample_id"][
|
||||
curr_remaining:
|
||||
]
|
||||
batch_columns["entropies"] = batch_columns["entropies"][
|
||||
leftover_entropies = batch_columns["entropies"][curr_remaining:]
|
||||
leftover_texts = batch_columns["text"][curr_remaining:]
|
||||
elif self.file_format == "json":
|
||||
leftover_sample_ids = batch_columns[get_id_key(batch_columns)][
|
||||
curr_remaining:
|
||||
]
|
||||
batch_columns["text"] = batch_columns["text"][curr_remaining:]
|
||||
leftover_entropies = None
|
||||
leftover_texts = get_text(batch_columns)[curr_remaining:]
|
||||
else:
|
||||
raise ValueError(f"Unknown file format: {self.file_format}")
|
||||
|
||||
batch_columns["sample_id"] = leftover_sample_ids
|
||||
batch_columns["entropies"] = leftover_entropies
|
||||
batch_columns["text"] = leftover_texts
|
||||
self.batch_to_consume = batch_columns
|
||||
break
|
||||
elif len(batch) == curr_remaining:
|
||||
|
@ -250,30 +288,3 @@ class ArrowFileIterator(StatefulIterator):
|
|||
logger.info(
|
||||
f"Finished setting arrow position to {target_row_num} for {self.dataset_files}"
|
||||
)
|
||||
|
||||
|
||||
TRAIN_DATA_FILE_PATTERN = "*.chunk.*.jsonl"
|
||||
|
||||
|
||||
def find_and_sanitize_chunks(
|
||||
dataset_path: str,
|
||||
world_size: int,
|
||||
file_pattern: str = TRAIN_DATA_FILE_PATTERN,
|
||||
s3_profile: str | None = None,
|
||||
):
|
||||
fs = get_fs(dataset_path, s3_profile=s3_profile)
|
||||
path_with_glob = os.path.join(dataset_path, file_pattern)
|
||||
dataset_chunks = fs.glob(path_with_glob)
|
||||
n_chunks = len(dataset_chunks)
|
||||
|
||||
if n_chunks > world_size:
|
||||
n_discard = n_chunks - world_size
|
||||
dataset_chunks = dataset_chunks[:world_size]
|
||||
else:
|
||||
assert (
|
||||
world_size % n_chunks == 0
|
||||
), "World size should be a multiple of number of chunks"
|
||||
|
||||
assert n_chunks > 0, f"No valid chunks in {dataset_path}"
|
||||
|
||||
return dataset_chunks
|
||||
|
|
|
@ -15,9 +15,16 @@ from lm_eval.api.model import LM
|
|||
from omegaconf import OmegaConf
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
from bytelatent.args import EvalArgs, ValidationArgs, parse_args
|
||||
from bytelatent.args import (
|
||||
EvalArgs,
|
||||
TrainArgs,
|
||||
ValidationArgs,
|
||||
find_and_sanitize_chunks,
|
||||
parse_args,
|
||||
)
|
||||
from bytelatent.checkpoint import CONSOLIDATE_FOLDER, consolidate_checkpoints
|
||||
from bytelatent.data.file_util import get_fs
|
||||
from bytelatent.data.iterators.arrow_iterator import ArrowFileIterator
|
||||
from bytelatent.distributed import (
|
||||
DistributedArgs,
|
||||
dist_mean_dict,
|
||||
|
@ -117,19 +124,40 @@ class EvalHarnessLM(LM):
|
|||
return results
|
||||
|
||||
|
||||
def eval_on_val(generator, val_args: ValidationArgs, train_cfg):
|
||||
srcs = {}
|
||||
def eval_on_val(generator, val_args: ValidationArgs, train_cfg: TrainArgs):
|
||||
srcs = []
|
||||
for src in val_args.sources:
|
||||
path = os.path.join(val_args.root_dir, src)
|
||||
srcs[path] = 1.0
|
||||
srcs.append(path)
|
||||
|
||||
for src in train_cfg.data.sources:
|
||||
path = os.path.join(train_cfg.data.root_dir, src)
|
||||
srcs[path] = 1.0
|
||||
srcs.append(path)
|
||||
|
||||
multi_state = init_choice_state(
|
||||
"", srcs, 0, get_global_rank(), get_world_size(), "*.val.jsonl"
|
||||
path_to_iter = {}
|
||||
for path in srcs:
|
||||
chunks = find_and_sanitize_chunks(
|
||||
path,
|
||||
world_size=1,
|
||||
file_pattern="*.val.jsonl",
|
||||
s3_profile=train_cfg.data.s3_profile,
|
||||
)
|
||||
path_to_iter = setup_sources(multi_state)
|
||||
assert (
|
||||
len(chunks) == 1
|
||||
), f"There should be only 1 chunk per validation file, but found: {chunks}"
|
||||
chunk = chunks[0]
|
||||
iterator = ArrowFileIterator(
|
||||
dataset_files=[chunk],
|
||||
file_path=None,
|
||||
preprocess_dir=None,
|
||||
entropy_model_name=None,
|
||||
worker_id=0,
|
||||
num_workers=1,
|
||||
arrow_batch_size=train_cfg.data.arrow_batch_size,
|
||||
s3_profile=train_cfg.data.s3_profile,
|
||||
file_format="json",
|
||||
)
|
||||
path_to_iter[path] = iterator
|
||||
|
||||
max_gen_len = generator.max_gen_len
|
||||
# We temporarily lower max gen len
|
||||
|
@ -137,16 +165,11 @@ def eval_on_val(generator, val_args: ValidationArgs, train_cfg):
|
|||
|
||||
all_val_metrics = {}
|
||||
for src in path_to_iter:
|
||||
jsonl_iterator = path_to_iter[src]
|
||||
example_iterator = path_to_iter[src].create_iter()
|
||||
texts = []
|
||||
logger.info(f"Running validation on {src}...")
|
||||
for step, (content, state) in enumerate(jsonl_iterator):
|
||||
if state["current_iter"] > 0 or (
|
||||
val_args.max_steps is not None and step >= val_args.max_steps
|
||||
):
|
||||
break
|
||||
content_key = "text" if ("text" in content) else "content"
|
||||
texts.append(content[content_key])
|
||||
for step, example in enumerate(example_iterator):
|
||||
texts.append(example.text)
|
||||
|
||||
_, loglikelihood, _ = generator.generate(texts)
|
||||
|
||||
|
@ -191,7 +214,7 @@ def launch_eval(eval_args: EvalArgs):
|
|||
else:
|
||||
consolidate_path = os.path.join(eval_args.ckpt_dir, CONSOLIDATE_FOLDER)
|
||||
if not fs.exists(consolidate_path) and get_global_rank() == 0:
|
||||
consolidate_path = consolidate_checkpoints(eval_args.ckpt_dir)
|
||||
consolidate_path = consolidate_checkpoints(fs, eval_args.ckpt_dir)
|
||||
|
||||
fs.mkdirs(eval_args.dump_dir, exist_ok=True)
|
||||
with fs.open(os.path.join(eval_args.dump_dir, "config.yaml"), "w") as f:
|
||||
|
@ -210,10 +233,12 @@ def launch_eval(eval_args: EvalArgs):
|
|||
|
||||
wrap = EvalHarnessLM(generator)
|
||||
# Redo
|
||||
results = simple_evaluate(wrap, eval_args.harness.model_dump())
|
||||
results = simple_evaluate(wrap, **eval_args.harness.model_dump())
|
||||
|
||||
val_results = None
|
||||
if eval_args.validation:
|
||||
val_results = eval_on_val(generator, eval_args.validation, train_cfg)
|
||||
|
||||
if get_global_rank() == 0:
|
||||
with fs.open(os.path.join(eval_args.dump_dir, "results.json"), "w") as f:
|
||||
f.write(json.dumps(results))
|
||||
|
@ -222,6 +247,7 @@ def launch_eval(eval_args: EvalArgs):
|
|||
with fs.open(os.path.join(eval_args.dump_dir, "validation.json"), "w") as f:
|
||||
f.write(json.dumps(val_results))
|
||||
logger.info(f"All validation results: {val_results}")
|
||||
|
||||
if eval_args.metric_log_dir and get_global_rank() == 0:
|
||||
metric_log_path = os.path.join(eval_args.metric_log_dir, "metrics.eval.jsonl")
|
||||
|
||||
|
|
|
@ -387,8 +387,7 @@ def load_consolidated_model_and_tokenizer(
|
|||
):
|
||||
train_args_path = os.path.join(consolidated_path, "params.json")
|
||||
fs = get_fs(train_args_path)
|
||||
with fs.open(train_args_path) as f:
|
||||
train_args = TrainArgs.model_validate_json(f.read())
|
||||
train_args = TrainArgs.model_validate_json(fs.read_text(train_args_path))
|
||||
|
||||
if train_args.train_entropy_model:
|
||||
model_args = train_args.entropy_model
|
||||
|
@ -401,7 +400,8 @@ def load_consolidated_model_and_tokenizer(
|
|||
train_args.distributed.model_dtype
|
||||
]
|
||||
tokenizer = train_args.data.tokenizer_args.build()
|
||||
st_dict = torch.load(consolidated_path / CONSOLIDATE_NAME, weights_only=True)
|
||||
with fs.open(os.path.join(consolidated_path, CONSOLIDATE_NAME)) as f:
|
||||
st_dict = torch.load(f, weights_only=True)
|
||||
model.load_state_dict(st_dict["model"])
|
||||
model = model.cuda().eval()
|
||||
for param in model.parameters():
|
||||
|
|
|
@ -15,29 +15,37 @@ from bytelatent.entropy_model import load_entropy_model
|
|||
from bytelatent.tokenizers.build_tokenizer import TokenizerArgs
|
||||
|
||||
|
||||
def get_id_from_doc(doc: dict) -> int:
|
||||
def get_id_key(doc: dict) -> int:
|
||||
"""
|
||||
We need a reliable way to ensure that samples from jsonl
|
||||
and arrow are the same, but there is no unique id field,
|
||||
so derive the best possible
|
||||
"""
|
||||
if "sample_id" in doc:
|
||||
sample_id = doc["sample_id"]
|
||||
return "sample_id"
|
||||
elif "title" in doc:
|
||||
sample_id = doc["title"]
|
||||
return "title"
|
||||
elif "qid" in doc:
|
||||
sample_id = doc["qid"]
|
||||
return "qid"
|
||||
elif "paper_id" in doc:
|
||||
sample_id = doc["paper_id"]
|
||||
return "paper_id"
|
||||
elif "path" in doc:
|
||||
sample_id = doc["path"]
|
||||
return "path"
|
||||
elif "url" in doc:
|
||||
sample_id = doc["url"]
|
||||
return "url"
|
||||
elif "id" in doc:
|
||||
sample_id = doc["id"]
|
||||
return "id"
|
||||
else:
|
||||
raise ValueError(f"Could not find a id key from: {doc.keys()}")
|
||||
return str(sample_id)
|
||||
|
||||
|
||||
def get_id_from_doc(doc: dict) -> int:
|
||||
"""
|
||||
We need a reliable way to ensure that samples from jsonl
|
||||
and arrow are the same, but there is no unique id field,
|
||||
so derive the best possible
|
||||
"""
|
||||
return str(doc[get_id_key(doc)])
|
||||
|
||||
|
||||
def get_text(doc: dict):
|
||||
|
|
|
@ -527,7 +527,7 @@ def train(args: TrainArgs):
|
|||
step_tok_losses.append(tok_loss / train_state.scale)
|
||||
|
||||
world_size = get_world_size()
|
||||
if 1 < world_size <= 8:
|
||||
if 1 < world_size <= 8 and False:
|
||||
# For some reason, there are errors in reduces due to
|
||||
# not working for non-bf16 numbers. This function is a patched
|
||||
# version that converts gradients to bf16 before computing norms.
|
||||
|
|
Loading…
Reference in a new issue