Update ppl evals to work with blt model, in addition to entropy model

Summary:

Test Plan:

Run
```
python -m bytelatent.eval config=../internal-blt/configs/eval_blt.yaml validation.max_n_docs=null
python -m bytelatent.eval config=../internal-blt/configs/eval_entropy.yaml validation.max_n_docs=null
```
This commit is contained in:
Pedro Rodriguez 2025-03-13 17:14:53 +00:00
parent f84ee635bd
commit 719900d4bd
3 changed files with 61 additions and 111 deletions
bytelatent

View file

@ -263,6 +263,7 @@ class ValidationArgs(BaseModel):
use_val_from_train_src: bool = True # Use the validation set from training sources
root_dir: str = ""
sources: list[str] = [] # Other sources to eval on
batch_size: int = 8
class EvalArgs(BaseModel):

View file

@ -221,6 +221,7 @@ class PackingIterator(StatefulIterator[Batch, PackingIteratorState]):
enable_byte_ngrams = self.packing_args.enable_byte_ngrams
max_length = self.packing_args.max_length
assert max_length is not None
final_leftover_batch = False
while True:
tokens: list[list[int]] = []
masks: list[list[bool]] = []
@ -252,6 +253,9 @@ class PackingIterator(StatefulIterator[Batch, PackingIteratorState]):
break
x_patch_lengths = np.array(patch_lengths)
assert (
x_patch_lengths.shape[1] == seq_len
), f"{x_patch_lengths.shape[1]} vs {seq_len}"
# pad batch to same length
tok_seq_len = max([len(toks) for toks in tokens]) - 1
x = np.full((batch_size, tok_seq_len), fill_value=pad_id)
@ -263,7 +267,30 @@ class PackingIterator(StatefulIterator[Batch, PackingIteratorState]):
# Adjust patch lengths to match x
x_patch_lengths[i, -1] += tok_seq_len - (len(tok_seq) - 1)
assert x_patch_lengths.shape == (batch_size, seq_len)
if x_patch_lengths.shape[0] < batch_size:
if final_leftover_batch:
raise ValueError(
"There should only be one partial batch, but found multiple"
)
final_leftover_batch = True
assert len(masks) == len(x_patch_lengths)
n_missing = batch_size - x_patch_lengths.shape[0]
# Repeat the last patch length to validly pad it out, but
# update the mask to ignore the row
x_patch_lengths = np.vstack(
[
x_patch_lengths,
np.repeat(x_patch_lengths[-1:, :], n_missing, axis=0),
]
)
for _ in range(n_missing):
masks.append([0] * tok_seq_len)
assert len(masks) == batch_size
assert x_patch_lengths.shape == (
batch_size,
seq_len,
), f"{x_patch_lengths.shape} vs {(batch_size, seq_len)}"
if enable_byte_ngrams:
raise NotImplementedError()

View file

@ -148,35 +148,25 @@ def eval_ppl_on_path(
model: LMTransformer | ByteLatentTransformer,
tokenizer_args: TokenizerArgs,
patcher_args: PatcherArgs,
packing_args: PackingArgs,
add_patches: bool,
path: str,
batch_size: int,
arrow_batch_size: int,
max_n_docs: int | None,
s3_profile: str | None = None,
):
model.eval()
tokenizer = tokenizer_args.build()
seq_len = model.get_output_seq_len()
chunks = find_and_sanitize_chunks(
path,
world_size=1,
file_pattern="*.val.jsonl",
s3_profile=s3_profile,
)
assert (
len(chunks) == 1
), f"There should be only 1 chunk per validation file, but found: {chunks}"
chunk = chunks[0]
arrow_iterator = ArrowFileIterator(
file_path=chunk,
preprocess_dir=None,
file_path=None,
dataset_files=[path],
entropy_model_name=None,
worker_id=world_rank,
num_workers=world_size,
arrow_batch_size=arrow_batch_size,
preprocess_dir=None,
s3_profile=s3_profile,
file_format="json",
file_format="arrow" if path.endswith("arrow") else "json",
)
if max_n_docs is not None:
arrow_iterator = LimitIterator(arrow_iterator, limit=max_n_docs)
@ -195,16 +185,6 @@ def eval_ppl_on_path(
),
rng_state=None,
)
packing_args = PackingArgs(
batch_size=batch_size,
seq_len=seq_len,
# TODO: make these seq lens worth with blt
max_length=seq_len,
pad_to_max_length=True,
enable_byte_ngrams=False,
pad_id=tokenizer.boe_id,
packing_mode=PackingMode.BYTES,
)
packing_iterator = PackingIterator(sequence_iterator, packing_args=packing_args)
total_loss = 0.0
n_bytes = 0
@ -213,9 +193,16 @@ def eval_ppl_on_path(
x = torch.from_numpy(batch.x).cuda()
y = torch.from_numpy(batch.y).cuda()
mask = None if batch.mask is None else torch.from_numpy(batch.mask).cuda()
patch_lengths = batch.patch_lengths
if patch_lengths is not None:
patch_lengths = torch.from_numpy(patch_lengths).cuda()
if tokenizer_args.name in ["bytes", "blt"]:
n_bytes += y.numel() if mask is None else mask.sum().item()
pred = model(x)
if isinstance(model, ByteLatentTransformer):
pred = model(x, patch_lengths=patch_lengths)
else:
pred = model(x)
loss = F.cross_entropy(pred.flatten(0, 1), y.flatten(0, 1), reduction="sum")
total_loss += loss.item()
else:
@ -234,82 +221,6 @@ def eval_ppl_on_path(
}
def eval_on_val(generator, val_args: ValidationArgs, train_cfg: TrainArgs):
srcs = []
for src in val_args.sources:
path = os.path.join(val_args.root_dir, src)
srcs.append(path)
for src in train_cfg.data.sources:
path = os.path.join(train_cfg.data.root_dir, src)
srcs.append(path)
path_to_iter = {}
for path in srcs:
chunks = find_and_sanitize_chunks(
path,
world_size=1,
file_pattern="*.val.jsonl",
s3_profile=train_cfg.data.s3_profile,
)
assert (
len(chunks) == 1
), f"There should be only 1 chunk per validation file, but found: {chunks}"
chunk = chunks[0]
iterator = ArrowFileIterator(
dataset_files=[chunk],
file_path=None,
preprocess_dir=None,
entropy_model_name=None,
worker_id=0,
num_workers=1,
arrow_batch_size=train_cfg.data.arrow_batch_size,
s3_profile=train_cfg.data.s3_profile,
file_format="json",
)
path_to_iter[path] = iterator
max_gen_len = generator.max_gen_len
# We temporarily lower max gen len
generator.max_gen_len = 1
all_val_metrics = {}
for src in path_to_iter:
example_iterator = path_to_iter[src].create_iter()
texts = []
logger.info(f"Running validation on {src}...")
for step, example in enumerate(example_iterator):
texts.append(example.text)
_, loglikelihood, _ = generator.generate(texts)
metrics = defaultdict(list)
for i, ll in enumerate(loglikelihood):
tmp = ll.sum().item()
metrics["nll"].append(tmp)
metrics["nll_per_token"].append(tmp / len(ll))
metrics["nll_per_char"].append(tmp / len(texts[i]))
metrics["avg_seqlen"].append(len(ll))
for m in metrics:
metrics[m] = sum(metrics[m]) / len(metrics[m])
metrics.update(dist_mean_dict(metrics))
logger.info(f"Validation on {src} done. Metrics: {metrics}")
name = os.path.basename(src)
if name in all_val_metrics:
logger.warning(
f"Duplicate source name {name}, path {src} in validation sources, renaming to {name}_1"
)
name = f"{name}_1"
all_val_metrics[name] = metrics
generator.max_gen_len = max_gen_len
return all_val_metrics
def launch_eval(eval_args: EvalArgs):
assert eval_args.dump_dir is not None
assert eval_args.ckpt_dir is not None
@ -342,17 +253,29 @@ def launch_eval(eval_args: EvalArgs):
torch.distributed.barrier()
logger.info("Loading model")
# TODO: Make this general so that it works with either
# LMTransformer or Blt, similar with args
model, tokenizer, train_cfg = load_consolidated_model_and_tokenizer(
consolidate_path,
)
pad_id = 0 if train_cfg.data.tokenizer_args.name == "bytes" else tokenizer.boe_id
model.eval()
logger.info("Model loaded")
ppl_results = None
if eval_args.run_ppl:
assert eval_args.validation is not None
packing_args = PackingArgs(
batch_size=eval_args.validation.batch_size,
seq_len=train_cfg.data.seq_len,
max_length=train_cfg.data.max_encoder_seq_length,
pad_to_max_length=True,
enable_byte_ngrams=False,
pad_id=pad_id,
packing_mode=(
PackingMode.BYTES
if train_cfg.data.patcher_args.patching_mode == PatchingModeEnum.byte
else PackingMode.PATCHING
),
)
if len(eval_args.validation.sources) > 0:
ppl_results = {}
logger.info("Starting PPL evaluation on validation sets")
@ -362,14 +285,13 @@ def launch_eval(eval_args: EvalArgs):
world_size=world_size,
model=model,
tokenizer_args=train_cfg.data.tokenizer_args,
# TODO: Don't hardcode, modify based on model
patcher_args=PatcherArgs(patching_mode=PatchingModeEnum.byte),
add_patches=False,
patcher_args=train_cfg.data.patcher_args,
packing_args=packing_args,
add_patches=train_cfg.data.add_patches,
path=os.path.join(eval_args.validation.root_dir, source),
max_n_docs=eval_args.validation.max_n_docs,
batch_size=8,
arrow_batch_size=100,
s3_profile="blt",
arrow_batch_size=20,
s3_profile=eval_args.s3_profile,
)
task_results = None