From 7517ac2a9f3fbe2106a37d1817584e02202f2de5 Mon Sep 17 00:00:00 2001 From: Pedro Rodriguez <par@meta.com> Date: Tue, 11 Mar 2025 09:57:19 -0700 Subject: [PATCH] Get evals working again. (#46) - PPL/validation: Works now and uses multi-gpu. For some reason 1 GPU differs from multi-GPU, can debug in a followup PR - Generation evals likely work, but are very slow, so disabled for now Test Plan: ``` torchrun --nproc-per-node 8 -m bytelatent.eval config=../internal-blt/configs/eval.yaml ``` --- bytelatent/args.py | 4 + bytelatent/distributed.py | 42 +++++++ bytelatent/eval.py | 253 ++++++++++++++++++++++++++++++++------ bytelatent/generate.py | 6 +- bytelatent/metrics.py | 2 +- bytelatent/train.py | 70 ++--------- 6 files changed, 276 insertions(+), 101 deletions(-) diff --git a/bytelatent/args.py b/bytelatent/args.py index bad4d17..13acfc0 100644 --- a/bytelatent/args.py +++ b/bytelatent/args.py @@ -270,6 +270,10 @@ class EvalArgs(BaseModel): dump_dir: str | None = None ckpt_dir: str | None = None metric_log_dir: str | None = None + + run_ppl: bool = True + run_tasks: bool = False + generator: PackedCausalTransformerGeneratorArgs = ( PackedCausalTransformerGeneratorArgs() ) diff --git a/bytelatent/distributed.py b/bytelatent/distributed.py index 80661d5..284c717 100644 --- a/bytelatent/distributed.py +++ b/bytelatent/distributed.py @@ -15,6 +15,7 @@ from functools import lru_cache, partial, reduce from itertools import chain from typing import List, Optional, Tuple, Union +import numpy as np import torch # for no recompute ops @@ -78,6 +79,40 @@ class DistributedArgs(BaseModel): spawn_method: str = "forkserver" + def configure_world(self): + pass + if self.dp_replicate * self.dp_shard * self.tp_size != get_world_size(): + logging.info("Modifying TrainArgs distributed config") + assert get_world_size() % self.dp_shard == 0 + logging.info("World size: %s", get_world_size()) + logging.info( + "Existing setting: train_args.distributed.dp_shard=%s", + self.dp_shard, + ) + logging.info( + "Setting train_args.distributed.dp_replicate=%s, was dp_replicate=%s", + get_world_size() // self.dp_shard, + self.dp_replicate, + ) + self.dp_replicate = get_world_size() // self.dp_shard + + logging.info( + "Changing dp_replicate from %s to %s, to account for tp_size=%s", + self.dp_replicate, + self.dp_replicate // self.tp_size, + self.tp_size, + ) + assert self.dp_replicate % self.tp_size == 0 + self.dp_replicate = self.dp_replicate // self.tp_size + + logger.warning( + f"Setting Data Parallel size to {self.dp_replicate * self.dp_shard}" + ) + assert self.dp_replicate * self.dp_shard * self.tp_size == get_world_size() + + if self.fsdp_type == "no_shard": + assert self.dp_shard == 1 and self.dp_replicate == get_world_size() + class EnvironmentArgs(BaseModel): model_config = ConfigDict(extra="forbid") @@ -151,6 +186,13 @@ def dist_mean_dict(x): return r +def to_py_num(num: int | float | torch.Tensor | np.ndarray) -> int | float: + if isinstance(num, (torch.Tensor, np.ndarray)): + return num.item() + else: + return num + + @lru_cache() def get_is_torch_run() -> bool: return os.environ.get("LOCAL_RANK") is not None diff --git a/bytelatent/eval.py b/bytelatent/eval.py index 50e17cd..0622979 100644 --- a/bytelatent/eval.py +++ b/bytelatent/eval.py @@ -2,6 +2,7 @@ import json import logging +import math import os from collections import defaultdict from datetime import datetime @@ -10,22 +11,48 @@ import torch from lm_eval import simple_evaluate from lm_eval.api.instance import Instance from lm_eval.api.model import LM +from rich.progress import track +from torch.nn import functional as F -from bytelatent.args import EvalArgs, ValidationArgs +from bytelatent.args import ( + EvalArgs, + TrainArgs, + ValidationArgs, + find_and_sanitize_chunks, +) from bytelatent.checkpoint import CONSOLIDATE_FOLDER, consolidate_checkpoints from bytelatent.config_parser import parse_args_to_pydantic_model from bytelatent.data.file_util import get_fs +from bytelatent.data.iterators.arrow_iterator import ArrowFileIterator +from bytelatent.data.iterators.limit_iterator import LimitIterator +from bytelatent.data.iterators.packing_iterator import ( + PackingArgs, + PackingIterator, + PackingMode, +) +from bytelatent.data.iterators.preprocess_iterator import PreprocessIterator +from bytelatent.data.iterators.sequence_iterator import ( + SequenceIterator, + SequencePackingArgs, +) +from bytelatent.data.patcher import PatcherArgs, PatchingModeEnum from bytelatent.distributed import ( DistributedArgs, dist_mean_dict, + dist_sum, + get_device_mesh, get_global_rank, get_world_size, setup_torch_distributed, + to_py_num, ) from bytelatent.generate import ( PackedCausalTransformerGenerator, load_consolidated_model_and_tokenizer, ) +from bytelatent.model.blt import ByteLatentTransformer +from bytelatent.tokenizers.build_tokenizer import TokenizerArgs +from bytelatent.transformer import LMTransformer EVAL_FOLDER_NAME = "{:010d}" @@ -113,19 +140,134 @@ class EvalHarnessLM(LM): return results -def eval_on_val(generator, val_args: ValidationArgs, train_cfg): - srcs = {} +@torch.no_grad() +def eval_ppl_on_path( + *, + world_rank: int, + world_size: int, + model: LMTransformer | ByteLatentTransformer, + tokenizer_args: TokenizerArgs, + patcher_args: PatcherArgs, + add_patches: bool, + path: str, + batch_size: int, + arrow_batch_size: int, + max_n_docs: int | None, + s3_profile: str | None = None, +): + model.eval() + tokenizer = tokenizer_args.build() + seq_len = model.get_output_seq_len() + chunks = find_and_sanitize_chunks( + path, + world_size=1, + file_pattern="*.val.jsonl", + s3_profile=s3_profile, + ) + assert ( + len(chunks) == 1 + ), f"There should be only 1 chunk per validation file, but found: {chunks}" + chunk = chunks[0] + arrow_iterator = ArrowFileIterator( + file_path=chunk, + preprocess_dir=None, + entropy_model_name=None, + worker_id=world_rank, + num_workers=world_size, + arrow_batch_size=arrow_batch_size, + s3_profile=s3_profile, + file_format="json", + ) + if max_n_docs is not None: + arrow_iterator = LimitIterator(arrow_iterator, limit=max_n_docs) + preprocess_iterator = PreprocessIterator( + arrow_iterator, + patcher_args=patcher_args, + tokenizer_args=tokenizer_args, + add_patches=add_patches, + ) + sequence_iterator = SequenceIterator( + preprocess_iterator, + sequence_packing_args=SequencePackingArgs( + output_seq_len=seq_len, + # Effectively disables shuffles + buffer_size=1, + ), + rng_state=None, + ) + packing_args = PackingArgs( + batch_size=batch_size, + seq_len=seq_len, + # TODO: make these seq lens worth with blt + max_length=seq_len, + pad_to_max_length=True, + enable_byte_ngrams=False, + pad_id=tokenizer.boe_id, + packing_mode=PackingMode.BYTES, + ) + packing_iterator = PackingIterator(sequence_iterator, packing_args=packing_args) + total_loss = 0.0 + n_bytes = 0 + batch_iterator = packing_iterator.create_iter() + for batch in batch_iterator: + x = torch.from_numpy(batch.x).cuda() + y = torch.from_numpy(batch.y).cuda() + mask = None if batch.mask is None else torch.from_numpy(batch.mask).cuda() + if tokenizer_args.name in ["bytes", "blt"]: + n_bytes += y.numel() if mask is None else mask.sum().item() + pred = model(x) + loss = F.cross_entropy(pred.flatten(0, 1), y.flatten(0, 1), reduction="sum") + total_loss += loss.item() + else: + raise NotImplementedError() + all_n_bytes = to_py_num(dist_sum(n_bytes)) + all_total_loss = to_py_num(dist_sum(total_loss)) + return { + "n_bytes": all_n_bytes, + "n_bytes_gpu": n_bytes, + "loss_sum": all_total_loss, + "loss_sum_gpu": total_loss, + "loss_mean": all_total_loss / all_n_bytes, + "loss_mean_gpu": total_loss / n_bytes, + "ppl": math.exp(all_total_loss / all_n_bytes) if all_n_bytes > 0 else 0.0, + "bpb": all_total_loss / math.log(2) / all_n_bytes, + } + + +def eval_on_val(generator, val_args: ValidationArgs, train_cfg: TrainArgs): + srcs = [] for src in val_args.sources: path = os.path.join(val_args.root_dir, src) - srcs[path] = 1.0 + srcs.append(path) + for src in train_cfg.data.sources: path = os.path.join(train_cfg.data.root_dir, src) - srcs[path] = 1.0 + srcs.append(path) - multi_state = init_choice_state( - "", srcs, 0, get_global_rank(), get_world_size(), "*.val.jsonl" - ) - path_to_iter = setup_sources(multi_state) + path_to_iter = {} + for path in srcs: + chunks = find_and_sanitize_chunks( + path, + world_size=1, + file_pattern="*.val.jsonl", + s3_profile=train_cfg.data.s3_profile, + ) + assert ( + len(chunks) == 1 + ), f"There should be only 1 chunk per validation file, but found: {chunks}" + chunk = chunks[0] + iterator = ArrowFileIterator( + dataset_files=[chunk], + file_path=None, + preprocess_dir=None, + entropy_model_name=None, + worker_id=0, + num_workers=1, + arrow_batch_size=train_cfg.data.arrow_batch_size, + s3_profile=train_cfg.data.s3_profile, + file_format="json", + ) + path_to_iter[path] = iterator max_gen_len = generator.max_gen_len # We temporarily lower max gen len @@ -133,16 +275,11 @@ def eval_on_val(generator, val_args: ValidationArgs, train_cfg): all_val_metrics = {} for src in path_to_iter: - jsonl_iterator = path_to_iter[src] + example_iterator = path_to_iter[src].create_iter() texts = [] logger.info(f"Running validation on {src}...") - for step, (content, state) in enumerate(jsonl_iterator): - if state["current_iter"] > 0 or ( - val_args.max_steps is not None and step >= val_args.max_steps - ): - break - content_key = "text" if ("text" in content) else "content" - texts.append(content[content_key]) + for step, example in enumerate(example_iterator): + texts.append(example.text) _, loglikelihood, _ = generator.generate(texts) @@ -174,8 +311,18 @@ def eval_on_val(generator, val_args: ValidationArgs, train_cfg): def launch_eval(eval_args: EvalArgs): + assert eval_args.dump_dir is not None + assert eval_args.ckpt_dir is not None + distributed_args = DistributedArgs() + distributed_args.configure_world() if not torch.distributed.is_initialized(): - setup_torch_distributed(DistributedArgs()) + setup_torch_distributed(distributed_args) + + world_mesh = get_device_mesh(distributed_args) + dp_mesh = world_mesh["dp_replicate"] + assert distributed_args.dp_shard == 1 + world_size = dp_mesh.size() + world_rank = dp_mesh.get_local_rank() fs = get_fs(eval_args.ckpt_dir, s3_profile=eval_args.s3_profile) if ( @@ -187,7 +334,7 @@ def launch_eval(eval_args: EvalArgs): else: consolidate_path = os.path.join(eval_args.ckpt_dir, CONSOLIDATE_FOLDER) if not fs.exists(consolidate_path) and get_global_rank() == 0: - consolidate_path = consolidate_checkpoints(eval_args.ckpt_dir) + consolidate_path = consolidate_checkpoints(fs, eval_args.ckpt_dir) fs.mkdirs(eval_args.dump_dir, exist_ok=True) with fs.open(os.path.join(eval_args.dump_dir, "config.yaml"), "w") as f: @@ -200,35 +347,67 @@ def launch_eval(eval_args: EvalArgs): model, tokenizer, train_cfg = load_consolidated_model_and_tokenizer( consolidate_path, ) - logger.info("Model loaded") model.eval() - generator = PackedCausalTransformerGenerator(eval_args.generator, model, tokenizer) + logger.info("Model loaded") + + ppl_results = None + if eval_args.run_ppl: + assert eval_args.validation is not None + if len(eval_args.validation.sources) > 0: + ppl_results = {} + logger.info("Starting PPL evaluation on validation sets") + for source in eval_args.validation.sources: + ppl_results[source] = eval_ppl_on_path( + world_rank=world_rank, + world_size=world_size, + model=model, + tokenizer_args=train_cfg.data.tokenizer_args, + # TODO: Don't hardcode, modify based on model + patcher_args=PatcherArgs(patching_mode=PatchingModeEnum.byte), + add_patches=False, + path=os.path.join(eval_args.validation.root_dir, source), + max_n_docs=eval_args.validation.max_n_docs, + batch_size=8, + arrow_batch_size=100, + s3_profile="blt", + ) + + task_results = None + if eval_args.run_tasks: + assert eval_args.generator is not None + assert eval_args.harness is not None + generator = PackedCausalTransformerGenerator( + eval_args.generator, model, tokenizer + ) + wrap = EvalHarnessLM(generator) + # TODO: This needs to be checked/sped up + task_results = simple_evaluate(wrap, **eval_args.harness.model_dump()) + + results = {"ppl": ppl_results, "tasks": task_results} + # TODO: Serial and Parallel yield slightly different number of bytes, debug this later, + # leaving this log statement here to help with that. + # logging.info("Rank: %s Results: %s", world_rank, results) - wrap = EvalHarnessLM(generator) - # Redo - results = simple_evaluate(wrap, eval_args.harness.model_dump()) - val_results = None - if eval_args.validation: - val_results = eval_on_val(generator, eval_args.validation, train_cfg) if get_global_rank() == 0: with fs.open(os.path.join(eval_args.dump_dir, "results.json"), "w") as f: f.write(json.dumps(results)) - logger.info(f"All evaluation results: {results['results']}") - if val_results is not None: + logger.info(f"All evaluation results: {results}") + if ppl_results is not None: with fs.open(os.path.join(eval_args.dump_dir, "validation.json"), "w") as f: - f.write(json.dumps(val_results)) - logger.info(f"All validation results: {val_results}") + f.write(json.dumps(ppl_results)) + logger.info(f"All validation results: {ppl_results}") + if eval_args.metric_log_dir and get_global_rank() == 0: metric_log_path = os.path.join(eval_args.metric_log_dir, "metrics.eval.jsonl") logger.info(f"Writing metric logs to {metric_log_path}") - timestamp = { + timestamp: dict[str, int | str] = { "created_at": datetime.utcnow().isoformat(), } if eval_args.global_step is not None: timestamp["global_step"] = eval_args.global_step print( - json.dumps(timestamp | results["results"]), + json.dumps(timestamp | results), file=fs.open(metric_log_path, mode="a"), flush=True, ) @@ -236,18 +415,16 @@ def launch_eval(eval_args: EvalArgs): val_log_path = os.path.join( eval_args.metric_log_dir, "metrics.validation.jsonl" ) - if val_results is not None: + if ppl_results is not None: print( - json.dumps(timestamp | val_results), + json.dumps(timestamp | ppl_results), file=fs.open(val_log_path, mode="a"), flush=True, ) - del generator - def main(): - eval_args = parse_args(EvalArgs) + eval_args = parse_args_to_pydantic_model(EvalArgs) launch_eval(eval_args) diff --git a/bytelatent/generate.py b/bytelatent/generate.py index eb79d81..9d44f30 100644 --- a/bytelatent/generate.py +++ b/bytelatent/generate.py @@ -387,8 +387,7 @@ def load_consolidated_model_and_tokenizer( ): train_args_path = os.path.join(consolidated_path, "params.json") fs = get_fs(train_args_path) - with fs.open(train_args_path) as f: - train_args = TrainArgs.model_validate_json(f.read()) + train_args = TrainArgs.model_validate_json(fs.read_text(train_args_path)) if train_args.train_entropy_model: model_args = train_args.entropy_model @@ -401,7 +400,8 @@ def load_consolidated_model_and_tokenizer( train_args.distributed.model_dtype ] tokenizer = train_args.data.tokenizer_args.build() - st_dict = torch.load(consolidated_path / CONSOLIDATE_NAME, weights_only=True) + with fs.open(os.path.join(consolidated_path, CONSOLIDATE_NAME)) as f: + st_dict = torch.load(f, weights_only=True) model.load_state_dict(st_dict["model"]) model = model.cuda().eval() for param in model.parameters(): diff --git a/bytelatent/metrics.py b/bytelatent/metrics.py index ed805e5..15d2f48 100644 --- a/bytelatent/metrics.py +++ b/bytelatent/metrics.py @@ -55,7 +55,7 @@ class LoggingArgs(BaseModel): class MetricLogger: def __init__( self, - outdir: Path, + outdir: str, # args: TrainArgs args: Any | None = None, fs: fsspec.AbstractFileSystem | None = None, diff --git a/bytelatent/train.py b/bytelatent/train.py index 5a8f937..f9f38e6 100644 --- a/bytelatent/train.py +++ b/bytelatent/train.py @@ -48,6 +48,7 @@ from bytelatent.distributed import ( requeue_slurm_job, setup_env, setup_torch_distributed, + to_py_num, ) from bytelatent.eval import EVAL_FOLDER_NAME, launch_eval from bytelatent.logger import init_logger @@ -91,13 +92,6 @@ def get_iterator_state_name(iterator_state): raise ValueError(f"Unsupported iterator to get name from: {iterator_state}") -def to_py_num(num: int | float | torch.Tensor | np.ndarray) -> int | float: - if isinstance(num, (torch.Tensor, np.ndarray)): - return num.item() - else: - return num - - # TODO: Make this pydantic based instead of data class based # TODO: Generalize this to any iterator state @dataclass @@ -154,57 +148,13 @@ def validate_train_args(args: TrainArgs, output_size: int): logger.info(f"Setting checkpoint path to {args.checkpoint.path}") args.checkpoint.path = os.path.join(args.dump_dir, "checkpoints") - data_fs = get_fs(args.data.root_dir, s3_profile=args.data.s3_profile) - for source in args.data.sources: - data_path = os.path.join(args.data.root_dir, source) - assert data_fs.exists(data_path), f"{data_path} doesn't exist" + if args.data.root_dir is not None: + data_fs = get_fs(args.data.root_dir, s3_profile=args.data.s3_profile) + for source in args.data.sources: + data_path = os.path.join(args.data.root_dir, source) + assert data_fs.exists(data_path), f"{data_path} doesn't exist" - if ( - args.distributed.dp_replicate - * args.distributed.dp_shard - * args.distributed.tp_size - != get_world_size() - ): - logging.info("Modifying TrainArgs distributed config") - assert get_world_size() % args.distributed.dp_shard == 0 - logging.info("World size: %s", get_world_size()) - logging.info( - "Existing setting: train_args.distributed.dp_shard=%s", - args.distributed.dp_shard, - ) - logging.info( - "Setting train_args.distributed.dp_replicate=%s, was dp_replicate=%s", - get_world_size() // args.distributed.dp_shard, - args.distributed.dp_replicate, - ) - args.distributed.dp_replicate = get_world_size() // args.distributed.dp_shard - - logging.info( - "Changing dp_replicate from %s to %s, to account for tp_size=%s", - args.distributed.dp_replicate, - args.distributed.dp_replicate // args.distributed.tp_size, - args.distributed.tp_size, - ) - assert args.distributed.dp_replicate % args.distributed.tp_size == 0 - args.distributed.dp_replicate = ( - args.distributed.dp_replicate // args.distributed.tp_size - ) - - logger.warning( - f"Setting Data Parallel size to {args.distributed.dp_replicate * args.distributed.dp_shard}" - ) - assert ( - args.distributed.dp_replicate - * args.distributed.dp_shard - * args.distributed.tp_size - == get_world_size() - ) - - if args.distributed.fsdp_type == "no_shard": - assert ( - args.distributed.dp_shard == 1 - and args.distributed.dp_replicate == get_world_size() - ) + args.distributed.configure_world() if args.model is not None: args.model.max_seqlen = args.data.seq_len @@ -243,7 +193,9 @@ def set_preemption_flag(signum, frame): preemption_flag["flag"] = True -def every_n_steps(train_state, freq, acc_step=None, acc_freq=None): +def every_n_steps(train_state, freq: int, acc_step=None, acc_freq=None): + if freq < 0: + return False test = train_state.step % freq == 0 if acc_step is not None: test = test and (train_state.acc_step == acc_step) @@ -272,7 +224,7 @@ def train(args: TrainArgs): tokenizer = args.data.tokenizer_args.build() validate_train_args( args, - tokenizer.n_words, + tokenizer.get_vocab_size(), ) dump_fs = get_fs(args.dump_dir, s3_profile=args.checkpoint.s3_profile) if get_is_master():