support chunk prefill, support 139K context for 24G VRAM

This commit is contained in:
Atream 2025-03-01 11:28:25 +00:00
parent 494469d4c5
commit f35e8d41d8
10 changed files with 227 additions and 83 deletions

View file

@ -205,12 +205,13 @@ class MLAWrapperSingleton():
if __name__ == "__main__":
torch.set_default_dtype(torch.bfloat16)
max_batch_size = 1
max_pages = 128
page_size = 64
num_heads = 128
kv_len = 2069
kv_len = 4023
q_len = 1
q_nope = torch.randn((q_len, num_heads, 512), dtype=torch.bfloat16, device="cuda")
q_pe = torch.randn((q_len, num_heads, 64), dtype=torch.bfloat16, device="cuda")
@ -242,6 +243,29 @@ if __name__ == "__main__":
attn_output = wrapper.run(q_nope, q_pe, ckv, k_pe)
print(attn_output.shape)
graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(graph):
attn_output = wrapper.run(q_nope, q_pe, ckv, k_pe)
kv_len = 6789
kv_len_arr = torch.tensor([kv_len], dtype=torch.int32, device="cuda")
qo_indptr = torch.tensor([0, q_len], dtype=torch.int32, device="cuda")
wrapper.plan(
qo_indptr,
None,
None,
kv_len_arr,
128,
512,
64,
page_size,
192 ** (-0.5),
torch.bfloat16,
torch.bfloat16,
)
graph.replay()
k = (
torch.cat([ckv, k_pe], dim=-1)