allow grads when calculating entropies

This commit is contained in:
Luciferian Ink 2025-01-18 01:42:00 -06:00
parent 9e42f5dd1d
commit 5adf1c7133

View file

@ -2,6 +2,7 @@
import math import math
import time import time
from collections import defaultdict from collections import defaultdict
from contextlib import nullcontext
from enum import Enum from enum import Enum
import torch import torch
@ -63,7 +64,11 @@ def entropy(scores):
def calculate_entropies( def calculate_entropies(
tokens: torch.tensor, entropy_model, patching_batch_size, device: str | None = None tokens: torch.tensor,
entropy_model,
patching_batch_size,
device: str | None = None,
enable_grad: bool = False,
): ):
""" """
tokens: 2D tensor of shape [batch_size, seq_len] tokens: 2D tensor of shape [batch_size, seq_len]
@ -72,7 +77,10 @@ def calculate_entropies(
Splits the tokens into chunks of size max_length and calculates entropies for each chunk. Splits the tokens into chunks of size max_length and calculates entropies for each chunk.
Entropy model can be executed on cpu or gpu, specify either 'cuda' or 'cpu' in the device argument. Entropy model can be executed on cpu or gpu, specify either 'cuda' or 'cpu' in the device argument.
""" """
with torch.no_grad():
grad_context = nullcontext() if enable_grad else torch.no_grad()
with grad_context:
entropies = [] entropies = []
max_length = getattr(entropy_model, "max_length", 8192) max_length = getattr(entropy_model, "max_length", 8192)
batch_numel = max_length * patching_batch_size batch_numel = max_length * patching_batch_size