From cff0dcb7abf29def377c8915de769846c6900ac2 Mon Sep 17 00:00:00 2001 From: Luciferian Ink Date: Fri, 17 Jan 2025 19:41:17 -0600 Subject: [PATCH] allow patch_batch_size to be adjusted in the forward() method --- bytelatent/data/patcher.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/bytelatent/data/patcher.py b/bytelatent/data/patcher.py index d32f168..0063c80 100644 --- a/bytelatent/data/patcher.py +++ b/bytelatent/data/patcher.py @@ -490,6 +490,7 @@ class Patcher: preds: torch.Tensor | None = None, entropies: torch.Tensor | None = None, threshold: float = None, + patching_batch_size: int | None = None, ) -> torch.Tensor: """ tokens: 2D tensor of shape [batch_size, seq_len] that needs to be patched @@ -539,7 +540,11 @@ class Patcher: scores = calculate_entropies( tokens, self.entropy_model, - self.patching_batch_size, + ( + patching_batch_size + if patching_batch_size is not None + else self.patching_batch_size + ), self.device, ) if self.log_time: