[fix](test): fix import kt-kernel (#1728)

This commit is contained in:
ErvinXie 2025-12-17 19:46:32 +08:00 committed by GitHub
parent 6fc4080a7d
commit a8667ddb58
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
33 changed files with 1063 additions and 1151 deletions

View file

@ -1,8 +1,9 @@
import os, sys
import time
os.environ["BLAS_NUM_THREADS"] = "1"
sys.path.insert(0, os.path.dirname(__file__) + "/../build")
import kt_kernel_ext
from kt_kernel import kt_kernel_ext
from kt_kernel_ext.kvcache import ggml_type
import torch
import logging
@ -20,6 +21,7 @@ from transformers import (
logger = logging.getLogger("reader")
from gguf.gguf_reader import GGUFReader
# load_layers = 6
load_layers = None
CPUInfer = kt_kernel_ext.CPUInfer(304)
@ -284,22 +286,21 @@ def build_moegate(layer_idx, json_config, gguf_weights):
json_config["topk_group"],
)
config.routed_scaling_factor = json_config['routed_scaling_factor']
config.routed_scaling_factor = json_config["routed_scaling_factor"]
config.pool = CPUInfer.backend_
weight,weight_type = get_torch_tensor_and_type_from_gguf(gguf_weights, f"blk.{layer_idx}.ffn_gate_inp.weight")
weight, weight_type = get_torch_tensor_and_type_from_gguf(gguf_weights, f"blk.{layer_idx}.ffn_gate_inp.weight")
config.weight = weight.data_ptr()
config.weight_type = type_to_ggml_type(weight_type)
bias,bias_type = get_torch_tensor_and_type_from_gguf(gguf_weights, f"blk.{layer_idx}.exp_probs_b.bias")
bias, bias_type = get_torch_tensor_and_type_from_gguf(gguf_weights, f"blk.{layer_idx}.exp_probs_b.bias")
config.e_score_correction_bias = bias.data_ptr()
config.e_score_correction_bias_type = type_to_ggml_type(bias_type)
gate = kt_kernel_ext.gate.MoEGate(config)
return gate
def build_llm(json_config, gguf_weights):
@ -312,15 +313,15 @@ def build_llm(json_config, gguf_weights):
general_config.n_shared_experts = json_config["n_shared_experts"]
general_config.max_qlen = max_qlen
lm_heads,lm_heads_type = get_torch_tensor_and_type_from_gguf(gguf_weights, "output.weight")
lm_heads, lm_heads_type = get_torch_tensor_and_type_from_gguf(gguf_weights, "output.weight")
general_config.lm_heads_ptr = lm_heads.data_ptr()
general_config.lm_heads_type = type_to_ggml_type(lm_heads_type)
output_norm, output_norm_type = get_torch_tensor_and_type_from_gguf(gguf_weights, "output_norm.weight")
general_config.norm_weights_ptr = output_norm.data_ptr()
general_config.norm_weights_type = type_to_ggml_type(output_norm_type)
general_config.norm_weights_type = type_to_ggml_type(output_norm_type)
token_embd,token_embd_type = get_torch_tensor_and_type_from_gguf(weights, "token_embd.weight")
token_embd, token_embd_type = get_torch_tensor_and_type_from_gguf(weights, "token_embd.weight")
general_config.token_embd_ptr = token_embd.data_ptr()
general_config.token_embd_type = type_to_ggml_type(token_embd_type)
@ -330,12 +331,11 @@ def build_llm(json_config, gguf_weights):
model = kt_kernel_ext.DeepseekV3Model(general_config)
llm.model = model
decoder_layers = []
real_load_layers = json_config["num_hidden_layers"] if load_layers is None else load_layers
for i in range(real_load_layers):
layer = kt_kernel_ext.DeepseekV3DecoderLayer(general_config,i)
layer = kt_kernel_ext.DeepseekV3DecoderLayer(general_config, i)
attn_norm, attn_norm_type = get_torch_tensor_and_type_from_gguf(gguf_weights, f"blk.{i}.attn_norm.weight")
ffn_norm, ffn_norm_type = get_torch_tensor_and_type_from_gguf(gguf_weights, f"blk.{i}.ffn_norm.weight")
@ -351,11 +351,11 @@ def build_llm(json_config, gguf_weights):
layer.ffn = build_ffn(i, json_config, gguf_weights)
decoder_layers.append(layer)
model.layers = decoder_layers
model.layers = decoder_layers
return llm
safetensor_path = '/home/bd/models/DeepSeek-R1'
safetensor_path = "/home/bd/models/DeepSeek-R1"
json_path = os.path.join(safetensor_path, "config.json")
json_config = json.load(open(json_path, "r"))
print(json_config)
@ -368,11 +368,11 @@ weights = dict(sorted(weights.items()))
for name, t in weights.items():
# if not name.startswith("blk"):
# if name.startswith("blk.10."):
# if "ffn_gate." in name:
# print(f"Found weight: {t.name}, Shape: {t.shape}, Type: {t.tensor_type.name}, Size: {t.n_elements}")
# if "ffn_gate." in name:
# print(f"Found weight: {t.name}, Shape: {t.shape}, Type: {t.tensor_type.name}, Size: {t.n_elements}")
print(f"Found weight: {t.name}, Shape: {t.shape}, Type: {t.tensor_type.name}, Size: {t.n_elements}")
print("Building LLM ...")
print("Building LLM ...")
load_start_time = time.perf_counter()
llm = build_llm(json_config, weights)
load_end_time = time.perf_counter()
@ -389,22 +389,20 @@ config = AutoConfig.from_pretrained(safetensor_path, trust_remote_code=True)
force_think = False
output_logits = torch.zeros((max_qlen, json_config['vocab_size']), dtype=torch.float32)
output_logits = torch.zeros((max_qlen, json_config["vocab_size"]), dtype=torch.float32)
def start_chat(content=None):
if content is None:
content = input("Chat: ")
messages = [{"role": "user", "content": content}]
input_tensor = tokenizer.apply_chat_template(
messages, add_generation_prompt=True, return_tensors="pt"
)
input_tensor = tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_tensors="pt")
if force_think:
token_thinks = torch.tensor([tokenizer.encode("<think>\\n",add_special_tokens=False)],device=input_tensor.device)
input_tensor = torch.cat(
[input_tensor, token_thinks], dim=1
token_thinks = torch.tensor(
[tokenizer.encode("<think>\\n", add_special_tokens=False)], device=input_tensor.device
)
input_tensor = torch.cat([input_tensor, token_thinks], dim=1)
input_tensor = input_tensor.squeeze(0) # Add batch dimension
print(f"Input tensor: {input_tensor}, type {input_tensor.dtype}, shape {input_tensor.shape}")
@ -415,34 +413,36 @@ def start_chat(content=None):
stream = TextStreamer(tokenizer)
qlen = input_tensor.shape[0]
qlens = [qlen-kvlen]
qlens = [qlen - kvlen]
kvlens = [kvlen]
page_tables = [list(range(pages_count))]
start_time = time.perf_counter()
llm.forward(qlens,page_tables, kvlens, input_tensor[kvlen:].data_ptr(), output_logits.data_ptr())
llm.forward(qlens, page_tables, kvlens, input_tensor[kvlen:].data_ptr(), output_logits.data_ptr())
end_time = time.perf_counter()
print(f"Forward time: {end_time - start_time:.4f} seconds, tps: {qlens[0] / (end_time - start_time)} tokens/sec")
print(
f"Forward time: {end_time - start_time:.4f} seconds, tps: {qlens[0] / (end_time - start_time)} tokens/sec"
)
logits = output_logits[0]
# print(logits)
# sample
# sample
next_token = torch.argmax(logits).item()
# print(f"Next token: {next_token}, {tokenizer.decode(next_token)}")
kvlen = input_tensor.shape[0]
input_tensor = torch.cat((input_tensor, torch.tensor([next_token])), dim=-1)
if next_token == tokenizer.eos_token_id or tokenizer.decode(next_token) == '<|im_end|>':
if next_token == tokenizer.eos_token_id or tokenizer.decode(next_token) == "<|im_end|>":
stream.end()
break
else:
stream.put(torch.tensor([next_token]))
job_id = 0
while True:
try:
# ---------- 让用户决定是否继续 ----------
choice = input(
"\n【回车】开始对话 | 输入 1 读取文件 | 输入 q/quit/exit 退出程序: "
).strip().lower()
choice = input("\n【回车】开始对话 | 输入 1 读取文件 | 输入 q/quit/exit 退出程序: ").strip().lower()
if choice in {"q", "quit", "exit"}:
print("收到退出指令,程序结束。")
break
@ -466,15 +466,4 @@ while True:
print(f"\n发生错误:{e}\n已终止对话 #{job_id},马上重启…")
logger.error(f"Error in job {job_id}: {e}", exc_info=True)
finally:
job_id += 1 # 不管中断与否,都给下一任务换编号
job_id += 1 # 不管中断与否,都给下一任务换编号