ktransformers/utils: fix _get_logits_warper error

This commit is contained in:
Aubrey Li 2025-05-01 08:05:35 +08:00
parent 7530491f5b
commit b3a1fcf471

View file

@ -11,6 +11,17 @@ from torch import nn
import itertools
import time
import enum
from transformers import (
LogitsProcessorList,
TemperatureLogitsWarper,
TopKLogitsWarper,
TopPLogitsWarper,
MinPLogitsWarper,
TypicalLogitsWarper,
EpsilonLogitsWarper,
EtaLogitsWarper,
)
from ktransformers.util.custom_gguf import translate_name_to_gguf
from ktransformers.util.custom_gguf import GGUFLoader
from ktransformers.operators import base_operator
@ -126,6 +137,57 @@ def load_weights(module:nn.Module, gguf_loader:GGUFLoader, prefix=''):
else:
module.load()
def tf_logits_warper(generation_config):
"""
This class returns a [`LogitsProcessorList`] list object that contains all relevant [`LogitsWarper`] instances
used for multinomial sampling.
"""
# instantiate warpers list
warpers = LogitsProcessorList()
# In beam methods, we need to keep at least one non-eos token to explore continuations that might have a
# better score (i.e. keep len(list(generation_config._eos_token_tensor)) + 1)
if generation_config.num_beams > 1:
if isinstance(generation_config._eos_token_tensor, list):
min_tokens_to_keep = len(generation_config._eos_token_tensor) + 1
elif isinstance(generation_config._eos_token_tensor, torch.Tensor):
min_tokens_to_keep = generation_config._eos_token_tensor.shape[0] + 1
else:
min_tokens_to_keep = 2
else:
min_tokens_to_keep = 1
# the following idea is largely copied from this PR: https://github.com/huggingface/transformers/pull/5420/files
# all samplers can be found in `generation_utils_samplers.py`
if generation_config.temperature is not None and generation_config.temperature != 1.0:
warpers.append(TemperatureLogitsWarper(generation_config.temperature))
if generation_config.top_k is not None and generation_config.top_k != 0:
warpers.append(TopKLogitsWarper(top_k=generation_config.top_k, min_tokens_to_keep=min_tokens_to_keep))
if generation_config.top_p is not None and generation_config.top_p < 1.0:
warpers.append(TopPLogitsWarper(top_p=generation_config.top_p, min_tokens_to_keep=min_tokens_to_keep))
if generation_config.min_p is not None:
# Applied after temperature scaling (see https://github.com/ggerganov/llama.cpp/pull/3841#issuecomment-2073826084)
warpers.append(MinPLogitsWarper(min_p=generation_config.min_p, min_tokens_to_keep=min_tokens_to_keep))
if generation_config.typical_p is not None and generation_config.typical_p < 1.0:
warpers.append(
TypicalLogitsWarper(mass=generation_config.typical_p, min_tokens_to_keep=min_tokens_to_keep)
)
if generation_config.epsilon_cutoff is not None and 0.0 < generation_config.epsilon_cutoff < 1.0:
warpers.append(
EpsilonLogitsWarper(epsilon=generation_config.epsilon_cutoff, min_tokens_to_keep=min_tokens_to_keep)
)
if generation_config.eta_cutoff is not None and 0.0 < generation_config.eta_cutoff < 1.0:
warpers.append(
EtaLogitsWarper(
epsilon=generation_config.eta_cutoff, min_tokens_to_keep=min_tokens_to_keep, device=device
)
)
# `LogitNormalization` should always be the last logit processor, when present
if generation_config.renormalize_logits is True:
warpers.append(LogitNormalization())
return warpers
def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cuda_graph: bool = True,
mode = 'normal', force_think: bool = False, chunk_size = 16384, use_flashinfer_mla = False,
num_heads = None, head_dim_ckv = None, head_dim_kpe = None, q_head_dim = None):
@ -201,14 +263,8 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud
# change this to modify generate config
#top_k=5, top_p=0.85, temperature=0.1
)
try: # transformers==4.43
logits_warper = (
model._get_logits_warper(generation_config,device=inputs.device)
)
except:
logits_warper = (
model._get_logits_warper(generation_config)
)
logits_warper = tf_logits_warper(generation_config)
cache_position = torch.arange(seq_length, device=torch_device, dtype=torch.int32)
generated_ids = torch.zeros(