mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2025-09-10 23:34:35 +00:00
fix temperature
This commit is contained in:
parent
5e3c6b4f97
commit
22df52e94e
2 changed files with 16 additions and 16 deletions
|
@ -203,10 +203,10 @@ class TransformersInterface(BackendInterfaceBase):
|
|||
return self.streamer.put(new_tokens)
|
||||
|
||||
def prepare_logits_wrapper(self, inputs, device, temperature: Optional[float] = None, top_p: Optional[float] = None):
|
||||
if temperature is None:
|
||||
temperature = self.args.temperature
|
||||
if temperature is None or temperature == 0:
|
||||
temperature = self.model.generation_config.temperature
|
||||
if top_p is None:
|
||||
top_p = self.args.top_p
|
||||
top_p = self.model.generation_config.top_p
|
||||
generation_config, model_kwargs = self.model._prepare_generation_config(
|
||||
None, max_length=self.args.max_new_tokens,
|
||||
do_sample=True,
|
||||
|
@ -216,10 +216,9 @@ class TransformersInterface(BackendInterfaceBase):
|
|||
repetition_penalty=self.args.repetition_penalty # change this to modify generate config
|
||||
)
|
||||
self.inputs = inputs
|
||||
self.generation_config = generation_config
|
||||
try: # transformers==4.43
|
||||
self.logits_warper = (
|
||||
self.model._get_logits_warper(generation_config,device=device)
|
||||
self.model._get_logits_warper(generation_config, device=device)
|
||||
)
|
||||
except:
|
||||
self.logits_warper = (
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue