mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2026-05-05 07:11:39 +00:00
[fix](test): fix import kt-kernel (#1728)
This commit is contained in:
parent
6fc4080a7d
commit
a8667ddb58
33 changed files with 1063 additions and 1151 deletions
|
|
@ -6,7 +6,7 @@ from typing import Dict, Literal
|
|||
sys.path.insert(0, os.path.dirname(__file__) + "/../build")
|
||||
|
||||
import torch
|
||||
import kt_kernel_ext
|
||||
from kt_kernel import kt_kernel_ext
|
||||
|
||||
torch.manual_seed(42)
|
||||
|
||||
|
|
@ -132,6 +132,7 @@ def pack_to_int32(value: torch.Tensor, num_bits: int, packed_dim: Literal[0, 1]
|
|||
|
||||
return packed
|
||||
|
||||
|
||||
def pack_tensor_per_row(q: torch.Tensor, num_bits: int) -> torch.Tensor:
|
||||
e, rows, cols = q.shape
|
||||
flat = q.view(e * rows, cols)
|
||||
|
|
@ -283,9 +284,9 @@ def run_case(pattern: str) -> Dict[str, float]:
|
|||
CPUInfer.sync()
|
||||
|
||||
input_tensor_fp16 = input_tensor.to(torch.float16)
|
||||
t_output = moe_torch(
|
||||
input_tensor_fp16, expert_ids, weights, gate_fp16, up_fp16, down_fp16
|
||||
).to(torch.bfloat16)
|
||||
t_output = moe_torch(input_tensor_fp16, expert_ids, weights, gate_fp16, up_fp16, down_fp16).to(
|
||||
torch.bfloat16
|
||||
)
|
||||
|
||||
t_output = t_output.flatten()
|
||||
output = output.flatten()
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue