blt/bytelatent/data/patcher.py

610 lines
22 KiB
Python
Raw Permalink Normal View History

2024-12-12 23:32:30 +00:00
# Copyright (c) Meta Platforms, Inc. and affiliates.
import math
import time
from collections import defaultdict
from enum import Enum
import torch
from pydantic import BaseModel
from torch.nn import functional as F
from bytelatent.distributed import get_local_rank
from bytelatent.entropy_model import load_entropy_model
# from src.slurm import get_local_rank
from bytelatent.tokenizers.blt_tokenizer import BPE_ID, OFFSET
from bytelatent.tokenizers.constants import BPE_ID, OFFSET
class PatchingModeEnum(str, Enum):
entropy = "entropy"
bpe = "bpe"
bpe_patcher = "bpe_patcher"
space = "space"
class PatcherArgs(BaseModel):
patching_mode: PatchingModeEnum = PatchingModeEnum.entropy
patching_device: str = "cuda"
entropy_model_checkpoint_dir: str | None = None
realtime_patching: bool = False
threshold: float = 1.335442066192627
threshold_add: float | None = None
max_patch_length: int | None = None
patch_size: float = 4.5
patching_batch_size: int = 1
data_loader_patching: bool = False
device: str = "cuda"
monotonicity: bool = False
log_time: bool = False
def build(self) -> "Patcher":
return Patcher(self)
def entropy(scores):
"""
scores: [bs, seq_len, vocab]
returns [bs, seq_len]
Computes the entropy for each token in the batch.
Note: uses natural log.
"""
log_probs = F.log_softmax(scores, dim=-1)
probs = torch.exp(log_probs)
p_log_p = log_probs * probs
entropy = -p_log_p.sum(dim=-1)
return entropy
def calculate_entropies(
tokens: torch.tensor, entropy_model, patching_batch_size, device: str | None = None
):
"""
tokens: 2D tensor of shape [batch_size, seq_len]
Return 2D tensor of shape [batch_size, seq_len] with entropies for each token.
Splits the tokens into chunks of size max_length and calculates entropies for each chunk.
Entropy model can be executed on cpu or gpu, specify either 'cuda' or 'cpu' in the device argument.
"""
with torch.no_grad():
entropies = []
max_length = getattr(entropy_model, "max_length", 8192)
batch_numel = max_length * patching_batch_size
splits = torch.split(tokens.flatten(), batch_numel)
for split in splits:
pad_size = (max_length - (split.numel() % max_length)) % max_length
pad = torch.zeros(
pad_size, dtype=split.dtype, device=split.device, requires_grad=False
)
split = torch.cat((split, pad), dim=0)
split = split.reshape(-1, max_length)
if device is not None:
split = split.to(device)
assert torch.all(split >= 0) and torch.all(split < 260)
pred = entropy_model(split)
2024-12-12 23:32:30 +00:00
pred = pred.reshape(-1, pred.shape[-1])[
: split.numel() - pad_size, :
] # [batch_size * seq_len, vocab]
pred_entropies = entropy(pred)
entropies.append(pred_entropies)
concat_entropies = torch.cat(entropies, dim=0)
concat_entropies = concat_entropies.reshape(tokens.shape)
return concat_entropies
2024-12-12 23:32:30 +00:00
def patch_start_mask_from_entropy_with_monotonicity(entropies, t):
"""
entropies: [bs, seq_len] torch tensor of entropies
t: threshold
returns [bs, seq_len] mask where True indicates the start of a patch
"""
bs, seq_len = entropies.shape
mask = torch.zeros_like(entropies, dtype=torch.bool)
mask[:, 0] = True
# Calculate differences between consecutive elements along the sequence length
differences = entropies[:, 1:] - entropies[:, :-1]
# Calculate conditions for all elements except the first one in each sequence
condition = differences > t
# Update the mask based on the condition
mask[:, 1:] = condition
return mask
def patch_start_mask_global_and_monotonicity(entropies, t, t_add=0):
"""
entropies: [bs, seq_len] torch tensor of entropies
t: threshold
returns [bs, seq_len] mask where True indicates the start of a patch
"""
bs, seq_len = entropies.shape
mask = torch.zeros_like(entropies, dtype=torch.bool)
mask[:, 0] = True
# Calculate differences between consecutive elements along the sequence length
differences = entropies[:, 1:] - entropies[:, :-1]
# Calculate conditions for all elements except the first one in each sequence
condition = (differences > t_add) & (entropies[:, 1:] > t) & (~mask[:, :-1])
# Update the mask based on the condition
mask[:, 1:] = condition
return mask
def patch_start_ids_from_patch_start_mask(patch_start_mask):
bs, trunc_seq_len = patch_start_mask.shape
max_patches = patch_start_mask.sum(dim=1).max()
if max_patches == 0:
patch_start_ids = torch.full(
(bs, trunc_seq_len),
trunc_seq_len,
dtype=torch.long,
device=patch_start_mask.device,
)
else:
patch_ids = (
torch.arange(trunc_seq_len, device=patch_start_mask.device)
.unsqueeze(0)
.repeat(bs, 1)
)
extra_patch_ids = torch.full(
(bs, trunc_seq_len),
trunc_seq_len,
dtype=torch.long,
device=patch_start_mask.device,
)
all_patch_ids = torch.cat((patch_ids, extra_patch_ids), dim=1)
patch_start_mask_padded = torch.cat(
(patch_start_mask, ~patch_start_mask), dim=1
)
patch_start_ids = all_patch_ids[patch_start_mask_padded].reshape(
bs, trunc_seq_len
)[:, :max_patches]
return patch_start_ids
def check_non_zero_after_zero(tensor):
zero_mask = tensor == 0
shifted_mask = torch.cat(
[
torch.zeros(tensor.shape[0], 1, dtype=torch.bool, device=tensor.device),
zero_mask[:, :-1],
],
dim=1,
)
non_zero_after_zero = (tensor != 0) & shifted_mask
return non_zero_after_zero.any()
def patch_lengths_from_start_ids(patch_start_ids, seq_len):
"""
Calculate patch lengths from start ids.
start ids: ex: [0, 1, 7, 7, 7, 7, 7], it has the start ids of the patches (here 0, 1), and then
the rest are filled to the seq len.
seq_len: ex: 7 length of the sequence
returns the patch lengths:
[1, 6] for the above example.
"""
last_ids = torch.full_like(patch_start_ids[:, :1], seq_len - 1)
patch_end_ids = torch.cat((patch_start_ids[:, 1:] - 1, last_ids), dim=1)
patch_lengths = patch_end_ids - patch_start_ids + 1
assert torch.all(patch_lengths >= 0), f"{patch_lengths}"
assert not check_non_zero_after_zero(patch_lengths), f"{patch_lengths}"
return patch_lengths
def find_space_patch_start_ids(tokens):
bs, seq_len = tokens.shape
tokens_no_offset = tokens - OFFSET
patch_end_mask = (
(tokens_no_offset < ord("0"))
| ((ord("9") < tokens_no_offset) & (tokens_no_offset < ord("A")))
| ((ord("Z") < tokens_no_offset) & (tokens_no_offset < ord("a")))
| ((ord("z") < tokens_no_offset) & (tokens_no_offset < 0b1000_0000))
| (0b1100_0000 <= tokens_no_offset)
)
patch_end_mask[:, 1:] &= patch_end_mask[:, :-1].bitwise_not()
patch_end_mask |= tokens < OFFSET
patch_start_mask = torch.cat(
[
torch.tensor([1, 1], device=tokens.device, dtype=torch.bool)
.unsqueeze(0)
.repeat(bs, 1),
patch_end_mask[:, 1:],
],
dim=1,
)
max_patches = patch_start_mask.sum(dim=1).max()
patch_ids = (
torch.arange(seq_len + 1, device=tokens.device).unsqueeze(0).repeat(bs, 1)
)
extra_patch_ids = torch.full(
(bs, seq_len + 1), seq_len + 1, dtype=torch.long, device=tokens.device
)
all_patch_ids = torch.cat((patch_ids, extra_patch_ids), dim=1)
patch_start_mask_padded = torch.cat((patch_start_mask, ~patch_start_mask), dim=1)
patch_start_ids = all_patch_ids[patch_start_mask_padded].reshape(bs, -1)[
:, :max_patches
]
return patch_start_ids
def to_device(entropy_model, device=None):
if device == "cuda":
rank = get_local_rank()
device = f"cuda:{rank}"
entropy_model = entropy_model.to(device)
return entropy_model, device
def model_pred_to_bpe_patching_pred(pred):
_, indices = torch.max(pred, dim=1)
return indices == BPE_ID
def apply_bpe_patcher(tokens, bpe_patcher, patching_batch_size, device=None):
assert tokens.device == torch.device(
"cpu"
), f"{tokens.device} != cpu expects tokens to be on cpu"
with torch.no_grad():
bpe_patcher_device, device = to_device(
bpe_patcher, device
) # Get entropy model to right rank device.
bpe_patching_mask = []
max_length = getattr(bpe_patcher, "max_length", 8192)
batch_numel = max_length * patching_batch_size
splits = torch.split(tokens.flatten(), batch_numel)
for split in splits:
pad_size = (max_length - (split.numel() % max_length)) % max_length
pad = torch.zeros(
pad_size, dtype=split.dtype, device=split.device, requires_grad=False
)
split = torch.cat((split, pad), dim=0)
split = split.reshape(-1, max_length).to(device)
assert torch.all(split >= 0) and torch.all(split < 260)
pred = bpe_patcher_device(split)
pred_cpu = pred[0].cpu()
pred_cpu = pred_cpu.reshape(-1, pred_cpu.shape[-1])[
: split.numel() - pad_size, :
] # [batch_size * seq_len, vocab]
bpe_patching_pred = model_pred_to_bpe_patching_pred(pred_cpu)
bpe_patching_mask.append(bpe_patching_pred)
bpe_patching_mask = torch.cat(bpe_patching_mask, dim=0)
bpe_patching_mask = bpe_patching_mask.reshape(tokens.shape)
return bpe_patching_mask
def find_bpe_patcher_patch_start_ids(
tokens, bpe_patcher, patching_batch_size, device=None, include_next_token=True
):
bs, seq_len = tokens.shape
first_ids = (
torch.tensor([0, 1], dtype=torch.long, device=tokens.device)
.unsqueeze(0)
.repeat(bs, 1)
)
preds_truncation_len = first_ids.shape[1]
token_input = tokens[:, 1:] if include_next_token else tokens[:, 1:-1]
if token_input.shape[1] >= 1:
patch_start_mask = apply_bpe_patcher(
token_input, bpe_patcher, patching_batch_size, device
)
assert (
patch_start_mask.shape[1]
== tokens.shape[1] + include_next_token - preds_truncation_len
), f"{patch_start_mask.shape[1]} != {tokens.shape[1] + include_next_token - preds_truncation_len}"
patch_start_ids = patch_start_ids_from_patch_start_mask(patch_start_mask)
patch_start_ids = torch.cat(
(first_ids, patch_start_ids + preds_truncation_len), dim=1
)
else:
patch_start_ids = first_ids
return patch_start_ids
def find_entropy_patch_start_ids(
entropies,
patch_size=None,
threshold=None,
threshold_add=None,
monotonicity=False,
include_next_token=True,
):
"""
Use entropies to find the start ids of each patch.
Use patch_size or threshold to figure out the total number of patches to allocate.
When threshold is not None the number of patches is not constant between
different sequences, but patches can be identified incrementally rather than
decided globally using the entire sequence.
"""
bs, seq_len = entropies.shape[:2]
first_ids = (
torch.tensor([0, 1], dtype=torch.long, device=entropies.device)
.unsqueeze(0)
.repeat(bs, 1)
)
preds_truncation_len = first_ids.shape[
1
] # remove the first preds because they will be start of patches.
entropies = entropies[:, 1:]
if threshold is None:
num_patches = seq_len // patch_size
patch_start_ids = entropies.topk(num_patches - 2, dim=1).indices
patch_start_ids = patch_start_ids.sort(dim=1).values
else:
# Assumes that there is at least one token going over the threshold
if monotonicity:
patch_start_mask = patch_start_mask_from_entropy_with_monotonicity(
entropies, threshold
)
elif threshold_add is not None and threshold is not None:
patch_start_mask = patch_start_mask_global_and_monotonicity(
entropies, threshold, threshold_add
)
else:
patch_start_mask = entropies > threshold
if not include_next_token:
patch_start_mask = patch_start_mask[:, :-1]
# patch_start_mask[1:] |= tokens[:-1] < OFFSET
patch_start_ids = patch_start_ids_from_patch_start_mask(patch_start_mask)
patch_start_ids = torch.cat(
(first_ids, patch_start_ids + preds_truncation_len), dim=1
)
return patch_start_ids
def rightpad(seq, pad_id, max_len):
return seq + [pad_id] * (max_len - len(seq))
def find_bpe_delim_patch_start_ids(tokens, delim):
ids = (tokens[:, :-1] == delim).nonzero(as_tuple=False)
out = [[0, 1] for _ in range(tokens.shape[0])]
for x, y in ids:
# start is at delim + 1, delim should be the last element in the patch.
out[x.item()].append(y.item() + 1)
max_len = max([len(elt) for elt in out])
out = [rightpad(elt, tokens.shape[1], max_len) for elt in out]
patch_start_ids = torch.tensor(out, dtype=tokens.dtype, device=tokens.device)
return patch_start_ids
def find_lookup_table_start_mask(
tokens: torch.Tensor, lookup_table: torch.Tensor, include_next_token=True
):
window_size = lookup_table.ndim
# Unfold the tensor to get sliding windows
unfolded = tokens.unfold(1, window_size, 1)
# Gather indices for each dimension
indices = [unfolded[..., i] for i in range(window_size)]
# Access the lookup table using the gathered indices
result = lookup_table[indices]
return result
def find_lookup_table_patch_start_ids(
tokens: torch.Tensor, lookup_table: torch.Tensor, include_next_token=True
):
bs, seq_len = tokens.shape
first_ids = (
torch.tensor([0, 1], dtype=torch.long, device=tokens.device)
.unsqueeze(0)
.repeat(bs, 1)
)
preds_truncation_len = first_ids.shape[1]
window_size = lookup_table.ndim
assert window_size == 2, f"{window_size} != 2"
# output dimensions: token_input shape - window_size + 1 --> we want first ids + this = tokens shape + 1 if next token otherwise just token shape
token_input = (
tokens if include_next_token else tokens[:, : -preds_truncation_len + 1]
)
if token_input.shape[1] >= window_size:
patch_start_mask = find_lookup_table_start_mask(
token_input, lookup_table, include_next_token
)
assert (
patch_start_mask.shape[1]
== tokens.shape[1] + include_next_token - preds_truncation_len
), f"{patch_start_mask.shape[1]} != {tokens.shape[1] + include_next_token - preds_truncation_len}"
patch_start_ids = patch_start_ids_from_patch_start_mask(patch_start_mask)
patch_start_ids = torch.cat(
(first_ids, patch_start_ids + preds_truncation_len), dim=1
)
else:
patch_start_ids = first_ids
return patch_start_ids
def split_large_numbers(lst, m):
new_lst = []
for i in lst:
if i > m:
while i > m:
new_lst.append(m)
i -= m
new_lst.append(i)
else:
new_lst.append(i)
assert sum(new_lst) == sum(lst), f"{sum(new_lst)} != {sum(lst)}"
return new_lst
class Patcher:
def __init__(self, patcher_args: PatcherArgs):
self.patcher_args = patcher_args
self.patching_mode = patcher_args.patching_mode
self.realtime_patching = patcher_args.realtime_patching
if self.realtime_patching:
assert (
patcher_args.entropy_model_checkpoint_dir is not None
), "Cannot require realtime patching without an entropy model checkpoint"
entropy_model = load_entropy_model(
patcher_args.entropy_model_checkpoint_dir
)
entropy_model, _ = to_device(entropy_model, patcher_args.patching_device)
self.entropy_model = entropy_model
else:
self.entropy_model = None
self.threshold = patcher_args.threshold
self.threshold_add = patcher_args.threshold_add
self.max_patch_length = patcher_args.max_patch_length
self.patch_size = patcher_args.patch_size
self.patching_batch_size = patcher_args.patching_batch_size
self.data_loader_patching = patcher_args.data_loader_patching
self.device = patcher_args.device
self.monotonicity = patcher_args.monotonicity
self.log_time = patcher_args.log_time
if self.log_time:
self.log = defaultdict(float)
def patch(
self,
tokens: torch.Tensor,
include_next_token: bool = False,
preds: torch.Tensor | None = None,
entropies: torch.Tensor | None = None,
threshold: float = None,
) -> torch.Tensor:
"""
tokens: 2D tensor of shape [batch_size, seq_len] that needs to be patched
Returns patch lengths and optionally scores associated with the tokens (i.e. entropies, logprobs etc.)
-> output tensor: [batch_size, max_num_patches]
each tensor is processed independently and gets right padded with zeros.
Patching with the following modes:
1. patching_mode = None: static patch size
2. patching_mode = "entropy":
calculate entropy of each token, allocate patches so that the total
number of patches is the same as static patching but choose to begin
patches on tokens where the model is most uncertain (highest entropy).
When threshold is provided, it uses the threshold to decide when to
start a new patch.
3. patching_mode = "space":
use space like tokens to define the patches.
4. patching_mode = "bpe":
use bpe delim tokens to define the patches.
To correctly patch the last token, it may be necessary to include the next token in the patch
lengths calculations. This is controlled by the include_next_token argument.
"""
bs, seq_len = tokens.shape
seq_len_next_tok = seq_len + 1 if include_next_token else seq_len
scores = None
# STATIC
if self.patching_mode is None:
patch_lengths = torch.zeros(
(bs, math.ceil(seq_len_next_tok / self.patch_size)),
dtype=tokens.dtype,
device=tokens.device,
).fill_(self.patch_size)
if seq_len_next_tok % self.patch_size != 0:
patch_lengths[:, -1] = seq_len_next_tok % self.patch_size
# ENTROPY
elif self.patching_mode == PatchingModeEnum.entropy:
if self.log_time:
s = time.time()
if entropies is not None:
scores = torch.tensor(entropies, dtype=torch.float32)
elif preds is not None:
scores = entropy(preds)
else:
start_entropies = time.time()
scores = calculate_entropies(
tokens,
self.entropy_model,
self.patching_batch_size,
self.device,
)
if self.log_time:
self.log["calculate_entropies"] += time.time() - s
s = time.time()
patch_start_ids = find_entropy_patch_start_ids(
scores,
self.patch_size,
include_next_token=include_next_token,
threshold=threshold if threshold is not None else self.threshold,
threshold_add=self.threshold_add,
monotonicity=self.monotonicity,
)
if self.log_time:
self.log["find_entropy_patch_start_ids"] += time.time() - s
s = time.time()
patch_lengths = patch_lengths_from_start_ids(
patch_start_ids, seq_len_next_tok
)
if self.log_time:
self.log["patch_lengths_from_start_ids"] += time.time() - s
s = time.time()
# BPE
elif self.patching_mode == PatchingModeEnum.bpe:
patch_start_ids = find_bpe_delim_patch_start_ids(tokens, delim=BPE_ID)
patch_lengths = patch_lengths_from_start_ids(
patch_start_ids, seq_len_next_tok
)
elif self.patching_mode == PatchingModeEnum.bpe_patcher:
patch_start_ids = find_bpe_patcher_patch_start_ids(
tokens,
self.entropy_model,
self.patching_batch_size,
self.device,
include_next_token,
)
patch_lengths = patch_lengths_from_start_ids(
patch_start_ids, seq_len_next_tok
)
# SPACE
elif self.patching_mode == PatchingModeEnum.space:
patch_start_ids = find_space_patch_start_ids(tokens)
patch_lengths = patch_lengths_from_start_ids(
patch_start_ids, seq_len_next_tok
)
else:
raise NotImplementedError(f"self.patching_mode {self.patching_mode}")
# Apply any processing to patch lengths
if self.max_patch_length is not None:
# TODO: avoid going back to a list here.
patch_lengths = [
split_large_numbers(pl, self.max_patch_length)
for pl in patch_lengths.tolist()
]
max_len = max([len(pl) for pl in patch_lengths])
patch_lengths = [rightpad(pl, 0, max_len=max_len) for pl in patch_lengths]
patch_lengths = torch.tensor(
patch_lengths, dtype=tokens.dtype, device=tokens.device
)
assert not check_non_zero_after_zero(patch_lengths)
# Find the last non-zero column index using argmax on a reversed version of the tensor
last_non_zero_col_reversed = (
(patch_lengths != 0).flip(dims=[1]).int().argmax(dim=1).min()
)
# Slice the tensor up to the last non-zero column
patch_lengths = patch_lengths[
:, : patch_lengths.shape[1] - last_non_zero_col_reversed
]
assert (
torch.sum(patch_lengths)
== tokens.numel() + include_next_token * tokens.shape[0]
), f"{torch.sum(patch_lengths)} != {tokens.numel() + include_next_token * tokens.shape[0]}"
if self.log_time:
self.log["postprocessing_patch_lengths"] += time.time() - s
self.log["tokens"] += patch_lengths.sum().item()
return patch_lengths, scores