allow patch_batch_size to be adjusted in the forward() method

This commit is contained in:
Luciferian Ink 2025-01-17 19:41:17 -06:00
parent 175fce61df
commit cff0dcb7ab

View file

@ -490,6 +490,7 @@ class Patcher:
preds: torch.Tensor | None = None, preds: torch.Tensor | None = None,
entropies: torch.Tensor | None = None, entropies: torch.Tensor | None = None,
threshold: float = None, threshold: float = None,
patching_batch_size: int | None = None,
) -> torch.Tensor: ) -> torch.Tensor:
""" """
tokens: 2D tensor of shape [batch_size, seq_len] that needs to be patched tokens: 2D tensor of shape [batch_size, seq_len] that needs to be patched
@ -539,7 +540,11 @@ class Patcher:
scores = calculate_entropies( scores = calculate_entropies(
tokens, tokens,
self.entropy_model, self.entropy_model,
self.patching_batch_size, (
patching_batch_size
if patching_batch_size is not None
else self.patching_batch_size
),
self.device, self.device,
) )
if self.log_time: if self.log_time: