mirror of
https://github.com/facebookresearch/blt.git
synced 2025-01-31 01:52:15 +00:00
Fix realtime entropy patching (#26)
* allow loading of the entropy model directly * remove unused argument * remove spammy warning * allow patch_batch_size to be adjusted in the forward() method * revert to original patcher style, fix warning * allow grads when calculating entropies * fix grad flow * return preds from calculate_entropies() * remove legacy arg * fix an error with monotonicity and small sequence lengths * ensure patcher is serializable * revert patcher to original * remove unused import
This commit is contained in:
parent
6ffeb66b53
commit
392117bff2
|
@ -2,6 +2,7 @@
|
|||
import math
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from contextlib import nullcontext
|
||||
from enum import Enum
|
||||
|
||||
import torch
|
||||
|
@ -58,7 +59,11 @@ def entropy(scores):
|
|||
|
||||
|
||||
def calculate_entropies(
|
||||
tokens: torch.tensor, entropy_model, patching_batch_size, device: str | None = None
|
||||
tokens: torch.tensor,
|
||||
entropy_model,
|
||||
patching_batch_size,
|
||||
device: str | None = None,
|
||||
enable_grad: bool = False,
|
||||
):
|
||||
"""
|
||||
tokens: 2D tensor of shape [batch_size, seq_len]
|
||||
|
@ -67,8 +72,12 @@ def calculate_entropies(
|
|||
Splits the tokens into chunks of size max_length and calculates entropies for each chunk.
|
||||
Entropy model can be executed on cpu or gpu, specify either 'cuda' or 'cpu' in the device argument.
|
||||
"""
|
||||
with torch.no_grad():
|
||||
|
||||
grad_context = nullcontext() if enable_grad else torch.no_grad()
|
||||
|
||||
with grad_context:
|
||||
entropies = []
|
||||
preds = []
|
||||
max_length = getattr(entropy_model, "max_length", 8192)
|
||||
batch_numel = max_length * patching_batch_size
|
||||
splits = torch.split(tokens.flatten(), batch_numel)
|
||||
|
@ -86,12 +95,15 @@ def calculate_entropies(
|
|||
pred = pred.reshape(-1, pred.shape[-1])[
|
||||
: split.numel() - pad_size, :
|
||||
] # [batch_size * seq_len, vocab]
|
||||
preds.append(pred)
|
||||
pred_entropies = entropy(pred)
|
||||
entropies.append(pred_entropies)
|
||||
|
||||
concat_entropies = torch.cat(entropies, dim=0)
|
||||
concat_entropies = concat_entropies.reshape(tokens.shape)
|
||||
return concat_entropies
|
||||
concat_preds = torch.cat(preds, dim=0)
|
||||
concat_preds = concat_preds.reshape(tokens.shape[0], tokens.shape[1], -1)
|
||||
return concat_entropies, concat_preds
|
||||
|
||||
|
||||
def patch_start_mask_from_entropy_with_monotonicity(entropies, t):
|
||||
|
@ -101,6 +113,10 @@ def patch_start_mask_from_entropy_with_monotonicity(entropies, t):
|
|||
returns [bs, seq_len] mask where True indicates the start of a patch
|
||||
"""
|
||||
bs, seq_len = entropies.shape
|
||||
|
||||
if seq_len == 0:
|
||||
return entropies > t
|
||||
|
||||
mask = torch.zeros_like(entropies, dtype=torch.bool)
|
||||
mask[:, 0] = True
|
||||
|
||||
|
@ -123,6 +139,10 @@ def patch_start_mask_global_and_monotonicity(entropies, t, t_add=0):
|
|||
returns [bs, seq_len] mask where True indicates the start of a patch
|
||||
"""
|
||||
bs, seq_len = entropies.shape
|
||||
|
||||
if seq_len == 0:
|
||||
return entropies > t
|
||||
|
||||
mask = torch.zeros_like(entropies, dtype=torch.bool)
|
||||
mask[:, 0] = True
|
||||
|
||||
|
@ -521,12 +541,12 @@ class Patcher:
|
|||
if self.log_time:
|
||||
s = time.time()
|
||||
if entropies is not None:
|
||||
scores = torch.tensor(entropies, dtype=torch.float32)
|
||||
scores = entropies.to(dtype=torch.float32)
|
||||
elif preds is not None:
|
||||
scores = entropy(preds)
|
||||
else:
|
||||
start_entropies = time.time()
|
||||
scores = calculate_entropies(
|
||||
scores, _ = calculate_entropies(
|
||||
tokens,
|
||||
self.entropy_model,
|
||||
self.patching_batch_size,
|
||||
|
|
|
@ -199,9 +199,6 @@ class LocalModelBase(nn.Module):
|
|||
class LocalEncoder(LocalModelBase):
|
||||
def __init__(self, args: LocalModelArgs):
|
||||
super().__init__(args)
|
||||
self.output_proj = (
|
||||
args.patching_mode in ["entropy", "probmax"]
|
||||
) and args.entropy_model_checkpoint_dir is None
|
||||
|
||||
self.apply_transformer = args.use_local_encoder_transformer
|
||||
self.downsampling_by_pooling = args.downsampling_by_pooling
|
||||
|
|
|
@ -162,9 +162,6 @@ def create_causal_mask(
|
|||
return "causal"
|
||||
|
||||
if BLT_SUPPRESS_ATTN_ERROR == 1:
|
||||
logging.warning(
|
||||
"SDPA attention being used, which doesn't have specialized attention implementations for block_causal and local_block_causal attention. Allowing model to run since BLT_SUPPRESS_ATTN_ERROR=1"
|
||||
)
|
||||
return "causal"
|
||||
else:
|
||||
raise ValueError(
|
||||
|
|
|
@ -117,7 +117,7 @@ def main(
|
|||
text = get_text(doc)
|
||||
tokens = torch.tensor(tokenizer.encode(text))
|
||||
patch_start = time.time()
|
||||
scores = calculate_entropies(
|
||||
scores, _ = calculate_entropies(
|
||||
tokens,
|
||||
entropy_model,
|
||||
patching_batch_size,
|
||||
|
|
Loading…
Reference in a new issue