mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2026-04-30 21:00:07 +00:00
[fix](test): fix import kt-kernel (#1728)
This commit is contained in:
parent
6fc4080a7d
commit
a8667ddb58
33 changed files with 1063 additions and 1151 deletions
|
|
@ -1,19 +1,22 @@
|
|||
import logging
|
||||
import os,sys
|
||||
import os, sys
|
||||
import time
|
||||
from typing import Optional
|
||||
|
||||
os.environ["BLAS_NUM_THREADS"] = "1"
|
||||
sys.path.insert(0, os.path.dirname(__file__) + '/../build')
|
||||
import kt_kernel_ext
|
||||
sys.path.insert(0, os.path.dirname(__file__) + "/../build")
|
||||
from kt_kernel import kt_kernel_ext
|
||||
from kt_kernel_ext.kvcache import ggml_type
|
||||
import torch
|
||||
from torch import inf, nn
|
||||
from torch.nn import init
|
||||
from torch_attention import apply_rotary_pos_emb,DeepseekV2RMSNorm,KDeepSeekV3Cache,DeepseekV3YarnRotaryEmbedding
|
||||
from torch_attention import apply_rotary_pos_emb, DeepseekV2RMSNorm, KDeepSeekV3Cache, DeepseekV3YarnRotaryEmbedding
|
||||
|
||||
logger = logging.getLogger("reader")
|
||||
|
||||
from gguf.gguf_reader import GGUFReader
|
||||
|
||||
|
||||
def read_gguf_file(gguf_file_path):
|
||||
"""
|
||||
Reads and prints key-value pairs and tensor information from a GGUF file in an improved format.
|
||||
|
|
@ -46,12 +49,15 @@ def read_gguf_file(gguf_file_path):
|
|||
re.append(tensor)
|
||||
return re
|
||||
|
||||
|
||||
def get_torch_tensor_from_gguf(gguf_weights, name):
|
||||
return torch.from_numpy(gguf_weights[name].data).contiguous()
|
||||
|
||||
|
||||
def get_torch_tensor_and_type_from_gguf(gguf_weights, name):
|
||||
return torch.from_numpy(gguf_weights[name].data).contiguous(), gguf_weights[name].tensor_type.name
|
||||
|
||||
|
||||
def type_to_ggml_type(type):
|
||||
if type == "F32":
|
||||
return ggml_type.FP32
|
||||
|
|
@ -75,7 +81,7 @@ kvlen = 0
|
|||
|
||||
|
||||
page_table = range(20)
|
||||
bsz_tensors=torch.tensor([1])
|
||||
bsz_tensors = torch.tensor([1])
|
||||
|
||||
|
||||
page_size = 256
|
||||
|
|
@ -94,8 +100,7 @@ rope_theta = 10000
|
|||
max_qlen = 1024
|
||||
max_kvlen = 4096
|
||||
|
||||
max_position_embeddings = 163840
|
||||
|
||||
max_position_embeddings = 163840
|
||||
|
||||
|
||||
rope_scaling = {
|
||||
|
|
@ -105,11 +110,10 @@ rope_scaling = {
|
|||
"mscale": 1.0,
|
||||
"mscale_all_dim": 1.0,
|
||||
"original_max_position_embeddings": 4096,
|
||||
"type": "yarn"
|
||||
"type": "yarn",
|
||||
}
|
||||
|
||||
|
||||
|
||||
CPUInfer = kt_kernel_ext.CPUInfer(64)
|
||||
validation_iter = 100
|
||||
|
||||
|
|
@ -119,15 +123,16 @@ weight_type = torch.bfloat16
|
|||
# weight_type = torch.float16
|
||||
|
||||
|
||||
input_type = {torch.float32:torch.float32,
|
||||
torch.float16:torch.float16,
|
||||
torch.bfloat16:torch.float32,
|
||||
}[weight_type]
|
||||
input_type = {
|
||||
torch.float32: torch.float32,
|
||||
torch.float16: torch.float16,
|
||||
torch.bfloat16: torch.float32,
|
||||
}[weight_type]
|
||||
|
||||
q_a_proj = nn.Linear(hidden_size, q_lora_rank, bias=False, dtype=weight_type)
|
||||
q_b_proj = nn.Linear(q_lora_rank, num_heads * (nope_size+rope_size) , bias=False, dtype=weight_type)
|
||||
q_b_proj = nn.Linear(q_lora_rank, num_heads * (nope_size + rope_size), bias=False, dtype=weight_type)
|
||||
kv_a_proj_with_mqa = nn.Linear(hidden_size, kv_lora_rank + rope_size, bias=False, dtype=weight_type)
|
||||
kv_b_proj = nn.Linear( num_heads * (nope_size + nope_size),kv_lora_rank, bias=False, dtype=weight_type)
|
||||
kv_b_proj = nn.Linear(num_heads * (nope_size + nope_size), kv_lora_rank, bias=False, dtype=weight_type)
|
||||
o_proj = nn.Linear(num_heads * nope_size, hidden_size, bias=False, dtype=weight_type)
|
||||
q_a_norm = torch.ones(hidden_size, dtype=torch.float32)
|
||||
kv_a_norm = torch.ones(hidden_size, dtype=torch.float32)
|
||||
|
|
@ -190,7 +195,7 @@ if use_real_weights := True:
|
|||
|
||||
o_proj_weight, type = get_torch_tensor_and_type_from_gguf(gguf_weights, f"blk.{layer_idx}.attn_output.weight")
|
||||
o_proj.weight = nn.Parameter(o_proj_weight.view(torch.bfloat16), requires_grad=False)
|
||||
|
||||
|
||||
else:
|
||||
init.normal_(q_a_proj.weight, mean=0.0, std=0.02)
|
||||
init.normal_(q_b_proj.weight, mean=0.0, std=0.02)
|
||||
|
|
@ -203,16 +208,16 @@ q_absorb = x_reshaped[:, 0]
|
|||
out_absorb = x_reshaped[:, 1]
|
||||
|
||||
|
||||
hidden_states = torch.randn((qlen, hidden_size), dtype=input_type).to('cpu').contiguous()
|
||||
hidden_states = torch.randn((qlen, hidden_size), dtype=input_type).to("cpu").contiguous()
|
||||
|
||||
|
||||
def build_mla():
|
||||
os.environ["BLAS_NUM_THREADS"] = "1"
|
||||
q_a_proj_weight = q_a_proj.weight.to(weight_type).to('cpu').contiguous()
|
||||
q_b_proj_weight = q_b_proj.weight.to(weight_type).to('cpu').contiguous()
|
||||
kv_a_proj_with_mqa_weight = kv_a_proj_with_mqa.weight.to('cpu').to(weight_type).contiguous()
|
||||
kv_b_proj_weight = kv_b_proj.weight.to(weight_type).to('cpu').contiguous()
|
||||
o_proj_weight = o_proj.weight.to(weight_type).to('cpu').contiguous()
|
||||
q_a_proj_weight = q_a_proj.weight.to(weight_type).to("cpu").contiguous()
|
||||
q_b_proj_weight = q_b_proj.weight.to(weight_type).to("cpu").contiguous()
|
||||
kv_a_proj_with_mqa_weight = kv_a_proj_with_mqa.weight.to("cpu").to(weight_type).contiguous()
|
||||
kv_b_proj_weight = kv_b_proj.weight.to(weight_type).to("cpu").contiguous()
|
||||
o_proj_weight = o_proj.weight.to(weight_type).to("cpu").contiguous()
|
||||
|
||||
config = kt_kernel_ext.mla.MLAConfig(
|
||||
hidden_size,
|
||||
|
|
@ -224,7 +229,7 @@ def build_mla():
|
|||
)
|
||||
config.max_qlen = max_qlen
|
||||
config.max_kvlen = max_kvlen
|
||||
config.max_position_embeddings = max_position_embeddings
|
||||
config.max_position_embeddings = max_position_embeddings
|
||||
config.rope_scaling_factor = rope_scaling["factor"]
|
||||
config.rope_theta = rope_theta
|
||||
config.rope_scaling_beta_fast = rope_scaling["beta_fast"]
|
||||
|
|
@ -244,7 +249,6 @@ def build_mla():
|
|||
config.kv_a_norm = kv_a_norm.data_ptr()
|
||||
config.kv_a_norm_type = ggml_type.FP32
|
||||
|
||||
|
||||
if weight_type == torch.float32:
|
||||
config.q_a_proj_type = ggml_type.FP32
|
||||
config.q_b_proj_type = ggml_type.FP32
|
||||
|
|
@ -266,10 +270,8 @@ def build_mla():
|
|||
else:
|
||||
raise ValueError(f"Unsupported data type: {weight_type}")
|
||||
|
||||
|
||||
config.pool = CPUInfer.backend_
|
||||
|
||||
|
||||
if weight_type == torch.float32:
|
||||
mla = kt_kernel_ext.mla.MLA_F32(config)
|
||||
elif weight_type == torch.float16:
|
||||
|
|
@ -278,25 +280,20 @@ def build_mla():
|
|||
mla = kt_kernel_ext.mla.MLA_F32(config)
|
||||
else:
|
||||
raise ValueError(f"Unsupported data type: {weight_type}")
|
||||
|
||||
|
||||
mla.load_weights()
|
||||
mla.set_local_pages(pages_count)
|
||||
return mla
|
||||
|
||||
|
||||
|
||||
|
||||
def load_fp32_tensor(file_path, shape):
|
||||
with open(file_path, 'rb') as f:
|
||||
with open(file_path, "rb") as f:
|
||||
raw_data = f.read()
|
||||
tensor = torch.frombuffer(raw_data, dtype=torch.float32)
|
||||
tensor = tensor.view(shape) # 根据你的 shape reshape
|
||||
return tensor
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
# page3 = load_fp32_tensor('/home/yzw/xwy/Projects/ktransformers-dev/csrc/ktransformers_ext/examples/debug1/query_0_tp_0_page_3_kv_lora_rank_norm.f32',(page_size,kv_lora_rank))
|
||||
# page3_2 = load_fp32_tensor('/home/yzw/xwy/Projects/ktransformers-dev/csrc/ktransformers_ext/examples/debug2/query_0_tp_0_page_3_kv_lora_rank_norm.f32',(page_size,kv_lora_rank))
|
||||
|
||||
|
|
@ -320,7 +317,6 @@ def load_fp32_tensor(file_path, shape):
|
|||
# print(f'PE Attention Weights Diff: ave:{diff.mean()}, max:{diff.max()}')
|
||||
|
||||
|
||||
|
||||
# raw_attn_w_1 = load_fp32_tensor('/home/yzw/xwy/Projects/ktransformers-dev/csrc/ktransformers_ext/examples/debug1/query_0_tp_0_raw_attention_weights.f32',(1,max_kvlen))
|
||||
# raw_attn_w_2 = load_fp32_tensor('/home/yzw/xwy/Projects/ktransformers-dev/csrc/ktransformers_ext/examples/debug2/query_0_tp_0_raw_attention_weights.f32',(qlen,max_kvlen))
|
||||
# diff = torch.abs(raw_attn_w_1 - raw_attn_w_2[-1])
|
||||
|
|
@ -334,22 +330,16 @@ def load_fp32_tensor(file_path, shape):
|
|||
# print(f'Output Diff: ave:{diff.mean()}, max:{diff.max()}')
|
||||
|
||||
|
||||
|
||||
|
||||
mla = build_mla()
|
||||
output = torch.zeros((qlen, hidden_size), dtype=input_type).to('cpu').contiguous()
|
||||
mla.forward([qlen],[page_table],[kvlen],hidden_states.data_ptr(),output.data_ptr())
|
||||
print("CPU MLA Output: ",output[-1])
|
||||
output = torch.zeros((qlen, hidden_size), dtype=input_type).to("cpu").contiguous()
|
||||
mla.forward([qlen], [page_table], [kvlen], hidden_states.data_ptr(), output.data_ptr())
|
||||
print("CPU MLA Output: ", output[-1])
|
||||
|
||||
|
||||
output_2 = torch.zeros((1, hidden_size), dtype=input_type).to('cpu').contiguous()
|
||||
mla.forward([1],[page_table],[qlen-1],hidden_states[-1].data_ptr(),output_2.data_ptr())
|
||||
print("CPU MLA Output 2: ",output_2[-1])
|
||||
output_2 = torch.zeros((1, hidden_size), dtype=input_type).to("cpu").contiguous()
|
||||
mla.forward([1], [page_table], [qlen - 1], hidden_states[-1].data_ptr(), output_2.data_ptr())
|
||||
print("CPU MLA Output 2: ", output_2[-1])
|
||||
|
||||
diff = torch.abs(output[-1] - output_2[-1])
|
||||
print(f'Diff: ave:{diff.mean()}, max:{diff.max()}')
|
||||
print(f"Diff: ave:{diff.mean()}, max:{diff.max()}")
|
||||
assert diff.max() < 1e-1, "CPU and Torch outputs are not close enough!"
|
||||
|
||||
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue