mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2026-04-28 20:00:06 +00:00
[fix](test): fix import kt-kernel (#1728)
This commit is contained in:
parent
6fc4080a7d
commit
a8667ddb58
33 changed files with 1063 additions and 1151 deletions
|
|
@ -11,7 +11,7 @@ import numpy as np
|
|||
# if REPO_ROOT not in sys.path:
|
||||
# sys.path.insert(0, REPO_ROOT)
|
||||
|
||||
import kt_kernel_ext
|
||||
from kt_kernel import kt_kernel_ext
|
||||
from kt_kernel_ext import CPUInfer
|
||||
|
||||
|
||||
|
|
@ -57,10 +57,10 @@ def allocate_weights(expert_num, hidden_size, intermediate_size, group_size):
|
|||
def main():
|
||||
torch.manual_seed(123)
|
||||
|
||||
expert_num = 256 # Total experts
|
||||
expert_num = 256 # Total experts
|
||||
gpu_experts = expert_num # Number of experts on GPU
|
||||
gpu_tp_count = 2 # Number of TP parts
|
||||
|
||||
|
||||
num_experts_per_tok = 8
|
||||
hidden_size = 7168
|
||||
intermediate_size = 2048
|
||||
|
|
@ -89,9 +89,7 @@ def main():
|
|||
|
||||
moe = kt_kernel_ext.moe.AMXInt4_KGroup_MOE(cfg)
|
||||
|
||||
physical_to_logical_map = (
|
||||
torch.arange(expert_num, dtype=torch.int64, device="cpu").contiguous()
|
||||
)
|
||||
physical_to_logical_map = torch.arange(expert_num, dtype=torch.int64, device="cpu").contiguous()
|
||||
cpuinfer.submit(moe.load_weights_task(physical_to_logical_map.data_ptr()))
|
||||
cpuinfer.sync()
|
||||
|
||||
|
|
@ -169,6 +167,7 @@ def main():
|
|||
total_bytes = total_weights // group_size + total_weights // 2
|
||||
print(f"write_weight_scale_to_buffer time: {elapsed_ms:.2f} ms")
|
||||
print(f"Throughput: {total_bytes / (elapsed_ms * 1e6):.2f} GB/s")
|
||||
|
||||
def split_expert_tensor(tensor, chunk):
|
||||
"""Split tensor by experts"""
|
||||
return [tensor[i * chunk : (i + 1) * chunk] for i in range(expert_num)]
|
||||
|
|
@ -229,10 +228,10 @@ def main():
|
|||
tp_scale_offset = col_scale_start + tp_idx * tp_slice_scale_size
|
||||
|
||||
down_weight_tp_parts.append(
|
||||
down_q_experts[expert_idx][tp_weight_offset:tp_weight_offset + tp_slice_weight_size]
|
||||
down_q_experts[expert_idx][tp_weight_offset : tp_weight_offset + tp_slice_weight_size]
|
||||
)
|
||||
down_scale_tp_parts.append(
|
||||
down_scale_experts[expert_idx][tp_scale_offset:tp_scale_offset + tp_slice_scale_size]
|
||||
down_scale_experts[expert_idx][tp_scale_offset : tp_scale_offset + tp_slice_scale_size]
|
||||
)
|
||||
|
||||
# Concatenate all column slices for this TP
|
||||
|
|
@ -260,7 +259,9 @@ def main():
|
|||
assert torch.equal(w2_weight_bufs[tp_idx], expected_w2_weight), f"w2 weight bytes mismatch for TP {tp_idx}"
|
||||
assert torch.allclose(w2_scale_bufs[tp_idx], expected_w2_scale), f"w2 scale values mismatch for TP {tp_idx}"
|
||||
|
||||
print(f"\n✓ write_weight_scale_to_buffer passed: extracted {gpu_experts} GPU experts across {gpu_tp_count} TP parts from total {expert_num} experts")
|
||||
print(
|
||||
f"\n✓ write_weight_scale_to_buffer passed: extracted {gpu_experts} GPU experts across {gpu_tp_count} TP parts from total {expert_num} experts"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue