fix flashinfer precision

This commit is contained in:
Atream 2025-03-07 14:07:00 +00:00
parent 96d75d53df
commit d453c320f1
5 changed files with 151 additions and 61 deletions

View file

@ -25,7 +25,7 @@ from ktransformers.operators.triton_attention import decode_attention_fwd_groupe
import os
from ktransformers.operators.flashinfer_wrapper import flashinfer_enabled
if flashinfer_enabled:
from ktransformers.operators.flashinfer_wrapper import MLAWrapperSingleton, attention_ref
from ktransformers.operators.flashinfer_wrapper import MLAWrapperSingleton
logger = logging.getLogger("attention")

View file

@ -1,9 +1,11 @@
'''
Description : flashinfer MLA wrapper
Author : Boxin Zhang
Version : 0.2.2
Version : 0.2.3
'''
import torch
import os
from ktransformers.operators.triton_attention import decode_attention_fwd_grouped
flashinfer_enabled = False
@ -17,7 +19,7 @@ except ImportError:
import math
def attention_ref(
def attention_ref_torch(
batch_size,
q: torch.Tensor,
k: torch.Tensor,
@ -122,7 +124,7 @@ class MLAWrapper():
if kv_indices is None:
assert self.max_batch_size == 1
kv_indices = self.kv_indices_buf
self.wrapper.plan(
qo_indptr,
kv_indptr,
@ -139,11 +141,6 @@ class MLAWrapper():
)
def run(self, q_nope, q_pe, ckv, k_pe, return_lse = False):
#print("run")
#print(self.wrapper._qo_indptr_buf)
#print(self.wrapper._kv_indptr_buf)
#print(self.wrapper._kv_indices_buf)
#print(self.wrapper._kv_len_arr_buf)
return self.wrapper.run(q_nope, q_pe, ckv, k_pe, return_lse = return_lse)
class MLAWrapperSingleton():
@ -202,21 +199,59 @@ class MLAWrapperSingleton():
wrapper.kv_indptr_buf[1] = max_pages # assert max_batch_size=1 here.
wrapper.kv_indices_buf = torch.arange(0, max_pages, dtype=torch.int32, device=device)
wrapper.wrapper._kv_indices_buf = wrapper.kv_indices_buf
def checksame():
flashinfer_folder = "./flashinfer_output"
flashinfer_folder = "./kv_cache_flashinfer"
triton_folder = "./triton_output"
triton_folder = "./kv_cache_triton"
max_layer_id = 1
max_forward_id = 2
for forward_id in range(0, 19):
print("forward_id", forward_id)
for layer_id in range(max_layer_id):
print(layer_id)
#file_name = f"layer_{layer_id}_forward_{forward_id}_attn_output.pt"
#file_name = f"layer_{layer_id}_forward_{forward_id}_q_pe.pt"
file_name = f"layer_{layer_id}.pt"
flashinfer_path = os.path.join(flashinfer_folder, file_name)
triton_path = os.path.join(triton_folder, file_name)
if not os.path.exists(triton_path):
print(f"{file_name} not exist in {triton_folder}")
continue
if not os.path.exists(flashinfer_path):
print(f"{file_name} not exist in {flashinfer_folder}")
continue
flashinfer_tensor = torch.load(flashinfer_path)[1:2, :62]#
triton_tensor = torch.load(triton_path)[1:2, :62]#.squeeze(1)#
try:
torch.testing.assert_close(flashinfer_tensor, triton_tensor, rtol=1e-9, atol=1e-9)
except AssertionError as e:
print(e)
if __name__ == "__main__":
torch.set_default_dtype(torch.bfloat16)
#checksame()
#exit(0)
max_batch_size = 1
max_pages = 128
max_pages = 64
page_size = 64
num_heads = 128
# warm-up
kv_len = 4023
q_len = 1
q_nope = torch.randn((q_len, num_heads, 512), dtype=torch.bfloat16, device="cuda")
q_pe = torch.randn((q_len, num_heads, 64), dtype=torch.bfloat16, device="cuda")
ckv = torch.randn((max_pages, page_size, 512), dtype=torch.bfloat16, device="cuda")
k_pe = torch.randn((max_pages, page_size, 64), dtype=torch.bfloat16, device="cuda")
q_nope_buf = torch.randn((q_len, num_heads, 512), dtype=torch.bfloat16, device="cuda")
q_pe_buf = torch.randn((q_len, num_heads, 64), dtype=torch.bfloat16, device="cuda")
kv_buf = torch.randn((max_pages, page_size, 576), dtype=torch.bfloat16, device="cuda")
ckv, k_pe = torch.split(kv_buf, [512, 64], dim=-1)
wrapper = MLAWrapperSingleton.get_instance(
@ -241,51 +276,105 @@ if __name__ == "__main__":
torch.bfloat16,
)
attn_output = wrapper.run(q_nope, q_pe, ckv, k_pe)
attn_output = wrapper.run(q_nope_buf, q_pe_buf, ckv, k_pe)
print(attn_output.shape)
graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(graph):
attn_output = wrapper.run(q_nope, q_pe, ckv, k_pe)
kv_len = 6789
kv_len_arr = torch.tensor([kv_len], dtype=torch.int32, device="cuda")
qo_indptr = torch.tensor([0, q_len], dtype=torch.int32, device="cuda")
wrapper.plan(
qo_indptr,
None,
None,
kv_len_arr,
128,
512,
64,
page_size,
192 ** (-0.5),
torch.bfloat16,
torch.bfloat16,
)
attn_output = wrapper.run(q_nope_buf, q_pe_buf, ckv, k_pe)
# warm-up finished
for forward_id in range(0, 1):
print("forward_id", forward_id)
for layer_id in range(1):
print(layer_id)
flashinfer_folder = "./kv_cache_flashinfer"
forward_id = 17
layer_id = 0
file_name = f"layer_{layer_id}.pt"
kv_cache_path = os.path.join(flashinfer_folder, file_name)
flashinfer_folder = "./flashinfer_output"
q_len = 1
kv_len = 126
file_name = f"layer_{layer_id}_forward_{forward_id}_q_nope.pt"
q_nope = torch.load(os.path.join(flashinfer_folder, file_name)).view(q_len,128,512).to(device="cuda")
file_name = f"layer_{layer_id}_forward_{forward_id}_q_pe.pt"
q_pe = torch.load(os.path.join(flashinfer_folder, file_name)).view(q_len,128,64).to(device="cuda")
q = torch.cat([q_nope, q_pe], dim=-1)
kv_cache = torch.load(kv_cache_path).to(device="cuda")
pages, page_size, _, head_dim = kv_cache.shape
kv_cache = kv_cache.view(pages, page_size, head_dim)
ckv, k_pe = torch.split(kv_cache, [512, 64], dim=-1)
graph.replay()
k = (
torch.cat([ckv, k_pe], dim=-1)
.view(-1, 1, 512 + 64)
.repeat_interleave(num_heads, dim=1)
)
v = ckv.view(-1, 1, 512).repeat_interleave(num_heads, dim=1)
print(k[:kv_len].shape)
print(v[:kv_len].shape)
attn_ref, lse_ref = attention_ref(
max_batch_size,
torch.cat([q_nope, q_pe], dim=-1),
k[:kv_len],
v[:kv_len],
True,
192 ** (-0.5)
)
print(attn_ref.shape)
kv_len_arr = torch.tensor([kv_len], dtype=torch.int32, device="cuda")
qo_indptr = torch.tensor([0, q_len], dtype=torch.int32, device="cuda")
wrapper.plan(
None,
None,
None,
kv_len_arr,
128,
512,
64,
page_size,
192 ** (-0.5),
torch.bfloat16,
torch.bfloat16,
)
q_nope_buf.copy_(q_nope)
q_pe_buf.copy_(q_pe)
kv_buf[:pages].copy_(kv_cache)
torch.cuda.synchronize()
graph.replay()
torch.cuda.synchronize()
# ref_torch
k = (
torch.cat([ckv, k_pe], dim=-1)
.view(-1, 1, 512 + 64)
.repeat_interleave(num_heads, dim=1)
)
v = ckv.view(-1, 1, 512).repeat_interleave(num_heads, dim=1)
attn_ref, lse_ref = attention_ref_torch(
max_batch_size,
q,
k[:kv_len],
v[:kv_len],
False,
192 ** (-0.5)
)
torch.testing.assert_close(attn_output, attn_ref, rtol=1e-3, atol=1e-3)
# ref_triton
attn_logits = torch.empty(
(
max_batch_size,
num_heads,
4, #num_kv_splits # follow vLLM, fix it TODO
512 + 1,
),
dtype=torch.float32,
device = "cuda"
)
triton_ref = torch.zeros_like(q_nope)
page_table = torch.arange(max_pages, dtype=torch.int32, device="cuda")
ckv_with_pe = torch.cat([ckv, k_pe], dim=-1).contiguous().view(pages, page_size, 1, 576)
ckv = ckv.view(pages, page_size, 1, 512)
decode_attention_fwd_grouped(q, ckv_with_pe, ckv, triton_ref,
page_table,
kv_len_arr, attn_logits,
4, #num_kv_splits # follow vLLM, fix it TODO
192 ** (-0.5),
page_size)
torch.testing.assert_close(attn_output, triton_ref, rtol=1e-3, atol=1e-3)
#file_name = f"./flashinfer_output/layer_{layer_id}_forward_{forward_id}_attn_output.pt"
#ktrans_output = torch.load(file_name)
#torch.testing.assert_close(attn_output, ktrans_output.squeeze(1), rtol=1e-3, atol=1e-3)
print("test past")
torch.testing.assert_close(attn_output, attn_ref, rtol=1e-3, atol=1e-3)
print("test past")

View file

@ -344,7 +344,7 @@ class TransformersInterface(BackendInterfaceBase):
MLAWrapperSingleton.plan_all(None,None,None,self.active_cache_position.to(torch.int32)+1,
num_heads=self.model.config.num_attention_heads, head_dim_ckv=self.model.config.kv_lora_rank,
head_dim_kpe=self.model.config.qk_rope_head_dim, page_size=self.cache.page_size,
sm_scale=(self.model.config.qk_rope_head_dim + self.model.config.qk_nope_head_dim) ** (-0.5), q_data_type=torch.bfloat16, kv_data_type=torch.bfloat16)
sm_scale=self.model.model.layers[0].self_attn.softmax_scale, q_data_type=torch.bfloat16, kv_data_type=torch.bfloat16)
next_token = self.decode_one_tokens()
self.profiler.inc("decode")
if next_token == self.tokenizer.eos_token_id or "<|im_end|>" == self.tokenizer.decode(next_token):

View file

@ -85,7 +85,8 @@ def main(output_path, api_url, model_name, auth_token, format_tabs,problem_file,
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="API Generate Tester")
parser.add_argument("--api_url", type=str, default="https://api.siliconflow.cn/v1/chat/completions", help="API URL")
#parser.add_argument("--api_url", type=str, default="https://api.siliconflow.cn/v1/chat/completions", help="API URL")
parser.add_argument("--api_url", type=str, default="http://localhost:10002/v1/chat/completions", help="API URL")
parser.add_argument("--model_name", type=str, default="Pro/deepseek-ai/DeepSeek-V3", help="Model Name")
parser.add_argument("--out_path", type=str, default="results/api/eval_b.jsonl", help="Output Path")
parser.add_argument("--auth_token", type=str, default=None, help="Auth Token")

View file

@ -239,7 +239,7 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud
if use_flashinfer_mla:
MLAWrapperSingleton.plan_all(None,None,None,position_ids.squeeze(1)+1,
num_heads, head_dim_ckv, head_dim_kpe, past_key_values.page_size,
q_head_dim ** (-0.5), torch.bfloat16, torch.bfloat16)
model.model.layers[0].self_attn.softmax_scale, torch.bfloat16, torch.bfloat16)
global warm_uped
if use_cuda_graph and ( (warm_uped == True and int(i) == 1) or (warm_uped == False and int(i) == 2) ):
warm_uped = True