mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2025-09-09 13:55:27 +00:00
use generation config from json file in official repo
This commit is contained in:
parent
9660b2cc1e
commit
e645d84794
4 changed files with 57 additions and 21 deletions
|
@ -122,7 +122,7 @@ class MLAWrapper():
|
|||
if kv_indices is None:
|
||||
assert self.max_batch_size == 1
|
||||
kv_indices = self.kv_indices_buf
|
||||
|
||||
|
||||
self.wrapper.plan(
|
||||
qo_indptr,
|
||||
kv_indptr,
|
||||
|
@ -139,6 +139,11 @@ class MLAWrapper():
|
|||
)
|
||||
|
||||
def run(self, q_nope, q_pe, ckv, k_pe, return_lse = False):
|
||||
#print("run")
|
||||
#print(self.wrapper._qo_indptr_buf)
|
||||
#print(self.wrapper._kv_indptr_buf)
|
||||
#print(self.wrapper._kv_indices_buf)
|
||||
#print(self.wrapper._kv_len_arr_buf)
|
||||
return self.wrapper.run(q_nope, q_pe, ckv, k_pe, return_lse = return_lse)
|
||||
|
||||
class MLAWrapperSingleton():
|
||||
|
@ -201,11 +206,12 @@ class MLAWrapperSingleton():
|
|||
|
||||
if __name__ == "__main__":
|
||||
max_batch_size = 1
|
||||
max_pages = 1
|
||||
max_pages = 128
|
||||
page_size = 64
|
||||
num_heads = 128
|
||||
|
||||
q_len = 10
|
||||
kv_len = 2069
|
||||
q_len = 1
|
||||
q_nope = torch.randn((q_len, num_heads, 512), dtype=torch.bfloat16, device="cuda")
|
||||
q_pe = torch.randn((q_len, num_heads, 64), dtype=torch.bfloat16, device="cuda")
|
||||
ckv = torch.randn((max_pages, page_size, 512), dtype=torch.bfloat16, device="cuda")
|
||||
|
@ -218,7 +224,7 @@ if __name__ == "__main__":
|
|||
max_pages,
|
||||
)
|
||||
|
||||
kv_len_arr = torch.tensor([q_len], dtype=torch.int32, device="cuda")
|
||||
kv_len_arr = torch.tensor([kv_len], dtype=torch.int32, device="cuda")
|
||||
qo_indptr = torch.tensor([0, q_len], dtype=torch.int32, device="cuda")
|
||||
wrapper.plan(
|
||||
qo_indptr,
|
||||
|
@ -244,15 +250,15 @@ if __name__ == "__main__":
|
|||
)
|
||||
v = ckv.view(-1, 1, 512).repeat_interleave(num_heads, dim=1)
|
||||
|
||||
print(k[:10].shape)
|
||||
print(v[:10].shape)
|
||||
print(k[:kv_len].shape)
|
||||
print(v[:kv_len].shape)
|
||||
|
||||
attn_ref, lse_ref = attention_ref(
|
||||
max_batch_size,
|
||||
torch.cat([q_nope, q_pe], dim=-1),
|
||||
k[:10],
|
||||
v[:10],
|
||||
False,
|
||||
k[:kv_len],
|
||||
v[:kv_len],
|
||||
True,
|
||||
192 ** (-0.5)
|
||||
)
|
||||
print(attn_ref.shape)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue