Fix realtime entropy patching (#26)
Some checks are pending
Lint with Black / lint (push) Waiting to run
Lint with isort / lint (push) Waiting to run

* allow loading of the entropy model directly

* remove unused argument

* remove spammy warning

* allow patch_batch_size to be adjusted in the forward() method

* revert to original patcher style, fix warning

* allow grads when calculating entropies

* fix grad flow

* return preds from calculate_entropies()

* remove legacy arg

* fix an error with monotonicity and small sequence lengths

* ensure patcher is serializable

* revert patcher to original

* remove unused import
This commit is contained in:
Ink 2025-01-21 18:34:23 -06:00 committed by GitHub
parent 6ffeb66b53
commit 392117bff2
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 26 additions and 12 deletions

View file

@ -2,6 +2,7 @@
import math import math
import time import time
from collections import defaultdict from collections import defaultdict
from contextlib import nullcontext
from enum import Enum from enum import Enum
import torch import torch
@ -58,7 +59,11 @@ def entropy(scores):
def calculate_entropies( def calculate_entropies(
tokens: torch.tensor, entropy_model, patching_batch_size, device: str | None = None tokens: torch.tensor,
entropy_model,
patching_batch_size,
device: str | None = None,
enable_grad: bool = False,
): ):
""" """
tokens: 2D tensor of shape [batch_size, seq_len] tokens: 2D tensor of shape [batch_size, seq_len]
@ -67,8 +72,12 @@ def calculate_entropies(
Splits the tokens into chunks of size max_length and calculates entropies for each chunk. Splits the tokens into chunks of size max_length and calculates entropies for each chunk.
Entropy model can be executed on cpu or gpu, specify either 'cuda' or 'cpu' in the device argument. Entropy model can be executed on cpu or gpu, specify either 'cuda' or 'cpu' in the device argument.
""" """
with torch.no_grad():
grad_context = nullcontext() if enable_grad else torch.no_grad()
with grad_context:
entropies = [] entropies = []
preds = []
max_length = getattr(entropy_model, "max_length", 8192) max_length = getattr(entropy_model, "max_length", 8192)
batch_numel = max_length * patching_batch_size batch_numel = max_length * patching_batch_size
splits = torch.split(tokens.flatten(), batch_numel) splits = torch.split(tokens.flatten(), batch_numel)
@ -86,12 +95,15 @@ def calculate_entropies(
pred = pred.reshape(-1, pred.shape[-1])[ pred = pred.reshape(-1, pred.shape[-1])[
: split.numel() - pad_size, : : split.numel() - pad_size, :
] # [batch_size * seq_len, vocab] ] # [batch_size * seq_len, vocab]
preds.append(pred)
pred_entropies = entropy(pred) pred_entropies = entropy(pred)
entropies.append(pred_entropies) entropies.append(pred_entropies)
concat_entropies = torch.cat(entropies, dim=0) concat_entropies = torch.cat(entropies, dim=0)
concat_entropies = concat_entropies.reshape(tokens.shape) concat_entropies = concat_entropies.reshape(tokens.shape)
return concat_entropies concat_preds = torch.cat(preds, dim=0)
concat_preds = concat_preds.reshape(tokens.shape[0], tokens.shape[1], -1)
return concat_entropies, concat_preds
def patch_start_mask_from_entropy_with_monotonicity(entropies, t): def patch_start_mask_from_entropy_with_monotonicity(entropies, t):
@ -101,6 +113,10 @@ def patch_start_mask_from_entropy_with_monotonicity(entropies, t):
returns [bs, seq_len] mask where True indicates the start of a patch returns [bs, seq_len] mask where True indicates the start of a patch
""" """
bs, seq_len = entropies.shape bs, seq_len = entropies.shape
if seq_len == 0:
return entropies > t
mask = torch.zeros_like(entropies, dtype=torch.bool) mask = torch.zeros_like(entropies, dtype=torch.bool)
mask[:, 0] = True mask[:, 0] = True
@ -123,6 +139,10 @@ def patch_start_mask_global_and_monotonicity(entropies, t, t_add=0):
returns [bs, seq_len] mask where True indicates the start of a patch returns [bs, seq_len] mask where True indicates the start of a patch
""" """
bs, seq_len = entropies.shape bs, seq_len = entropies.shape
if seq_len == 0:
return entropies > t
mask = torch.zeros_like(entropies, dtype=torch.bool) mask = torch.zeros_like(entropies, dtype=torch.bool)
mask[:, 0] = True mask[:, 0] = True
@ -521,12 +541,12 @@ class Patcher:
if self.log_time: if self.log_time:
s = time.time() s = time.time()
if entropies is not None: if entropies is not None:
scores = torch.tensor(entropies, dtype=torch.float32) scores = entropies.to(dtype=torch.float32)
elif preds is not None: elif preds is not None:
scores = entropy(preds) scores = entropy(preds)
else: else:
start_entropies = time.time() start_entropies = time.time()
scores = calculate_entropies( scores, _ = calculate_entropies(
tokens, tokens,
self.entropy_model, self.entropy_model,
self.patching_batch_size, self.patching_batch_size,

View file

@ -199,9 +199,6 @@ class LocalModelBase(nn.Module):
class LocalEncoder(LocalModelBase): class LocalEncoder(LocalModelBase):
def __init__(self, args: LocalModelArgs): def __init__(self, args: LocalModelArgs):
super().__init__(args) super().__init__(args)
self.output_proj = (
args.patching_mode in ["entropy", "probmax"]
) and args.entropy_model_checkpoint_dir is None
self.apply_transformer = args.use_local_encoder_transformer self.apply_transformer = args.use_local_encoder_transformer
self.downsampling_by_pooling = args.downsampling_by_pooling self.downsampling_by_pooling = args.downsampling_by_pooling

View file

@ -162,9 +162,6 @@ def create_causal_mask(
return "causal" return "causal"
if BLT_SUPPRESS_ATTN_ERROR == 1: if BLT_SUPPRESS_ATTN_ERROR == 1:
logging.warning(
"SDPA attention being used, which doesn't have specialized attention implementations for block_causal and local_block_causal attention. Allowing model to run since BLT_SUPPRESS_ATTN_ERROR=1"
)
return "causal" return "causal"
else: else:
raise ValueError( raise ValueError(

View file

@ -117,7 +117,7 @@ def main(
text = get_text(doc) text = get_text(doc)
tokens = torch.tensor(tokenizer.encode(text)) tokens = torch.tensor(tokenizer.encode(text))
patch_start = time.time() patch_start = time.time()
scores = calculate_entropies( scores, _ = calculate_entropies(
tokens, tokens,
entropy_model, entropy_model,
patching_batch_size, patching_batch_size,