init support for MLA using Attention kernel

This commit is contained in:
Atream 2025-02-13 15:01:14 +00:00
parent 62011fd63e
commit bb35dc5b0d
5 changed files with 551 additions and 262 deletions

View file

@ -133,7 +133,7 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud
)
else:
past_key_values = None
cache_position = torch.arange(seq_length, device=torch_device)
cache_position = torch.arange(seq_length, device=torch_device, dtype=torch.long)
generated_ids = torch.zeros(
batch_size, seq_length + max_new_tokens + 1, dtype=torch.int, device=torch_device
)
@ -178,7 +178,7 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud
generated_ids[:, seq_length] = next_token
tokens.append(int(next_token))
inputs = torch.cat((inputs, next_token.unsqueeze(0)), dim=-1)
cache_position = torch.tensor([seq_length], device=torch_device)
cache_position = torch.tensor([seq_length], device=torch_device, dtype=torch.long)
position_ids = cache_position.unsqueeze(0)
seq_length += 1