[fix](test): fix import kt-kernel (#1728)

This commit is contained in:
ErvinXie 2025-12-17 19:46:32 +08:00 committed by GitHub
parent 6fc4080a7d
commit a8667ddb58
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
33 changed files with 1063 additions and 1151 deletions

View file

@ -11,7 +11,7 @@ import numpy as np
# if REPO_ROOT not in sys.path:
# sys.path.insert(0, REPO_ROOT)
import kt_kernel_ext
from kt_kernel import kt_kernel_ext
from kt_kernel_ext import CPUInfer
@ -57,10 +57,10 @@ def allocate_weights(expert_num, hidden_size, intermediate_size, group_size):
def main():
torch.manual_seed(123)
expert_num = 256 # Total experts
expert_num = 256 # Total experts
gpu_experts = expert_num # Number of experts on GPU
gpu_tp_count = 2 # Number of TP parts
num_experts_per_tok = 8
hidden_size = 7168
intermediate_size = 2048
@ -89,9 +89,7 @@ def main():
moe = kt_kernel_ext.moe.AMXInt4_KGroup_MOE(cfg)
physical_to_logical_map = (
torch.arange(expert_num, dtype=torch.int64, device="cpu").contiguous()
)
physical_to_logical_map = torch.arange(expert_num, dtype=torch.int64, device="cpu").contiguous()
cpuinfer.submit(moe.load_weights_task(physical_to_logical_map.data_ptr()))
cpuinfer.sync()
@ -169,6 +167,7 @@ def main():
total_bytes = total_weights // group_size + total_weights // 2
print(f"write_weight_scale_to_buffer time: {elapsed_ms:.2f} ms")
print(f"Throughput: {total_bytes / (elapsed_ms * 1e6):.2f} GB/s")
def split_expert_tensor(tensor, chunk):
"""Split tensor by experts"""
return [tensor[i * chunk : (i + 1) * chunk] for i in range(expert_num)]
@ -229,10 +228,10 @@ def main():
tp_scale_offset = col_scale_start + tp_idx * tp_slice_scale_size
down_weight_tp_parts.append(
down_q_experts[expert_idx][tp_weight_offset:tp_weight_offset + tp_slice_weight_size]
down_q_experts[expert_idx][tp_weight_offset : tp_weight_offset + tp_slice_weight_size]
)
down_scale_tp_parts.append(
down_scale_experts[expert_idx][tp_scale_offset:tp_scale_offset + tp_slice_scale_size]
down_scale_experts[expert_idx][tp_scale_offset : tp_scale_offset + tp_slice_scale_size]
)
# Concatenate all column slices for this TP
@ -260,7 +259,9 @@ def main():
assert torch.equal(w2_weight_bufs[tp_idx], expected_w2_weight), f"w2 weight bytes mismatch for TP {tp_idx}"
assert torch.allclose(w2_scale_bufs[tp_idx], expected_w2_scale), f"w2 scale values mismatch for TP {tp_idx}"
print(f"\n✓ write_weight_scale_to_buffer passed: extracted {gpu_experts} GPU experts across {gpu_tp_count} TP parts from total {expert_num} experts")
print(
f"\n✓ write_weight_scale_to_buffer passed: extracted {gpu_experts} GPU experts across {gpu_tp_count} TP parts from total {expert_num} experts"
)
if __name__ == "__main__":