Fix eval mask

This commit is contained in:
Srini Iyer 2025-04-08 02:22:29 +00:00
parent 8c1b1a78bb
commit a6acccdc71
2 changed files with 9 additions and 2 deletions

View file

@ -260,6 +260,9 @@ class ValidationArgs(BaseModel):
max_n_docs: int | None = ( max_n_docs: int | None = (
None # If None the whole validation file is used -> /!\ This number of steps is gpu dependent (100 max steps on 8 gpus = 800 steps on 1 gpu) None # If None the whole validation file is used -> /!\ This number of steps is gpu dependent (100 max steps on 8 gpus = 800 steps on 1 gpu)
) )
max_n_batches: int | None = (
None # If None the whole validation file is used -> /!\ This number of steps is gpu dependent (100 max steps on 8 gpus = 800 steps on 1 gpu)
)
use_val_from_train_src: bool = True # Use the validation set from training sources use_val_from_train_src: bool = True # Use the validation set from training sources
root_dir: str = "" root_dir: str = ""
sources: list[str] = [] # Other sources to eval on sources: list[str] = [] # Other sources to eval on

View file

@ -153,6 +153,7 @@ def eval_ppl_on_path(
path: str, path: str,
arrow_batch_size: int, arrow_batch_size: int,
max_n_docs: int | None, max_n_docs: int | None,
max_n_batches: int | None,
s3_profile: str | None = None, s3_profile: str | None = None,
): ):
model.eval() model.eval()
@ -189,7 +190,9 @@ def eval_ppl_on_path(
total_loss = 0.0 total_loss = 0.0
n_bytes = 0 n_bytes = 0
batch_iterator = packing_iterator.create_iter() batch_iterator = packing_iterator.create_iter()
for batch in batch_iterator: for i, batch in enumerate(batch_iterator):
if i == max_n_batches:
break
x = torch.from_numpy(batch.x).cuda() x = torch.from_numpy(batch.x).cuda()
y = torch.from_numpy(batch.y).cuda() y = torch.from_numpy(batch.y).cuda()
mask = None if batch.mask is None else torch.from_numpy(batch.mask).cuda() mask = None if batch.mask is None else torch.from_numpy(batch.mask).cuda()
@ -203,7 +206,7 @@ def eval_ppl_on_path(
pred = model(x, patch_lengths=patch_lengths) pred = model(x, patch_lengths=patch_lengths)
else: else:
pred = model(x) pred = model(x)
loss = F.cross_entropy(pred.flatten(0, 1), y.flatten(0, 1), reduction="sum") loss = F.cross_entropy(pred.flatten(0, 1), y.flatten(0, 1), reduction="sum", ignore_index=0)
total_loss += loss.item() total_loss += loss.item()
else: else:
raise NotImplementedError() raise NotImplementedError()
@ -301,6 +304,7 @@ def launch_eval(eval_args: EvalArgs):
add_patches=train_cfg.data.add_patches, add_patches=train_cfg.data.add_patches,
path=os.path.join(eval_args.validation.root_dir, source), path=os.path.join(eval_args.validation.root_dir, source),
max_n_docs=eval_args.validation.max_n_docs, max_n_docs=eval_args.validation.max_n_docs,
max_n_batches=eval_args.validation.max_n_batches,
arrow_batch_size=20, arrow_batch_size=20,
s3_profile=eval_args.s3_profile, s3_profile=eval_args.s3_profile,
) )