From 7517ac2a9f3fbe2106a37d1817584e02202f2de5 Mon Sep 17 00:00:00 2001
From: Pedro Rodriguez <par@meta.com>
Date: Tue, 11 Mar 2025 09:57:19 -0700
Subject: [PATCH] Get evals working again. (#46)

- PPL/validation: Works now and uses multi-gpu. For some reason 1 GPU differs from multi-GPU, can debug in a followup PR
- Generation evals likely work, but are very slow, so disabled for now


Test Plan:
```
torchrun --nproc-per-node 8 -m bytelatent.eval config=../internal-blt/configs/eval.yaml
```
---
 bytelatent/args.py        |   4 +
 bytelatent/distributed.py |  42 +++++++
 bytelatent/eval.py        | 253 ++++++++++++++++++++++++++++++++------
 bytelatent/generate.py    |   6 +-
 bytelatent/metrics.py     |   2 +-
 bytelatent/train.py       |  70 ++---------
 6 files changed, 276 insertions(+), 101 deletions(-)

diff --git a/bytelatent/args.py b/bytelatent/args.py
index bad4d17..13acfc0 100644
--- a/bytelatent/args.py
+++ b/bytelatent/args.py
@@ -270,6 +270,10 @@ class EvalArgs(BaseModel):
     dump_dir: str | None = None
     ckpt_dir: str | None = None
     metric_log_dir: str | None = None
+
+    run_ppl: bool = True
+    run_tasks: bool = False
+
     generator: PackedCausalTransformerGeneratorArgs = (
         PackedCausalTransformerGeneratorArgs()
     )
diff --git a/bytelatent/distributed.py b/bytelatent/distributed.py
index 80661d5..284c717 100644
--- a/bytelatent/distributed.py
+++ b/bytelatent/distributed.py
@@ -15,6 +15,7 @@ from functools import lru_cache, partial, reduce
 from itertools import chain
 from typing import List, Optional, Tuple, Union
 
+import numpy as np
 import torch
 
 # for no recompute ops
@@ -78,6 +79,40 @@ class DistributedArgs(BaseModel):
 
     spawn_method: str = "forkserver"
 
+    def configure_world(self):
+        pass
+        if self.dp_replicate * self.dp_shard * self.tp_size != get_world_size():
+            logging.info("Modifying TrainArgs distributed config")
+            assert get_world_size() % self.dp_shard == 0
+            logging.info("World size: %s", get_world_size())
+            logging.info(
+                "Existing setting: train_args.distributed.dp_shard=%s",
+                self.dp_shard,
+            )
+            logging.info(
+                "Setting train_args.distributed.dp_replicate=%s, was dp_replicate=%s",
+                get_world_size() // self.dp_shard,
+                self.dp_replicate,
+            )
+            self.dp_replicate = get_world_size() // self.dp_shard
+
+            logging.info(
+                "Changing dp_replicate from %s to %s, to account for tp_size=%s",
+                self.dp_replicate,
+                self.dp_replicate // self.tp_size,
+                self.tp_size,
+            )
+            assert self.dp_replicate % self.tp_size == 0
+            self.dp_replicate = self.dp_replicate // self.tp_size
+
+            logger.warning(
+                f"Setting Data Parallel size to {self.dp_replicate * self.dp_shard}"
+            )
+            assert self.dp_replicate * self.dp_shard * self.tp_size == get_world_size()
+
+            if self.fsdp_type == "no_shard":
+                assert self.dp_shard == 1 and self.dp_replicate == get_world_size()
+
 
 class EnvironmentArgs(BaseModel):
     model_config = ConfigDict(extra="forbid")
@@ -151,6 +186,13 @@ def dist_mean_dict(x):
     return r
 
 
+def to_py_num(num: int | float | torch.Tensor | np.ndarray) -> int | float:
+    if isinstance(num, (torch.Tensor, np.ndarray)):
+        return num.item()
+    else:
+        return num
+
+
 @lru_cache()
 def get_is_torch_run() -> bool:
     return os.environ.get("LOCAL_RANK") is not None
diff --git a/bytelatent/eval.py b/bytelatent/eval.py
index 50e17cd..0622979 100644
--- a/bytelatent/eval.py
+++ b/bytelatent/eval.py
@@ -2,6 +2,7 @@
 
 import json
 import logging
+import math
 import os
 from collections import defaultdict
 from datetime import datetime
@@ -10,22 +11,48 @@ import torch
 from lm_eval import simple_evaluate
 from lm_eval.api.instance import Instance
 from lm_eval.api.model import LM
+from rich.progress import track
+from torch.nn import functional as F
 
-from bytelatent.args import EvalArgs, ValidationArgs
+from bytelatent.args import (
+    EvalArgs,
+    TrainArgs,
+    ValidationArgs,
+    find_and_sanitize_chunks,
+)
 from bytelatent.checkpoint import CONSOLIDATE_FOLDER, consolidate_checkpoints
 from bytelatent.config_parser import parse_args_to_pydantic_model
 from bytelatent.data.file_util import get_fs
+from bytelatent.data.iterators.arrow_iterator import ArrowFileIterator
+from bytelatent.data.iterators.limit_iterator import LimitIterator
+from bytelatent.data.iterators.packing_iterator import (
+    PackingArgs,
+    PackingIterator,
+    PackingMode,
+)
+from bytelatent.data.iterators.preprocess_iterator import PreprocessIterator
+from bytelatent.data.iterators.sequence_iterator import (
+    SequenceIterator,
+    SequencePackingArgs,
+)
+from bytelatent.data.patcher import PatcherArgs, PatchingModeEnum
 from bytelatent.distributed import (
     DistributedArgs,
     dist_mean_dict,
+    dist_sum,
+    get_device_mesh,
     get_global_rank,
     get_world_size,
     setup_torch_distributed,
+    to_py_num,
 )
 from bytelatent.generate import (
     PackedCausalTransformerGenerator,
     load_consolidated_model_and_tokenizer,
 )
+from bytelatent.model.blt import ByteLatentTransformer
+from bytelatent.tokenizers.build_tokenizer import TokenizerArgs
+from bytelatent.transformer import LMTransformer
 
 EVAL_FOLDER_NAME = "{:010d}"
 
@@ -113,19 +140,134 @@ class EvalHarnessLM(LM):
         return results
 
 
-def eval_on_val(generator, val_args: ValidationArgs, train_cfg):
-    srcs = {}
+@torch.no_grad()
+def eval_ppl_on_path(
+    *,
+    world_rank: int,
+    world_size: int,
+    model: LMTransformer | ByteLatentTransformer,
+    tokenizer_args: TokenizerArgs,
+    patcher_args: PatcherArgs,
+    add_patches: bool,
+    path: str,
+    batch_size: int,
+    arrow_batch_size: int,
+    max_n_docs: int | None,
+    s3_profile: str | None = None,
+):
+    model.eval()
+    tokenizer = tokenizer_args.build()
+    seq_len = model.get_output_seq_len()
+    chunks = find_and_sanitize_chunks(
+        path,
+        world_size=1,
+        file_pattern="*.val.jsonl",
+        s3_profile=s3_profile,
+    )
+    assert (
+        len(chunks) == 1
+    ), f"There should be only 1 chunk per validation file, but found: {chunks}"
+    chunk = chunks[0]
+    arrow_iterator = ArrowFileIterator(
+        file_path=chunk,
+        preprocess_dir=None,
+        entropy_model_name=None,
+        worker_id=world_rank,
+        num_workers=world_size,
+        arrow_batch_size=arrow_batch_size,
+        s3_profile=s3_profile,
+        file_format="json",
+    )
+    if max_n_docs is not None:
+        arrow_iterator = LimitIterator(arrow_iterator, limit=max_n_docs)
+    preprocess_iterator = PreprocessIterator(
+        arrow_iterator,
+        patcher_args=patcher_args,
+        tokenizer_args=tokenizer_args,
+        add_patches=add_patches,
+    )
+    sequence_iterator = SequenceIterator(
+        preprocess_iterator,
+        sequence_packing_args=SequencePackingArgs(
+            output_seq_len=seq_len,
+            # Effectively disables shuffles
+            buffer_size=1,
+        ),
+        rng_state=None,
+    )
+    packing_args = PackingArgs(
+        batch_size=batch_size,
+        seq_len=seq_len,
+        # TODO: make these seq lens worth with blt
+        max_length=seq_len,
+        pad_to_max_length=True,
+        enable_byte_ngrams=False,
+        pad_id=tokenizer.boe_id,
+        packing_mode=PackingMode.BYTES,
+    )
+    packing_iterator = PackingIterator(sequence_iterator, packing_args=packing_args)
+    total_loss = 0.0
+    n_bytes = 0
+    batch_iterator = packing_iterator.create_iter()
+    for batch in batch_iterator:
+        x = torch.from_numpy(batch.x).cuda()
+        y = torch.from_numpy(batch.y).cuda()
+        mask = None if batch.mask is None else torch.from_numpy(batch.mask).cuda()
+        if tokenizer_args.name in ["bytes", "blt"]:
+            n_bytes += y.numel() if mask is None else mask.sum().item()
+            pred = model(x)
+            loss = F.cross_entropy(pred.flatten(0, 1), y.flatten(0, 1), reduction="sum")
+            total_loss += loss.item()
+        else:
+            raise NotImplementedError()
+    all_n_bytes = to_py_num(dist_sum(n_bytes))
+    all_total_loss = to_py_num(dist_sum(total_loss))
+    return {
+        "n_bytes": all_n_bytes,
+        "n_bytes_gpu": n_bytes,
+        "loss_sum": all_total_loss,
+        "loss_sum_gpu": total_loss,
+        "loss_mean": all_total_loss / all_n_bytes,
+        "loss_mean_gpu": total_loss / n_bytes,
+        "ppl": math.exp(all_total_loss / all_n_bytes) if all_n_bytes > 0 else 0.0,
+        "bpb": all_total_loss / math.log(2) / all_n_bytes,
+    }
+
+
+def eval_on_val(generator, val_args: ValidationArgs, train_cfg: TrainArgs):
+    srcs = []
     for src in val_args.sources:
         path = os.path.join(val_args.root_dir, src)
-        srcs[path] = 1.0
+        srcs.append(path)
+
     for src in train_cfg.data.sources:
         path = os.path.join(train_cfg.data.root_dir, src)
-        srcs[path] = 1.0
+        srcs.append(path)
 
-    multi_state = init_choice_state(
-        "", srcs, 0, get_global_rank(), get_world_size(), "*.val.jsonl"
-    )
-    path_to_iter = setup_sources(multi_state)
+    path_to_iter = {}
+    for path in srcs:
+        chunks = find_and_sanitize_chunks(
+            path,
+            world_size=1,
+            file_pattern="*.val.jsonl",
+            s3_profile=train_cfg.data.s3_profile,
+        )
+        assert (
+            len(chunks) == 1
+        ), f"There should be only 1 chunk per validation file, but found: {chunks}"
+        chunk = chunks[0]
+        iterator = ArrowFileIterator(
+            dataset_files=[chunk],
+            file_path=None,
+            preprocess_dir=None,
+            entropy_model_name=None,
+            worker_id=0,
+            num_workers=1,
+            arrow_batch_size=train_cfg.data.arrow_batch_size,
+            s3_profile=train_cfg.data.s3_profile,
+            file_format="json",
+        )
+        path_to_iter[path] = iterator
 
     max_gen_len = generator.max_gen_len
     # We temporarily lower max gen len
@@ -133,16 +275,11 @@ def eval_on_val(generator, val_args: ValidationArgs, train_cfg):
 
     all_val_metrics = {}
     for src in path_to_iter:
-        jsonl_iterator = path_to_iter[src]
+        example_iterator = path_to_iter[src].create_iter()
         texts = []
         logger.info(f"Running validation on {src}...")
-        for step, (content, state) in enumerate(jsonl_iterator):
-            if state["current_iter"] > 0 or (
-                val_args.max_steps is not None and step >= val_args.max_steps
-            ):
-                break
-            content_key = "text" if ("text" in content) else "content"
-            texts.append(content[content_key])
+        for step, example in enumerate(example_iterator):
+            texts.append(example.text)
 
         _, loglikelihood, _ = generator.generate(texts)
 
@@ -174,8 +311,18 @@ def eval_on_val(generator, val_args: ValidationArgs, train_cfg):
 
 
 def launch_eval(eval_args: EvalArgs):
+    assert eval_args.dump_dir is not None
+    assert eval_args.ckpt_dir is not None
+    distributed_args = DistributedArgs()
+    distributed_args.configure_world()
     if not torch.distributed.is_initialized():
-        setup_torch_distributed(DistributedArgs())
+        setup_torch_distributed(distributed_args)
+
+    world_mesh = get_device_mesh(distributed_args)
+    dp_mesh = world_mesh["dp_replicate"]
+    assert distributed_args.dp_shard == 1
+    world_size = dp_mesh.size()
+    world_rank = dp_mesh.get_local_rank()
 
     fs = get_fs(eval_args.ckpt_dir, s3_profile=eval_args.s3_profile)
     if (
@@ -187,7 +334,7 @@ def launch_eval(eval_args: EvalArgs):
     else:
         consolidate_path = os.path.join(eval_args.ckpt_dir, CONSOLIDATE_FOLDER)
         if not fs.exists(consolidate_path) and get_global_rank() == 0:
-            consolidate_path = consolidate_checkpoints(eval_args.ckpt_dir)
+            consolidate_path = consolidate_checkpoints(fs, eval_args.ckpt_dir)
 
     fs.mkdirs(eval_args.dump_dir, exist_ok=True)
     with fs.open(os.path.join(eval_args.dump_dir, "config.yaml"), "w") as f:
@@ -200,35 +347,67 @@ def launch_eval(eval_args: EvalArgs):
     model, tokenizer, train_cfg = load_consolidated_model_and_tokenizer(
         consolidate_path,
     )
-    logger.info("Model loaded")
     model.eval()
-    generator = PackedCausalTransformerGenerator(eval_args.generator, model, tokenizer)
+    logger.info("Model loaded")
+
+    ppl_results = None
+    if eval_args.run_ppl:
+        assert eval_args.validation is not None
+        if len(eval_args.validation.sources) > 0:
+            ppl_results = {}
+            logger.info("Starting PPL evaluation on validation sets")
+            for source in eval_args.validation.sources:
+                ppl_results[source] = eval_ppl_on_path(
+                    world_rank=world_rank,
+                    world_size=world_size,
+                    model=model,
+                    tokenizer_args=train_cfg.data.tokenizer_args,
+                    # TODO: Don't hardcode, modify based on model
+                    patcher_args=PatcherArgs(patching_mode=PatchingModeEnum.byte),
+                    add_patches=False,
+                    path=os.path.join(eval_args.validation.root_dir, source),
+                    max_n_docs=eval_args.validation.max_n_docs,
+                    batch_size=8,
+                    arrow_batch_size=100,
+                    s3_profile="blt",
+                )
+
+    task_results = None
+    if eval_args.run_tasks:
+        assert eval_args.generator is not None
+        assert eval_args.harness is not None
+        generator = PackedCausalTransformerGenerator(
+            eval_args.generator, model, tokenizer
+        )
+        wrap = EvalHarnessLM(generator)
+        # TODO: This needs to be checked/sped up
+        task_results = simple_evaluate(wrap, **eval_args.harness.model_dump())
+
+    results = {"ppl": ppl_results, "tasks": task_results}
+    # TODO: Serial and Parallel yield slightly different number of bytes, debug this later,
+    # leaving this log statement here to help with that.
+    # logging.info("Rank: %s Results: %s", world_rank, results)
 
-    wrap = EvalHarnessLM(generator)
-    # Redo
-    results = simple_evaluate(wrap, eval_args.harness.model_dump())
-    val_results = None
-    if eval_args.validation:
-        val_results = eval_on_val(generator, eval_args.validation, train_cfg)
     if get_global_rank() == 0:
         with fs.open(os.path.join(eval_args.dump_dir, "results.json"), "w") as f:
             f.write(json.dumps(results))
-        logger.info(f"All evaluation results: {results['results']}")
-        if val_results is not None:
+        logger.info(f"All evaluation results: {results}")
+        if ppl_results is not None:
             with fs.open(os.path.join(eval_args.dump_dir, "validation.json"), "w") as f:
-                f.write(json.dumps(val_results))
-            logger.info(f"All validation results: {val_results}")
+                f.write(json.dumps(ppl_results))
+            logger.info(f"All validation results: {ppl_results}")
+
     if eval_args.metric_log_dir and get_global_rank() == 0:
         metric_log_path = os.path.join(eval_args.metric_log_dir, "metrics.eval.jsonl")
 
         logger.info(f"Writing metric logs to {metric_log_path}")
-        timestamp = {
+        timestamp: dict[str, int | str] = {
             "created_at": datetime.utcnow().isoformat(),
         }
         if eval_args.global_step is not None:
             timestamp["global_step"] = eval_args.global_step
         print(
-            json.dumps(timestamp | results["results"]),
+            json.dumps(timestamp | results),
             file=fs.open(metric_log_path, mode="a"),
             flush=True,
         )
@@ -236,18 +415,16 @@ def launch_eval(eval_args: EvalArgs):
         val_log_path = os.path.join(
             eval_args.metric_log_dir, "metrics.validation.jsonl"
         )
-        if val_results is not None:
+        if ppl_results is not None:
             print(
-                json.dumps(timestamp | val_results),
+                json.dumps(timestamp | ppl_results),
                 file=fs.open(val_log_path, mode="a"),
                 flush=True,
             )
 
-    del generator
-
 
 def main():
-    eval_args = parse_args(EvalArgs)
+    eval_args = parse_args_to_pydantic_model(EvalArgs)
     launch_eval(eval_args)
 
 
diff --git a/bytelatent/generate.py b/bytelatent/generate.py
index eb79d81..9d44f30 100644
--- a/bytelatent/generate.py
+++ b/bytelatent/generate.py
@@ -387,8 +387,7 @@ def load_consolidated_model_and_tokenizer(
 ):
     train_args_path = os.path.join(consolidated_path, "params.json")
     fs = get_fs(train_args_path)
-    with fs.open(train_args_path) as f:
-        train_args = TrainArgs.model_validate_json(f.read())
+    train_args = TrainArgs.model_validate_json(fs.read_text(train_args_path))
 
     if train_args.train_entropy_model:
         model_args = train_args.entropy_model
@@ -401,7 +400,8 @@ def load_consolidated_model_and_tokenizer(
         train_args.distributed.model_dtype
     ]
     tokenizer = train_args.data.tokenizer_args.build()
-    st_dict = torch.load(consolidated_path / CONSOLIDATE_NAME, weights_only=True)
+    with fs.open(os.path.join(consolidated_path, CONSOLIDATE_NAME)) as f:
+        st_dict = torch.load(f, weights_only=True)
     model.load_state_dict(st_dict["model"])
     model = model.cuda().eval()
     for param in model.parameters():
diff --git a/bytelatent/metrics.py b/bytelatent/metrics.py
index ed805e5..15d2f48 100644
--- a/bytelatent/metrics.py
+++ b/bytelatent/metrics.py
@@ -55,7 +55,7 @@ class LoggingArgs(BaseModel):
 class MetricLogger:
     def __init__(
         self,
-        outdir: Path,
+        outdir: str,
         # args: TrainArgs
         args: Any | None = None,
         fs: fsspec.AbstractFileSystem | None = None,
diff --git a/bytelatent/train.py b/bytelatent/train.py
index 5a8f937..f9f38e6 100644
--- a/bytelatent/train.py
+++ b/bytelatent/train.py
@@ -48,6 +48,7 @@ from bytelatent.distributed import (
     requeue_slurm_job,
     setup_env,
     setup_torch_distributed,
+    to_py_num,
 )
 from bytelatent.eval import EVAL_FOLDER_NAME, launch_eval
 from bytelatent.logger import init_logger
@@ -91,13 +92,6 @@ def get_iterator_state_name(iterator_state):
         raise ValueError(f"Unsupported iterator to get name from: {iterator_state}")
 
 
-def to_py_num(num: int | float | torch.Tensor | np.ndarray) -> int | float:
-    if isinstance(num, (torch.Tensor, np.ndarray)):
-        return num.item()
-    else:
-        return num
-
-
 # TODO: Make this pydantic based instead of data class based
 # TODO: Generalize this to any iterator state
 @dataclass
@@ -154,57 +148,13 @@ def validate_train_args(args: TrainArgs, output_size: int):
         logger.info(f"Setting checkpoint path to {args.checkpoint.path}")
         args.checkpoint.path = os.path.join(args.dump_dir, "checkpoints")
 
-    data_fs = get_fs(args.data.root_dir, s3_profile=args.data.s3_profile)
-    for source in args.data.sources:
-        data_path = os.path.join(args.data.root_dir, source)
-        assert data_fs.exists(data_path), f"{data_path} doesn't exist"
+    if args.data.root_dir is not None:
+        data_fs = get_fs(args.data.root_dir, s3_profile=args.data.s3_profile)
+        for source in args.data.sources:
+            data_path = os.path.join(args.data.root_dir, source)
+            assert data_fs.exists(data_path), f"{data_path} doesn't exist"
 
-    if (
-        args.distributed.dp_replicate
-        * args.distributed.dp_shard
-        * args.distributed.tp_size
-        != get_world_size()
-    ):
-        logging.info("Modifying TrainArgs distributed config")
-        assert get_world_size() % args.distributed.dp_shard == 0
-        logging.info("World size: %s", get_world_size())
-        logging.info(
-            "Existing setting: train_args.distributed.dp_shard=%s",
-            args.distributed.dp_shard,
-        )
-        logging.info(
-            "Setting train_args.distributed.dp_replicate=%s, was dp_replicate=%s",
-            get_world_size() // args.distributed.dp_shard,
-            args.distributed.dp_replicate,
-        )
-        args.distributed.dp_replicate = get_world_size() // args.distributed.dp_shard
-
-        logging.info(
-            "Changing dp_replicate from %s to %s, to account for tp_size=%s",
-            args.distributed.dp_replicate,
-            args.distributed.dp_replicate // args.distributed.tp_size,
-            args.distributed.tp_size,
-        )
-        assert args.distributed.dp_replicate % args.distributed.tp_size == 0
-        args.distributed.dp_replicate = (
-            args.distributed.dp_replicate // args.distributed.tp_size
-        )
-
-        logger.warning(
-            f"Setting Data Parallel size to {args.distributed.dp_replicate * args.distributed.dp_shard}"
-        )
-        assert (
-            args.distributed.dp_replicate
-            * args.distributed.dp_shard
-            * args.distributed.tp_size
-            == get_world_size()
-        )
-
-        if args.distributed.fsdp_type == "no_shard":
-            assert (
-                args.distributed.dp_shard == 1
-                and args.distributed.dp_replicate == get_world_size()
-            )
+    args.distributed.configure_world()
 
     if args.model is not None:
         args.model.max_seqlen = args.data.seq_len
@@ -243,7 +193,9 @@ def set_preemption_flag(signum, frame):
     preemption_flag["flag"] = True
 
 
-def every_n_steps(train_state, freq, acc_step=None, acc_freq=None):
+def every_n_steps(train_state, freq: int, acc_step=None, acc_freq=None):
+    if freq < 0:
+        return False
     test = train_state.step % freq == 0
     if acc_step is not None:
         test = test and (train_state.acc_step == acc_step)
@@ -272,7 +224,7 @@ def train(args: TrainArgs):
         tokenizer = args.data.tokenizer_args.build()
         validate_train_args(
             args,
-            tokenizer.n_words,
+            tokenizer.get_vocab_size(),
         )
         dump_fs = get_fs(args.dump_dir, s3_profile=args.checkpoint.s3_profile)
         if get_is_master():