blt/bytelatent/eval.py
Pedro Rodriguez 7517ac2a9f
Get evals working again. ()
- PPL/validation: Works now and uses multi-gpu. For some reason 1 GPU differs from multi-GPU, can debug in a followup PR
- Generation evals likely work, but are very slow, so disabled for now


Test Plan:
```
torchrun --nproc-per-node 8 -m bytelatent.eval config=../internal-blt/configs/eval.yaml
```
2025-03-11 09:57:19 -07:00

432 lines
15 KiB
Python

# Copyright (c) Meta Platforms, Inc. and affiliates.
import json
import logging
import math
import os
from collections import defaultdict
from datetime import datetime
import torch
from lm_eval import simple_evaluate
from lm_eval.api.instance import Instance
from lm_eval.api.model import LM
from rich.progress import track
from torch.nn import functional as F
from bytelatent.args import (
EvalArgs,
TrainArgs,
ValidationArgs,
find_and_sanitize_chunks,
)
from bytelatent.checkpoint import CONSOLIDATE_FOLDER, consolidate_checkpoints
from bytelatent.config_parser import parse_args_to_pydantic_model
from bytelatent.data.file_util import get_fs
from bytelatent.data.iterators.arrow_iterator import ArrowFileIterator
from bytelatent.data.iterators.limit_iterator import LimitIterator
from bytelatent.data.iterators.packing_iterator import (
PackingArgs,
PackingIterator,
PackingMode,
)
from bytelatent.data.iterators.preprocess_iterator import PreprocessIterator
from bytelatent.data.iterators.sequence_iterator import (
SequenceIterator,
SequencePackingArgs,
)
from bytelatent.data.patcher import PatcherArgs, PatchingModeEnum
from bytelatent.distributed import (
DistributedArgs,
dist_mean_dict,
dist_sum,
get_device_mesh,
get_global_rank,
get_world_size,
setup_torch_distributed,
to_py_num,
)
from bytelatent.generate import (
PackedCausalTransformerGenerator,
load_consolidated_model_and_tokenizer,
)
from bytelatent.model.blt import ByteLatentTransformer
from bytelatent.tokenizers.build_tokenizer import TokenizerArgs
from bytelatent.transformer import LMTransformer
EVAL_FOLDER_NAME = "{:010d}"
logger = logging.getLogger()
def all_dicts_same(dict_list):
if not dict_list: # Check if the list is empty
return True
# Compare each dictionary to the first one
first_dict = dict_list[0]
return all(d == first_dict for d in dict_list)
class MockAccelerator:
def gather(self, tensor):
l = [torch.zeros_like(tensor) for _ in range(get_world_size())]
torch.distributed.all_gather(l, tensor)
return torch.stack(l)
def wait_for_everyone(self):
torch.distributed.barrier()
# Light wrapper around generator for lm-eval harness
class EvalHarnessLM(LM):
def __init__(self, generator):
super().__init__()
self.generator = generator
self.accelerator = MockAccelerator()
self._rank = get_global_rank()
self._world_size = get_world_size()
self.device = generator.device
def generate_until(self, requests: list[Instance]) -> list[str]:
prompts, gen_args = zip(*[req.args for req in requests])
assert all_dicts_same(gen_args), "Doesn't support different gen args for now"
gen_args = gen_args[0]
temperature = gen_args.get("temperature", 0.0)
top_p = gen_args.get("top_p", None)
top_k = gen_args.get("top_k", None)
until = gen_args.get("until", [])
self.generator.temperature = temperature
self.generator.top_p = top_p
self.generator.top_k = top_k
self.generator.until = until
generations, _, _ = self.generator.generate(prompts)
filtered_gen = []
for g in generations:
for e in until:
g = g.replace(e, "")
filtered_gen.append(g)
return filtered_gen
def loglikelihood(self, requests: list[Instance]) -> list[tuple[float, bool]]:
prompts, continuations = zip(*[req.args for req in requests])
inputs = [req.args[0] + req.args[1] for req in requests]
max_gen_len = self.generator.max_gen_len
# We temporarily lower max gen len
self.generator.max_gen_len = 1
_, lls, greedy = self.generator.generate(inputs)
results = []
for p, ll, gr in zip(prompts, lls, greedy):
p_len = len(
self.generator.tokenizer.encode(p, add_bos=False, add_eos=False)
)
results.append((ll[p_len:].sum().item(), gr[p_len:].all().item()))
self.generator.max_gen_len = max_gen_len
return results
def loglikelihood_rolling(self, requests: list[Instance]) -> list[float]:
prompts = [req.args[0] for req in requests]
max_gen_len = self.generator.max_gen_len
# We temporarily lower max gen len
self.generator.max_gen_len = 1
_, lls, _ = self.generator.generate(prompts)
results = []
for ll in lls:
results.append((ll.sum().item(),))
self.generator.max_gen_len = max_gen_len
return results
@torch.no_grad()
def eval_ppl_on_path(
*,
world_rank: int,
world_size: int,
model: LMTransformer | ByteLatentTransformer,
tokenizer_args: TokenizerArgs,
patcher_args: PatcherArgs,
add_patches: bool,
path: str,
batch_size: int,
arrow_batch_size: int,
max_n_docs: int | None,
s3_profile: str | None = None,
):
model.eval()
tokenizer = tokenizer_args.build()
seq_len = model.get_output_seq_len()
chunks = find_and_sanitize_chunks(
path,
world_size=1,
file_pattern="*.val.jsonl",
s3_profile=s3_profile,
)
assert (
len(chunks) == 1
), f"There should be only 1 chunk per validation file, but found: {chunks}"
chunk = chunks[0]
arrow_iterator = ArrowFileIterator(
file_path=chunk,
preprocess_dir=None,
entropy_model_name=None,
worker_id=world_rank,
num_workers=world_size,
arrow_batch_size=arrow_batch_size,
s3_profile=s3_profile,
file_format="json",
)
if max_n_docs is not None:
arrow_iterator = LimitIterator(arrow_iterator, limit=max_n_docs)
preprocess_iterator = PreprocessIterator(
arrow_iterator,
patcher_args=patcher_args,
tokenizer_args=tokenizer_args,
add_patches=add_patches,
)
sequence_iterator = SequenceIterator(
preprocess_iterator,
sequence_packing_args=SequencePackingArgs(
output_seq_len=seq_len,
# Effectively disables shuffles
buffer_size=1,
),
rng_state=None,
)
packing_args = PackingArgs(
batch_size=batch_size,
seq_len=seq_len,
# TODO: make these seq lens worth with blt
max_length=seq_len,
pad_to_max_length=True,
enable_byte_ngrams=False,
pad_id=tokenizer.boe_id,
packing_mode=PackingMode.BYTES,
)
packing_iterator = PackingIterator(sequence_iterator, packing_args=packing_args)
total_loss = 0.0
n_bytes = 0
batch_iterator = packing_iterator.create_iter()
for batch in batch_iterator:
x = torch.from_numpy(batch.x).cuda()
y = torch.from_numpy(batch.y).cuda()
mask = None if batch.mask is None else torch.from_numpy(batch.mask).cuda()
if tokenizer_args.name in ["bytes", "blt"]:
n_bytes += y.numel() if mask is None else mask.sum().item()
pred = model(x)
loss = F.cross_entropy(pred.flatten(0, 1), y.flatten(0, 1), reduction="sum")
total_loss += loss.item()
else:
raise NotImplementedError()
all_n_bytes = to_py_num(dist_sum(n_bytes))
all_total_loss = to_py_num(dist_sum(total_loss))
return {
"n_bytes": all_n_bytes,
"n_bytes_gpu": n_bytes,
"loss_sum": all_total_loss,
"loss_sum_gpu": total_loss,
"loss_mean": all_total_loss / all_n_bytes,
"loss_mean_gpu": total_loss / n_bytes,
"ppl": math.exp(all_total_loss / all_n_bytes) if all_n_bytes > 0 else 0.0,
"bpb": all_total_loss / math.log(2) / all_n_bytes,
}
def eval_on_val(generator, val_args: ValidationArgs, train_cfg: TrainArgs):
srcs = []
for src in val_args.sources:
path = os.path.join(val_args.root_dir, src)
srcs.append(path)
for src in train_cfg.data.sources:
path = os.path.join(train_cfg.data.root_dir, src)
srcs.append(path)
path_to_iter = {}
for path in srcs:
chunks = find_and_sanitize_chunks(
path,
world_size=1,
file_pattern="*.val.jsonl",
s3_profile=train_cfg.data.s3_profile,
)
assert (
len(chunks) == 1
), f"There should be only 1 chunk per validation file, but found: {chunks}"
chunk = chunks[0]
iterator = ArrowFileIterator(
dataset_files=[chunk],
file_path=None,
preprocess_dir=None,
entropy_model_name=None,
worker_id=0,
num_workers=1,
arrow_batch_size=train_cfg.data.arrow_batch_size,
s3_profile=train_cfg.data.s3_profile,
file_format="json",
)
path_to_iter[path] = iterator
max_gen_len = generator.max_gen_len
# We temporarily lower max gen len
generator.max_gen_len = 1
all_val_metrics = {}
for src in path_to_iter:
example_iterator = path_to_iter[src].create_iter()
texts = []
logger.info(f"Running validation on {src}...")
for step, example in enumerate(example_iterator):
texts.append(example.text)
_, loglikelihood, _ = generator.generate(texts)
metrics = defaultdict(list)
for i, ll in enumerate(loglikelihood):
tmp = ll.sum().item()
metrics["nll"].append(tmp)
metrics["nll_per_token"].append(tmp / len(ll))
metrics["nll_per_char"].append(tmp / len(texts[i]))
metrics["avg_seqlen"].append(len(ll))
for m in metrics:
metrics[m] = sum(metrics[m]) / len(metrics[m])
metrics.update(dist_mean_dict(metrics))
logger.info(f"Validation on {src} done. Metrics: {metrics}")
name = os.path.basename(src)
if name in all_val_metrics:
logger.warning(
f"Duplicate source name {name}, path {src} in validation sources, renaming to {name}_1"
)
name = f"{name}_1"
all_val_metrics[name] = metrics
generator.max_gen_len = max_gen_len
return all_val_metrics
def launch_eval(eval_args: EvalArgs):
assert eval_args.dump_dir is not None
assert eval_args.ckpt_dir is not None
distributed_args = DistributedArgs()
distributed_args.configure_world()
if not torch.distributed.is_initialized():
setup_torch_distributed(distributed_args)
world_mesh = get_device_mesh(distributed_args)
dp_mesh = world_mesh["dp_replicate"]
assert distributed_args.dp_shard == 1
world_size = dp_mesh.size()
world_rank = dp_mesh.get_local_rank()
fs = get_fs(eval_args.ckpt_dir, s3_profile=eval_args.s3_profile)
if (
fs.exists(eval_args.ckpt_dir)
and fs.exists(os.path.join(eval_args.ckpt_dir, "params.json"))
and len(fs.glob(os.path.join(eval_args.ckpt_dir, "*.pth"))) != 0
):
consolidate_path = eval_args.ckpt_dir
else:
consolidate_path = os.path.join(eval_args.ckpt_dir, CONSOLIDATE_FOLDER)
if not fs.exists(consolidate_path) and get_global_rank() == 0:
consolidate_path = consolidate_checkpoints(fs, eval_args.ckpt_dir)
fs.mkdirs(eval_args.dump_dir, exist_ok=True)
with fs.open(os.path.join(eval_args.dump_dir, "config.yaml"), "w") as f:
f.write(eval_args.model_dump_json())
torch.distributed.barrier()
logger.info("Loading model")
# TODO: Make this general so that it works with either
# LMTransformer or Blt, similar with args
model, tokenizer, train_cfg = load_consolidated_model_and_tokenizer(
consolidate_path,
)
model.eval()
logger.info("Model loaded")
ppl_results = None
if eval_args.run_ppl:
assert eval_args.validation is not None
if len(eval_args.validation.sources) > 0:
ppl_results = {}
logger.info("Starting PPL evaluation on validation sets")
for source in eval_args.validation.sources:
ppl_results[source] = eval_ppl_on_path(
world_rank=world_rank,
world_size=world_size,
model=model,
tokenizer_args=train_cfg.data.tokenizer_args,
# TODO: Don't hardcode, modify based on model
patcher_args=PatcherArgs(patching_mode=PatchingModeEnum.byte),
add_patches=False,
path=os.path.join(eval_args.validation.root_dir, source),
max_n_docs=eval_args.validation.max_n_docs,
batch_size=8,
arrow_batch_size=100,
s3_profile="blt",
)
task_results = None
if eval_args.run_tasks:
assert eval_args.generator is not None
assert eval_args.harness is not None
generator = PackedCausalTransformerGenerator(
eval_args.generator, model, tokenizer
)
wrap = EvalHarnessLM(generator)
# TODO: This needs to be checked/sped up
task_results = simple_evaluate(wrap, **eval_args.harness.model_dump())
results = {"ppl": ppl_results, "tasks": task_results}
# TODO: Serial and Parallel yield slightly different number of bytes, debug this later,
# leaving this log statement here to help with that.
# logging.info("Rank: %s Results: %s", world_rank, results)
if get_global_rank() == 0:
with fs.open(os.path.join(eval_args.dump_dir, "results.json"), "w") as f:
f.write(json.dumps(results))
logger.info(f"All evaluation results: {results}")
if ppl_results is not None:
with fs.open(os.path.join(eval_args.dump_dir, "validation.json"), "w") as f:
f.write(json.dumps(ppl_results))
logger.info(f"All validation results: {ppl_results}")
if eval_args.metric_log_dir and get_global_rank() == 0:
metric_log_path = os.path.join(eval_args.metric_log_dir, "metrics.eval.jsonl")
logger.info(f"Writing metric logs to {metric_log_path}")
timestamp: dict[str, int | str] = {
"created_at": datetime.utcnow().isoformat(),
}
if eval_args.global_step is not None:
timestamp["global_step"] = eval_args.global_step
print(
json.dumps(timestamp | results),
file=fs.open(metric_log_path, mode="a"),
flush=True,
)
val_log_path = os.path.join(
eval_args.metric_log_dir, "metrics.validation.jsonl"
)
if ppl_results is not None:
print(
json.dumps(timestamp | ppl_results),
file=fs.open(val_log_path, mode="a"),
flush=True,
)
def main():
eval_args = parse_args_to_pydantic_model(EvalArgs)
launch_eval(eval_args)
if __name__ == "__main__":
main()