mirror of
https://github.com/facebookresearch/blt.git
synced 2025-01-19 00:47:44 +00:00
allow patch_batch_size to be adjusted in the forward() method
This commit is contained in:
parent
175fce61df
commit
cff0dcb7ab
|
@ -490,6 +490,7 @@ class Patcher:
|
||||||
preds: torch.Tensor | None = None,
|
preds: torch.Tensor | None = None,
|
||||||
entropies: torch.Tensor | None = None,
|
entropies: torch.Tensor | None = None,
|
||||||
threshold: float = None,
|
threshold: float = None,
|
||||||
|
patching_batch_size: int | None = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
tokens: 2D tensor of shape [batch_size, seq_len] that needs to be patched
|
tokens: 2D tensor of shape [batch_size, seq_len] that needs to be patched
|
||||||
|
@ -539,7 +540,11 @@ class Patcher:
|
||||||
scores = calculate_entropies(
|
scores = calculate_entropies(
|
||||||
tokens,
|
tokens,
|
||||||
self.entropy_model,
|
self.entropy_model,
|
||||||
self.patching_batch_size,
|
(
|
||||||
|
patching_batch_size
|
||||||
|
if patching_batch_size is not None
|
||||||
|
else self.patching_batch_size
|
||||||
|
),
|
||||||
self.device,
|
self.device,
|
||||||
)
|
)
|
||||||
if self.log_time:
|
if self.log_time:
|
||||||
|
|
Loading…
Reference in a new issue