mirror of
https://github.com/facebookresearch/blt.git
synced 2025-02-23 05:22:16 +00:00
Merge 45bfe94c1e
into sapling-pr-archive-EntilZha
This commit is contained in:
commit
2d1c766050
|
@ -1,8 +1,10 @@
|
||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
import json
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
|
import fsspec
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import yaml
|
import yaml
|
||||||
from omegaconf import OmegaConf
|
from omegaconf import OmegaConf
|
||||||
|
@ -10,11 +12,9 @@ from pydantic import BaseModel, ConfigDict
|
||||||
|
|
||||||
from bytelatent.checkpoint import CheckpointArgs
|
from bytelatent.checkpoint import CheckpointArgs
|
||||||
from bytelatent.data.data_types import Batch
|
from bytelatent.data.data_types import Batch
|
||||||
|
from bytelatent.data.file_util import get_fs
|
||||||
from bytelatent.data.iterators.abstract_iterator import StatefulIterator
|
from bytelatent.data.iterators.abstract_iterator import StatefulIterator
|
||||||
from bytelatent.data.iterators.arrow_iterator import (
|
from bytelatent.data.iterators.arrow_iterator import ArrowFileIterator
|
||||||
ArrowFileIterator,
|
|
||||||
find_and_sanitize_chunks,
|
|
||||||
)
|
|
||||||
from bytelatent.data.iterators.looping_iterator import LoopingIterator
|
from bytelatent.data.iterators.looping_iterator import LoopingIterator
|
||||||
from bytelatent.data.iterators.multiprocess_iterator import MultiprocessIterator
|
from bytelatent.data.iterators.multiprocess_iterator import MultiprocessIterator
|
||||||
from bytelatent.data.iterators.packing_iterator import PackingArgs, PackingIterator
|
from bytelatent.data.iterators.packing_iterator import PackingArgs, PackingIterator
|
||||||
|
@ -53,6 +53,33 @@ def parse_args(args_cls):
|
||||||
return pydantic_args
|
return pydantic_args
|
||||||
|
|
||||||
|
|
||||||
|
TRAIN_DATA_FILE_PATTERN = "*.chunk.*.jsonl"
|
||||||
|
|
||||||
|
|
||||||
|
def find_and_sanitize_chunks(
|
||||||
|
dataset_path: str,
|
||||||
|
world_size: int,
|
||||||
|
file_pattern: str,
|
||||||
|
s3_profile: str | None = None,
|
||||||
|
):
|
||||||
|
fs = get_fs(dataset_path, s3_profile=s3_profile)
|
||||||
|
path_with_glob = os.path.join(dataset_path, file_pattern)
|
||||||
|
dataset_chunks = fs.glob(path_with_glob)
|
||||||
|
n_chunks = len(dataset_chunks)
|
||||||
|
|
||||||
|
if n_chunks > world_size:
|
||||||
|
n_discard = n_chunks - world_size
|
||||||
|
dataset_chunks = dataset_chunks[:world_size]
|
||||||
|
else:
|
||||||
|
assert (
|
||||||
|
world_size % n_chunks == 0
|
||||||
|
), "World size should be a multiple of number of chunks"
|
||||||
|
|
||||||
|
assert n_chunks > 0, f"No valid chunks in {dataset_path}"
|
||||||
|
|
||||||
|
return dataset_chunks
|
||||||
|
|
||||||
|
|
||||||
def distribute_data_to_rank(
|
def distribute_data_to_rank(
|
||||||
*,
|
*,
|
||||||
dataset_path: str,
|
dataset_path: str,
|
||||||
|
@ -62,9 +89,10 @@ def distribute_data_to_rank(
|
||||||
rank: int,
|
rank: int,
|
||||||
world_size: int,
|
world_size: int,
|
||||||
s3_profile: str | None = None,
|
s3_profile: str | None = None,
|
||||||
|
file_pattern: str = TRAIN_DATA_FILE_PATTERN,
|
||||||
) -> ArrowFileIterator:
|
) -> ArrowFileIterator:
|
||||||
dataset_chunks = find_and_sanitize_chunks(
|
dataset_chunks = find_and_sanitize_chunks(
|
||||||
dataset_path, world_size, s3_profile=s3_profile
|
dataset_path, world_size, file_pattern, s3_profile=s3_profile
|
||||||
)
|
)
|
||||||
n_workers_per_chunk = world_size // len(dataset_chunks)
|
n_workers_per_chunk = world_size // len(dataset_chunks)
|
||||||
rank_to_arrow_iterator_params = []
|
rank_to_arrow_iterator_params = []
|
||||||
|
|
623
bytelatent/broken_train.py
Normal file
623
bytelatent/broken_train.py
Normal file
|
@ -0,0 +1,623 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
|
||||||
|
|
||||||
|
from datetime import timedelta
|
||||||
|
from enum import Enum
|
||||||
|
from functools import lru_cache
|
||||||
|
import logging
|
||||||
|
import math
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
|
from typing import Dict, List, Optional, Tuple
|
||||||
|
from torch import Tensor
|
||||||
|
import os
|
||||||
|
import pickle
|
||||||
|
|
||||||
|
import fsspec
|
||||||
|
import torch
|
||||||
|
import torch.distributed
|
||||||
|
import torch.nn.functional
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from torch.distributed._tensor import DTensor
|
||||||
|
|
||||||
|
from torch.distributed.device_mesh import init_device_mesh
|
||||||
|
from torch.utils._foreach_utils import (
|
||||||
|
_device_has_foreach_support,
|
||||||
|
_group_tensors_by_device_and_dtype,
|
||||||
|
_has_foreach_support,
|
||||||
|
)
|
||||||
|
|
||||||
|
from bytelatent.args import TrainArgs
|
||||||
|
from bytelatent.distributed import (
|
||||||
|
DistributedArgs,
|
||||||
|
check_model_value_range,
|
||||||
|
parallelize_model,
|
||||||
|
setup_env,
|
||||||
|
setup_torch_distributed,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger()
|
||||||
|
|
||||||
|
|
||||||
|
def set_root_log_level(log_level: str):
|
||||||
|
logger = logging.getLogger()
|
||||||
|
level: int | str = log_level.upper()
|
||||||
|
try:
|
||||||
|
level = int(log_level)
|
||||||
|
except ValueError:
|
||||||
|
pass
|
||||||
|
try:
|
||||||
|
logger.setLevel(level) # type: ignore
|
||||||
|
except Exception:
|
||||||
|
logger.warning(
|
||||||
|
f"Failed to set logging level to {log_level}, using default 'NOTSET'"
|
||||||
|
)
|
||||||
|
logger.setLevel(logging.NOTSET)
|
||||||
|
|
||||||
|
|
||||||
|
class LogFormatter(logging.Formatter):
|
||||||
|
"""
|
||||||
|
Custom logger for distributed jobs, displaying rank
|
||||||
|
and preserving indent from the custom prefix format.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.start_time = time.time()
|
||||||
|
self.rank = get_global_rank()
|
||||||
|
self.show_rank = not get_is_slurm_job() # srun has --label
|
||||||
|
|
||||||
|
def formatTime(self, record):
|
||||||
|
subsecond, seconds = math.modf(record.created)
|
||||||
|
curr_date = (
|
||||||
|
time.strftime("%y-%m-%d %H:%M:%S", time.localtime(seconds))
|
||||||
|
+ f".{int(subsecond * 1_000_000):06d}"
|
||||||
|
)
|
||||||
|
delta = timedelta(seconds=round(record.created - self.start_time))
|
||||||
|
return f"{curr_date} - {delta}"
|
||||||
|
|
||||||
|
def formatPrefix(self, record):
|
||||||
|
fmt_time = self.formatTime(record)
|
||||||
|
if self.show_rank:
|
||||||
|
return f"{self.rank}: {record.levelname:<7} {fmt_time} - "
|
||||||
|
else:
|
||||||
|
return f"{record.levelname:<7} {fmt_time} - "
|
||||||
|
|
||||||
|
def formatMessage(self, record, indent: str):
|
||||||
|
content = record.getMessage()
|
||||||
|
content = content.replace("\n", "\n" + indent)
|
||||||
|
# Exception handling as in the default formatter, albeit with indenting
|
||||||
|
# according to our custom prefix
|
||||||
|
if record.exc_info:
|
||||||
|
# Cache the traceback text to avoid converting it multiple times
|
||||||
|
# (it's constant anyway)
|
||||||
|
if not record.exc_text:
|
||||||
|
record.exc_text = self.formatException(record.exc_info)
|
||||||
|
if record.exc_text:
|
||||||
|
if content[-1:] != "\n":
|
||||||
|
content = content + "\n" + indent
|
||||||
|
content = content + indent.join(
|
||||||
|
[l + "\n" for l in record.exc_text.splitlines()]
|
||||||
|
)
|
||||||
|
if content[-1:] == "\n":
|
||||||
|
content = content[:-1]
|
||||||
|
if record.stack_info:
|
||||||
|
if content[-1:] != "\n":
|
||||||
|
content = content + "\n" + indent
|
||||||
|
stack_text = self.formatStack(record.stack_info)
|
||||||
|
content = content + indent.join([l + "\n" for l in stack_text.splitlines()])
|
||||||
|
if content[-1:] == "\n":
|
||||||
|
content = content[:-1]
|
||||||
|
|
||||||
|
return content
|
||||||
|
|
||||||
|
def format(self, record):
|
||||||
|
prefix = self.formatPrefix(record)
|
||||||
|
indent = " " * len(prefix)
|
||||||
|
content = self.formatMessage(record, indent)
|
||||||
|
return prefix + content
|
||||||
|
|
||||||
|
|
||||||
|
def init_logger(
|
||||||
|
log_file: str | None = None,
|
||||||
|
*,
|
||||||
|
name: str | None = None,
|
||||||
|
level: str = "INFO",
|
||||||
|
fs: fsspec.AbstractFileSystem | None = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Setup logging.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
log_file: A file name to save file logs to.
|
||||||
|
name: The name of the logger to configure, by default the root logger.
|
||||||
|
level: The logging level to use.
|
||||||
|
"""
|
||||||
|
set_root_log_level(level)
|
||||||
|
logger = logging.getLogger(name)
|
||||||
|
|
||||||
|
# stdout: everything
|
||||||
|
stdout_handler = logging.StreamHandler(sys.stdout)
|
||||||
|
stdout_handler.setLevel(logging.NOTSET)
|
||||||
|
stdout_handler.setFormatter(LogFormatter())
|
||||||
|
|
||||||
|
# stderr: warnings / errors and above
|
||||||
|
stderr_handler = logging.StreamHandler(sys.stderr)
|
||||||
|
stderr_handler.setLevel(logging.WARNING)
|
||||||
|
stderr_handler.setFormatter(LogFormatter())
|
||||||
|
|
||||||
|
# set stream handlers
|
||||||
|
logger.handlers.clear()
|
||||||
|
logger.handlers.append(stdout_handler)
|
||||||
|
logger.handlers.append(stderr_handler)
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def fixed_clip_grad_norm_(
|
||||||
|
parameters: torch.Tensor | list[torch.Tensor],
|
||||||
|
max_norm: float,
|
||||||
|
norm_type: float = 2.0,
|
||||||
|
error_if_nonfinite: bool = False,
|
||||||
|
foreach: Optional[bool] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
r"""Clip the gradient norm of an iterable of parameters.
|
||||||
|
|
||||||
|
The norm is computed over the norms of the individual gradients of all parameters,
|
||||||
|
as if the norms of the individual gradients were concatenated into a single vector.
|
||||||
|
Gradients are modified in-place.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a
|
||||||
|
single Tensor that will have gradients normalized
|
||||||
|
max_norm (float): max norm of the gradients
|
||||||
|
norm_type (float): type of the used p-norm. Can be ``'inf'`` for
|
||||||
|
infinity norm.
|
||||||
|
error_if_nonfinite (bool): if True, an error is thrown if the total
|
||||||
|
norm of the gradients from :attr:`parameters` is ``nan``,
|
||||||
|
``inf``, or ``-inf``. Default: False (will switch to True in the future)
|
||||||
|
foreach (bool): use the faster foreach-based implementation.
|
||||||
|
If ``None``, use the foreach implementation for CUDA and CPU native tensors and silently
|
||||||
|
fall back to the slow implementation for other device types.
|
||||||
|
Default: ``None``
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Total norm of the parameter gradients (viewed as a single vector).
|
||||||
|
"""
|
||||||
|
if isinstance(parameters, torch.Tensor):
|
||||||
|
parameters = [parameters]
|
||||||
|
grads = [p.grad.to(torch.bfloat16) for p in parameters if p.grad is not None]
|
||||||
|
max_norm = float(max_norm)
|
||||||
|
norm_type = float(norm_type)
|
||||||
|
if len(grads) == 0:
|
||||||
|
return torch.tensor(0.0)
|
||||||
|
first_device = grads[0].device
|
||||||
|
grouped_grads: Dict[
|
||||||
|
Tuple[torch.device, torch.dtype], Tuple[List[List[Tensor]], List[int]]
|
||||||
|
] = _group_tensors_by_device_and_dtype(
|
||||||
|
[grads]
|
||||||
|
) # type: ignore[assignment]
|
||||||
|
|
||||||
|
norms: List[Tensor] = []
|
||||||
|
for (device, _), ([device_grads], _) in grouped_grads.items(): # type: ignore[assignment]
|
||||||
|
if (foreach is None and _has_foreach_support(device_grads, device)) or (
|
||||||
|
foreach and _device_has_foreach_support(device)
|
||||||
|
):
|
||||||
|
norms.extend(torch._foreach_norm(device_grads, norm_type))
|
||||||
|
elif foreach:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"foreach=True was passed, but can't use the foreach API on {device.type} tensors"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
norms.extend([torch.linalg.vector_norm(g, norm_type) for g in device_grads])
|
||||||
|
|
||||||
|
total_norm = torch.linalg.vector_norm(
|
||||||
|
torch.stack([norm.to(first_device) for norm in norms]), norm_type
|
||||||
|
)
|
||||||
|
|
||||||
|
if error_if_nonfinite and torch.logical_or(total_norm.isnan(), total_norm.isinf()):
|
||||||
|
raise RuntimeError(
|
||||||
|
f"The total norm of order {norm_type} for gradients from "
|
||||||
|
"`parameters` is non-finite, so it cannot be clipped. To disable "
|
||||||
|
"this error and scale the gradients by the non-finite norm anyway, "
|
||||||
|
"set `error_if_nonfinite=False`"
|
||||||
|
)
|
||||||
|
clip_coef = max_norm / (total_norm + 1e-6)
|
||||||
|
# Note: multiplying by the clamped coef is redundant when the coef is clamped to 1, but doing so
|
||||||
|
# avoids a `if clip_coef < 1:` conditional which can require a CPU <=> device synchronization
|
||||||
|
# when the gradients do not reside in CPU memory.
|
||||||
|
clip_coef_clamped = torch.clamp(clip_coef, max=1.0)
|
||||||
|
for (device, _), ([device_grads], _) in grouped_grads.items(): # type: ignore[assignment]
|
||||||
|
if (foreach is None and _has_foreach_support(device_grads, device)) or (
|
||||||
|
foreach and _device_has_foreach_support(device)
|
||||||
|
):
|
||||||
|
torch._foreach_mul_(device_grads, clip_coef_clamped.to(device))
|
||||||
|
elif foreach:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"foreach=True was passed, but can't use the foreach API on {device.type} tensors"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
clip_coef_clamped_device = clip_coef_clamped.to(device)
|
||||||
|
for g in device_grads:
|
||||||
|
g.mul_(clip_coef_clamped_device)
|
||||||
|
|
||||||
|
return total_norm
|
||||||
|
|
||||||
|
|
||||||
|
def get_no_recompute_ops():
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache()
|
||||||
|
def get_is_torch_run() -> bool:
|
||||||
|
return os.environ.get("LOCAL_RANK") is not None
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache()
|
||||||
|
def get_is_slurm_job() -> bool:
|
||||||
|
return "SLURM_JOB_ID" in os.environ and not get_is_torch_run()
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache()
|
||||||
|
def get_global_rank() -> int:
|
||||||
|
if get_is_torch_run():
|
||||||
|
return int(os.environ["RANK"])
|
||||||
|
elif get_is_slurm_job():
|
||||||
|
return int(os.environ["SLURM_PROCID"])
|
||||||
|
else:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache()
|
||||||
|
def get_local_rank() -> int:
|
||||||
|
if get_is_torch_run():
|
||||||
|
return int(os.environ["LOCAL_RANK"])
|
||||||
|
elif get_is_slurm_job():
|
||||||
|
return int(os.environ["SLURM_LOCALID"])
|
||||||
|
else:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache()
|
||||||
|
def get_world_size() -> int:
|
||||||
|
if get_is_torch_run():
|
||||||
|
return int(os.environ["WORLD_SIZE"])
|
||||||
|
elif get_is_slurm_job():
|
||||||
|
return int(os.environ["SLURM_NTASKS"])
|
||||||
|
else:
|
||||||
|
return 1
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache()
|
||||||
|
def get_is_master() -> bool:
|
||||||
|
return get_global_rank() == 0
|
||||||
|
|
||||||
|
|
||||||
|
def validate_train_args(args: TrainArgs, output_size: int):
|
||||||
|
# assert args.model is not None or args.entropy_model is not None
|
||||||
|
if args.entropy_model is not None:
|
||||||
|
logger.info(f"Setting model output size to {args.entropy_model.vocab_size}")
|
||||||
|
args.entropy_model.vocab_size = output_size
|
||||||
|
|
||||||
|
assert args.dump_dir, "Dump dir not set"
|
||||||
|
|
||||||
|
if (
|
||||||
|
args.distributed.dp_replicate
|
||||||
|
* args.distributed.dp_shard
|
||||||
|
* args.distributed.tp_size
|
||||||
|
!= get_world_size()
|
||||||
|
):
|
||||||
|
logging.info("Modifying TrainArgs distributed config")
|
||||||
|
assert get_world_size() % args.distributed.dp_shard == 0
|
||||||
|
logging.info("World size: %s", get_world_size())
|
||||||
|
logging.info(
|
||||||
|
"Existing setting: train_args.distributed.dp_shard=%s",
|
||||||
|
args.distributed.dp_shard,
|
||||||
|
)
|
||||||
|
logging.info(
|
||||||
|
"Setting train_args.distributed.dp_replicate=%s, was dp_replicate=%s",
|
||||||
|
get_world_size() // args.distributed.dp_shard,
|
||||||
|
args.distributed.dp_replicate,
|
||||||
|
)
|
||||||
|
args.distributed.dp_replicate = get_world_size() // args.distributed.dp_shard
|
||||||
|
|
||||||
|
logging.info(
|
||||||
|
"Changing dp_replicate from %s to %s, to account for tp_size=%s",
|
||||||
|
args.distributed.dp_replicate,
|
||||||
|
args.distributed.dp_replicate // args.distributed.tp_size,
|
||||||
|
args.distributed.tp_size,
|
||||||
|
)
|
||||||
|
assert args.distributed.dp_replicate % args.distributed.tp_size == 0
|
||||||
|
args.distributed.dp_replicate = (
|
||||||
|
args.distributed.dp_replicate // args.distributed.tp_size
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.warning(
|
||||||
|
f"Setting Data Parallel size to {args.distributed.dp_replicate * args.distributed.dp_shard}"
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
args.distributed.dp_replicate
|
||||||
|
* args.distributed.dp_shard
|
||||||
|
* args.distributed.tp_size
|
||||||
|
== get_world_size()
|
||||||
|
)
|
||||||
|
|
||||||
|
if args.distributed.fsdp_type == "no_shard":
|
||||||
|
assert (
|
||||||
|
args.distributed.dp_shard == 1
|
||||||
|
and args.distributed.dp_replicate == get_world_size()
|
||||||
|
)
|
||||||
|
|
||||||
|
if args.model is not None:
|
||||||
|
args.model.max_seqlen = args.data.seq_len
|
||||||
|
if args.entropy_model is not None:
|
||||||
|
args.entropy_model.max_seqlen = args.data.seq_len
|
||||||
|
|
||||||
|
if args.distributed.tp_size == 1:
|
||||||
|
logger.warning(
|
||||||
|
"Tensor parallelism has not been tested for a while, use at your own risk"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert (
|
||||||
|
args.probe_freq != args.profiling.mem_steps
|
||||||
|
), "Don't profile during probe step"
|
||||||
|
assert (
|
||||||
|
args.probe_freq != args.profiling.profile_steps
|
||||||
|
), "Don't profile during probe step"
|
||||||
|
if args.logging.wandb is not None:
|
||||||
|
args.logging.wandb.name = args.name
|
||||||
|
|
||||||
|
if args.probe_freq is not None:
|
||||||
|
assert (
|
||||||
|
args.distributed.tp_size == 1
|
||||||
|
), "Probing not supported with tensor parallelism"
|
||||||
|
assert (
|
||||||
|
args.distributed.selective_activation_checkpointing is False
|
||||||
|
), "Probing not supported with selective activation checkpointing"
|
||||||
|
|
||||||
|
|
||||||
|
def compute_loss(p, y, mask, scale):
|
||||||
|
tok_loss = scale * F.cross_entropy(
|
||||||
|
p.flatten(0, 1), y.flatten(0, 1), reduction="none"
|
||||||
|
)
|
||||||
|
if mask is None:
|
||||||
|
loss = tok_loss.mean()
|
||||||
|
else:
|
||||||
|
mask = mask.flatten(0, 1)
|
||||||
|
tok_loss = tok_loss * mask
|
||||||
|
loss = tok_loss.sum() / (mask.sum() + 1e-6)
|
||||||
|
return loss, tok_loss
|
||||||
|
|
||||||
|
|
||||||
|
def get_device_mesh(distributed_args):
|
||||||
|
tp_size = distributed_args.tp_size
|
||||||
|
dp_replicate = distributed_args.dp_replicate
|
||||||
|
dp_shard = distributed_args.dp_shard
|
||||||
|
|
||||||
|
assert (
|
||||||
|
dp_replicate * dp_shard * tp_size == get_world_size()
|
||||||
|
), f"dp_replicate * dp_shard * tp_size ({dp_replicate} * {dp_shard} * {tp_size}) != world_size ({get_world_size()})"
|
||||||
|
|
||||||
|
dims = []
|
||||||
|
names = []
|
||||||
|
if dp_replicate >= 1:
|
||||||
|
dims.append(dp_replicate)
|
||||||
|
names.append("dp_replicate")
|
||||||
|
if dp_shard > 1 or distributed_args.fsdp_type == "no_shard":
|
||||||
|
dims.append(dp_shard)
|
||||||
|
names.append("dp_shard")
|
||||||
|
if tp_size > 1:
|
||||||
|
dims.append(tp_size)
|
||||||
|
names.append("tp")
|
||||||
|
dims = tuple(dims)
|
||||||
|
names = tuple(names)
|
||||||
|
|
||||||
|
return init_device_mesh("cuda", mesh_shape=dims, mesh_dim_names=names)
|
||||||
|
|
||||||
|
|
||||||
|
def build_fsdp_grouping_plan():
|
||||||
|
group_plan: tuple[int, bool] = []
|
||||||
|
|
||||||
|
# Grouping and output seperately
|
||||||
|
group_plan.append(("tok_embeddings", False))
|
||||||
|
|
||||||
|
# Grouping by layers
|
||||||
|
# for i in range(model_args.n_layers):
|
||||||
|
# group_plan.append((f"layers.{i}", False))
|
||||||
|
|
||||||
|
group_plan.append(("output", True))
|
||||||
|
|
||||||
|
return group_plan
|
||||||
|
|
||||||
|
|
||||||
|
class MinimalModel(torch.nn.Module):
|
||||||
|
def __init__(self, dim: int, vocab_size: int):
|
||||||
|
super().__init__()
|
||||||
|
self.tok_embeddings = torch.nn.Embedding(vocab_size, dim)
|
||||||
|
|
||||||
|
# self.norm = RMSNorm(args.dim, eps=args.norm_eps)
|
||||||
|
# self.layers = torch.nn.ModuleList()
|
||||||
|
# for _ in range(args.n_layers):
|
||||||
|
# self.layers.append(TransformerBlock(args))
|
||||||
|
|
||||||
|
self.output = torch.nn.Linear(
|
||||||
|
dim,
|
||||||
|
vocab_size,
|
||||||
|
bias=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, tokens):
|
||||||
|
h = self.tok_embeddings(tokens)
|
||||||
|
logits = self.output(h)
|
||||||
|
# logits = self.output(self.norm(h))
|
||||||
|
return logits
|
||||||
|
|
||||||
|
def reset_parameters(self, init_std=None):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def init_weights(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def train():
|
||||||
|
args = TrainArgs(
|
||||||
|
dump_dir="/tmp",
|
||||||
|
name="debug_bf16",
|
||||||
|
model=None,
|
||||||
|
entropy_model=None,
|
||||||
|
distributed=DistributedArgs(
|
||||||
|
fsdp_type="full_shard",
|
||||||
|
model_dtype="bf16",
|
||||||
|
matmul_allow_tf32=False,
|
||||||
|
selective_activation_checkpointing=False,
|
||||||
|
tp_size=1,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
tokenizer = args.data.tokenizer_args.build()
|
||||||
|
validate_train_args(
|
||||||
|
args,
|
||||||
|
tokenizer.n_words,
|
||||||
|
)
|
||||||
|
dump_fs = fsspec.filesystem("file")
|
||||||
|
init_logger(os.path.join(args.dump_dir, "train.log"), fs=dump_fs)
|
||||||
|
setup_env(args.env)
|
||||||
|
setup_torch_distributed(args.distributed)
|
||||||
|
world_mesh = get_device_mesh(args.distributed)
|
||||||
|
logger.info(f"Starting job: {args.name}")
|
||||||
|
|
||||||
|
# build dataloader
|
||||||
|
# need dp world size and rank
|
||||||
|
dp_mesh = world_mesh["dp_replicate"]
|
||||||
|
dp_degree = dp_mesh.size()
|
||||||
|
dp_rank = dp_mesh.get_local_rank()
|
||||||
|
if args.distributed.dp_shard > 1:
|
||||||
|
dp_rank = dp_rank * dp_degree + world_mesh["dp_shard"].get_local_rank()
|
||||||
|
dp_degree *= world_mesh["dp_shard"].size()
|
||||||
|
|
||||||
|
logger.info(f"Running on dp rank : {dp_rank}")
|
||||||
|
logger.info(f"Running on dp size : {dp_degree}")
|
||||||
|
|
||||||
|
torch.manual_seed(args.seed)
|
||||||
|
logger.info("Building model")
|
||||||
|
|
||||||
|
# Initializing Model in meta device allows us to initialize models much bigger than 1 gpu's memory
|
||||||
|
with torch.device("meta"):
|
||||||
|
model = MinimalModel(768, tokenizer.n_words)
|
||||||
|
|
||||||
|
model = parallelize_model(
|
||||||
|
model,
|
||||||
|
world_mesh,
|
||||||
|
args.model,
|
||||||
|
args.distributed,
|
||||||
|
fsdp_grouping_plan=build_fsdp_grouping_plan(),
|
||||||
|
tp_parallelize=None,
|
||||||
|
no_recompute_ops=get_no_recompute_ops(),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Once we shard the model on different gpus we can actually initialize the model
|
||||||
|
# First we create empty tensors of the correct shapes
|
||||||
|
model = model.to_empty(device="cuda")
|
||||||
|
# Then we init the model. Please make sure this function initializes *ALL* parameters
|
||||||
|
# and buffers, otherwise you will have random values in the unitialized tensors
|
||||||
|
# which will silently fail (give nan gradients for example)
|
||||||
|
|
||||||
|
with torch.random.fork_rng(devices=[torch.cuda.current_device()]):
|
||||||
|
torch.manual_seed(42)
|
||||||
|
model.init_weights()
|
||||||
|
check_model_value_range(model, range=10.0, std=1.0)
|
||||||
|
|
||||||
|
# data_loader = args.data.build_from_rank(dp_rank, dp_degree)
|
||||||
|
|
||||||
|
# train loop
|
||||||
|
model.train()
|
||||||
|
# data_loader = train_state.data_loader_state.build()
|
||||||
|
# batch_iterator = data_loader.create_iter()
|
||||||
|
# batch = next(batch_iterator)
|
||||||
|
# with open(f"/storage/home/par/toy-data/batch_{dp_rank}.pickle", "wb") as f:
|
||||||
|
# pickle.dump(batch, f)
|
||||||
|
with open(f"/storage/home/par/toy-data/batch_{dp_rank}.pickle", "rb") as f:
|
||||||
|
batch = pickle.load(f)
|
||||||
|
|
||||||
|
batch_x = torch.from_numpy(
|
||||||
|
batch.x,
|
||||||
|
).cuda()
|
||||||
|
batch_y = torch.from_numpy(batch.y).cuda()
|
||||||
|
mask = None if batch.mask is None else torch.from_numpy(batch.mask).cuda()
|
||||||
|
pred = model(batch_x)
|
||||||
|
loss, _ = compute_loss(pred, batch_y, mask, 1.0)
|
||||||
|
|
||||||
|
# We scale loss with grad_acc_steps so the gradient is the same
|
||||||
|
# regardless of grad_acc_steps
|
||||||
|
loss = loss / args.grad_acc_steps
|
||||||
|
|
||||||
|
# backward on scaled loss to create scaled gradients
|
||||||
|
loss.backward()
|
||||||
|
# For logging we undo that scaling
|
||||||
|
loss = loss.detach() * args.grad_acc_steps
|
||||||
|
|
||||||
|
world_size = get_world_size()
|
||||||
|
if 1 < world_size <= 8 and False:
|
||||||
|
# For some reason, there are errors in reduces due to
|
||||||
|
# not working for non-bf16 numbers. This function is a patched
|
||||||
|
# version that converts gradients to bf16 before computing norms.
|
||||||
|
# The error only happens in distributed training on one node,
|
||||||
|
# hence the guard
|
||||||
|
grad_norm = fixed_clip_grad_norm_(
|
||||||
|
model.parameters(), max_norm=args.optim.clip, foreach=True
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||||
|
model.parameters(), max_norm=args.optim.clip, foreach=True
|
||||||
|
)
|
||||||
|
|
||||||
|
grad_norm = (
|
||||||
|
grad_norm.full_tensor() if isinstance(grad_norm, DTensor) else grad_norm
|
||||||
|
).item()
|
||||||
|
|
||||||
|
# if isinstance(data_loader, MultiprocessIterator):
|
||||||
|
# logger.info("Closing MP iterator before exiting")
|
||||||
|
# data_loader.shutdown()
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
"""
|
||||||
|
The command line interface here uses OmegaConf https://omegaconf.readthedocs.io/en/2.3_branch/usage.html#from-command-line-arguments
|
||||||
|
This accepts arguments as a dot list
|
||||||
|
So if the dataclass looks like
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class DummyArgs:
|
||||||
|
name: str
|
||||||
|
model: LMTransformerArgsgs
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class LMTransformerArgsgs:
|
||||||
|
dim: int
|
||||||
|
|
||||||
|
Then you can pass model.dim=32 to change values in LMTransformerArgsgs
|
||||||
|
or just name=tictac for top level attributes.
|
||||||
|
|
||||||
|
The behavior here is as follows:
|
||||||
|
1. We instantiate TrainArgs with its default values
|
||||||
|
2. We override those default values with the ones in the provided config file
|
||||||
|
3. We override the result with the additional arguments provided through command line
|
||||||
|
|
||||||
|
For example, if the config is the following
|
||||||
|
|
||||||
|
model:
|
||||||
|
dim: 128
|
||||||
|
n_layers: 4
|
||||||
|
|
||||||
|
and you call train.py with train.py model.dim=64
|
||||||
|
|
||||||
|
Then the final TrainArgs will have
|
||||||
|
|
||||||
|
model:
|
||||||
|
dim: 64
|
||||||
|
n_layers: 4
|
||||||
|
|
||||||
|
Plus all the default values in TrainArgs dataclass.
|
||||||
|
"""
|
||||||
|
train()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
|
@ -16,6 +16,7 @@ from bytelatent import ByteLatentError
|
||||||
from bytelatent.data.data_types import BltExample
|
from bytelatent.data.data_types import BltExample
|
||||||
from bytelatent.data.file_util import get_fs
|
from bytelatent.data.file_util import get_fs
|
||||||
from bytelatent.data.iterators.abstract_iterator import IteratorState, StatefulIterator
|
from bytelatent.data.iterators.abstract_iterator import IteratorState, StatefulIterator
|
||||||
|
from bytelatent.preprocess.preprocess_entropies import get_id_key, get_text
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
@ -32,6 +33,7 @@ class ArrowFileIteratorState(BaseModel, IteratorState):
|
||||||
arrow_batch_size: int = 100
|
arrow_batch_size: int = 100
|
||||||
s3_profile: str | None
|
s3_profile: str | None
|
||||||
filesystem_type: str | None = None
|
filesystem_type: str | None = None
|
||||||
|
file_format: str
|
||||||
|
|
||||||
def build(self) -> "ArrowFileIterator":
|
def build(self) -> "ArrowFileIterator":
|
||||||
arrow_file = ArrowFileIterator(
|
arrow_file = ArrowFileIterator(
|
||||||
|
@ -44,6 +46,7 @@ class ArrowFileIteratorState(BaseModel, IteratorState):
|
||||||
dataset_files=self.dataset_files,
|
dataset_files=self.dataset_files,
|
||||||
s3_profile=self.s3_profile,
|
s3_profile=self.s3_profile,
|
||||||
filesystem_type=self.filesystem_type,
|
filesystem_type=self.filesystem_type,
|
||||||
|
file_format=self.file_format,
|
||||||
)
|
)
|
||||||
if self.row_num != 0:
|
if self.row_num != 0:
|
||||||
arrow_file._set_row_num(self.row_num)
|
arrow_file._set_row_num(self.row_num)
|
||||||
|
@ -70,6 +73,7 @@ class ArrowFileIterator(StatefulIterator):
|
||||||
dataset_files: list[str] | None = None,
|
dataset_files: list[str] | None = None,
|
||||||
s3_profile: str | None = None,
|
s3_profile: str | None = None,
|
||||||
filesystem_type: str | None = None,
|
filesystem_type: str | None = None,
|
||||||
|
file_format: str = "arrow",
|
||||||
):
|
):
|
||||||
assert 0 <= worker_id < num_workers, (worker_id, num_workers)
|
assert 0 <= worker_id < num_workers, (worker_id, num_workers)
|
||||||
if file_path is None and dataset_files is None:
|
if file_path is None and dataset_files is None:
|
||||||
|
@ -87,12 +91,16 @@ class ArrowFileIterator(StatefulIterator):
|
||||||
self.arrow_batch_size = arrow_batch_size
|
self.arrow_batch_size = arrow_batch_size
|
||||||
self.s3_profile = s3_profile
|
self.s3_profile = s3_profile
|
||||||
self.filesystem_type = filesystem_type
|
self.filesystem_type = filesystem_type
|
||||||
|
self.file_format = file_format
|
||||||
self.fs = None
|
self.fs = None
|
||||||
if self.filesystem_type is not None:
|
if self.filesystem_type is not None:
|
||||||
if self.filesystem_type == "file":
|
if self.filesystem_type == "file":
|
||||||
self.fs = fsspec.filesystem("file")
|
self.fs = fsspec.filesystem("file")
|
||||||
elif self.filesystem_type == "s3":
|
elif self.filesystem_type == "s3":
|
||||||
self.fs = fsspec.filesystem("s3", profile=s3_profile)
|
self.fs = fsspec.filesystem("s3", profile=s3_profile)
|
||||||
|
else:
|
||||||
|
raise ValueError("Unknown filesystem")
|
||||||
|
logger.info("Arrow iterator using fs=%s", self.fs)
|
||||||
|
|
||||||
if dataset_files is None:
|
if dataset_files is None:
|
||||||
# Prepare arrow shards
|
# Prepare arrow shards
|
||||||
|
@ -153,6 +161,7 @@ class ArrowFileIterator(StatefulIterator):
|
||||||
dataset_files=self.dataset_files,
|
dataset_files=self.dataset_files,
|
||||||
s3_profile=self.s3_profile,
|
s3_profile=self.s3_profile,
|
||||||
filesystem_type=self.filesystem_type,
|
filesystem_type=self.filesystem_type,
|
||||||
|
file_format=self.file_format,
|
||||||
)
|
)
|
||||||
|
|
||||||
def create_iter(
|
def create_iter(
|
||||||
|
@ -164,7 +173,7 @@ class ArrowFileIterator(StatefulIterator):
|
||||||
else:
|
else:
|
||||||
filesystem = None
|
filesystem = None
|
||||||
self.dataset = pa.dataset.dataset(
|
self.dataset = pa.dataset.dataset(
|
||||||
self.dataset_files, format="arrow", filesystem=filesystem
|
self.dataset_files, format=self.file_format, filesystem=filesystem
|
||||||
)
|
)
|
||||||
self.batch_iterator = self.dataset.to_batches(
|
self.batch_iterator = self.dataset.to_batches(
|
||||||
batch_size=self.arrow_batch_size
|
batch_size=self.arrow_batch_size
|
||||||
|
@ -173,13 +182,22 @@ class ArrowFileIterator(StatefulIterator):
|
||||||
if self.batch_to_consume is not None:
|
if self.batch_to_consume is not None:
|
||||||
batch_columns: dict[str, list] = self.batch_to_consume
|
batch_columns: dict[str, list] = self.batch_to_consume
|
||||||
self.batch_to_consume = None
|
self.batch_to_consume = None
|
||||||
|
if self.file_format == "arrow":
|
||||||
sample_ids = batch_columns["sample_id"]
|
sample_ids = batch_columns["sample_id"]
|
||||||
texts = batch_columns["text"]
|
texts = batch_columns["text"]
|
||||||
entropies = batch_columns["entropies"]
|
entropies = batch_columns["entropies"]
|
||||||
|
elif self.file_format == "json":
|
||||||
|
# This data hasn't been preprocessed to a uniform format,
|
||||||
|
# so we have to do it now and omit entropies
|
||||||
|
sample_ids = batch_columns[get_id_key(batch_columns)]
|
||||||
|
texts = get_text(batch_columns)
|
||||||
|
entropies = None
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown file format: {self.file_format}")
|
||||||
for i in range(len(sample_ids)):
|
for i in range(len(sample_ids)):
|
||||||
out = BltExample(
|
out = BltExample(
|
||||||
sample_id=sample_ids[i],
|
sample_id=sample_ids[i],
|
||||||
entropies=entropies[i],
|
entropies=entropies[i] if entropies is not None else None,
|
||||||
text=texts[i],
|
text=texts[i],
|
||||||
tokens=None,
|
tokens=None,
|
||||||
mask=None,
|
mask=None,
|
||||||
|
@ -191,13 +209,22 @@ class ArrowFileIterator(StatefulIterator):
|
||||||
|
|
||||||
for batch in self.batch_iterator:
|
for batch in self.batch_iterator:
|
||||||
batch_columns = batch.to_pydict()
|
batch_columns = batch.to_pydict()
|
||||||
|
if self.file_format == "arrow":
|
||||||
sample_ids = batch_columns["sample_id"]
|
sample_ids = batch_columns["sample_id"]
|
||||||
texts = batch_columns["text"]
|
texts = batch_columns["text"]
|
||||||
entropies = batch_columns["entropies"]
|
entropies = batch_columns["entropies"]
|
||||||
|
elif self.file_format == "json":
|
||||||
|
# This data hasn't been preprocessed to a uniform format,
|
||||||
|
# so we have to do it now and omit entropies
|
||||||
|
sample_ids = batch_columns[get_id_key(batch_columns)]
|
||||||
|
texts = get_text(batch_columns)
|
||||||
|
entropies = None
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown file format: {self.file_format}")
|
||||||
for i in range(len(sample_ids)):
|
for i in range(len(sample_ids)):
|
||||||
out = BltExample(
|
out = BltExample(
|
||||||
sample_id=sample_ids[i],
|
sample_id=sample_ids[i],
|
||||||
entropies=entropies[i],
|
entropies=entropies[i] if entropies is not None else None,
|
||||||
text=texts[i],
|
text=texts[i],
|
||||||
tokens=None,
|
tokens=None,
|
||||||
mask=None,
|
mask=None,
|
||||||
|
@ -231,13 +258,24 @@ class ArrowFileIterator(StatefulIterator):
|
||||||
for batch in self.batch_iterator:
|
for batch in self.batch_iterator:
|
||||||
if len(batch) > curr_remaining:
|
if len(batch) > curr_remaining:
|
||||||
batch_columns: dict[str, list] = batch.to_pydict()
|
batch_columns: dict[str, list] = batch.to_pydict()
|
||||||
batch_columns["sample_id"] = batch_columns["sample_id"][
|
if self.file_format == "arrow":
|
||||||
|
leftover_sample_ids = batch_columns["sample_id"][
|
||||||
curr_remaining:
|
curr_remaining:
|
||||||
]
|
]
|
||||||
batch_columns["entropies"] = batch_columns["entropies"][
|
leftover_entropies = batch_columns["entropies"][curr_remaining:]
|
||||||
|
leftover_texts = batch_columns["text"][curr_remaining:]
|
||||||
|
elif self.file_format == "json":
|
||||||
|
leftover_sample_ids = batch_columns[get_id_key(batch_columns)][
|
||||||
curr_remaining:
|
curr_remaining:
|
||||||
]
|
]
|
||||||
batch_columns["text"] = batch_columns["text"][curr_remaining:]
|
leftover_entropies = None
|
||||||
|
leftover_texts = get_text(batch_columns)[curr_remaining:]
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown file format: {self.file_format}")
|
||||||
|
|
||||||
|
batch_columns["sample_id"] = leftover_sample_ids
|
||||||
|
batch_columns["entropies"] = leftover_entropies
|
||||||
|
batch_columns["text"] = leftover_texts
|
||||||
self.batch_to_consume = batch_columns
|
self.batch_to_consume = batch_columns
|
||||||
break
|
break
|
||||||
elif len(batch) == curr_remaining:
|
elif len(batch) == curr_remaining:
|
||||||
|
@ -250,30 +288,3 @@ class ArrowFileIterator(StatefulIterator):
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Finished setting arrow position to {target_row_num} for {self.dataset_files}"
|
f"Finished setting arrow position to {target_row_num} for {self.dataset_files}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
TRAIN_DATA_FILE_PATTERN = "*.chunk.*.jsonl"
|
|
||||||
|
|
||||||
|
|
||||||
def find_and_sanitize_chunks(
|
|
||||||
dataset_path: str,
|
|
||||||
world_size: int,
|
|
||||||
file_pattern: str = TRAIN_DATA_FILE_PATTERN,
|
|
||||||
s3_profile: str | None = None,
|
|
||||||
):
|
|
||||||
fs = get_fs(dataset_path, s3_profile=s3_profile)
|
|
||||||
path_with_glob = os.path.join(dataset_path, file_pattern)
|
|
||||||
dataset_chunks = fs.glob(path_with_glob)
|
|
||||||
n_chunks = len(dataset_chunks)
|
|
||||||
|
|
||||||
if n_chunks > world_size:
|
|
||||||
n_discard = n_chunks - world_size
|
|
||||||
dataset_chunks = dataset_chunks[:world_size]
|
|
||||||
else:
|
|
||||||
assert (
|
|
||||||
world_size % n_chunks == 0
|
|
||||||
), "World size should be a multiple of number of chunks"
|
|
||||||
|
|
||||||
assert n_chunks > 0, f"No valid chunks in {dataset_path}"
|
|
||||||
|
|
||||||
return dataset_chunks
|
|
||||||
|
|
|
@ -15,9 +15,16 @@ from lm_eval.api.model import LM
|
||||||
from omegaconf import OmegaConf
|
from omegaconf import OmegaConf
|
||||||
from pydantic import BaseModel, ConfigDict
|
from pydantic import BaseModel, ConfigDict
|
||||||
|
|
||||||
from bytelatent.args import EvalArgs, ValidationArgs, parse_args
|
from bytelatent.args import (
|
||||||
|
EvalArgs,
|
||||||
|
TrainArgs,
|
||||||
|
ValidationArgs,
|
||||||
|
find_and_sanitize_chunks,
|
||||||
|
parse_args,
|
||||||
|
)
|
||||||
from bytelatent.checkpoint import CONSOLIDATE_FOLDER, consolidate_checkpoints
|
from bytelatent.checkpoint import CONSOLIDATE_FOLDER, consolidate_checkpoints
|
||||||
from bytelatent.data.file_util import get_fs
|
from bytelatent.data.file_util import get_fs
|
||||||
|
from bytelatent.data.iterators.arrow_iterator import ArrowFileIterator
|
||||||
from bytelatent.distributed import (
|
from bytelatent.distributed import (
|
||||||
DistributedArgs,
|
DistributedArgs,
|
||||||
dist_mean_dict,
|
dist_mean_dict,
|
||||||
|
@ -117,19 +124,40 @@ class EvalHarnessLM(LM):
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
|
||||||
def eval_on_val(generator, val_args: ValidationArgs, train_cfg):
|
def eval_on_val(generator, val_args: ValidationArgs, train_cfg: TrainArgs):
|
||||||
srcs = {}
|
srcs = []
|
||||||
for src in val_args.sources:
|
for src in val_args.sources:
|
||||||
path = os.path.join(val_args.root_dir, src)
|
path = os.path.join(val_args.root_dir, src)
|
||||||
srcs[path] = 1.0
|
srcs.append(path)
|
||||||
|
|
||||||
for src in train_cfg.data.sources:
|
for src in train_cfg.data.sources:
|
||||||
path = os.path.join(train_cfg.data.root_dir, src)
|
path = os.path.join(train_cfg.data.root_dir, src)
|
||||||
srcs[path] = 1.0
|
srcs.append(path)
|
||||||
|
|
||||||
multi_state = init_choice_state(
|
path_to_iter = {}
|
||||||
"", srcs, 0, get_global_rank(), get_world_size(), "*.val.jsonl"
|
for path in srcs:
|
||||||
|
chunks = find_and_sanitize_chunks(
|
||||||
|
path,
|
||||||
|
world_size=1,
|
||||||
|
file_pattern="*.val.jsonl",
|
||||||
|
s3_profile=train_cfg.data.s3_profile,
|
||||||
)
|
)
|
||||||
path_to_iter = setup_sources(multi_state)
|
assert (
|
||||||
|
len(chunks) == 1
|
||||||
|
), f"There should be only 1 chunk per validation file, but found: {chunks}"
|
||||||
|
chunk = chunks[0]
|
||||||
|
iterator = ArrowFileIterator(
|
||||||
|
dataset_files=[chunk],
|
||||||
|
file_path=None,
|
||||||
|
preprocess_dir=None,
|
||||||
|
entropy_model_name=None,
|
||||||
|
worker_id=0,
|
||||||
|
num_workers=1,
|
||||||
|
arrow_batch_size=train_cfg.data.arrow_batch_size,
|
||||||
|
s3_profile=train_cfg.data.s3_profile,
|
||||||
|
file_format="json",
|
||||||
|
)
|
||||||
|
path_to_iter[path] = iterator
|
||||||
|
|
||||||
max_gen_len = generator.max_gen_len
|
max_gen_len = generator.max_gen_len
|
||||||
# We temporarily lower max gen len
|
# We temporarily lower max gen len
|
||||||
|
@ -137,16 +165,11 @@ def eval_on_val(generator, val_args: ValidationArgs, train_cfg):
|
||||||
|
|
||||||
all_val_metrics = {}
|
all_val_metrics = {}
|
||||||
for src in path_to_iter:
|
for src in path_to_iter:
|
||||||
jsonl_iterator = path_to_iter[src]
|
example_iterator = path_to_iter[src].create_iter()
|
||||||
texts = []
|
texts = []
|
||||||
logger.info(f"Running validation on {src}...")
|
logger.info(f"Running validation on {src}...")
|
||||||
for step, (content, state) in enumerate(jsonl_iterator):
|
for step, example in enumerate(example_iterator):
|
||||||
if state["current_iter"] > 0 or (
|
texts.append(example.text)
|
||||||
val_args.max_steps is not None and step >= val_args.max_steps
|
|
||||||
):
|
|
||||||
break
|
|
||||||
content_key = "text" if ("text" in content) else "content"
|
|
||||||
texts.append(content[content_key])
|
|
||||||
|
|
||||||
_, loglikelihood, _ = generator.generate(texts)
|
_, loglikelihood, _ = generator.generate(texts)
|
||||||
|
|
||||||
|
@ -191,7 +214,7 @@ def launch_eval(eval_args: EvalArgs):
|
||||||
else:
|
else:
|
||||||
consolidate_path = os.path.join(eval_args.ckpt_dir, CONSOLIDATE_FOLDER)
|
consolidate_path = os.path.join(eval_args.ckpt_dir, CONSOLIDATE_FOLDER)
|
||||||
if not fs.exists(consolidate_path) and get_global_rank() == 0:
|
if not fs.exists(consolidate_path) and get_global_rank() == 0:
|
||||||
consolidate_path = consolidate_checkpoints(eval_args.ckpt_dir)
|
consolidate_path = consolidate_checkpoints(fs, eval_args.ckpt_dir)
|
||||||
|
|
||||||
fs.mkdirs(eval_args.dump_dir, exist_ok=True)
|
fs.mkdirs(eval_args.dump_dir, exist_ok=True)
|
||||||
with fs.open(os.path.join(eval_args.dump_dir, "config.yaml"), "w") as f:
|
with fs.open(os.path.join(eval_args.dump_dir, "config.yaml"), "w") as f:
|
||||||
|
@ -210,10 +233,12 @@ def launch_eval(eval_args: EvalArgs):
|
||||||
|
|
||||||
wrap = EvalHarnessLM(generator)
|
wrap = EvalHarnessLM(generator)
|
||||||
# Redo
|
# Redo
|
||||||
results = simple_evaluate(wrap, eval_args.harness.model_dump())
|
results = simple_evaluate(wrap, **eval_args.harness.model_dump())
|
||||||
|
|
||||||
val_results = None
|
val_results = None
|
||||||
if eval_args.validation:
|
if eval_args.validation:
|
||||||
val_results = eval_on_val(generator, eval_args.validation, train_cfg)
|
val_results = eval_on_val(generator, eval_args.validation, train_cfg)
|
||||||
|
|
||||||
if get_global_rank() == 0:
|
if get_global_rank() == 0:
|
||||||
with fs.open(os.path.join(eval_args.dump_dir, "results.json"), "w") as f:
|
with fs.open(os.path.join(eval_args.dump_dir, "results.json"), "w") as f:
|
||||||
f.write(json.dumps(results))
|
f.write(json.dumps(results))
|
||||||
|
@ -222,6 +247,7 @@ def launch_eval(eval_args: EvalArgs):
|
||||||
with fs.open(os.path.join(eval_args.dump_dir, "validation.json"), "w") as f:
|
with fs.open(os.path.join(eval_args.dump_dir, "validation.json"), "w") as f:
|
||||||
f.write(json.dumps(val_results))
|
f.write(json.dumps(val_results))
|
||||||
logger.info(f"All validation results: {val_results}")
|
logger.info(f"All validation results: {val_results}")
|
||||||
|
|
||||||
if eval_args.metric_log_dir and get_global_rank() == 0:
|
if eval_args.metric_log_dir and get_global_rank() == 0:
|
||||||
metric_log_path = os.path.join(eval_args.metric_log_dir, "metrics.eval.jsonl")
|
metric_log_path = os.path.join(eval_args.metric_log_dir, "metrics.eval.jsonl")
|
||||||
|
|
||||||
|
|
|
@ -387,8 +387,7 @@ def load_consolidated_model_and_tokenizer(
|
||||||
):
|
):
|
||||||
train_args_path = os.path.join(consolidated_path, "params.json")
|
train_args_path = os.path.join(consolidated_path, "params.json")
|
||||||
fs = get_fs(train_args_path)
|
fs = get_fs(train_args_path)
|
||||||
with fs.open(train_args_path) as f:
|
train_args = TrainArgs.model_validate_json(fs.read_text(train_args_path))
|
||||||
train_args = TrainArgs.model_validate_json(f.read())
|
|
||||||
|
|
||||||
if train_args.train_entropy_model:
|
if train_args.train_entropy_model:
|
||||||
model_args = train_args.entropy_model
|
model_args = train_args.entropy_model
|
||||||
|
@ -401,7 +400,8 @@ def load_consolidated_model_and_tokenizer(
|
||||||
train_args.distributed.model_dtype
|
train_args.distributed.model_dtype
|
||||||
]
|
]
|
||||||
tokenizer = train_args.data.tokenizer_args.build()
|
tokenizer = train_args.data.tokenizer_args.build()
|
||||||
st_dict = torch.load(consolidated_path / CONSOLIDATE_NAME, weights_only=True)
|
with fs.open(os.path.join(consolidated_path, CONSOLIDATE_NAME)) as f:
|
||||||
|
st_dict = torch.load(f, weights_only=True)
|
||||||
model.load_state_dict(st_dict["model"])
|
model.load_state_dict(st_dict["model"])
|
||||||
model = model.cuda().eval()
|
model = model.cuda().eval()
|
||||||
for param in model.parameters():
|
for param in model.parameters():
|
||||||
|
|
|
@ -15,29 +15,37 @@ from bytelatent.entropy_model import load_entropy_model
|
||||||
from bytelatent.tokenizers.build_tokenizer import TokenizerArgs
|
from bytelatent.tokenizers.build_tokenizer import TokenizerArgs
|
||||||
|
|
||||||
|
|
||||||
def get_id_from_doc(doc: dict) -> int:
|
def get_id_key(doc: dict) -> int:
|
||||||
"""
|
"""
|
||||||
We need a reliable way to ensure that samples from jsonl
|
We need a reliable way to ensure that samples from jsonl
|
||||||
and arrow are the same, but there is no unique id field,
|
and arrow are the same, but there is no unique id field,
|
||||||
so derive the best possible
|
so derive the best possible
|
||||||
"""
|
"""
|
||||||
if "sample_id" in doc:
|
if "sample_id" in doc:
|
||||||
sample_id = doc["sample_id"]
|
return "sample_id"
|
||||||
elif "title" in doc:
|
elif "title" in doc:
|
||||||
sample_id = doc["title"]
|
return "title"
|
||||||
elif "qid" in doc:
|
elif "qid" in doc:
|
||||||
sample_id = doc["qid"]
|
return "qid"
|
||||||
elif "paper_id" in doc:
|
elif "paper_id" in doc:
|
||||||
sample_id = doc["paper_id"]
|
return "paper_id"
|
||||||
elif "path" in doc:
|
elif "path" in doc:
|
||||||
sample_id = doc["path"]
|
return "path"
|
||||||
elif "url" in doc:
|
elif "url" in doc:
|
||||||
sample_id = doc["url"]
|
return "url"
|
||||||
elif "id" in doc:
|
elif "id" in doc:
|
||||||
sample_id = doc["id"]
|
return "id"
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Could not find a id key from: {doc.keys()}")
|
raise ValueError(f"Could not find a id key from: {doc.keys()}")
|
||||||
return str(sample_id)
|
|
||||||
|
|
||||||
|
def get_id_from_doc(doc: dict) -> int:
|
||||||
|
"""
|
||||||
|
We need a reliable way to ensure that samples from jsonl
|
||||||
|
and arrow are the same, but there is no unique id field,
|
||||||
|
so derive the best possible
|
||||||
|
"""
|
||||||
|
return str(doc[get_id_key(doc)])
|
||||||
|
|
||||||
|
|
||||||
def get_text(doc: dict):
|
def get_text(doc: dict):
|
||||||
|
|
|
@ -527,7 +527,7 @@ def train(args: TrainArgs):
|
||||||
step_tok_losses.append(tok_loss / train_state.scale)
|
step_tok_losses.append(tok_loss / train_state.scale)
|
||||||
|
|
||||||
world_size = get_world_size()
|
world_size = get_world_size()
|
||||||
if 1 < world_size <= 8:
|
if 1 < world_size <= 8 and False:
|
||||||
# For some reason, there are errors in reduces due to
|
# For some reason, there are errors in reduces due to
|
||||||
# not working for non-bf16 numbers. This function is a patched
|
# not working for non-bf16 numbers. This function is a patched
|
||||||
# version that converts gradients to bf16 before computing norms.
|
# version that converts gradients to bf16 before computing norms.
|
||||||
|
|
Loading…
Reference in a new issue