mirror of
https://github.com/facebookresearch/blt.git
synced 2025-01-19 00:47:44 +00:00
355 lines
12 KiB
Python
355 lines
12 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
|
|
import json
|
|
import logging
|
|
import os
|
|
from collections import defaultdict
|
|
from dataclasses import asdict, dataclass, field
|
|
from datetime import datetime
|
|
from pathlib import Path
|
|
from typing import Any, List, Optional, Tuple, Union
|
|
|
|
import torch
|
|
from lingua.args import dump_config
|
|
from lingua.data import init_choice_state, setup_sources
|
|
from lm_eval import simple_evaluate
|
|
from lm_eval.api.instance import Instance
|
|
from lm_eval.api.model import LM
|
|
from omegaconf import OmegaConf
|
|
|
|
from bytelatent.checkpoint import CONSOLIDATE_FOLDER, consolidate_checkpoints
|
|
from bytelatent.distributed import (
|
|
DistributedArgs,
|
|
dist_mean_dict,
|
|
get_global_rank,
|
|
get_world_size,
|
|
setup_torch_distributed,
|
|
)
|
|
from bytelatent.transformer import LMTransformer, LMTransformerArgs
|
|
|
|
from apps.main.generate import (
|
|
PackedCausalTransformerGenerator,
|
|
PackedCausalTransformerGeneratorArgs,
|
|
load_consolidated_model_and_tokenizer,
|
|
)
|
|
|
|
EVAL_FOLDER_NAME = "{:010d}"
|
|
|
|
logger = logging.getLogger()
|
|
|
|
|
|
@dataclass
|
|
class LMHarnessArgs:
|
|
tasks: Optional[List[Any]] = None
|
|
num_fewshot: Optional[int] = None
|
|
device: Optional[str] = None
|
|
use_cache: Optional[str] = None
|
|
cache_requests: bool = False
|
|
rewrite_requests_cache: bool = False
|
|
delete_requests_cache: bool = False
|
|
limit: Optional[Union[int, float]] = None
|
|
bootstrap_iters: int = 100000
|
|
check_integrity: bool = False
|
|
write_out: bool = False
|
|
log_samples: bool = True
|
|
system_instruction: Optional[str] = None
|
|
apply_chat_template: Union[bool, str] = False
|
|
fewshot_as_multiturn: bool = False
|
|
gen_kwargs: Optional[str] = None
|
|
verbosity: str = "INFO"
|
|
predict_only: bool = False
|
|
random_seed: int = 0
|
|
numpy_random_seed: int = 1234
|
|
torch_random_seed: int = 1234
|
|
fewshot_random_seed: int = 1234
|
|
|
|
|
|
@dataclass
|
|
class ValidationArgs:
|
|
max_steps: Optional[int] = (
|
|
None # If None the whole validation file is used -> /!\ This number of steps is gpu dependent (100 max steps on 8 gpus = 800 steps on 1 gpu)
|
|
)
|
|
use_val_from_train_src: bool = True # Use the validation set from training sources
|
|
root_dir: str = ""
|
|
sources: List[str] = field(default_factory=list) # Other sources to eval on
|
|
|
|
|
|
@dataclass
|
|
class EvalArgs:
|
|
name: str = "evals"
|
|
dump_dir: Optional[str] = None
|
|
metric_log_dir: Optional[str] = None
|
|
ckpt_dir: str = ""
|
|
generator: PackedCausalTransformerGeneratorArgs = field(
|
|
default_factory=PackedCausalTransformerGeneratorArgs
|
|
)
|
|
harness: Optional[LMHarnessArgs] = field(default_factory=LMHarnessArgs)
|
|
validation: Optional[ValidationArgs] = field(default_factory=ValidationArgs)
|
|
|
|
wandb: Optional[Any] = None
|
|
|
|
global_step: Optional[int] = None # for in-training evaluation
|
|
|
|
|
|
def all_dicts_same(dict_list):
|
|
if not dict_list: # Check if the list is empty
|
|
return True
|
|
|
|
# Compare each dictionary to the first one
|
|
first_dict = dict_list[0]
|
|
return all(d == first_dict for d in dict_list)
|
|
|
|
|
|
class MockAccelerator:
|
|
def gather(self, tensor):
|
|
l = [torch.zeros_like(tensor) for _ in range(get_world_size())]
|
|
torch.distributed.all_gather(l, tensor)
|
|
return torch.stack(l)
|
|
|
|
def wait_for_everyone(self):
|
|
torch.distributed.barrier()
|
|
|
|
|
|
# Light wrapper around generator for lm-eval harness
|
|
class EvalHarnessLM(LM):
|
|
def __init__(self, generator):
|
|
super().__init__()
|
|
self.generator = generator
|
|
self.accelerator = MockAccelerator()
|
|
self._rank = get_global_rank()
|
|
self._world_size = get_world_size()
|
|
self.device = generator.device
|
|
|
|
def generate_until(self, requests: List[Instance]) -> List[str]:
|
|
prompts, gen_args = zip(*[req.args for req in requests])
|
|
assert all_dicts_same(gen_args), "Doesn't support different gen args for now"
|
|
gen_args = gen_args[0]
|
|
temperature = gen_args.get("temperature", 0.0)
|
|
top_p = gen_args.get("top_p", None)
|
|
top_k = gen_args.get("top_k", None)
|
|
until = gen_args.get("until", [])
|
|
|
|
self.generator.temperature = temperature
|
|
self.generator.top_p = top_p
|
|
self.generator.top_k = top_k
|
|
self.generator.until = until
|
|
generations, _, _ = self.generator.generate(prompts)
|
|
filtered_gen = []
|
|
for g in generations:
|
|
for e in until:
|
|
g = g.replace(e, "")
|
|
filtered_gen.append(g)
|
|
return filtered_gen
|
|
|
|
def loglikelihood(self, requests: List[Instance]) -> List[Tuple[float, bool]]:
|
|
prompts, continuations = zip(*[req.args for req in requests])
|
|
inputs = [req.args[0] + req.args[1] for req in requests]
|
|
max_gen_len = self.generator.max_gen_len
|
|
# We temporarily lower max gen len
|
|
self.generator.max_gen_len = 1
|
|
_, lls, greedy = self.generator.generate(inputs)
|
|
results = []
|
|
for p, ll, gr in zip(prompts, lls, greedy):
|
|
p_len = len(
|
|
self.generator.tokenizer.encode(p, add_bos=False, add_eos=False)
|
|
)
|
|
results.append((ll[p_len:].sum().item(), gr[p_len:].all().item()))
|
|
|
|
self.generator.max_gen_len = max_gen_len
|
|
return results
|
|
|
|
def loglikelihood_rolling(self, requests: List[Instance]) -> List[float]:
|
|
prompts = [req.args[0] for req in requests]
|
|
max_gen_len = self.generator.max_gen_len
|
|
# We temporarily lower max gen len
|
|
self.generator.max_gen_len = 1
|
|
_, lls, _ = self.generator.generate(prompts)
|
|
results = []
|
|
for ll in lls:
|
|
results.append((ll.sum().item(),))
|
|
self.generator.max_gen_len = max_gen_len
|
|
|
|
return results
|
|
|
|
|
|
def eval_on_val(generator, val_args: ValidationArgs, train_cfg):
|
|
srcs = {}
|
|
for src in val_args.sources:
|
|
path = os.path.join(val_args.root_dir, src)
|
|
srcs[path] = 1.0
|
|
for src in train_cfg.data.sources:
|
|
path = os.path.join(train_cfg.data.root_dir, src)
|
|
srcs[path] = 1.0
|
|
|
|
multi_state = init_choice_state(
|
|
"", srcs, 0, get_global_rank(), get_world_size(), "*.val.jsonl"
|
|
)
|
|
path_to_iter = setup_sources(multi_state)
|
|
|
|
max_gen_len = generator.max_gen_len
|
|
# We temporarily lower max gen len
|
|
generator.max_gen_len = 1
|
|
|
|
all_val_metrics = {}
|
|
for src in path_to_iter:
|
|
jsonl_iterator = path_to_iter[src]
|
|
texts = []
|
|
logger.info(f"Running validation on {src}...")
|
|
for step, (content, state) in enumerate(jsonl_iterator):
|
|
if state["current_iter"] > 0 or (
|
|
val_args.max_steps is not None and step >= val_args.max_steps
|
|
):
|
|
break
|
|
content_key = "text" if ("text" in content) else "content"
|
|
texts.append(content[content_key])
|
|
|
|
_, loglikelihood, _ = generator.generate(texts)
|
|
|
|
metrics = defaultdict(list)
|
|
for i, ll in enumerate(loglikelihood):
|
|
tmp = ll.sum().item()
|
|
metrics["nll"].append(tmp)
|
|
metrics["nll_per_token"].append(tmp / len(ll))
|
|
metrics["nll_per_char"].append(tmp / len(texts[i]))
|
|
|
|
metrics["avg_seqlen"].append(len(ll))
|
|
|
|
for m in metrics:
|
|
metrics[m] = sum(metrics[m]) / len(metrics[m])
|
|
metrics.update(dist_mean_dict(metrics))
|
|
logger.info(f"Validation on {src} done. Metrics: {metrics}")
|
|
|
|
name = os.path.basename(src)
|
|
if name in all_val_metrics:
|
|
logger.warning(
|
|
f"Duplicate source name {name}, path {src} in validation sources, renaming to {name}_1"
|
|
)
|
|
name = f"{name}_1"
|
|
all_val_metrics[name] = metrics
|
|
|
|
generator.max_gen_len = max_gen_len
|
|
|
|
return all_val_metrics
|
|
|
|
|
|
def launch_eval(cfg: EvalArgs):
|
|
if not torch.distributed.is_initialized():
|
|
setup_torch_distributed(DistributedArgs())
|
|
if (
|
|
Path(cfg.ckpt_dir).exists()
|
|
and (Path(cfg.ckpt_dir) / "params.json").exists()
|
|
and next(Path(cfg.ckpt_dir).glob("*.pth"), None) is not None
|
|
):
|
|
consolidate_path = Path(cfg.ckpt_dir)
|
|
else:
|
|
consolidate_path = Path(cfg.ckpt_dir) / CONSOLIDATE_FOLDER
|
|
if not consolidate_path.exists() and get_global_rank() == 0:
|
|
consolidate_path = consolidate_checkpoints(cfg.ckpt_dir)
|
|
|
|
Path(cfg.dump_dir).mkdir(parents=True, exist_ok=True)
|
|
dump_config(cfg, Path(cfg.dump_dir) / "config.yaml", log_config=False)
|
|
|
|
consolidate_path = str(consolidate_path)
|
|
torch.distributed.barrier()
|
|
logger.info("Loading model")
|
|
model, tokenizer, train_cfg = load_consolidated_model_and_tokenizer(
|
|
consolidate_path,
|
|
model_cls=LMTransformer,
|
|
model_args_cls=LMTransformerArgs,
|
|
)
|
|
logger.info("Model loaded")
|
|
model.eval()
|
|
generator = PackedCausalTransformerGenerator(cfg.generator, model, tokenizer)
|
|
|
|
wrap = EvalHarnessLM(generator)
|
|
results = simple_evaluate(wrap, **asdict(cfg.harness))
|
|
val_results = None
|
|
if cfg.validation:
|
|
val_results = eval_on_val(generator, cfg.validation, train_cfg)
|
|
if get_global_rank() == 0:
|
|
with open(Path(cfg.dump_dir) / "results.json", "w") as f:
|
|
f.write(json.dumps(results))
|
|
logger.info(f"All evaluation results: {results['results']}")
|
|
if val_results is not None:
|
|
with open(Path(cfg.dump_dir) / "validation.json", "w") as f:
|
|
f.write(json.dumps(val_results))
|
|
logger.info(f"All validation results: {val_results}")
|
|
if cfg.metric_log_dir and get_global_rank() == 0:
|
|
metric_log_path = Path(cfg.metric_log_dir) / "metrics.eval.jsonl"
|
|
|
|
logger.info(f"Writing metric logs to {metric_log_path}")
|
|
timestamp = {
|
|
"created_at": datetime.utcnow().isoformat(),
|
|
}
|
|
if cfg.global_step is not None:
|
|
timestamp["global_step"] = cfg.global_step
|
|
print(
|
|
json.dumps(timestamp | results["results"]),
|
|
file=open(metric_log_path, mode="a"),
|
|
flush=True,
|
|
)
|
|
|
|
val_log_path = Path(cfg.metric_log_dir) / "metrics.validation.jsonl"
|
|
if val_results is not None:
|
|
print(
|
|
json.dumps(timestamp | val_results),
|
|
file=open(val_log_path, mode="a"),
|
|
flush=True,
|
|
)
|
|
|
|
del generator
|
|
|
|
|
|
def main():
|
|
"""
|
|
The command line interface here uses OmegaConf https://omegaconf.readthedocs.io/en/2.3_branch/usage.html#from-command-line-arguments
|
|
This accepts arguments as a dot list
|
|
So if the dataclass looks like
|
|
|
|
@dataclass
|
|
class DummyArgs:
|
|
name: str
|
|
model: LMTransformerArgsgs
|
|
|
|
@dataclass
|
|
class LMTransformerArgsgs:
|
|
dim: int
|
|
|
|
Then you can pass model.dim=32 to change values in LMTransformerArgsgs
|
|
or just name=tictac for top level attributes.
|
|
|
|
The behavior here is as follows:
|
|
1. We instantiate EvalArgs with its default values
|
|
2. We override those default values with the ones in the provided config file
|
|
3. We override the result with the additional arguments provided through command line
|
|
|
|
For example, if the config is the following
|
|
|
|
model:
|
|
dim: 128
|
|
n_layers: 4
|
|
|
|
and you call eval.py with eval.py model.dim=64
|
|
|
|
Then the final TrainArgs will have
|
|
|
|
model:
|
|
dim: 64
|
|
n_layers: 4
|
|
|
|
Plus all the default values in EvalArgs dataclass.
|
|
"""
|
|
cli_args = OmegaConf.from_cli()
|
|
file_cfg = OmegaConf.load(cli_args.config)
|
|
# We remove 'config' attribute from config as the underlying DataClass does not have it
|
|
del cli_args.config
|
|
|
|
default_cfg = OmegaConf.structured(EvalArgs())
|
|
cfg = OmegaConf.merge(default_cfg, file_cfg, cli_args)
|
|
cfg = OmegaConf.to_object(cfg)
|
|
launch_eval(cfg)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|