mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2026-04-30 21:00:07 +00:00
[fix](test): fix import kt-kernel (#1728)
This commit is contained in:
parent
6fc4080a7d
commit
a8667ddb58
33 changed files with 1063 additions and 1151 deletions
|
|
@ -14,7 +14,7 @@ from tqdm import tqdm
|
|||
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "build"))
|
||||
|
||||
import kt_kernel_ext
|
||||
from kt_kernel import kt_kernel_ext
|
||||
import torch
|
||||
|
||||
# Benchmark parameters (single MoE, mirror examples/test_k2_write_buffer.py)
|
||||
|
|
@ -39,20 +39,12 @@ CPUInfer = kt_kernel_ext.CPUInfer(96)
|
|||
def get_git_commit():
|
||||
result = {}
|
||||
try:
|
||||
commit = (
|
||||
subprocess.check_output(["git", "rev-parse", "HEAD"]).decode("utf-8").strip()
|
||||
)
|
||||
commit_msg = (
|
||||
subprocess.check_output(["git", "log", "-1", "--pretty=%B"])
|
||||
.decode("utf-8")
|
||||
.strip()
|
||||
)
|
||||
commit = subprocess.check_output(["git", "rev-parse", "HEAD"]).decode("utf-8").strip()
|
||||
commit_msg = subprocess.check_output(["git", "log", "-1", "--pretty=%B"]).decode("utf-8").strip()
|
||||
result["commit"] = commit
|
||||
result["commit_message"] = commit_msg
|
||||
|
||||
dirty_output = (
|
||||
subprocess.check_output(["git", "status", "--porcelain"]).decode("utf-8").strip()
|
||||
)
|
||||
dirty_output = subprocess.check_output(["git", "status", "--porcelain"]).decode("utf-8").strip()
|
||||
if dirty_output:
|
||||
result["dirty"] = True
|
||||
result["dirty_files"] = dirty_output.splitlines()
|
||||
|
|
@ -160,9 +152,7 @@ def build_moe():
|
|||
per_mat_scale_elems,
|
||||
) = allocate_weights()
|
||||
|
||||
config = kt_kernel_ext.moe.MOEConfig(
|
||||
expert_num, num_experts_per_tok, hidden_size, intermediate_size
|
||||
)
|
||||
config = kt_kernel_ext.moe.MOEConfig(expert_num, num_experts_per_tok, hidden_size, intermediate_size)
|
||||
config.max_len = max_len
|
||||
config.quant_config.bits = 4
|
||||
config.quant_config.group_size = group_size
|
||||
|
|
@ -186,18 +176,10 @@ def build_moe():
|
|||
total_weight_bytes_per_tp = gpu_experts_num * weight_bytes_per_expert_per_tp
|
||||
total_scale_elems_per_tp = gpu_experts_num * scale_elems_per_expert_per_tp
|
||||
|
||||
w13_weight_bufs = [
|
||||
torch.empty(2 * total_weight_bytes_per_tp, dtype=torch.uint8) for _ in range(gpu_tp_count)
|
||||
]
|
||||
w13_scale_bufs = [
|
||||
torch.empty(2 * total_scale_elems_per_tp, dtype=torch.bfloat16) for _ in range(gpu_tp_count)
|
||||
]
|
||||
w2_weight_bufs = [
|
||||
torch.empty(total_weight_bytes_per_tp, dtype=torch.uint8) for _ in range(gpu_tp_count)
|
||||
]
|
||||
w2_scale_bufs = [
|
||||
torch.empty(total_scale_elems_per_tp, dtype=torch.bfloat16) for _ in range(gpu_tp_count)
|
||||
]
|
||||
w13_weight_bufs = [torch.empty(2 * total_weight_bytes_per_tp, dtype=torch.uint8) for _ in range(gpu_tp_count)]
|
||||
w13_scale_bufs = [torch.empty(2 * total_scale_elems_per_tp, dtype=torch.bfloat16) for _ in range(gpu_tp_count)]
|
||||
w2_weight_bufs = [torch.empty(total_weight_bytes_per_tp, dtype=torch.uint8) for _ in range(gpu_tp_count)]
|
||||
w2_scale_bufs = [torch.empty(total_scale_elems_per_tp, dtype=torch.bfloat16) for _ in range(gpu_tp_count)]
|
||||
|
||||
buffer_ptrs = {
|
||||
"w13_weight_ptrs": [buf.data_ptr() for buf in w13_weight_bufs],
|
||||
|
|
@ -248,7 +230,7 @@ def bench_write_buffer():
|
|||
)
|
||||
)
|
||||
CPUInfer.sync()
|
||||
|
||||
|
||||
total_time = 0
|
||||
for _ in tqdm(range(test_iter), desc="Testing"):
|
||||
start = time.perf_counter()
|
||||
|
|
@ -265,8 +247,6 @@ def bench_write_buffer():
|
|||
time.sleep(0.6)
|
||||
print(end - start)
|
||||
|
||||
|
||||
|
||||
time_per_iter_us = total_time / test_iter * 1e6
|
||||
bandwidth_gbs = bytes_per_call * test_iter / total_time / 1e9
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue