mirror of
https://github.com/facebookresearch/blt.git
synced 2025-01-19 00:47:44 +00:00
allow grads when calculating entropies
This commit is contained in:
parent
9e42f5dd1d
commit
5adf1c7133
|
@ -2,6 +2,7 @@
|
||||||
import math
|
import math
|
||||||
import time
|
import time
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
|
from contextlib import nullcontext
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
@ -63,7 +64,11 @@ def entropy(scores):
|
||||||
|
|
||||||
|
|
||||||
def calculate_entropies(
|
def calculate_entropies(
|
||||||
tokens: torch.tensor, entropy_model, patching_batch_size, device: str | None = None
|
tokens: torch.tensor,
|
||||||
|
entropy_model,
|
||||||
|
patching_batch_size,
|
||||||
|
device: str | None = None,
|
||||||
|
enable_grad: bool = False,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
tokens: 2D tensor of shape [batch_size, seq_len]
|
tokens: 2D tensor of shape [batch_size, seq_len]
|
||||||
|
@ -72,7 +77,10 @@ def calculate_entropies(
|
||||||
Splits the tokens into chunks of size max_length and calculates entropies for each chunk.
|
Splits the tokens into chunks of size max_length and calculates entropies for each chunk.
|
||||||
Entropy model can be executed on cpu or gpu, specify either 'cuda' or 'cpu' in the device argument.
|
Entropy model can be executed on cpu or gpu, specify either 'cuda' or 'cpu' in the device argument.
|
||||||
"""
|
"""
|
||||||
with torch.no_grad():
|
|
||||||
|
grad_context = nullcontext() if enable_grad else torch.no_grad()
|
||||||
|
|
||||||
|
with grad_context:
|
||||||
entropies = []
|
entropies = []
|
||||||
max_length = getattr(entropy_model, "max_length", 8192)
|
max_length = getattr(entropy_model, "max_length", 8192)
|
||||||
batch_numel = max_length * patching_batch_size
|
batch_numel = max_length * patching_batch_size
|
||||||
|
|
Loading…
Reference in a new issue