[fix](test): fix import kt-kernel (#1728)

This commit is contained in:
ErvinXie 2025-12-17 19:46:32 +08:00 committed by GitHub
parent 6fc4080a7d
commit a8667ddb58
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
33 changed files with 1063 additions and 1151 deletions

View file

@ -14,7 +14,7 @@ from tqdm import tqdm
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "build"))
import kt_kernel_ext
from kt_kernel import kt_kernel_ext
import torch
# Benchmark parameters (single MoE, mirror examples/test_k2_write_buffer.py)
@ -39,20 +39,12 @@ CPUInfer = kt_kernel_ext.CPUInfer(96)
def get_git_commit():
result = {}
try:
commit = (
subprocess.check_output(["git", "rev-parse", "HEAD"]).decode("utf-8").strip()
)
commit_msg = (
subprocess.check_output(["git", "log", "-1", "--pretty=%B"])
.decode("utf-8")
.strip()
)
commit = subprocess.check_output(["git", "rev-parse", "HEAD"]).decode("utf-8").strip()
commit_msg = subprocess.check_output(["git", "log", "-1", "--pretty=%B"]).decode("utf-8").strip()
result["commit"] = commit
result["commit_message"] = commit_msg
dirty_output = (
subprocess.check_output(["git", "status", "--porcelain"]).decode("utf-8").strip()
)
dirty_output = subprocess.check_output(["git", "status", "--porcelain"]).decode("utf-8").strip()
if dirty_output:
result["dirty"] = True
result["dirty_files"] = dirty_output.splitlines()
@ -160,9 +152,7 @@ def build_moe():
per_mat_scale_elems,
) = allocate_weights()
config = kt_kernel_ext.moe.MOEConfig(
expert_num, num_experts_per_tok, hidden_size, intermediate_size
)
config = kt_kernel_ext.moe.MOEConfig(expert_num, num_experts_per_tok, hidden_size, intermediate_size)
config.max_len = max_len
config.quant_config.bits = 4
config.quant_config.group_size = group_size
@ -186,18 +176,10 @@ def build_moe():
total_weight_bytes_per_tp = gpu_experts_num * weight_bytes_per_expert_per_tp
total_scale_elems_per_tp = gpu_experts_num * scale_elems_per_expert_per_tp
w13_weight_bufs = [
torch.empty(2 * total_weight_bytes_per_tp, dtype=torch.uint8) for _ in range(gpu_tp_count)
]
w13_scale_bufs = [
torch.empty(2 * total_scale_elems_per_tp, dtype=torch.bfloat16) for _ in range(gpu_tp_count)
]
w2_weight_bufs = [
torch.empty(total_weight_bytes_per_tp, dtype=torch.uint8) for _ in range(gpu_tp_count)
]
w2_scale_bufs = [
torch.empty(total_scale_elems_per_tp, dtype=torch.bfloat16) for _ in range(gpu_tp_count)
]
w13_weight_bufs = [torch.empty(2 * total_weight_bytes_per_tp, dtype=torch.uint8) for _ in range(gpu_tp_count)]
w13_scale_bufs = [torch.empty(2 * total_scale_elems_per_tp, dtype=torch.bfloat16) for _ in range(gpu_tp_count)]
w2_weight_bufs = [torch.empty(total_weight_bytes_per_tp, dtype=torch.uint8) for _ in range(gpu_tp_count)]
w2_scale_bufs = [torch.empty(total_scale_elems_per_tp, dtype=torch.bfloat16) for _ in range(gpu_tp_count)]
buffer_ptrs = {
"w13_weight_ptrs": [buf.data_ptr() for buf in w13_weight_bufs],
@ -248,7 +230,7 @@ def bench_write_buffer():
)
)
CPUInfer.sync()
total_time = 0
for _ in tqdm(range(test_iter), desc="Testing"):
start = time.perf_counter()
@ -265,8 +247,6 @@ def bench_write_buffer():
time.sleep(0.6)
print(end - start)
time_per_iter_us = total_time / test_iter * 1e6
bandwidth_gbs = bytes_per_call * test_iter / total_time / 1e9