[fix](test): fix import kt-kernel (#1728)

This commit is contained in:
ErvinXie 2025-12-17 19:46:32 +08:00 committed by GitHub
parent 6fc4080a7d
commit a8667ddb58
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
33 changed files with 1063 additions and 1151 deletions

View file

@ -1,19 +1,22 @@
import logging
import os,sys
import os, sys
import time
from typing import Optional
os.environ["BLAS_NUM_THREADS"] = "1"
sys.path.insert(0, os.path.dirname(__file__) + '/../build')
import kt_kernel_ext
sys.path.insert(0, os.path.dirname(__file__) + "/../build")
from kt_kernel import kt_kernel_ext
from kt_kernel_ext.kvcache import ggml_type
import torch
from torch import inf, nn
from torch.nn import init
from torch_attention import apply_rotary_pos_emb,DeepseekV2RMSNorm,KDeepSeekV3Cache,DeepseekV3YarnRotaryEmbedding
from torch_attention import apply_rotary_pos_emb, DeepseekV2RMSNorm, KDeepSeekV3Cache, DeepseekV3YarnRotaryEmbedding
logger = logging.getLogger("reader")
from gguf.gguf_reader import GGUFReader
def read_gguf_file(gguf_file_path):
"""
Reads and prints key-value pairs and tensor information from a GGUF file in an improved format.
@ -46,12 +49,15 @@ def read_gguf_file(gguf_file_path):
re.append(tensor)
return re
def get_torch_tensor_from_gguf(gguf_weights, name):
return torch.from_numpy(gguf_weights[name].data).contiguous()
def get_torch_tensor_and_type_from_gguf(gguf_weights, name):
return torch.from_numpy(gguf_weights[name].data).contiguous(), gguf_weights[name].tensor_type.name
def type_to_ggml_type(type):
if type == "F32":
return ggml_type.FP32
@ -75,7 +81,7 @@ kvlen = 0
page_table = range(20)
bsz_tensors=torch.tensor([1])
bsz_tensors = torch.tensor([1])
page_size = 256
@ -94,8 +100,7 @@ rope_theta = 10000
max_qlen = 1024
max_kvlen = 4096
max_position_embeddings = 163840
max_position_embeddings = 163840
rope_scaling = {
@ -105,11 +110,10 @@ rope_scaling = {
"mscale": 1.0,
"mscale_all_dim": 1.0,
"original_max_position_embeddings": 4096,
"type": "yarn"
"type": "yarn",
}
CPUInfer = kt_kernel_ext.CPUInfer(64)
validation_iter = 100
@ -119,15 +123,16 @@ weight_type = torch.bfloat16
# weight_type = torch.float16
input_type = {torch.float32:torch.float32,
torch.float16:torch.float16,
torch.bfloat16:torch.float32,
}[weight_type]
input_type = {
torch.float32: torch.float32,
torch.float16: torch.float16,
torch.bfloat16: torch.float32,
}[weight_type]
q_a_proj = nn.Linear(hidden_size, q_lora_rank, bias=False, dtype=weight_type)
q_b_proj = nn.Linear(q_lora_rank, num_heads * (nope_size+rope_size) , bias=False, dtype=weight_type)
q_b_proj = nn.Linear(q_lora_rank, num_heads * (nope_size + rope_size), bias=False, dtype=weight_type)
kv_a_proj_with_mqa = nn.Linear(hidden_size, kv_lora_rank + rope_size, bias=False, dtype=weight_type)
kv_b_proj = nn.Linear( num_heads * (nope_size + nope_size),kv_lora_rank, bias=False, dtype=weight_type)
kv_b_proj = nn.Linear(num_heads * (nope_size + nope_size), kv_lora_rank, bias=False, dtype=weight_type)
o_proj = nn.Linear(num_heads * nope_size, hidden_size, bias=False, dtype=weight_type)
q_a_norm = torch.ones(hidden_size, dtype=torch.float32)
kv_a_norm = torch.ones(hidden_size, dtype=torch.float32)
@ -190,7 +195,7 @@ if use_real_weights := True:
o_proj_weight, type = get_torch_tensor_and_type_from_gguf(gguf_weights, f"blk.{layer_idx}.attn_output.weight")
o_proj.weight = nn.Parameter(o_proj_weight.view(torch.bfloat16), requires_grad=False)
else:
init.normal_(q_a_proj.weight, mean=0.0, std=0.02)
init.normal_(q_b_proj.weight, mean=0.0, std=0.02)
@ -203,16 +208,16 @@ q_absorb = x_reshaped[:, 0]
out_absorb = x_reshaped[:, 1]
hidden_states = torch.randn((qlen, hidden_size), dtype=input_type).to('cpu').contiguous()
hidden_states = torch.randn((qlen, hidden_size), dtype=input_type).to("cpu").contiguous()
def build_mla():
os.environ["BLAS_NUM_THREADS"] = "1"
q_a_proj_weight = q_a_proj.weight.to(weight_type).to('cpu').contiguous()
q_b_proj_weight = q_b_proj.weight.to(weight_type).to('cpu').contiguous()
kv_a_proj_with_mqa_weight = kv_a_proj_with_mqa.weight.to('cpu').to(weight_type).contiguous()
kv_b_proj_weight = kv_b_proj.weight.to(weight_type).to('cpu').contiguous()
o_proj_weight = o_proj.weight.to(weight_type).to('cpu').contiguous()
q_a_proj_weight = q_a_proj.weight.to(weight_type).to("cpu").contiguous()
q_b_proj_weight = q_b_proj.weight.to(weight_type).to("cpu").contiguous()
kv_a_proj_with_mqa_weight = kv_a_proj_with_mqa.weight.to("cpu").to(weight_type).contiguous()
kv_b_proj_weight = kv_b_proj.weight.to(weight_type).to("cpu").contiguous()
o_proj_weight = o_proj.weight.to(weight_type).to("cpu").contiguous()
config = kt_kernel_ext.mla.MLAConfig(
hidden_size,
@ -224,7 +229,7 @@ def build_mla():
)
config.max_qlen = max_qlen
config.max_kvlen = max_kvlen
config.max_position_embeddings = max_position_embeddings
config.max_position_embeddings = max_position_embeddings
config.rope_scaling_factor = rope_scaling["factor"]
config.rope_theta = rope_theta
config.rope_scaling_beta_fast = rope_scaling["beta_fast"]
@ -244,7 +249,6 @@ def build_mla():
config.kv_a_norm = kv_a_norm.data_ptr()
config.kv_a_norm_type = ggml_type.FP32
if weight_type == torch.float32:
config.q_a_proj_type = ggml_type.FP32
config.q_b_proj_type = ggml_type.FP32
@ -266,10 +270,8 @@ def build_mla():
else:
raise ValueError(f"Unsupported data type: {weight_type}")
config.pool = CPUInfer.backend_
if weight_type == torch.float32:
mla = kt_kernel_ext.mla.MLA_F32(config)
elif weight_type == torch.float16:
@ -278,25 +280,20 @@ def build_mla():
mla = kt_kernel_ext.mla.MLA_F32(config)
else:
raise ValueError(f"Unsupported data type: {weight_type}")
mla.load_weights()
mla.set_local_pages(pages_count)
return mla
def load_fp32_tensor(file_path, shape):
with open(file_path, 'rb') as f:
with open(file_path, "rb") as f:
raw_data = f.read()
tensor = torch.frombuffer(raw_data, dtype=torch.float32)
tensor = tensor.view(shape) # 根据你的 shape reshape
return tensor
# page3 = load_fp32_tensor('/home/yzw/xwy/Projects/ktransformers-dev/csrc/ktransformers_ext/examples/debug1/query_0_tp_0_page_3_kv_lora_rank_norm.f32',(page_size,kv_lora_rank))
# page3_2 = load_fp32_tensor('/home/yzw/xwy/Projects/ktransformers-dev/csrc/ktransformers_ext/examples/debug2/query_0_tp_0_page_3_kv_lora_rank_norm.f32',(page_size,kv_lora_rank))
@ -320,7 +317,6 @@ def load_fp32_tensor(file_path, shape):
# print(f'PE Attention Weights Diff: ave:{diff.mean()}, max:{diff.max()}')
# raw_attn_w_1 = load_fp32_tensor('/home/yzw/xwy/Projects/ktransformers-dev/csrc/ktransformers_ext/examples/debug1/query_0_tp_0_raw_attention_weights.f32',(1,max_kvlen))
# raw_attn_w_2 = load_fp32_tensor('/home/yzw/xwy/Projects/ktransformers-dev/csrc/ktransformers_ext/examples/debug2/query_0_tp_0_raw_attention_weights.f32',(qlen,max_kvlen))
# diff = torch.abs(raw_attn_w_1 - raw_attn_w_2[-1])
@ -334,22 +330,16 @@ def load_fp32_tensor(file_path, shape):
# print(f'Output Diff: ave:{diff.mean()}, max:{diff.max()}')
mla = build_mla()
output = torch.zeros((qlen, hidden_size), dtype=input_type).to('cpu').contiguous()
mla.forward([qlen],[page_table],[kvlen],hidden_states.data_ptr(),output.data_ptr())
print("CPU MLA Output: ",output[-1])
output = torch.zeros((qlen, hidden_size), dtype=input_type).to("cpu").contiguous()
mla.forward([qlen], [page_table], [kvlen], hidden_states.data_ptr(), output.data_ptr())
print("CPU MLA Output: ", output[-1])
output_2 = torch.zeros((1, hidden_size), dtype=input_type).to('cpu').contiguous()
mla.forward([1],[page_table],[qlen-1],hidden_states[-1].data_ptr(),output_2.data_ptr())
print("CPU MLA Output 2: ",output_2[-1])
output_2 = torch.zeros((1, hidden_size), dtype=input_type).to("cpu").contiguous()
mla.forward([1], [page_table], [qlen - 1], hidden_states[-1].data_ptr(), output_2.data_ptr())
print("CPU MLA Output 2: ", output_2[-1])
diff = torch.abs(output[-1] - output_2[-1])
print(f'Diff: ave:{diff.mean()}, max:{diff.max()}')
print(f"Diff: ave:{diff.mean()}, max:{diff.max()}")
assert diff.max() < 1e-1, "CPU and Torch outputs are not close enough!"