mirror of
https://github.com/facebookresearch/blt.git
synced 2025-01-18 16:37:46 +00:00
allow grads when calculating entropies
This commit is contained in:
parent
9e42f5dd1d
commit
5adf1c7133
|
@ -2,6 +2,7 @@
|
|||
import math
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from contextlib import nullcontext
|
||||
from enum import Enum
|
||||
|
||||
import torch
|
||||
|
@ -63,7 +64,11 @@ def entropy(scores):
|
|||
|
||||
|
||||
def calculate_entropies(
|
||||
tokens: torch.tensor, entropy_model, patching_batch_size, device: str | None = None
|
||||
tokens: torch.tensor,
|
||||
entropy_model,
|
||||
patching_batch_size,
|
||||
device: str | None = None,
|
||||
enable_grad: bool = False,
|
||||
):
|
||||
"""
|
||||
tokens: 2D tensor of shape [batch_size, seq_len]
|
||||
|
@ -72,7 +77,10 @@ def calculate_entropies(
|
|||
Splits the tokens into chunks of size max_length and calculates entropies for each chunk.
|
||||
Entropy model can be executed on cpu or gpu, specify either 'cuda' or 'cpu' in the device argument.
|
||||
"""
|
||||
with torch.no_grad():
|
||||
|
||||
grad_context = nullcontext() if enable_grad else torch.no_grad()
|
||||
|
||||
with grad_context:
|
||||
entropies = []
|
||||
max_length = getattr(entropy_model, "max_length", 8192)
|
||||
batch_numel = max_length * patching_batch_size
|
||||
|
|
Loading…
Reference in a new issue