[fix](test): fix import kt-kernel (#1728)

This commit is contained in:
ErvinXie 2025-12-17 19:46:32 +08:00 committed by GitHub
parent 6fc4080a7d
commit a8667ddb58
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
33 changed files with 1063 additions and 1151 deletions

View file

@ -6,7 +6,7 @@ from typing import Dict, Literal
sys.path.insert(0, os.path.dirname(__file__) + "/../build")
import torch
import kt_kernel_ext
from kt_kernel import kt_kernel_ext
torch.manual_seed(42)
@ -132,6 +132,7 @@ def pack_to_int32(value: torch.Tensor, num_bits: int, packed_dim: Literal[0, 1]
return packed
def pack_tensor_per_row(q: torch.Tensor, num_bits: int) -> torch.Tensor:
e, rows, cols = q.shape
flat = q.view(e * rows, cols)
@ -283,9 +284,9 @@ def run_case(pattern: str) -> Dict[str, float]:
CPUInfer.sync()
input_tensor_fp16 = input_tensor.to(torch.float16)
t_output = moe_torch(
input_tensor_fp16, expert_ids, weights, gate_fp16, up_fp16, down_fp16
).to(torch.bfloat16)
t_output = moe_torch(input_tensor_fp16, expert_ids, weights, gate_fp16, up_fp16, down_fp16).to(
torch.bfloat16
)
t_output = t_output.flatten()
output = output.flatten()