allow grads when calculating entropies

This commit is contained in:
Luciferian Ink 2025-01-18 01:42:00 -06:00
parent 9e42f5dd1d
commit 5adf1c7133

View file

@ -2,6 +2,7 @@
import math
import time
from collections import defaultdict
from contextlib import nullcontext
from enum import Enum
import torch
@ -63,7 +64,11 @@ def entropy(scores):
def calculate_entropies(
tokens: torch.tensor, entropy_model, patching_batch_size, device: str | None = None
tokens: torch.tensor,
entropy_model,
patching_batch_size,
device: str | None = None,
enable_grad: bool = False,
):
"""
tokens: 2D tensor of shape [batch_size, seq_len]
@ -72,7 +77,10 @@ def calculate_entropies(
Splits the tokens into chunks of size max_length and calculates entropies for each chunk.
Entropy model can be executed on cpu or gpu, specify either 'cuda' or 'cpu' in the device argument.
"""
with torch.no_grad():
grad_context = nullcontext() if enable_grad else torch.no_grad()
with grad_context:
entropies = []
max_length = getattr(entropy_model, "max_length", 8192)
batch_numel = max_length * patching_batch_size