2024-12-12 23:32:30 +00:00
|
|
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
|
|
import math
|
|
|
|
import time
|
|
|
|
from collections import defaultdict
|
|
|
|
from enum import Enum
|
|
|
|
|
|
|
|
import torch
|
|
|
|
from pydantic import BaseModel
|
|
|
|
from torch.nn import functional as F
|
|
|
|
|
|
|
|
from bytelatent.distributed import get_local_rank
|
|
|
|
from bytelatent.entropy_model import load_entropy_model
|
|
|
|
|
|
|
|
# from src.slurm import get_local_rank
|
|
|
|
from bytelatent.tokenizers.blt_tokenizer import BPE_ID, OFFSET
|
|
|
|
from bytelatent.tokenizers.constants import BPE_ID, OFFSET
|
|
|
|
|
|
|
|
|
|
|
|
class PatchingModeEnum(str, Enum):
|
|
|
|
entropy = "entropy"
|
|
|
|
bpe = "bpe"
|
|
|
|
bpe_patcher = "bpe_patcher"
|
|
|
|
space = "space"
|
|
|
|
|
|
|
|
|
|
|
|
class PatcherArgs(BaseModel):
|
|
|
|
patching_mode: PatchingModeEnum = PatchingModeEnum.entropy
|
|
|
|
patching_device: str = "cuda"
|
|
|
|
entropy_model_checkpoint_dir: str | None = None
|
|
|
|
realtime_patching: bool = False
|
|
|
|
threshold: float = 1.335442066192627
|
|
|
|
threshold_add: float | None = None
|
|
|
|
max_patch_length: int | None = None
|
|
|
|
patch_size: float = 4.5
|
|
|
|
patching_batch_size: int = 1
|
|
|
|
data_loader_patching: bool = False
|
|
|
|
device: str = "cuda"
|
|
|
|
monotonicity: bool = False
|
|
|
|
log_time: bool = False
|
|
|
|
|
|
|
|
def build(self) -> "Patcher":
|
|
|
|
return Patcher(self)
|
|
|
|
|
|
|
|
|
|
|
|
def entropy(scores):
|
|
|
|
"""
|
|
|
|
scores: [bs, seq_len, vocab]
|
|
|
|
returns [bs, seq_len]
|
|
|
|
|
|
|
|
Computes the entropy for each token in the batch.
|
|
|
|
Note: uses natural log.
|
|
|
|
"""
|
|
|
|
log_probs = F.log_softmax(scores, dim=-1)
|
|
|
|
probs = torch.exp(log_probs)
|
|
|
|
p_log_p = log_probs * probs
|
|
|
|
entropy = -p_log_p.sum(dim=-1)
|
|
|
|
return entropy
|
|
|
|
|
|
|
|
|
|
|
|
def calculate_entropies(
|
|
|
|
tokens: torch.tensor, entropy_model, patching_batch_size, device: str | None = None
|
|
|
|
):
|
|
|
|
"""
|
|
|
|
tokens: 2D tensor of shape [batch_size, seq_len]
|
|
|
|
Return 2D tensor of shape [batch_size, seq_len] with entropies for each token.
|
|
|
|
|
|
|
|
Splits the tokens into chunks of size max_length and calculates entropies for each chunk.
|
|
|
|
Entropy model can be executed on cpu or gpu, specify either 'cuda' or 'cpu' in the device argument.
|
|
|
|
"""
|
|
|
|
with torch.no_grad():
|
|
|
|
entropies = []
|
|
|
|
max_length = getattr(entropy_model, "max_length", 8192)
|
|
|
|
batch_numel = max_length * patching_batch_size
|
|
|
|
splits = torch.split(tokens.flatten(), batch_numel)
|
|
|
|
for split in splits:
|
|
|
|
pad_size = (max_length - (split.numel() % max_length)) % max_length
|
|
|
|
pad = torch.zeros(
|
|
|
|
pad_size, dtype=split.dtype, device=split.device, requires_grad=False
|
|
|
|
)
|
|
|
|
split = torch.cat((split, pad), dim=0)
|
|
|
|
split = split.reshape(-1, max_length)
|
|
|
|
if device is not None:
|
|
|
|
split = split.to(device)
|
|
|
|
assert torch.all(split >= 0) and torch.all(split < 260)
|
2025-01-13 23:28:14 +00:00
|
|
|
pred = entropy_model(split)
|
2024-12-12 23:32:30 +00:00
|
|
|
pred = pred.reshape(-1, pred.shape[-1])[
|
|
|
|
: split.numel() - pad_size, :
|
|
|
|
] # [batch_size * seq_len, vocab]
|
|
|
|
pred_entropies = entropy(pred)
|
|
|
|
entropies.append(pred_entropies)
|
|
|
|
|
2025-01-13 23:28:14 +00:00
|
|
|
concat_entropies = torch.cat(entropies, dim=0)
|
|
|
|
concat_entropies = concat_entropies.reshape(tokens.shape)
|
|
|
|
return concat_entropies
|
2024-12-12 23:32:30 +00:00
|
|
|
|
|
|
|
|
|
|
|
def patch_start_mask_from_entropy_with_monotonicity(entropies, t):
|
|
|
|
"""
|
|
|
|
entropies: [bs, seq_len] torch tensor of entropies
|
|
|
|
t: threshold
|
|
|
|
returns [bs, seq_len] mask where True indicates the start of a patch
|
|
|
|
"""
|
|
|
|
bs, seq_len = entropies.shape
|
|
|
|
mask = torch.zeros_like(entropies, dtype=torch.bool)
|
|
|
|
mask[:, 0] = True
|
|
|
|
|
|
|
|
# Calculate differences between consecutive elements along the sequence length
|
|
|
|
differences = entropies[:, 1:] - entropies[:, :-1]
|
|
|
|
|
|
|
|
# Calculate conditions for all elements except the first one in each sequence
|
|
|
|
condition = differences > t
|
|
|
|
|
|
|
|
# Update the mask based on the condition
|
|
|
|
mask[:, 1:] = condition
|
|
|
|
|
|
|
|
return mask
|
|
|
|
|
|
|
|
|
|
|
|
def patch_start_mask_global_and_monotonicity(entropies, t, t_add=0):
|
|
|
|
"""
|
|
|
|
entropies: [bs, seq_len] torch tensor of entropies
|
|
|
|
t: threshold
|
|
|
|
returns [bs, seq_len] mask where True indicates the start of a patch
|
|
|
|
"""
|
|
|
|
bs, seq_len = entropies.shape
|
|
|
|
mask = torch.zeros_like(entropies, dtype=torch.bool)
|
|
|
|
mask[:, 0] = True
|
|
|
|
|
|
|
|
# Calculate differences between consecutive elements along the sequence length
|
|
|
|
differences = entropies[:, 1:] - entropies[:, :-1]
|
|
|
|
|
|
|
|
# Calculate conditions for all elements except the first one in each sequence
|
|
|
|
condition = (differences > t_add) & (entropies[:, 1:] > t) & (~mask[:, :-1])
|
|
|
|
|
|
|
|
# Update the mask based on the condition
|
|
|
|
mask[:, 1:] = condition
|
|
|
|
|
|
|
|
return mask
|
|
|
|
|
|
|
|
|
|
|
|
def patch_start_ids_from_patch_start_mask(patch_start_mask):
|
|
|
|
bs, trunc_seq_len = patch_start_mask.shape
|
|
|
|
max_patches = patch_start_mask.sum(dim=1).max()
|
|
|
|
if max_patches == 0:
|
|
|
|
patch_start_ids = torch.full(
|
|
|
|
(bs, trunc_seq_len),
|
|
|
|
trunc_seq_len,
|
|
|
|
dtype=torch.long,
|
|
|
|
device=patch_start_mask.device,
|
|
|
|
)
|
|
|
|
else:
|
|
|
|
patch_ids = (
|
|
|
|
torch.arange(trunc_seq_len, device=patch_start_mask.device)
|
|
|
|
.unsqueeze(0)
|
|
|
|
.repeat(bs, 1)
|
|
|
|
)
|
|
|
|
extra_patch_ids = torch.full(
|
|
|
|
(bs, trunc_seq_len),
|
|
|
|
trunc_seq_len,
|
|
|
|
dtype=torch.long,
|
|
|
|
device=patch_start_mask.device,
|
|
|
|
)
|
|
|
|
all_patch_ids = torch.cat((patch_ids, extra_patch_ids), dim=1)
|
|
|
|
patch_start_mask_padded = torch.cat(
|
|
|
|
(patch_start_mask, ~patch_start_mask), dim=1
|
|
|
|
)
|
|
|
|
patch_start_ids = all_patch_ids[patch_start_mask_padded].reshape(
|
|
|
|
bs, trunc_seq_len
|
|
|
|
)[:, :max_patches]
|
|
|
|
return patch_start_ids
|
|
|
|
|
|
|
|
|
|
|
|
def check_non_zero_after_zero(tensor):
|
|
|
|
zero_mask = tensor == 0
|
|
|
|
shifted_mask = torch.cat(
|
|
|
|
[
|
|
|
|
torch.zeros(tensor.shape[0], 1, dtype=torch.bool, device=tensor.device),
|
|
|
|
zero_mask[:, :-1],
|
|
|
|
],
|
|
|
|
dim=1,
|
|
|
|
)
|
|
|
|
non_zero_after_zero = (tensor != 0) & shifted_mask
|
|
|
|
return non_zero_after_zero.any()
|
|
|
|
|
|
|
|
|
|
|
|
def patch_lengths_from_start_ids(patch_start_ids, seq_len):
|
|
|
|
"""
|
|
|
|
Calculate patch lengths from start ids.
|
|
|
|
start ids: ex: [0, 1, 7, 7, 7, 7, 7], it has the start ids of the patches (here 0, 1), and then
|
|
|
|
the rest are filled to the seq len.
|
|
|
|
seq_len: ex: 7 length of the sequence
|
|
|
|
|
|
|
|
returns the patch lengths:
|
|
|
|
[1, 6] for the above example.
|
|
|
|
"""
|
|
|
|
last_ids = torch.full_like(patch_start_ids[:, :1], seq_len - 1)
|
|
|
|
patch_end_ids = torch.cat((patch_start_ids[:, 1:] - 1, last_ids), dim=1)
|
|
|
|
patch_lengths = patch_end_ids - patch_start_ids + 1
|
|
|
|
assert torch.all(patch_lengths >= 0), f"{patch_lengths}"
|
|
|
|
assert not check_non_zero_after_zero(patch_lengths), f"{patch_lengths}"
|
|
|
|
return patch_lengths
|
|
|
|
|
|
|
|
|
|
|
|
def find_space_patch_start_ids(tokens):
|
|
|
|
bs, seq_len = tokens.shape
|
|
|
|
tokens_no_offset = tokens - OFFSET
|
|
|
|
patch_end_mask = (
|
|
|
|
(tokens_no_offset < ord("0"))
|
|
|
|
| ((ord("9") < tokens_no_offset) & (tokens_no_offset < ord("A")))
|
|
|
|
| ((ord("Z") < tokens_no_offset) & (tokens_no_offset < ord("a")))
|
|
|
|
| ((ord("z") < tokens_no_offset) & (tokens_no_offset < 0b1000_0000))
|
|
|
|
| (0b1100_0000 <= tokens_no_offset)
|
|
|
|
)
|
|
|
|
patch_end_mask[:, 1:] &= patch_end_mask[:, :-1].bitwise_not()
|
|
|
|
patch_end_mask |= tokens < OFFSET
|
|
|
|
|
|
|
|
patch_start_mask = torch.cat(
|
|
|
|
[
|
|
|
|
torch.tensor([1, 1], device=tokens.device, dtype=torch.bool)
|
|
|
|
.unsqueeze(0)
|
|
|
|
.repeat(bs, 1),
|
|
|
|
patch_end_mask[:, 1:],
|
|
|
|
],
|
|
|
|
dim=1,
|
|
|
|
)
|
|
|
|
max_patches = patch_start_mask.sum(dim=1).max()
|
|
|
|
|
|
|
|
patch_ids = (
|
|
|
|
torch.arange(seq_len + 1, device=tokens.device).unsqueeze(0).repeat(bs, 1)
|
|
|
|
)
|
|
|
|
extra_patch_ids = torch.full(
|
|
|
|
(bs, seq_len + 1), seq_len + 1, dtype=torch.long, device=tokens.device
|
|
|
|
)
|
|
|
|
all_patch_ids = torch.cat((patch_ids, extra_patch_ids), dim=1)
|
|
|
|
patch_start_mask_padded = torch.cat((patch_start_mask, ~patch_start_mask), dim=1)
|
|
|
|
|
|
|
|
patch_start_ids = all_patch_ids[patch_start_mask_padded].reshape(bs, -1)[
|
|
|
|
:, :max_patches
|
|
|
|
]
|
|
|
|
return patch_start_ids
|
|
|
|
|
|
|
|
|
|
|
|
def to_device(entropy_model, device=None):
|
|
|
|
if device == "cuda":
|
|
|
|
rank = get_local_rank()
|
|
|
|
device = f"cuda:{rank}"
|
|
|
|
entropy_model = entropy_model.to(device)
|
|
|
|
return entropy_model, device
|
|
|
|
|
|
|
|
|
|
|
|
def model_pred_to_bpe_patching_pred(pred):
|
|
|
|
_, indices = torch.max(pred, dim=1)
|
|
|
|
return indices == BPE_ID
|
|
|
|
|
|
|
|
|
|
|
|
def apply_bpe_patcher(tokens, bpe_patcher, patching_batch_size, device=None):
|
|
|
|
assert tokens.device == torch.device(
|
|
|
|
"cpu"
|
|
|
|
), f"{tokens.device} != cpu expects tokens to be on cpu"
|
|
|
|
with torch.no_grad():
|
|
|
|
bpe_patcher_device, device = to_device(
|
|
|
|
bpe_patcher, device
|
|
|
|
) # Get entropy model to right rank device.
|
|
|
|
bpe_patching_mask = []
|
|
|
|
max_length = getattr(bpe_patcher, "max_length", 8192)
|
|
|
|
batch_numel = max_length * patching_batch_size
|
|
|
|
splits = torch.split(tokens.flatten(), batch_numel)
|
|
|
|
for split in splits:
|
|
|
|
pad_size = (max_length - (split.numel() % max_length)) % max_length
|
|
|
|
pad = torch.zeros(
|
|
|
|
pad_size, dtype=split.dtype, device=split.device, requires_grad=False
|
|
|
|
)
|
|
|
|
split = torch.cat((split, pad), dim=0)
|
|
|
|
split = split.reshape(-1, max_length).to(device)
|
|
|
|
assert torch.all(split >= 0) and torch.all(split < 260)
|
|
|
|
pred = bpe_patcher_device(split)
|
|
|
|
pred_cpu = pred[0].cpu()
|
|
|
|
pred_cpu = pred_cpu.reshape(-1, pred_cpu.shape[-1])[
|
|
|
|
: split.numel() - pad_size, :
|
|
|
|
] # [batch_size * seq_len, vocab]
|
|
|
|
bpe_patching_pred = model_pred_to_bpe_patching_pred(pred_cpu)
|
|
|
|
bpe_patching_mask.append(bpe_patching_pred)
|
|
|
|
bpe_patching_mask = torch.cat(bpe_patching_mask, dim=0)
|
|
|
|
bpe_patching_mask = bpe_patching_mask.reshape(tokens.shape)
|
|
|
|
return bpe_patching_mask
|
|
|
|
|
|
|
|
|
|
|
|
def find_bpe_patcher_patch_start_ids(
|
|
|
|
tokens, bpe_patcher, patching_batch_size, device=None, include_next_token=True
|
|
|
|
):
|
|
|
|
bs, seq_len = tokens.shape
|
|
|
|
|
|
|
|
first_ids = (
|
|
|
|
torch.tensor([0, 1], dtype=torch.long, device=tokens.device)
|
|
|
|
.unsqueeze(0)
|
|
|
|
.repeat(bs, 1)
|
|
|
|
)
|
|
|
|
preds_truncation_len = first_ids.shape[1]
|
|
|
|
token_input = tokens[:, 1:] if include_next_token else tokens[:, 1:-1]
|
|
|
|
if token_input.shape[1] >= 1:
|
|
|
|
patch_start_mask = apply_bpe_patcher(
|
|
|
|
token_input, bpe_patcher, patching_batch_size, device
|
|
|
|
)
|
|
|
|
assert (
|
|
|
|
patch_start_mask.shape[1]
|
|
|
|
== tokens.shape[1] + include_next_token - preds_truncation_len
|
|
|
|
), f"{patch_start_mask.shape[1]} != {tokens.shape[1] + include_next_token - preds_truncation_len}"
|
|
|
|
patch_start_ids = patch_start_ids_from_patch_start_mask(patch_start_mask)
|
|
|
|
patch_start_ids = torch.cat(
|
|
|
|
(first_ids, patch_start_ids + preds_truncation_len), dim=1
|
|
|
|
)
|
|
|
|
else:
|
|
|
|
patch_start_ids = first_ids
|
|
|
|
return patch_start_ids
|
|
|
|
|
|
|
|
|
|
|
|
def find_entropy_patch_start_ids(
|
|
|
|
entropies,
|
|
|
|
patch_size=None,
|
|
|
|
threshold=None,
|
|
|
|
threshold_add=None,
|
|
|
|
monotonicity=False,
|
|
|
|
include_next_token=True,
|
|
|
|
):
|
|
|
|
"""
|
|
|
|
Use entropies to find the start ids of each patch.
|
|
|
|
Use patch_size or threshold to figure out the total number of patches to allocate.
|
|
|
|
|
|
|
|
When threshold is not None the number of patches is not constant between
|
|
|
|
different sequences, but patches can be identified incrementally rather than
|
|
|
|
decided globally using the entire sequence.
|
|
|
|
"""
|
|
|
|
bs, seq_len = entropies.shape[:2]
|
|
|
|
|
|
|
|
first_ids = (
|
|
|
|
torch.tensor([0, 1], dtype=torch.long, device=entropies.device)
|
|
|
|
.unsqueeze(0)
|
|
|
|
.repeat(bs, 1)
|
|
|
|
)
|
|
|
|
preds_truncation_len = first_ids.shape[
|
|
|
|
1
|
|
|
|
] # remove the first preds because they will be start of patches.
|
|
|
|
entropies = entropies[:, 1:]
|
|
|
|
if threshold is None:
|
|
|
|
num_patches = seq_len // patch_size
|
|
|
|
patch_start_ids = entropies.topk(num_patches - 2, dim=1).indices
|
|
|
|
patch_start_ids = patch_start_ids.sort(dim=1).values
|
|
|
|
else:
|
|
|
|
# Assumes that there is at least one token going over the threshold
|
|
|
|
if monotonicity:
|
|
|
|
patch_start_mask = patch_start_mask_from_entropy_with_monotonicity(
|
|
|
|
entropies, threshold
|
|
|
|
)
|
|
|
|
elif threshold_add is not None and threshold is not None:
|
|
|
|
patch_start_mask = patch_start_mask_global_and_monotonicity(
|
|
|
|
entropies, threshold, threshold_add
|
|
|
|
)
|
|
|
|
else:
|
|
|
|
patch_start_mask = entropies > threshold
|
|
|
|
if not include_next_token:
|
|
|
|
patch_start_mask = patch_start_mask[:, :-1]
|
|
|
|
# patch_start_mask[1:] |= tokens[:-1] < OFFSET
|
|
|
|
patch_start_ids = patch_start_ids_from_patch_start_mask(patch_start_mask)
|
|
|
|
|
|
|
|
patch_start_ids = torch.cat(
|
|
|
|
(first_ids, patch_start_ids + preds_truncation_len), dim=1
|
|
|
|
)
|
|
|
|
return patch_start_ids
|
|
|
|
|
|
|
|
|
|
|
|
def rightpad(seq, pad_id, max_len):
|
|
|
|
return seq + [pad_id] * (max_len - len(seq))
|
|
|
|
|
|
|
|
|
|
|
|
def find_bpe_delim_patch_start_ids(tokens, delim):
|
|
|
|
ids = (tokens[:, :-1] == delim).nonzero(as_tuple=False)
|
|
|
|
out = [[0, 1] for _ in range(tokens.shape[0])]
|
|
|
|
for x, y in ids:
|
|
|
|
# start is at delim + 1, delim should be the last element in the patch.
|
|
|
|
out[x.item()].append(y.item() + 1)
|
|
|
|
max_len = max([len(elt) for elt in out])
|
|
|
|
out = [rightpad(elt, tokens.shape[1], max_len) for elt in out]
|
|
|
|
patch_start_ids = torch.tensor(out, dtype=tokens.dtype, device=tokens.device)
|
|
|
|
return patch_start_ids
|
|
|
|
|
|
|
|
|
|
|
|
def find_lookup_table_start_mask(
|
|
|
|
tokens: torch.Tensor, lookup_table: torch.Tensor, include_next_token=True
|
|
|
|
):
|
|
|
|
window_size = lookup_table.ndim
|
|
|
|
# Unfold the tensor to get sliding windows
|
|
|
|
unfolded = tokens.unfold(1, window_size, 1)
|
|
|
|
# Gather indices for each dimension
|
|
|
|
indices = [unfolded[..., i] for i in range(window_size)]
|
|
|
|
# Access the lookup table using the gathered indices
|
|
|
|
result = lookup_table[indices]
|
|
|
|
return result
|
|
|
|
|
|
|
|
|
|
|
|
def find_lookup_table_patch_start_ids(
|
|
|
|
tokens: torch.Tensor, lookup_table: torch.Tensor, include_next_token=True
|
|
|
|
):
|
|
|
|
bs, seq_len = tokens.shape
|
|
|
|
|
|
|
|
first_ids = (
|
|
|
|
torch.tensor([0, 1], dtype=torch.long, device=tokens.device)
|
|
|
|
.unsqueeze(0)
|
|
|
|
.repeat(bs, 1)
|
|
|
|
)
|
|
|
|
preds_truncation_len = first_ids.shape[1]
|
|
|
|
window_size = lookup_table.ndim
|
|
|
|
assert window_size == 2, f"{window_size} != 2"
|
|
|
|
# output dimensions: token_input shape - window_size + 1 --> we want first ids + this = tokens shape + 1 if next token otherwise just token shape
|
|
|
|
token_input = (
|
|
|
|
tokens if include_next_token else tokens[:, : -preds_truncation_len + 1]
|
|
|
|
)
|
|
|
|
if token_input.shape[1] >= window_size:
|
|
|
|
patch_start_mask = find_lookup_table_start_mask(
|
|
|
|
token_input, lookup_table, include_next_token
|
|
|
|
)
|
|
|
|
assert (
|
|
|
|
patch_start_mask.shape[1]
|
|
|
|
== tokens.shape[1] + include_next_token - preds_truncation_len
|
|
|
|
), f"{patch_start_mask.shape[1]} != {tokens.shape[1] + include_next_token - preds_truncation_len}"
|
|
|
|
patch_start_ids = patch_start_ids_from_patch_start_mask(patch_start_mask)
|
|
|
|
patch_start_ids = torch.cat(
|
|
|
|
(first_ids, patch_start_ids + preds_truncation_len), dim=1
|
|
|
|
)
|
|
|
|
else:
|
|
|
|
patch_start_ids = first_ids
|
|
|
|
return patch_start_ids
|
|
|
|
|
|
|
|
|
|
|
|
def split_large_numbers(lst, m):
|
|
|
|
new_lst = []
|
|
|
|
for i in lst:
|
|
|
|
if i > m:
|
|
|
|
while i > m:
|
|
|
|
new_lst.append(m)
|
|
|
|
i -= m
|
|
|
|
new_lst.append(i)
|
|
|
|
else:
|
|
|
|
new_lst.append(i)
|
|
|
|
assert sum(new_lst) == sum(lst), f"{sum(new_lst)} != {sum(lst)}"
|
|
|
|
return new_lst
|
|
|
|
|
|
|
|
|
|
|
|
class Patcher:
|
|
|
|
def __init__(self, patcher_args: PatcherArgs):
|
|
|
|
self.patcher_args = patcher_args
|
|
|
|
self.patching_mode = patcher_args.patching_mode
|
|
|
|
self.realtime_patching = patcher_args.realtime_patching
|
|
|
|
if self.realtime_patching:
|
|
|
|
assert (
|
|
|
|
patcher_args.entropy_model_checkpoint_dir is not None
|
|
|
|
), "Cannot require realtime patching without an entropy model checkpoint"
|
|
|
|
entropy_model = load_entropy_model(
|
|
|
|
patcher_args.entropy_model_checkpoint_dir
|
|
|
|
)
|
|
|
|
entropy_model, _ = to_device(entropy_model, patcher_args.patching_device)
|
|
|
|
self.entropy_model = entropy_model
|
|
|
|
else:
|
|
|
|
self.entropy_model = None
|
|
|
|
self.threshold = patcher_args.threshold
|
|
|
|
self.threshold_add = patcher_args.threshold_add
|
|
|
|
self.max_patch_length = patcher_args.max_patch_length
|
|
|
|
self.patch_size = patcher_args.patch_size
|
|
|
|
self.patching_batch_size = patcher_args.patching_batch_size
|
|
|
|
self.data_loader_patching = patcher_args.data_loader_patching
|
|
|
|
self.device = patcher_args.device
|
|
|
|
self.monotonicity = patcher_args.monotonicity
|
|
|
|
self.log_time = patcher_args.log_time
|
|
|
|
if self.log_time:
|
|
|
|
self.log = defaultdict(float)
|
|
|
|
|
|
|
|
def patch(
|
|
|
|
self,
|
|
|
|
tokens: torch.Tensor,
|
|
|
|
include_next_token: bool = False,
|
|
|
|
preds: torch.Tensor | None = None,
|
|
|
|
entropies: torch.Tensor | None = None,
|
|
|
|
threshold: float = None,
|
|
|
|
) -> torch.Tensor:
|
|
|
|
"""
|
|
|
|
tokens: 2D tensor of shape [batch_size, seq_len] that needs to be patched
|
|
|
|
Returns patch lengths and optionally scores associated with the tokens (i.e. entropies, logprobs etc.)
|
|
|
|
-> output tensor: [batch_size, max_num_patches]
|
|
|
|
each tensor is processed independently and gets right padded with zeros.
|
|
|
|
|
|
|
|
Patching with the following modes:
|
|
|
|
1. patching_mode = None: static patch size
|
|
|
|
2. patching_mode = "entropy":
|
|
|
|
calculate entropy of each token, allocate patches so that the total
|
|
|
|
number of patches is the same as static patching but choose to begin
|
|
|
|
patches on tokens where the model is most uncertain (highest entropy).
|
|
|
|
|
|
|
|
When threshold is provided, it uses the threshold to decide when to
|
|
|
|
start a new patch.
|
|
|
|
3. patching_mode = "space":
|
|
|
|
use space like tokens to define the patches.
|
|
|
|
4. patching_mode = "bpe":
|
|
|
|
use bpe delim tokens to define the patches.
|
|
|
|
|
|
|
|
To correctly patch the last token, it may be necessary to include the next token in the patch
|
|
|
|
lengths calculations. This is controlled by the include_next_token argument.
|
|
|
|
"""
|
|
|
|
bs, seq_len = tokens.shape
|
|
|
|
seq_len_next_tok = seq_len + 1 if include_next_token else seq_len
|
|
|
|
scores = None
|
|
|
|
# STATIC
|
|
|
|
if self.patching_mode is None:
|
|
|
|
patch_lengths = torch.zeros(
|
|
|
|
(bs, math.ceil(seq_len_next_tok / self.patch_size)),
|
|
|
|
dtype=tokens.dtype,
|
|
|
|
device=tokens.device,
|
|
|
|
).fill_(self.patch_size)
|
|
|
|
if seq_len_next_tok % self.patch_size != 0:
|
|
|
|
patch_lengths[:, -1] = seq_len_next_tok % self.patch_size
|
|
|
|
# ENTROPY
|
|
|
|
elif self.patching_mode == PatchingModeEnum.entropy:
|
|
|
|
if self.log_time:
|
|
|
|
s = time.time()
|
|
|
|
if entropies is not None:
|
|
|
|
scores = torch.tensor(entropies, dtype=torch.float32)
|
|
|
|
elif preds is not None:
|
|
|
|
scores = entropy(preds)
|
|
|
|
else:
|
|
|
|
start_entropies = time.time()
|
|
|
|
scores = calculate_entropies(
|
|
|
|
tokens,
|
|
|
|
self.entropy_model,
|
|
|
|
self.patching_batch_size,
|
|
|
|
self.device,
|
|
|
|
)
|
|
|
|
if self.log_time:
|
|
|
|
self.log["calculate_entropies"] += time.time() - s
|
|
|
|
s = time.time()
|
|
|
|
patch_start_ids = find_entropy_patch_start_ids(
|
|
|
|
scores,
|
|
|
|
self.patch_size,
|
|
|
|
include_next_token=include_next_token,
|
|
|
|
threshold=threshold if threshold is not None else self.threshold,
|
|
|
|
threshold_add=self.threshold_add,
|
|
|
|
monotonicity=self.monotonicity,
|
|
|
|
)
|
|
|
|
if self.log_time:
|
|
|
|
self.log["find_entropy_patch_start_ids"] += time.time() - s
|
|
|
|
s = time.time()
|
|
|
|
patch_lengths = patch_lengths_from_start_ids(
|
|
|
|
patch_start_ids, seq_len_next_tok
|
|
|
|
)
|
|
|
|
if self.log_time:
|
|
|
|
self.log["patch_lengths_from_start_ids"] += time.time() - s
|
|
|
|
s = time.time()
|
|
|
|
# BPE
|
|
|
|
elif self.patching_mode == PatchingModeEnum.bpe:
|
|
|
|
patch_start_ids = find_bpe_delim_patch_start_ids(tokens, delim=BPE_ID)
|
|
|
|
patch_lengths = patch_lengths_from_start_ids(
|
|
|
|
patch_start_ids, seq_len_next_tok
|
|
|
|
)
|
|
|
|
elif self.patching_mode == PatchingModeEnum.bpe_patcher:
|
|
|
|
patch_start_ids = find_bpe_patcher_patch_start_ids(
|
|
|
|
tokens,
|
|
|
|
self.entropy_model,
|
|
|
|
self.patching_batch_size,
|
|
|
|
self.device,
|
|
|
|
include_next_token,
|
|
|
|
)
|
|
|
|
patch_lengths = patch_lengths_from_start_ids(
|
|
|
|
patch_start_ids, seq_len_next_tok
|
|
|
|
)
|
|
|
|
# SPACE
|
|
|
|
elif self.patching_mode == PatchingModeEnum.space:
|
|
|
|
patch_start_ids = find_space_patch_start_ids(tokens)
|
|
|
|
patch_lengths = patch_lengths_from_start_ids(
|
|
|
|
patch_start_ids, seq_len_next_tok
|
|
|
|
)
|
|
|
|
else:
|
|
|
|
raise NotImplementedError(f"self.patching_mode {self.patching_mode}")
|
|
|
|
|
|
|
|
# Apply any processing to patch lengths
|
|
|
|
if self.max_patch_length is not None:
|
|
|
|
# TODO: avoid going back to a list here.
|
|
|
|
patch_lengths = [
|
|
|
|
split_large_numbers(pl, self.max_patch_length)
|
|
|
|
for pl in patch_lengths.tolist()
|
|
|
|
]
|
|
|
|
max_len = max([len(pl) for pl in patch_lengths])
|
|
|
|
patch_lengths = [rightpad(pl, 0, max_len=max_len) for pl in patch_lengths]
|
|
|
|
patch_lengths = torch.tensor(
|
|
|
|
patch_lengths, dtype=tokens.dtype, device=tokens.device
|
|
|
|
)
|
|
|
|
assert not check_non_zero_after_zero(patch_lengths)
|
|
|
|
# Find the last non-zero column index using argmax on a reversed version of the tensor
|
|
|
|
last_non_zero_col_reversed = (
|
|
|
|
(patch_lengths != 0).flip(dims=[1]).int().argmax(dim=1).min()
|
|
|
|
)
|
|
|
|
# Slice the tensor up to the last non-zero column
|
|
|
|
patch_lengths = patch_lengths[
|
|
|
|
:, : patch_lengths.shape[1] - last_non_zero_col_reversed
|
|
|
|
]
|
|
|
|
assert (
|
|
|
|
torch.sum(patch_lengths)
|
|
|
|
== tokens.numel() + include_next_token * tokens.shape[0]
|
|
|
|
), f"{torch.sum(patch_lengths)} != {tokens.numel() + include_next_token * tokens.shape[0]}"
|
|
|
|
if self.log_time:
|
|
|
|
self.log["postprocessing_patch_lengths"] += time.time() - s
|
|
|
|
self.log["tokens"] += patch_lengths.sum().item()
|
|
|
|
return patch_lengths, scores
|