mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2026-05-04 22:51:51 +00:00
[fix](test): fix import kt-kernel (#1728)
This commit is contained in:
parent
6fc4080a7d
commit
a8667ddb58
33 changed files with 1063 additions and 1151 deletions
|
|
@ -1,13 +1,14 @@
|
|||
import os,sys
|
||||
import os, sys
|
||||
import time
|
||||
from typing import Optional
|
||||
sys.path.insert(0, os.path.dirname(__file__) + '/../build')
|
||||
import kt_kernel_ext
|
||||
|
||||
sys.path.insert(0, os.path.dirname(__file__) + "/../build")
|
||||
from kt_kernel import kt_kernel_ext
|
||||
from kt_kernel_ext.kvcache import ggml_type
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import init
|
||||
from torch_attention import apply_rotary_pos_emb,DeepseekV2RMSNorm,KDeepSeekV3Cache,DeepseekV3YarnRotaryEmbedding
|
||||
from torch_attention import apply_rotary_pos_emb, DeepseekV2RMSNorm, KDeepSeekV3Cache, DeepseekV3YarnRotaryEmbedding
|
||||
|
||||
|
||||
seed = 42 # 你可以选择任何整数作为种子
|
||||
|
|
@ -19,7 +20,7 @@ kvlen = 0
|
|||
|
||||
|
||||
page_table = range(20)
|
||||
bsz_tensors=torch.tensor([1])
|
||||
bsz_tensors = torch.tensor([1])
|
||||
|
||||
|
||||
page_size = 256
|
||||
|
|
@ -38,8 +39,7 @@ rope_theta = 10000
|
|||
max_qlen = 1024
|
||||
max_kvlen = 4096
|
||||
|
||||
max_position_embeddings = 163840
|
||||
|
||||
max_position_embeddings = 163840
|
||||
|
||||
|
||||
rope_scaling = {
|
||||
|
|
@ -49,17 +49,16 @@ rope_scaling = {
|
|||
"mscale": 1.0,
|
||||
"mscale_all_dim": 1.0,
|
||||
"original_max_position_embeddings": 4096,
|
||||
"type": "yarn"
|
||||
"type": "yarn",
|
||||
}
|
||||
|
||||
|
||||
|
||||
CPUInfer = kt_kernel_ext.CPUInfer(64)
|
||||
validation_iter = 100
|
||||
|
||||
|
||||
q_a_proj = nn.Linear(hidden_size, q_lora_rank, bias=False, dtype=torch.float16)
|
||||
q_b_proj = nn.Linear(q_lora_rank, num_heads * (nope_size+rope_size) , bias=False, dtype=torch.float16)
|
||||
q_b_proj = nn.Linear(q_lora_rank, num_heads * (nope_size + rope_size), bias=False, dtype=torch.float16)
|
||||
kv_a_proj_with_mqa = nn.Linear(hidden_size, kv_lora_rank + rope_size, bias=False, dtype=torch.float16)
|
||||
kv_b_proj = nn.Linear(kv_lora_rank, num_heads * (nope_size + nope_size), bias=False, dtype=torch.float16)
|
||||
o_proj = nn.Linear(num_heads * nope_size, hidden_size, bias=False, dtype=torch.float16)
|
||||
|
|
@ -70,13 +69,11 @@ init.normal_(kv_a_proj_with_mqa.weight, mean=0.0, std=0.02)
|
|||
init.normal_(kv_b_proj.weight, mean=0.0, std=0.02)
|
||||
init.normal_(o_proj.weight, mean=0.0, std=0.02)
|
||||
|
||||
q_a_proj_weight = q_a_proj.weight.to(torch.float16).to('cpu').contiguous()
|
||||
q_b_proj_weight = q_b_proj.weight.to(torch.float16).to('cpu').contiguous()
|
||||
kv_a_proj_with_mqa_weight = kv_a_proj_with_mqa.weight.to('cpu').to(torch.float16).contiguous()
|
||||
kv_b_proj_weight = kv_b_proj.weight.to(torch.float16).to('cpu').contiguous()
|
||||
o_proj_weight = o_proj.weight.to(torch.float16).to('cpu').contiguous()
|
||||
|
||||
|
||||
q_a_proj_weight = q_a_proj.weight.to(torch.float16).to("cpu").contiguous()
|
||||
q_b_proj_weight = q_b_proj.weight.to(torch.float16).to("cpu").contiguous()
|
||||
kv_a_proj_with_mqa_weight = kv_a_proj_with_mqa.weight.to("cpu").to(torch.float16).contiguous()
|
||||
kv_b_proj_weight = kv_b_proj.weight.to(torch.float16).to("cpu").contiguous()
|
||||
o_proj_weight = o_proj.weight.to(torch.float16).to("cpu").contiguous()
|
||||
|
||||
|
||||
config = kt_kernel_ext.mla.MLAConfig(
|
||||
|
|
@ -89,7 +86,7 @@ config = kt_kernel_ext.mla.MLAConfig(
|
|||
)
|
||||
config.max_qlen = max_qlen
|
||||
config.max_kvlen = max_kvlen
|
||||
config.max_position_embeddings = max_position_embeddings
|
||||
config.max_position_embeddings = max_position_embeddings
|
||||
config.rope_scaling_factor = rope_scaling["factor"]
|
||||
config.rope_theta = rope_theta
|
||||
config.rope_scaling_beta_fast = rope_scaling["beta_fast"]
|
||||
|
|
@ -114,30 +111,27 @@ config.w_o_type = ggml_type.FP16
|
|||
config.pool = CPUInfer.backend_
|
||||
|
||||
|
||||
|
||||
mla = kt_kernel_ext.mla.MLA(config)
|
||||
mla.load_weights()
|
||||
mla.set_local_pages(pages_count)
|
||||
|
||||
|
||||
|
||||
input = torch.randn((qlen, hidden_size), dtype=torch.float16).to('cpu').contiguous()
|
||||
input = torch.randn((qlen, hidden_size), dtype=torch.float16).to("cpu").contiguous()
|
||||
|
||||
|
||||
output = torch.zeros((qlen, hidden_size), dtype=torch.float16).to('cpu').contiguous()
|
||||
mla.forward([qlen],[page_table],[kvlen],input.data_ptr(),output.data_ptr())
|
||||
print("CPU MLA Output: ",output)
|
||||
|
||||
output = torch.zeros((qlen, hidden_size), dtype=torch.float16).to("cpu").contiguous()
|
||||
mla.forward([qlen], [page_table], [kvlen], input.data_ptr(), output.data_ptr())
|
||||
print("CPU MLA Output: ", output)
|
||||
|
||||
|
||||
softmax_scale = (nope_size + rope_size) ** -0.5
|
||||
# 1代表的是压缩的kv的头数
|
||||
k_caches = torch.randn(1,pages_count, page_size, 1, kv_lora_rank + rope_size).to(torch.float16)
|
||||
k_caches = torch.randn(1, pages_count, page_size, 1, kv_lora_rank + rope_size).to(torch.float16)
|
||||
kv_cache = KDeepSeekV3Cache(page_size=page_size, kv_lora_rank=kv_lora_rank, k_caches=k_caches)
|
||||
|
||||
q_a_layernorm = DeepseekV2RMSNorm(q_lora_rank)
|
||||
|
||||
x = torch.randn(q_lora_rank, dtype=torch.float16)*100
|
||||
x = torch.randn(q_lora_rank, dtype=torch.float16) * 100
|
||||
print(x)
|
||||
print(q_a_layernorm(x))
|
||||
|
||||
|
|
@ -163,110 +157,114 @@ rotary_emb = DeepseekV3YarnRotaryEmbedding(
|
|||
# last_page_len = [qlen+kvlen,...] layer_idx = 1
|
||||
# position_ids = [kvlen:qlen+kvlen]
|
||||
hidden_states = torch.randn(qlen, hidden_size, dtype=torch.float16)
|
||||
q_indptr = torch.tensor([0,qlen]).to(torch.int32)
|
||||
q_indptr = torch.tensor([0, qlen]).to(torch.int32)
|
||||
|
||||
kv_indptr = torch.tensor([0,(qlen+kvlen+page_size-1)//page_size]).to(torch.int32)
|
||||
kv_indptr = torch.tensor([0, (qlen + kvlen + page_size - 1) // page_size]).to(torch.int32)
|
||||
kv_indices = torch.tensor(range(pages_count)).to(torch.int32)
|
||||
|
||||
page_idx = torch.tensor([i//page_size for i in range(kvlen,kvlen+qlen)] ).to(torch.int32)
|
||||
page_offset = torch.tensor( [i%page_size for i in range(kvlen, kvlen + qlen)]).to(torch.int32)
|
||||
page_idx = torch.tensor([i // page_size for i in range(kvlen, kvlen + qlen)]).to(torch.int32)
|
||||
page_offset = torch.tensor([i % page_size for i in range(kvlen, kvlen + qlen)]).to(torch.int32)
|
||||
|
||||
last_page_len = torch.tensor([(qlen+kvlen)%page_size], device=hidden_states.device)
|
||||
last_page_len = torch.tensor([(qlen + kvlen) % page_size], device=hidden_states.device)
|
||||
position_ids = torch.tensor(range(kvlen, kvlen + qlen)).to(torch.int32)
|
||||
|
||||
|
||||
# 按照行创建 mask [qlen,kvlen+qlen]
|
||||
attention_masks = torch.zeros((qlen, kvlen + qlen), dtype=torch.float16)
|
||||
for i in range(qlen):
|
||||
attention_masks[i, i + kvlen + 1: i + kvlen + qlen] = -65504.0
|
||||
attention_masks[i, i + kvlen + 1 : i + kvlen + qlen] = -65504.0
|
||||
|
||||
|
||||
def torch_attn(hidden_states: torch.Tensor,
|
||||
kv_cache: KDeepSeekV3Cache,
|
||||
position_ids: torch.Tensor,
|
||||
page_idx: torch.Tensor,
|
||||
page_offset: torch.Tensor,
|
||||
attention_masks: Optional[list[torch.Tensor]] = None,
|
||||
q_indptr: Optional[torch.Tensor] = None,
|
||||
kv_indices: Optional[torch.Tensor] = None,
|
||||
kv_indptr: Optional[torch.Tensor] = None,
|
||||
bsz_tensors: Optional[torch.Tensor] = None,
|
||||
last_page_len: Optional[torch.Tensor] = None,
|
||||
layer_idx: Optional[int] = None,
|
||||
):
|
||||
def torch_attn(
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: KDeepSeekV3Cache,
|
||||
position_ids: torch.Tensor,
|
||||
page_idx: torch.Tensor,
|
||||
page_offset: torch.Tensor,
|
||||
attention_masks: Optional[list[torch.Tensor]] = None,
|
||||
q_indptr: Optional[torch.Tensor] = None,
|
||||
kv_indices: Optional[torch.Tensor] = None,
|
||||
kv_indptr: Optional[torch.Tensor] = None,
|
||||
bsz_tensors: Optional[torch.Tensor] = None,
|
||||
last_page_len: Optional[torch.Tensor] = None,
|
||||
layer_idx: Optional[int] = None,
|
||||
):
|
||||
global out_absorb
|
||||
global q_absorb
|
||||
# range bsz_tensors
|
||||
final_attention_output = torch.tensor([], device=hidden_states.device)
|
||||
for i in range(bsz_tensors[0]):
|
||||
batch_num_tokens_tensors = q_indptr[i+1] - q_indptr[i]
|
||||
batch_num_tokens_tensors = q_indptr[i + 1] - q_indptr[i]
|
||||
batch_last_page_len = last_page_len[i]
|
||||
# kv_total_len is kv_len, batch_compressed_kv is compressed_kv, batch_k_pe is k_pe
|
||||
batch_page_idx = page_idx[q_indptr[i]:q_indptr[i+1]]
|
||||
batch_page_offset = page_offset[q_indptr[i]:q_indptr[i+1]]
|
||||
batch_page_idx = page_idx[q_indptr[i] : q_indptr[i + 1]]
|
||||
batch_page_offset = page_offset[q_indptr[i] : q_indptr[i + 1]]
|
||||
# kv_page_nums is the number of pages for the current batch
|
||||
kv_page_nums = kv_indptr[i+1] - kv_indptr[i]
|
||||
kv_page_nums = kv_indptr[i + 1] - kv_indptr[i]
|
||||
# kv_total_len is the total length of the kv cache for the current batch (kv_len for algorithm)
|
||||
kv_total_len = kv_page_nums * page_size
|
||||
if batch_last_page_len is not None:
|
||||
kv_total_len = kv_total_len - (page_size - batch_last_page_len)
|
||||
# print(f"kv_total_len's shape {kv_total_len.shape}")
|
||||
# kv_index is the index of the kv cache pages for the current batch
|
||||
kv_index = kv_indices[kv_indptr[i]:kv_indptr[i+1]]
|
||||
kv_index = kv_indices[kv_indptr[i] : kv_indptr[i + 1]]
|
||||
# we can index [kv_index, page_offset_indices] to get the kv cache for the current batch
|
||||
# from q_indptr[i] to q_indptr[i+1] is the range of the current batch
|
||||
batch_hidden_states = hidden_states[q_indptr[i]:q_indptr[i+1]]
|
||||
batch_position_ids = position_ids[q_indptr[i]:q_indptr[i+1]]
|
||||
batch_hidden_states = hidden_states[q_indptr[i] : q_indptr[i + 1]]
|
||||
batch_position_ids = position_ids[q_indptr[i] : q_indptr[i + 1]]
|
||||
qlen, _ = batch_hidden_states.size()
|
||||
# print("qlen -> ", qlen)
|
||||
q_lora = q_a_proj(batch_hidden_states)
|
||||
print('q_a_proj',q_a_proj.weight)
|
||||
print('q_lora',q_lora)
|
||||
|
||||
print("q_a_proj", q_a_proj.weight)
|
||||
print("q_lora", q_lora)
|
||||
|
||||
q = q_b_proj(q_a_layernorm(q_lora))
|
||||
print('q_b_proj',q_b_proj.weight)
|
||||
print("q_b_proj", q_b_proj.weight)
|
||||
# for v3, bsz, qlen, num_heads(128), qk_head_dim(192=128(nope)+64(rope))
|
||||
q = q.view(qlen, num_heads, nope_size+rope_size)
|
||||
q = q.view(qlen, num_heads, nope_size + rope_size)
|
||||
# q_nope is [qlen, num_heads(128), qk_nope_head_dim(128)]
|
||||
# q_pe is [qlen, num_heads(128), qk_rope_head_dim(64)]
|
||||
q_nope, q_pe = torch.split(
|
||||
q, [nope_size, rope_size], dim=-1
|
||||
)
|
||||
print('q_nope',q_nope)
|
||||
print('q_pe',q_pe)
|
||||
q_nope, q_pe = torch.split(q, [nope_size, rope_size], dim=-1)
|
||||
print("q_nope", q_nope)
|
||||
print("q_pe", q_pe)
|
||||
# compressed_kv is [qlen, kv_lora_rank(512) + rope(64)]
|
||||
compressed_kv = kv_a_proj_with_mqa(batch_hidden_states)
|
||||
# compressed_kv is [qlen, kv_lora_rank(512)], k_pe is [qlen, rope(64)]
|
||||
compressed_kv, k_pe = torch.split(
|
||||
compressed_kv, [kv_lora_rank, rope_size], dim=-1
|
||||
)
|
||||
compressed_kv, k_pe = torch.split(compressed_kv, [kv_lora_rank, rope_size], dim=-1)
|
||||
compressed_kv = compressed_kv.contiguous()
|
||||
compressed_kv = kv_a_layernorm(compressed_kv)
|
||||
# k_pe is [qlen, 1, qk_rope_head_dim(64)]
|
||||
print('compressed_kv ',compressed_kv)
|
||||
print('k_pe ',k_pe)
|
||||
print("compressed_kv ", compressed_kv)
|
||||
print("k_pe ", k_pe)
|
||||
k_pe = k_pe.view(qlen, 1, rope_size)
|
||||
# compressed_kv is [qlen, 1, kv_lora_rank(512)]
|
||||
compressed_kv = compressed_kv.view(qlen, 1, kv_lora_rank)
|
||||
|
||||
|
||||
cos, sin = rotary_emb(q_pe, batch_position_ids)
|
||||
# print(f"q_pe shape{q_pe.shape}, k_pe shape {k_pe.shape}")
|
||||
q_pe, k_pe = apply_rotary_pos_emb(q_pe.unsqueeze(0), k_pe.unsqueeze(0), cos, sin, unsqueeze_dim=1)
|
||||
q_pe = q_pe.squeeze(0)
|
||||
# q_pe is [num_heads(128), qlen, qk_rope_head_dim(64)]
|
||||
q_pe.transpose_(0, 1)
|
||||
q_pe.transpose_(0, 1)
|
||||
if kv_cache is not None:
|
||||
cache_kwargs = {"sin": sin, "cos": cos, "page_idx": batch_page_idx, "page_offset": batch_page_offset} # Specific to RoPE models
|
||||
compressed_kv_with_k_pe = kv_cache.update(compressed_kv.unsqueeze(0), k_pe, layer_idx, batch_page_idx, batch_page_offset, cache_kwargs)
|
||||
compressed_kv = compressed_kv_with_k_pe [:, :, :, :kv_lora_rank].view(-1, page_size, kv_lora_rank)
|
||||
k_pe = compressed_kv_with_k_pe [:, :, :, kv_lora_rank:].view(-1, page_size, rope_size)
|
||||
cache_kwargs = {
|
||||
"sin": sin,
|
||||
"cos": cos,
|
||||
"page_idx": batch_page_idx,
|
||||
"page_offset": batch_page_offset,
|
||||
} # Specific to RoPE models
|
||||
compressed_kv_with_k_pe = kv_cache.update(
|
||||
compressed_kv.unsqueeze(0), k_pe, layer_idx, batch_page_idx, batch_page_offset, cache_kwargs
|
||||
)
|
||||
compressed_kv = compressed_kv_with_k_pe[:, :, :, :kv_lora_rank].view(-1, page_size, kv_lora_rank)
|
||||
k_pe = compressed_kv_with_k_pe[:, :, :, kv_lora_rank:].view(-1, page_size, rope_size)
|
||||
# q_absorb is [num_heads(128), qk_nope_head_dim(128), kv_lora_rank(512)]
|
||||
# out_absorb is [num_heads(128), kv_lora_rank(512), v_head_dim(128)] v_head_dim is also the nope dim
|
||||
# q_absorb, out_absorb = get_absorbed()
|
||||
# q_nope is [num_heads(128), qlen, qk_nope_head_dim(128)]
|
||||
q_nope = q_nope.transpose(0, 1) # qlen is 1, no GPU overhead, same below
|
||||
q_nope = q_nope.transpose(0, 1) # qlen is 1, no GPU overhead, same below
|
||||
# q_nope is [num_heads(128), qlen, kv_lora_rank(512)]
|
||||
q_nope = torch.matmul(q_nope, q_absorb) # batched MM
|
||||
q_nope = torch.matmul(q_nope, q_absorb) # batched MM
|
||||
|
||||
# # q_nope is [qlen, num_heads(128), kv_lora_rank(512)]
|
||||
# q_nope = q_nope.transpose(0, 1)
|
||||
|
|
@ -281,7 +279,7 @@ def torch_attn(hidden_states: torch.Tensor,
|
|||
if batch_compressed_kv is None or batch_k_pe is None:
|
||||
batch_compressed_kv = tmp_compressed_kv
|
||||
batch_k_pe = tmp_k_pe
|
||||
else:
|
||||
else:
|
||||
batch_compressed_kv = torch.cat((batch_compressed_kv, tmp_compressed_kv), dim=0)
|
||||
batch_k_pe = torch.cat((batch_k_pe, tmp_k_pe), dim=0)
|
||||
kv_total_len -= page_size
|
||||
|
|
@ -291,57 +289,48 @@ def torch_attn(hidden_states: torch.Tensor,
|
|||
if batch_compressed_kv is None or batch_k_pe is None:
|
||||
batch_compressed_kv = tmp_compressed_kv
|
||||
batch_k_pe = tmp_k_pe
|
||||
else:
|
||||
else:
|
||||
batch_compressed_kv = torch.cat((batch_compressed_kv, tmp_compressed_kv), dim=0)
|
||||
batch_k_pe = torch.cat((batch_k_pe, tmp_k_pe), dim=0)
|
||||
break
|
||||
# batch_compressed_kv is [kv_total_len(k_len), kv_lora_rank(512)]
|
||||
# batch_k_pe is [kv_total_len(k_len), qk_rope_head_dim(64)]
|
||||
pe_weights = torch.matmul(q_pe,batch_k_pe.mT)
|
||||
print('pe_weights',pe_weights)
|
||||
pe_weights = torch.matmul(q_pe, batch_k_pe.mT)
|
||||
print("pe_weights", pe_weights)
|
||||
attention_weights = (pe_weights + torch.matmul(q_nope, batch_compressed_kv.mT)) * softmax_scale
|
||||
# attention_weights is [num_heads(128), qlen, k_len]
|
||||
|
||||
|
||||
# attention_weights = attention_weights.transpose(0,1).unsqueeze(0).squeeze(-1).expand(qlen,-1,-1).transpose(0,1)
|
||||
|
||||
|
||||
# attention_masks[i] is [qlen, k_len]
|
||||
|
||||
attention_weights = (attention_weights + attention_masks[i])
|
||||
|
||||
attention_weights = attention_weights + attention_masks[i]
|
||||
# attention_weights shape is [num_heads(128), qlen, k_len]
|
||||
attention_weights = nn.functional.softmax(attention_weights,dim=-1,dtype=torch.float16).to(q_pe.dtype)
|
||||
attn_output = torch.matmul(attention_weights, batch_compressed_kv) # [num_heads(128),qlen, lora_rank(512)]
|
||||
attention_weights = nn.functional.softmax(attention_weights, dim=-1, dtype=torch.float16).to(q_pe.dtype)
|
||||
attn_output = torch.matmul(attention_weights, batch_compressed_kv) # [num_heads(128),qlen, lora_rank(512)]
|
||||
# out_absorb shape is [num_heads(128), kv_lora_rank(512), v_head_dim(128)]
|
||||
out_absorb = out_absorb.transpose(1,2)
|
||||
out_absorb = out_absorb.transpose(1, 2)
|
||||
# q for qlen, n for num_heads, h for v_head_dim, v for kv_lora_rank
|
||||
attn_output = torch.matmul(attn_output, out_absorb) # [num_heads(128), qlen, v_head_dim(128)]
|
||||
attn_output = attn_output.transpose(0, 1) # [qlen, num_heads(128), v_head_dim(128)]
|
||||
attn_output = torch.matmul(attn_output, out_absorb) # [num_heads(128), qlen, v_head_dim(128)]
|
||||
attn_output = attn_output.transpose(0, 1) # [qlen, num_heads(128), v_head_dim(128)]
|
||||
attn_output = attn_output.reshape(qlen, num_heads * nope_size)
|
||||
attn_output = o_proj(attn_output)
|
||||
final_attention_output = torch.cat((final_attention_output, attn_output), dim=0)
|
||||
return final_attention_output
|
||||
|
||||
|
||||
|
||||
torch_output = torch_attn(
|
||||
input,
|
||||
kv_cache,
|
||||
position_ids,
|
||||
page_idx,
|
||||
page_offset,
|
||||
attention_masks=attention_masks,
|
||||
q_indptr=q_indptr,
|
||||
kv_indices=kv_indices,
|
||||
kv_indptr=kv_indptr,
|
||||
bsz_tensors=bsz_tensors,
|
||||
last_page_len=last_page_len,
|
||||
layer_idx=0
|
||||
)
|
||||
print("Torch Output: ",torch_output)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
input,
|
||||
kv_cache,
|
||||
position_ids,
|
||||
page_idx,
|
||||
page_offset,
|
||||
attention_masks=attention_masks,
|
||||
q_indptr=q_indptr,
|
||||
kv_indices=kv_indices,
|
||||
kv_indptr=kv_indptr,
|
||||
bsz_tensors=bsz_tensors,
|
||||
last_page_len=last_page_len,
|
||||
layer_idx=0,
|
||||
)
|
||||
print("Torch Output: ", torch_output)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue