[fix](test): fix import kt-kernel (#1728)

This commit is contained in:
ErvinXie 2025-12-17 19:46:32 +08:00 committed by GitHub
parent 6fc4080a7d
commit a8667ddb58
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
33 changed files with 1063 additions and 1151 deletions

View file

@ -1,19 +1,22 @@
import logging
import os,sys
import os, sys
import time
from typing import Optional
os.environ["BLAS_NUM_THREADS"] = "1"
sys.path.insert(0, os.path.dirname(__file__) + '/../build')
import kt_kernel_ext
sys.path.insert(0, os.path.dirname(__file__) + "/../build")
from kt_kernel import kt_kernel_ext
from kt_kernel_ext.kvcache import ggml_type
import torch
from torch import inf, nn
from torch.nn import init
from torch_attention import apply_rotary_pos_emb,DeepseekV2RMSNorm,KDeepSeekV3Cache,DeepseekV3YarnRotaryEmbedding
from torch_attention import apply_rotary_pos_emb, DeepseekV2RMSNorm, KDeepSeekV3Cache, DeepseekV3YarnRotaryEmbedding
logger = logging.getLogger("reader")
from gguf.gguf_reader import GGUFReader
def read_gguf_file(gguf_file_path):
"""
Reads and prints key-value pairs and tensor information from a GGUF file in an improved format.
@ -46,12 +49,15 @@ def read_gguf_file(gguf_file_path):
re.append(tensor)
return re
def get_torch_tensor_from_gguf(gguf_weights, name):
return torch.from_numpy(gguf_weights[name].data).contiguous()
def get_torch_tensor_and_type_from_gguf(gguf_weights, name):
return torch.from_numpy(gguf_weights[name].data).contiguous(), gguf_weights[name].tensor_type.name
def type_to_ggml_type(type):
if type == "F32":
return ggml_type.FP32
@ -70,12 +76,12 @@ seed = 42 # 你可以选择任何整数作为种子
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
qlen = 3212
qlen = 3212
kvlen = 0
page_table = range(20)
bsz_tensors=torch.tensor([1])
bsz_tensors = torch.tensor([1])
page_size = 256
@ -94,8 +100,7 @@ rope_theta = 10000
max_qlen = 4096
max_kvlen = 4096
max_position_embeddings = 163840
max_position_embeddings = 163840
rope_scaling = {
@ -105,11 +110,10 @@ rope_scaling = {
"mscale": 1.0,
"mscale_all_dim": 1.0,
"original_max_position_embeddings": 4096,
"type": "yarn"
"type": "yarn",
}
CPUInfer = kt_kernel_ext.CPUInfer(30)
validation_iter = 100
@ -119,15 +123,16 @@ weight_type = torch.bfloat16
# weight_type = torch.float16
input_type = {torch.float32:torch.float32,
torch.float16:torch.float16,
torch.bfloat16:torch.float32,
}[weight_type]
input_type = {
torch.float32: torch.float32,
torch.float16: torch.float16,
torch.bfloat16: torch.float32,
}[weight_type]
q_a_proj = nn.Linear(hidden_size, q_lora_rank, bias=False, dtype=weight_type)
q_b_proj = nn.Linear(q_lora_rank, num_heads * (nope_size+rope_size) , bias=False, dtype=weight_type)
q_b_proj = nn.Linear(q_lora_rank, num_heads * (nope_size + rope_size), bias=False, dtype=weight_type)
kv_a_proj_with_mqa = nn.Linear(hidden_size, kv_lora_rank + rope_size, bias=False, dtype=weight_type)
kv_b_proj = nn.Linear( num_heads * (nope_size + nope_size),kv_lora_rank, bias=False, dtype=weight_type)
kv_b_proj = nn.Linear(num_heads * (nope_size + nope_size), kv_lora_rank, bias=False, dtype=weight_type)
o_proj = nn.Linear(num_heads * nope_size, hidden_size, bias=False, dtype=weight_type)
q_a_norm = torch.ones(hidden_size, dtype=torch.float32)
kv_a_norm = torch.ones(hidden_size, dtype=torch.float32)
@ -190,7 +195,7 @@ if use_real_weights := True:
o_proj_weight, type = get_torch_tensor_and_type_from_gguf(gguf_weights, f"blk.{layer_idx}.attn_output.weight")
o_proj.weight = nn.Parameter(o_proj_weight.view(torch.bfloat16), requires_grad=False)
else:
init.normal_(q_a_proj.weight, mean=0.0, std=0.02)
init.normal_(q_b_proj.weight, mean=0.0, std=0.02)
@ -203,16 +208,16 @@ q_absorb = x_reshaped[:, 0]
out_absorb = x_reshaped[:, 1]
hidden_states = torch.randn((qlen, hidden_size), dtype=input_type).to('cpu').contiguous()
hidden_states = torch.randn((qlen, hidden_size), dtype=input_type).to("cpu").contiguous()
def test_cpu_mla():
os.environ["BLAS_NUM_THREADS"] = "1"
q_a_proj_weight = q_a_proj.weight.to(weight_type).to('cpu').contiguous()
q_b_proj_weight = q_b_proj.weight.to(weight_type).to('cpu').contiguous()
kv_a_proj_with_mqa_weight = kv_a_proj_with_mqa.weight.to('cpu').to(weight_type).contiguous()
kv_b_proj_weight = kv_b_proj.weight.to(weight_type).to('cpu').contiguous()
o_proj_weight = o_proj.weight.to(weight_type).to('cpu').contiguous()
q_a_proj_weight = q_a_proj.weight.to(weight_type).to("cpu").contiguous()
q_b_proj_weight = q_b_proj.weight.to(weight_type).to("cpu").contiguous()
kv_a_proj_with_mqa_weight = kv_a_proj_with_mqa.weight.to("cpu").to(weight_type).contiguous()
kv_b_proj_weight = kv_b_proj.weight.to(weight_type).to("cpu").contiguous()
o_proj_weight = o_proj.weight.to(weight_type).to("cpu").contiguous()
config = kt_kernel_ext.mla.MLAConfig(
hidden_size,
@ -224,7 +229,7 @@ def test_cpu_mla():
)
config.max_qlen = max_qlen
config.max_kvlen = max_kvlen
config.max_position_embeddings = max_position_embeddings
config.max_position_embeddings = max_position_embeddings
config.rope_scaling_factor = rope_scaling["factor"]
config.rope_theta = rope_theta
config.rope_scaling_beta_fast = rope_scaling["beta_fast"]
@ -245,7 +250,6 @@ def test_cpu_mla():
config.kv_a_norm_type = ggml_type.FP32
config.page_count = pages_count
if weight_type == torch.float32:
config.q_a_proj_type = ggml_type.FP32
config.q_b_proj_type = ggml_type.FP32
@ -267,10 +271,8 @@ def test_cpu_mla():
else:
raise ValueError(f"Unsupported data type: {weight_type}")
config.pool = CPUInfer.backend_
if weight_type == torch.float32:
mla = kt_kernel_ext.mla.MLA_F32(config)
elif weight_type == torch.float16:
@ -280,54 +282,53 @@ def test_cpu_mla():
mla = kt_kernel_ext.mla.MLA_QUAN_F32(config)
else:
raise ValueError(f"Unsupported data type: {weight_type}")
mla.load_weights()
mla.set_local_pages(pages_count)
output = torch.zeros((qlen, hidden_size), dtype=input_type).to('cpu').contiguous()
mla.forward([qlen],[page_table],[kvlen],hidden_states.data_ptr(),output.data_ptr())
print("CPU MLA Output: ",output)
output = torch.zeros((qlen, hidden_size), dtype=input_type).to("cpu").contiguous()
mla.forward([qlen], [page_table], [kvlen], hidden_states.data_ptr(), output.data_ptr())
print("CPU MLA Output: ", output)
return output
def load_fp16_tensor(file_path, shape):
# return load_fp32_tensor(file_path, shape)
return torch.zeros(shape)
with open(file_path, 'rb') as f:
with open(file_path, "rb") as f:
raw_data = f.read()
tensor = torch.frombuffer(raw_data, dtype=weight_type)
tensor = tensor.view(shape) # 根据你的 shape reshape
return tensor
def load_fp32_tensor(file_path, shape):
return torch.zeros(shape)
with open(file_path, 'rb') as f:
with open(file_path, "rb") as f:
raw_data = f.read()
tensor = torch.frombuffer(raw_data, dtype=torch.float32)
tensor = tensor.view(shape) # 根据你的 shape reshape
return tensor
def test_torch():
torch.set_grad_enabled(False)
softmax_scale = (nope_size + rope_size) ** -0.5
# 1代表的是压缩的kv的头数
k_caches = torch.randn(1,pages_count, page_size, 1, kv_lora_rank + rope_size).to(weight_type)
k_caches = torch.randn(1, pages_count, page_size, 1, kv_lora_rank + rope_size).to(weight_type)
kv_cache = KDeepSeekV3Cache(page_size=page_size, kv_lora_rank=kv_lora_rank, k_caches=k_caches)
q_a_layernorm = DeepseekV2RMSNorm(q_lora_rank)
q_a_layernorm.weight = nn.Parameter( q_a_norm,requires_grad=False)
q_a_layernorm.weight = nn.Parameter(q_a_norm, requires_grad=False)
x = torch.randn(q_lora_rank, dtype=weight_type)*100
x = torch.randn(q_lora_rank, dtype=weight_type) * 100
print(x)
print(q_a_layernorm(x))
kv_a_layernorm = DeepseekV2RMSNorm(kv_lora_rank)
kv_a_layernorm.weight = nn.Parameter(kv_a_norm, requires_grad=False)
# 第三步:拆分成两个 tensor
# q_absorb, out_absorb = x_permuted[:, 0], x_permuted[:, 1] # 都是 (num_heads, nope_size, kv_lora_rank
# q_absorb = kv_b_proj[:, ] # torch.randn(num_heads, nope_size, kv_lora_rank, dtype=data_type)
@ -348,65 +349,64 @@ def test_torch():
# kv_indices 是[0:bsz]page_idx=[0:bsz], page_offset=[kvlen:qlen+kvlen]
# last_page_len = [qlen+kvlen,...] layer_idx = 1
# position_ids = [kvlen:qlen+kvlen]
q_indptr = torch.tensor([0,qlen]).to(torch.int32)
q_indptr = torch.tensor([0, qlen]).to(torch.int32)
kv_indptr = torch.tensor([0,(qlen+kvlen+page_size-1)//page_size]).to(torch.int32)
kv_indptr = torch.tensor([0, (qlen + kvlen + page_size - 1) // page_size]).to(torch.int32)
kv_indices = torch.tensor(range(pages_count)).to(torch.int32)
page_idx = torch.tensor([i//page_size for i in range(kvlen,kvlen+qlen)] ).to(torch.int32)
page_offset = torch.tensor( [i%page_size for i in range(kvlen, kvlen + qlen)]).to(torch.int32)
page_idx = torch.tensor([i // page_size for i in range(kvlen, kvlen + qlen)]).to(torch.int32)
page_offset = torch.tensor([i % page_size for i in range(kvlen, kvlen + qlen)]).to(torch.int32)
last_page_len = torch.tensor([256], device=hidden_states.device)
position_ids = torch.tensor(range(kvlen, kvlen + qlen)).to(torch.int32)
# 按照行创建 mask [qlen,kvlen+qlen]
attention_masks = torch.zeros((max_qlen, max_kvlen), dtype=weight_type)
for i in range(max_qlen):
attention_masks[i, i + kvlen + 1:] = -inf
attention_masks[i, i + kvlen + 1 :] = -inf
def torch_attn(hidden_states_i: torch.Tensor,
kv_cache: KDeepSeekV3Cache,
position_ids: torch.Tensor,
page_idx: torch.Tensor,
page_offset: torch.Tensor,
attention_masks: Optional[list[torch.Tensor]] = None,
q_indptr: Optional[torch.Tensor] = None,
kv_indices: Optional[torch.Tensor] = None,
kv_indptr: Optional[torch.Tensor] = None,
bsz_tensors: Optional[torch.Tensor] = None,
last_page_len: Optional[torch.Tensor] = None,
layer_idx: Optional[int] = None,
):
def torch_attn(
hidden_states_i: torch.Tensor,
kv_cache: KDeepSeekV3Cache,
position_ids: torch.Tensor,
page_idx: torch.Tensor,
page_offset: torch.Tensor,
attention_masks: Optional[list[torch.Tensor]] = None,
q_indptr: Optional[torch.Tensor] = None,
kv_indices: Optional[torch.Tensor] = None,
kv_indptr: Optional[torch.Tensor] = None,
bsz_tensors: Optional[torch.Tensor] = None,
last_page_len: Optional[torch.Tensor] = None,
layer_idx: Optional[int] = None,
):
global out_absorb
global q_absorb
hidden_states = hidden_states_i.to(weight_type)
# range bsz_tensors
final_attention_output = torch.tensor([], device=hidden_states.device)
for i in range(bsz_tensors[0]):
batch_num_tokens_tensors = q_indptr[i+1] - q_indptr[i]
batch_num_tokens_tensors = q_indptr[i + 1] - q_indptr[i]
batch_last_page_len = last_page_len[i]
# kv_total_len is kv_len, batch_compressed_kv is compressed_kv, batch_k_pe is k_pe
batch_page_idx = page_idx[q_indptr[i]:q_indptr[i+1]]
batch_page_offset = page_offset[q_indptr[i]:q_indptr[i+1]]
batch_page_idx = page_idx[q_indptr[i] : q_indptr[i + 1]]
batch_page_offset = page_offset[q_indptr[i] : q_indptr[i + 1]]
# kv_page_nums is the number of pages for the current batch
kv_page_nums = kv_indptr[i+1] - kv_indptr[i]
kv_page_nums = kv_indptr[i + 1] - kv_indptr[i]
# kv_total_len is the total length of the kv cache for the current batch (kv_len for algorithm)
kv_total_len = kv_page_nums * page_size
if batch_last_page_len is not None:
kv_total_len = kv_total_len - (page_size - batch_last_page_len)
# print(f"kv_total_len's shape {kv_total_len.shape}")
# kv_index is the index of the kv cache pages for the current batch
kv_index = kv_indices[kv_indptr[i]:kv_indptr[i+1]]
kv_index = kv_indices[kv_indptr[i] : kv_indptr[i + 1]]
# we can index [kv_index, page_offset_indices] to get the kv cache for the current batch
# from q_indptr[i] to q_indptr[i+1] is the range of the current batch
batch_hidden_states = hidden_states[q_indptr[i]:q_indptr[i+1]]
batch_position_ids = position_ids[q_indptr[i]:q_indptr[i+1]]
batch_hidden_states = hidden_states[q_indptr[i] : q_indptr[i + 1]]
batch_position_ids = position_ids[q_indptr[i] : q_indptr[i + 1]]
qlen, _ = batch_hidden_states.size()
# print("qlen -> ", qlen)
hidden_states_to_check = load_fp16_tensor('./debug/query_0_tp_0_input.bin',batch_hidden_states.shape)
hidden_states_to_check = load_fp16_tensor("./debug/query_0_tp_0_input.bin", batch_hidden_states.shape)
diff = torch.abs(batch_hidden_states - hidden_states_to_check).max()
print("hidden_states diff -> ", diff)
@ -422,8 +422,6 @@ def test_torch():
# print("q_lora mae -> ", mae)
# print("q_lora mae test -> ", mae_test)
q_lora_norm = q_a_layernorm(q_lora)
# q_lora_norm_to_check = load_fp16_tensor('./debug/query_0_tp_0_qlora_norm.bin', q_lora_norm.shape)
# q_lora_norm_to_check_test = load_fp16_tensor('./debug/query_0_tp_0_qlora_norm_test.bin', q_lora_norm.shape)
@ -435,30 +433,25 @@ def test_torch():
# print("q_lora_norm mae -> ", mae)
# print("q_lora_norm diff test -> ", diff_test)
# print("q_lora_norm mae test -> ", mae_test)
q = q_b_proj(q_lora_norm)
# for v3, bsz, qlen, num_heads(128), qk_head_dim(192=128(nope)+64(rope))
q = q.view(qlen, num_heads, nope_size+rope_size)
q = q.view(qlen, num_heads, nope_size + rope_size)
# q_nope is [qlen, num_heads(128), qk_nope_head_dim(128)]
# q_pe is [qlen, num_heads(128), qk_rope_head_dim(64)]
q_nope, q_pe = torch.split(
q, [nope_size, rope_size], dim=-1
)
q_nope, q_pe = torch.split(q, [nope_size, rope_size], dim=-1)
# compressed_kv is [qlen, kv_lora_rank(512) + rope(64)]
compressed_kv = kv_a_proj_with_mqa(batch_hidden_states)
# compressed_kv is [qlen, kv_lora_rank(512)], k_pe is [qlen, rope(64)]
compressed_kv, k_pe = torch.split(
compressed_kv, [kv_lora_rank, rope_size], dim=-1
)
compressed_kv, k_pe = torch.split(compressed_kv, [kv_lora_rank, rope_size], dim=-1)
compressed_kv = compressed_kv.contiguous()
# compressed_kv_page_0 = compressed_kv[0:page_size, :]
# compressed_kv_to_check = load_fp16_tensor('./debug/query_0_tp_0_page_0_kv_lora_rank',
# compressed_kv_page_0.shape)
# diff = torch.abs(compressed_kv_page_0 - compressed_kv_to_check).max()
# mae = torch.mean(torch.abs(compressed_kv_page_0 - compressed_kv_to_check))
# mae = torch.mean(torch.abs(compressed_kv_page_0 - compressed_kv_to_check))
# print("compressed_kv diff -> ", diff)
# print("compressed_kv mae -> ", mae)
@ -472,14 +465,11 @@ def test_torch():
# mae = torch.mean(torch.abs(compressed_kv_page_0 - compressed_kv_to_check))
# print("compressed_kv diff norm -> ", diff)
# print("compressed_kv mae norm -> ", mae)
k_pe = k_pe.view(qlen, 1, rope_size)
# compressed_kv is [qlen, 1, kv_lora_rank(512)]
compressed_kv = compressed_kv.view(qlen, 1, kv_lora_rank)
cos, sin = rotary_emb(q_pe, batch_position_ids)
# q_nope_check = q_nope.transpose(0, 1) # qlen is 1, no GPU overhead, same below
@ -494,8 +484,8 @@ def test_torch():
# print("q_nope[0] mae -> ", mae)
# print("q_nope[0] diff test -> ", diff_test)
# print("q_nope[0] mae test -> ", mae_test)
q_pe_nope = q_pe.transpose(0,1)
q_pe_nope = q_pe.transpose(0, 1)
# q_pe_0_to_check = load_fp16_tensor('./debug/query_0_tp_0_q_rope', q_pe_nope[0].shape)
# q_pe_0_to_check = load_fp16_tensor('./debug/query_0_tp_0_q_rope_no_rope', q_pe_nope[0].shape)
# q_pe_0_to_check_test = load_fp16_tensor('./debug/query_0_tp_0_q_rope_no_rope_test', q_pe_nope[0].shape)
@ -534,12 +524,11 @@ def test_torch():
q_pe, k_pe = apply_rotary_pos_emb(q_pe.unsqueeze(0), k_pe.unsqueeze(0), cos, sin, unsqueeze_dim=1)
q_pe = q_pe.squeeze(0)
# q_pe is [num_heads(128), qlen, qk_rope_head_dim(64)]
q_pe.transpose_(0, 1)
q_pe.transpose_(0, 1)
# diff = torch.abs(q_pe - q_new).max()
# print("q_pe diff -> ", diff)
# q_pe_0_to_check = load_fp16_tensor('./debug/query_0_tp_0_q_rope', q_pe[0].shape)
# diff = torch.abs(q_pe[0] - q_pe_0_to_check).max()
# mae = torch.mean(torch.abs(q_pe[0] - q_pe_0_to_check))
@ -552,15 +541,22 @@ def test_torch():
# print("q_pe[0] 2 mae -> ", mae)
if kv_cache is not None:
cache_kwargs = {"sin": sin, "cos": cos, "page_idx": batch_page_idx, "page_offset": batch_page_offset} # Specific to RoPE models
compressed_kv_with_k_pe = kv_cache.update(compressed_kv.unsqueeze(0), k_pe, layer_idx, batch_page_idx, batch_page_offset, cache_kwargs)
compressed_kv = compressed_kv_with_k_pe [:, :, :, :kv_lora_rank].view(-1, page_size, kv_lora_rank)
k_pe = compressed_kv_with_k_pe [:, :, :, kv_lora_rank:].view(-1, page_size, rope_size)
cache_kwargs = {
"sin": sin,
"cos": cos,
"page_idx": batch_page_idx,
"page_offset": batch_page_offset,
} # Specific to RoPE models
compressed_kv_with_k_pe = kv_cache.update(
compressed_kv.unsqueeze(0), k_pe, layer_idx, batch_page_idx, batch_page_offset, cache_kwargs
)
compressed_kv = compressed_kv_with_k_pe[:, :, :, :kv_lora_rank].view(-1, page_size, kv_lora_rank)
k_pe = compressed_kv_with_k_pe[:, :, :, kv_lora_rank:].view(-1, page_size, rope_size)
# q_absorb is [num_heads(128), qk_nope_head_dim(128), kv_lora_rank(512)]
# out_absorb is [num_heads(128), kv_lora_rank(512), v_head_dim(128)] v_head_dim is also the nope dim
# q_absorb, out_absorb = get_absorbed()
# q_nope is [num_heads(128), qlen, qk_nope_head_dim(128)]
q_nope = q_nope.transpose(0, 1) # qlen is 1, no GPU overhead, same below
q_nope = q_nope.transpose(0, 1) # qlen is 1, no GPU overhead, same below
# q_nope_0_to_check = load_fp16_tensor('./debug/query_0_tp_0_q_nope', q_nope[0].shape)
# diff = torch.abs(q_nope[0] - q_nope_0_to_check).max()
@ -568,7 +564,7 @@ def test_torch():
# print("q_nope[0] diff -> ", diff)
# q_nope is [num_heads(128), qlen, kv_lora_rank(512)]
q_nope = torch.matmul(q_nope, q_absorb) # batched MM
q_nope = torch.matmul(q_nope, q_absorb) # batched MM
# k_b_proj_check = load_fp16_tensor('./debug/query_0_tp_0_k_b_lora', (nope_size,kv_lora_rank))
# diff = torch.abs(q_absorb[0] - k_b_proj_check).max()
@ -594,7 +590,7 @@ def test_torch():
if batch_compressed_kv is None or batch_k_pe is None:
batch_compressed_kv = tmp_compressed_kv
batch_k_pe = tmp_k_pe
else:
else:
batch_compressed_kv = torch.cat((batch_compressed_kv, tmp_compressed_kv), dim=0)
batch_k_pe = torch.cat((batch_k_pe, tmp_k_pe), dim=0)
kv_total_len -= page_size
@ -604,28 +600,27 @@ def test_torch():
if batch_compressed_kv is None or batch_k_pe is None:
batch_compressed_kv = tmp_compressed_kv
batch_k_pe = tmp_k_pe
else:
else:
batch_compressed_kv = torch.cat((batch_compressed_kv, tmp_compressed_kv), dim=0)
batch_k_pe = torch.cat((batch_k_pe, tmp_k_pe), dim=0)
break
# batch_compressed_kv is [kv_total_len(k_len), kv_lora_rank(512)]
# batch_k_pe is [kv_total_len(k_len), qk_rope_head_dim(64)]
# k_pe_to_check = load_fp16_tensor('./debug/query_0_tp_0_page_0_k_rope', (256,64))
# diff = torch.abs(batch_k_pe[:256] - k_pe_to_check).max()
# mae = torch.mean(torch.abs(batch_k_pe[:256] - k_pe_to_check))
# print("k_pe diff -> ", diff)
# print("k_pe mae -> ", mae)
pe_weights = torch.matmul(q_pe,batch_k_pe.mT)
pe_weights = torch.matmul(q_pe, batch_k_pe.mT)
kv_total_len = kv_page_nums * page_size
# pe_weights_0 = load_fp16_tensor('./debug/query_0_tp_0_pe_attention_weights', (1024,4096))
# pe_weights_0 = pe_weights_0[0:qlen, 0:kv_total_len]
# diff = torch.abs(pe_weights[0] - pe_weights_0).max()
# print("pe_weights[0] diff -> ", diff)
attention_weights = (pe_weights + torch.matmul(q_nope, batch_compressed_kv.mT))
attention_weights = pe_weights + torch.matmul(q_nope, batch_compressed_kv.mT)
# raw_weights = load_fp16_tensor('./debug/query_0_tp_0_raw_attention_weights', (1024, 4096))
# raw_weights = raw_weights[0:qlen, 0:kv_total_len]
@ -634,47 +629,47 @@ def test_torch():
attention_weights = attention_weights * softmax_scale
# attention_weights is [num_heads(128), qlen, k_len]
# attention_weights = attention_weights.transpose(0,1).unsqueeze(0).squeeze(-1).expand(qlen,-1,-1).transpose(0,1)
# attention_masks[i] is [qlen, k_len]
print(attention_weights.shape)
print(attention_masks.shape)
attention_weights = (attention_weights + attention_masks[ :attention_weights.shape[1],:attention_weights.shape[2]])
attention_weights = (
attention_weights + attention_masks[: attention_weights.shape[1], : attention_weights.shape[2]]
)
# attention_weights shape is [num_heads(128), qlen, k_len]
attention_weights = nn.functional.softmax(attention_weights,dim=-1,dtype=weight_type).to(q_pe.dtype)
attention_weights = nn.functional.softmax(attention_weights, dim=-1, dtype=weight_type).to(q_pe.dtype)
# attention_weights_0 = load_fp16_tensor('./debug/query_0_tp_0_attention_weights', (1024, 4096))
# attention_weights_0 = attention_weights_0[0:qlen, 0:kv_total_len]
# diff = torch.abs(attention_weights[0] - attention_weights_0).max()
# print("attention_weights[0] diff -> ", diff)
attn_output = torch.matmul(attention_weights, batch_compressed_kv) # [num_heads(128),qlen, lora_rank(512)]
attn_output = torch.matmul(attention_weights, batch_compressed_kv) # [num_heads(128),qlen, lora_rank(512)]
# out_absorb shape is [num_heads(128), kv_lora_rank(512), v_head_dim(128)]
# o_absorb_check = load_fp16_tensor('./debug/query_0_tp_0_o_absorb', (qlen,kv_lora_rank))
# diff = torch.abs(attn_output[0] - o_absorb_check).max()
# print("o absorb[0] diff -> ", diff)
out_absorb = out_absorb.transpose(1, 2) # [qlen, num_heads(128), v_head_dim(128)]
out_absorb = out_absorb.transpose(1, 2) # [qlen, num_heads(128), v_head_dim(128)]
# q for qlen, n for num_heads, h for v_head_dim, v for kv_lora_rank
attn_output = torch.matmul(attn_output, out_absorb) # [num_heads(128), qlen, v_head_dim(128)]
attn_output = torch.matmul(attn_output, out_absorb) # [num_heads(128), qlen, v_head_dim(128)]
# attn_output_check_0 = load_fp16_tensor('./debug/query_0_tp_0_attention_output', (qlen, nope_size))
# diff = torch.abs(attn_output[0] - attn_output_check_0).max()
# print("attn_output[0] diff -> ", diff)
attn_output = attn_output.transpose(0, 1) # [qlen, num_heads(128), v_head_dim(128)]
attn_output = attn_output.transpose(0, 1) # [qlen, num_heads(128), v_head_dim(128)]
attn_output = attn_output.reshape(qlen, num_heads * nope_size)
w_o = o_proj.weight.view([hidden_size,num_heads * nope_size])
output = torch.matmul(attn_output,w_o.transpose(0,1))
w_o = o_proj.weight.view([hidden_size, num_heads * nope_size])
output = torch.matmul(attn_output, w_o.transpose(0, 1))
output = output.view(qlen, hidden_size)
# output_0_check = load_fp16_tensor('./debug/query_0_tp_0_qlen_output', (qlen, hidden_size))
# h1_o = w_o[:,:128]
# local_o_check = load_fp16_tensor('./debug/query_0_tp_0_local_w_o', (hidden_size, 128))
@ -685,35 +680,32 @@ def test_torch():
# diff = torch.abs(h1_output - output_0_check).max()
# print("h1_output diff -> ", diff)
# output_check = load_fp16_tensor('./debug/output.bin', output.shape)
# diff = torch.abs(output - output_check).max()
# mae = torch.mean(torch.abs(output - output_check))
# print("output diff -> ", diff)
final_attention_output = torch.cat((final_attention_output, output), dim=0)
return final_attention_output
torch_output = torch_attn(
hidden_states,
kv_cache,
position_ids,
page_idx,
page_offset,
attention_masks=attention_masks,
q_indptr=q_indptr,
kv_indices=kv_indices,
kv_indptr=kv_indptr,
bsz_tensors=bsz_tensors,
last_page_len=last_page_len,
layer_idx=0
)
print("Torch Output: ",torch_output)
hidden_states,
kv_cache,
position_ids,
page_idx,
page_offset,
attention_masks=attention_masks,
q_indptr=q_indptr,
kv_indices=kv_indices,
kv_indptr=kv_indptr,
bsz_tensors=bsz_tensors,
last_page_len=last_page_len,
layer_idx=0,
)
print("Torch Output: ", torch_output)
return torch_output
torch.set_printoptions(sci_mode=False, precision=5)
output_cpu = test_cpu_mla()
output_torch = test_torch()
@ -724,11 +716,9 @@ diff = (output_cpu - output_torch).abs()
diff_relative = diff / (output_cpu.abs())
# 把 diff_relative 中的 NaN 替换为 0
diff_relative = torch.where(torch.isnan(diff_relative), torch.zeros_like(diff_relative), diff_relative)
diff_relative_mean = torch.mean(torch.abs(output_cpu-output_torch)) / torch.mean(torch.abs(output_torch))
diff_relative_mean = torch.mean(torch.abs(output_cpu - output_torch)) / torch.mean(torch.abs(output_torch))
print(f'Diff: ave:{diff.mean()}, max:{diff.max()}, min:{diff.min()}, relative_mean:{diff_relative_mean}, relative_max:{diff_relative.max()}, relative_min:{diff_relative.min()}')
print(
f"Diff: ave:{diff.mean()}, max:{diff.max()}, min:{diff.min()}, relative_mean:{diff_relative_mean}, relative_max:{diff_relative.max()}, relative_min:{diff_relative.min()}"
)
assert diff_relative_mean < 2e-1, "CPU and Torch outputs are not close enough!"