diff --git a/ktransformers/server/backend/interfaces/ktransformers.py b/ktransformers/server/backend/interfaces/ktransformers.py index 47d99c6..6de0998 100644 --- a/ktransformers/server/backend/interfaces/ktransformers.py +++ b/ktransformers/server/backend/interfaces/ktransformers.py @@ -201,10 +201,9 @@ class KTransformersInterface(TransformersInterface): else: logits = self.model(inputs_embeds=inputs_embeds, return_dict=False)[0] - self.prepare_logits_wrapper(input_ids, device, temperature, top_p) if flashinfer_enabled: MLAWrapperSingleton.reset_buffer() - self.prepare_logits_wrapper(input_ids, device) + self.prepare_logits_wrapper(input_ids, device, temperature, top_p) next_token = self.logits_to_token(logits[0, -1, :]) yield self.append_new_tokens(next_token)