mirror of
https://github.com/facebookresearch/blt.git
synced 2025-09-10 06:14:35 +00:00
Initial commit
This commit is contained in:
commit
bcc039bb75
86 changed files with 12203 additions and 0 deletions
609
bytelatent/data/patcher.py
Normal file
609
bytelatent/data/patcher.py
Normal file
|
@ -0,0 +1,609 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
import math
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from enum import Enum
|
||||
|
||||
import torch
|
||||
from pydantic import BaseModel
|
||||
from torch.nn import functional as F
|
||||
|
||||
from bytelatent.distributed import get_local_rank
|
||||
from bytelatent.entropy_model import load_entropy_model
|
||||
|
||||
# from src.slurm import get_local_rank
|
||||
from bytelatent.tokenizers.blt_tokenizer import BPE_ID, OFFSET
|
||||
from bytelatent.tokenizers.constants import BPE_ID, OFFSET
|
||||
|
||||
|
||||
class PatchingModeEnum(str, Enum):
|
||||
entropy = "entropy"
|
||||
bpe = "bpe"
|
||||
bpe_patcher = "bpe_patcher"
|
||||
space = "space"
|
||||
|
||||
|
||||
class PatcherArgs(BaseModel):
|
||||
patching_mode: PatchingModeEnum = PatchingModeEnum.entropy
|
||||
patching_device: str = "cuda"
|
||||
entropy_model_checkpoint_dir: str | None = None
|
||||
realtime_patching: bool = False
|
||||
threshold: float = 1.335442066192627
|
||||
threshold_add: float | None = None
|
||||
max_patch_length: int | None = None
|
||||
patch_size: float = 4.5
|
||||
patching_batch_size: int = 1
|
||||
data_loader_patching: bool = False
|
||||
device: str = "cuda"
|
||||
monotonicity: bool = False
|
||||
log_time: bool = False
|
||||
|
||||
def build(self) -> "Patcher":
|
||||
return Patcher(self)
|
||||
|
||||
|
||||
def entropy(scores):
|
||||
"""
|
||||
scores: [bs, seq_len, vocab]
|
||||
returns [bs, seq_len]
|
||||
|
||||
Computes the entropy for each token in the batch.
|
||||
Note: uses natural log.
|
||||
"""
|
||||
log_probs = F.log_softmax(scores, dim=-1)
|
||||
probs = torch.exp(log_probs)
|
||||
p_log_p = log_probs * probs
|
||||
entropy = -p_log_p.sum(dim=-1)
|
||||
return entropy
|
||||
|
||||
|
||||
def calculate_entropies(
|
||||
tokens: torch.tensor, entropy_model, patching_batch_size, device: str | None = None
|
||||
):
|
||||
"""
|
||||
tokens: 2D tensor of shape [batch_size, seq_len]
|
||||
Return 2D tensor of shape [batch_size, seq_len] with entropies for each token.
|
||||
|
||||
Splits the tokens into chunks of size max_length and calculates entropies for each chunk.
|
||||
Entropy model can be executed on cpu or gpu, specify either 'cuda' or 'cpu' in the device argument.
|
||||
"""
|
||||
with torch.no_grad():
|
||||
entropies = []
|
||||
max_length = getattr(entropy_model, "max_length", 8192)
|
||||
batch_numel = max_length * patching_batch_size
|
||||
splits = torch.split(tokens.flatten(), batch_numel)
|
||||
for split in splits:
|
||||
pad_size = (max_length - (split.numel() % max_length)) % max_length
|
||||
pad = torch.zeros(
|
||||
pad_size, dtype=split.dtype, device=split.device, requires_grad=False
|
||||
)
|
||||
split = torch.cat((split, pad), dim=0)
|
||||
split = split.reshape(-1, max_length)
|
||||
if device is not None:
|
||||
split = split.to(device)
|
||||
assert torch.all(split >= 0) and torch.all(split < 260)
|
||||
pred, _ = entropy_model(split)
|
||||
pred = pred.reshape(-1, pred.shape[-1])[
|
||||
: split.numel() - pad_size, :
|
||||
] # [batch_size * seq_len, vocab]
|
||||
pred_entropies = entropy(pred)
|
||||
entropies.append(pred_entropies)
|
||||
|
||||
entropies = torch.cat(entropies, dim=0)
|
||||
entropies = entropies.reshape(tokens.shape)
|
||||
return entropies
|
||||
|
||||
|
||||
def patch_start_mask_from_entropy_with_monotonicity(entropies, t):
|
||||
"""
|
||||
entropies: [bs, seq_len] torch tensor of entropies
|
||||
t: threshold
|
||||
returns [bs, seq_len] mask where True indicates the start of a patch
|
||||
"""
|
||||
bs, seq_len = entropies.shape
|
||||
mask = torch.zeros_like(entropies, dtype=torch.bool)
|
||||
mask[:, 0] = True
|
||||
|
||||
# Calculate differences between consecutive elements along the sequence length
|
||||
differences = entropies[:, 1:] - entropies[:, :-1]
|
||||
|
||||
# Calculate conditions for all elements except the first one in each sequence
|
||||
condition = differences > t
|
||||
|
||||
# Update the mask based on the condition
|
||||
mask[:, 1:] = condition
|
||||
|
||||
return mask
|
||||
|
||||
|
||||
def patch_start_mask_global_and_monotonicity(entropies, t, t_add=0):
|
||||
"""
|
||||
entropies: [bs, seq_len] torch tensor of entropies
|
||||
t: threshold
|
||||
returns [bs, seq_len] mask where True indicates the start of a patch
|
||||
"""
|
||||
bs, seq_len = entropies.shape
|
||||
mask = torch.zeros_like(entropies, dtype=torch.bool)
|
||||
mask[:, 0] = True
|
||||
|
||||
# Calculate differences between consecutive elements along the sequence length
|
||||
differences = entropies[:, 1:] - entropies[:, :-1]
|
||||
|
||||
# Calculate conditions for all elements except the first one in each sequence
|
||||
condition = (differences > t_add) & (entropies[:, 1:] > t) & (~mask[:, :-1])
|
||||
|
||||
# Update the mask based on the condition
|
||||
mask[:, 1:] = condition
|
||||
|
||||
return mask
|
||||
|
||||
|
||||
def patch_start_ids_from_patch_start_mask(patch_start_mask):
|
||||
bs, trunc_seq_len = patch_start_mask.shape
|
||||
max_patches = patch_start_mask.sum(dim=1).max()
|
||||
if max_patches == 0:
|
||||
patch_start_ids = torch.full(
|
||||
(bs, trunc_seq_len),
|
||||
trunc_seq_len,
|
||||
dtype=torch.long,
|
||||
device=patch_start_mask.device,
|
||||
)
|
||||
else:
|
||||
patch_ids = (
|
||||
torch.arange(trunc_seq_len, device=patch_start_mask.device)
|
||||
.unsqueeze(0)
|
||||
.repeat(bs, 1)
|
||||
)
|
||||
extra_patch_ids = torch.full(
|
||||
(bs, trunc_seq_len),
|
||||
trunc_seq_len,
|
||||
dtype=torch.long,
|
||||
device=patch_start_mask.device,
|
||||
)
|
||||
all_patch_ids = torch.cat((patch_ids, extra_patch_ids), dim=1)
|
||||
patch_start_mask_padded = torch.cat(
|
||||
(patch_start_mask, ~patch_start_mask), dim=1
|
||||
)
|
||||
patch_start_ids = all_patch_ids[patch_start_mask_padded].reshape(
|
||||
bs, trunc_seq_len
|
||||
)[:, :max_patches]
|
||||
return patch_start_ids
|
||||
|
||||
|
||||
def check_non_zero_after_zero(tensor):
|
||||
zero_mask = tensor == 0
|
||||
shifted_mask = torch.cat(
|
||||
[
|
||||
torch.zeros(tensor.shape[0], 1, dtype=torch.bool, device=tensor.device),
|
||||
zero_mask[:, :-1],
|
||||
],
|
||||
dim=1,
|
||||
)
|
||||
non_zero_after_zero = (tensor != 0) & shifted_mask
|
||||
return non_zero_after_zero.any()
|
||||
|
||||
|
||||
def patch_lengths_from_start_ids(patch_start_ids, seq_len):
|
||||
"""
|
||||
Calculate patch lengths from start ids.
|
||||
start ids: ex: [0, 1, 7, 7, 7, 7, 7], it has the start ids of the patches (here 0, 1), and then
|
||||
the rest are filled to the seq len.
|
||||
seq_len: ex: 7 length of the sequence
|
||||
|
||||
returns the patch lengths:
|
||||
[1, 6] for the above example.
|
||||
"""
|
||||
last_ids = torch.full_like(patch_start_ids[:, :1], seq_len - 1)
|
||||
patch_end_ids = torch.cat((patch_start_ids[:, 1:] - 1, last_ids), dim=1)
|
||||
patch_lengths = patch_end_ids - patch_start_ids + 1
|
||||
assert torch.all(patch_lengths >= 0), f"{patch_lengths}"
|
||||
assert not check_non_zero_after_zero(patch_lengths), f"{patch_lengths}"
|
||||
return patch_lengths
|
||||
|
||||
|
||||
def find_space_patch_start_ids(tokens):
|
||||
bs, seq_len = tokens.shape
|
||||
tokens_no_offset = tokens - OFFSET
|
||||
patch_end_mask = (
|
||||
(tokens_no_offset < ord("0"))
|
||||
| ((ord("9") < tokens_no_offset) & (tokens_no_offset < ord("A")))
|
||||
| ((ord("Z") < tokens_no_offset) & (tokens_no_offset < ord("a")))
|
||||
| ((ord("z") < tokens_no_offset) & (tokens_no_offset < 0b1000_0000))
|
||||
| (0b1100_0000 <= tokens_no_offset)
|
||||
)
|
||||
patch_end_mask[:, 1:] &= patch_end_mask[:, :-1].bitwise_not()
|
||||
patch_end_mask |= tokens < OFFSET
|
||||
|
||||
patch_start_mask = torch.cat(
|
||||
[
|
||||
torch.tensor([1, 1], device=tokens.device, dtype=torch.bool)
|
||||
.unsqueeze(0)
|
||||
.repeat(bs, 1),
|
||||
patch_end_mask[:, 1:],
|
||||
],
|
||||
dim=1,
|
||||
)
|
||||
max_patches = patch_start_mask.sum(dim=1).max()
|
||||
|
||||
patch_ids = (
|
||||
torch.arange(seq_len + 1, device=tokens.device).unsqueeze(0).repeat(bs, 1)
|
||||
)
|
||||
extra_patch_ids = torch.full(
|
||||
(bs, seq_len + 1), seq_len + 1, dtype=torch.long, device=tokens.device
|
||||
)
|
||||
all_patch_ids = torch.cat((patch_ids, extra_patch_ids), dim=1)
|
||||
patch_start_mask_padded = torch.cat((patch_start_mask, ~patch_start_mask), dim=1)
|
||||
|
||||
patch_start_ids = all_patch_ids[patch_start_mask_padded].reshape(bs, -1)[
|
||||
:, :max_patches
|
||||
]
|
||||
return patch_start_ids
|
||||
|
||||
|
||||
def to_device(entropy_model, device=None):
|
||||
if device == "cuda":
|
||||
rank = get_local_rank()
|
||||
device = f"cuda:{rank}"
|
||||
entropy_model = entropy_model.to(device)
|
||||
return entropy_model, device
|
||||
|
||||
|
||||
def model_pred_to_bpe_patching_pred(pred):
|
||||
_, indices = torch.max(pred, dim=1)
|
||||
return indices == BPE_ID
|
||||
|
||||
|
||||
def apply_bpe_patcher(tokens, bpe_patcher, patching_batch_size, device=None):
|
||||
assert tokens.device == torch.device(
|
||||
"cpu"
|
||||
), f"{tokens.device} != cpu expects tokens to be on cpu"
|
||||
with torch.no_grad():
|
||||
bpe_patcher_device, device = to_device(
|
||||
bpe_patcher, device
|
||||
) # Get entropy model to right rank device.
|
||||
bpe_patching_mask = []
|
||||
max_length = getattr(bpe_patcher, "max_length", 8192)
|
||||
batch_numel = max_length * patching_batch_size
|
||||
splits = torch.split(tokens.flatten(), batch_numel)
|
||||
for split in splits:
|
||||
pad_size = (max_length - (split.numel() % max_length)) % max_length
|
||||
pad = torch.zeros(
|
||||
pad_size, dtype=split.dtype, device=split.device, requires_grad=False
|
||||
)
|
||||
split = torch.cat((split, pad), dim=0)
|
||||
split = split.reshape(-1, max_length).to(device)
|
||||
assert torch.all(split >= 0) and torch.all(split < 260)
|
||||
pred = bpe_patcher_device(split)
|
||||
pred_cpu = pred[0].cpu()
|
||||
pred_cpu = pred_cpu.reshape(-1, pred_cpu.shape[-1])[
|
||||
: split.numel() - pad_size, :
|
||||
] # [batch_size * seq_len, vocab]
|
||||
bpe_patching_pred = model_pred_to_bpe_patching_pred(pred_cpu)
|
||||
bpe_patching_mask.append(bpe_patching_pred)
|
||||
bpe_patching_mask = torch.cat(bpe_patching_mask, dim=0)
|
||||
bpe_patching_mask = bpe_patching_mask.reshape(tokens.shape)
|
||||
return bpe_patching_mask
|
||||
|
||||
|
||||
def find_bpe_patcher_patch_start_ids(
|
||||
tokens, bpe_patcher, patching_batch_size, device=None, include_next_token=True
|
||||
):
|
||||
bs, seq_len = tokens.shape
|
||||
|
||||
first_ids = (
|
||||
torch.tensor([0, 1], dtype=torch.long, device=tokens.device)
|
||||
.unsqueeze(0)
|
||||
.repeat(bs, 1)
|
||||
)
|
||||
preds_truncation_len = first_ids.shape[1]
|
||||
token_input = tokens[:, 1:] if include_next_token else tokens[:, 1:-1]
|
||||
if token_input.shape[1] >= 1:
|
||||
patch_start_mask = apply_bpe_patcher(
|
||||
token_input, bpe_patcher, patching_batch_size, device
|
||||
)
|
||||
assert (
|
||||
patch_start_mask.shape[1]
|
||||
== tokens.shape[1] + include_next_token - preds_truncation_len
|
||||
), f"{patch_start_mask.shape[1]} != {tokens.shape[1] + include_next_token - preds_truncation_len}"
|
||||
patch_start_ids = patch_start_ids_from_patch_start_mask(patch_start_mask)
|
||||
patch_start_ids = torch.cat(
|
||||
(first_ids, patch_start_ids + preds_truncation_len), dim=1
|
||||
)
|
||||
else:
|
||||
patch_start_ids = first_ids
|
||||
return patch_start_ids
|
||||
|
||||
|
||||
def find_entropy_patch_start_ids(
|
||||
entropies,
|
||||
patch_size=None,
|
||||
threshold=None,
|
||||
threshold_add=None,
|
||||
monotonicity=False,
|
||||
include_next_token=True,
|
||||
):
|
||||
"""
|
||||
Use entropies to find the start ids of each patch.
|
||||
Use patch_size or threshold to figure out the total number of patches to allocate.
|
||||
|
||||
When threshold is not None the number of patches is not constant between
|
||||
different sequences, but patches can be identified incrementally rather than
|
||||
decided globally using the entire sequence.
|
||||
"""
|
||||
bs, seq_len = entropies.shape[:2]
|
||||
|
||||
first_ids = (
|
||||
torch.tensor([0, 1], dtype=torch.long, device=entropies.device)
|
||||
.unsqueeze(0)
|
||||
.repeat(bs, 1)
|
||||
)
|
||||
preds_truncation_len = first_ids.shape[
|
||||
1
|
||||
] # remove the first preds because they will be start of patches.
|
||||
entropies = entropies[:, 1:]
|
||||
if threshold is None:
|
||||
num_patches = seq_len // patch_size
|
||||
patch_start_ids = entropies.topk(num_patches - 2, dim=1).indices
|
||||
patch_start_ids = patch_start_ids.sort(dim=1).values
|
||||
else:
|
||||
# Assumes that there is at least one token going over the threshold
|
||||
if monotonicity:
|
||||
patch_start_mask = patch_start_mask_from_entropy_with_monotonicity(
|
||||
entropies, threshold
|
||||
)
|
||||
elif threshold_add is not None and threshold is not None:
|
||||
patch_start_mask = patch_start_mask_global_and_monotonicity(
|
||||
entropies, threshold, threshold_add
|
||||
)
|
||||
else:
|
||||
patch_start_mask = entropies > threshold
|
||||
if not include_next_token:
|
||||
patch_start_mask = patch_start_mask[:, :-1]
|
||||
# patch_start_mask[1:] |= tokens[:-1] < OFFSET
|
||||
patch_start_ids = patch_start_ids_from_patch_start_mask(patch_start_mask)
|
||||
|
||||
patch_start_ids = torch.cat(
|
||||
(first_ids, patch_start_ids + preds_truncation_len), dim=1
|
||||
)
|
||||
return patch_start_ids
|
||||
|
||||
|
||||
def rightpad(seq, pad_id, max_len):
|
||||
return seq + [pad_id] * (max_len - len(seq))
|
||||
|
||||
|
||||
def find_bpe_delim_patch_start_ids(tokens, delim):
|
||||
ids = (tokens[:, :-1] == delim).nonzero(as_tuple=False)
|
||||
out = [[0, 1] for _ in range(tokens.shape[0])]
|
||||
for x, y in ids:
|
||||
# start is at delim + 1, delim should be the last element in the patch.
|
||||
out[x.item()].append(y.item() + 1)
|
||||
max_len = max([len(elt) for elt in out])
|
||||
out = [rightpad(elt, tokens.shape[1], max_len) for elt in out]
|
||||
patch_start_ids = torch.tensor(out, dtype=tokens.dtype, device=tokens.device)
|
||||
return patch_start_ids
|
||||
|
||||
|
||||
def find_lookup_table_start_mask(
|
||||
tokens: torch.Tensor, lookup_table: torch.Tensor, include_next_token=True
|
||||
):
|
||||
window_size = lookup_table.ndim
|
||||
# Unfold the tensor to get sliding windows
|
||||
unfolded = tokens.unfold(1, window_size, 1)
|
||||
# Gather indices for each dimension
|
||||
indices = [unfolded[..., i] for i in range(window_size)]
|
||||
# Access the lookup table using the gathered indices
|
||||
result = lookup_table[indices]
|
||||
return result
|
||||
|
||||
|
||||
def find_lookup_table_patch_start_ids(
|
||||
tokens: torch.Tensor, lookup_table: torch.Tensor, include_next_token=True
|
||||
):
|
||||
bs, seq_len = tokens.shape
|
||||
|
||||
first_ids = (
|
||||
torch.tensor([0, 1], dtype=torch.long, device=tokens.device)
|
||||
.unsqueeze(0)
|
||||
.repeat(bs, 1)
|
||||
)
|
||||
preds_truncation_len = first_ids.shape[1]
|
||||
window_size = lookup_table.ndim
|
||||
assert window_size == 2, f"{window_size} != 2"
|
||||
# output dimensions: token_input shape - window_size + 1 --> we want first ids + this = tokens shape + 1 if next token otherwise just token shape
|
||||
token_input = (
|
||||
tokens if include_next_token else tokens[:, : -preds_truncation_len + 1]
|
||||
)
|
||||
if token_input.shape[1] >= window_size:
|
||||
patch_start_mask = find_lookup_table_start_mask(
|
||||
token_input, lookup_table, include_next_token
|
||||
)
|
||||
assert (
|
||||
patch_start_mask.shape[1]
|
||||
== tokens.shape[1] + include_next_token - preds_truncation_len
|
||||
), f"{patch_start_mask.shape[1]} != {tokens.shape[1] + include_next_token - preds_truncation_len}"
|
||||
patch_start_ids = patch_start_ids_from_patch_start_mask(patch_start_mask)
|
||||
patch_start_ids = torch.cat(
|
||||
(first_ids, patch_start_ids + preds_truncation_len), dim=1
|
||||
)
|
||||
else:
|
||||
patch_start_ids = first_ids
|
||||
return patch_start_ids
|
||||
|
||||
|
||||
def split_large_numbers(lst, m):
|
||||
new_lst = []
|
||||
for i in lst:
|
||||
if i > m:
|
||||
while i > m:
|
||||
new_lst.append(m)
|
||||
i -= m
|
||||
new_lst.append(i)
|
||||
else:
|
||||
new_lst.append(i)
|
||||
assert sum(new_lst) == sum(lst), f"{sum(new_lst)} != {sum(lst)}"
|
||||
return new_lst
|
||||
|
||||
|
||||
class Patcher:
|
||||
def __init__(self, patcher_args: PatcherArgs):
|
||||
self.patcher_args = patcher_args
|
||||
self.patching_mode = patcher_args.patching_mode
|
||||
self.realtime_patching = patcher_args.realtime_patching
|
||||
if self.realtime_patching:
|
||||
assert (
|
||||
patcher_args.entropy_model_checkpoint_dir is not None
|
||||
), "Cannot require realtime patching without an entropy model checkpoint"
|
||||
entropy_model = load_entropy_model(
|
||||
patcher_args.entropy_model_checkpoint_dir
|
||||
)
|
||||
entropy_model, _ = to_device(entropy_model, patcher_args.patching_device)
|
||||
self.entropy_model = entropy_model
|
||||
else:
|
||||
self.entropy_model = None
|
||||
self.threshold = patcher_args.threshold
|
||||
self.threshold_add = patcher_args.threshold_add
|
||||
self.max_patch_length = patcher_args.max_patch_length
|
||||
self.patch_size = patcher_args.patch_size
|
||||
self.patching_batch_size = patcher_args.patching_batch_size
|
||||
self.data_loader_patching = patcher_args.data_loader_patching
|
||||
self.device = patcher_args.device
|
||||
self.monotonicity = patcher_args.monotonicity
|
||||
self.log_time = patcher_args.log_time
|
||||
if self.log_time:
|
||||
self.log = defaultdict(float)
|
||||
|
||||
def patch(
|
||||
self,
|
||||
tokens: torch.Tensor,
|
||||
include_next_token: bool = False,
|
||||
preds: torch.Tensor | None = None,
|
||||
entropies: torch.Tensor | None = None,
|
||||
threshold: float = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
tokens: 2D tensor of shape [batch_size, seq_len] that needs to be patched
|
||||
Returns patch lengths and optionally scores associated with the tokens (i.e. entropies, logprobs etc.)
|
||||
-> output tensor: [batch_size, max_num_patches]
|
||||
each tensor is processed independently and gets right padded with zeros.
|
||||
|
||||
Patching with the following modes:
|
||||
1. patching_mode = None: static patch size
|
||||
2. patching_mode = "entropy":
|
||||
calculate entropy of each token, allocate patches so that the total
|
||||
number of patches is the same as static patching but choose to begin
|
||||
patches on tokens where the model is most uncertain (highest entropy).
|
||||
|
||||
When threshold is provided, it uses the threshold to decide when to
|
||||
start a new patch.
|
||||
3. patching_mode = "space":
|
||||
use space like tokens to define the patches.
|
||||
4. patching_mode = "bpe":
|
||||
use bpe delim tokens to define the patches.
|
||||
|
||||
To correctly patch the last token, it may be necessary to include the next token in the patch
|
||||
lengths calculations. This is controlled by the include_next_token argument.
|
||||
"""
|
||||
bs, seq_len = tokens.shape
|
||||
seq_len_next_tok = seq_len + 1 if include_next_token else seq_len
|
||||
scores = None
|
||||
# STATIC
|
||||
if self.patching_mode is None:
|
||||
patch_lengths = torch.zeros(
|
||||
(bs, math.ceil(seq_len_next_tok / self.patch_size)),
|
||||
dtype=tokens.dtype,
|
||||
device=tokens.device,
|
||||
).fill_(self.patch_size)
|
||||
if seq_len_next_tok % self.patch_size != 0:
|
||||
patch_lengths[:, -1] = seq_len_next_tok % self.patch_size
|
||||
# ENTROPY
|
||||
elif self.patching_mode == PatchingModeEnum.entropy:
|
||||
if self.log_time:
|
||||
s = time.time()
|
||||
if entropies is not None:
|
||||
scores = torch.tensor(entropies, dtype=torch.float32)
|
||||
elif preds is not None:
|
||||
scores = entropy(preds)
|
||||
else:
|
||||
start_entropies = time.time()
|
||||
scores = calculate_entropies(
|
||||
tokens,
|
||||
self.entropy_model,
|
||||
self.patching_batch_size,
|
||||
self.device,
|
||||
)
|
||||
if self.log_time:
|
||||
self.log["calculate_entropies"] += time.time() - s
|
||||
s = time.time()
|
||||
patch_start_ids = find_entropy_patch_start_ids(
|
||||
scores,
|
||||
self.patch_size,
|
||||
include_next_token=include_next_token,
|
||||
threshold=threshold if threshold is not None else self.threshold,
|
||||
threshold_add=self.threshold_add,
|
||||
monotonicity=self.monotonicity,
|
||||
)
|
||||
if self.log_time:
|
||||
self.log["find_entropy_patch_start_ids"] += time.time() - s
|
||||
s = time.time()
|
||||
patch_lengths = patch_lengths_from_start_ids(
|
||||
patch_start_ids, seq_len_next_tok
|
||||
)
|
||||
if self.log_time:
|
||||
self.log["patch_lengths_from_start_ids"] += time.time() - s
|
||||
s = time.time()
|
||||
# BPE
|
||||
elif self.patching_mode == PatchingModeEnum.bpe:
|
||||
patch_start_ids = find_bpe_delim_patch_start_ids(tokens, delim=BPE_ID)
|
||||
patch_lengths = patch_lengths_from_start_ids(
|
||||
patch_start_ids, seq_len_next_tok
|
||||
)
|
||||
elif self.patching_mode == PatchingModeEnum.bpe_patcher:
|
||||
patch_start_ids = find_bpe_patcher_patch_start_ids(
|
||||
tokens,
|
||||
self.entropy_model,
|
||||
self.patching_batch_size,
|
||||
self.device,
|
||||
include_next_token,
|
||||
)
|
||||
patch_lengths = patch_lengths_from_start_ids(
|
||||
patch_start_ids, seq_len_next_tok
|
||||
)
|
||||
# SPACE
|
||||
elif self.patching_mode == PatchingModeEnum.space:
|
||||
patch_start_ids = find_space_patch_start_ids(tokens)
|
||||
patch_lengths = patch_lengths_from_start_ids(
|
||||
patch_start_ids, seq_len_next_tok
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(f"self.patching_mode {self.patching_mode}")
|
||||
|
||||
# Apply any processing to patch lengths
|
||||
if self.max_patch_length is not None:
|
||||
# TODO: avoid going back to a list here.
|
||||
patch_lengths = [
|
||||
split_large_numbers(pl, self.max_patch_length)
|
||||
for pl in patch_lengths.tolist()
|
||||
]
|
||||
max_len = max([len(pl) for pl in patch_lengths])
|
||||
patch_lengths = [rightpad(pl, 0, max_len=max_len) for pl in patch_lengths]
|
||||
patch_lengths = torch.tensor(
|
||||
patch_lengths, dtype=tokens.dtype, device=tokens.device
|
||||
)
|
||||
assert not check_non_zero_after_zero(patch_lengths)
|
||||
# Find the last non-zero column index using argmax on a reversed version of the tensor
|
||||
last_non_zero_col_reversed = (
|
||||
(patch_lengths != 0).flip(dims=[1]).int().argmax(dim=1).min()
|
||||
)
|
||||
# Slice the tensor up to the last non-zero column
|
||||
patch_lengths = patch_lengths[
|
||||
:, : patch_lengths.shape[1] - last_non_zero_col_reversed
|
||||
]
|
||||
assert (
|
||||
torch.sum(patch_lengths)
|
||||
== tokens.numel() + include_next_token * tokens.shape[0]
|
||||
), f"{torch.sum(patch_lengths)} != {tokens.numel() + include_next_token * tokens.shape[0]}"
|
||||
if self.log_time:
|
||||
self.log["postprocessing_patch_lengths"] += time.time() - s
|
||||
self.log["tokens"] += patch_lengths.sum().item()
|
||||
return patch_lengths, scores
|
Loading…
Add table
Add a link
Reference in a new issue