revert to original patcher style, fix warning

This commit is contained in:
Luciferian Ink 2025-01-17 20:36:07 -06:00
parent cff0dcb7ab
commit 9e42f5dd1d

View file

@ -532,7 +532,7 @@ class Patcher:
if self.log_time:
s = time.time()
if entropies is not None:
scores = torch.tensor(entropies, dtype=torch.float32)
scores = entropies.clone().detach().to(dtype=torch.float32)
elif preds is not None:
scores = entropy(preds)
else:
@ -540,11 +540,7 @@ class Patcher:
scores = calculate_entropies(
tokens,
self.entropy_model,
(
patching_batch_size
if patching_batch_size is not None
else self.patching_batch_size
),
self.patching_batch_size,
self.device,
)
if self.log_time: