allow patch_batch_size to be adjusted in the forward() method

This commit is contained in:
Luciferian Ink 2025-01-17 19:41:17 -06:00
parent 175fce61df
commit cff0dcb7ab

View file

@ -490,6 +490,7 @@ class Patcher:
preds: torch.Tensor | None = None,
entropies: torch.Tensor | None = None,
threshold: float = None,
patching_batch_size: int | None = None,
) -> torch.Tensor:
"""
tokens: 2D tensor of shape [batch_size, seq_len] that needs to be patched
@ -539,7 +540,11 @@ class Patcher:
scores = calculate_entropies(
tokens,
self.entropy_model,
self.patching_batch_size,
(
patching_batch_size
if patching_batch_size is not None
else self.patching_batch_size
),
self.device,
)
if self.log_time: