[fix](test): fix import kt-kernel (#1728)

This commit is contained in:
ErvinXie 2025-12-17 19:46:32 +08:00 committed by GitHub
parent 6fc4080a7d
commit a8667ddb58
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
33 changed files with 1063 additions and 1151 deletions

View file

@ -1,59 +1,62 @@
import logging
import os,sys
import os, sys
import time
from typing import Optional
os.environ["BLAS_NUM_THREADS"] = "1"
sys.path.insert(0, os.path.dirname(__file__) + '/../build')
import kt_kernel_ext
sys.path.insert(0, os.path.dirname(__file__) + "/../build")
from kt_kernel import kt_kernel_ext
from kt_kernel_ext.kvcache import ggml_type
import torch
from torch import inf, nn
from torch.nn import init
from torch_attention import apply_rotary_pos_emb,DeepseekV2RMSNorm,KDeepSeekV3Cache,DeepseekV3YarnRotaryEmbedding
from torch_attention import apply_rotary_pos_emb, DeepseekV2RMSNorm, KDeepSeekV3Cache, DeepseekV3YarnRotaryEmbedding
logger = logging.getLogger("reader")
from gguf.gguf_reader import GGUFReader
def load_fp32_tensor_raw(file_path):
# return torch.zeros(shape)
with open(file_path, 'rb') as f:
with open(file_path, "rb") as f:
raw_data = f.read()
tensor = torch.frombuffer(raw_data, dtype=torch.float32)
return tensor
def load_fp16_tensor(file_path, shape=None):
# return load_fp32_tensor(file_path, shape)
return load_fp32_tensor_raw(file_path)
# return torch.zeros(shape)
with open(file_path, 'rb') as f:
with open(file_path, "rb") as f:
raw_data = f.read()
tensor = torch.frombuffer(raw_data, dtype=weight_type)
tensor = tensor.view(shape) # 根据你的 shape reshape
return tensor
def load_fp32_tensor(file_path, shape):
# return torch.zeros(shape)
with open(file_path, 'rb') as f:
with open(file_path, "rb") as f:
raw_data = f.read()
tensor = torch.frombuffer(raw_data, dtype=torch.float32)
tensor = tensor.view(shape) # 根据你的 shape reshape
return tensor
def test_torch():
torch.set_grad_enabled(False)
hidden_states_to_check_decode = load_fp16_tensor('./debug_decode/query_0_tp_0_input.bin')
hidden_states_to_check_prefill = load_fp16_tensor('./debug_prefill/query_0_tp_0_input.bin')
hidden_states_to_check_decode = load_fp16_tensor("./debug_decode/query_0_tp_0_input.bin")
hidden_states_to_check_prefill = load_fp16_tensor("./debug_prefill/query_0_tp_0_input.bin")
# diff = torch.abs(hidden_states_to_check_prefill - hidden_states_to_check_decode).max()
# print("hidden_states diff -> ", diff)
q_lora_to_check_decode = load_fp16_tensor('./debug_decode/query_0_tp_0_qlora.bin')
q_lora_to_check_test_decode = load_fp16_tensor('./debug_decode/query_0_tp_0_qlora_test.bin')
q_lora_to_check_prefill = load_fp16_tensor('./debug_prefill/query_0_tp_0_qlora.bin')
q_lora_to_check_test_prefill = load_fp16_tensor('./debug_prefill/query_0_tp_0_qlora_test.bin')
q_lora_to_check_decode = load_fp16_tensor("./debug_decode/query_0_tp_0_qlora.bin")
q_lora_to_check_test_decode = load_fp16_tensor("./debug_decode/query_0_tp_0_qlora_test.bin")
q_lora_to_check_prefill = load_fp16_tensor("./debug_prefill/query_0_tp_0_qlora.bin")
q_lora_to_check_test_prefill = load_fp16_tensor("./debug_prefill/query_0_tp_0_qlora_test.bin")
# diff = torch.abs(q_lora_to_check_prefill - q_lora_to_check_decode).max()
# diff_test = torch.abs(q_lora_to_check_prefill - q_lora_to_check_decode).max()
# print("q_lora max diff -> ", diff)
@ -63,8 +66,6 @@ def test_torch():
# print("q_lora mae -> ", mae)
# print("q_lora mae test -> ", mae_test)
# q_lora_norm = q_a_layernorm(q_lora)
# q_lora_norm_to_check = load_fp16_tensor('./debug/query_0_tp_0_qlora_norm.bin', q_lora_norm.shape)
# q_lora_norm_to_check_test = load_fp16_tensor('./debug/query_0_tp_0_qlora_norm_test.bin', q_lora_norm.shape)
@ -76,7 +77,7 @@ def test_torch():
# print("q_lora_norm mae -> ", mae)
# print("q_lora_norm diff test -> ", diff_test)
# print("q_lora_norm mae test -> ", mae_test)
# q = q_b_proj(q_lora_norm)
# for v3, bsz, qlen, num_heads(128), qk_head_dim(192=128(nope)+64(rope))
# q = q.view(qlen, num_heads, nope_size+rope_size)
@ -85,7 +86,7 @@ def test_torch():
# q_nope, q_pe = torch.split(
# q, [nope_size, rope_size], dim=-1
# )
# compressed_kv is [qlen, kv_lora_rank(512) + rope(64)]
# compressed_kv = kv_a_proj_with_mqa(batch_hidden_states)
# compressed_kv is [qlen, kv_lora_rank(512)], k_pe is [qlen, rope(64)]
@ -94,12 +95,11 @@ def test_torch():
# )
# compressed_kv = compressed_kv.contiguous()
# compressed_kv_page_0 = compressed_kv[0:page_size, :]
compressed_kv_to_check_decode = load_fp16_tensor('./debug_decode/query_0_tp_0_page_0_kv_lora_rank')
compressed_kv_to_check_prefill = load_fp16_tensor('./debug_prefill/query_0_tp_0_page_0_kv_lora_rank')
compressed_kv_to_check_decode = load_fp16_tensor("./debug_decode/query_0_tp_0_page_0_kv_lora_rank")
compressed_kv_to_check_prefill = load_fp16_tensor("./debug_prefill/query_0_tp_0_page_0_kv_lora_rank")
# diff = torch.abs(compressed_kv_to_check_prefill - compressed_kv_to_check_decode).max()
# mae = torch.mean(torch.abs(compressed_kv_to_check_prefill - compressed_kv_to_check_decode))
# mae = torch.mean(torch.abs(compressed_kv_to_check_prefill - compressed_kv_to_check_decode))
# print("compressed_kv diff -> ", diff)
# print("compressed_kv mae -> ", mae)
@ -107,20 +107,17 @@ def test_torch():
# k_pe is [qlen, 1, qk_rope_head_dim(64)]
# compressed_kv_page_0 = compressed_kv[0:page_size, :]
compressed_kv_to_check_decode = load_fp16_tensor('./debug_decode/query_0_tp_0_page_0_kv_lora_rank_norm')
compressed_kv_to_check_prefill = load_fp16_tensor('./debug_prefill/query_0_tp_0_page_0_kv_lora_rank_norm')
compressed_kv_to_check_decode = load_fp16_tensor("./debug_decode/query_0_tp_0_page_0_kv_lora_rank_norm")
compressed_kv_to_check_prefill = load_fp16_tensor("./debug_prefill/query_0_tp_0_page_0_kv_lora_rank_norm")
# diff = torch.abs(compressed_kv_page_0 - compressed_kv_to_check).max()
# mae = torch.mean(torch.abs(compressed_kv_page_0 - compressed_kv_to_check))
# print("compressed_kv diff norm -> ", diff)
# print("compressed_kv mae norm -> ", mae)
# k_pe = k_pe.view(qlen, 1, rope_size)
# compressed_kv is [qlen, 1, kv_lora_rank(512)]
# compressed_kv = compressed_kv.view(qlen, 1, kv_lora_rank)
# cos, sin = rotary_emb(q_pe, batch_position_ids)
# q_nope_check = q_nope.transpose(0, 1) # qlen is 1, no GPU overhead, same below
@ -135,11 +132,11 @@ def test_torch():
# print("q_nope[0] mae -> ", mae)
# print("q_nope[0] diff test -> ", diff_test)
# print("q_nope[0] mae test -> ", mae_test)
# q_pe_nope = q_pe.transpose(0,1)
q_pe_0_to_check_decode = load_fp16_tensor('./debug_decode/query_0_tp_0_q_rope')
q_pe_0_to_check_prefill = load_fp16_tensor('./debug_prefill/query_0_tp_0_q_rope')
q_pe_0_to_check_decode = load_fp16_tensor("./debug_decode/query_0_tp_0_q_rope")
q_pe_0_to_check_prefill = load_fp16_tensor("./debug_prefill/query_0_tp_0_q_rope")
# q_pe_0_to_check_decode_test = load_fp16_tensor('./debug_decode/query_0_tp_0_q_rope_test')
# q_pe_0_to_check_prefill_test = load_fp16_tensor('./debug_prefill/query_0_tp_0_q_rope_test')
@ -180,12 +177,11 @@ def test_torch():
# q_pe, k_pe = apply_rotary_pos_emb(q_pe.unsqueeze(0), k_pe.unsqueeze(0), cos, sin, unsqueeze_dim=1)
# q_pe = q_pe.squeeze(0)
# q_pe is [num_heads(128), qlen, qk_rope_head_dim(64)]
# q_pe.transpose_(0, 1)
# q_pe.transpose_(0, 1)
# diff = torch.abs(q_pe - q_new).max()
# print("q_pe diff -> ", diff)
# q_pe_0_to_check = load_fp16_tensor('./debug/query_0_tp_0_q_rope', q_pe[0].shape)
# diff = torch.abs(q_pe[0] - q_pe_0_to_check).max()
# mae = torch.mean(torch.abs(q_pe[0] - q_pe_0_to_check))
@ -240,7 +236,7 @@ def test_torch():
# if batch_compressed_kv is None or batch_k_pe is None:
# batch_compressed_kv = tmp_compressed_kv
# batch_k_pe = tmp_k_pe
# else:
# else:
# batch_compressed_kv = torch.cat((batch_compressed_kv, tmp_compressed_kv), dim=0)
# batch_k_pe = torch.cat((batch_k_pe, tmp_k_pe), dim=0)
# kv_total_len -= page_size
@ -250,16 +246,15 @@ def test_torch():
# if batch_compressed_kv is None or batch_k_pe is None:
# batch_compressed_kv = tmp_compressed_kv
# batch_k_pe = tmp_k_pe
# else:
# else:
# batch_compressed_kv = torch.cat((batch_compressed_kv, tmp_compressed_kv), dim=0)
# batch_k_pe = torch.cat((batch_k_pe, tmp_k_pe), dim=0)
# break
# batch_compressed_kv is [kv_total_len(k_len), kv_lora_rank(512)]
# batch_k_pe is [kv_total_len(k_len), qk_rope_head_dim(64)]
k_pe_to_check_decode = load_fp16_tensor('./debug_decode/query_0_tp_0_page_0_k_rope', (256,64))
k_pe_to_check_prefill = load_fp16_tensor('./debug_prefill/query_0_tp_0_page_0_k_rope', (256,64))
k_pe_to_check_decode = load_fp16_tensor("./debug_decode/query_0_tp_0_page_0_k_rope", (256, 64))
k_pe_to_check_prefill = load_fp16_tensor("./debug_prefill/query_0_tp_0_page_0_k_rope", (256, 64))
# diff = torch.abs(k_pe_to_check_prefill - k_pe_to_check_decode).max()
# mae = torch.mean(k_pe_to_check_prefill - k_pe_to_check_decode)
# print("k_pe diff -> ", diff)
@ -267,13 +262,13 @@ def test_torch():
# pe_weights = torch.matmul(q_pe,batch_k_pe.mT)
# kv_total_len = kv_page_nums * page_size
pe_weights_0_decode = load_fp16_tensor('./debug_decode/query_0_tp_0_pe_attention_weights', (1024,4096))
pe_weights_0_prefill = load_fp16_tensor('./debug_prefill/query_0_tp_0_pe_attention_weights', (1024,4096))
pe_weights_0_decode = load_fp16_tensor("./debug_decode/query_0_tp_0_pe_attention_weights", (1024, 4096))
pe_weights_0_prefill = load_fp16_tensor("./debug_prefill/query_0_tp_0_pe_attention_weights", (1024, 4096))
# diff = torch.abs(pe_weights[0] - pe_weights_0).max()
# print("pe_weights[0] diff -> ", diff)
# attention_weights = (pe_weights + torch.matmul(q_nope, batch_compressed_kv.mT))
# attention_weights = (pe_weights + torch.matmul(q_nope, batch_compressed_kv.mT))
# raw_weights = load_fp16_tensor('./debug/query_0_tp_0_raw_attention_weights', (1024, 4096))
# raw_weights = raw_weights[0:qlen, 0:kv_total_len]
@ -282,25 +277,23 @@ def test_torch():
# attention_weights = attention_weights * softmax_scale
# attention_weights is [num_heads(128), qlen, k_len]
# attention_weights = attention_weights.transpose(0,1).unsqueeze(0).squeeze(-1).expand(qlen,-1,-1).transpose(0,1)
# attention_masks[i] is [qlen, k_len]
# attention_weights = (attention_weights + attention_masks)
# attention_weights shape is [num_heads(128), qlen, k_len]
# attention_weights = nn.functional.softmax(attention_weights,dim=-1,dtype=weight_type).to(q_pe.dtype)
attention_weights_0_decode = load_fp16_tensor('./debug_decode/query_0_tp_0_attention_weights', (1024, 4096))
attention_weights_0_prefill = load_fp16_tensor('./debug_prefill/query_0_tp_0_attention_weights', (1024, 4096))
attention_weights_0_decode = load_fp16_tensor("./debug_decode/query_0_tp_0_attention_weights", (1024, 4096))
attention_weights_0_prefill = load_fp16_tensor("./debug_prefill/query_0_tp_0_attention_weights", (1024, 4096))
# attention_weights_0 = attention_weights_0[0:qlen, 0:kv_total_len]
# diff = torch.abs(attention_weights[0] - attention_weights_0).max()
# print("attention_weights[0] diff -> ", diff)
# attn_output = torch.matmul(attention_weights, batch_compressed_kv) # [num_heads(128),qlen, lora_rank(512)]
# out_absorb shape is [num_heads(128), kv_lora_rank(512), v_head_dim(128)]
@ -322,7 +315,7 @@ def test_torch():
# w_o = o_proj.weight.view([hidden_size,num_heads * nope_size])
# output = torch.matmul(attn_output,w_o.transpose(0,1))
# output = output.view(qlen, hidden_size)
# output_0_check = load_fp16_tensor('./debug/query_0_tp_0_qlen_output', (qlen, hidden_size))
# h1_o = w_o[:,:128]
# local_o_check = load_fp16_tensor('./debug/query_0_tp_0_local_w_o', (hidden_size, 128))
@ -333,18 +326,15 @@ def test_torch():
# diff = torch.abs(h1_output - output_0_check).max()
# print("h1_output diff -> ", diff)
output_check_decode = load_fp16_tensor('./debug_decode/output.bin')
output_check_prefill = load_fp16_tensor('./debug_prefill/output.bin')
output_check_decode = load_fp16_tensor("./debug_decode/output.bin")
output_check_prefill = load_fp16_tensor("./debug_prefill/output.bin")
# diff = torch.abs(output - output_check).max()
# mae = torch.mean(torch.abs(output - output_check))
# print("output diff -> ", diff)
return None
torch.set_printoptions(sci_mode=False, precision=5)
# output_cpu = test_cpu_mla()
# output_cpu_quant = test_cpu_mla_quant()
@ -361,7 +351,3 @@ output_torch = test_torch()
# print(f'Diff: ave:{diff.mean()}, max:{diff.max()}, min:{diff.min()}, relative_mean:{diff_relative_mean}, relative_max:{diff_relative.max()}, relative_min:{diff_relative.min()}')
# assert diff_relative_mean < 2e-1, "CPU and Torch outputs are not close enough!"