diff --git a/ktransformers/server/backend/interfaces/ktransformers.py b/ktransformers/server/backend/interfaces/ktransformers.py index 85bfb29..88b7e4b 100644 --- a/ktransformers/server/backend/interfaces/ktransformers.py +++ b/ktransformers/server/backend/interfaces/ktransformers.py @@ -127,7 +127,7 @@ class KTransformersInterface(TransformersInterface): @torch.no_grad - def prefill(self, input_ids: torch.Tensor, is_new: bool): + def prefill(self, input_ids: torch.Tensor, is_new: bool, temperature: Optional[float], top_p: Optional[float]): input_ids_length = input_ids.shape[-1] logger.debug(f"input_ids: {input_ids.shape}") @@ -198,7 +198,7 @@ class KTransformersInterface(TransformersInterface): else: logits = self.model(inputs_embeds=inputs_embeds, return_dict=False)[0] - self.prepare_logits_wrapper(input_ids, device) + self.prepare_logits_wrapper(input_ids, device, temperature, top_p) next_token = self.logits_to_token(logits[0, -1, :]) yield self.append_new_tokens(next_token)