From 719900d4bd2924fcafd83c3d26832adf24226b68 Mon Sep 17 00:00:00 2001
From: Pedro Rodriguez <me@pedro.ai>
Date: Thu, 13 Mar 2025 17:14:53 +0000
Subject: [PATCH] Update ppl evals to work with blt model, in addition to
 entropy model

Summary:

Test Plan:

Run
```
python -m bytelatent.eval config=../internal-blt/configs/eval_blt.yaml validation.max_n_docs=null
python -m bytelatent.eval config=../internal-blt/configs/eval_entropy.yaml validation.max_n_docs=null
```
---
 bytelatent/args.py                            |   1 +
 bytelatent/data/iterators/packing_iterator.py |  29 +++-
 bytelatent/eval.py                            | 142 ++++--------------
 3 files changed, 61 insertions(+), 111 deletions(-)

diff --git a/bytelatent/args.py b/bytelatent/args.py
index 13acfc0..6927f1c 100644
--- a/bytelatent/args.py
+++ b/bytelatent/args.py
@@ -263,6 +263,7 @@ class ValidationArgs(BaseModel):
     use_val_from_train_src: bool = True  # Use the validation set from training sources
     root_dir: str = ""
     sources: list[str] = []  # Other sources to eval on
+    batch_size: int = 8
 
 
 class EvalArgs(BaseModel):
diff --git a/bytelatent/data/iterators/packing_iterator.py b/bytelatent/data/iterators/packing_iterator.py
index f407f9f..d220342 100644
--- a/bytelatent/data/iterators/packing_iterator.py
+++ b/bytelatent/data/iterators/packing_iterator.py
@@ -221,6 +221,7 @@ class PackingIterator(StatefulIterator[Batch, PackingIteratorState]):
         enable_byte_ngrams = self.packing_args.enable_byte_ngrams
         max_length = self.packing_args.max_length
         assert max_length is not None
+        final_leftover_batch = False
         while True:
             tokens: list[list[int]] = []
             masks: list[list[bool]] = []
@@ -252,6 +253,9 @@ class PackingIterator(StatefulIterator[Batch, PackingIteratorState]):
                 break
 
             x_patch_lengths = np.array(patch_lengths)
+            assert (
+                x_patch_lengths.shape[1] == seq_len
+            ), f"{x_patch_lengths.shape[1]} vs {seq_len}"
             # pad batch to same length
             tok_seq_len = max([len(toks) for toks in tokens]) - 1
             x = np.full((batch_size, tok_seq_len), fill_value=pad_id)
@@ -263,7 +267,30 @@ class PackingIterator(StatefulIterator[Batch, PackingIteratorState]):
                 # Adjust patch lengths to match x
                 x_patch_lengths[i, -1] += tok_seq_len - (len(tok_seq) - 1)
 
-            assert x_patch_lengths.shape == (batch_size, seq_len)
+            if x_patch_lengths.shape[0] < batch_size:
+                if final_leftover_batch:
+                    raise ValueError(
+                        "There should only be one partial batch, but found multiple"
+                    )
+                final_leftover_batch = True
+                assert len(masks) == len(x_patch_lengths)
+                n_missing = batch_size - x_patch_lengths.shape[0]
+                # Repeat the last patch length to validly pad it out, but
+                # update the mask to ignore the row
+                x_patch_lengths = np.vstack(
+                    [
+                        x_patch_lengths,
+                        np.repeat(x_patch_lengths[-1:, :], n_missing, axis=0),
+                    ]
+                )
+                for _ in range(n_missing):
+                    masks.append([0] * tok_seq_len)
+                assert len(masks) == batch_size
+
+            assert x_patch_lengths.shape == (
+                batch_size,
+                seq_len,
+            ), f"{x_patch_lengths.shape} vs {(batch_size, seq_len)}"
 
             if enable_byte_ngrams:
                 raise NotImplementedError()
diff --git a/bytelatent/eval.py b/bytelatent/eval.py
index 0622979..61e4a2d 100644
--- a/bytelatent/eval.py
+++ b/bytelatent/eval.py
@@ -148,35 +148,25 @@ def eval_ppl_on_path(
     model: LMTransformer | ByteLatentTransformer,
     tokenizer_args: TokenizerArgs,
     patcher_args: PatcherArgs,
+    packing_args: PackingArgs,
     add_patches: bool,
     path: str,
-    batch_size: int,
     arrow_batch_size: int,
     max_n_docs: int | None,
     s3_profile: str | None = None,
 ):
     model.eval()
-    tokenizer = tokenizer_args.build()
     seq_len = model.get_output_seq_len()
-    chunks = find_and_sanitize_chunks(
-        path,
-        world_size=1,
-        file_pattern="*.val.jsonl",
-        s3_profile=s3_profile,
-    )
-    assert (
-        len(chunks) == 1
-    ), f"There should be only 1 chunk per validation file, but found: {chunks}"
-    chunk = chunks[0]
     arrow_iterator = ArrowFileIterator(
-        file_path=chunk,
-        preprocess_dir=None,
+        file_path=None,
+        dataset_files=[path],
         entropy_model_name=None,
         worker_id=world_rank,
         num_workers=world_size,
         arrow_batch_size=arrow_batch_size,
+        preprocess_dir=None,
         s3_profile=s3_profile,
-        file_format="json",
+        file_format="arrow" if path.endswith("arrow") else "json",
     )
     if max_n_docs is not None:
         arrow_iterator = LimitIterator(arrow_iterator, limit=max_n_docs)
@@ -195,16 +185,6 @@ def eval_ppl_on_path(
         ),
         rng_state=None,
     )
-    packing_args = PackingArgs(
-        batch_size=batch_size,
-        seq_len=seq_len,
-        # TODO: make these seq lens worth with blt
-        max_length=seq_len,
-        pad_to_max_length=True,
-        enable_byte_ngrams=False,
-        pad_id=tokenizer.boe_id,
-        packing_mode=PackingMode.BYTES,
-    )
     packing_iterator = PackingIterator(sequence_iterator, packing_args=packing_args)
     total_loss = 0.0
     n_bytes = 0
@@ -213,9 +193,16 @@ def eval_ppl_on_path(
         x = torch.from_numpy(batch.x).cuda()
         y = torch.from_numpy(batch.y).cuda()
         mask = None if batch.mask is None else torch.from_numpy(batch.mask).cuda()
+        patch_lengths = batch.patch_lengths
+        if patch_lengths is not None:
+            patch_lengths = torch.from_numpy(patch_lengths).cuda()
+
         if tokenizer_args.name in ["bytes", "blt"]:
             n_bytes += y.numel() if mask is None else mask.sum().item()
-            pred = model(x)
+            if isinstance(model, ByteLatentTransformer):
+                pred = model(x, patch_lengths=patch_lengths)
+            else:
+                pred = model(x)
             loss = F.cross_entropy(pred.flatten(0, 1), y.flatten(0, 1), reduction="sum")
             total_loss += loss.item()
         else:
@@ -234,82 +221,6 @@ def eval_ppl_on_path(
     }
 
 
-def eval_on_val(generator, val_args: ValidationArgs, train_cfg: TrainArgs):
-    srcs = []
-    for src in val_args.sources:
-        path = os.path.join(val_args.root_dir, src)
-        srcs.append(path)
-
-    for src in train_cfg.data.sources:
-        path = os.path.join(train_cfg.data.root_dir, src)
-        srcs.append(path)
-
-    path_to_iter = {}
-    for path in srcs:
-        chunks = find_and_sanitize_chunks(
-            path,
-            world_size=1,
-            file_pattern="*.val.jsonl",
-            s3_profile=train_cfg.data.s3_profile,
-        )
-        assert (
-            len(chunks) == 1
-        ), f"There should be only 1 chunk per validation file, but found: {chunks}"
-        chunk = chunks[0]
-        iterator = ArrowFileIterator(
-            dataset_files=[chunk],
-            file_path=None,
-            preprocess_dir=None,
-            entropy_model_name=None,
-            worker_id=0,
-            num_workers=1,
-            arrow_batch_size=train_cfg.data.arrow_batch_size,
-            s3_profile=train_cfg.data.s3_profile,
-            file_format="json",
-        )
-        path_to_iter[path] = iterator
-
-    max_gen_len = generator.max_gen_len
-    # We temporarily lower max gen len
-    generator.max_gen_len = 1
-
-    all_val_metrics = {}
-    for src in path_to_iter:
-        example_iterator = path_to_iter[src].create_iter()
-        texts = []
-        logger.info(f"Running validation on {src}...")
-        for step, example in enumerate(example_iterator):
-            texts.append(example.text)
-
-        _, loglikelihood, _ = generator.generate(texts)
-
-        metrics = defaultdict(list)
-        for i, ll in enumerate(loglikelihood):
-            tmp = ll.sum().item()
-            metrics["nll"].append(tmp)
-            metrics["nll_per_token"].append(tmp / len(ll))
-            metrics["nll_per_char"].append(tmp / len(texts[i]))
-
-            metrics["avg_seqlen"].append(len(ll))
-
-        for m in metrics:
-            metrics[m] = sum(metrics[m]) / len(metrics[m])
-        metrics.update(dist_mean_dict(metrics))
-        logger.info(f"Validation on {src} done. Metrics: {metrics}")
-
-        name = os.path.basename(src)
-        if name in all_val_metrics:
-            logger.warning(
-                f"Duplicate source name {name}, path {src} in validation sources, renaming to {name}_1"
-            )
-            name = f"{name}_1"
-        all_val_metrics[name] = metrics
-
-    generator.max_gen_len = max_gen_len
-
-    return all_val_metrics
-
-
 def launch_eval(eval_args: EvalArgs):
     assert eval_args.dump_dir is not None
     assert eval_args.ckpt_dir is not None
@@ -342,17 +253,29 @@ def launch_eval(eval_args: EvalArgs):
 
     torch.distributed.barrier()
     logger.info("Loading model")
-    # TODO: Make this general so that it works with either
-    # LMTransformer or Blt, similar with args
     model, tokenizer, train_cfg = load_consolidated_model_and_tokenizer(
         consolidate_path,
     )
+    pad_id = 0 if train_cfg.data.tokenizer_args.name == "bytes" else tokenizer.boe_id
     model.eval()
     logger.info("Model loaded")
 
     ppl_results = None
     if eval_args.run_ppl:
         assert eval_args.validation is not None
+        packing_args = PackingArgs(
+            batch_size=eval_args.validation.batch_size,
+            seq_len=train_cfg.data.seq_len,
+            max_length=train_cfg.data.max_encoder_seq_length,
+            pad_to_max_length=True,
+            enable_byte_ngrams=False,
+            pad_id=pad_id,
+            packing_mode=(
+                PackingMode.BYTES
+                if train_cfg.data.patcher_args.patching_mode == PatchingModeEnum.byte
+                else PackingMode.PATCHING
+            ),
+        )
         if len(eval_args.validation.sources) > 0:
             ppl_results = {}
             logger.info("Starting PPL evaluation on validation sets")
@@ -362,14 +285,13 @@ def launch_eval(eval_args: EvalArgs):
                     world_size=world_size,
                     model=model,
                     tokenizer_args=train_cfg.data.tokenizer_args,
-                    # TODO: Don't hardcode, modify based on model
-                    patcher_args=PatcherArgs(patching_mode=PatchingModeEnum.byte),
-                    add_patches=False,
+                    patcher_args=train_cfg.data.patcher_args,
+                    packing_args=packing_args,
+                    add_patches=train_cfg.data.add_patches,
                     path=os.path.join(eval_args.validation.root_dir, source),
                     max_n_docs=eval_args.validation.max_n_docs,
-                    batch_size=8,
-                    arrow_batch_size=100,
-                    s3_profile="blt",
+                    arrow_batch_size=20,
+                    s3_profile=eval_args.s3_profile,
                 )
 
     task_results = None