This commit is contained in:
lazymio 2025-02-24 21:51:14 +08:00
parent 91062a834f
commit 07eb712a73
No known key found for this signature in database
GPG key ID: DFF27E34A47CB873

View file

@ -127,7 +127,7 @@ class KTransformersInterface(TransformersInterface):
@torch.no_grad
def prefill(self, input_ids: torch.Tensor, is_new: bool):
def prefill(self, input_ids: torch.Tensor, is_new: bool, temperature: Optional[float], top_p: Optional[float]):
input_ids_length = input_ids.shape[-1]
logger.debug(f"input_ids: {input_ids.shape}")
@ -198,7 +198,7 @@ class KTransformersInterface(TransformersInterface):
else:
logits = self.model(inputs_embeds=inputs_embeds, return_dict=False)[0]
self.prepare_logits_wrapper(input_ids, device)
self.prepare_logits_wrapper(input_ids, device, temperature, top_p)
next_token = self.logits_to_token(logits[0, -1, :])
yield self.append_new_tokens(next_token)