mirror of
https://github.com/facebookresearch/blt.git
synced 2025-09-01 01:59:05 +00:00
Fix eval mask
This commit is contained in:
parent
8c1b1a78bb
commit
a6acccdc71
2 changed files with 9 additions and 2 deletions
|
@ -260,6 +260,9 @@ class ValidationArgs(BaseModel):
|
||||||
max_n_docs: int | None = (
|
max_n_docs: int | None = (
|
||||||
None # If None the whole validation file is used -> /!\ This number of steps is gpu dependent (100 max steps on 8 gpus = 800 steps on 1 gpu)
|
None # If None the whole validation file is used -> /!\ This number of steps is gpu dependent (100 max steps on 8 gpus = 800 steps on 1 gpu)
|
||||||
)
|
)
|
||||||
|
max_n_batches: int | None = (
|
||||||
|
None # If None the whole validation file is used -> /!\ This number of steps is gpu dependent (100 max steps on 8 gpus = 800 steps on 1 gpu)
|
||||||
|
)
|
||||||
use_val_from_train_src: bool = True # Use the validation set from training sources
|
use_val_from_train_src: bool = True # Use the validation set from training sources
|
||||||
root_dir: str = ""
|
root_dir: str = ""
|
||||||
sources: list[str] = [] # Other sources to eval on
|
sources: list[str] = [] # Other sources to eval on
|
||||||
|
|
|
@ -153,6 +153,7 @@ def eval_ppl_on_path(
|
||||||
path: str,
|
path: str,
|
||||||
arrow_batch_size: int,
|
arrow_batch_size: int,
|
||||||
max_n_docs: int | None,
|
max_n_docs: int | None,
|
||||||
|
max_n_batches: int | None,
|
||||||
s3_profile: str | None = None,
|
s3_profile: str | None = None,
|
||||||
):
|
):
|
||||||
model.eval()
|
model.eval()
|
||||||
|
@ -189,7 +190,9 @@ def eval_ppl_on_path(
|
||||||
total_loss = 0.0
|
total_loss = 0.0
|
||||||
n_bytes = 0
|
n_bytes = 0
|
||||||
batch_iterator = packing_iterator.create_iter()
|
batch_iterator = packing_iterator.create_iter()
|
||||||
for batch in batch_iterator:
|
for i, batch in enumerate(batch_iterator):
|
||||||
|
if i == max_n_batches:
|
||||||
|
break
|
||||||
x = torch.from_numpy(batch.x).cuda()
|
x = torch.from_numpy(batch.x).cuda()
|
||||||
y = torch.from_numpy(batch.y).cuda()
|
y = torch.from_numpy(batch.y).cuda()
|
||||||
mask = None if batch.mask is None else torch.from_numpy(batch.mask).cuda()
|
mask = None if batch.mask is None else torch.from_numpy(batch.mask).cuda()
|
||||||
|
@ -203,7 +206,7 @@ def eval_ppl_on_path(
|
||||||
pred = model(x, patch_lengths=patch_lengths)
|
pred = model(x, patch_lengths=patch_lengths)
|
||||||
else:
|
else:
|
||||||
pred = model(x)
|
pred = model(x)
|
||||||
loss = F.cross_entropy(pred.flatten(0, 1), y.flatten(0, 1), reduction="sum")
|
loss = F.cross_entropy(pred.flatten(0, 1), y.flatten(0, 1), reduction="sum", ignore_index=0)
|
||||||
total_loss += loss.item()
|
total_loss += loss.item()
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
@ -301,6 +304,7 @@ def launch_eval(eval_args: EvalArgs):
|
||||||
add_patches=train_cfg.data.add_patches,
|
add_patches=train_cfg.data.add_patches,
|
||||||
path=os.path.join(eval_args.validation.root_dir, source),
|
path=os.path.join(eval_args.validation.root_dir, source),
|
||||||
max_n_docs=eval_args.validation.max_n_docs,
|
max_n_docs=eval_args.validation.max_n_docs,
|
||||||
|
max_n_batches=eval_args.validation.max_n_batches,
|
||||||
arrow_batch_size=20,
|
arrow_batch_size=20,
|
||||||
s3_profile=eval_args.s3_profile,
|
s3_profile=eval_args.s3_profile,
|
||||||
)
|
)
|
||||||
|
|
Loading…
Add table
Reference in a new issue