mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2025-09-10 15:29:39 +00:00
fix flashinfer precision
This commit is contained in:
parent
96d75d53df
commit
d453c320f1
5 changed files with 151 additions and 61 deletions
|
@ -25,7 +25,7 @@ from ktransformers.operators.triton_attention import decode_attention_fwd_groupe
|
|||
import os
|
||||
from ktransformers.operators.flashinfer_wrapper import flashinfer_enabled
|
||||
if flashinfer_enabled:
|
||||
from ktransformers.operators.flashinfer_wrapper import MLAWrapperSingleton, attention_ref
|
||||
from ktransformers.operators.flashinfer_wrapper import MLAWrapperSingleton
|
||||
|
||||
logger = logging.getLogger("attention")
|
||||
|
||||
|
|
|
@ -1,9 +1,11 @@
|
|||
'''
|
||||
Description : flashinfer MLA wrapper
|
||||
Author : Boxin Zhang
|
||||
Version : 0.2.2
|
||||
Version : 0.2.3
|
||||
'''
|
||||
import torch
|
||||
import os
|
||||
from ktransformers.operators.triton_attention import decode_attention_fwd_grouped
|
||||
|
||||
flashinfer_enabled = False
|
||||
|
||||
|
@ -17,7 +19,7 @@ except ImportError:
|
|||
|
||||
import math
|
||||
|
||||
def attention_ref(
|
||||
def attention_ref_torch(
|
||||
batch_size,
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
|
@ -139,11 +141,6 @@ class MLAWrapper():
|
|||
)
|
||||
|
||||
def run(self, q_nope, q_pe, ckv, k_pe, return_lse = False):
|
||||
#print("run")
|
||||
#print(self.wrapper._qo_indptr_buf)
|
||||
#print(self.wrapper._kv_indptr_buf)
|
||||
#print(self.wrapper._kv_indices_buf)
|
||||
#print(self.wrapper._kv_len_arr_buf)
|
||||
return self.wrapper.run(q_nope, q_pe, ckv, k_pe, return_lse = return_lse)
|
||||
|
||||
class MLAWrapperSingleton():
|
||||
|
@ -203,20 +200,58 @@ class MLAWrapperSingleton():
|
|||
wrapper.kv_indices_buf = torch.arange(0, max_pages, dtype=torch.int32, device=device)
|
||||
wrapper.wrapper._kv_indices_buf = wrapper.kv_indices_buf
|
||||
|
||||
def checksame():
|
||||
flashinfer_folder = "./flashinfer_output"
|
||||
flashinfer_folder = "./kv_cache_flashinfer"
|
||||
triton_folder = "./triton_output"
|
||||
triton_folder = "./kv_cache_triton"
|
||||
|
||||
max_layer_id = 1
|
||||
max_forward_id = 2
|
||||
|
||||
for forward_id in range(0, 19):
|
||||
print("forward_id", forward_id)
|
||||
for layer_id in range(max_layer_id):
|
||||
print(layer_id)
|
||||
#file_name = f"layer_{layer_id}_forward_{forward_id}_attn_output.pt"
|
||||
#file_name = f"layer_{layer_id}_forward_{forward_id}_q_pe.pt"
|
||||
file_name = f"layer_{layer_id}.pt"
|
||||
|
||||
flashinfer_path = os.path.join(flashinfer_folder, file_name)
|
||||
triton_path = os.path.join(triton_folder, file_name)
|
||||
|
||||
if not os.path.exists(triton_path):
|
||||
print(f"{file_name} not exist in {triton_folder}")
|
||||
continue
|
||||
if not os.path.exists(flashinfer_path):
|
||||
print(f"{file_name} not exist in {flashinfer_folder}")
|
||||
continue
|
||||
|
||||
|
||||
flashinfer_tensor = torch.load(flashinfer_path)[1:2, :62]#
|
||||
triton_tensor = torch.load(triton_path)[1:2, :62]#.squeeze(1)#
|
||||
try:
|
||||
torch.testing.assert_close(flashinfer_tensor, triton_tensor, rtol=1e-9, atol=1e-9)
|
||||
except AssertionError as e:
|
||||
print(e)
|
||||
|
||||
if __name__ == "__main__":
|
||||
torch.set_default_dtype(torch.bfloat16)
|
||||
|
||||
#checksame()
|
||||
#exit(0)
|
||||
|
||||
max_batch_size = 1
|
||||
max_pages = 128
|
||||
max_pages = 64
|
||||
page_size = 64
|
||||
num_heads = 128
|
||||
|
||||
# warm-up
|
||||
kv_len = 4023
|
||||
q_len = 1
|
||||
q_nope = torch.randn((q_len, num_heads, 512), dtype=torch.bfloat16, device="cuda")
|
||||
q_pe = torch.randn((q_len, num_heads, 64), dtype=torch.bfloat16, device="cuda")
|
||||
ckv = torch.randn((max_pages, page_size, 512), dtype=torch.bfloat16, device="cuda")
|
||||
k_pe = torch.randn((max_pages, page_size, 64), dtype=torch.bfloat16, device="cuda")
|
||||
q_nope_buf = torch.randn((q_len, num_heads, 512), dtype=torch.bfloat16, device="cuda")
|
||||
q_pe_buf = torch.randn((q_len, num_heads, 64), dtype=torch.bfloat16, device="cuda")
|
||||
kv_buf = torch.randn((max_pages, page_size, 576), dtype=torch.bfloat16, device="cuda")
|
||||
ckv, k_pe = torch.split(kv_buf, [512, 64], dim=-1)
|
||||
|
||||
|
||||
wrapper = MLAWrapperSingleton.get_instance(
|
||||
|
@ -241,18 +276,41 @@ if __name__ == "__main__":
|
|||
torch.bfloat16,
|
||||
)
|
||||
|
||||
attn_output = wrapper.run(q_nope, q_pe, ckv, k_pe)
|
||||
attn_output = wrapper.run(q_nope_buf, q_pe_buf, ckv, k_pe)
|
||||
print(attn_output.shape)
|
||||
|
||||
graph = torch.cuda.CUDAGraph()
|
||||
with torch.cuda.graph(graph):
|
||||
attn_output = wrapper.run(q_nope, q_pe, ckv, k_pe)
|
||||
attn_output = wrapper.run(q_nope_buf, q_pe_buf, ckv, k_pe)
|
||||
# warm-up finished
|
||||
|
||||
for forward_id in range(0, 1):
|
||||
print("forward_id", forward_id)
|
||||
for layer_id in range(1):
|
||||
print(layer_id)
|
||||
flashinfer_folder = "./kv_cache_flashinfer"
|
||||
forward_id = 17
|
||||
layer_id = 0
|
||||
file_name = f"layer_{layer_id}.pt"
|
||||
kv_cache_path = os.path.join(flashinfer_folder, file_name)
|
||||
flashinfer_folder = "./flashinfer_output"
|
||||
|
||||
q_len = 1
|
||||
kv_len = 126
|
||||
file_name = f"layer_{layer_id}_forward_{forward_id}_q_nope.pt"
|
||||
q_nope = torch.load(os.path.join(flashinfer_folder, file_name)).view(q_len,128,512).to(device="cuda")
|
||||
file_name = f"layer_{layer_id}_forward_{forward_id}_q_pe.pt"
|
||||
q_pe = torch.load(os.path.join(flashinfer_folder, file_name)).view(q_len,128,64).to(device="cuda")
|
||||
q = torch.cat([q_nope, q_pe], dim=-1)
|
||||
kv_cache = torch.load(kv_cache_path).to(device="cuda")
|
||||
pages, page_size, _, head_dim = kv_cache.shape
|
||||
kv_cache = kv_cache.view(pages, page_size, head_dim)
|
||||
ckv, k_pe = torch.split(kv_cache, [512, 64], dim=-1)
|
||||
|
||||
kv_len = 6789
|
||||
kv_len_arr = torch.tensor([kv_len], dtype=torch.int32, device="cuda")
|
||||
qo_indptr = torch.tensor([0, q_len], dtype=torch.int32, device="cuda")
|
||||
wrapper.plan(
|
||||
qo_indptr,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
kv_len_arr,
|
||||
|
@ -265,27 +323,58 @@ if __name__ == "__main__":
|
|||
torch.bfloat16,
|
||||
)
|
||||
|
||||
graph.replay()
|
||||
q_nope_buf.copy_(q_nope)
|
||||
q_pe_buf.copy_(q_pe)
|
||||
kv_buf[:pages].copy_(kv_cache)
|
||||
|
||||
torch.cuda.synchronize()
|
||||
graph.replay()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
# ref_torch
|
||||
k = (
|
||||
torch.cat([ckv, k_pe], dim=-1)
|
||||
.view(-1, 1, 512 + 64)
|
||||
.repeat_interleave(num_heads, dim=1)
|
||||
)
|
||||
v = ckv.view(-1, 1, 512).repeat_interleave(num_heads, dim=1)
|
||||
|
||||
print(k[:kv_len].shape)
|
||||
print(v[:kv_len].shape)
|
||||
|
||||
attn_ref, lse_ref = attention_ref(
|
||||
attn_ref, lse_ref = attention_ref_torch(
|
||||
max_batch_size,
|
||||
torch.cat([q_nope, q_pe], dim=-1),
|
||||
q,
|
||||
k[:kv_len],
|
||||
v[:kv_len],
|
||||
True,
|
||||
False,
|
||||
192 ** (-0.5)
|
||||
)
|
||||
print(attn_ref.shape)
|
||||
|
||||
torch.testing.assert_close(attn_output, attn_ref, rtol=1e-3, atol=1e-3)
|
||||
|
||||
# ref_triton
|
||||
attn_logits = torch.empty(
|
||||
(
|
||||
max_batch_size,
|
||||
num_heads,
|
||||
4, #num_kv_splits # follow vLLM, fix it TODO
|
||||
512 + 1,
|
||||
),
|
||||
dtype=torch.float32,
|
||||
device = "cuda"
|
||||
)
|
||||
|
||||
triton_ref = torch.zeros_like(q_nope)
|
||||
page_table = torch.arange(max_pages, dtype=torch.int32, device="cuda")
|
||||
ckv_with_pe = torch.cat([ckv, k_pe], dim=-1).contiguous().view(pages, page_size, 1, 576)
|
||||
ckv = ckv.view(pages, page_size, 1, 512)
|
||||
decode_attention_fwd_grouped(q, ckv_with_pe, ckv, triton_ref,
|
||||
page_table,
|
||||
kv_len_arr, attn_logits,
|
||||
4, #num_kv_splits # follow vLLM, fix it TODO
|
||||
192 ** (-0.5),
|
||||
page_size)
|
||||
|
||||
torch.testing.assert_close(attn_output, triton_ref, rtol=1e-3, atol=1e-3)
|
||||
|
||||
#file_name = f"./flashinfer_output/layer_{layer_id}_forward_{forward_id}_attn_output.pt"
|
||||
#ktrans_output = torch.load(file_name)
|
||||
#torch.testing.assert_close(attn_output, ktrans_output.squeeze(1), rtol=1e-3, atol=1e-3)
|
||||
print("test past")
|
||||
|
||||
|
|
|
@ -344,7 +344,7 @@ class TransformersInterface(BackendInterfaceBase):
|
|||
MLAWrapperSingleton.plan_all(None,None,None,self.active_cache_position.to(torch.int32)+1,
|
||||
num_heads=self.model.config.num_attention_heads, head_dim_ckv=self.model.config.kv_lora_rank,
|
||||
head_dim_kpe=self.model.config.qk_rope_head_dim, page_size=self.cache.page_size,
|
||||
sm_scale=(self.model.config.qk_rope_head_dim + self.model.config.qk_nope_head_dim) ** (-0.5), q_data_type=torch.bfloat16, kv_data_type=torch.bfloat16)
|
||||
sm_scale=self.model.model.layers[0].self_attn.softmax_scale, q_data_type=torch.bfloat16, kv_data_type=torch.bfloat16)
|
||||
next_token = self.decode_one_tokens()
|
||||
self.profiler.inc("decode")
|
||||
if next_token == self.tokenizer.eos_token_id or "<|im_end|>" == self.tokenizer.decode(next_token):
|
||||
|
|
|
@ -85,7 +85,8 @@ def main(output_path, api_url, model_name, auth_token, format_tabs,problem_file,
|
|||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="API Generate Tester")
|
||||
parser.add_argument("--api_url", type=str, default="https://api.siliconflow.cn/v1/chat/completions", help="API URL")
|
||||
#parser.add_argument("--api_url", type=str, default="https://api.siliconflow.cn/v1/chat/completions", help="API URL")
|
||||
parser.add_argument("--api_url", type=str, default="http://localhost:10002/v1/chat/completions", help="API URL")
|
||||
parser.add_argument("--model_name", type=str, default="Pro/deepseek-ai/DeepSeek-V3", help="Model Name")
|
||||
parser.add_argument("--out_path", type=str, default="results/api/eval_b.jsonl", help="Output Path")
|
||||
parser.add_argument("--auth_token", type=str, default=None, help="Auth Token")
|
||||
|
|
|
@ -239,7 +239,7 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud
|
|||
if use_flashinfer_mla:
|
||||
MLAWrapperSingleton.plan_all(None,None,None,position_ids.squeeze(1)+1,
|
||||
num_heads, head_dim_ckv, head_dim_kpe, past_key_values.page_size,
|
||||
q_head_dim ** (-0.5), torch.bfloat16, torch.bfloat16)
|
||||
model.model.layers[0].self_attn.softmax_scale, torch.bfloat16, torch.bfloat16)
|
||||
global warm_uped
|
||||
if use_cuda_graph and ( (warm_uped == True and int(i) == 1) or (warm_uped == False and int(i) == 2) ):
|
||||
warm_uped = True
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue