mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2025-09-10 14:51:06 +00:00
fix flashinfer precision
This commit is contained in:
parent
96d75d53df
commit
d453c320f1
5 changed files with 151 additions and 61 deletions
|
@ -1,9 +1,11 @@
|
|||
'''
|
||||
Description : flashinfer MLA wrapper
|
||||
Author : Boxin Zhang
|
||||
Version : 0.2.2
|
||||
Version : 0.2.3
|
||||
'''
|
||||
import torch
|
||||
import os
|
||||
from ktransformers.operators.triton_attention import decode_attention_fwd_grouped
|
||||
|
||||
flashinfer_enabled = False
|
||||
|
||||
|
@ -17,7 +19,7 @@ except ImportError:
|
|||
|
||||
import math
|
||||
|
||||
def attention_ref(
|
||||
def attention_ref_torch(
|
||||
batch_size,
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
|
@ -122,7 +124,7 @@ class MLAWrapper():
|
|||
if kv_indices is None:
|
||||
assert self.max_batch_size == 1
|
||||
kv_indices = self.kv_indices_buf
|
||||
|
||||
|
||||
self.wrapper.plan(
|
||||
qo_indptr,
|
||||
kv_indptr,
|
||||
|
@ -139,11 +141,6 @@ class MLAWrapper():
|
|||
)
|
||||
|
||||
def run(self, q_nope, q_pe, ckv, k_pe, return_lse = False):
|
||||
#print("run")
|
||||
#print(self.wrapper._qo_indptr_buf)
|
||||
#print(self.wrapper._kv_indptr_buf)
|
||||
#print(self.wrapper._kv_indices_buf)
|
||||
#print(self.wrapper._kv_len_arr_buf)
|
||||
return self.wrapper.run(q_nope, q_pe, ckv, k_pe, return_lse = return_lse)
|
||||
|
||||
class MLAWrapperSingleton():
|
||||
|
@ -202,21 +199,59 @@ class MLAWrapperSingleton():
|
|||
wrapper.kv_indptr_buf[1] = max_pages # assert max_batch_size=1 here.
|
||||
wrapper.kv_indices_buf = torch.arange(0, max_pages, dtype=torch.int32, device=device)
|
||||
wrapper.wrapper._kv_indices_buf = wrapper.kv_indices_buf
|
||||
|
||||
def checksame():
|
||||
flashinfer_folder = "./flashinfer_output"
|
||||
flashinfer_folder = "./kv_cache_flashinfer"
|
||||
triton_folder = "./triton_output"
|
||||
triton_folder = "./kv_cache_triton"
|
||||
|
||||
max_layer_id = 1
|
||||
max_forward_id = 2
|
||||
|
||||
for forward_id in range(0, 19):
|
||||
print("forward_id", forward_id)
|
||||
for layer_id in range(max_layer_id):
|
||||
print(layer_id)
|
||||
#file_name = f"layer_{layer_id}_forward_{forward_id}_attn_output.pt"
|
||||
#file_name = f"layer_{layer_id}_forward_{forward_id}_q_pe.pt"
|
||||
file_name = f"layer_{layer_id}.pt"
|
||||
|
||||
flashinfer_path = os.path.join(flashinfer_folder, file_name)
|
||||
triton_path = os.path.join(triton_folder, file_name)
|
||||
|
||||
if not os.path.exists(triton_path):
|
||||
print(f"{file_name} not exist in {triton_folder}")
|
||||
continue
|
||||
if not os.path.exists(flashinfer_path):
|
||||
print(f"{file_name} not exist in {flashinfer_folder}")
|
||||
continue
|
||||
|
||||
|
||||
flashinfer_tensor = torch.load(flashinfer_path)[1:2, :62]#
|
||||
triton_tensor = torch.load(triton_path)[1:2, :62]#.squeeze(1)#
|
||||
try:
|
||||
torch.testing.assert_close(flashinfer_tensor, triton_tensor, rtol=1e-9, atol=1e-9)
|
||||
except AssertionError as e:
|
||||
print(e)
|
||||
|
||||
if __name__ == "__main__":
|
||||
torch.set_default_dtype(torch.bfloat16)
|
||||
|
||||
#checksame()
|
||||
#exit(0)
|
||||
|
||||
max_batch_size = 1
|
||||
max_pages = 128
|
||||
max_pages = 64
|
||||
page_size = 64
|
||||
num_heads = 128
|
||||
|
||||
|
||||
# warm-up
|
||||
kv_len = 4023
|
||||
q_len = 1
|
||||
q_nope = torch.randn((q_len, num_heads, 512), dtype=torch.bfloat16, device="cuda")
|
||||
q_pe = torch.randn((q_len, num_heads, 64), dtype=torch.bfloat16, device="cuda")
|
||||
ckv = torch.randn((max_pages, page_size, 512), dtype=torch.bfloat16, device="cuda")
|
||||
k_pe = torch.randn((max_pages, page_size, 64), dtype=torch.bfloat16, device="cuda")
|
||||
q_nope_buf = torch.randn((q_len, num_heads, 512), dtype=torch.bfloat16, device="cuda")
|
||||
q_pe_buf = torch.randn((q_len, num_heads, 64), dtype=torch.bfloat16, device="cuda")
|
||||
kv_buf = torch.randn((max_pages, page_size, 576), dtype=torch.bfloat16, device="cuda")
|
||||
ckv, k_pe = torch.split(kv_buf, [512, 64], dim=-1)
|
||||
|
||||
|
||||
wrapper = MLAWrapperSingleton.get_instance(
|
||||
|
@ -241,51 +276,105 @@ if __name__ == "__main__":
|
|||
torch.bfloat16,
|
||||
)
|
||||
|
||||
attn_output = wrapper.run(q_nope, q_pe, ckv, k_pe)
|
||||
attn_output = wrapper.run(q_nope_buf, q_pe_buf, ckv, k_pe)
|
||||
print(attn_output.shape)
|
||||
|
||||
graph = torch.cuda.CUDAGraph()
|
||||
with torch.cuda.graph(graph):
|
||||
attn_output = wrapper.run(q_nope, q_pe, ckv, k_pe)
|
||||
|
||||
kv_len = 6789
|
||||
kv_len_arr = torch.tensor([kv_len], dtype=torch.int32, device="cuda")
|
||||
qo_indptr = torch.tensor([0, q_len], dtype=torch.int32, device="cuda")
|
||||
wrapper.plan(
|
||||
qo_indptr,
|
||||
None,
|
||||
None,
|
||||
kv_len_arr,
|
||||
128,
|
||||
512,
|
||||
64,
|
||||
page_size,
|
||||
192 ** (-0.5),
|
||||
torch.bfloat16,
|
||||
torch.bfloat16,
|
||||
)
|
||||
attn_output = wrapper.run(q_nope_buf, q_pe_buf, ckv, k_pe)
|
||||
# warm-up finished
|
||||
|
||||
for forward_id in range(0, 1):
|
||||
print("forward_id", forward_id)
|
||||
for layer_id in range(1):
|
||||
print(layer_id)
|
||||
flashinfer_folder = "./kv_cache_flashinfer"
|
||||
forward_id = 17
|
||||
layer_id = 0
|
||||
file_name = f"layer_{layer_id}.pt"
|
||||
kv_cache_path = os.path.join(flashinfer_folder, file_name)
|
||||
flashinfer_folder = "./flashinfer_output"
|
||||
|
||||
q_len = 1
|
||||
kv_len = 126
|
||||
file_name = f"layer_{layer_id}_forward_{forward_id}_q_nope.pt"
|
||||
q_nope = torch.load(os.path.join(flashinfer_folder, file_name)).view(q_len,128,512).to(device="cuda")
|
||||
file_name = f"layer_{layer_id}_forward_{forward_id}_q_pe.pt"
|
||||
q_pe = torch.load(os.path.join(flashinfer_folder, file_name)).view(q_len,128,64).to(device="cuda")
|
||||
q = torch.cat([q_nope, q_pe], dim=-1)
|
||||
kv_cache = torch.load(kv_cache_path).to(device="cuda")
|
||||
pages, page_size, _, head_dim = kv_cache.shape
|
||||
kv_cache = kv_cache.view(pages, page_size, head_dim)
|
||||
ckv, k_pe = torch.split(kv_cache, [512, 64], dim=-1)
|
||||
|
||||
graph.replay()
|
||||
|
||||
k = (
|
||||
torch.cat([ckv, k_pe], dim=-1)
|
||||
.view(-1, 1, 512 + 64)
|
||||
.repeat_interleave(num_heads, dim=1)
|
||||
)
|
||||
v = ckv.view(-1, 1, 512).repeat_interleave(num_heads, dim=1)
|
||||
|
||||
print(k[:kv_len].shape)
|
||||
print(v[:kv_len].shape)
|
||||
|
||||
attn_ref, lse_ref = attention_ref(
|
||||
max_batch_size,
|
||||
torch.cat([q_nope, q_pe], dim=-1),
|
||||
k[:kv_len],
|
||||
v[:kv_len],
|
||||
True,
|
||||
192 ** (-0.5)
|
||||
)
|
||||
print(attn_ref.shape)
|
||||
kv_len_arr = torch.tensor([kv_len], dtype=torch.int32, device="cuda")
|
||||
qo_indptr = torch.tensor([0, q_len], dtype=torch.int32, device="cuda")
|
||||
wrapper.plan(
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
kv_len_arr,
|
||||
128,
|
||||
512,
|
||||
64,
|
||||
page_size,
|
||||
192 ** (-0.5),
|
||||
torch.bfloat16,
|
||||
torch.bfloat16,
|
||||
)
|
||||
|
||||
q_nope_buf.copy_(q_nope)
|
||||
q_pe_buf.copy_(q_pe)
|
||||
kv_buf[:pages].copy_(kv_cache)
|
||||
|
||||
torch.cuda.synchronize()
|
||||
graph.replay()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
# ref_torch
|
||||
k = (
|
||||
torch.cat([ckv, k_pe], dim=-1)
|
||||
.view(-1, 1, 512 + 64)
|
||||
.repeat_interleave(num_heads, dim=1)
|
||||
)
|
||||
v = ckv.view(-1, 1, 512).repeat_interleave(num_heads, dim=1)
|
||||
attn_ref, lse_ref = attention_ref_torch(
|
||||
max_batch_size,
|
||||
q,
|
||||
k[:kv_len],
|
||||
v[:kv_len],
|
||||
False,
|
||||
192 ** (-0.5)
|
||||
)
|
||||
torch.testing.assert_close(attn_output, attn_ref, rtol=1e-3, atol=1e-3)
|
||||
|
||||
# ref_triton
|
||||
attn_logits = torch.empty(
|
||||
(
|
||||
max_batch_size,
|
||||
num_heads,
|
||||
4, #num_kv_splits # follow vLLM, fix it TODO
|
||||
512 + 1,
|
||||
),
|
||||
dtype=torch.float32,
|
||||
device = "cuda"
|
||||
)
|
||||
|
||||
triton_ref = torch.zeros_like(q_nope)
|
||||
page_table = torch.arange(max_pages, dtype=torch.int32, device="cuda")
|
||||
ckv_with_pe = torch.cat([ckv, k_pe], dim=-1).contiguous().view(pages, page_size, 1, 576)
|
||||
ckv = ckv.view(pages, page_size, 1, 512)
|
||||
decode_attention_fwd_grouped(q, ckv_with_pe, ckv, triton_ref,
|
||||
page_table,
|
||||
kv_len_arr, attn_logits,
|
||||
4, #num_kv_splits # follow vLLM, fix it TODO
|
||||
192 ** (-0.5),
|
||||
page_size)
|
||||
|
||||
torch.testing.assert_close(attn_output, triton_ref, rtol=1e-3, atol=1e-3)
|
||||
|
||||
#file_name = f"./flashinfer_output/layer_{layer_id}_forward_{forward_id}_attn_output.pt"
|
||||
#ktrans_output = torch.load(file_name)
|
||||
#torch.testing.assert_close(attn_output, ktrans_output.squeeze(1), rtol=1e-3, atol=1e-3)
|
||||
print("test past")
|
||||
|
||||
torch.testing.assert_close(attn_output, attn_ref, rtol=1e-3, atol=1e-3)
|
||||
print("test past")
|
Loading…
Add table
Add a link
Reference in a new issue