# Copyright (c) Meta Platforms, Inc. and affiliates. import json import logging import os from collections import defaultdict from datetime import datetime import torch from lm_eval import simple_evaluate from lm_eval.api.instance import Instance from lm_eval.api.model import LM from bytelatent.args import EvalArgs, ValidationArgs from bytelatent.checkpoint import CONSOLIDATE_FOLDER, consolidate_checkpoints from bytelatent.config_parser import parse_args_to_pydantic_model from bytelatent.data.file_util import get_fs from bytelatent.distributed import ( DistributedArgs, dist_mean_dict, get_global_rank, get_world_size, setup_torch_distributed, ) from bytelatent.generate import ( PackedCausalTransformerGenerator, load_consolidated_model_and_tokenizer, ) EVAL_FOLDER_NAME = "{:010d}" logger = logging.getLogger() def all_dicts_same(dict_list): if not dict_list: # Check if the list is empty return True # Compare each dictionary to the first one first_dict = dict_list[0] return all(d == first_dict for d in dict_list) class MockAccelerator: def gather(self, tensor): l = [torch.zeros_like(tensor) for _ in range(get_world_size())] torch.distributed.all_gather(l, tensor) return torch.stack(l) def wait_for_everyone(self): torch.distributed.barrier() # Light wrapper around generator for lm-eval harness class EvalHarnessLM(LM): def __init__(self, generator): super().__init__() self.generator = generator self.accelerator = MockAccelerator() self._rank = get_global_rank() self._world_size = get_world_size() self.device = generator.device def generate_until(self, requests: list[Instance]) -> list[str]: prompts, gen_args = zip(*[req.args for req in requests]) assert all_dicts_same(gen_args), "Doesn't support different gen args for now" gen_args = gen_args[0] temperature = gen_args.get("temperature", 0.0) top_p = gen_args.get("top_p", None) top_k = gen_args.get("top_k", None) until = gen_args.get("until", []) self.generator.temperature = temperature self.generator.top_p = top_p self.generator.top_k = top_k self.generator.until = until generations, _, _ = self.generator.generate(prompts) filtered_gen = [] for g in generations: for e in until: g = g.replace(e, "") filtered_gen.append(g) return filtered_gen def loglikelihood(self, requests: list[Instance]) -> list[tuple[float, bool]]: prompts, continuations = zip(*[req.args for req in requests]) inputs = [req.args[0] + req.args[1] for req in requests] max_gen_len = self.generator.max_gen_len # We temporarily lower max gen len self.generator.max_gen_len = 1 _, lls, greedy = self.generator.generate(inputs) results = [] for p, ll, gr in zip(prompts, lls, greedy): p_len = len( self.generator.tokenizer.encode(p, add_bos=False, add_eos=False) ) results.append((ll[p_len:].sum().item(), gr[p_len:].all().item())) self.generator.max_gen_len = max_gen_len return results def loglikelihood_rolling(self, requests: list[Instance]) -> list[float]: prompts = [req.args[0] for req in requests] max_gen_len = self.generator.max_gen_len # We temporarily lower max gen len self.generator.max_gen_len = 1 _, lls, _ = self.generator.generate(prompts) results = [] for ll in lls: results.append((ll.sum().item(),)) self.generator.max_gen_len = max_gen_len return results def eval_on_val(generator, val_args: ValidationArgs, train_cfg): srcs = {} for src in val_args.sources: path = os.path.join(val_args.root_dir, src) srcs[path] = 1.0 for src in train_cfg.data.sources: path = os.path.join(train_cfg.data.root_dir, src) srcs[path] = 1.0 multi_state = init_choice_state( "", srcs, 0, get_global_rank(), get_world_size(), "*.val.jsonl" ) path_to_iter = setup_sources(multi_state) max_gen_len = generator.max_gen_len # We temporarily lower max gen len generator.max_gen_len = 1 all_val_metrics = {} for src in path_to_iter: jsonl_iterator = path_to_iter[src] texts = [] logger.info(f"Running validation on {src}...") for step, (content, state) in enumerate(jsonl_iterator): if state["current_iter"] > 0 or ( val_args.max_steps is not None and step >= val_args.max_steps ): break content_key = "text" if ("text" in content) else "content" texts.append(content[content_key]) _, loglikelihood, _ = generator.generate(texts) metrics = defaultdict(list) for i, ll in enumerate(loglikelihood): tmp = ll.sum().item() metrics["nll"].append(tmp) metrics["nll_per_token"].append(tmp / len(ll)) metrics["nll_per_char"].append(tmp / len(texts[i])) metrics["avg_seqlen"].append(len(ll)) for m in metrics: metrics[m] = sum(metrics[m]) / len(metrics[m]) metrics.update(dist_mean_dict(metrics)) logger.info(f"Validation on {src} done. Metrics: {metrics}") name = os.path.basename(src) if name in all_val_metrics: logger.warning( f"Duplicate source name {name}, path {src} in validation sources, renaming to {name}_1" ) name = f"{name}_1" all_val_metrics[name] = metrics generator.max_gen_len = max_gen_len return all_val_metrics def launch_eval(eval_args: EvalArgs): if not torch.distributed.is_initialized(): setup_torch_distributed(DistributedArgs()) fs = get_fs(eval_args.ckpt_dir, s3_profile=eval_args.s3_profile) if ( fs.exists(eval_args.ckpt_dir) and fs.exists(os.path.join(eval_args.ckpt_dir, "params.json")) and len(fs.glob(os.path.join(eval_args.ckpt_dir, "*.pth"))) != 0 ): consolidate_path = eval_args.ckpt_dir else: consolidate_path = os.path.join(eval_args.ckpt_dir, CONSOLIDATE_FOLDER) if not fs.exists(consolidate_path) and get_global_rank() == 0: consolidate_path = consolidate_checkpoints(eval_args.ckpt_dir) fs.mkdirs(eval_args.dump_dir, exist_ok=True) with fs.open(os.path.join(eval_args.dump_dir, "config.yaml"), "w") as f: f.write(eval_args.model_dump_json()) torch.distributed.barrier() logger.info("Loading model") # TODO: Make this general so that it works with either # LMTransformer or Blt, similar with args model, tokenizer, train_cfg = load_consolidated_model_and_tokenizer( consolidate_path, ) logger.info("Model loaded") model.eval() generator = PackedCausalTransformerGenerator(eval_args.generator, model, tokenizer) wrap = EvalHarnessLM(generator) # Redo results = simple_evaluate(wrap, eval_args.harness.model_dump()) val_results = None if eval_args.validation: val_results = eval_on_val(generator, eval_args.validation, train_cfg) if get_global_rank() == 0: with fs.open(os.path.join(eval_args.dump_dir, "results.json"), "w") as f: f.write(json.dumps(results)) logger.info(f"All evaluation results: {results['results']}") if val_results is not None: with fs.open(os.path.join(eval_args.dump_dir, "validation.json"), "w") as f: f.write(json.dumps(val_results)) logger.info(f"All validation results: {val_results}") if eval_args.metric_log_dir and get_global_rank() == 0: metric_log_path = os.path.join(eval_args.metric_log_dir, "metrics.eval.jsonl") logger.info(f"Writing metric logs to {metric_log_path}") timestamp = { "created_at": datetime.utcnow().isoformat(), } if eval_args.global_step is not None: timestamp["global_step"] = eval_args.global_step print( json.dumps(timestamp | results["results"]), file=fs.open(metric_log_path, mode="a"), flush=True, ) val_log_path = os.path.join( eval_args.metric_log_dir, "metrics.validation.jsonl" ) if val_results is not None: print( json.dumps(timestamp | val_results), file=fs.open(val_log_path, mode="a"), flush=True, ) del generator def main(): eval_args = parse_args(EvalArgs) launch_eval(eval_args) if __name__ == "__main__": main()