From 5adf1c71336d40dfc5f0785d95fb2ea99fd84c35 Mon Sep 17 00:00:00 2001 From: Luciferian Ink Date: Sat, 18 Jan 2025 01:42:00 -0600 Subject: [PATCH] allow grads when calculating entropies --- bytelatent/data/patcher.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/bytelatent/data/patcher.py b/bytelatent/data/patcher.py index 1e64bd4..b495d9b 100644 --- a/bytelatent/data/patcher.py +++ b/bytelatent/data/patcher.py @@ -2,6 +2,7 @@ import math import time from collections import defaultdict +from contextlib import nullcontext from enum import Enum import torch @@ -63,7 +64,11 @@ def entropy(scores): def calculate_entropies( - tokens: torch.tensor, entropy_model, patching_batch_size, device: str | None = None + tokens: torch.tensor, + entropy_model, + patching_batch_size, + device: str | None = None, + enable_grad: bool = False, ): """ tokens: 2D tensor of shape [batch_size, seq_len] @@ -72,7 +77,10 @@ def calculate_entropies( Splits the tokens into chunks of size max_length and calculates entropies for each chunk. Entropy model can be executed on cpu or gpu, specify either 'cuda' or 'cpu' in the device argument. """ - with torch.no_grad(): + + grad_context = nullcontext() if enable_grad else torch.no_grad() + + with grad_context: entropies = [] max_length = getattr(entropy_model, "max_length", 8192) batch_numel = max_length * patching_batch_size