diff --git a/ktransformers/util/utils.py b/ktransformers/util/utils.py index 30f8880..5a83a45 100644 --- a/ktransformers/util/utils.py +++ b/ktransformers/util/utils.py @@ -11,6 +11,17 @@ from torch import nn import itertools import time import enum +from transformers import ( + LogitsProcessorList, + TemperatureLogitsWarper, + TopKLogitsWarper, + TopPLogitsWarper, + MinPLogitsWarper, + TypicalLogitsWarper, + EpsilonLogitsWarper, + EtaLogitsWarper, +) + from ktransformers.util.custom_gguf import translate_name_to_gguf from ktransformers.util.custom_gguf import GGUFLoader from ktransformers.operators import base_operator @@ -126,6 +137,57 @@ def load_weights(module:nn.Module, gguf_loader:GGUFLoader, prefix=''): else: module.load() +def tf_logits_warper(generation_config): + """ + This class returns a [`LogitsProcessorList`] list object that contains all relevant [`LogitsWarper`] instances + used for multinomial sampling. + """ + + # instantiate warpers list + warpers = LogitsProcessorList() + + # In beam methods, we need to keep at least one non-eos token to explore continuations that might have a + # better score (i.e. keep len(list(generation_config._eos_token_tensor)) + 1) + if generation_config.num_beams > 1: + if isinstance(generation_config._eos_token_tensor, list): + min_tokens_to_keep = len(generation_config._eos_token_tensor) + 1 + elif isinstance(generation_config._eos_token_tensor, torch.Tensor): + min_tokens_to_keep = generation_config._eos_token_tensor.shape[0] + 1 + else: + min_tokens_to_keep = 2 + else: + min_tokens_to_keep = 1 + + # the following idea is largely copied from this PR: https://github.com/huggingface/transformers/pull/5420/files + # all samplers can be found in `generation_utils_samplers.py` + if generation_config.temperature is not None and generation_config.temperature != 1.0: + warpers.append(TemperatureLogitsWarper(generation_config.temperature)) + if generation_config.top_k is not None and generation_config.top_k != 0: + warpers.append(TopKLogitsWarper(top_k=generation_config.top_k, min_tokens_to_keep=min_tokens_to_keep)) + if generation_config.top_p is not None and generation_config.top_p < 1.0: + warpers.append(TopPLogitsWarper(top_p=generation_config.top_p, min_tokens_to_keep=min_tokens_to_keep)) + if generation_config.min_p is not None: + # Applied after temperature scaling (see https://github.com/ggerganov/llama.cpp/pull/3841#issuecomment-2073826084) + warpers.append(MinPLogitsWarper(min_p=generation_config.min_p, min_tokens_to_keep=min_tokens_to_keep)) + if generation_config.typical_p is not None and generation_config.typical_p < 1.0: + warpers.append( + TypicalLogitsWarper(mass=generation_config.typical_p, min_tokens_to_keep=min_tokens_to_keep) + ) + if generation_config.epsilon_cutoff is not None and 0.0 < generation_config.epsilon_cutoff < 1.0: + warpers.append( + EpsilonLogitsWarper(epsilon=generation_config.epsilon_cutoff, min_tokens_to_keep=min_tokens_to_keep) + ) + if generation_config.eta_cutoff is not None and 0.0 < generation_config.eta_cutoff < 1.0: + warpers.append( + EtaLogitsWarper( + epsilon=generation_config.eta_cutoff, min_tokens_to_keep=min_tokens_to_keep, device=device + ) + ) + # `LogitNormalization` should always be the last logit processor, when present + if generation_config.renormalize_logits is True: + warpers.append(LogitNormalization()) + return warpers + def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cuda_graph: bool = True, mode = 'normal', force_think: bool = False, chunk_size = 16384, use_flashinfer_mla = False, num_heads = None, head_dim_ckv = None, head_dim_kpe = None, q_head_dim = None): @@ -201,14 +263,8 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud # change this to modify generate config #top_k=5, top_p=0.85, temperature=0.1 ) - try: # transformers==4.43 - logits_warper = ( - model._get_logits_warper(generation_config,device=inputs.device) - ) - except: - logits_warper = ( - model._get_logits_warper(generation_config) - ) + + logits_warper = tf_logits_warper(generation_config) cache_position = torch.arange(seq_length, device=torch_device, dtype=torch.int32) generated_ids = torch.zeros(