mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2025-09-09 13:55:27 +00:00
support chunk prefill, support 139K context for 24G VRAM
This commit is contained in:
parent
494469d4c5
commit
f35e8d41d8
10 changed files with 227 additions and 83 deletions
|
@ -242,12 +242,10 @@ class TransformersInterface(BackendInterfaceBase):
|
|||
|
||||
def decode_one_tokens(self):
|
||||
if self.use_static_cache:
|
||||
mask = torch.ones((1, self.seq_length)).to(self.args.device)
|
||||
logits = self.model(
|
||||
self.current_ids,
|
||||
cache_position=self.active_cache_position,
|
||||
past_key_values=self.cache,
|
||||
attention_mask=mask,
|
||||
return_dict=False,
|
||||
use_cache=True,
|
||||
)[0]
|
||||
|
@ -309,7 +307,6 @@ class TransformersInterface(BackendInterfaceBase):
|
|||
cache_position = torch.arange(former_seq_length, self.seq_length, device=self.args.device)
|
||||
self.generated_ids[:, cache_position] = input_ids.to(self.args.device).to(torch.int)
|
||||
|
||||
mask = torch.ones((1, self.seq_length)).to(self.args.device)
|
||||
device = input_ids.device
|
||||
if not (type(self) is TransformersInterface):
|
||||
input_ids = input_ids.to("cpu")
|
||||
|
@ -321,7 +318,6 @@ class TransformersInterface(BackendInterfaceBase):
|
|||
past_key_values=self.cache,
|
||||
return_dict=False,
|
||||
use_cache=True,
|
||||
attention_mask=mask,
|
||||
)[0]
|
||||
else:
|
||||
logits = self.model(inputs_embeds=inputs_embeds, return_dict=False)[0]
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue