diff --git a/bytelatent/data/patcher.py b/bytelatent/data/patcher.py index f8477a3..afcfa2e 100644 --- a/bytelatent/data/patcher.py +++ b/bytelatent/data/patcher.py @@ -2,6 +2,7 @@ import math import time from collections import defaultdict +from contextlib import nullcontext from enum import Enum import torch @@ -58,7 +59,11 @@ def entropy(scores): def calculate_entropies( - tokens: torch.tensor, entropy_model, patching_batch_size, device: str | None = None + tokens: torch.tensor, + entropy_model, + patching_batch_size, + device: str | None = None, + enable_grad: bool = False, ): """ tokens: 2D tensor of shape [batch_size, seq_len] @@ -67,8 +72,12 @@ def calculate_entropies( Splits the tokens into chunks of size max_length and calculates entropies for each chunk. Entropy model can be executed on cpu or gpu, specify either 'cuda' or 'cpu' in the device argument. """ - with torch.no_grad(): + + grad_context = nullcontext() if enable_grad else torch.no_grad() + + with grad_context: entropies = [] + preds = [] max_length = getattr(entropy_model, "max_length", 8192) batch_numel = max_length * patching_batch_size splits = torch.split(tokens.flatten(), batch_numel) @@ -86,12 +95,15 @@ def calculate_entropies( pred = pred.reshape(-1, pred.shape[-1])[ : split.numel() - pad_size, : ] # [batch_size * seq_len, vocab] + preds.append(pred) pred_entropies = entropy(pred) entropies.append(pred_entropies) concat_entropies = torch.cat(entropies, dim=0) concat_entropies = concat_entropies.reshape(tokens.shape) - return concat_entropies + concat_preds = torch.cat(preds, dim=0) + concat_preds = concat_preds.reshape(tokens.shape[0], tokens.shape[1], -1) + return concat_entropies, concat_preds def patch_start_mask_from_entropy_with_monotonicity(entropies, t): @@ -101,6 +113,10 @@ def patch_start_mask_from_entropy_with_monotonicity(entropies, t): returns [bs, seq_len] mask where True indicates the start of a patch """ bs, seq_len = entropies.shape + + if seq_len == 0: + return entropies > t + mask = torch.zeros_like(entropies, dtype=torch.bool) mask[:, 0] = True @@ -123,6 +139,10 @@ def patch_start_mask_global_and_monotonicity(entropies, t, t_add=0): returns [bs, seq_len] mask where True indicates the start of a patch """ bs, seq_len = entropies.shape + + if seq_len == 0: + return entropies > t + mask = torch.zeros_like(entropies, dtype=torch.bool) mask[:, 0] = True @@ -521,12 +541,12 @@ class Patcher: if self.log_time: s = time.time() if entropies is not None: - scores = torch.tensor(entropies, dtype=torch.float32) + scores = entropies.to(dtype=torch.float32) elif preds is not None: scores = entropy(preds) else: start_entropies = time.time() - scores = calculate_entropies( + scores, _ = calculate_entropies( tokens, self.entropy_model, self.patching_batch_size, diff --git a/bytelatent/model/local_models.py b/bytelatent/model/local_models.py index 59fa76d..c16f62e 100644 --- a/bytelatent/model/local_models.py +++ b/bytelatent/model/local_models.py @@ -199,9 +199,6 @@ class LocalModelBase(nn.Module): class LocalEncoder(LocalModelBase): def __init__(self, args: LocalModelArgs): super().__init__(args) - self.output_proj = ( - args.patching_mode in ["entropy", "probmax"] - ) and args.entropy_model_checkpoint_dir is None self.apply_transformer = args.use_local_encoder_transformer self.downsampling_by_pooling = args.downsampling_by_pooling diff --git a/bytelatent/model/utils.py b/bytelatent/model/utils.py index 7ca979d..e01672e 100644 --- a/bytelatent/model/utils.py +++ b/bytelatent/model/utils.py @@ -162,9 +162,6 @@ def create_causal_mask( return "causal" if BLT_SUPPRESS_ATTN_ERROR == 1: - logging.warning( - "SDPA attention being used, which doesn't have specialized attention implementations for block_causal and local_block_causal attention. Allowing model to run since BLT_SUPPRESS_ATTN_ERROR=1" - ) return "causal" else: raise ValueError( diff --git a/bytelatent/preprocess/preprocess_entropies.py b/bytelatent/preprocess/preprocess_entropies.py index 1c19a5a..519da94 100644 --- a/bytelatent/preprocess/preprocess_entropies.py +++ b/bytelatent/preprocess/preprocess_entropies.py @@ -117,7 +117,7 @@ def main( text = get_text(doc) tokens = torch.tensor(tokenizer.encode(text)) patch_start = time.time() - scores = calculate_entropies( + scores, _ = calculate_entropies( tokens, entropy_model, patching_batch_size,