[fix](test): fix import kt-kernel (#1728)

This commit is contained in:
ErvinXie 2025-12-17 19:46:32 +08:00 committed by GitHub
parent 6fc4080a7d
commit a8667ddb58
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
33 changed files with 1063 additions and 1151 deletions

View file

@ -1,8 +1,9 @@
import os, sys
import time
os.environ["BLAS_NUM_THREADS"] = "1"
sys.path.insert(0, os.path.dirname(__file__) + "/../build")
import kt_kernel_ext
from kt_kernel import kt_kernel_ext
from kt_kernel_ext.kvcache import ggml_type
import torch
import logging
@ -188,7 +189,6 @@ def build_mla(layer_idx, json_config, gguf_weights):
config.layer_idx = layer_idx
config.pool = CPUInfer.backend_
config.page_count = pages_count
if q_a_type == "F32":
mla = kt_kernel_ext.mla.MLA_F32(config)
@ -284,22 +284,21 @@ def build_moegate(layer_idx, json_config, gguf_weights):
json_config["topk_group"],
)
config.routed_scaling_factor = json_config['routed_scaling_factor']
config.routed_scaling_factor = json_config["routed_scaling_factor"]
config.pool = CPUInfer.backend_
weight,weight_type = get_torch_tensor_and_type_from_gguf(gguf_weights, f"blk.{layer_idx}.ffn_gate_inp.weight")
weight, weight_type = get_torch_tensor_and_type_from_gguf(gguf_weights, f"blk.{layer_idx}.ffn_gate_inp.weight")
config.weight = weight.data_ptr()
config.weight_type = type_to_ggml_type(weight_type)
bias,bias_type = get_torch_tensor_and_type_from_gguf(gguf_weights, f"blk.{layer_idx}.exp_probs_b.bias")
bias, bias_type = get_torch_tensor_and_type_from_gguf(gguf_weights, f"blk.{layer_idx}.exp_probs_b.bias")
config.e_score_correction_bias = bias.data_ptr()
config.e_score_correction_bias_type = type_to_ggml_type(bias_type)
gate = kt_kernel_ext.gate.MoEGate(config)
return gate
def build_llm(json_config, gguf_weights):
@ -312,15 +311,15 @@ def build_llm(json_config, gguf_weights):
general_config.n_shared_experts = json_config["n_shared_experts"]
general_config.max_qlen = max_qlen
lm_heads,lm_heads_type = get_torch_tensor_and_type_from_gguf(gguf_weights, "output.weight")
lm_heads, lm_heads_type = get_torch_tensor_and_type_from_gguf(gguf_weights, "output.weight")
general_config.lm_heads_ptr = lm_heads.data_ptr()
general_config.lm_heads_type = type_to_ggml_type(lm_heads_type)
output_norm, output_norm_type = get_torch_tensor_and_type_from_gguf(gguf_weights, "output_norm.weight")
general_config.norm_weights_ptr = output_norm.data_ptr()
general_config.norm_weights_type = type_to_ggml_type(output_norm_type)
general_config.norm_weights_type = type_to_ggml_type(output_norm_type)
token_embd,token_embd_type = get_torch_tensor_and_type_from_gguf(weights, "token_embd.weight")
token_embd, token_embd_type = get_torch_tensor_and_type_from_gguf(weights, "token_embd.weight")
general_config.token_embd_ptr = token_embd.data_ptr()
general_config.token_embd_type = type_to_ggml_type(token_embd_type)
@ -330,12 +329,11 @@ def build_llm(json_config, gguf_weights):
model = kt_kernel_ext.DeepseekV3Model(general_config)
llm.model = model
decoder_layers = []
for i in range(json_config["num_hidden_layers"]):
# for i in range(6):
# for i in [0,1,2,3,4,5,6,7,8,9,10]:
layer = kt_kernel_ext.DeepseekV3DecoderLayer(general_config,i)
# for i in range(6):
# for i in [0,1,2,3,4,5,6,7,8,9,10]:
layer = kt_kernel_ext.DeepseekV3DecoderLayer(general_config, i)
attn_norm, attn_norm_type = get_torch_tensor_and_type_from_gguf(gguf_weights, f"blk.{i}.attn_norm.weight")
ffn_norm, ffn_norm_type = get_torch_tensor_and_type_from_gguf(gguf_weights, f"blk.{i}.ffn_norm.weight")
@ -351,11 +349,11 @@ def build_llm(json_config, gguf_weights):
layer.ffn = build_ffn(i, json_config, gguf_weights)
decoder_layers.append(layer)
model.layers = decoder_layers
model.layers = decoder_layers
return llm
safetensor_path = '/home/bd/models/DeepSeek-R1'
safetensor_path = "/home/bd/models/DeepSeek-R1"
json_path = os.path.join(safetensor_path, "config.json")
json_config = json.load(open(json_path, "r"))
print(json_config)
@ -368,8 +366,8 @@ weights = dict(sorted(weights.items()))
for name, t in weights.items():
# if not name.startswith("blk"):
# if name.startswith("blk.10."):
# if "ffn_gate." in name:
# print(f"Found weight: {t.name}, Shape: {t.shape}, Type: {t.tensor_type.name}, Size: {t.n_elements}")
# if "ffn_gate." in name:
# print(f"Found weight: {t.name}, Shape: {t.shape}, Type: {t.tensor_type.name}, Size: {t.n_elements}")
print(f"Found weight: {t.name}, Shape: {t.shape}, Type: {t.tensor_type.name}, Size: {t.n_elements}")
print("Building LLM ...")
llm = build_llm(json_config, weights)
@ -384,7 +382,7 @@ prompt_file = None
force_think = False
output_logits = torch.zeros((max_qlen, json_config['vocab_size']), dtype=torch.float32)
output_logits = torch.zeros((max_qlen, json_config["vocab_size"]), dtype=torch.float32)
def start_chat():
@ -411,16 +409,14 @@ def start_chat():
content = "Please write a piece of quicksort code in C++."
elif os.path.isfile(content):
content = open(content, "r").read()
messages = [{"role": "user", "content": content}]
input_tensor = tokenizer.apply_chat_template(
messages, add_generation_prompt=True, return_tensors="pt"
)
input_tensor = tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_tensors="pt")
if force_think:
token_thinks = torch.tensor([tokenizer.encode("<think>\\n",add_special_tokens=False)],device=input_tensor.device)
input_tensor = torch.cat(
[input_tensor, token_thinks], dim=1
token_thinks = torch.tensor(
[tokenizer.encode("<think>\\n", add_special_tokens=False)], device=input_tensor.device
)
input_tensor = torch.cat([input_tensor, token_thinks], dim=1)
input_tensor = input_tensor.squeeze(0) # Add batch dimension
print(f"Input tensor: {input_tensor}, type {input_tensor.dtype}, shape {input_tensor.shape}")
@ -431,28 +427,27 @@ def start_chat():
qlens = [qlen]
kvlens = [0]
page_tables = [list(range(pages_count))]
llm.forward(qlens,page_tables, kvlens, input_tensor.data_ptr(), output_logits.data_ptr())
llm.forward(qlens, page_tables, kvlens, input_tensor.data_ptr(), output_logits.data_ptr())
logits = output_logits[0]
# print(logits)
# sample
# sample
next_token = torch.argmax(logits).item()
# print(f"Next token: {next_token}, {tokenizer.decode(next_token)}")
input_tensor = torch.cat((input_tensor, torch.tensor([next_token])), dim=-1)
if next_token == tokenizer.eos_token_id or tokenizer.decode(next_token) == '<|im_end|>':
if next_token == tokenizer.eos_token_id or tokenizer.decode(next_token) == "<|im_end|>":
print(stream.end(), end="", flush=True)
break
else:
print(stream.put(torch.tensor([next_token])), end="", flush=True)
job_id = 0
while True:
try:
# ---------- 让用户决定是否继续 ----------
choice = input(
"\n【回车】开始对话 | 输入 q/quit/exit 退出程序: "
).strip().lower()
choice = input("\n【回车】开始对话 | 输入 q/quit/exit 退出程序: ").strip().lower()
if choice in {"q", "quit", "exit"}:
print("收到退出指令,程序结束。")
break
@ -464,15 +459,4 @@ while True:
# 随时 Ctrl-C放弃当前任务并重启
print(f"\n检测到 Ctrl-C已终止对话 #{job_id},马上重启…")
finally:
job_id += 1 # 不管中断与否,都给下一任务换编号
job_id += 1 # 不管中断与否,都给下一任务换编号