From 719900d4bd2924fcafd83c3d26832adf24226b68 Mon Sep 17 00:00:00 2001 From: Pedro Rodriguez <me@pedro.ai> Date: Thu, 13 Mar 2025 17:14:53 +0000 Subject: [PATCH] Update ppl evals to work with blt model, in addition to entropy model Summary: Test Plan: Run ``` python -m bytelatent.eval config=../internal-blt/configs/eval_blt.yaml validation.max_n_docs=null python -m bytelatent.eval config=../internal-blt/configs/eval_entropy.yaml validation.max_n_docs=null ``` --- bytelatent/args.py | 1 + bytelatent/data/iterators/packing_iterator.py | 29 +++- bytelatent/eval.py | 142 ++++-------------- 3 files changed, 61 insertions(+), 111 deletions(-) diff --git a/bytelatent/args.py b/bytelatent/args.py index 13acfc0..6927f1c 100644 --- a/bytelatent/args.py +++ b/bytelatent/args.py @@ -263,6 +263,7 @@ class ValidationArgs(BaseModel): use_val_from_train_src: bool = True # Use the validation set from training sources root_dir: str = "" sources: list[str] = [] # Other sources to eval on + batch_size: int = 8 class EvalArgs(BaseModel): diff --git a/bytelatent/data/iterators/packing_iterator.py b/bytelatent/data/iterators/packing_iterator.py index f407f9f..d220342 100644 --- a/bytelatent/data/iterators/packing_iterator.py +++ b/bytelatent/data/iterators/packing_iterator.py @@ -221,6 +221,7 @@ class PackingIterator(StatefulIterator[Batch, PackingIteratorState]): enable_byte_ngrams = self.packing_args.enable_byte_ngrams max_length = self.packing_args.max_length assert max_length is not None + final_leftover_batch = False while True: tokens: list[list[int]] = [] masks: list[list[bool]] = [] @@ -252,6 +253,9 @@ class PackingIterator(StatefulIterator[Batch, PackingIteratorState]): break x_patch_lengths = np.array(patch_lengths) + assert ( + x_patch_lengths.shape[1] == seq_len + ), f"{x_patch_lengths.shape[1]} vs {seq_len}" # pad batch to same length tok_seq_len = max([len(toks) for toks in tokens]) - 1 x = np.full((batch_size, tok_seq_len), fill_value=pad_id) @@ -263,7 +267,30 @@ class PackingIterator(StatefulIterator[Batch, PackingIteratorState]): # Adjust patch lengths to match x x_patch_lengths[i, -1] += tok_seq_len - (len(tok_seq) - 1) - assert x_patch_lengths.shape == (batch_size, seq_len) + if x_patch_lengths.shape[0] < batch_size: + if final_leftover_batch: + raise ValueError( + "There should only be one partial batch, but found multiple" + ) + final_leftover_batch = True + assert len(masks) == len(x_patch_lengths) + n_missing = batch_size - x_patch_lengths.shape[0] + # Repeat the last patch length to validly pad it out, but + # update the mask to ignore the row + x_patch_lengths = np.vstack( + [ + x_patch_lengths, + np.repeat(x_patch_lengths[-1:, :], n_missing, axis=0), + ] + ) + for _ in range(n_missing): + masks.append([0] * tok_seq_len) + assert len(masks) == batch_size + + assert x_patch_lengths.shape == ( + batch_size, + seq_len, + ), f"{x_patch_lengths.shape} vs {(batch_size, seq_len)}" if enable_byte_ngrams: raise NotImplementedError() diff --git a/bytelatent/eval.py b/bytelatent/eval.py index 0622979..61e4a2d 100644 --- a/bytelatent/eval.py +++ b/bytelatent/eval.py @@ -148,35 +148,25 @@ def eval_ppl_on_path( model: LMTransformer | ByteLatentTransformer, tokenizer_args: TokenizerArgs, patcher_args: PatcherArgs, + packing_args: PackingArgs, add_patches: bool, path: str, - batch_size: int, arrow_batch_size: int, max_n_docs: int | None, s3_profile: str | None = None, ): model.eval() - tokenizer = tokenizer_args.build() seq_len = model.get_output_seq_len() - chunks = find_and_sanitize_chunks( - path, - world_size=1, - file_pattern="*.val.jsonl", - s3_profile=s3_profile, - ) - assert ( - len(chunks) == 1 - ), f"There should be only 1 chunk per validation file, but found: {chunks}" - chunk = chunks[0] arrow_iterator = ArrowFileIterator( - file_path=chunk, - preprocess_dir=None, + file_path=None, + dataset_files=[path], entropy_model_name=None, worker_id=world_rank, num_workers=world_size, arrow_batch_size=arrow_batch_size, + preprocess_dir=None, s3_profile=s3_profile, - file_format="json", + file_format="arrow" if path.endswith("arrow") else "json", ) if max_n_docs is not None: arrow_iterator = LimitIterator(arrow_iterator, limit=max_n_docs) @@ -195,16 +185,6 @@ def eval_ppl_on_path( ), rng_state=None, ) - packing_args = PackingArgs( - batch_size=batch_size, - seq_len=seq_len, - # TODO: make these seq lens worth with blt - max_length=seq_len, - pad_to_max_length=True, - enable_byte_ngrams=False, - pad_id=tokenizer.boe_id, - packing_mode=PackingMode.BYTES, - ) packing_iterator = PackingIterator(sequence_iterator, packing_args=packing_args) total_loss = 0.0 n_bytes = 0 @@ -213,9 +193,16 @@ def eval_ppl_on_path( x = torch.from_numpy(batch.x).cuda() y = torch.from_numpy(batch.y).cuda() mask = None if batch.mask is None else torch.from_numpy(batch.mask).cuda() + patch_lengths = batch.patch_lengths + if patch_lengths is not None: + patch_lengths = torch.from_numpy(patch_lengths).cuda() + if tokenizer_args.name in ["bytes", "blt"]: n_bytes += y.numel() if mask is None else mask.sum().item() - pred = model(x) + if isinstance(model, ByteLatentTransformer): + pred = model(x, patch_lengths=patch_lengths) + else: + pred = model(x) loss = F.cross_entropy(pred.flatten(0, 1), y.flatten(0, 1), reduction="sum") total_loss += loss.item() else: @@ -234,82 +221,6 @@ def eval_ppl_on_path( } -def eval_on_val(generator, val_args: ValidationArgs, train_cfg: TrainArgs): - srcs = [] - for src in val_args.sources: - path = os.path.join(val_args.root_dir, src) - srcs.append(path) - - for src in train_cfg.data.sources: - path = os.path.join(train_cfg.data.root_dir, src) - srcs.append(path) - - path_to_iter = {} - for path in srcs: - chunks = find_and_sanitize_chunks( - path, - world_size=1, - file_pattern="*.val.jsonl", - s3_profile=train_cfg.data.s3_profile, - ) - assert ( - len(chunks) == 1 - ), f"There should be only 1 chunk per validation file, but found: {chunks}" - chunk = chunks[0] - iterator = ArrowFileIterator( - dataset_files=[chunk], - file_path=None, - preprocess_dir=None, - entropy_model_name=None, - worker_id=0, - num_workers=1, - arrow_batch_size=train_cfg.data.arrow_batch_size, - s3_profile=train_cfg.data.s3_profile, - file_format="json", - ) - path_to_iter[path] = iterator - - max_gen_len = generator.max_gen_len - # We temporarily lower max gen len - generator.max_gen_len = 1 - - all_val_metrics = {} - for src in path_to_iter: - example_iterator = path_to_iter[src].create_iter() - texts = [] - logger.info(f"Running validation on {src}...") - for step, example in enumerate(example_iterator): - texts.append(example.text) - - _, loglikelihood, _ = generator.generate(texts) - - metrics = defaultdict(list) - for i, ll in enumerate(loglikelihood): - tmp = ll.sum().item() - metrics["nll"].append(tmp) - metrics["nll_per_token"].append(tmp / len(ll)) - metrics["nll_per_char"].append(tmp / len(texts[i])) - - metrics["avg_seqlen"].append(len(ll)) - - for m in metrics: - metrics[m] = sum(metrics[m]) / len(metrics[m]) - metrics.update(dist_mean_dict(metrics)) - logger.info(f"Validation on {src} done. Metrics: {metrics}") - - name = os.path.basename(src) - if name in all_val_metrics: - logger.warning( - f"Duplicate source name {name}, path {src} in validation sources, renaming to {name}_1" - ) - name = f"{name}_1" - all_val_metrics[name] = metrics - - generator.max_gen_len = max_gen_len - - return all_val_metrics - - def launch_eval(eval_args: EvalArgs): assert eval_args.dump_dir is not None assert eval_args.ckpt_dir is not None @@ -342,17 +253,29 @@ def launch_eval(eval_args: EvalArgs): torch.distributed.barrier() logger.info("Loading model") - # TODO: Make this general so that it works with either - # LMTransformer or Blt, similar with args model, tokenizer, train_cfg = load_consolidated_model_and_tokenizer( consolidate_path, ) + pad_id = 0 if train_cfg.data.tokenizer_args.name == "bytes" else tokenizer.boe_id model.eval() logger.info("Model loaded") ppl_results = None if eval_args.run_ppl: assert eval_args.validation is not None + packing_args = PackingArgs( + batch_size=eval_args.validation.batch_size, + seq_len=train_cfg.data.seq_len, + max_length=train_cfg.data.max_encoder_seq_length, + pad_to_max_length=True, + enable_byte_ngrams=False, + pad_id=pad_id, + packing_mode=( + PackingMode.BYTES + if train_cfg.data.patcher_args.patching_mode == PatchingModeEnum.byte + else PackingMode.PATCHING + ), + ) if len(eval_args.validation.sources) > 0: ppl_results = {} logger.info("Starting PPL evaluation on validation sets") @@ -362,14 +285,13 @@ def launch_eval(eval_args: EvalArgs): world_size=world_size, model=model, tokenizer_args=train_cfg.data.tokenizer_args, - # TODO: Don't hardcode, modify based on model - patcher_args=PatcherArgs(patching_mode=PatchingModeEnum.byte), - add_patches=False, + patcher_args=train_cfg.data.patcher_args, + packing_args=packing_args, + add_patches=train_cfg.data.add_patches, path=os.path.join(eval_args.validation.root_dir, source), max_n_docs=eval_args.validation.max_n_docs, - batch_size=8, - arrow_batch_size=100, - s3_profile="blt", + arrow_batch_size=20, + s3_profile=eval_args.s3_profile, ) task_results = None