mirror of
https://github.com/facebookresearch/blt.git
synced 2025-04-21 00:59:09 +00:00
Update ppl evals to work with blt model, in addition to entropy model
Summary: Test Plan: Run ``` python -m bytelatent.eval config=../internal-blt/configs/eval_blt.yaml validation.max_n_docs=null python -m bytelatent.eval config=../internal-blt/configs/eval_entropy.yaml validation.max_n_docs=null ```
This commit is contained in:
parent
f84ee635bd
commit
719900d4bd
3 changed files with 61 additions and 111 deletions
bytelatent
|
@ -263,6 +263,7 @@ class ValidationArgs(BaseModel):
|
||||||
use_val_from_train_src: bool = True # Use the validation set from training sources
|
use_val_from_train_src: bool = True # Use the validation set from training sources
|
||||||
root_dir: str = ""
|
root_dir: str = ""
|
||||||
sources: list[str] = [] # Other sources to eval on
|
sources: list[str] = [] # Other sources to eval on
|
||||||
|
batch_size: int = 8
|
||||||
|
|
||||||
|
|
||||||
class EvalArgs(BaseModel):
|
class EvalArgs(BaseModel):
|
||||||
|
|
|
@ -221,6 +221,7 @@ class PackingIterator(StatefulIterator[Batch, PackingIteratorState]):
|
||||||
enable_byte_ngrams = self.packing_args.enable_byte_ngrams
|
enable_byte_ngrams = self.packing_args.enable_byte_ngrams
|
||||||
max_length = self.packing_args.max_length
|
max_length = self.packing_args.max_length
|
||||||
assert max_length is not None
|
assert max_length is not None
|
||||||
|
final_leftover_batch = False
|
||||||
while True:
|
while True:
|
||||||
tokens: list[list[int]] = []
|
tokens: list[list[int]] = []
|
||||||
masks: list[list[bool]] = []
|
masks: list[list[bool]] = []
|
||||||
|
@ -252,6 +253,9 @@ class PackingIterator(StatefulIterator[Batch, PackingIteratorState]):
|
||||||
break
|
break
|
||||||
|
|
||||||
x_patch_lengths = np.array(patch_lengths)
|
x_patch_lengths = np.array(patch_lengths)
|
||||||
|
assert (
|
||||||
|
x_patch_lengths.shape[1] == seq_len
|
||||||
|
), f"{x_patch_lengths.shape[1]} vs {seq_len}"
|
||||||
# pad batch to same length
|
# pad batch to same length
|
||||||
tok_seq_len = max([len(toks) for toks in tokens]) - 1
|
tok_seq_len = max([len(toks) for toks in tokens]) - 1
|
||||||
x = np.full((batch_size, tok_seq_len), fill_value=pad_id)
|
x = np.full((batch_size, tok_seq_len), fill_value=pad_id)
|
||||||
|
@ -263,7 +267,30 @@ class PackingIterator(StatefulIterator[Batch, PackingIteratorState]):
|
||||||
# Adjust patch lengths to match x
|
# Adjust patch lengths to match x
|
||||||
x_patch_lengths[i, -1] += tok_seq_len - (len(tok_seq) - 1)
|
x_patch_lengths[i, -1] += tok_seq_len - (len(tok_seq) - 1)
|
||||||
|
|
||||||
assert x_patch_lengths.shape == (batch_size, seq_len)
|
if x_patch_lengths.shape[0] < batch_size:
|
||||||
|
if final_leftover_batch:
|
||||||
|
raise ValueError(
|
||||||
|
"There should only be one partial batch, but found multiple"
|
||||||
|
)
|
||||||
|
final_leftover_batch = True
|
||||||
|
assert len(masks) == len(x_patch_lengths)
|
||||||
|
n_missing = batch_size - x_patch_lengths.shape[0]
|
||||||
|
# Repeat the last patch length to validly pad it out, but
|
||||||
|
# update the mask to ignore the row
|
||||||
|
x_patch_lengths = np.vstack(
|
||||||
|
[
|
||||||
|
x_patch_lengths,
|
||||||
|
np.repeat(x_patch_lengths[-1:, :], n_missing, axis=0),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
for _ in range(n_missing):
|
||||||
|
masks.append([0] * tok_seq_len)
|
||||||
|
assert len(masks) == batch_size
|
||||||
|
|
||||||
|
assert x_patch_lengths.shape == (
|
||||||
|
batch_size,
|
||||||
|
seq_len,
|
||||||
|
), f"{x_patch_lengths.shape} vs {(batch_size, seq_len)}"
|
||||||
|
|
||||||
if enable_byte_ngrams:
|
if enable_byte_ngrams:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
|
@ -148,35 +148,25 @@ def eval_ppl_on_path(
|
||||||
model: LMTransformer | ByteLatentTransformer,
|
model: LMTransformer | ByteLatentTransformer,
|
||||||
tokenizer_args: TokenizerArgs,
|
tokenizer_args: TokenizerArgs,
|
||||||
patcher_args: PatcherArgs,
|
patcher_args: PatcherArgs,
|
||||||
|
packing_args: PackingArgs,
|
||||||
add_patches: bool,
|
add_patches: bool,
|
||||||
path: str,
|
path: str,
|
||||||
batch_size: int,
|
|
||||||
arrow_batch_size: int,
|
arrow_batch_size: int,
|
||||||
max_n_docs: int | None,
|
max_n_docs: int | None,
|
||||||
s3_profile: str | None = None,
|
s3_profile: str | None = None,
|
||||||
):
|
):
|
||||||
model.eval()
|
model.eval()
|
||||||
tokenizer = tokenizer_args.build()
|
|
||||||
seq_len = model.get_output_seq_len()
|
seq_len = model.get_output_seq_len()
|
||||||
chunks = find_and_sanitize_chunks(
|
|
||||||
path,
|
|
||||||
world_size=1,
|
|
||||||
file_pattern="*.val.jsonl",
|
|
||||||
s3_profile=s3_profile,
|
|
||||||
)
|
|
||||||
assert (
|
|
||||||
len(chunks) == 1
|
|
||||||
), f"There should be only 1 chunk per validation file, but found: {chunks}"
|
|
||||||
chunk = chunks[0]
|
|
||||||
arrow_iterator = ArrowFileIterator(
|
arrow_iterator = ArrowFileIterator(
|
||||||
file_path=chunk,
|
file_path=None,
|
||||||
preprocess_dir=None,
|
dataset_files=[path],
|
||||||
entropy_model_name=None,
|
entropy_model_name=None,
|
||||||
worker_id=world_rank,
|
worker_id=world_rank,
|
||||||
num_workers=world_size,
|
num_workers=world_size,
|
||||||
arrow_batch_size=arrow_batch_size,
|
arrow_batch_size=arrow_batch_size,
|
||||||
|
preprocess_dir=None,
|
||||||
s3_profile=s3_profile,
|
s3_profile=s3_profile,
|
||||||
file_format="json",
|
file_format="arrow" if path.endswith("arrow") else "json",
|
||||||
)
|
)
|
||||||
if max_n_docs is not None:
|
if max_n_docs is not None:
|
||||||
arrow_iterator = LimitIterator(arrow_iterator, limit=max_n_docs)
|
arrow_iterator = LimitIterator(arrow_iterator, limit=max_n_docs)
|
||||||
|
@ -195,16 +185,6 @@ def eval_ppl_on_path(
|
||||||
),
|
),
|
||||||
rng_state=None,
|
rng_state=None,
|
||||||
)
|
)
|
||||||
packing_args = PackingArgs(
|
|
||||||
batch_size=batch_size,
|
|
||||||
seq_len=seq_len,
|
|
||||||
# TODO: make these seq lens worth with blt
|
|
||||||
max_length=seq_len,
|
|
||||||
pad_to_max_length=True,
|
|
||||||
enable_byte_ngrams=False,
|
|
||||||
pad_id=tokenizer.boe_id,
|
|
||||||
packing_mode=PackingMode.BYTES,
|
|
||||||
)
|
|
||||||
packing_iterator = PackingIterator(sequence_iterator, packing_args=packing_args)
|
packing_iterator = PackingIterator(sequence_iterator, packing_args=packing_args)
|
||||||
total_loss = 0.0
|
total_loss = 0.0
|
||||||
n_bytes = 0
|
n_bytes = 0
|
||||||
|
@ -213,9 +193,16 @@ def eval_ppl_on_path(
|
||||||
x = torch.from_numpy(batch.x).cuda()
|
x = torch.from_numpy(batch.x).cuda()
|
||||||
y = torch.from_numpy(batch.y).cuda()
|
y = torch.from_numpy(batch.y).cuda()
|
||||||
mask = None if batch.mask is None else torch.from_numpy(batch.mask).cuda()
|
mask = None if batch.mask is None else torch.from_numpy(batch.mask).cuda()
|
||||||
|
patch_lengths = batch.patch_lengths
|
||||||
|
if patch_lengths is not None:
|
||||||
|
patch_lengths = torch.from_numpy(patch_lengths).cuda()
|
||||||
|
|
||||||
if tokenizer_args.name in ["bytes", "blt"]:
|
if tokenizer_args.name in ["bytes", "blt"]:
|
||||||
n_bytes += y.numel() if mask is None else mask.sum().item()
|
n_bytes += y.numel() if mask is None else mask.sum().item()
|
||||||
pred = model(x)
|
if isinstance(model, ByteLatentTransformer):
|
||||||
|
pred = model(x, patch_lengths=patch_lengths)
|
||||||
|
else:
|
||||||
|
pred = model(x)
|
||||||
loss = F.cross_entropy(pred.flatten(0, 1), y.flatten(0, 1), reduction="sum")
|
loss = F.cross_entropy(pred.flatten(0, 1), y.flatten(0, 1), reduction="sum")
|
||||||
total_loss += loss.item()
|
total_loss += loss.item()
|
||||||
else:
|
else:
|
||||||
|
@ -234,82 +221,6 @@ def eval_ppl_on_path(
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def eval_on_val(generator, val_args: ValidationArgs, train_cfg: TrainArgs):
|
|
||||||
srcs = []
|
|
||||||
for src in val_args.sources:
|
|
||||||
path = os.path.join(val_args.root_dir, src)
|
|
||||||
srcs.append(path)
|
|
||||||
|
|
||||||
for src in train_cfg.data.sources:
|
|
||||||
path = os.path.join(train_cfg.data.root_dir, src)
|
|
||||||
srcs.append(path)
|
|
||||||
|
|
||||||
path_to_iter = {}
|
|
||||||
for path in srcs:
|
|
||||||
chunks = find_and_sanitize_chunks(
|
|
||||||
path,
|
|
||||||
world_size=1,
|
|
||||||
file_pattern="*.val.jsonl",
|
|
||||||
s3_profile=train_cfg.data.s3_profile,
|
|
||||||
)
|
|
||||||
assert (
|
|
||||||
len(chunks) == 1
|
|
||||||
), f"There should be only 1 chunk per validation file, but found: {chunks}"
|
|
||||||
chunk = chunks[0]
|
|
||||||
iterator = ArrowFileIterator(
|
|
||||||
dataset_files=[chunk],
|
|
||||||
file_path=None,
|
|
||||||
preprocess_dir=None,
|
|
||||||
entropy_model_name=None,
|
|
||||||
worker_id=0,
|
|
||||||
num_workers=1,
|
|
||||||
arrow_batch_size=train_cfg.data.arrow_batch_size,
|
|
||||||
s3_profile=train_cfg.data.s3_profile,
|
|
||||||
file_format="json",
|
|
||||||
)
|
|
||||||
path_to_iter[path] = iterator
|
|
||||||
|
|
||||||
max_gen_len = generator.max_gen_len
|
|
||||||
# We temporarily lower max gen len
|
|
||||||
generator.max_gen_len = 1
|
|
||||||
|
|
||||||
all_val_metrics = {}
|
|
||||||
for src in path_to_iter:
|
|
||||||
example_iterator = path_to_iter[src].create_iter()
|
|
||||||
texts = []
|
|
||||||
logger.info(f"Running validation on {src}...")
|
|
||||||
for step, example in enumerate(example_iterator):
|
|
||||||
texts.append(example.text)
|
|
||||||
|
|
||||||
_, loglikelihood, _ = generator.generate(texts)
|
|
||||||
|
|
||||||
metrics = defaultdict(list)
|
|
||||||
for i, ll in enumerate(loglikelihood):
|
|
||||||
tmp = ll.sum().item()
|
|
||||||
metrics["nll"].append(tmp)
|
|
||||||
metrics["nll_per_token"].append(tmp / len(ll))
|
|
||||||
metrics["nll_per_char"].append(tmp / len(texts[i]))
|
|
||||||
|
|
||||||
metrics["avg_seqlen"].append(len(ll))
|
|
||||||
|
|
||||||
for m in metrics:
|
|
||||||
metrics[m] = sum(metrics[m]) / len(metrics[m])
|
|
||||||
metrics.update(dist_mean_dict(metrics))
|
|
||||||
logger.info(f"Validation on {src} done. Metrics: {metrics}")
|
|
||||||
|
|
||||||
name = os.path.basename(src)
|
|
||||||
if name in all_val_metrics:
|
|
||||||
logger.warning(
|
|
||||||
f"Duplicate source name {name}, path {src} in validation sources, renaming to {name}_1"
|
|
||||||
)
|
|
||||||
name = f"{name}_1"
|
|
||||||
all_val_metrics[name] = metrics
|
|
||||||
|
|
||||||
generator.max_gen_len = max_gen_len
|
|
||||||
|
|
||||||
return all_val_metrics
|
|
||||||
|
|
||||||
|
|
||||||
def launch_eval(eval_args: EvalArgs):
|
def launch_eval(eval_args: EvalArgs):
|
||||||
assert eval_args.dump_dir is not None
|
assert eval_args.dump_dir is not None
|
||||||
assert eval_args.ckpt_dir is not None
|
assert eval_args.ckpt_dir is not None
|
||||||
|
@ -342,17 +253,29 @@ def launch_eval(eval_args: EvalArgs):
|
||||||
|
|
||||||
torch.distributed.barrier()
|
torch.distributed.barrier()
|
||||||
logger.info("Loading model")
|
logger.info("Loading model")
|
||||||
# TODO: Make this general so that it works with either
|
|
||||||
# LMTransformer or Blt, similar with args
|
|
||||||
model, tokenizer, train_cfg = load_consolidated_model_and_tokenizer(
|
model, tokenizer, train_cfg = load_consolidated_model_and_tokenizer(
|
||||||
consolidate_path,
|
consolidate_path,
|
||||||
)
|
)
|
||||||
|
pad_id = 0 if train_cfg.data.tokenizer_args.name == "bytes" else tokenizer.boe_id
|
||||||
model.eval()
|
model.eval()
|
||||||
logger.info("Model loaded")
|
logger.info("Model loaded")
|
||||||
|
|
||||||
ppl_results = None
|
ppl_results = None
|
||||||
if eval_args.run_ppl:
|
if eval_args.run_ppl:
|
||||||
assert eval_args.validation is not None
|
assert eval_args.validation is not None
|
||||||
|
packing_args = PackingArgs(
|
||||||
|
batch_size=eval_args.validation.batch_size,
|
||||||
|
seq_len=train_cfg.data.seq_len,
|
||||||
|
max_length=train_cfg.data.max_encoder_seq_length,
|
||||||
|
pad_to_max_length=True,
|
||||||
|
enable_byte_ngrams=False,
|
||||||
|
pad_id=pad_id,
|
||||||
|
packing_mode=(
|
||||||
|
PackingMode.BYTES
|
||||||
|
if train_cfg.data.patcher_args.patching_mode == PatchingModeEnum.byte
|
||||||
|
else PackingMode.PATCHING
|
||||||
|
),
|
||||||
|
)
|
||||||
if len(eval_args.validation.sources) > 0:
|
if len(eval_args.validation.sources) > 0:
|
||||||
ppl_results = {}
|
ppl_results = {}
|
||||||
logger.info("Starting PPL evaluation on validation sets")
|
logger.info("Starting PPL evaluation on validation sets")
|
||||||
|
@ -362,14 +285,13 @@ def launch_eval(eval_args: EvalArgs):
|
||||||
world_size=world_size,
|
world_size=world_size,
|
||||||
model=model,
|
model=model,
|
||||||
tokenizer_args=train_cfg.data.tokenizer_args,
|
tokenizer_args=train_cfg.data.tokenizer_args,
|
||||||
# TODO: Don't hardcode, modify based on model
|
patcher_args=train_cfg.data.patcher_args,
|
||||||
patcher_args=PatcherArgs(patching_mode=PatchingModeEnum.byte),
|
packing_args=packing_args,
|
||||||
add_patches=False,
|
add_patches=train_cfg.data.add_patches,
|
||||||
path=os.path.join(eval_args.validation.root_dir, source),
|
path=os.path.join(eval_args.validation.root_dir, source),
|
||||||
max_n_docs=eval_args.validation.max_n_docs,
|
max_n_docs=eval_args.validation.max_n_docs,
|
||||||
batch_size=8,
|
arrow_batch_size=20,
|
||||||
arrow_batch_size=100,
|
s3_profile=eval_args.s3_profile,
|
||||||
s3_profile="blt",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
task_results = None
|
task_results = None
|
||||||
|
|
Loading…
Add table
Reference in a new issue