mirror of
https://github.com/facebookresearch/blt.git
synced 2025-01-31 10:02:15 +00:00
Fix realtime entropy patching (#26)
* allow loading of the entropy model directly * remove unused argument * remove spammy warning * allow patch_batch_size to be adjusted in the forward() method * revert to original patcher style, fix warning * allow grads when calculating entropies * fix grad flow * return preds from calculate_entropies() * remove legacy arg * fix an error with monotonicity and small sequence lengths * ensure patcher is serializable * revert patcher to original * remove unused import
This commit is contained in:
parent
6ffeb66b53
commit
392117bff2
|
@ -2,6 +2,7 @@
|
||||||
import math
|
import math
|
||||||
import time
|
import time
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
|
from contextlib import nullcontext
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
@ -58,7 +59,11 @@ def entropy(scores):
|
||||||
|
|
||||||
|
|
||||||
def calculate_entropies(
|
def calculate_entropies(
|
||||||
tokens: torch.tensor, entropy_model, patching_batch_size, device: str | None = None
|
tokens: torch.tensor,
|
||||||
|
entropy_model,
|
||||||
|
patching_batch_size,
|
||||||
|
device: str | None = None,
|
||||||
|
enable_grad: bool = False,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
tokens: 2D tensor of shape [batch_size, seq_len]
|
tokens: 2D tensor of shape [batch_size, seq_len]
|
||||||
|
@ -67,8 +72,12 @@ def calculate_entropies(
|
||||||
Splits the tokens into chunks of size max_length and calculates entropies for each chunk.
|
Splits the tokens into chunks of size max_length and calculates entropies for each chunk.
|
||||||
Entropy model can be executed on cpu or gpu, specify either 'cuda' or 'cpu' in the device argument.
|
Entropy model can be executed on cpu or gpu, specify either 'cuda' or 'cpu' in the device argument.
|
||||||
"""
|
"""
|
||||||
with torch.no_grad():
|
|
||||||
|
grad_context = nullcontext() if enable_grad else torch.no_grad()
|
||||||
|
|
||||||
|
with grad_context:
|
||||||
entropies = []
|
entropies = []
|
||||||
|
preds = []
|
||||||
max_length = getattr(entropy_model, "max_length", 8192)
|
max_length = getattr(entropy_model, "max_length", 8192)
|
||||||
batch_numel = max_length * patching_batch_size
|
batch_numel = max_length * patching_batch_size
|
||||||
splits = torch.split(tokens.flatten(), batch_numel)
|
splits = torch.split(tokens.flatten(), batch_numel)
|
||||||
|
@ -86,12 +95,15 @@ def calculate_entropies(
|
||||||
pred = pred.reshape(-1, pred.shape[-1])[
|
pred = pred.reshape(-1, pred.shape[-1])[
|
||||||
: split.numel() - pad_size, :
|
: split.numel() - pad_size, :
|
||||||
] # [batch_size * seq_len, vocab]
|
] # [batch_size * seq_len, vocab]
|
||||||
|
preds.append(pred)
|
||||||
pred_entropies = entropy(pred)
|
pred_entropies = entropy(pred)
|
||||||
entropies.append(pred_entropies)
|
entropies.append(pred_entropies)
|
||||||
|
|
||||||
concat_entropies = torch.cat(entropies, dim=0)
|
concat_entropies = torch.cat(entropies, dim=0)
|
||||||
concat_entropies = concat_entropies.reshape(tokens.shape)
|
concat_entropies = concat_entropies.reshape(tokens.shape)
|
||||||
return concat_entropies
|
concat_preds = torch.cat(preds, dim=0)
|
||||||
|
concat_preds = concat_preds.reshape(tokens.shape[0], tokens.shape[1], -1)
|
||||||
|
return concat_entropies, concat_preds
|
||||||
|
|
||||||
|
|
||||||
def patch_start_mask_from_entropy_with_monotonicity(entropies, t):
|
def patch_start_mask_from_entropy_with_monotonicity(entropies, t):
|
||||||
|
@ -101,6 +113,10 @@ def patch_start_mask_from_entropy_with_monotonicity(entropies, t):
|
||||||
returns [bs, seq_len] mask where True indicates the start of a patch
|
returns [bs, seq_len] mask where True indicates the start of a patch
|
||||||
"""
|
"""
|
||||||
bs, seq_len = entropies.shape
|
bs, seq_len = entropies.shape
|
||||||
|
|
||||||
|
if seq_len == 0:
|
||||||
|
return entropies > t
|
||||||
|
|
||||||
mask = torch.zeros_like(entropies, dtype=torch.bool)
|
mask = torch.zeros_like(entropies, dtype=torch.bool)
|
||||||
mask[:, 0] = True
|
mask[:, 0] = True
|
||||||
|
|
||||||
|
@ -123,6 +139,10 @@ def patch_start_mask_global_and_monotonicity(entropies, t, t_add=0):
|
||||||
returns [bs, seq_len] mask where True indicates the start of a patch
|
returns [bs, seq_len] mask where True indicates the start of a patch
|
||||||
"""
|
"""
|
||||||
bs, seq_len = entropies.shape
|
bs, seq_len = entropies.shape
|
||||||
|
|
||||||
|
if seq_len == 0:
|
||||||
|
return entropies > t
|
||||||
|
|
||||||
mask = torch.zeros_like(entropies, dtype=torch.bool)
|
mask = torch.zeros_like(entropies, dtype=torch.bool)
|
||||||
mask[:, 0] = True
|
mask[:, 0] = True
|
||||||
|
|
||||||
|
@ -521,12 +541,12 @@ class Patcher:
|
||||||
if self.log_time:
|
if self.log_time:
|
||||||
s = time.time()
|
s = time.time()
|
||||||
if entropies is not None:
|
if entropies is not None:
|
||||||
scores = torch.tensor(entropies, dtype=torch.float32)
|
scores = entropies.to(dtype=torch.float32)
|
||||||
elif preds is not None:
|
elif preds is not None:
|
||||||
scores = entropy(preds)
|
scores = entropy(preds)
|
||||||
else:
|
else:
|
||||||
start_entropies = time.time()
|
start_entropies = time.time()
|
||||||
scores = calculate_entropies(
|
scores, _ = calculate_entropies(
|
||||||
tokens,
|
tokens,
|
||||||
self.entropy_model,
|
self.entropy_model,
|
||||||
self.patching_batch_size,
|
self.patching_batch_size,
|
||||||
|
|
|
@ -199,9 +199,6 @@ class LocalModelBase(nn.Module):
|
||||||
class LocalEncoder(LocalModelBase):
|
class LocalEncoder(LocalModelBase):
|
||||||
def __init__(self, args: LocalModelArgs):
|
def __init__(self, args: LocalModelArgs):
|
||||||
super().__init__(args)
|
super().__init__(args)
|
||||||
self.output_proj = (
|
|
||||||
args.patching_mode in ["entropy", "probmax"]
|
|
||||||
) and args.entropy_model_checkpoint_dir is None
|
|
||||||
|
|
||||||
self.apply_transformer = args.use_local_encoder_transformer
|
self.apply_transformer = args.use_local_encoder_transformer
|
||||||
self.downsampling_by_pooling = args.downsampling_by_pooling
|
self.downsampling_by_pooling = args.downsampling_by_pooling
|
||||||
|
|
|
@ -162,9 +162,6 @@ def create_causal_mask(
|
||||||
return "causal"
|
return "causal"
|
||||||
|
|
||||||
if BLT_SUPPRESS_ATTN_ERROR == 1:
|
if BLT_SUPPRESS_ATTN_ERROR == 1:
|
||||||
logging.warning(
|
|
||||||
"SDPA attention being used, which doesn't have specialized attention implementations for block_causal and local_block_causal attention. Allowing model to run since BLT_SUPPRESS_ATTN_ERROR=1"
|
|
||||||
)
|
|
||||||
return "causal"
|
return "causal"
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
|
|
|
@ -117,7 +117,7 @@ def main(
|
||||||
text = get_text(doc)
|
text = get_text(doc)
|
||||||
tokens = torch.tensor(tokenizer.encode(text))
|
tokens = torch.tensor(tokenizer.encode(text))
|
||||||
patch_start = time.time()
|
patch_start = time.time()
|
||||||
scores = calculate_entropies(
|
scores, _ = calculate_entropies(
|
||||||
tokens,
|
tokens,
|
||||||
entropy_model,
|
entropy_model,
|
||||||
patching_batch_size,
|
patching_batch_size,
|
||||||
|
|
Loading…
Reference in a new issue