mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2025-09-07 04:59:55 +00:00
Left out
This commit is contained in:
parent
91062a834f
commit
07eb712a73
1 changed files with 2 additions and 2 deletions
|
@ -127,7 +127,7 @@ class KTransformersInterface(TransformersInterface):
|
|||
|
||||
|
||||
@torch.no_grad
|
||||
def prefill(self, input_ids: torch.Tensor, is_new: bool):
|
||||
def prefill(self, input_ids: torch.Tensor, is_new: bool, temperature: Optional[float], top_p: Optional[float]):
|
||||
input_ids_length = input_ids.shape[-1]
|
||||
logger.debug(f"input_ids: {input_ids.shape}")
|
||||
|
||||
|
@ -198,7 +198,7 @@ class KTransformersInterface(TransformersInterface):
|
|||
else:
|
||||
logits = self.model(inputs_embeds=inputs_embeds, return_dict=False)[0]
|
||||
|
||||
self.prepare_logits_wrapper(input_ids, device)
|
||||
self.prepare_logits_wrapper(input_ids, device, temperature, top_p)
|
||||
next_token = self.logits_to_token(logits[0, -1, :])
|
||||
yield self.append_new_tokens(next_token)
|
||||
|
||||
|
|
Loading…
Add table
Reference in a new issue