mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2025-09-10 06:14:58 +00:00
Mirror #1247 in server mode
This commit is contained in:
parent
ce75fcd7dd
commit
00949d5e8d
1 changed files with 62 additions and 8 deletions
|
@ -11,6 +11,14 @@ from transformers import (
|
|||
StaticCache,
|
||||
AutoModelForCausalLM,
|
||||
BitsAndBytesConfig,
|
||||
LogitsProcessorList,
|
||||
TemperatureLogitsWarper,
|
||||
TopKLogitsWarper,
|
||||
TopPLogitsWarper,
|
||||
MinPLogitsWarper,
|
||||
TypicalLogitsWarper,
|
||||
EpsilonLogitsWarper,
|
||||
EtaLogitsWarper,
|
||||
)
|
||||
|
||||
from ktransformers.server.config.config import Config
|
||||
|
@ -206,6 +214,58 @@ class TransformersInterface(BackendInterfaceBase):
|
|||
self.seq_length += 1
|
||||
return self.streamer.put(new_tokens)
|
||||
|
||||
@staticmethod
|
||||
def tf_logits_warper(generation_config):
|
||||
"""
|
||||
This class returns a [`LogitsProcessorList`] list object that contains all relevant [`LogitsWarper`] instances
|
||||
used for multinomial sampling.
|
||||
"""
|
||||
|
||||
# instantiate warpers list
|
||||
warpers = LogitsProcessorList()
|
||||
|
||||
# In beam methods, we need to keep at least one non-eos token to explore continuations that might have a
|
||||
# better score (i.e. keep len(list(generation_config._eos_token_tensor)) + 1)
|
||||
if generation_config.num_beams > 1:
|
||||
if isinstance(generation_config._eos_token_tensor, list):
|
||||
min_tokens_to_keep = len(generation_config._eos_token_tensor) + 1
|
||||
elif isinstance(generation_config._eos_token_tensor, torch.Tensor):
|
||||
min_tokens_to_keep = generation_config._eos_token_tensor.shape[0] + 1
|
||||
else:
|
||||
min_tokens_to_keep = 2
|
||||
else:
|
||||
min_tokens_to_keep = 1
|
||||
|
||||
# the following idea is largely copied from this PR: https://github.com/huggingface/transformers/pull/5420/files
|
||||
# all samplers can be found in `generation_utils_samplers.py`
|
||||
if generation_config.temperature is not None and generation_config.temperature != 1.0:
|
||||
warpers.append(TemperatureLogitsWarper(generation_config.temperature))
|
||||
if generation_config.top_k is not None and generation_config.top_k != 0:
|
||||
warpers.append(TopKLogitsWarper(top_k=generation_config.top_k, min_tokens_to_keep=min_tokens_to_keep))
|
||||
if generation_config.top_p is not None and generation_config.top_p < 1.0:
|
||||
warpers.append(TopPLogitsWarper(top_p=generation_config.top_p, min_tokens_to_keep=min_tokens_to_keep))
|
||||
if generation_config.min_p is not None:
|
||||
# Applied after temperature scaling (see https://github.com/ggerganov/llama.cpp/pull/3841#issuecomment-2073826084)
|
||||
warpers.append(MinPLogitsWarper(min_p=generation_config.min_p, min_tokens_to_keep=min_tokens_to_keep))
|
||||
if generation_config.typical_p is not None and generation_config.typical_p < 1.0:
|
||||
warpers.append(
|
||||
TypicalLogitsWarper(mass=generation_config.typical_p, min_tokens_to_keep=min_tokens_to_keep)
|
||||
)
|
||||
if generation_config.epsilon_cutoff is not None and 0.0 < generation_config.epsilon_cutoff < 1.0:
|
||||
warpers.append(
|
||||
EpsilonLogitsWarper(epsilon=generation_config.epsilon_cutoff, min_tokens_to_keep=min_tokens_to_keep)
|
||||
)
|
||||
if generation_config.eta_cutoff is not None and 0.0 < generation_config.eta_cutoff < 1.0:
|
||||
warpers.append(
|
||||
EtaLogitsWarper(
|
||||
epsilon=generation_config.eta_cutoff, min_tokens_to_keep=min_tokens_to_keep, device=device
|
||||
)
|
||||
)
|
||||
# `LogitNormalization` should always be the last logit processor, when present
|
||||
if generation_config.renormalize_logits is True:
|
||||
warpers.append(LogitNormalization())
|
||||
return warpers
|
||||
|
||||
def prepare_logits_wrapper(self, inputs, device, temperature: Optional[float] = None, top_p: Optional[float] = None):
|
||||
if temperature is None or temperature == 0:
|
||||
temperature = self.model.generation_config.temperature
|
||||
|
@ -222,14 +282,8 @@ class TransformersInterface(BackendInterfaceBase):
|
|||
repetition_penalty=self.args.repetition_penalty # change this to modify generate config
|
||||
)
|
||||
self.inputs = inputs
|
||||
try: # transformers==4.43
|
||||
self.logits_warper = (
|
||||
self.model._get_logits_warper(generation_config, device=device)
|
||||
)
|
||||
except:
|
||||
self.logits_warper = (
|
||||
self.model._get_logits_warper(generation_config)
|
||||
)
|
||||
|
||||
self.logits_warper = self.tf_logits_warper(generation_config)
|
||||
|
||||
def logits_to_token(self, logits: torch.Tensor):
|
||||
logits = self.logits_warper(self.inputs.view(1, -1), logits.view(1, -1))
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue