mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2026-05-05 07:11:39 +00:00
parent
e7d277d163
commit
d8046e1bb4
65 changed files with 12111 additions and 2502 deletions
|
|
@ -6,11 +6,6 @@ import torch
|
|||
import numpy as np
|
||||
|
||||
|
||||
# Ensure we can import the local extension
|
||||
# REPO_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "../"))
|
||||
# if REPO_ROOT not in sys.path:
|
||||
# sys.path.insert(0, REPO_ROOT)
|
||||
|
||||
from kt_kernel import kt_kernel_ext
|
||||
from kt_kernel_ext import CPUInfer
|
||||
|
||||
|
|
@ -54,12 +49,12 @@ def allocate_weights(expert_num, hidden_size, intermediate_size, group_size):
|
|||
)
|
||||
|
||||
|
||||
def main():
|
||||
def test_with_tp(gpu_tp_count):
|
||||
"""Test write_weight_scale_to_buffer with a specific gpu_tp_count"""
|
||||
torch.manual_seed(123)
|
||||
|
||||
expert_num = 256 # Total experts
|
||||
expert_num = 8 # Reduced for faster testing
|
||||
gpu_experts = expert_num # Number of experts on GPU
|
||||
gpu_tp_count = 2 # Number of TP parts
|
||||
|
||||
num_experts_per_tok = 8
|
||||
hidden_size = 7168
|
||||
|
|
@ -94,11 +89,7 @@ def main():
|
|||
cpuinfer.sync()
|
||||
|
||||
# TP configuration
|
||||
|
||||
# Since weights are col-major, we can directly divide the total size by tp_count
|
||||
# Each matrix is divided into gpu_tp_count parts in memory order
|
||||
|
||||
# Calculate sizes per TP part (direct division since col-major)
|
||||
# Calculate sizes per TP part (per expert)
|
||||
weight_bytes_per_expert_per_tp = per_mat_weight_bytes // gpu_tp_count
|
||||
scale_elems_per_expert_per_tp = per_mat_scale_elems // gpu_tp_count
|
||||
|
||||
|
|
@ -107,24 +98,19 @@ def main():
|
|||
total_scale_elems_per_tp = gpu_experts * scale_elems_per_expert_per_tp
|
||||
|
||||
# Create buffer lists for w13 (gate+up) and w2 (down)
|
||||
# These hold all experts' data for each GPU TP
|
||||
w13_weight_bufs = []
|
||||
w13_scale_bufs = []
|
||||
w2_weight_bufs = []
|
||||
w2_scale_bufs = []
|
||||
|
||||
for tp_idx in range(gpu_tp_count):
|
||||
# w13 combines gate and up, so needs 2x the size
|
||||
# w13 combines gate and up, so needs 2x the size per expert
|
||||
w13_weight_bufs.append(torch.empty(2 * total_weight_bytes_per_tp, dtype=torch.uint8))
|
||||
w13_scale_bufs.append(torch.empty(2 * total_scale_elems_per_tp, dtype=torch.bfloat16))
|
||||
w2_weight_bufs.append(torch.empty(total_weight_bytes_per_tp, dtype=torch.uint8))
|
||||
w2_scale_bufs.append(torch.empty(total_scale_elems_per_tp, dtype=torch.bfloat16))
|
||||
|
||||
# Get data pointers for all buffers
|
||||
w13_weight_ptrs = [buf.data_ptr() for buf in w13_weight_bufs]
|
||||
w13_scale_ptrs = [buf.data_ptr() for buf in w13_scale_bufs]
|
||||
w2_weight_ptrs = [buf.data_ptr() for buf in w2_weight_bufs]
|
||||
w2_scale_ptrs = [buf.data_ptr() for buf in w2_scale_bufs]
|
||||
|
||||
print(f"Total experts: {expert_num}, GPU experts: {gpu_experts}")
|
||||
print(f"GPU TP count: {gpu_tp_count}")
|
||||
print(f"Original per matrix weight bytes: {per_mat_weight_bytes}")
|
||||
|
|
@ -133,14 +119,56 @@ def main():
|
|||
print(f"Scale elements per expert per TP: {scale_elems_per_expert_per_tp}")
|
||||
print(f"Total weight bytes per TP (w13): {2 * total_weight_bytes_per_tp}")
|
||||
print(f"Total weight bytes per TP (w2): {total_weight_bytes_per_tp}")
|
||||
print(f"Total scale elements per TP (w13): {2 * total_scale_elems_per_tp}")
|
||||
print(f"Total scale elements per TP (w2): {total_scale_elems_per_tp}")
|
||||
|
||||
for i in range(5):
|
||||
# Helper function to get pointers with expert offset
|
||||
# K2 write_weights_to_buffer writes one expert at a time, so we need to pass
|
||||
# pointers that already point to the correct location for each expert
|
||||
def get_expert_ptrs(expert_id):
|
||||
w13_weight_ptrs = []
|
||||
w13_scale_ptrs = []
|
||||
w2_weight_ptrs = []
|
||||
w2_scale_ptrs = []
|
||||
|
||||
for tp_idx in range(gpu_tp_count):
|
||||
# Calculate byte offsets for this expert
|
||||
# w13: gate_weight + up_weight interleaved by expert
|
||||
# Layout: [expert0_gate, expert0_up, expert1_gate, expert1_up, ...]
|
||||
w13_weight_expert_offset = expert_id * 2 * weight_bytes_per_expert_per_tp
|
||||
w13_scale_expert_offset = expert_id * 2 * scale_elems_per_expert_per_tp
|
||||
w2_weight_expert_offset = expert_id * weight_bytes_per_expert_per_tp
|
||||
w2_scale_expert_offset = expert_id * scale_elems_per_expert_per_tp
|
||||
|
||||
w13_weight_ptrs.append(w13_weight_bufs[tp_idx].data_ptr() + w13_weight_expert_offset)
|
||||
w13_scale_ptrs.append(w13_scale_bufs[tp_idx].data_ptr() + w13_scale_expert_offset * 2) # bf16 = 2 bytes
|
||||
w2_weight_ptrs.append(w2_weight_bufs[tp_idx].data_ptr() + w2_weight_expert_offset)
|
||||
w2_scale_ptrs.append(w2_scale_bufs[tp_idx].data_ptr() + w2_scale_expert_offset * 2) # bf16 = 2 bytes
|
||||
|
||||
return w13_weight_ptrs, w13_scale_ptrs, w2_weight_ptrs, w2_scale_ptrs
|
||||
|
||||
# Warm up
|
||||
for i in range(2):
|
||||
for expert_id in range(gpu_experts):
|
||||
w13_weight_ptrs, w13_scale_ptrs, w2_weight_ptrs, w2_scale_ptrs = get_expert_ptrs(expert_id)
|
||||
cpuinfer.submit(
|
||||
moe.write_weight_scale_to_buffer_task(
|
||||
gpu_tp_count=gpu_tp_count,
|
||||
expert_id=expert_id,
|
||||
w13_weight_ptrs=w13_weight_ptrs,
|
||||
w13_scale_ptrs=w13_scale_ptrs,
|
||||
w2_weight_ptrs=w2_weight_ptrs,
|
||||
w2_scale_ptrs=w2_scale_ptrs,
|
||||
)
|
||||
)
|
||||
cpuinfer.sync()
|
||||
|
||||
# Timing
|
||||
begin_time = time.perf_counter_ns()
|
||||
for expert_id in range(gpu_experts):
|
||||
w13_weight_ptrs, w13_scale_ptrs, w2_weight_ptrs, w2_scale_ptrs = get_expert_ptrs(expert_id)
|
||||
cpuinfer.submit(
|
||||
moe.write_weight_scale_to_buffer_task(
|
||||
gpu_tp_count=gpu_tp_count,
|
||||
gpu_experts_num=gpu_experts,
|
||||
expert_id=expert_id,
|
||||
w13_weight_ptrs=w13_weight_ptrs,
|
||||
w13_scale_ptrs=w13_scale_ptrs,
|
||||
w2_weight_ptrs=w2_weight_ptrs,
|
||||
|
|
@ -148,23 +176,10 @@ def main():
|
|||
)
|
||||
)
|
||||
cpuinfer.sync()
|
||||
|
||||
begin_time = time.perf_counter_ns()
|
||||
cpuinfer.submit(
|
||||
moe.write_weight_scale_to_buffer_task(
|
||||
gpu_tp_count=gpu_tp_count,
|
||||
gpu_experts_num=gpu_experts,
|
||||
w13_weight_ptrs=w13_weight_ptrs,
|
||||
w13_scale_ptrs=w13_scale_ptrs,
|
||||
w2_weight_ptrs=w2_weight_ptrs,
|
||||
w2_scale_ptrs=w2_scale_ptrs,
|
||||
)
|
||||
)
|
||||
cpuinfer.sync()
|
||||
end_time = time.perf_counter_ns()
|
||||
elapsed_ms = (end_time - begin_time) / 1000000
|
||||
total_weights = hidden_size * intermediate_size * expert_num * 3
|
||||
total_bytes = total_weights // group_size + total_weights // 2
|
||||
total_weights = hidden_size * intermediate_size * gpu_experts * 3
|
||||
total_bytes = total_weights // group_size * 2 + total_weights // 2 # scale (bf16) + weight (int4)
|
||||
print(f"write_weight_scale_to_buffer time: {elapsed_ms:.2f} ms")
|
||||
print(f"Throughput: {total_bytes / (elapsed_ms * 1e6):.2f} GB/s")
|
||||
|
||||
|
|
@ -181,9 +196,6 @@ def main():
|
|||
up_scale_experts = split_expert_tensor(up_scale, per_mat_scale_elems)
|
||||
down_scale_experts = split_expert_tensor(down_scale, per_mat_scale_elems)
|
||||
|
||||
# CPU TP count is always 2 in this test setup (one TP per NUMA node)
|
||||
cpu_tp_count = 2
|
||||
|
||||
# Verify buffers for each TP part
|
||||
for tp_idx in range(gpu_tp_count):
|
||||
expected_w13_weights = []
|
||||
|
|
@ -193,22 +205,22 @@ def main():
|
|||
|
||||
weight13_per_tp = per_mat_weight_bytes // gpu_tp_count
|
||||
scale13_per_tp = per_mat_scale_elems // gpu_tp_count
|
||||
# Process each GPU expert
|
||||
for expert_idx in range(gpu_experts):
|
||||
# For w13 (gate and up), the slicing is straightforward
|
||||
|
||||
# Process each GPU expert
|
||||
for expert_id in range(gpu_experts):
|
||||
# For w13 (gate and up), the slicing is straightforward
|
||||
start_weight = tp_idx * weight13_per_tp
|
||||
end_weight = (tp_idx + 1) * weight13_per_tp
|
||||
start_scale = tp_idx * scale13_per_tp
|
||||
end_scale = (tp_idx + 1) * scale13_per_tp
|
||||
|
||||
# Gate
|
||||
gate_weight_tp = gate_q_experts[expert_idx][start_weight:end_weight]
|
||||
gate_scale_tp = gate_scale_experts[expert_idx][start_scale:end_scale]
|
||||
gate_weight_tp = gate_q_experts[expert_id][start_weight:end_weight]
|
||||
gate_scale_tp = gate_scale_experts[expert_id][start_scale:end_scale]
|
||||
|
||||
# Up
|
||||
up_weight_tp = up_q_experts[expert_idx][start_weight:end_weight]
|
||||
up_scale_tp = up_scale_experts[expert_idx][start_scale:end_scale]
|
||||
up_weight_tp = up_q_experts[expert_id][start_weight:end_weight]
|
||||
up_scale_tp = up_scale_experts[expert_id][start_scale:end_scale]
|
||||
|
||||
# Down matrix needs special handling because it's sliced column-wise
|
||||
# We need to reconstruct it from column slices
|
||||
|
|
@ -228,16 +240,17 @@ def main():
|
|||
tp_scale_offset = col_scale_start + tp_idx * tp_slice_scale_size
|
||||
|
||||
down_weight_tp_parts.append(
|
||||
down_q_experts[expert_idx][tp_weight_offset : tp_weight_offset + tp_slice_weight_size]
|
||||
down_q_experts[expert_id][tp_weight_offset : tp_weight_offset + tp_slice_weight_size]
|
||||
)
|
||||
down_scale_tp_parts.append(
|
||||
down_scale_experts[expert_idx][tp_scale_offset : tp_scale_offset + tp_slice_scale_size]
|
||||
down_scale_experts[expert_id][tp_scale_offset : tp_scale_offset + tp_slice_scale_size]
|
||||
)
|
||||
|
||||
# Concatenate all column slices for this TP
|
||||
down_weight_tp = torch.cat(down_weight_tp_parts)
|
||||
down_scale_tp = torch.cat(down_scale_tp_parts)
|
||||
|
||||
# Append to expected lists - interleaved by expert: [gate0, up0, gate1, up1, ...]
|
||||
expected_w13_weights.append(gate_weight_tp)
|
||||
expected_w13_weights.append(up_weight_tp)
|
||||
expected_w13_scales.append(gate_scale_tp)
|
||||
|
|
@ -252,16 +265,85 @@ def main():
|
|||
expected_w2_scale = torch.cat(expected_w2_scales)
|
||||
|
||||
print(f"=== Checking TP part {tp_idx} ===")
|
||||
print(f" w13 weight shape: actual={w13_weight_bufs[tp_idx].shape}, expected={expected_w13_weight.shape}")
|
||||
print(f" w13 scale shape: actual={w13_scale_bufs[tp_idx].shape}, expected={expected_w13_scale.shape}")
|
||||
print(f" w2 weight shape: actual={w2_weight_bufs[tp_idx].shape}, expected={expected_w2_weight.shape}")
|
||||
print(f" w2 scale shape: actual={w2_scale_bufs[tp_idx].shape}, expected={expected_w2_scale.shape}")
|
||||
|
||||
# Assert all checks pass
|
||||
assert torch.equal(w13_weight_bufs[tp_idx], expected_w13_weight), f"w13 weight bytes mismatch for TP {tp_idx}"
|
||||
assert torch.allclose(w13_scale_bufs[tp_idx], expected_w13_scale), f"w13 scale values mismatch for TP {tp_idx}"
|
||||
assert torch.equal(w2_weight_bufs[tp_idx], expected_w2_weight), f"w2 weight bytes mismatch for TP {tp_idx}"
|
||||
assert torch.allclose(w2_scale_bufs[tp_idx], expected_w2_scale), f"w2 scale values mismatch for TP {tp_idx}"
|
||||
if not torch.equal(w13_weight_bufs[tp_idx], expected_w13_weight):
|
||||
diff_mask = w13_weight_bufs[tp_idx] != expected_w13_weight
|
||||
first_diff_idx = diff_mask.nonzero()[0].item() if diff_mask.any() else -1
|
||||
print(f" w13 weight mismatch at index {first_diff_idx}")
|
||||
print(f" actual: {w13_weight_bufs[tp_idx][first_diff_idx:first_diff_idx+10]}")
|
||||
print(f" expected: {expected_w13_weight[first_diff_idx:first_diff_idx+10]}")
|
||||
raise AssertionError(f"w13 weight bytes mismatch for TP {tp_idx}")
|
||||
|
||||
if not torch.allclose(w13_scale_bufs[tp_idx], expected_w13_scale):
|
||||
diff = torch.abs(w13_scale_bufs[tp_idx].float() - expected_w13_scale.float())
|
||||
max_diff_idx = diff.argmax().item()
|
||||
print(f" w13 scale mismatch, max diff at index {max_diff_idx}")
|
||||
print(f" actual: {w13_scale_bufs[tp_idx][max_diff_idx]}")
|
||||
print(f" expected: {expected_w13_scale[max_diff_idx]}")
|
||||
raise AssertionError(f"w13 scale values mismatch for TP {tp_idx}")
|
||||
|
||||
if not torch.equal(w2_weight_bufs[tp_idx], expected_w2_weight):
|
||||
diff_mask = w2_weight_bufs[tp_idx] != expected_w2_weight
|
||||
first_diff_idx = diff_mask.nonzero()[0].item() if diff_mask.any() else -1
|
||||
print(f" w2 weight mismatch at index {first_diff_idx}")
|
||||
print(f" actual: {w2_weight_bufs[tp_idx][first_diff_idx:first_diff_idx+10]}")
|
||||
print(f" expected: {expected_w2_weight[first_diff_idx:first_diff_idx+10]}")
|
||||
raise AssertionError(f"w2 weight bytes mismatch for TP {tp_idx}")
|
||||
|
||||
if not torch.allclose(w2_scale_bufs[tp_idx], expected_w2_scale):
|
||||
diff = torch.abs(w2_scale_bufs[tp_idx].float() - expected_w2_scale.float())
|
||||
max_diff_idx = diff.argmax().item()
|
||||
print(f" w2 scale mismatch, max diff at index {max_diff_idx}")
|
||||
print(f" actual: {w2_scale_bufs[tp_idx][max_diff_idx]}")
|
||||
print(f" expected: {expected_w2_scale[max_diff_idx]}")
|
||||
raise AssertionError(f"w2 scale values mismatch for TP {tp_idx}")
|
||||
|
||||
print(
|
||||
f"\n✓ write_weight_scale_to_buffer passed: extracted {gpu_experts} GPU experts across {gpu_tp_count} TP parts from total {expert_num} experts"
|
||||
f"\n✓ write_weight_scale_to_buffer passed: extracted {gpu_experts} GPU experts across {gpu_tp_count} TP parts"
|
||||
)
|
||||
return True
|
||||
|
||||
|
||||
def main():
|
||||
"""Run tests for all gpu_tp_count values: 1, 2, 4, 8"""
|
||||
tp_values = [1, 2, 4, 8]
|
||||
all_passed = True
|
||||
results = {}
|
||||
|
||||
print("=" * 60)
|
||||
print("Testing K2 write_weight_scale_to_buffer for TP = 1, 2, 4, 8")
|
||||
print("=" * 60)
|
||||
|
||||
for tp in tp_values:
|
||||
print(f"\n{'='*60}")
|
||||
print(f"Testing with gpu_tp_count = {tp}")
|
||||
print(f"{'='*60}")
|
||||
try:
|
||||
test_with_tp(tp)
|
||||
results[tp] = "PASSED"
|
||||
print(f"✓ TP={tp} PASSED")
|
||||
except Exception as e:
|
||||
results[tp] = f"FAILED: {e}"
|
||||
all_passed = False
|
||||
print(f"✗ TP={tp} FAILED: {e}")
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("SUMMARY")
|
||||
print("=" * 60)
|
||||
for tp, result in results.items():
|
||||
status = "✓" if "PASSED" in result else "✗"
|
||||
print(f" {status} TP={tp}: {result}")
|
||||
|
||||
if all_passed:
|
||||
print("\n✓ ALL TESTS PASSED")
|
||||
else:
|
||||
print("\n✗ SOME TESTS FAILED")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue