Also allow repetition_penalty

This commit is contained in:
lazymio 2025-02-24 21:07:35 +08:00
parent 8704c09192
commit bf36547f98
No known key found for this signature in database
GPG key ID: DFF27E34A47CB873
3 changed files with 11 additions and 8 deletions

View file

@ -202,18 +202,20 @@ class TransformersInterface(BackendInterfaceBase):
self.seq_length += 1
return self.streamer.put(new_tokens)
def prepare_logits_wrapper(self, inputs, device, temperature: Optional[float] = None, top_p: Optional[float] = None):
def prepare_logits_wrapper(self, inputs, device, temperature: Optional[float] = None, top_p: Optional[float] = None, repetition_penalty: Optional[float] = None):
if temperature is None:
temperature = self.args.temperature
if top_p is None:
top_p = self.args.top_p
if repetition_penalty is None:
repetition_penalty = self.args.repetition_penalty
generation_config, model_kwargs = self.model._prepare_generation_config(
None, max_length=self.args.max_new_tokens,
do_sample=True,
top_k=self.args.top_k,
top_p=top_p,
temperature=temperature,
repetition_penalty=self.args.repetition_penalty # change this to modify generate config
repetition_penalty=repetition_penalty # change this to modify generate config
)
self.inputs = inputs
self.generation_config = generation_config
@ -259,7 +261,7 @@ class TransformersInterface(BackendInterfaceBase):
return self.logits_to_token(logits)
@torch.no_grad
def prefill(self, input_ids: torch.Tensor, is_new: bool, temperature: Optional[float] = None, top_p: Optional[float] = None):
def prefill(self, input_ids: torch.Tensor, is_new: bool, temperature: Optional[float] = None, top_p: Optional[float] = None, repetition_penalty: Optional[float] = None):
input_ids_length = input_ids.shape[-1]
logger.debug(f"input_ids: {input_ids.shape}")
@ -327,7 +329,7 @@ class TransformersInterface(BackendInterfaceBase):
else:
logits = self.model(inputs_embeds=inputs_embeds, return_dict=False)[0]
self.prepare_logits_wrapper(input_ids, device, temperature, top_p)
self.prepare_logits_wrapper(input_ids, device, temperature, top_p, repetition_penalty)
next_token = self.logits_to_token(logits[0, -1, :])
yield self.append_new_tokens(next_token)
@ -363,7 +365,7 @@ class TransformersInterface(BackendInterfaceBase):
self.last_request_id = thread_id
return True
async def inference(self, local_messages, thread_id: str, temperature: Optional[float] = None, top_p: Optional[float] = None):
async def inference(self, local_messages, thread_id: str, temperature: Optional[float] = None, top_p: Optional[float] = None, repetition_penalty: Optional[float] = None):
self.streamer.reset()
self.profiler.create_and_start_timer("tokenize")
if isinstance(local_messages, List):
@ -390,7 +392,7 @@ class TransformersInterface(BackendInterfaceBase):
print(think, end="",flush=True)
yield think
for t in self.prefill(input_ids, self.check_is_new(thread_id), temperature, top_p):
for t in self.prefill(input_ids, self.check_is_new(thread_id), temperature, top_p, repetition_penalty):
# output think token after prefill done
if t is not None:
print(t, end="",flush=True)