mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2025-09-09 22:05:30 +00:00
Left out
This commit is contained in:
parent
91062a834f
commit
07eb712a73
1 changed files with 2 additions and 2 deletions
|
@ -127,7 +127,7 @@ class KTransformersInterface(TransformersInterface):
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad
|
@torch.no_grad
|
||||||
def prefill(self, input_ids: torch.Tensor, is_new: bool):
|
def prefill(self, input_ids: torch.Tensor, is_new: bool, temperature: Optional[float], top_p: Optional[float]):
|
||||||
input_ids_length = input_ids.shape[-1]
|
input_ids_length = input_ids.shape[-1]
|
||||||
logger.debug(f"input_ids: {input_ids.shape}")
|
logger.debug(f"input_ids: {input_ids.shape}")
|
||||||
|
|
||||||
|
@ -198,7 +198,7 @@ class KTransformersInterface(TransformersInterface):
|
||||||
else:
|
else:
|
||||||
logits = self.model(inputs_embeds=inputs_embeds, return_dict=False)[0]
|
logits = self.model(inputs_embeds=inputs_embeds, return_dict=False)[0]
|
||||||
|
|
||||||
self.prepare_logits_wrapper(input_ids, device)
|
self.prepare_logits_wrapper(input_ids, device, temperature, top_p)
|
||||||
next_token = self.logits_to_token(logits[0, -1, :])
|
next_token = self.logits_to_token(logits[0, -1, :])
|
||||||
yield self.append_new_tokens(next_token)
|
yield self.append_new_tokens(next_token)
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue