mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2025-09-09 13:55:27 +00:00
add balance-serve, support concurrence
This commit is contained in:
parent
8d0292aa44
commit
25cee5810e
196 changed files with 22077 additions and 565 deletions
|
@ -86,6 +86,7 @@ class MLAWrapper():
|
|||
self.qo_indptr_buf = torch.empty(max_batch_size+1, dtype=torch.int32, device=device)
|
||||
self.kv_indptr_buf = torch.empty(max_batch_size+1, dtype=torch.int32, device=device)
|
||||
self.kv_indices_buf = torch.empty(max_pages, dtype=torch.int32, device=device)
|
||||
self.batch_size_tensor_buf = torch.tensor([self.max_batch_size], dtype=torch.int32, device=device)
|
||||
self.kv_len_arr_buf = torch.empty(max_batch_size, dtype=torch.int32, device=device)
|
||||
else:
|
||||
self.qo_indptr_buf = None
|
||||
|
@ -94,19 +95,22 @@ class MLAWrapper():
|
|||
self.kv_len_arr_buf = None
|
||||
self.wrapper = flashinfer.mla.BatchMLAPagedAttentionWrapper(
|
||||
self.float_workspace_buffer,
|
||||
use_cuda_graph=False,
|
||||
use_cuda_graph=use_cuda_graph,
|
||||
qo_indptr=self.qo_indptr_buf,
|
||||
kv_indptr=self.kv_indptr_buf,
|
||||
kv_indices=self.kv_indices_buf,
|
||||
kv_len_arr=self.kv_len_arr_buf,
|
||||
bsz_tensor=self.batch_size_tensor_buf
|
||||
)
|
||||
self.need_plan = True
|
||||
|
||||
|
||||
def plan(self,
|
||||
qo_indptr,
|
||||
kv_indptr,
|
||||
kv_indices,
|
||||
kv_len_arr,
|
||||
bsz_tensor,
|
||||
num_heads,
|
||||
head_dim_ckv,
|
||||
head_dim_kpe,
|
||||
|
@ -138,6 +142,7 @@ class MLAWrapper():
|
|||
sm_scale,
|
||||
q_data_type,
|
||||
kv_data_type,
|
||||
bsz_tensor
|
||||
)
|
||||
|
||||
def run(self, q_nope, q_pe, ckv, k_pe, return_lse = False):
|
||||
|
@ -240,16 +245,17 @@ if __name__ == "__main__":
|
|||
#checksame()
|
||||
#exit(0)
|
||||
|
||||
max_batch_size = 1
|
||||
max_pages = 64
|
||||
max_batch_size = 2
|
||||
max_batch_tokens = 256
|
||||
max_pages = 128
|
||||
page_size = 64
|
||||
num_heads = 128
|
||||
|
||||
# warm-up
|
||||
kv_len = 4023
|
||||
q_len = 1
|
||||
q_nope_buf = torch.randn((q_len, num_heads, 512), dtype=torch.bfloat16, device="cuda")
|
||||
q_pe_buf = torch.randn((q_len, num_heads, 64), dtype=torch.bfloat16, device="cuda")
|
||||
q_nope_buf = torch.randn((max_batch_tokens, num_heads, 512), dtype=torch.bfloat16, device="cuda")
|
||||
q_pe_buf = torch.randn((max_batch_tokens, num_heads, 64), dtype=torch.bfloat16, device="cuda")
|
||||
kv_buf = torch.randn((max_pages, page_size, 576), dtype=torch.bfloat16, device="cuda")
|
||||
ckv, k_pe = torch.split(kv_buf, [512, 64], dim=-1)
|
||||
|
||||
|
@ -260,13 +266,19 @@ if __name__ == "__main__":
|
|||
max_pages,
|
||||
)
|
||||
|
||||
used_pages = (kv_len + page_size - 1)// page_size
|
||||
kv_len_arr = torch.tensor([kv_len], dtype=torch.int32, device="cuda")
|
||||
qo_indptr = torch.tensor([0, q_len], dtype=torch.int32, device="cuda")
|
||||
kv_indptr = torch.tensor([0, used_pages], dtype=torch.int32, device="cuda")
|
||||
kv_indices = torch.empty(max_pages, dtype=torch.int32, device="cuda")
|
||||
kv_indices[:used_pages] = torch.arange(0, used_pages, dtype=torch.int32, device="cuda")
|
||||
bsz_tensor = torch.tensor([1], dtype=torch.int32, device="cuda")
|
||||
wrapper.plan(
|
||||
qo_indptr,
|
||||
None,
|
||||
None,
|
||||
kv_indptr,
|
||||
kv_indices,
|
||||
kv_len_arr,
|
||||
bsz_tensor,
|
||||
128,
|
||||
512,
|
||||
64,
|
||||
|
@ -276,14 +288,98 @@ if __name__ == "__main__":
|
|||
torch.bfloat16,
|
||||
)
|
||||
|
||||
attn_output = wrapper.run(q_nope_buf, q_pe_buf, ckv, k_pe)
|
||||
attn_output = wrapper.run(q_nope_buf[:q_len], q_pe_buf[:q_len], ckv, k_pe)
|
||||
print(attn_output.shape)
|
||||
|
||||
graph = torch.cuda.CUDAGraph()
|
||||
with torch.cuda.graph(graph):
|
||||
attn_output = wrapper.run(q_nope_buf, q_pe_buf, ckv, k_pe)
|
||||
graph.replay()
|
||||
|
||||
q = torch.cat([q_nope_buf, q_pe_buf], dim=-1)
|
||||
k = (
|
||||
torch.cat([ckv, k_pe], dim=-1)
|
||||
.view(-1, 1, 512 + 64)
|
||||
.repeat_interleave(num_heads, dim=1)
|
||||
)
|
||||
v = ckv.view(-1, 1, 512).repeat_interleave(num_heads, dim=1)
|
||||
attn_ref, lse_ref = attention_ref_torch(
|
||||
1,
|
||||
q[:q_len],
|
||||
k[:kv_len],
|
||||
v[:kv_len],
|
||||
True,
|
||||
192 ** (-0.5)
|
||||
)
|
||||
torch.testing.assert_close(attn_output[:q_len], attn_ref, rtol=5e-3, atol=5e-3)
|
||||
# warm-up finished
|
||||
|
||||
kv_len = 512
|
||||
q_len = 128
|
||||
pages = max_pages
|
||||
used_pages = (kv_len + page_size - 1)// page_size
|
||||
q_nope = torch.randn((q_len*2, num_heads, 512), dtype=torch.bfloat16, device="cuda")
|
||||
q_nope[q_len:] = q_nope[:q_len]
|
||||
q_pe = torch.randn((q_len*2, num_heads, 64), dtype=torch.bfloat16, device="cuda")
|
||||
q_pe[q_len:] = q_pe[:q_len]
|
||||
kv_cache = torch.randn((max_pages, page_size, 576), dtype=torch.bfloat16, device="cuda")
|
||||
kv_cache[used_pages:2*used_pages] = kv_cache[:used_pages]
|
||||
ckv, k_pe = torch.split(kv_cache, [512, 64], dim=-1)
|
||||
|
||||
kv_len_arr = torch.tensor([kv_len, kv_len], dtype=torch.int32, device="cuda")
|
||||
qo_indptr = torch.tensor([0, q_len, q_len*2], dtype=torch.int32, device="cuda")
|
||||
kv_indptr = torch.tensor([0, used_pages, used_pages*2], dtype=torch.int32, device="cuda")
|
||||
kv_indices = torch.empty(max_pages, dtype=torch.int32, device="cuda")
|
||||
kv_indices[:2*used_pages] = torch.arange(0, 2*used_pages, dtype=torch.int32, device="cuda")
|
||||
bsz_tensor = torch.tensor([2], dtype=torch.int32, device="cuda")
|
||||
wrapper.plan(
|
||||
qo_indptr,
|
||||
kv_indptr,
|
||||
kv_indices,
|
||||
kv_len_arr,
|
||||
bsz_tensor,
|
||||
128,
|
||||
512,
|
||||
64,
|
||||
page_size,
|
||||
192 ** (-0.5),
|
||||
torch.bfloat16,
|
||||
torch.bfloat16,
|
||||
)
|
||||
|
||||
q_nope_buf.copy_(q_nope)
|
||||
q_pe_buf.copy_(q_pe)
|
||||
kv_buf[:pages].copy_(kv_cache)
|
||||
|
||||
torch.cuda.synchronize()
|
||||
graph.replay()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
# ref_torch
|
||||
q = torch.cat([q_nope, q_pe], dim=-1)
|
||||
k = (
|
||||
torch.cat([ckv, k_pe], dim=-1)
|
||||
.view(-1, 1, 512 + 64)
|
||||
.repeat_interleave(num_heads, dim=1)
|
||||
)
|
||||
v = ckv.view(-1, 1, 512).repeat_interleave(num_heads, dim=1)
|
||||
attn_ref, lse_ref = attention_ref_torch(
|
||||
max_batch_size,
|
||||
q,
|
||||
k[:2*kv_len],
|
||||
v[:2*kv_len],
|
||||
True,
|
||||
192 ** (-0.5)
|
||||
)
|
||||
|
||||
torch.testing.assert_close(attn_ref[:q_len], attn_ref[q_len:q_len*2], rtol=1e-9, atol=1e-9)
|
||||
torch.testing.assert_close(attn_output[:q_len], attn_output[q_len:q_len*2], rtol=1e-9, atol=1e-9)
|
||||
torch.testing.assert_close(attn_output[:q_len], attn_ref[:q_len], rtol=5e-3, atol=5e-3)
|
||||
torch.testing.assert_close(attn_output[q_len:q_len*2], attn_ref[q_len:q_len*2], rtol=5e-3, atol=5e-3)
|
||||
#torch.testing.assert_close(attn_output[:q_len], attn_output[q_len:q_len*2], rtol=1e-9, atol=1e-9)
|
||||
#torch.testing.assert_close(attn_output, attn_ref, rtol=5e-3, atol=5e-3)
|
||||
|
||||
exit(0)
|
||||
|
||||
for forward_id in range(0, 1):
|
||||
print("forward_id", forward_id)
|
||||
for layer_id in range(1):
|
||||
|
@ -376,5 +472,4 @@ if __name__ == "__main__":
|
|||
#file_name = f"./flashinfer_output/layer_{layer_id}_forward_{forward_id}_attn_output.pt"
|
||||
#ktrans_output = torch.load(file_name)
|
||||
#torch.testing.assert_close(attn_output, ktrans_output.squeeze(1), rtol=1e-3, atol=1e-3)
|
||||
print("test past")
|
||||
|
||||
print("test past")
|
Loading…
Add table
Add a link
Reference in a new issue