diff --git a/bytelatent/data/patcher.py b/bytelatent/data/patcher.py index 0063c80..1e64bd4 100644 --- a/bytelatent/data/patcher.py +++ b/bytelatent/data/patcher.py @@ -532,7 +532,7 @@ class Patcher: if self.log_time: s = time.time() if entropies is not None: - scores = torch.tensor(entropies, dtype=torch.float32) + scores = entropies.clone().detach().to(dtype=torch.float32) elif preds is not None: scores = entropy(preds) else: @@ -540,11 +540,7 @@ class Patcher: scores = calculate_entropies( tokens, self.entropy_model, - ( - patching_batch_size - if patching_batch_size is not None - else self.patching_batch_size - ), + self.patching_batch_size, self.device, ) if self.log_time: