mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2026-04-30 21:00:07 +00:00
[fix](test): fix import kt-kernel (#1728)
This commit is contained in:
parent
6fc4080a7d
commit
a8667ddb58
33 changed files with 1063 additions and 1151 deletions
|
|
@ -1,8 +1,9 @@
|
|||
import os, sys
|
||||
import time
|
||||
|
||||
os.environ["BLAS_NUM_THREADS"] = "1"
|
||||
sys.path.insert(0, os.path.dirname(__file__) + "/../build")
|
||||
import kt_kernel_ext
|
||||
from kt_kernel import kt_kernel_ext
|
||||
from kt_kernel_ext.kvcache import ggml_type
|
||||
import torch
|
||||
import logging
|
||||
|
|
@ -20,6 +21,7 @@ from transformers import (
|
|||
logger = logging.getLogger("reader")
|
||||
|
||||
from gguf.gguf_reader import GGUFReader
|
||||
|
||||
# load_layers = 6
|
||||
load_layers = None
|
||||
CPUInfer = kt_kernel_ext.CPUInfer(304)
|
||||
|
|
@ -284,22 +286,21 @@ def build_moegate(layer_idx, json_config, gguf_weights):
|
|||
json_config["topk_group"],
|
||||
)
|
||||
|
||||
config.routed_scaling_factor = json_config['routed_scaling_factor']
|
||||
config.routed_scaling_factor = json_config["routed_scaling_factor"]
|
||||
|
||||
config.pool = CPUInfer.backend_
|
||||
|
||||
weight,weight_type = get_torch_tensor_and_type_from_gguf(gguf_weights, f"blk.{layer_idx}.ffn_gate_inp.weight")
|
||||
weight, weight_type = get_torch_tensor_and_type_from_gguf(gguf_weights, f"blk.{layer_idx}.ffn_gate_inp.weight")
|
||||
config.weight = weight.data_ptr()
|
||||
config.weight_type = type_to_ggml_type(weight_type)
|
||||
|
||||
bias,bias_type = get_torch_tensor_and_type_from_gguf(gguf_weights, f"blk.{layer_idx}.exp_probs_b.bias")
|
||||
bias, bias_type = get_torch_tensor_and_type_from_gguf(gguf_weights, f"blk.{layer_idx}.exp_probs_b.bias")
|
||||
config.e_score_correction_bias = bias.data_ptr()
|
||||
config.e_score_correction_bias_type = type_to_ggml_type(bias_type)
|
||||
|
||||
gate = kt_kernel_ext.gate.MoEGate(config)
|
||||
|
||||
|
||||
return gate
|
||||
|
||||
|
||||
|
||||
def build_llm(json_config, gguf_weights):
|
||||
|
|
@ -312,15 +313,15 @@ def build_llm(json_config, gguf_weights):
|
|||
general_config.n_shared_experts = json_config["n_shared_experts"]
|
||||
general_config.max_qlen = max_qlen
|
||||
|
||||
lm_heads,lm_heads_type = get_torch_tensor_and_type_from_gguf(gguf_weights, "output.weight")
|
||||
lm_heads, lm_heads_type = get_torch_tensor_and_type_from_gguf(gguf_weights, "output.weight")
|
||||
general_config.lm_heads_ptr = lm_heads.data_ptr()
|
||||
general_config.lm_heads_type = type_to_ggml_type(lm_heads_type)
|
||||
|
||||
output_norm, output_norm_type = get_torch_tensor_and_type_from_gguf(gguf_weights, "output_norm.weight")
|
||||
general_config.norm_weights_ptr = output_norm.data_ptr()
|
||||
general_config.norm_weights_type = type_to_ggml_type(output_norm_type)
|
||||
general_config.norm_weights_type = type_to_ggml_type(output_norm_type)
|
||||
|
||||
token_embd,token_embd_type = get_torch_tensor_and_type_from_gguf(weights, "token_embd.weight")
|
||||
token_embd, token_embd_type = get_torch_tensor_and_type_from_gguf(weights, "token_embd.weight")
|
||||
general_config.token_embd_ptr = token_embd.data_ptr()
|
||||
general_config.token_embd_type = type_to_ggml_type(token_embd_type)
|
||||
|
||||
|
|
@ -330,12 +331,11 @@ def build_llm(json_config, gguf_weights):
|
|||
model = kt_kernel_ext.DeepseekV3Model(general_config)
|
||||
llm.model = model
|
||||
|
||||
|
||||
decoder_layers = []
|
||||
real_load_layers = json_config["num_hidden_layers"] if load_layers is None else load_layers
|
||||
|
||||
for i in range(real_load_layers):
|
||||
layer = kt_kernel_ext.DeepseekV3DecoderLayer(general_config,i)
|
||||
layer = kt_kernel_ext.DeepseekV3DecoderLayer(general_config, i)
|
||||
attn_norm, attn_norm_type = get_torch_tensor_and_type_from_gguf(gguf_weights, f"blk.{i}.attn_norm.weight")
|
||||
ffn_norm, ffn_norm_type = get_torch_tensor_and_type_from_gguf(gguf_weights, f"blk.{i}.ffn_norm.weight")
|
||||
|
||||
|
|
@ -351,11 +351,11 @@ def build_llm(json_config, gguf_weights):
|
|||
layer.ffn = build_ffn(i, json_config, gguf_weights)
|
||||
decoder_layers.append(layer)
|
||||
|
||||
model.layers = decoder_layers
|
||||
model.layers = decoder_layers
|
||||
return llm
|
||||
|
||||
|
||||
safetensor_path = '/home/bd/models/DeepSeek-R1'
|
||||
safetensor_path = "/home/bd/models/DeepSeek-R1"
|
||||
json_path = os.path.join(safetensor_path, "config.json")
|
||||
json_config = json.load(open(json_path, "r"))
|
||||
print(json_config)
|
||||
|
|
@ -368,11 +368,11 @@ weights = dict(sorted(weights.items()))
|
|||
for name, t in weights.items():
|
||||
# if not name.startswith("blk"):
|
||||
# if name.startswith("blk.10."):
|
||||
# if "ffn_gate." in name:
|
||||
# print(f"Found weight: {t.name}, Shape: {t.shape}, Type: {t.tensor_type.name}, Size: {t.n_elements}")
|
||||
# if "ffn_gate." in name:
|
||||
# print(f"Found weight: {t.name}, Shape: {t.shape}, Type: {t.tensor_type.name}, Size: {t.n_elements}")
|
||||
print(f"Found weight: {t.name}, Shape: {t.shape}, Type: {t.tensor_type.name}, Size: {t.n_elements}")
|
||||
|
||||
print("Building LLM ...")
|
||||
|
||||
print("Building LLM ...")
|
||||
load_start_time = time.perf_counter()
|
||||
llm = build_llm(json_config, weights)
|
||||
load_end_time = time.perf_counter()
|
||||
|
|
@ -389,22 +389,20 @@ config = AutoConfig.from_pretrained(safetensor_path, trust_remote_code=True)
|
|||
force_think = False
|
||||
|
||||
|
||||
output_logits = torch.zeros((max_qlen, json_config['vocab_size']), dtype=torch.float32)
|
||||
output_logits = torch.zeros((max_qlen, json_config["vocab_size"]), dtype=torch.float32)
|
||||
|
||||
|
||||
def start_chat(content=None):
|
||||
if content is None:
|
||||
content = input("Chat: ")
|
||||
|
||||
|
||||
messages = [{"role": "user", "content": content}]
|
||||
input_tensor = tokenizer.apply_chat_template(
|
||||
messages, add_generation_prompt=True, return_tensors="pt"
|
||||
)
|
||||
input_tensor = tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_tensors="pt")
|
||||
if force_think:
|
||||
token_thinks = torch.tensor([tokenizer.encode("<think>\\n",add_special_tokens=False)],device=input_tensor.device)
|
||||
input_tensor = torch.cat(
|
||||
[input_tensor, token_thinks], dim=1
|
||||
token_thinks = torch.tensor(
|
||||
[tokenizer.encode("<think>\\n", add_special_tokens=False)], device=input_tensor.device
|
||||
)
|
||||
input_tensor = torch.cat([input_tensor, token_thinks], dim=1)
|
||||
input_tensor = input_tensor.squeeze(0) # Add batch dimension
|
||||
|
||||
print(f"Input tensor: {input_tensor}, type {input_tensor.dtype}, shape {input_tensor.shape}")
|
||||
|
|
@ -415,34 +413,36 @@ def start_chat(content=None):
|
|||
stream = TextStreamer(tokenizer)
|
||||
|
||||
qlen = input_tensor.shape[0]
|
||||
qlens = [qlen-kvlen]
|
||||
qlens = [qlen - kvlen]
|
||||
kvlens = [kvlen]
|
||||
page_tables = [list(range(pages_count))]
|
||||
start_time = time.perf_counter()
|
||||
llm.forward(qlens,page_tables, kvlens, input_tensor[kvlen:].data_ptr(), output_logits.data_ptr())
|
||||
llm.forward(qlens, page_tables, kvlens, input_tensor[kvlen:].data_ptr(), output_logits.data_ptr())
|
||||
end_time = time.perf_counter()
|
||||
print(f"Forward time: {end_time - start_time:.4f} seconds, tps: {qlens[0] / (end_time - start_time)} tokens/sec")
|
||||
|
||||
print(
|
||||
f"Forward time: {end_time - start_time:.4f} seconds, tps: {qlens[0] / (end_time - start_time)} tokens/sec"
|
||||
)
|
||||
|
||||
logits = output_logits[0]
|
||||
# print(logits)
|
||||
# sample
|
||||
# sample
|
||||
next_token = torch.argmax(logits).item()
|
||||
# print(f"Next token: {next_token}, {tokenizer.decode(next_token)}")
|
||||
kvlen = input_tensor.shape[0]
|
||||
input_tensor = torch.cat((input_tensor, torch.tensor([next_token])), dim=-1)
|
||||
|
||||
if next_token == tokenizer.eos_token_id or tokenizer.decode(next_token) == '<|im_end|>':
|
||||
|
||||
if next_token == tokenizer.eos_token_id or tokenizer.decode(next_token) == "<|im_end|>":
|
||||
stream.end()
|
||||
break
|
||||
else:
|
||||
stream.put(torch.tensor([next_token]))
|
||||
|
||||
|
||||
job_id = 0
|
||||
while True:
|
||||
try:
|
||||
# ---------- 让用户决定是否继续 ----------
|
||||
choice = input(
|
||||
"\n【回车】开始对话 | 输入 1 读取文件 | 输入 q/quit/exit 退出程序: "
|
||||
).strip().lower()
|
||||
choice = input("\n【回车】开始对话 | 输入 1 读取文件 | 输入 q/quit/exit 退出程序: ").strip().lower()
|
||||
if choice in {"q", "quit", "exit"}:
|
||||
print("收到退出指令,程序结束。")
|
||||
break
|
||||
|
|
@ -466,15 +466,4 @@ while True:
|
|||
print(f"\n发生错误:{e}\n已终止对话 #{job_id},马上重启…")
|
||||
logger.error(f"Error in job {job_id}: {e}", exc_info=True)
|
||||
finally:
|
||||
job_id += 1 # 不管中断与否,都给下一任务换编号
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
job_id += 1 # 不管中断与否,都给下一任务换编号
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue