mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2025-09-09 13:55:27 +00:00
support chunk prefill, support 139K context for 24G VRAM
This commit is contained in:
parent
494469d4c5
commit
f35e8d41d8
10 changed files with 227 additions and 83 deletions
|
@ -205,12 +205,13 @@ class MLAWrapperSingleton():
|
|||
|
||||
|
||||
if __name__ == "__main__":
|
||||
torch.set_default_dtype(torch.bfloat16)
|
||||
max_batch_size = 1
|
||||
max_pages = 128
|
||||
page_size = 64
|
||||
num_heads = 128
|
||||
|
||||
kv_len = 2069
|
||||
kv_len = 4023
|
||||
q_len = 1
|
||||
q_nope = torch.randn((q_len, num_heads, 512), dtype=torch.bfloat16, device="cuda")
|
||||
q_pe = torch.randn((q_len, num_heads, 64), dtype=torch.bfloat16, device="cuda")
|
||||
|
@ -242,6 +243,29 @@ if __name__ == "__main__":
|
|||
|
||||
attn_output = wrapper.run(q_nope, q_pe, ckv, k_pe)
|
||||
print(attn_output.shape)
|
||||
|
||||
graph = torch.cuda.CUDAGraph()
|
||||
with torch.cuda.graph(graph):
|
||||
attn_output = wrapper.run(q_nope, q_pe, ckv, k_pe)
|
||||
|
||||
kv_len = 6789
|
||||
kv_len_arr = torch.tensor([kv_len], dtype=torch.int32, device="cuda")
|
||||
qo_indptr = torch.tensor([0, q_len], dtype=torch.int32, device="cuda")
|
||||
wrapper.plan(
|
||||
qo_indptr,
|
||||
None,
|
||||
None,
|
||||
kv_len_arr,
|
||||
128,
|
||||
512,
|
||||
64,
|
||||
page_size,
|
||||
192 ** (-0.5),
|
||||
torch.bfloat16,
|
||||
torch.bfloat16,
|
||||
)
|
||||
|
||||
graph.replay()
|
||||
|
||||
k = (
|
||||
torch.cat([ckv, k_pe], dim=-1)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue