Get evals working again.

- PPL/validation: Works now and uses multi-gpu. For some reason 1 GPU differs from multi-GPU, can debug in a followup PR
- Generation evals likely work, but are very slow, so disabled for now


Test Plan:
```
torchrun --nproc-per-node 8 -m bytelatent.eval config=../internal-blt/configs/eval.yaml
```
This commit is contained in:
Pedro Rodriguez 2025-02-28 00:40:04 +00:00
parent 08b8c7cd05
commit 2cae41fe1f
6 changed files with 276 additions and 101 deletions

View file

@ -263,6 +263,10 @@ class EvalArgs(BaseModel):
dump_dir: str | None = None dump_dir: str | None = None
ckpt_dir: str | None = None ckpt_dir: str | None = None
metric_log_dir: str | None = None metric_log_dir: str | None = None
run_ppl: bool = True
run_tasks: bool = False
generator: PackedCausalTransformerGeneratorArgs = ( generator: PackedCausalTransformerGeneratorArgs = (
PackedCausalTransformerGeneratorArgs() PackedCausalTransformerGeneratorArgs()
) )

View file

@ -15,6 +15,7 @@ from functools import lru_cache, partial, reduce
from itertools import chain from itertools import chain
from typing import List, Optional, Tuple, Union from typing import List, Optional, Tuple, Union
import numpy as np
import torch import torch
# for no recompute ops # for no recompute ops
@ -78,6 +79,40 @@ class DistributedArgs(BaseModel):
spawn_method: str = "forkserver" spawn_method: str = "forkserver"
def configure_world(self):
pass
if self.dp_replicate * self.dp_shard * self.tp_size != get_world_size():
logging.info("Modifying TrainArgs distributed config")
assert get_world_size() % self.dp_shard == 0
logging.info("World size: %s", get_world_size())
logging.info(
"Existing setting: train_args.distributed.dp_shard=%s",
self.dp_shard,
)
logging.info(
"Setting train_args.distributed.dp_replicate=%s, was dp_replicate=%s",
get_world_size() // self.dp_shard,
self.dp_replicate,
)
self.dp_replicate = get_world_size() // self.dp_shard
logging.info(
"Changing dp_replicate from %s to %s, to account for tp_size=%s",
self.dp_replicate,
self.dp_replicate // self.tp_size,
self.tp_size,
)
assert self.dp_replicate % self.tp_size == 0
self.dp_replicate = self.dp_replicate // self.tp_size
logger.warning(
f"Setting Data Parallel size to {self.dp_replicate * self.dp_shard}"
)
assert self.dp_replicate * self.dp_shard * self.tp_size == get_world_size()
if self.fsdp_type == "no_shard":
assert self.dp_shard == 1 and self.dp_replicate == get_world_size()
class EnvironmentArgs(BaseModel): class EnvironmentArgs(BaseModel):
model_config = ConfigDict(extra="forbid") model_config = ConfigDict(extra="forbid")
@ -151,6 +186,13 @@ def dist_mean_dict(x):
return r return r
def to_py_num(num: int | float | torch.Tensor | np.ndarray) -> int | float:
if isinstance(num, (torch.Tensor, np.ndarray)):
return num.item()
else:
return num
@lru_cache() @lru_cache()
def get_is_torch_run() -> bool: def get_is_torch_run() -> bool:
return os.environ.get("LOCAL_RANK") is not None return os.environ.get("LOCAL_RANK") is not None

View file

@ -2,6 +2,7 @@
import json import json
import logging import logging
import math
import os import os
from collections import defaultdict from collections import defaultdict
from datetime import datetime from datetime import datetime
@ -10,22 +11,48 @@ import torch
from lm_eval import simple_evaluate from lm_eval import simple_evaluate
from lm_eval.api.instance import Instance from lm_eval.api.instance import Instance
from lm_eval.api.model import LM from lm_eval.api.model import LM
from rich.progress import track
from torch.nn import functional as F
from bytelatent.args import EvalArgs, ValidationArgs from bytelatent.args import (
EvalArgs,
TrainArgs,
ValidationArgs,
find_and_sanitize_chunks,
)
from bytelatent.checkpoint import CONSOLIDATE_FOLDER, consolidate_checkpoints from bytelatent.checkpoint import CONSOLIDATE_FOLDER, consolidate_checkpoints
from bytelatent.config_parser import parse_args_to_pydantic_model from bytelatent.config_parser import parse_args_to_pydantic_model
from bytelatent.data.file_util import get_fs from bytelatent.data.file_util import get_fs
from bytelatent.data.iterators.arrow_iterator import ArrowFileIterator
from bytelatent.data.iterators.limit_iterator import LimitIterator
from bytelatent.data.iterators.packing_iterator import (
PackingArgs,
PackingIterator,
PackingMode,
)
from bytelatent.data.iterators.preprocess_iterator import PreprocessIterator
from bytelatent.data.iterators.sequence_iterator import (
SequenceIterator,
SequencePackingArgs,
)
from bytelatent.data.patcher import PatcherArgs, PatchingModeEnum
from bytelatent.distributed import ( from bytelatent.distributed import (
DistributedArgs, DistributedArgs,
dist_mean_dict, dist_mean_dict,
dist_sum,
get_device_mesh,
get_global_rank, get_global_rank,
get_world_size, get_world_size,
setup_torch_distributed, setup_torch_distributed,
to_py_num,
) )
from bytelatent.generate import ( from bytelatent.generate import (
PackedCausalTransformerGenerator, PackedCausalTransformerGenerator,
load_consolidated_model_and_tokenizer, load_consolidated_model_and_tokenizer,
) )
from bytelatent.model.blt import ByteLatentTransformer
from bytelatent.tokenizers.build_tokenizer import TokenizerArgs
from bytelatent.transformer import LMTransformer
EVAL_FOLDER_NAME = "{:010d}" EVAL_FOLDER_NAME = "{:010d}"
@ -113,19 +140,134 @@ class EvalHarnessLM(LM):
return results return results
def eval_on_val(generator, val_args: ValidationArgs, train_cfg): @torch.no_grad()
srcs = {} def eval_ppl_on_path(
*,
world_rank: int,
world_size: int,
model: LMTransformer | ByteLatentTransformer,
tokenizer_args: TokenizerArgs,
patcher_args: PatcherArgs,
add_patches: bool,
path: str,
batch_size: int,
arrow_batch_size: int,
max_n_docs: int | None,
s3_profile: str | None = None,
):
model.eval()
tokenizer = tokenizer_args.build()
seq_len = model.get_output_seq_len()
chunks = find_and_sanitize_chunks(
path,
world_size=1,
file_pattern="*.val.jsonl",
s3_profile=s3_profile,
)
assert (
len(chunks) == 1
), f"There should be only 1 chunk per validation file, but found: {chunks}"
chunk = chunks[0]
arrow_iterator = ArrowFileIterator(
file_path=chunk,
preprocess_dir=None,
entropy_model_name=None,
worker_id=world_rank,
num_workers=world_size,
arrow_batch_size=arrow_batch_size,
s3_profile=s3_profile,
file_format="json",
)
if max_n_docs is not None:
arrow_iterator = LimitIterator(arrow_iterator, limit=max_n_docs)
preprocess_iterator = PreprocessIterator(
arrow_iterator,
patcher_args=patcher_args,
tokenizer_args=tokenizer_args,
add_patches=add_patches,
)
sequence_iterator = SequenceIterator(
preprocess_iterator,
sequence_packing_args=SequencePackingArgs(
output_seq_len=seq_len,
# Effectively disables shuffles
buffer_size=1,
),
rng_state=None,
)
packing_args = PackingArgs(
batch_size=batch_size,
seq_len=seq_len,
# TODO: make these seq lens worth with blt
max_length=seq_len,
pad_to_max_length=True,
enable_byte_ngrams=False,
pad_id=tokenizer.boe_id,
packing_mode=PackingMode.BYTES,
)
packing_iterator = PackingIterator(sequence_iterator, packing_args=packing_args)
total_loss = 0.0
n_bytes = 0
batch_iterator = packing_iterator.create_iter()
for batch in batch_iterator:
x = torch.from_numpy(batch.x).cuda()
y = torch.from_numpy(batch.y).cuda()
mask = None if batch.mask is None else torch.from_numpy(batch.mask).cuda()
if tokenizer_args.name in ["bytes", "blt"]:
n_bytes += y.numel() if mask is None else mask.sum().item()
pred = model(x)
loss = F.cross_entropy(pred.flatten(0, 1), y.flatten(0, 1), reduction="sum")
total_loss += loss.item()
else:
raise NotImplementedError()
all_n_bytes = to_py_num(dist_sum(n_bytes))
all_total_loss = to_py_num(dist_sum(total_loss))
return {
"n_bytes": all_n_bytes,
"n_bytes_gpu": n_bytes,
"loss_sum": all_total_loss,
"loss_sum_gpu": total_loss,
"loss_mean": all_total_loss / all_n_bytes,
"loss_mean_gpu": total_loss / n_bytes,
"ppl": math.exp(all_total_loss / all_n_bytes) if all_n_bytes > 0 else 0.0,
"bpb": all_total_loss / math.log(2) / all_n_bytes,
}
def eval_on_val(generator, val_args: ValidationArgs, train_cfg: TrainArgs):
srcs = []
for src in val_args.sources: for src in val_args.sources:
path = os.path.join(val_args.root_dir, src) path = os.path.join(val_args.root_dir, src)
srcs[path] = 1.0 srcs.append(path)
for src in train_cfg.data.sources: for src in train_cfg.data.sources:
path = os.path.join(train_cfg.data.root_dir, src) path = os.path.join(train_cfg.data.root_dir, src)
srcs[path] = 1.0 srcs.append(path)
multi_state = init_choice_state( path_to_iter = {}
"", srcs, 0, get_global_rank(), get_world_size(), "*.val.jsonl" for path in srcs:
) chunks = find_and_sanitize_chunks(
path_to_iter = setup_sources(multi_state) path,
world_size=1,
file_pattern="*.val.jsonl",
s3_profile=train_cfg.data.s3_profile,
)
assert (
len(chunks) == 1
), f"There should be only 1 chunk per validation file, but found: {chunks}"
chunk = chunks[0]
iterator = ArrowFileIterator(
dataset_files=[chunk],
file_path=None,
preprocess_dir=None,
entropy_model_name=None,
worker_id=0,
num_workers=1,
arrow_batch_size=train_cfg.data.arrow_batch_size,
s3_profile=train_cfg.data.s3_profile,
file_format="json",
)
path_to_iter[path] = iterator
max_gen_len = generator.max_gen_len max_gen_len = generator.max_gen_len
# We temporarily lower max gen len # We temporarily lower max gen len
@ -133,16 +275,11 @@ def eval_on_val(generator, val_args: ValidationArgs, train_cfg):
all_val_metrics = {} all_val_metrics = {}
for src in path_to_iter: for src in path_to_iter:
jsonl_iterator = path_to_iter[src] example_iterator = path_to_iter[src].create_iter()
texts = [] texts = []
logger.info(f"Running validation on {src}...") logger.info(f"Running validation on {src}...")
for step, (content, state) in enumerate(jsonl_iterator): for step, example in enumerate(example_iterator):
if state["current_iter"] > 0 or ( texts.append(example.text)
val_args.max_steps is not None and step >= val_args.max_steps
):
break
content_key = "text" if ("text" in content) else "content"
texts.append(content[content_key])
_, loglikelihood, _ = generator.generate(texts) _, loglikelihood, _ = generator.generate(texts)
@ -174,8 +311,18 @@ def eval_on_val(generator, val_args: ValidationArgs, train_cfg):
def launch_eval(eval_args: EvalArgs): def launch_eval(eval_args: EvalArgs):
assert eval_args.dump_dir is not None
assert eval_args.ckpt_dir is not None
distributed_args = DistributedArgs()
distributed_args.configure_world()
if not torch.distributed.is_initialized(): if not torch.distributed.is_initialized():
setup_torch_distributed(DistributedArgs()) setup_torch_distributed(distributed_args)
world_mesh = get_device_mesh(distributed_args)
dp_mesh = world_mesh["dp_replicate"]
assert distributed_args.dp_shard == 1
world_size = dp_mesh.size()
world_rank = dp_mesh.get_local_rank()
fs = get_fs(eval_args.ckpt_dir, s3_profile=eval_args.s3_profile) fs = get_fs(eval_args.ckpt_dir, s3_profile=eval_args.s3_profile)
if ( if (
@ -187,7 +334,7 @@ def launch_eval(eval_args: EvalArgs):
else: else:
consolidate_path = os.path.join(eval_args.ckpt_dir, CONSOLIDATE_FOLDER) consolidate_path = os.path.join(eval_args.ckpt_dir, CONSOLIDATE_FOLDER)
if not fs.exists(consolidate_path) and get_global_rank() == 0: if not fs.exists(consolidate_path) and get_global_rank() == 0:
consolidate_path = consolidate_checkpoints(eval_args.ckpt_dir) consolidate_path = consolidate_checkpoints(fs, eval_args.ckpt_dir)
fs.mkdirs(eval_args.dump_dir, exist_ok=True) fs.mkdirs(eval_args.dump_dir, exist_ok=True)
with fs.open(os.path.join(eval_args.dump_dir, "config.yaml"), "w") as f: with fs.open(os.path.join(eval_args.dump_dir, "config.yaml"), "w") as f:
@ -200,35 +347,67 @@ def launch_eval(eval_args: EvalArgs):
model, tokenizer, train_cfg = load_consolidated_model_and_tokenizer( model, tokenizer, train_cfg = load_consolidated_model_and_tokenizer(
consolidate_path, consolidate_path,
) )
logger.info("Model loaded")
model.eval() model.eval()
generator = PackedCausalTransformerGenerator(eval_args.generator, model, tokenizer) logger.info("Model loaded")
ppl_results = None
if eval_args.run_ppl:
assert eval_args.validation is not None
if len(eval_args.validation.sources) > 0:
ppl_results = {}
logger.info("Starting PPL evaluation on validation sets")
for source in eval_args.validation.sources:
ppl_results[source] = eval_ppl_on_path(
world_rank=world_rank,
world_size=world_size,
model=model,
tokenizer_args=train_cfg.data.tokenizer_args,
# TODO: Don't hardcode, modify based on model
patcher_args=PatcherArgs(patching_mode=PatchingModeEnum.byte),
add_patches=False,
path=os.path.join(eval_args.validation.root_dir, source),
max_n_docs=eval_args.validation.max_n_docs,
batch_size=8,
arrow_batch_size=100,
s3_profile="blt",
)
task_results = None
if eval_args.run_tasks:
assert eval_args.generator is not None
assert eval_args.harness is not None
generator = PackedCausalTransformerGenerator(
eval_args.generator, model, tokenizer
)
wrap = EvalHarnessLM(generator)
# TODO: This needs to be checked/sped up
task_results = simple_evaluate(wrap, **eval_args.harness.model_dump())
results = {"ppl": ppl_results, "tasks": task_results}
# TODO: Serial and Parallel yield slightly different number of bytes, debug this later,
# leaving this log statement here to help with that.
# logging.info("Rank: %s Results: %s", world_rank, results)
wrap = EvalHarnessLM(generator)
# Redo
results = simple_evaluate(wrap, eval_args.harness.model_dump())
val_results = None
if eval_args.validation:
val_results = eval_on_val(generator, eval_args.validation, train_cfg)
if get_global_rank() == 0: if get_global_rank() == 0:
with fs.open(os.path.join(eval_args.dump_dir, "results.json"), "w") as f: with fs.open(os.path.join(eval_args.dump_dir, "results.json"), "w") as f:
f.write(json.dumps(results)) f.write(json.dumps(results))
logger.info(f"All evaluation results: {results['results']}") logger.info(f"All evaluation results: {results}")
if val_results is not None: if ppl_results is not None:
with fs.open(os.path.join(eval_args.dump_dir, "validation.json"), "w") as f: with fs.open(os.path.join(eval_args.dump_dir, "validation.json"), "w") as f:
f.write(json.dumps(val_results)) f.write(json.dumps(ppl_results))
logger.info(f"All validation results: {val_results}") logger.info(f"All validation results: {ppl_results}")
if eval_args.metric_log_dir and get_global_rank() == 0: if eval_args.metric_log_dir and get_global_rank() == 0:
metric_log_path = os.path.join(eval_args.metric_log_dir, "metrics.eval.jsonl") metric_log_path = os.path.join(eval_args.metric_log_dir, "metrics.eval.jsonl")
logger.info(f"Writing metric logs to {metric_log_path}") logger.info(f"Writing metric logs to {metric_log_path}")
timestamp = { timestamp: dict[str, int | str] = {
"created_at": datetime.utcnow().isoformat(), "created_at": datetime.utcnow().isoformat(),
} }
if eval_args.global_step is not None: if eval_args.global_step is not None:
timestamp["global_step"] = eval_args.global_step timestamp["global_step"] = eval_args.global_step
print( print(
json.dumps(timestamp | results["results"]), json.dumps(timestamp | results),
file=fs.open(metric_log_path, mode="a"), file=fs.open(metric_log_path, mode="a"),
flush=True, flush=True,
) )
@ -236,18 +415,16 @@ def launch_eval(eval_args: EvalArgs):
val_log_path = os.path.join( val_log_path = os.path.join(
eval_args.metric_log_dir, "metrics.validation.jsonl" eval_args.metric_log_dir, "metrics.validation.jsonl"
) )
if val_results is not None: if ppl_results is not None:
print( print(
json.dumps(timestamp | val_results), json.dumps(timestamp | ppl_results),
file=fs.open(val_log_path, mode="a"), file=fs.open(val_log_path, mode="a"),
flush=True, flush=True,
) )
del generator
def main(): def main():
eval_args = parse_args(EvalArgs) eval_args = parse_args_to_pydantic_model(EvalArgs)
launch_eval(eval_args) launch_eval(eval_args)

View file

@ -387,8 +387,7 @@ def load_consolidated_model_and_tokenizer(
): ):
train_args_path = os.path.join(consolidated_path, "params.json") train_args_path = os.path.join(consolidated_path, "params.json")
fs = get_fs(train_args_path) fs = get_fs(train_args_path)
with fs.open(train_args_path) as f: train_args = TrainArgs.model_validate_json(fs.read_text(train_args_path))
train_args = TrainArgs.model_validate_json(f.read())
if train_args.train_entropy_model: if train_args.train_entropy_model:
model_args = train_args.entropy_model model_args = train_args.entropy_model
@ -401,7 +400,8 @@ def load_consolidated_model_and_tokenizer(
train_args.distributed.model_dtype train_args.distributed.model_dtype
] ]
tokenizer = train_args.data.tokenizer_args.build() tokenizer = train_args.data.tokenizer_args.build()
st_dict = torch.load(consolidated_path / CONSOLIDATE_NAME, weights_only=True) with fs.open(os.path.join(consolidated_path, CONSOLIDATE_NAME)) as f:
st_dict = torch.load(f, weights_only=True)
model.load_state_dict(st_dict["model"]) model.load_state_dict(st_dict["model"])
model = model.cuda().eval() model = model.cuda().eval()
for param in model.parameters(): for param in model.parameters():

View file

@ -55,7 +55,7 @@ class LoggingArgs(BaseModel):
class MetricLogger: class MetricLogger:
def __init__( def __init__(
self, self,
outdir: Path, outdir: str,
# args: TrainArgs # args: TrainArgs
args: Any | None = None, args: Any | None = None,
fs: fsspec.AbstractFileSystem | None = None, fs: fsspec.AbstractFileSystem | None = None,

View file

@ -46,6 +46,7 @@ from bytelatent.distributed import (
requeue_slurm_job, requeue_slurm_job,
setup_env, setup_env,
setup_torch_distributed, setup_torch_distributed,
to_py_num,
) )
from bytelatent.eval import EVAL_FOLDER_NAME, launch_eval from bytelatent.eval import EVAL_FOLDER_NAME, launch_eval
from bytelatent.logger import init_logger from bytelatent.logger import init_logger
@ -89,13 +90,6 @@ def get_iterator_state_name(iterator_state):
raise ValueError(f"Unsupported iterator to get name from: {iterator_state}") raise ValueError(f"Unsupported iterator to get name from: {iterator_state}")
def to_py_num(num: int | float | torch.Tensor | np.ndarray) -> int | float:
if isinstance(num, (torch.Tensor, np.ndarray)):
return num.item()
else:
return num
# TODO: Make this pydantic based instead of data class based # TODO: Make this pydantic based instead of data class based
# TODO: Generalize this to any iterator state # TODO: Generalize this to any iterator state
@dataclass @dataclass
@ -152,57 +146,13 @@ def validate_train_args(args: TrainArgs, output_size: int):
logger.info(f"Setting checkpoint path to {args.checkpoint.path}") logger.info(f"Setting checkpoint path to {args.checkpoint.path}")
args.checkpoint.path = os.path.join(args.dump_dir, "checkpoints") args.checkpoint.path = os.path.join(args.dump_dir, "checkpoints")
data_fs = get_fs(args.data.root_dir, s3_profile=args.data.s3_profile) if args.data.root_dir is not None:
for source in args.data.sources: data_fs = get_fs(args.data.root_dir, s3_profile=args.data.s3_profile)
data_path = os.path.join(args.data.root_dir, source) for source in args.data.sources:
assert data_fs.exists(data_path), f"{data_path} doesn't exist" data_path = os.path.join(args.data.root_dir, source)
assert data_fs.exists(data_path), f"{data_path} doesn't exist"
if ( args.distributed.configure_world()
args.distributed.dp_replicate
* args.distributed.dp_shard
* args.distributed.tp_size
!= get_world_size()
):
logging.info("Modifying TrainArgs distributed config")
assert get_world_size() % args.distributed.dp_shard == 0
logging.info("World size: %s", get_world_size())
logging.info(
"Existing setting: train_args.distributed.dp_shard=%s",
args.distributed.dp_shard,
)
logging.info(
"Setting train_args.distributed.dp_replicate=%s, was dp_replicate=%s",
get_world_size() // args.distributed.dp_shard,
args.distributed.dp_replicate,
)
args.distributed.dp_replicate = get_world_size() // args.distributed.dp_shard
logging.info(
"Changing dp_replicate from %s to %s, to account for tp_size=%s",
args.distributed.dp_replicate,
args.distributed.dp_replicate // args.distributed.tp_size,
args.distributed.tp_size,
)
assert args.distributed.dp_replicate % args.distributed.tp_size == 0
args.distributed.dp_replicate = (
args.distributed.dp_replicate // args.distributed.tp_size
)
logger.warning(
f"Setting Data Parallel size to {args.distributed.dp_replicate * args.distributed.dp_shard}"
)
assert (
args.distributed.dp_replicate
* args.distributed.dp_shard
* args.distributed.tp_size
== get_world_size()
)
if args.distributed.fsdp_type == "no_shard":
assert (
args.distributed.dp_shard == 1
and args.distributed.dp_replicate == get_world_size()
)
if args.model is not None: if args.model is not None:
args.model.max_seqlen = args.data.seq_len args.model.max_seqlen = args.data.seq_len
@ -241,7 +191,9 @@ def set_preemption_flag(signum, frame):
preemption_flag["flag"] = True preemption_flag["flag"] = True
def every_n_steps(train_state, freq, acc_step=None, acc_freq=None): def every_n_steps(train_state, freq: int, acc_step=None, acc_freq=None):
if freq < 0:
return False
test = train_state.step % freq == 0 test = train_state.step % freq == 0
if acc_step is not None: if acc_step is not None:
test = test and (train_state.acc_step == acc_step) test = test and (train_state.acc_step == acc_step)
@ -268,7 +220,7 @@ def train(args: TrainArgs):
tokenizer = args.data.tokenizer_args.build() tokenizer = args.data.tokenizer_args.build()
validate_train_args( validate_train_args(
args, args,
tokenizer.n_words, tokenizer.get_vocab_size(),
) )
dump_fs = get_fs(args.dump_dir, s3_profile=args.checkpoint.s3_profile) dump_fs = get_fs(args.dump_dir, s3_profile=args.checkpoint.s3_profile)
if get_is_master(): if get_is_master():