mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2025-09-05 20:19:51 +00:00
ktransformers/utils: fix _get_logits_warper error
This commit is contained in:
parent
7530491f5b
commit
b3a1fcf471
1 changed files with 64 additions and 8 deletions
|
@ -11,6 +11,17 @@ from torch import nn
|
|||
import itertools
|
||||
import time
|
||||
import enum
|
||||
from transformers import (
|
||||
LogitsProcessorList,
|
||||
TemperatureLogitsWarper,
|
||||
TopKLogitsWarper,
|
||||
TopPLogitsWarper,
|
||||
MinPLogitsWarper,
|
||||
TypicalLogitsWarper,
|
||||
EpsilonLogitsWarper,
|
||||
EtaLogitsWarper,
|
||||
)
|
||||
|
||||
from ktransformers.util.custom_gguf import translate_name_to_gguf
|
||||
from ktransformers.util.custom_gguf import GGUFLoader
|
||||
from ktransformers.operators import base_operator
|
||||
|
@ -126,6 +137,57 @@ def load_weights(module:nn.Module, gguf_loader:GGUFLoader, prefix=''):
|
|||
else:
|
||||
module.load()
|
||||
|
||||
def tf_logits_warper(generation_config):
|
||||
"""
|
||||
This class returns a [`LogitsProcessorList`] list object that contains all relevant [`LogitsWarper`] instances
|
||||
used for multinomial sampling.
|
||||
"""
|
||||
|
||||
# instantiate warpers list
|
||||
warpers = LogitsProcessorList()
|
||||
|
||||
# In beam methods, we need to keep at least one non-eos token to explore continuations that might have a
|
||||
# better score (i.e. keep len(list(generation_config._eos_token_tensor)) + 1)
|
||||
if generation_config.num_beams > 1:
|
||||
if isinstance(generation_config._eos_token_tensor, list):
|
||||
min_tokens_to_keep = len(generation_config._eos_token_tensor) + 1
|
||||
elif isinstance(generation_config._eos_token_tensor, torch.Tensor):
|
||||
min_tokens_to_keep = generation_config._eos_token_tensor.shape[0] + 1
|
||||
else:
|
||||
min_tokens_to_keep = 2
|
||||
else:
|
||||
min_tokens_to_keep = 1
|
||||
|
||||
# the following idea is largely copied from this PR: https://github.com/huggingface/transformers/pull/5420/files
|
||||
# all samplers can be found in `generation_utils_samplers.py`
|
||||
if generation_config.temperature is not None and generation_config.temperature != 1.0:
|
||||
warpers.append(TemperatureLogitsWarper(generation_config.temperature))
|
||||
if generation_config.top_k is not None and generation_config.top_k != 0:
|
||||
warpers.append(TopKLogitsWarper(top_k=generation_config.top_k, min_tokens_to_keep=min_tokens_to_keep))
|
||||
if generation_config.top_p is not None and generation_config.top_p < 1.0:
|
||||
warpers.append(TopPLogitsWarper(top_p=generation_config.top_p, min_tokens_to_keep=min_tokens_to_keep))
|
||||
if generation_config.min_p is not None:
|
||||
# Applied after temperature scaling (see https://github.com/ggerganov/llama.cpp/pull/3841#issuecomment-2073826084)
|
||||
warpers.append(MinPLogitsWarper(min_p=generation_config.min_p, min_tokens_to_keep=min_tokens_to_keep))
|
||||
if generation_config.typical_p is not None and generation_config.typical_p < 1.0:
|
||||
warpers.append(
|
||||
TypicalLogitsWarper(mass=generation_config.typical_p, min_tokens_to_keep=min_tokens_to_keep)
|
||||
)
|
||||
if generation_config.epsilon_cutoff is not None and 0.0 < generation_config.epsilon_cutoff < 1.0:
|
||||
warpers.append(
|
||||
EpsilonLogitsWarper(epsilon=generation_config.epsilon_cutoff, min_tokens_to_keep=min_tokens_to_keep)
|
||||
)
|
||||
if generation_config.eta_cutoff is not None and 0.0 < generation_config.eta_cutoff < 1.0:
|
||||
warpers.append(
|
||||
EtaLogitsWarper(
|
||||
epsilon=generation_config.eta_cutoff, min_tokens_to_keep=min_tokens_to_keep, device=device
|
||||
)
|
||||
)
|
||||
# `LogitNormalization` should always be the last logit processor, when present
|
||||
if generation_config.renormalize_logits is True:
|
||||
warpers.append(LogitNormalization())
|
||||
return warpers
|
||||
|
||||
def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cuda_graph: bool = True,
|
||||
mode = 'normal', force_think: bool = False, chunk_size = 16384, use_flashinfer_mla = False,
|
||||
num_heads = None, head_dim_ckv = None, head_dim_kpe = None, q_head_dim = None):
|
||||
|
@ -201,14 +263,8 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud
|
|||
# change this to modify generate config
|
||||
#top_k=5, top_p=0.85, temperature=0.1
|
||||
)
|
||||
try: # transformers==4.43
|
||||
logits_warper = (
|
||||
model._get_logits_warper(generation_config,device=inputs.device)
|
||||
)
|
||||
except:
|
||||
logits_warper = (
|
||||
model._get_logits_warper(generation_config)
|
||||
)
|
||||
|
||||
logits_warper = tf_logits_warper(generation_config)
|
||||
|
||||
cache_position = torch.arange(seq_length, device=torch_device, dtype=torch.int32)
|
||||
generated_ids = torch.zeros(
|
||||
|
|
Loading…
Add table
Reference in a new issue