mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2025-09-06 20:49:55 +00:00
ktransformers/utils: fix _get_logits_warper error
This commit is contained in:
parent
7530491f5b
commit
b3a1fcf471
1 changed files with 64 additions and 8 deletions
|
@ -11,6 +11,17 @@ from torch import nn
|
||||||
import itertools
|
import itertools
|
||||||
import time
|
import time
|
||||||
import enum
|
import enum
|
||||||
|
from transformers import (
|
||||||
|
LogitsProcessorList,
|
||||||
|
TemperatureLogitsWarper,
|
||||||
|
TopKLogitsWarper,
|
||||||
|
TopPLogitsWarper,
|
||||||
|
MinPLogitsWarper,
|
||||||
|
TypicalLogitsWarper,
|
||||||
|
EpsilonLogitsWarper,
|
||||||
|
EtaLogitsWarper,
|
||||||
|
)
|
||||||
|
|
||||||
from ktransformers.util.custom_gguf import translate_name_to_gguf
|
from ktransformers.util.custom_gguf import translate_name_to_gguf
|
||||||
from ktransformers.util.custom_gguf import GGUFLoader
|
from ktransformers.util.custom_gguf import GGUFLoader
|
||||||
from ktransformers.operators import base_operator
|
from ktransformers.operators import base_operator
|
||||||
|
@ -126,6 +137,57 @@ def load_weights(module:nn.Module, gguf_loader:GGUFLoader, prefix=''):
|
||||||
else:
|
else:
|
||||||
module.load()
|
module.load()
|
||||||
|
|
||||||
|
def tf_logits_warper(generation_config):
|
||||||
|
"""
|
||||||
|
This class returns a [`LogitsProcessorList`] list object that contains all relevant [`LogitsWarper`] instances
|
||||||
|
used for multinomial sampling.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# instantiate warpers list
|
||||||
|
warpers = LogitsProcessorList()
|
||||||
|
|
||||||
|
# In beam methods, we need to keep at least one non-eos token to explore continuations that might have a
|
||||||
|
# better score (i.e. keep len(list(generation_config._eos_token_tensor)) + 1)
|
||||||
|
if generation_config.num_beams > 1:
|
||||||
|
if isinstance(generation_config._eos_token_tensor, list):
|
||||||
|
min_tokens_to_keep = len(generation_config._eos_token_tensor) + 1
|
||||||
|
elif isinstance(generation_config._eos_token_tensor, torch.Tensor):
|
||||||
|
min_tokens_to_keep = generation_config._eos_token_tensor.shape[0] + 1
|
||||||
|
else:
|
||||||
|
min_tokens_to_keep = 2
|
||||||
|
else:
|
||||||
|
min_tokens_to_keep = 1
|
||||||
|
|
||||||
|
# the following idea is largely copied from this PR: https://github.com/huggingface/transformers/pull/5420/files
|
||||||
|
# all samplers can be found in `generation_utils_samplers.py`
|
||||||
|
if generation_config.temperature is not None and generation_config.temperature != 1.0:
|
||||||
|
warpers.append(TemperatureLogitsWarper(generation_config.temperature))
|
||||||
|
if generation_config.top_k is not None and generation_config.top_k != 0:
|
||||||
|
warpers.append(TopKLogitsWarper(top_k=generation_config.top_k, min_tokens_to_keep=min_tokens_to_keep))
|
||||||
|
if generation_config.top_p is not None and generation_config.top_p < 1.0:
|
||||||
|
warpers.append(TopPLogitsWarper(top_p=generation_config.top_p, min_tokens_to_keep=min_tokens_to_keep))
|
||||||
|
if generation_config.min_p is not None:
|
||||||
|
# Applied after temperature scaling (see https://github.com/ggerganov/llama.cpp/pull/3841#issuecomment-2073826084)
|
||||||
|
warpers.append(MinPLogitsWarper(min_p=generation_config.min_p, min_tokens_to_keep=min_tokens_to_keep))
|
||||||
|
if generation_config.typical_p is not None and generation_config.typical_p < 1.0:
|
||||||
|
warpers.append(
|
||||||
|
TypicalLogitsWarper(mass=generation_config.typical_p, min_tokens_to_keep=min_tokens_to_keep)
|
||||||
|
)
|
||||||
|
if generation_config.epsilon_cutoff is not None and 0.0 < generation_config.epsilon_cutoff < 1.0:
|
||||||
|
warpers.append(
|
||||||
|
EpsilonLogitsWarper(epsilon=generation_config.epsilon_cutoff, min_tokens_to_keep=min_tokens_to_keep)
|
||||||
|
)
|
||||||
|
if generation_config.eta_cutoff is not None and 0.0 < generation_config.eta_cutoff < 1.0:
|
||||||
|
warpers.append(
|
||||||
|
EtaLogitsWarper(
|
||||||
|
epsilon=generation_config.eta_cutoff, min_tokens_to_keep=min_tokens_to_keep, device=device
|
||||||
|
)
|
||||||
|
)
|
||||||
|
# `LogitNormalization` should always be the last logit processor, when present
|
||||||
|
if generation_config.renormalize_logits is True:
|
||||||
|
warpers.append(LogitNormalization())
|
||||||
|
return warpers
|
||||||
|
|
||||||
def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cuda_graph: bool = True,
|
def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cuda_graph: bool = True,
|
||||||
mode = 'normal', force_think: bool = False, chunk_size = 16384, use_flashinfer_mla = False,
|
mode = 'normal', force_think: bool = False, chunk_size = 16384, use_flashinfer_mla = False,
|
||||||
num_heads = None, head_dim_ckv = None, head_dim_kpe = None, q_head_dim = None):
|
num_heads = None, head_dim_ckv = None, head_dim_kpe = None, q_head_dim = None):
|
||||||
|
@ -201,14 +263,8 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud
|
||||||
# change this to modify generate config
|
# change this to modify generate config
|
||||||
#top_k=5, top_p=0.85, temperature=0.1
|
#top_k=5, top_p=0.85, temperature=0.1
|
||||||
)
|
)
|
||||||
try: # transformers==4.43
|
|
||||||
logits_warper = (
|
logits_warper = tf_logits_warper(generation_config)
|
||||||
model._get_logits_warper(generation_config,device=inputs.device)
|
|
||||||
)
|
|
||||||
except:
|
|
||||||
logits_warper = (
|
|
||||||
model._get_logits_warper(generation_config)
|
|
||||||
)
|
|
||||||
|
|
||||||
cache_position = torch.arange(seq_length, device=torch_device, dtype=torch.int32)
|
cache_position = torch.arange(seq_length, device=torch_device, dtype=torch.int32)
|
||||||
generated_ids = torch.zeros(
|
generated_ids = torch.zeros(
|
||||||
|
|
Loading…
Add table
Reference in a new issue