diff --git a/bytelatent/args.py b/bytelatent/args.py index 8ee5a67..ebe43ba 100644 --- a/bytelatent/args.py +++ b/bytelatent/args.py @@ -260,6 +260,9 @@ class ValidationArgs(BaseModel): max_n_docs: int | None = ( None # If None the whole validation file is used -> /!\ This number of steps is gpu dependent (100 max steps on 8 gpus = 800 steps on 1 gpu) ) + max_n_batches: int | None = ( + None # If None the whole validation file is used -> /!\ This number of steps is gpu dependent (100 max steps on 8 gpus = 800 steps on 1 gpu) + ) use_val_from_train_src: bool = True # Use the validation set from training sources root_dir: str = "" sources: list[str] = [] # Other sources to eval on diff --git a/bytelatent/eval.py b/bytelatent/eval.py index 2401fc0..a917f48 100644 --- a/bytelatent/eval.py +++ b/bytelatent/eval.py @@ -153,6 +153,7 @@ def eval_ppl_on_path( path: str, arrow_batch_size: int, max_n_docs: int | None, + max_n_batches: int | None, s3_profile: str | None = None, ): model.eval() @@ -189,7 +190,9 @@ def eval_ppl_on_path( total_loss = 0.0 n_bytes = 0 batch_iterator = packing_iterator.create_iter() - for batch in batch_iterator: + for i, batch in enumerate(batch_iterator): + if i == max_n_batches: + break x = torch.from_numpy(batch.x).cuda() y = torch.from_numpy(batch.y).cuda() mask = None if batch.mask is None else torch.from_numpy(batch.mask).cuda() @@ -203,7 +206,7 @@ def eval_ppl_on_path( pred = model(x, patch_lengths=patch_lengths) else: pred = model(x) - loss = F.cross_entropy(pred.flatten(0, 1), y.flatten(0, 1), reduction="sum") + loss = F.cross_entropy(pred.flatten(0, 1), y.flatten(0, 1), reduction="sum", ignore_index=0) total_loss += loss.item() else: raise NotImplementedError() @@ -301,6 +304,7 @@ def launch_eval(eval_args: EvalArgs): add_patches=train_cfg.data.add_patches, path=os.path.join(eval_args.validation.root_dir, source), max_n_docs=eval_args.validation.max_n_docs, + max_n_batches=eval_args.validation.max_n_batches, arrow_batch_size=20, s3_profile=eval_args.s3_profile, )