mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2025-09-10 06:14:58 +00:00
init support for MLA using Attention kernel
This commit is contained in:
parent
62011fd63e
commit
bb35dc5b0d
5 changed files with 551 additions and 262 deletions
|
@ -133,7 +133,7 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud
|
|||
)
|
||||
else:
|
||||
past_key_values = None
|
||||
cache_position = torch.arange(seq_length, device=torch_device)
|
||||
cache_position = torch.arange(seq_length, device=torch_device, dtype=torch.long)
|
||||
generated_ids = torch.zeros(
|
||||
batch_size, seq_length + max_new_tokens + 1, dtype=torch.int, device=torch_device
|
||||
)
|
||||
|
@ -178,7 +178,7 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud
|
|||
generated_ids[:, seq_length] = next_token
|
||||
tokens.append(int(next_token))
|
||||
inputs = torch.cat((inputs, next_token.unsqueeze(0)), dim=-1)
|
||||
cache_position = torch.tensor([seq_length], device=torch_device)
|
||||
cache_position = torch.tensor([seq_length], device=torch_device, dtype=torch.long)
|
||||
position_ids = cache_position.unsqueeze(0)
|
||||
seq_length += 1
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue