[Feature] Add avx-based kimi-k2 support (#1656)
Some checks are pending
Book-CI / test-2 (push) Waiting to run
Book-CI / test (push) Waiting to run
Book-CI / test-1 (push) Waiting to run
Deploy / deploy (macos-latest) (push) Waiting to run
Deploy / deploy (ubuntu-latest) (push) Waiting to run
Deploy / deploy (windows-latest) (push) Waiting to run

* support Kimi-K2-Thinking original weight
fix amx kernel bug

* update k2 avx kernel.

* feat: add CPUInfer write buffer task

* [feat]: add kimi k2 cpu write buffer support

- Implement write_weights_to_buffer function in k2-moe.hpp for extracting GPU expert weights
- Fix down (w2) weight column-wise slicing for different TP configurations
- Support three TP scenarios: cpu_tp == gpu_tp, cpu_tp > gpu_tp, cpu_tp < gpu_tp
- Add comprehensive test cases for weight extraction validation
- Ensure compatibility with Kimi model's MoE architecture

* [fix]: correct write_weight_scale_to_buffer expert offset calculation

Fixed the bug in write_weight_scale_to_buffer_task where expert offsets in GPU buffers were incorrectly calculated. Changed from using per_expert_gpu sizes to using full gpu_tp sizes, ensuring correct memory layout for multi-expert scenarios.

Also added benchmark scripts for k2 moe and write buffer operations, and cleaned up debug output in test files.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>

* [feat]: add write buffer wrapper

* [fix] fix comment

---------

Co-authored-by: ouqingliang <1692110604@qq.com>
Co-authored-by: Claude <noreply@anthropic.com>
This commit is contained in:
Jiaqi Liao 2025-12-02 16:01:07 +08:00 committed by GitHub
parent c2b8c60c4e
commit fcf8882075
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
12 changed files with 2649 additions and 34 deletions

View file

@ -0,0 +1,363 @@
#!/usr/bin/env python
# coding=utf-8
"""
Benchmark AMX_K2_MOE_TP int4 path with packed weights and BF16 scales.
"""
import json
import math
import os
import platform
import subprocess
import sys
import time
from tqdm import tqdm
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "build"))
import kt_kernel_ext
import torch
# Benchmark parameters (single MoE, no layer loop)
expert_num = 384
hidden_size = 7168
intermediate_size = 2048
max_len = 25600
num_experts_per_tok = 8
qlen = 1
warm_up_iter = 1000
test_iter = 5000
k_group_size = 32
physical_to_logical_map = (
torch.tensor(data=range(expert_num), device="cpu", dtype=torch.int64).contiguous()
)
worker_config = kt_kernel_ext.WorkerPoolConfig()
worker_config.subpool_count = 2
worker_config.subpool_numa_map = [0, 1]
worker_config.subpool_thread_count = [40, 40]
CPUInfer = kt_kernel_ext.CPUInfer(worker_config)
def get_git_commit():
result = {}
try:
commit = (
subprocess.check_output(["git", "rev-parse", "HEAD"])
.decode("utf-8")
.strip()
)
commit_msg = (
subprocess.check_output(["git", "log", "-1", "--pretty=%B"])
.decode("utf-8")
.strip()
)
result["commit"] = commit
result["commit_message"] = commit_msg
dirty_output = (
subprocess.check_output(["git", "status", "--porcelain"])
.decode("utf-8")
.strip()
)
if dirty_output:
result["dirty"] = True
result["dirty_files"] = dirty_output.splitlines()
else:
result["dirty"] = False
except Exception as e:
result["commit"] = None
result["commit_message"] = None
result["dirty"] = None
result["error"] = str(e)
return result
def get_system_info():
info = {}
uname = platform.uname()
info["system_name"] = uname.system
info["node_name"] = uname.node
cpu_model = None
if os.path.exists("/proc/cpuinfo"):
try:
with open("/proc/cpuinfo", "r") as f:
for line in f:
if "model name" in line:
cpu_model = line.split(":", 1)[1].strip()
break
except Exception as e:
cpu_model = f"Error: {e}"
info["cpu_model"] = cpu_model
mem_total_gb = None
if os.path.exists("/proc/meminfo"):
try:
with open("/proc/meminfo", "r") as f:
for line in f:
if "MemTotal" in line:
mem_kb = float(line.split(":", 1)[1].split()[0])
mem_total_gb = round(mem_kb / (1024 * 1024), 2)
break
except Exception as e:
mem_total_gb = f"Error: {e}"
info["memory_size_GB"] = mem_total_gb
info["cpu_core_count"] = os.cpu_count()
sockets = set()
if os.path.exists("/proc/cpuinfo"):
try:
with open("/proc/cpuinfo", "r") as f:
for line in f:
if "physical id" in line:
sockets.add(line.split(":", 1)[1].strip())
except Exception:
sockets = set()
info["cpu_socket_count"] = len(sockets) if len(sockets) > 0 else 1
return info
script_path = os.path.abspath(__file__)
script_dir = os.path.dirname(script_path)
script_name = os.path.splitext(os.path.basename(script_path))[0]
json_path = os.path.join(script_dir, script_name + ".jsonl")
def record_results(result, filename=json_path):
with open(filename, "a") as f:
f.write(json.dumps(result) + "\n")
def pack_to_int32(
value: torch.Tensor, num_bits: int, packed_dim: int = 1
) -> torch.Tensor:
if value.dtype is not torch.int8:
raise ValueError("Tensor must be torch.int8 before packing")
if not (1 <= num_bits <= 8):
raise ValueError(f"num_bits must be in [1, 8], got {num_bits}")
offset = 1 << (num_bits - 1)
value = (value + offset).to(torch.uint8)
device = value.device
pack_factor = 32 // num_bits
if packed_dim == 0:
value = value.transpose(0, 1)
rows, cols = value.shape
padded_cols = math.ceil(cols / pack_factor) * pack_factor
pad_len = padded_cols - cols
if pad_len > 0:
value = torch.nn.functional.pad(value, (0, pad_len))
num_groups = padded_cols // pack_factor
reshaped = value.view(rows, num_groups, pack_factor).to(torch.int32)
bit_shifts = torch.arange(pack_factor, device=device, dtype=torch.int32) * num_bits
packed = (reshaped << bit_shifts).sum(dim=2, dtype=torch.int32)
if packed_dim == 0:
packed = packed.transpose(0, 1)
return packed
def pack_tensor_per_row(q: torch.Tensor, num_bits: int) -> torch.Tensor:
e, rows, cols = q.shape
flat = q.view(e * rows, cols)
packed = pack_to_int32(flat, num_bits)
return packed.view(e, rows, -1).contiguous()
def quantize_k2_tensor(weights: torch.Tensor, group_size: int):
"""
K2 int4 quantization producing int32-packed weights (8 int4s each) and BF16 scales.
"""
weights_f32 = weights.to(torch.float32)
e, rows, cols = weights_f32.shape
if cols % group_size != 0 or cols % 2 != 0:
raise ValueError(
f"cols ({cols}) must be divisible by group_size ({group_size}) and 2"
)
reshaped = weights_f32.view(e, rows, cols // group_size, group_size)
max_abs = reshaped.abs().amax(dim=-1, keepdim=True).clamp(min=1e-8)
scales = (max_abs / 7.0).squeeze(-1)
q = torch.round(reshaped / scales.unsqueeze(-1)).clamp(-8, 7).to(torch.int8)
q = q.view(e, rows, cols)
packed = pack_tensor_per_row(q, num_bits=4).view(e, rows, cols // 8).contiguous()
scales = scales.to(torch.bfloat16).contiguous().view(
e, rows, cols // group_size
).contiguous()
return packed, scales
def build_quantized_layer_weights():
gate_proj = torch.randn(
(expert_num, intermediate_size, hidden_size),
dtype=torch.float32,
device="cpu",
).contiguous()
up_proj = torch.randn(
(expert_num, intermediate_size, hidden_size),
dtype=torch.float32,
device="cpu",
).contiguous()
down_proj = torch.randn(
(expert_num, hidden_size, intermediate_size),
dtype=torch.float32,
device="cpu",
).contiguous()
gate_q, gate_scales = quantize_k2_tensor(gate_proj, k_group_size)
up_q, up_scales = quantize_k2_tensor(up_proj, k_group_size)
down_q, down_scales = quantize_k2_tensor(down_proj, k_group_size)
return {
"gate_qweight": gate_q,
"up_qweight": up_q,
"down_qweight": down_q,
"gate_scales": gate_scales,
"up_scales": up_scales,
"down_scales": down_scales,
}
def bench_k2_moe():
with torch.inference_mode():
bytes_per_elem = 0.5 + 2.0 / k_group_size
quant_data = build_quantized_layer_weights()
config = kt_kernel_ext.moe.MOEConfig(
expert_num, num_experts_per_tok, hidden_size, intermediate_size, 0
)
config.max_len = max_len
config.quant_config.bits = 4
config.quant_config.group_size = k_group_size
config.quant_config.zero_point = False
config.gate_proj = quant_data["gate_qweight"].data_ptr()
config.up_proj = quant_data["up_qweight"].data_ptr()
config.down_proj = quant_data["down_qweight"].data_ptr()
config.gate_scale = quant_data["gate_scales"].data_ptr()
config.up_scale = quant_data["up_scales"].data_ptr()
config.down_scale = quant_data["down_scales"].data_ptr()
config.pool = CPUInfer.backend_
moe = kt_kernel_ext.moe.AMXInt4_KGroup_MOE(config)
CPUInfer.submit(moe.load_weights_task(physical_to_logical_map.data_ptr()))
CPUInfer.sync()
gen_iter = 3000
expert_ids = (
torch.rand(gen_iter * qlen, expert_num, device="cpu")
.argsort(dim=-1)[:, :num_experts_per_tok]
.reshape(gen_iter, qlen * num_experts_per_tok)
.contiguous()
)
weights = torch.rand(
(gen_iter, qlen, num_experts_per_tok), dtype=torch.float32, device="cpu"
).contiguous()
input_tensor = torch.randn(
(qlen, hidden_size), dtype=torch.bfloat16, device="cpu"
).contiguous()
output_tensor = torch.empty_like(input_tensor)
bsz_tensor = torch.tensor([qlen], device="cpu")
for i in tqdm(range(warm_up_iter), desc="Warm-up"):
CPUInfer.submit(
moe.forward_task(
bsz_tensor.data_ptr(),
num_experts_per_tok,
expert_ids[i % gen_iter].data_ptr(),
weights[i % gen_iter].data_ptr(),
input_tensor.data_ptr(),
output_tensor.data_ptr(),
False,
)
)
CPUInfer.sync()
start = time.perf_counter()
for i in tqdm(range(test_iter), desc="Testing"):
CPUInfer.submit(
moe.forward_task(
bsz_tensor.data_ptr(),
num_experts_per_tok,
expert_ids[i % gen_iter].data_ptr(),
weights[i % gen_iter].data_ptr(),
input_tensor.data_ptr(),
output_tensor.data_ptr(),
False,
)
)
CPUInfer.sync()
end = time.perf_counter()
total_time = end - start
time_per_iter_us = total_time / test_iter * 1e6
bandwidth = (
hidden_size
* intermediate_size
* 3
* num_experts_per_tok
* (1 / 8 * 256 * (1 - (31 / 32) ** qlen))
* bytes_per_elem
* test_iter
/ total_time
/ 1e9
)
flops = (
hidden_size
* intermediate_size
* qlen
* 3
* num_experts_per_tok
* 2
* test_iter
/ total_time
/ 1e12
)
print("Quant mode: int4_k2")
print("Time(s): ", total_time)
print("Iteration: ", test_iter)
print("Time(us) per iteration: ", time_per_iter_us)
print("Bandwidth: ", bandwidth, "GB/s")
print("Flops: ", flops, "TFLOPS")
print("")
result = {
"quant_mode": "int4_k2",
"total_time_seconds": total_time,
"iterations": test_iter,
"time_per_iteration_us": time_per_iter_us,
"bandwidth_GBs": bandwidth,
"flops_TFLOPS": flops,
"timestamp": time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()),
"test_parameters": {
"expert_num": expert_num,
"hidden_size": hidden_size,
"intermediate_size": intermediate_size,
"max_len": max_len,
"num_experts_per_tok": num_experts_per_tok,
"qlen": qlen,
"warm_up_iter": warm_up_iter,
"test_iter": test_iter,
"k_group_size": k_group_size,
"bytes_per_elem": bytes_per_elem,
},
}
result.update(get_git_commit())
result.update(get_system_info())
record_results(result)
if __name__ == "__main__":
bench_k2_moe()

View file

@ -0,0 +1,309 @@
#!/usr/bin/env python
# coding=utf-8
"""
Benchmark write_weight_scale_to_buffer for AMX_K2_MOE_TP (int4 packed weights + bf16 scales).
"""
import json
import os
import platform
import subprocess
import sys
import time
from tqdm import tqdm
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "build"))
import kt_kernel_ext
import torch
# Benchmark parameters (single MoE, mirror examples/test_k2_write_buffer.py)
expert_num = 384
num_experts_per_tok = expert_num
gpu_tp_count = 4
warm_up_iter = 3
test_iter = 7
gpu_experts_num = expert_num
hidden_size = 7168
intermediate_size = 2048
group_size = 32
max_len = 1
physical_to_logical_map = torch.arange(expert_num, dtype=torch.int64, device="cpu").contiguous()
CPUInfer = kt_kernel_ext.CPUInfer(96)
def get_git_commit():
result = {}
try:
commit = (
subprocess.check_output(["git", "rev-parse", "HEAD"]).decode("utf-8").strip()
)
commit_msg = (
subprocess.check_output(["git", "log", "-1", "--pretty=%B"])
.decode("utf-8")
.strip()
)
result["commit"] = commit
result["commit_message"] = commit_msg
dirty_output = (
subprocess.check_output(["git", "status", "--porcelain"]).decode("utf-8").strip()
)
if dirty_output:
result["dirty"] = True
result["dirty_files"] = dirty_output.splitlines()
else:
result["dirty"] = False
except Exception as e:
result["commit"] = None
result["commit_message"] = None
result["dirty"] = None
result["error"] = str(e)
return result
def get_system_info():
info = {}
uname = platform.uname()
info["system_name"] = uname.system
info["node_name"] = uname.node
cpu_model = None
if os.path.exists("/proc/cpuinfo"):
try:
with open("/proc/cpuinfo", "r") as f:
for line in f:
if "model name" in line:
cpu_model = line.split(":", 1)[1].strip()
break
except Exception as e:
cpu_model = f"Error: {e}"
info["cpu_model"] = cpu_model
mem_total_gb = None
if os.path.exists("/proc/meminfo"):
try:
with open("/proc/meminfo", "r") as f:
for line in f:
if "MemTotal" in line:
mem_kb = float(line.split(":", 1)[1].split()[0])
mem_total_gb = round(mem_kb / (1024 * 1024), 2)
break
except Exception as e:
mem_total_gb = f"Error: {e}"
info["memory_size_GB"] = mem_total_gb
info["cpu_core_count"] = os.cpu_count()
sockets = set()
if os.path.exists("/proc/cpuinfo"):
try:
with open("/proc/cpuinfo", "r") as f:
for line in f:
if "physical id" in line:
sockets.add(line.split(":", 1)[1].strip())
except Exception:
sockets = set()
info["cpu_socket_count"] = len(sockets) if len(sockets) > 0 else 1
return info
script_path = os.path.abspath(__file__)
script_dir = os.path.dirname(script_path)
script_name = os.path.splitext(os.path.basename(script_path))[0]
json_path = os.path.join(script_dir, script_name + ".jsonl")
def record_results(result, filename=json_path):
with open(filename, "a") as f:
f.write(json.dumps(result) + "\n")
def allocate_weights():
per_mat_weight_bytes = (hidden_size * intermediate_size) // 2
per_mat_scale_elems = (hidden_size * intermediate_size) // group_size
gate_q = torch.randint(0, 256, (expert_num * per_mat_weight_bytes,), dtype=torch.uint8)
up_q = torch.randint(0, 256, (expert_num * per_mat_weight_bytes,), dtype=torch.uint8)
down_q = torch.randint(0, 256, (expert_num * per_mat_weight_bytes,), dtype=torch.uint8)
gate_scale = torch.randn(expert_num * per_mat_scale_elems, dtype=torch.bfloat16)
up_scale = torch.randn(expert_num * per_mat_scale_elems, dtype=torch.bfloat16)
down_scale = torch.randn(expert_num * per_mat_scale_elems, dtype=torch.bfloat16)
return (
gate_q.contiguous(),
up_q.contiguous(),
down_q.contiguous(),
gate_scale.contiguous(),
up_scale.contiguous(),
down_scale.contiguous(),
per_mat_weight_bytes,
per_mat_scale_elems,
)
def build_moe():
(
gate_q,
up_q,
down_q,
gate_scale,
up_scale,
down_scale,
per_mat_weight_bytes,
per_mat_scale_elems,
) = allocate_weights()
config = kt_kernel_ext.moe.MOEConfig(
expert_num, num_experts_per_tok, hidden_size, intermediate_size
)
config.max_len = max_len
config.quant_config.bits = 4
config.quant_config.group_size = group_size
config.quant_config.zero_point = False
config.pool = CPUInfer.backend_
config.gate_proj = gate_q.data_ptr()
config.up_proj = up_q.data_ptr()
config.down_proj = down_q.data_ptr()
config.gate_scale = gate_scale.data_ptr()
config.up_scale = up_scale.data_ptr()
config.down_scale = down_scale.data_ptr()
moe = kt_kernel_ext.moe.AMXInt4_KGroup_MOE(config)
CPUInfer.submit(moe.load_weights_task(physical_to_logical_map.data_ptr()))
CPUInfer.sync()
# Buffer sizing per TP
weight_bytes_per_expert_per_tp = per_mat_weight_bytes // gpu_tp_count
scale_elems_per_expert_per_tp = per_mat_scale_elems // gpu_tp_count
total_weight_bytes_per_tp = gpu_experts_num * weight_bytes_per_expert_per_tp
total_scale_elems_per_tp = gpu_experts_num * scale_elems_per_expert_per_tp
w13_weight_bufs = [
torch.empty(2 * total_weight_bytes_per_tp, dtype=torch.uint8) for _ in range(gpu_tp_count)
]
w13_scale_bufs = [
torch.empty(2 * total_scale_elems_per_tp, dtype=torch.bfloat16) for _ in range(gpu_tp_count)
]
w2_weight_bufs = [
torch.empty(total_weight_bytes_per_tp, dtype=torch.uint8) for _ in range(gpu_tp_count)
]
w2_scale_bufs = [
torch.empty(total_scale_elems_per_tp, dtype=torch.bfloat16) for _ in range(gpu_tp_count)
]
buffer_ptrs = {
"w13_weight_ptrs": [buf.data_ptr() for buf in w13_weight_bufs],
"w13_scale_ptrs": [buf.data_ptr() for buf in w13_scale_bufs],
"w2_weight_ptrs": [buf.data_ptr() for buf in w2_weight_bufs],
"w2_scale_ptrs": [buf.data_ptr() for buf in w2_scale_bufs],
}
buffer_shapes = {
"per_mat_weight_bytes": per_mat_weight_bytes,
"per_mat_scale_elems": per_mat_scale_elems,
"weight_bytes_per_expert_per_tp": weight_bytes_per_expert_per_tp,
"scale_elems_per_expert_per_tp": scale_elems_per_expert_per_tp,
"total_weight_bytes_per_tp": total_weight_bytes_per_tp,
"total_scale_elems_per_tp": total_scale_elems_per_tp,
}
keep_tensors = {
"gate_q": gate_q,
"up_q": up_q,
"down_q": down_q,
"gate_scale": gate_scale,
"up_scale": up_scale,
"down_scale": down_scale,
"w13_weight_bufs": w13_weight_bufs,
"w13_scale_bufs": w13_scale_bufs,
"w2_weight_bufs": w2_weight_bufs,
"w2_scale_bufs": w2_scale_bufs,
}
return moe, buffer_ptrs, buffer_shapes, keep_tensors
def bench_write_buffer():
moe, buffer_ptrs, buffer_shapes, keep_tensors = build_moe()
total_weights = hidden_size * intermediate_size * expert_num * 3
# Throughput accounting consistent with examples/test_k2_write_buffer.py
bytes_per_call = total_weights // group_size + total_weights // 2
# Warm-up
for _ in tqdm(range(warm_up_iter), desc="Warm-up"):
CPUInfer.submit(
moe.write_weight_scale_to_buffer_task(
gpu_tp_count=gpu_tp_count,
gpu_experts_num=gpu_experts_num,
**buffer_ptrs,
)
)
CPUInfer.sync()
total_time = 0
for _ in tqdm(range(test_iter), desc="Testing"):
start = time.perf_counter()
CPUInfer.submit(
moe.write_weight_scale_to_buffer_task(
gpu_tp_count=gpu_tp_count,
gpu_experts_num=gpu_experts_num,
**buffer_ptrs,
)
)
CPUInfer.sync()
end = time.perf_counter()
total_time += end - start
time.sleep(0.6)
print(end - start)
time_per_iter_us = total_time / test_iter * 1e6
bandwidth_gbs = bytes_per_call * test_iter / total_time / 1e9
print("write_weight_scale_to_buffer benchmark")
print("Time(s): ", total_time)
print("Iteration: ", test_iter)
print("Time(us) per iteration: ", time_per_iter_us)
print("Bandwidth: ", bandwidth_gbs, "GB/s")
print("")
result = {
"op": "write_weight_scale_to_buffer",
"total_time_seconds": total_time,
"iterations": test_iter,
"time_per_iteration_us": time_per_iter_us,
"bandwidth_GBs": bandwidth_gbs,
"timestamp": time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()),
"test_parameters": {
"expert_num": expert_num,
"hidden_size": hidden_size,
"intermediate_size": intermediate_size,
"group_size": group_size,
"max_len": max_len,
"num_experts_per_tok": num_experts_per_tok,
"gpu_tp_count": gpu_tp_count,
"gpu_experts_num": gpu_experts_num,
"warm_up_iter": warm_up_iter,
"test_iter": test_iter,
"bytes_per_call": bytes_per_call,
},
"buffer_shapes": buffer_shapes,
"keep_tensors_alive": list(keep_tensors.keys()),
}
result.update(get_git_commit())
result.update(get_system_info())
record_results(result)
if __name__ == "__main__":
bench_write_buffer()

View file

@ -0,0 +1,319 @@
import math
import os
import sys
from typing import Dict, Literal
sys.path.insert(0, os.path.dirname(__file__) + "/../build")
import torch
import kt_kernel_ext
torch.manual_seed(42)
hidden_size = 7168
intermediate_size = 2048
max_len = 25600
expert_num = 16
num_experts_per_tok = 8
qlen = 1
layer_num = 1
CPUInfer = kt_kernel_ext.CPUInfer(40)
validation_iter = 10
k_group_size = 32
debug_print_count = 16
physical_to_logical_map = torch.tensor(data=range(expert_num), device="cpu", dtype=torch.int64).contiguous()
def _pattern_uniform(groups: int) -> torch.Tensor:
return torch.full((groups,), 0.02, dtype=torch.float32)
def _pattern_alternating(groups: int) -> torch.Tensor:
vals = torch.full((groups,), 0.015, dtype=torch.float32)
vals[1::2] = 0.03
return vals
def _pattern_ramp(groups: int) -> torch.Tensor:
return torch.linspace(0.005, 0.04, steps=groups, dtype=torch.float32)
WEIGHT_PATTERNS = {
"uniform_scale": ("All k-groups share the same abs max / scale", _pattern_uniform),
"alternating_scale": ("Alternate small / large abs max per k-group", _pattern_alternating),
"ramp_scale": ("Linearly increasing abs max per k-group", _pattern_ramp),
"random": ("Random bf16 weights (baseline)", None),
}
def act_fn(x):
return x / (1.0 + torch.exp(-x))
def mlp_torch(input, gate_proj, up_proj, down_proj):
gate_buf = torch.mm(input, gate_proj.t())
up_buf = torch.mm(input, up_proj.t())
print(f"gate_buf: {gate_buf}")
print(f"up_buf: {up_buf}")
intermediate = act_fn(gate_buf) * up_buf
ret = torch.mm(intermediate, down_proj.t())
print(f"intermediate: {intermediate}")
print(f"mlp output: {ret}")
return ret
def moe_torch(input, expert_ids, weights, gate_proj, up_proj, down_proj):
cnts = expert_ids.new_zeros((expert_ids.shape[0], expert_num))
cnts.scatter_(1, expert_ids, 1)
tokens_per_expert = cnts.sum(dim=0)
idxs = expert_ids.view(-1).argsort()
sorted_tokens = input[idxs // expert_ids.shape[1]]
outputs = []
start_idx = 0
for i, num_tokens in enumerate(tokens_per_expert):
end_idx = start_idx + num_tokens
if num_tokens == 0:
continue
tokens_for_this_expert = sorted_tokens[start_idx:end_idx]
expert_out = mlp_torch(tokens_for_this_expert, gate_proj[i], up_proj[i], down_proj[i])
outputs.append(expert_out)
start_idx = end_idx
outs = torch.cat(outputs, dim=0) if len(outputs) else sorted_tokens.new_empty(0)
new_x = torch.empty_like(outs)
new_x[idxs] = outs
t_output = (
new_x.view(*expert_ids.shape, -1)
.type(weights.dtype)
.mul_(weights.unsqueeze(dim=-1))
.sum(dim=1)
.type(new_x.dtype)
)
return t_output
def pack_to_int32(value: torch.Tensor, num_bits: int, packed_dim: Literal[0, 1] = 1) -> torch.Tensor:
if value.dtype is not torch.int8:
raise ValueError("Tensor must be torch.int8 before packing")
if not (1 <= num_bits <= 8):
raise ValueError(f"num_bits must be in [1, 8], got {num_bits}")
offset = 1 << (num_bits - 1)
value = (value + offset).to(torch.uint8)
device = value.device
pack_factor = 32 // num_bits
if packed_dim == 0:
value = value.transpose(0, 1)
rows, cols = value.shape
padded_cols = math.ceil(cols / pack_factor) * pack_factor
pad_len = padded_cols - cols
if pad_len > 0:
value = torch.nn.functional.pad(value, (0, pad_len))
num_groups = padded_cols // pack_factor
# Use int32 here
reshaped = value.view(rows, num_groups, pack_factor).to(torch.int32)
bit_shifts = torch.arange(pack_factor, device=device, dtype=torch.int32) * num_bits
packed = (reshaped << bit_shifts).sum(dim=2, dtype=torch.int32)
if packed_dim == 0:
packed = packed.transpose(0, 1)
return packed
def pack_tensor_per_row(q: torch.Tensor, num_bits: int) -> torch.Tensor:
e, rows, cols = q.shape
flat = q.view(e * rows, cols)
packed = pack_to_int32(flat, num_bits)
return packed.view(e, rows, -1).contiguous()
def quantize_k2_tensor(weights: torch.Tensor, group_size: int):
"""
Symmetric max-abs/7 quantization per k-group following compressed_tensors packing.
Args:
weights: [expert_num, rows (N), cols (K)]
Returns:
packed_q: int32 tensor storing 8 int4s per element with shape [expert_num, rows * (cols // 8)]
scales: bfloat16 tensor with shape [expert_num, rows * (cols // group_size)]
"""
weights_f32 = weights.to(torch.float32)
e, rows, cols = weights_f32.shape
if cols % group_size != 0 or cols % 2 != 0:
raise ValueError(f"cols ({cols}) must be divisible by group_size ({group_size}) and 2")
reshaped = weights_f32.view(e, rows, cols // group_size, group_size)
max_abs = reshaped.abs().amax(dim=-1, keepdim=True)
max_abs = torch.clamp(max_abs, min=1e-8)
scales = (max_abs / 7.0).squeeze(-1)
q = torch.round(reshaped / scales.unsqueeze(-1)).clamp(-8, 7).to(torch.int8)
q = q.view(e, rows, cols)
packed = pack_tensor_per_row(q, num_bits=4).view(e, rows, cols // 8).contiguous()
scales = scales.to(torch.bfloat16).contiguous().view(e, rows, cols // group_size).contiguous()
print(f"Quantized weights: {packed.shape}, scales: {scales.shape}")
print(f"Quantized tensors: \n{packed},\n {scales}")
return packed, scales
def build_structured_tensor(shape: torch.Size, pattern: str) -> torch.Tensor:
if pattern == "random":
torch.manual_seed(42)
return (torch.randn(shape, dtype=torch.bfloat16, device="cpu") / 100.0).contiguous()
e, rows, cols = shape
groups = cols // k_group_size
group_builder = WEIGHT_PATTERNS[pattern][1]
group_vals = group_builder(groups).to(torch.float32)
block = group_vals.view(1, 1, groups, 1).expand(e, rows, groups, k_group_size).clone()
row_signs = torch.where(
(torch.arange(rows) % 2 == 0),
torch.ones(rows, dtype=torch.float32),
-torch.ones(rows, dtype=torch.float32),
).view(1, rows, 1, 1)
col_offsets = torch.linspace(-0.0005, 0.0005, steps=k_group_size, dtype=torch.float32).view(1, 1, 1, k_group_size)
block = block * row_signs + col_offsets
return block.reshape(shape).to(torch.bfloat16).contiguous()
def prepare_k2_quantized_weights(pattern: str) -> Dict[str, torch.Tensor]:
if pattern not in WEIGHT_PATTERNS:
raise ValueError(f"Unknown weight pattern: {pattern}")
gate_proj = build_structured_tensor((expert_num, intermediate_size, hidden_size), pattern)
up_proj = build_structured_tensor((expert_num, intermediate_size, hidden_size), pattern)
down_proj = build_structured_tensor((expert_num, hidden_size, intermediate_size), pattern)
gate_q, gate_scales = quantize_k2_tensor(gate_proj, k_group_size)
up_q, up_scales = quantize_k2_tensor(up_proj, k_group_size)
down_q, down_scales = quantize_k2_tensor(down_proj, k_group_size)
return {
"gate_qweight": gate_q.contiguous(),
"up_qweight": up_q.contiguous(),
"down_qweight": down_q.contiguous(),
"gate_scales": gate_scales.contiguous(),
"up_scales": up_scales.contiguous(),
"down_scales": down_scales.contiguous(),
"original_fp16": {
"gate_proj": gate_proj.to(torch.float16).contiguous(),
"up_proj": up_proj.to(torch.float16).contiguous(),
"down_proj": down_proj.to(torch.float16).contiguous(),
},
}
def build_moes_from_quantized_data(quant_data: Dict[str, torch.Tensor]):
moes = []
with torch.inference_mode(mode=True):
for _ in range(layer_num):
config = kt_kernel_ext.moe.MOEConfig(expert_num, num_experts_per_tok, hidden_size, intermediate_size, 0)
config.max_len = max_len
config.quant_config.bits = 4
config.quant_config.group_size = k_group_size
config.quant_config.zero_point = False
config.gate_proj = quant_data["gate_qweight"].data_ptr()
config.up_proj = quant_data["up_qweight"].data_ptr()
config.down_proj = quant_data["down_qweight"].data_ptr()
config.gate_scale = quant_data["gate_scales"].data_ptr()
config.up_scale = quant_data["up_scales"].data_ptr()
config.down_scale = quant_data["down_scales"].data_ptr()
config.pool = CPUInfer.backend_
moe = kt_kernel_ext.moe.AMXInt4_KGroup_MOE(config)
CPUInfer.submit(moe.load_weights_task(physical_to_logical_map.data_ptr()))
CPUInfer.sync()
# CPUInfer.submit(moe.warm_up_task())
# CPUInfer.sync()
moes.append(moe)
return moes
def run_case(pattern: str) -> Dict[str, float]:
print("\n" + "=" * 70)
desc = WEIGHT_PATTERNS[pattern][0]
print(f"Running case: {pattern} -> {desc}")
print("=" * 70)
quant_data = prepare_k2_quantized_weights(pattern)
moes = build_moes_from_quantized_data(quant_data)
original_weights = quant_data["original_fp16"]
gate_fp16 = original_weights["gate_proj"]
up_fp16 = original_weights["up_proj"]
down_fp16 = original_weights["down_proj"]
diffs = []
with torch.inference_mode(mode=True):
for i in range(validation_iter):
torch.manual_seed(100 + i)
bsz_tensor = torch.tensor([qlen], device="cpu")
expert_ids = torch.stack(
[torch.randperm(expert_num)[:num_experts_per_tok] for _ in range(qlen)]
).contiguous()
weights = torch.randn((qlen, num_experts_per_tok), dtype=torch.float32).contiguous()
input_tensor = torch.randn((qlen, hidden_size), dtype=torch.bfloat16).contiguous() / 100
output = torch.empty((qlen, hidden_size), dtype=torch.bfloat16).contiguous()
moe = moes[i % layer_num]
CPUInfer.submit(
moe.forward_task(
bsz_tensor.data_ptr(),
num_experts_per_tok,
expert_ids.data_ptr(),
weights.data_ptr(),
input_tensor.data_ptr(),
output.data_ptr(),
False,
)
)
CPUInfer.sync()
input_tensor_fp16 = input_tensor.to(torch.float16)
t_output = moe_torch(
input_tensor_fp16, expert_ids, weights, gate_fp16, up_fp16, down_fp16
).to(torch.bfloat16)
t_output = t_output.flatten()
output = output.flatten()
diff = torch.mean(torch.abs(output - t_output)) / (torch.mean(torch.abs(t_output)) + 1e-12)
diffs.append(diff.item())
print(f"[{pattern}] Iteration {i}: relative L1 diff = {diff:.4f}")
print(f" output {output}")
print(f" t_output {t_output}")
mean_diff = float(sum(diffs) / len(diffs))
max_diff = float(max(diffs))
min_diff = float(min(diffs))
return {"case": pattern, "description": desc, "mean": mean_diff, "max": max_diff, "min": min_diff}
def run_k2_moe_test():
summary_rows = []
for case_name in WEIGHT_PATTERNS.keys():
results = run_case(case_name)
summary_rows.append(results)
# break
print("\n=== Case vs. Relative Error Summary ===")
print(f"{'Case':<20} {'Mean':>10} {'Max':>10} {'Min':>10}")
for row in summary_rows:
print(f"{row['case']:<20} {row['mean']*100:9.2f}% {row['max']*100:9.2f}% {row['min']*100:9.2f}%")
if __name__ == "__main__":
run_k2_moe_test()

View file

@ -0,0 +1,267 @@
import os
import sys
import time
import torch
import numpy as np
# Ensure we can import the local extension
# REPO_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "../"))
# if REPO_ROOT not in sys.path:
# sys.path.insert(0, REPO_ROOT)
import kt_kernel_ext
from kt_kernel_ext import CPUInfer
def make_cpu_infer(thread_num=80):
return CPUInfer(thread_num)
def build_config(cpuinfer, expert_num, num_experts_per_tok, hidden_size, intermediate_size, group_size):
cfg = kt_kernel_ext.moe.MOEConfig(expert_num, num_experts_per_tok, hidden_size, intermediate_size)
cfg.max_len = 1
cfg.quant_config.bits = 4
cfg.quant_config.group_size = group_size
cfg.quant_config.zero_point = False
cfg.pool = cpuinfer.backend_
return cfg
def allocate_weights(expert_num, hidden_size, intermediate_size, group_size):
# packed int4 weights: 2 values per byte
per_mat_weight_bytes = (hidden_size * intermediate_size) // 2
per_mat_scale_elems = (hidden_size * intermediate_size) // group_size
gate_q = torch.randint(0, 256, (expert_num * per_mat_weight_bytes,), dtype=torch.uint8)
up_q = torch.randint(0, 256, (expert_num * per_mat_weight_bytes,), dtype=torch.uint8)
down_q = torch.randint(0, 256, (expert_num * per_mat_weight_bytes,), dtype=torch.uint8)
gate_scale = torch.randn(expert_num * per_mat_scale_elems, dtype=torch.bfloat16)
up_scale = torch.randn(expert_num * per_mat_scale_elems, dtype=torch.bfloat16)
down_scale = torch.randn(expert_num * per_mat_scale_elems, dtype=torch.bfloat16)
return (
gate_q,
up_q,
down_q,
gate_scale,
up_scale,
down_scale,
per_mat_weight_bytes,
per_mat_scale_elems,
)
def main():
torch.manual_seed(123)
expert_num = 256 # Total experts
gpu_experts = expert_num # Number of experts on GPU
gpu_tp_count = 2 # Number of TP parts
num_experts_per_tok = 8
hidden_size = 7168
intermediate_size = 2048
group_size = 32
cpuinfer = make_cpu_infer()
cfg = build_config(cpuinfer, expert_num, num_experts_per_tok, hidden_size, intermediate_size, group_size)
(
gate_q,
up_q,
down_q,
gate_scale,
up_scale,
down_scale,
per_mat_weight_bytes,
per_mat_scale_elems,
) = allocate_weights(expert_num, hidden_size, intermediate_size, group_size)
cfg.gate_proj = gate_q.data_ptr()
cfg.up_proj = up_q.data_ptr()
cfg.down_proj = down_q.data_ptr()
cfg.gate_scale = gate_scale.data_ptr()
cfg.up_scale = up_scale.data_ptr()
cfg.down_scale = down_scale.data_ptr()
moe = kt_kernel_ext.moe.AMXInt4_KGroup_MOE(cfg)
physical_to_logical_map = (
torch.arange(expert_num, dtype=torch.int64, device="cpu").contiguous()
)
cpuinfer.submit(moe.load_weights_task(physical_to_logical_map.data_ptr()))
cpuinfer.sync()
# TP configuration
# Since weights are col-major, we can directly divide the total size by tp_count
# Each matrix is divided into gpu_tp_count parts in memory order
# Calculate sizes per TP part (direct division since col-major)
weight_bytes_per_expert_per_tp = per_mat_weight_bytes // gpu_tp_count
scale_elems_per_expert_per_tp = per_mat_scale_elems // gpu_tp_count
# Total sizes for all gpu_experts
total_weight_bytes_per_tp = gpu_experts * weight_bytes_per_expert_per_tp
total_scale_elems_per_tp = gpu_experts * scale_elems_per_expert_per_tp
# Create buffer lists for w13 (gate+up) and w2 (down)
w13_weight_bufs = []
w13_scale_bufs = []
w2_weight_bufs = []
w2_scale_bufs = []
for tp_idx in range(gpu_tp_count):
# w13 combines gate and up, so needs 2x the size
w13_weight_bufs.append(torch.empty(2 * total_weight_bytes_per_tp, dtype=torch.uint8))
w13_scale_bufs.append(torch.empty(2 * total_scale_elems_per_tp, dtype=torch.bfloat16))
w2_weight_bufs.append(torch.empty(total_weight_bytes_per_tp, dtype=torch.uint8))
w2_scale_bufs.append(torch.empty(total_scale_elems_per_tp, dtype=torch.bfloat16))
# Get data pointers for all buffers
w13_weight_ptrs = [buf.data_ptr() for buf in w13_weight_bufs]
w13_scale_ptrs = [buf.data_ptr() for buf in w13_scale_bufs]
w2_weight_ptrs = [buf.data_ptr() for buf in w2_weight_bufs]
w2_scale_ptrs = [buf.data_ptr() for buf in w2_scale_bufs]
print(f"Total experts: {expert_num}, GPU experts: {gpu_experts}")
print(f"GPU TP count: {gpu_tp_count}")
print(f"Original per matrix weight bytes: {per_mat_weight_bytes}")
print(f"Original per matrix scale elements: {per_mat_scale_elems}")
print(f"Weight bytes per expert per TP: {weight_bytes_per_expert_per_tp}")
print(f"Scale elements per expert per TP: {scale_elems_per_expert_per_tp}")
print(f"Total weight bytes per TP (w13): {2 * total_weight_bytes_per_tp}")
print(f"Total weight bytes per TP (w2): {total_weight_bytes_per_tp}")
print(f"Total scale elements per TP (w13): {2 * total_scale_elems_per_tp}")
print(f"Total scale elements per TP (w2): {total_scale_elems_per_tp}")
for i in range(5):
cpuinfer.submit(
moe.write_weight_scale_to_buffer_task(
gpu_tp_count=gpu_tp_count,
gpu_experts_num=gpu_experts,
w13_weight_ptrs=w13_weight_ptrs,
w13_scale_ptrs=w13_scale_ptrs,
w2_weight_ptrs=w2_weight_ptrs,
w2_scale_ptrs=w2_scale_ptrs,
)
)
cpuinfer.sync()
begin_time = time.perf_counter_ns()
cpuinfer.submit(
moe.write_weight_scale_to_buffer_task(
gpu_tp_count=gpu_tp_count,
gpu_experts_num=gpu_experts,
w13_weight_ptrs=w13_weight_ptrs,
w13_scale_ptrs=w13_scale_ptrs,
w2_weight_ptrs=w2_weight_ptrs,
w2_scale_ptrs=w2_scale_ptrs,
)
)
cpuinfer.sync()
end_time = time.perf_counter_ns()
elapsed_ms = (end_time - begin_time) / 1000000
total_weights = hidden_size * intermediate_size * expert_num * 3
total_bytes = total_weights // group_size + total_weights // 2
print(f"write_weight_scale_to_buffer time: {elapsed_ms:.2f} ms")
print(f"Throughput: {total_bytes / (elapsed_ms * 1e6):.2f} GB/s")
def split_expert_tensor(tensor, chunk):
"""Split tensor by experts"""
return [tensor[i * chunk : (i + 1) * chunk] for i in range(expert_num)]
# Split by experts first
gate_q_experts = split_expert_tensor(gate_q, per_mat_weight_bytes)
up_q_experts = split_expert_tensor(up_q, per_mat_weight_bytes)
down_q_experts = split_expert_tensor(down_q, per_mat_weight_bytes)
gate_scale_experts = split_expert_tensor(gate_scale, per_mat_scale_elems)
up_scale_experts = split_expert_tensor(up_scale, per_mat_scale_elems)
down_scale_experts = split_expert_tensor(down_scale, per_mat_scale_elems)
# CPU TP count is always 2 in this test setup (one TP per NUMA node)
cpu_tp_count = 2
# Verify buffers for each TP part
for tp_idx in range(gpu_tp_count):
expected_w13_weights = []
expected_w13_scales = []
expected_w2_weights = []
expected_w2_scales = []
weight13_per_tp = per_mat_weight_bytes // gpu_tp_count
scale13_per_tp = per_mat_scale_elems // gpu_tp_count
# Process each GPU expert
for expert_idx in range(gpu_experts):
# For w13 (gate and up), the slicing is straightforward
start_weight = tp_idx * weight13_per_tp
end_weight = (tp_idx + 1) * weight13_per_tp
start_scale = tp_idx * scale13_per_tp
end_scale = (tp_idx + 1) * scale13_per_tp
# Gate
gate_weight_tp = gate_q_experts[expert_idx][start_weight:end_weight]
gate_scale_tp = gate_scale_experts[expert_idx][start_scale:end_scale]
# Up
up_weight_tp = up_q_experts[expert_idx][start_weight:end_weight]
up_scale_tp = up_scale_experts[expert_idx][start_scale:end_scale]
# Down matrix needs special handling because it's sliced column-wise
# We need to reconstruct it from column slices
down_weight_tp_parts = []
down_scale_tp_parts = []
# Iterate through each column to extract the corresponding parts
for col_idx in range(hidden_size):
col_weight_start = col_idx * (intermediate_size // 2)
col_scale_start = col_idx * (intermediate_size // group_size)
# Direct mapping: each CPU TP corresponds to a GPU TP
tp_slice_weight_size = (intermediate_size // gpu_tp_count) // 2
tp_slice_scale_size = (intermediate_size // gpu_tp_count) // group_size
tp_weight_offset = col_weight_start + tp_idx * tp_slice_weight_size
tp_scale_offset = col_scale_start + tp_idx * tp_slice_scale_size
down_weight_tp_parts.append(
down_q_experts[expert_idx][tp_weight_offset:tp_weight_offset + tp_slice_weight_size]
)
down_scale_tp_parts.append(
down_scale_experts[expert_idx][tp_scale_offset:tp_scale_offset + tp_slice_scale_size]
)
# Concatenate all column slices for this TP
down_weight_tp = torch.cat(down_weight_tp_parts)
down_scale_tp = torch.cat(down_scale_tp_parts)
expected_w13_weights.append(gate_weight_tp)
expected_w13_weights.append(up_weight_tp)
expected_w13_scales.append(gate_scale_tp)
expected_w13_scales.append(up_scale_tp)
expected_w2_weights.append(down_weight_tp)
expected_w2_scales.append(down_scale_tp)
# Concatenate all experts for this TP part
expected_w13_weight = torch.cat(expected_w13_weights)
expected_w13_scale = torch.cat(expected_w13_scales)
expected_w2_weight = torch.cat(expected_w2_weights)
expected_w2_scale = torch.cat(expected_w2_scales)
print(f"=== Checking TP part {tp_idx} ===")
# Assert all checks pass
assert torch.equal(w13_weight_bufs[tp_idx], expected_w13_weight), f"w13 weight bytes mismatch for TP {tp_idx}"
assert torch.allclose(w13_scale_bufs[tp_idx], expected_w13_scale), f"w13 scale values mismatch for TP {tp_idx}"
assert torch.equal(w2_weight_bufs[tp_idx], expected_w2_weight), f"w2 weight bytes mismatch for TP {tp_idx}"
assert torch.allclose(w2_scale_bufs[tp_idx], expected_w2_scale), f"w2 scale values mismatch for TP {tp_idx}"
print(f"\n✓ write_weight_scale_to_buffer passed: extracted {gpu_experts} GPU experts across {gpu_tp_count} TP parts from total {expert_num} experts")
if __name__ == "__main__":
main()

View file

@ -36,6 +36,7 @@ static const bool _is_plain_ = false;
#if defined(__x86_64__) && defined(USE_AMX_AVX_KERNEL) #if defined(__x86_64__) && defined(USE_AMX_AVX_KERNEL)
#include "operators/amx/awq-moe.hpp" #include "operators/amx/awq-moe.hpp"
#include "operators/amx/k2-moe.hpp"
#include "operators/amx/la/amx_kernels.hpp" #include "operators/amx/la/amx_kernels.hpp"
#include "operators/amx/moe.hpp" #include "operators/amx/moe.hpp"
#endif #endif
@ -43,6 +44,7 @@ static const bool _is_plain_ = false;
#include <cstdint> #include <cstdint>
#include <memory> #include <memory>
#include <type_traits>
#include "operators/kvcache/kvcache.h" #include "operators/kvcache/kvcache.h"
#include "operators/llamafile/linear.h" #include "operators/llamafile/linear.h"
@ -225,7 +227,9 @@ void bind_moe_module(py::module_& moe_module, const char* name) {
using MoeClass = TP_MOE<MoeTP>; using MoeClass = TP_MOE<MoeTP>;
using MoeBindings = MOEBindings<MoeTP>; using MoeBindings = MOEBindings<MoeTP>;
py::class_<MoeClass, MoE_Interface, std::shared_ptr<MoeClass>>(moe_module, name) auto moe_cls = py::class_<MoeClass, MoE_Interface, std::shared_ptr<MoeClass>>(moe_module, name);
moe_cls
.def(py::init<GeneralMOEConfig>()) .def(py::init<GeneralMOEConfig>())
.def("warm_up_task", &MoeBindings::WarmUpBindings::cpuinfer_interface) .def("warm_up_task", &MoeBindings::WarmUpBindings::cpuinfer_interface)
.def("load_weights_task", .def("load_weights_task",
@ -244,6 +248,53 @@ void bind_moe_module(py::module_& moe_module, const char* name) {
.def("warm_up", &MoeClass::warm_up) .def("warm_up", &MoeClass::warm_up)
.def("load_weights", &MoeClass::load_weights) .def("load_weights", &MoeClass::load_weights)
.def("forward", &MoeClass::forward_binding); .def("forward", &MoeClass::forward_binding);
#if defined(__x86_64__) && defined(USE_AMX_AVX_KERNEL)
if constexpr (std::is_same_v<MoeTP, AMX_K2_MOE_TP<amx::GemmKernel224Int4SmallKGroup>>) {
struct WriteWeightScaleToBufferBindings {
struct Args {
CPUInfer* cpuinfer;
MoeClass* moe;
int gpu_tp_count;
int gpu_experts_num;
std::vector<uintptr_t> w13_weight_ptrs;
std::vector<uintptr_t> w13_scale_ptrs;
std::vector<uintptr_t> w2_weight_ptrs;
std::vector<uintptr_t> w2_scale_ptrs;
};
static void inner(void* args) {
Args* args_ = (Args*)args;
args_->cpuinfer->enqueue(&MoeClass::write_weight_scale_to_buffer, args_->moe,
args_->gpu_tp_count, args_->gpu_experts_num,
args_->w13_weight_ptrs, args_->w13_scale_ptrs,
args_->w2_weight_ptrs, args_->w2_scale_ptrs);
}
static std::pair<intptr_t, intptr_t> cpuinfer_interface(std::shared_ptr<MoeClass> moe,
int gpu_tp_count, int gpu_experts_num,
py::list w13_weight_ptrs, py::list w13_scale_ptrs,
py::list w2_weight_ptrs, py::list w2_scale_ptrs) {
// Convert Python lists to std::vector<uintptr_t>
std::vector<uintptr_t> w13_weight_vec, w13_scale_vec, w2_weight_vec, w2_scale_vec;
for (auto item : w13_weight_ptrs) w13_weight_vec.push_back(py::cast<uintptr_t>(item));
for (auto item : w13_scale_ptrs) w13_scale_vec.push_back(py::cast<uintptr_t>(item));
for (auto item : w2_weight_ptrs) w2_weight_vec.push_back(py::cast<uintptr_t>(item));
for (auto item : w2_scale_ptrs) w2_scale_vec.push_back(py::cast<uintptr_t>(item));
Args* args = new Args{nullptr, moe.get(), gpu_tp_count, gpu_experts_num,
w13_weight_vec, w13_scale_vec, w2_weight_vec, w2_scale_vec};
return std::make_pair((intptr_t)&inner, (intptr_t)args);
}
};
moe_cls.def("write_weight_scale_to_buffer_task", &WriteWeightScaleToBufferBindings::cpuinfer_interface,
py::arg("gpu_tp_count"), py::arg("gpu_experts_num"),
py::arg("w13_weight_ptrs"), py::arg("w13_scale_ptrs"),
py::arg("w2_weight_ptrs"), py::arg("w2_scale_ptrs"));
}
#endif
} }
PYBIND11_MODULE(kt_kernel_ext, m) { PYBIND11_MODULE(kt_kernel_ext, m) {
@ -513,6 +564,7 @@ PYBIND11_MODULE(kt_kernel_ext, m) {
bind_moe_module<AMX_MOE_TP<amx::GemmKernel224Int4>>(moe_module, "AMXInt4_MOE"); bind_moe_module<AMX_MOE_TP<amx::GemmKernel224Int4>>(moe_module, "AMXInt4_MOE");
bind_moe_module<AMX_MOE_TP<amx::GemmKernel224Int4_1>>(moe_module, "AMXInt4_1_MOE"); bind_moe_module<AMX_MOE_TP<amx::GemmKernel224Int4_1>>(moe_module, "AMXInt4_1_MOE");
bind_moe_module<AMX_AWQ_MOE_TP<amx::GemmKernel224Int4_1_LowKGroup>>(moe_module, "AMXInt4_1KGroup_MOE"); bind_moe_module<AMX_AWQ_MOE_TP<amx::GemmKernel224Int4_1_LowKGroup>>(moe_module, "AMXInt4_1KGroup_MOE");
bind_moe_module<AMX_K2_MOE_TP<amx::GemmKernel224Int4SmallKGroup>>(moe_module, "AMXInt4_KGroup_MOE");
#endif #endif
#if defined(USE_MOE_KERNEL) #if defined(USE_MOE_KERNEL)
bind_moe_module<MOE_KERNEL_TP<moe_kernel::GemmKernelInt8, _is_plain_>>(moe_module, "Int8_KERNEL_MOE"); bind_moe_module<MOE_KERNEL_TP<moe_kernel::GemmKernelInt8, _is_plain_>>(moe_module, "Int8_KERNEL_MOE");

View file

@ -0,0 +1,929 @@
/**
* @Description : Skeleton for K2 AMX MoE operator.
* @Author : Codex
* @Date : 2024-07-22
* @Version : 0.1.0
* @LastEditors : Codex
* @LastEditTime : 2024-07-22
* @Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
**/
#ifndef CPUINFER_OPERATOR_AMX_K2_MOE_H
#define CPUINFER_OPERATOR_AMX_K2_MOE_H
// #define DEBUG_K2_MOE
#include <cstddef>
#include <cstdint>
#include <cstring>
// #define FORWARD_TIME_PROFILE
// #define FORWARD_TIME_REPORT
#include <immintrin.h>
#include <algorithm>
#include <chrono>
#include <cmath>
#include <cstdio>
#include <filesystem>
#include <fstream>
#include <string>
#include <vector>
#include "../../cpu_backend/shared_mem_buffer.h"
#include "../../cpu_backend/worker_pool.h"
#include "../common.hpp"
#include "../moe-tp.hpp"
#include "la/amx.hpp"
#include "llama.cpp/ggml.h"
template <class T>
class AMX_K2_MOE_TP {
private:
int tp_part_idx = 0;
void* gate_proj_ = nullptr; // [expert_num * intermediate_size * hidden_size ( /32 if quantized)]
void* up_proj_ = nullptr; // [expert_num * intermediate_size * hidden_size ( /32 if quantized)]
void* down_proj_ = nullptr; // [expert_num * hidden_size * intermediate_size ( /32 if quantized)]
ggml_bf16_t* m_local_input_ = nullptr; // [num_experts_per_tok * max_len * hidden_size]
ggml_bf16_t* m_local_gate_output_ = nullptr; // [num_experts_per_tok * max_len * intermediate_size]
ggml_bf16_t* m_local_up_output_ = nullptr; // [num_experts_per_tok * max_len * intermediate_size]
ggml_bf16_t* m_local_down_output_ = nullptr; // [num_experts_per_tok * max_len * hidden_size]
std::vector<std::vector<int>> m_local_pos_; // [max_len, num_experts_per_tok]
std::vector<int> m_local_num_; // [expert_num]
std::vector<int> m_expert_id_map_; // [expert_num]
std::vector<ggml_bf16_t*> m_local_input_ptr_; // [expert_num]
std::vector<ggml_bf16_t*> m_local_gate_output_ptr_; // [expert_num]
std::vector<ggml_bf16_t*> m_local_up_output_ptr_; // [expert_num]
std::vector<ggml_bf16_t*> m_local_down_output_ptr_; // [expert_num]
std::vector<std::shared_ptr<typename T::BufferA>> gate_up_ba_;
std::vector<std::shared_ptr<typename T::BufferB>> gate_bb_;
std::vector<std::shared_ptr<typename T::BufferC>> gate_bc_;
std::vector<std::shared_ptr<typename T::BufferB>> up_bb_;
std::vector<std::shared_ptr<typename T::BufferC>> up_bc_;
std::vector<std::shared_ptr<typename T::BufferA>> down_ba_;
std::vector<std::shared_ptr<typename T::BufferB>> down_bb_;
std::vector<std::shared_ptr<typename T::BufferC>> down_bc_;
#ifdef CHECK
char verify_bb[100000000];
char check_bb[100000000];
uint8_t compare_expers = 3;
#endif
#ifdef CHECK
inline void load_check() {
// TODO: implement load_check for verification.
}
void verify_load_right() {
// TODO: implement verification helpers.
}
#endif
inline void dump_buffer_b(const std::string &quantization_type, int expert_idx, const std::string &matrix_type,
typename T::BufferB *buffer) {
auto &quant_config = config_.quant_config;
int &group_size = quant_config.group_size;
printf("[DUMP_BUFFER_B] TP%d %s Expert%d %s:\n", tp_part_idx, quantization_type.c_str(), expert_idx,
matrix_type.c_str());
// Calculate dimensions based on matrix type
int rows, cols, num_groups;
size_t scale_elem_count;
if (matrix_type == "gate" || matrix_type == "up") {
rows = config_.intermediate_size;
cols = config_.hidden_size;
num_groups = cols / group_size;
scale_elem_count = num_groups * rows;
} else { // down
rows = config_.hidden_size;
cols = config_.intermediate_size;
num_groups = cols / group_size;
scale_elem_count = num_groups * rows;
}
// Dump scales (as float)
printf(" Scales[first 16]: ");
for (int i = 0; i < std::min(16, (int)scale_elem_count); i++) {
printf("%.6f ", buffer->d[i]);
}
printf("\n");
if (scale_elem_count > 16) {
printf(" Scales[last 16]: ");
int start_idx = std::max(0, (int)scale_elem_count - 16);
for (int i = start_idx; i < (int)scale_elem_count; i++) {
printf("%.6f ", buffer->d[i]);
}
printf("\n");
}
// Dump quantized weights (as hex uint8)
size_t weight_size = (rows * cols) / 2; // INT4 packed
uint8_t *weight_ptr = (uint8_t *)buffer->b;
printf(" Weights[first 32 bytes]: ");
for (int i = 0; i < std::min(32, (int)weight_size); i++) {
printf("%02x ", weight_ptr[i]);
}
printf("\n");
if (weight_size > 32) {
printf(" Weights[last 32 bytes]: ");
int start_idx = std::max(32, (int)weight_size - 32);
for (int i = start_idx; i < (int)weight_size; i++) {
printf("%02x ", weight_ptr[i]);
}
printf("\n");
}
printf(" Matrix dimensions: %dx%d, Groups: %d, Group size: %d, Scale elements: %zu\n", rows, cols, num_groups,
group_size, scale_elem_count);
printf("\n");
fflush(stdout);
}
#ifdef FORWARD_TIME_REPORT
std::chrono::time_point<std::chrono::high_resolution_clock> last_now;
#endif
public:
using input_t = ggml_bf16_t;
using output_t = float;
GeneralMOEConfig config_;
static constexpr double ELEMENT_SIZE = T::ELEMENT_SIZE;
AMX_K2_MOE_TP(GeneralMOEConfig config, int tp_part_idx_) {
auto& quant_config = config.quant_config;
int& group_size = quant_config.group_size;
if (quant_config.group_size == 0 || quant_config.zero_point) {
throw std::runtime_error("Kimi-K2 MoE only support KGroup Int4");
}
printf("Creating AMX_K2_MOE_TP %d at numa %d\n", tp_part_idx_, numa_node_of_cpu(sched_getcpu()));
auto& load = config.load;
auto& save = config.save;
if (load && config.path == "") {
load = false;
}
this->tp_part_idx = tp_part_idx_;
config_ = config;
gate_proj_ = config_.gate_proj;
up_proj_ = config_.up_proj;
down_proj_ = config_.down_proj;
MemoryRequest mem_requests;
mem_requests.append_pointer(
&m_local_input_, sizeof(ggml_bf16_t) * config_.num_experts_per_tok * config_.max_len * config_.hidden_size);
mem_requests.append_pointer(&m_local_gate_output_, sizeof(ggml_bf16_t) * config_.num_experts_per_tok *
config_.max_len * config_.intermediate_size);
mem_requests.append_pointer(&m_local_up_output_, sizeof(ggml_bf16_t) * config_.num_experts_per_tok *
config_.max_len * config_.intermediate_size);
mem_requests.append_pointer(&m_local_down_output_, sizeof(ggml_bf16_t) * config_.num_experts_per_tok *
config_.max_len * config_.hidden_size);
m_local_pos_.resize(config_.max_len);
for (int i = 0; i < config_.max_len; i++) {
m_local_pos_[i].resize(config_.num_experts_per_tok);
}
m_expert_id_map_.resize(config_.expert_num);
m_local_num_.resize(config_.expert_num);
m_local_input_ptr_.resize(config_.expert_num);
m_local_gate_output_ptr_.resize(config_.expert_num);
m_local_up_output_ptr_.resize(config_.expert_num);
m_local_down_output_ptr_.resize(config_.expert_num);
for (size_t i = 0; i < config_.expert_num; i++) {
gate_up_ba_.push_back(
std::make_shared<typename T::BufferA>(config_.max_len, config_.hidden_size, group_size, nullptr));
gate_bc_.push_back(std::make_shared<typename T::BufferC>(config_.max_len, config_.intermediate_size, nullptr));
up_bc_.push_back(std::make_shared<typename T::BufferC>(config_.max_len, config_.intermediate_size, nullptr));
down_ba_.push_back(
std::make_shared<typename T::BufferA>(config_.max_len, config_.intermediate_size, group_size, nullptr));
down_bc_.push_back(std::make_shared<typename T::BufferC>(config_.max_len, config_.hidden_size, nullptr));
void* gate_bb_ptr =
std::aligned_alloc(64, T::BufferB::required_size(config_.intermediate_size, config_.hidden_size, group_size));
gate_bb_.push_back(std::make_shared<typename T::BufferB>(config_.intermediate_size, config_.hidden_size,
group_size, gate_bb_ptr));
void* up_bb_ptr =
std::aligned_alloc(64, T::BufferB::required_size(config_.intermediate_size, config_.hidden_size, group_size));
up_bb_.push_back(
std::make_shared<typename T::BufferB>(config_.intermediate_size, config_.hidden_size, group_size, up_bb_ptr));
void* down_bb_ptr =
std::aligned_alloc(64, T::BufferB::required_size(config_.hidden_size, config_.intermediate_size, group_size));
down_bb_.push_back(std::make_shared<typename T::BufferB>(config_.hidden_size, config_.intermediate_size,
group_size, down_bb_ptr));
}
for (int i = 0; i < config_.expert_num; i++) {
mem_requests.append_function([this, i](void* new_ptr) { gate_up_ba_[i]->set_data(new_ptr); },
T::BufferA::required_size(config_.max_len, config_.hidden_size, group_size));
mem_requests.append_function([this, i](void* new_ptr) { gate_bc_[i]->set_data(new_ptr); },
T::BufferC::required_size(config_.max_len, config_.intermediate_size));
mem_requests.append_function([this, i](void* new_ptr) { up_bc_[i]->set_data(new_ptr); },
T::BufferC::required_size(config_.max_len, config_.intermediate_size));
mem_requests.append_function([this, i](void* new_ptr) { down_ba_[i]->set_data(new_ptr); },
T::BufferA::required_size(config_.max_len, config_.intermediate_size, group_size));
mem_requests.append_function([this, i](void* new_ptr) { down_bc_[i]->set_data(new_ptr); },
T::BufferC::required_size(config_.max_len, config_.hidden_size));
}
shared_mem_buffer_numa.alloc(tp_part_idx, this, mem_requests);
}
~AMX_K2_MOE_TP() = default;
void load_weights() {
auto& quant_config = config_.quant_config;
int& group_size = quant_config.group_size;
const uint64_t* physical_to_logical_map = (const uint64_t*)config_.physical_to_logical_map;
auto pool = config_.pool->get_subpool(tp_part_idx);
if (quant_config.group_size == 0 || quant_config.zero_point) {
throw std::runtime_error("Kimi AVX MOE only support KGroup Int4.");
}
if (config_.gate_scale == nullptr) {
throw std::runtime_error("Kimi AVX MOE only support load native weight.");
}
// load weight
int nth = T::recommended_nth(config_.intermediate_size);
pool->do_work_stealing_job(
nth * config_.expert_num, nullptr,
[this, nth, physical_to_logical_map](int task_id) {
uint64_t expert_idx = task_id / nth;
uint64_t logical_expert_id = expert_map(physical_to_logical_map, expert_idx);
int ith = task_id % nth;
// gate part
gate_bb_[expert_idx]->from_raw_mat(
(uint8_t*)config_.gate_proj +
((logical_expert_id * config_.intermediate_size * config_.hidden_size) >> 1),
ith, nth);
// up part
up_bb_[expert_idx]->from_raw_mat(
(uint8_t*)config_.up_proj +
((logical_expert_id * config_.intermediate_size * config_.hidden_size) >> 1),
ith, nth);
},
nullptr);
nth = T::recommended_nth(config_.hidden_size);
pool->do_work_stealing_job(
nth * config_.expert_num, nullptr,
[this, nth, physical_to_logical_map](int task_id) {
uint64_t expert_idx = task_id / nth;
uint64_t logical_expert_id = expert_map(physical_to_logical_map, expert_idx);
int ith = task_id % nth;
// down part
down_bb_[expert_idx]->from_raw_mat(
(uint8_t*)config_.down_proj +
((logical_expert_id * config_.hidden_size * config_.intermediate_size) >> 1),
ith, nth);
},
nullptr);
pool->do_work_stealing_job(
config_.expert_num, nullptr,
[this, physical_to_logical_map](int task_id) {
uint64_t expert_idx = task_id;
uint64_t logical_expert_id = expert_map(physical_to_logical_map, expert_idx);
size_t scale_elem_count =
(config_.hidden_size * config_.intermediate_size) / config_.quant_config.group_size;
// convert scales from BF16 to FP32
convert_or_copy(gate_bb_[expert_idx]->d,
(ggml_bf16_t*)config_.gate_scale + (logical_expert_id * scale_elem_count),
scale_elem_count);
convert_or_copy(up_bb_[expert_idx]->d,
(ggml_bf16_t*)config_.up_scale + (logical_expert_id * scale_elem_count),
scale_elem_count);
convert_or_copy(down_bb_[expert_idx]->d,
(ggml_bf16_t*)config_.down_scale + (logical_expert_id * scale_elem_count),
scale_elem_count);
},
nullptr);
// dump_buffer_b("native", 0, "down", down_bb_[0].get());
}
// Reconstruct weights for all experts to the output buffers
// This function handles the TP-specific portion of the reconstruction for all experts
void write_weights_to_buffer(int gpu_tp_count, int cpu_tp_count, int num_experts, const GeneralMOEConfig& full_config,
const std::vector<uintptr_t>& w13_weight_ptrs,
const std::vector<uintptr_t>& w13_scale_ptrs,
const std::vector<uintptr_t>& w2_weight_ptrs,
const std::vector<uintptr_t>& w2_scale_ptrs) const {
const int group_size = config_.quant_config.group_size;
auto pool = config_.pool->get_subpool(tp_part_idx);
// Calculate sizes for CPU TP part (this instance)
size_t cpu_tp_weight_elem_count = (size_t)config_.intermediate_size * config_.hidden_size;
size_t cpu_tp_weight_bytes = cpu_tp_weight_elem_count / 2; // int4 packing
size_t cpu_tp_scale_elem_count = cpu_tp_weight_elem_count / group_size;
// Calculate sizes for GPU TP part
size_t gpu_tp_weight_elem_count = (size_t)full_config.intermediate_size * full_config.hidden_size / gpu_tp_count;
size_t gpu_tp_weight_bytes = gpu_tp_weight_elem_count / 2; // int4 packing
size_t gpu_tp_scale_elem_count = gpu_tp_weight_elem_count / group_size;
// Determine mapping: which GPU TP parts should this CPU TP part write to?
// Since weights are col-major and we slice directly by memory order:
// - If cpu_tp_count >= gpu_tp_count: multiple(or one) CPU TPs write to one GPU TP
// - If cpu_tp_count < gpu_tp_count: one CPU TP writes to multiple GPU TPs
if (cpu_tp_count >= gpu_tp_count) {
// Multiple CPU TPs map to one GPU TP
int target_gpu_tp = tp_part_idx / (cpu_tp_count / gpu_tp_count);
int local_idx = tp_part_idx % (cpu_tp_count / gpu_tp_count);
// Get pointers for this GPU TP part
uint8_t* w13_weight_dst = (uint8_t*)w13_weight_ptrs[target_gpu_tp];
ggml_bf16_t* w13_scale_dst = (ggml_bf16_t*)w13_scale_ptrs[target_gpu_tp];
uint8_t* w2_weight_dst = (uint8_t*)w2_weight_ptrs[target_gpu_tp];
ggml_bf16_t* w2_scale_dst = (ggml_bf16_t*)w2_scale_ptrs[target_gpu_tp];
// Calculate offset within the GPU TP buffer
size_t offset_in_gpu_weight = local_idx * cpu_tp_weight_bytes;
size_t offset_in_gpu_scale = local_idx * cpu_tp_scale_elem_count;
// Process only the first num_experts experts (GPU experts)
int nth = T::recommended_nth(config_.intermediate_size);
nth = 1;
pool->do_work_stealing_job(
nth * num_experts, nullptr,
[&, this](int task_id) {
int expert_id = task_id / nth;
// int ith = task_id % nth;
// auto [n_start, n_end] = T::split_range_n(config_.intermediate_size, ith, nth);
// Calculate base offsets for this expert in the GPU buffers
// For w13: each expert has gate+up, so the offset needs to account for 2x size
size_t w13_expert_base_weight = expert_id * 2 * gpu_tp_weight_bytes;
size_t w13_expert_base_scale = expert_id * 2 * gpu_tp_scale_elem_count;
size_t w2_expert_base_weight = expert_id * gpu_tp_weight_bytes;
size_t w2_expert_base_scale = expert_id * gpu_tp_scale_elem_count;
// Gate (first part of w13 for this expert)
uint8_t* gate_weight_src = (uint8_t*)gate_bb_[expert_id]->b;
float* gate_scale_src = gate_bb_[expert_id]->d;
std::memcpy(w13_weight_dst + w13_expert_base_weight + offset_in_gpu_weight,
gate_weight_src, cpu_tp_weight_bytes);
convert_or_copy((ggml_bf16_t*)(w13_scale_dst + w13_expert_base_scale + offset_in_gpu_scale),
gate_scale_src, cpu_tp_scale_elem_count);
// Up (second part of w13 for this expert, immediately after gate)
uint8_t* up_weight_src = (uint8_t*)up_bb_[expert_id]->b;
float* up_scale_src = up_bb_[expert_id]->d;
std::memcpy(w13_weight_dst + w13_expert_base_weight + offset_in_gpu_weight + gpu_tp_weight_bytes,
up_weight_src, cpu_tp_weight_bytes);
convert_or_copy((ggml_bf16_t*)(w13_scale_dst + w13_expert_base_scale + offset_in_gpu_scale + gpu_tp_scale_elem_count),
up_scale_src, cpu_tp_scale_elem_count);
// Down (w2) - need to handle column-wise slicing
// The down matrix is transposed compared to gate/up, so we need to extract by columns
// When multiple CPU TPs map to one GPU TP, each CPU TP has a slice of intermediate dimension
// CPU TP internal layout: each column has config_.intermediate_size elements
// GPU expects: each column has full_config.intermediate_size elements
size_t cpu_tps_per_gpu = cpu_tp_count / gpu_tp_count;
for (size_t col = 0; col < config_.hidden_size; col++) {
// GPU buffer column width is full_config.intermediate_size / gpu_tp_count
size_t gpu_col_offset = col * ((full_config.intermediate_size / gpu_tp_count) >> 1);
size_t cpu_col_offset = col * (config_.intermediate_size >> 1);
size_t gpu_col_slice_offset = local_idx * (config_.intermediate_size >> 1);
std::memcpy(w2_weight_dst + w2_expert_base_weight + gpu_col_offset + gpu_col_slice_offset,
(uint8_t*)down_bb_[expert_id]->b + cpu_col_offset,
config_.intermediate_size / 2);
// Same for scales
size_t gpu_scale_col_offset = col * ((full_config.intermediate_size / gpu_tp_count) / group_size);
size_t cpu_scale_col_offset = col * (config_.intermediate_size / group_size);
size_t gpu_scale_slice_offset = local_idx * (config_.intermediate_size / group_size);
convert_or_copy((ggml_bf16_t*)(w2_scale_dst + w2_expert_base_scale + gpu_scale_col_offset + gpu_scale_slice_offset),
down_bb_[expert_id]->d + cpu_scale_col_offset,
config_.intermediate_size / group_size);
}
},
nullptr);
} else {
// cpu_tp_count < gpu_tp_count: one CPU TP writes to multiple GPU TPs
// Each CPU TP part contains data for multiple GPU TP parts
int gpu_tps_per_cpu_tp = gpu_tp_count / cpu_tp_count;
// This CPU TP part writes to GPU TP indices: [start_gpu_tp, start_gpu_tp + gpu_tps_per_cpu_tp)
int start_gpu_tp = tp_part_idx * gpu_tps_per_cpu_tp;
// Size of data per GPU TP within this CPU TP
size_t data_per_gpu_tp_weight = cpu_tp_weight_bytes / gpu_tps_per_cpu_tp;
size_t data_per_gpu_tp_scale = cpu_tp_scale_elem_count / gpu_tps_per_cpu_tp;
// Process all experts for this GPU TP
pool->do_work_stealing_job(
gpu_tps_per_cpu_tp * num_experts, nullptr,
[&, this](int task_id) {
int expert_id = task_id % num_experts;
int local_gpu_idx = task_id / num_experts;
int gpu_tp_idx = start_gpu_tp + local_gpu_idx;
// Get pointers for this GPU TP part
uint8_t* w13_weight_dst = (uint8_t*)w13_weight_ptrs[gpu_tp_idx];
ggml_bf16_t* w13_scale_dst = (ggml_bf16_t*)w13_scale_ptrs[gpu_tp_idx];
uint8_t* w2_weight_dst = (uint8_t*)w2_weight_ptrs[gpu_tp_idx];
ggml_bf16_t* w2_scale_dst = (ggml_bf16_t*)w2_scale_ptrs[gpu_tp_idx];
// Calculate offsets within CPU TP buffers
size_t cpu_offset_weight = local_gpu_idx * data_per_gpu_tp_weight;
size_t cpu_offset_scale = local_gpu_idx * data_per_gpu_tp_scale;
// Calculate offsets for this expert in GPU buffers
// For w13: each expert has gate+up, so the offset needs to account for 2x size
size_t w13_gpu_expert_offset_weight = expert_id * 2 * gpu_tp_weight_bytes;
size_t w13_gpu_expert_offset_scale = expert_id * 2 * gpu_tp_scale_elem_count;
size_t w2_gpu_expert_offset_weight = expert_id * gpu_tp_weight_bytes;
size_t w2_gpu_expert_offset_scale = expert_id * gpu_tp_scale_elem_count;
// Gate (first part of w13 for this expert)
uint8_t* gate_weight_src = (uint8_t*)gate_bb_[expert_id]->b + cpu_offset_weight;
float* gate_scale_src = gate_bb_[expert_id]->d + cpu_offset_scale;
std::memcpy(w13_weight_dst + w13_gpu_expert_offset_weight,
gate_weight_src, data_per_gpu_tp_weight);
convert_or_copy((ggml_bf16_t*)(w13_scale_dst + w13_gpu_expert_offset_scale),
gate_scale_src, data_per_gpu_tp_scale);
// Up (second part of w13 for this expert, immediately after gate)
uint8_t* up_weight_src = (uint8_t*)up_bb_[expert_id]->b + cpu_offset_weight;
float* up_scale_src = up_bb_[expert_id]->d + cpu_offset_scale;
std::memcpy(w13_weight_dst + w13_gpu_expert_offset_weight + gpu_tp_weight_bytes,
up_weight_src, data_per_gpu_tp_weight);
convert_or_copy((ggml_bf16_t*)(w13_scale_dst + w13_gpu_expert_offset_scale + gpu_tp_scale_elem_count),
up_scale_src, data_per_gpu_tp_scale);
// Down (w2) - need to handle column-wise slicing
// The down matrix is transposed compared to gate/up, so we need to extract by columns
for (size_t col = 0; col < config_.hidden_size; col++) {
// Calculate the offset within the column for this GPU TP part
size_t col_offset_weight = (col * config_.intermediate_size / 2) + (local_gpu_idx * data_per_gpu_tp_weight / config_.hidden_size);
size_t col_offset_scale = (col * (config_.intermediate_size / group_size)) + (local_gpu_idx * data_per_gpu_tp_scale / config_.hidden_size);
// Copy weights column by column
std::memcpy(w2_weight_dst + w2_gpu_expert_offset_weight + (col * (config_.intermediate_size / gpu_tps_per_cpu_tp) / 2),
(uint8_t*)down_bb_[expert_id]->b + col_offset_weight,
(config_.intermediate_size / gpu_tps_per_cpu_tp) / 2);
// Copy scales column by column
convert_or_copy((ggml_bf16_t*)(w2_scale_dst + w2_gpu_expert_offset_scale + col * ((config_.intermediate_size / gpu_tps_per_cpu_tp) / group_size)),
down_bb_[expert_id]->d + col_offset_scale,
(config_.intermediate_size / gpu_tps_per_cpu_tp) / group_size);
}
},
nullptr);
}
}
void warm_up() {
int qlen = config_.max_len;
std::vector<uint8_t> input(sizeof(ggml_bf16_t) * qlen * config_.hidden_size);
std::vector<uint8_t> output(sizeof(ggml_bf16_t) * qlen * config_.hidden_size);
std::vector<int64_t> expert_ids(qlen * config_.num_experts_per_tok);
std::vector<float> weights(qlen * config_.num_experts_per_tok);
for (int i = 0; i < qlen * config_.num_experts_per_tok; i++) {
expert_ids[i] = i % config_.expert_num;
weights[i] = 0.01;
}
forward(qlen, config_.num_experts_per_tok, expert_ids.data(), weights.data(), input.data(), output.data());
}
void forward(int qlen, int k, const int64_t* expert_ids, const float* weights, const void* input, void* output) {
if (qlen > 1) {
forward_prefill(qlen, k, expert_ids, weights, input, output);
} else {
forward_decode(k, expert_ids, weights, input, output);
}
}
#ifndef DIRECT_OR_POOL_BY_QLEN
#define DIRECT_OR_POOL_BY_QLEN(var, fn) \
do { \
if (qlen < 10) { \
for (int i = 0; i < (var); i++) { \
(fn)(i); \
} \
} else { \
pool->do_work_stealing_job((var), nullptr, (fn), nullptr); \
} \
} while (0)
#endif
void forward_prefill(int qlen, int k, const int64_t* expert_ids, const float* weights, const void* input,
void* output) {
for (int i = 0; i < qlen; i ++)
forward_decode(k, expert_ids + i * k, weights + i * k, (ggml_bf16_t*)input + i * config_.hidden_size, (float*)output + i * config_.hidden_size);
}
void forward_decode(int k, const int64_t* expert_ids, const float* weights, const void* input, void* output) {
int qlen = 1;
auto pool = config_.pool->get_subpool(tp_part_idx);
auto& quant_config = config_.quant_config;
int& group_size = quant_config.group_size;
#ifdef FORWARD_TIME_PROFILE
auto start_time = std::chrono::high_resolution_clock::now();
auto last = start_time;
// 用于保存各阶段耗时(单位:微秒)
long prepare_time = 0, cpy_input_time = 0, q_input_time = 0, up_gate_time = 0;
long act_time = 0, q_down_time = 0, down_time = 0, weight_time = 0;
int max_local_num = 0; // 记录最大的 local num
#endif
int activated_expert = 0;
for (int i = 0; i < k; i++) {
if (expert_ids[i] < config_.num_gpu_experts || expert_ids[i] >= config_.expert_num) {
continue;
}
m_expert_id_map_[activated_expert] = expert_ids[i];
activated_expert++;
}
size_t offset = 0;
for (int i = 0; i < activated_expert; i++) {
auto expert_idx = m_expert_id_map_[i];
m_local_gate_output_ptr_[expert_idx] = m_local_gate_output_ + offset * config_.intermediate_size;
m_local_up_output_ptr_[expert_idx] = m_local_up_output_ + offset * config_.intermediate_size;
m_local_down_output_ptr_[expert_idx] = m_local_down_output_ + offset * config_.hidden_size;
offset += qlen;
}
gate_up_ba_[0]->from_mat(qlen, (ggml_bf16_t*)input, 0, 1);
#ifdef FORWARD_TIME_PROFILE
{
auto now_time = std::chrono::high_resolution_clock::now();
q_input_time = std::chrono::duration_cast<std::chrono::microseconds>(now_time - last).count();
last = now_time;
}
#endif
// calc gate & up
int nth = T::recommended_nth(config_.intermediate_size);
pool->do_work_stealing_job(
nth * activated_expert * 2, [](int _) { T::config(); },
[this, nth, qlen](int task_id2) {
int& group_size = config_.quant_config.group_size;
int task_id = task_id2 / 2;
bool do_up = task_id2 % 2;
int expert_idx = m_expert_id_map_[task_id / nth];
int ith = task_id % nth;
if (do_up) {
amx::vec_mul_kgroup(qlen, config_.intermediate_size, config_.hidden_size, group_size, gate_up_ba_[0],
up_bb_[expert_idx], up_bc_[expert_idx], ith, nth);
up_bc_[expert_idx]->to_mat(qlen, m_local_up_output_ptr_[expert_idx], ith, nth);
} else {
amx::vec_mul_kgroup(qlen, config_.intermediate_size, config_.hidden_size, group_size, gate_up_ba_[0],
gate_bb_[expert_idx], gate_bc_[expert_idx], ith, nth);
gate_bc_[expert_idx]->to_mat(qlen, m_local_gate_output_ptr_[expert_idx], ith, nth);
}
},
nullptr);
#ifdef DEBUG_K2_MOE
if (activated_expert > 0) {
int print_elems = std::min(config_.intermediate_size, 16);
for (int dbg = 0; dbg < activated_expert; ++dbg) {
int sample_expert = m_expert_id_map_[dbg];
ggml_bf16_t* gate_ptr = m_local_gate_output_ptr_[sample_expert];
if (gate_ptr == nullptr) {
continue;
}
printf("[K2][TP %d] gate_out (expert %d, first %d elems): ", tp_part_idx, sample_expert, print_elems);
for (int idx = 0; idx < print_elems; idx++) {
float val = ggml_bf16_to_fp32(gate_ptr[idx]);
printf("%.6f ", val);
}
printf("\n");
int tail_start = config_.intermediate_size > print_elems ? config_.intermediate_size - print_elems : 0;
printf("[K2][TP %d] gate_out (expert %d, last %d elems): ", tp_part_idx, sample_expert, print_elems);
for (int idx = 0; idx < print_elems; idx++) {
float val = ggml_bf16_to_fp32(gate_ptr[tail_start + idx]);
printf("%.6f ", val);
}
printf("\n");
}
}
#endif
#ifdef FORWARD_TIME_PROFILE
{
auto now_time = std::chrono::high_resolution_clock::now();
up_gate_time = std::chrono::duration_cast<std::chrono::microseconds>(now_time - last).count();
last = now_time;
}
#endif
// act
for (int task_id = 0; task_id < nth * activated_expert; task_id++) {
int expert_idx = m_expert_id_map_[task_id / nth];
int ith = task_id % nth;
auto [n_start, n_end] = T::split_range_n(config_.intermediate_size, ith, nth);
for (int i = 0; i < qlen; i++) {
ggml_bf16_t* gate_output_ptr = &m_local_gate_output_ptr_[expert_idx][i * config_.intermediate_size];
ggml_bf16_t* up_output_ptr = &m_local_up_output_ptr_[expert_idx][i * config_.intermediate_size];
for (int j = n_start; j < n_end; j += 32) {
__m512 gate_val0, gate_val1, up_val0, up_val1;
avx512_32xbf16_to_32xfp32((__m512i*)(gate_output_ptr + j), &gate_val0, &gate_val1);
avx512_32xbf16_to_32xfp32((__m512i*)(up_output_ptr + j), &up_val0, &up_val1);
__m512 result0 = amx::act_fn(gate_val0, up_val0);
__m512 result1 = amx::act_fn(gate_val1, up_val1);
avx512_32xfp32_to_32xbf16(&result0, &result1, (__m512i*)(gate_output_ptr + j));
}
}
}
#ifdef FORWARD_TIME_PROFILE
{
auto now_time = std::chrono::high_resolution_clock::now();
act_time = std::chrono::duration_cast<std::chrono::microseconds>(now_time - last).count();
last = now_time;
}
#endif
// quant, get down a
pool->do_work_stealing_job(
activated_expert, nullptr,
[this, qlen](int task_id) {
int expert_idx = m_expert_id_map_[task_id];
down_ba_[expert_idx]->from_mat(qlen, m_local_gate_output_ptr_[expert_idx], 0, 1);
},
nullptr);
#ifdef FORWARD_TIME_PROFILE
{
auto now_time = std::chrono::high_resolution_clock::now();
q_down_time = std::chrono::duration_cast<std::chrono::microseconds>(now_time - last).count();
last = now_time;
}
#endif
// * down
nth = T::recommended_nth(config_.hidden_size);
pool->do_work_stealing_job(
nth * activated_expert, [](int _) { T::config(); },
[this, nth, qlen](int task_id) {
int& group_size = config_.quant_config.group_size;
int expert_idx = m_expert_id_map_[task_id / nth];
int ith = task_id % nth;
amx::vec_mul_kgroup(qlen, config_.hidden_size, config_.intermediate_size, group_size, down_ba_[expert_idx],
down_bb_[expert_idx], down_bc_[expert_idx], ith, nth);
down_bc_[expert_idx]->to_mat(qlen, m_local_down_output_ptr_[expert_idx], ith, nth);
},
nullptr);
#ifdef DEBUG_K2_MOE
if (activated_expert > 0) {
int print_elems = std::min(config_.hidden_size, 16);
for (int dbg = 0; dbg < activated_expert; ++dbg) {
int sample_expert = m_expert_id_map_[dbg];
ggml_bf16_t* down_ptr = m_local_down_output_ptr_[sample_expert];
if (down_ptr == nullptr) {
continue;
}
printf("[K2][TP %d] down_out (expert %d, first %d elems): ", tp_part_idx, sample_expert, print_elems);
for (int idx = 0; idx < print_elems; idx++) {
float val = ggml_bf16_to_fp32(down_ptr[idx]);
printf("%.6f ", val);
}
printf("\n");
int tail_start = config_.hidden_size > print_elems ? config_.hidden_size - print_elems : 0;
printf("[K2][TP %d] down_out (expert %d, last %d elems): ", tp_part_idx, sample_expert, print_elems);
for (int idx = 0; idx < print_elems; idx++) {
float val = ggml_bf16_to_fp32(down_ptr[tail_start + idx]);
printf("%.6f ", val);
}
printf("\n");
}
}
#endif
#ifdef FORWARD_TIME_PROFILE
{
auto now_time = std::chrono::high_resolution_clock::now();
down_time = std::chrono::duration_cast<std::chrono::microseconds>(now_time - last).count();
last = now_time;
}
#endif
// get output
for (int e = 0; e < config_.hidden_size; e += 32) {
__m512 x0 = _mm512_setzero_ps();
__m512 x1 = _mm512_setzero_ps();
for (int j = 0; j < k; j++) {
if (expert_ids[j] < config_.num_gpu_experts || expert_ids[j] >= config_.expert_num) {
continue;
}
__m512 weight = _mm512_set1_ps(weights[j]);
__m512 down_output0, down_output1;
avx512_32xbf16_to_32xfp32((__m512i*)(m_local_down_output_ptr_[expert_ids[j]] +
m_local_pos_[0][j] * config_.hidden_size + e),
&down_output0, &down_output1);
x0 = _mm512_fmadd_ps(down_output0, weight, x0);
x1 = _mm512_fmadd_ps(down_output1, weight, x1);
}
auto f32out = (__m512*)((float*)output + e);
f32out[0] = x0;
f32out[1] = x1;
}
#ifdef FORWARD_TIME_PROFILE
{
auto now_time = std::chrono::high_resolution_clock::now();
weight_time = std::chrono::duration_cast<std::chrono::microseconds>(now_time - last).count();
last = now_time;
}
auto end_time = std::chrono::high_resolution_clock::now();
auto forward_total_time = std::chrono::duration_cast<std::chrono::microseconds>(end_time - start_time).count();
// 在函数末尾一次性打印所有阶段的耗时,并附带 max_local_num 和 qlen
printf(
"Profiling Results (numa[%d]): activated_expert: %d, q_input: %ld us, "
"up_gate: %ld us, act: %ld us, q_down: %ld us, down: %ld us, weight: %ld us, total: %ld us\n",
tp_part_idx, activated_expert, q_input_time, up_gate_time, act_time, q_down_time, down_time, weight_time,
forward_total_time);
#endif
}
};
template <typename K>
class TP_MOE<AMX_K2_MOE_TP<K>> : public TP_MOE_Common<AMX_K2_MOE_TP<K>> {
public:
using TP_MOE_Common<AMX_K2_MOE_TP<K>>::TP_MOE_Common;
void load_weights() {
auto& config = this->config;
auto& tps = this->tps;
auto& tp_count = this->tp_count;
auto pool = config.pool;
const uint64_t* physical_to_logical_map = (const uint64_t*)config.physical_to_logical_map;
if (config.gate_scale == nullptr) {
throw std::runtime_error("K2 MoE only supports Packed Int4 with KGroup Scale");
}
printf("From Packed Int4 with KGroup Scale\n");
int& group_size = config.quant_config.group_size;
for (auto i = 0; i < tp_count; i++) {
auto& tpc = tps[i]->config_;
size_t weight_elem_count = tpc.intermediate_size * tpc.hidden_size;
tpc.gate_proj = new uint8_t[(tpc.expert_num * weight_elem_count) / 2];
tpc.up_proj = new uint8_t[(tpc.expert_num * weight_elem_count) / 2];
tpc.down_proj = new uint8_t[(tpc.expert_num * weight_elem_count) / 2];
size_t scales_elem_count = (tpc.hidden_size / group_size) * tpc.intermediate_size;
tpc.gate_scale = new ggml_bf16_t[(tpc.expert_num * scales_elem_count)];
tpc.up_scale = new ggml_bf16_t[(tpc.expert_num * scales_elem_count)];
tpc.down_scale = new ggml_bf16_t[(tpc.expert_num * scales_elem_count)];
if (tps[i]->config_.load == false) {
pool->get_subpool(i)->do_work_stealing_job(
tpc.expert_num, nullptr,
[&](int expert_id_) { // weight and scale are all in col majored.
size_t expert_id = expert_map(physical_to_logical_map, expert_id_);
// weight and scale TP-slicing for gate and up
memcpy((uint8_t*)tpc.gate_proj + ((expert_id * weight_elem_count) >> 1),
(uint8_t*)config.gate_proj +
((expert_id * config.intermediate_size * config.hidden_size + i * weight_elem_count) >> 1),
((sizeof(uint8_t) * weight_elem_count) >> 1));
memcpy((uint8_t*)tpc.up_proj + ((expert_id * weight_elem_count) >> 1),
(uint8_t*)config.up_proj +
((expert_id * config.intermediate_size * config.hidden_size + i * weight_elem_count) >> 1),
((sizeof(uint8_t) * weight_elem_count) >> 1));
memcpy((ggml_bf16_t*)tpc.gate_scale + (expert_id * scales_elem_count),
(ggml_bf16_t*)config.gate_scale +
(expert_id * (config.hidden_size / group_size) * config.intermediate_size +
i * scales_elem_count),
sizeof(ggml_bf16_t) * scales_elem_count);
memcpy((ggml_bf16_t*)tpc.up_scale + (expert_id * scales_elem_count),
(ggml_bf16_t*)config.up_scale +
(expert_id * (config.hidden_size / group_size) * config.intermediate_size +
i * scales_elem_count),
sizeof(ggml_bf16_t) * scales_elem_count);
// memcpy((uint8_t*)tpc.down_proj + ((expert_id * weight_elem_count) >> 1),
// (uint8_t*)config.down_proj +
// ((expert_id * config.intermediate_size * config.hidden_size + i * weight_elem_count) >> 1),
// ((sizeof(uint8_t) * weight_elem_count) >> 1));
// memcpy((ggml_bf16_t*)tpc.down_scale + (expert_id * scales_elem_count),
// (ggml_bf16_t*)config.down_scale +
// (expert_id * (config.intermediate_size / group_size) * config.hidden_size +
// i * scales_elem_count),
// sizeof(ggml_bf16_t) * scales_elem_count);
// weight and scale TP-slicing for down (by column)
for (size_t col = 0; col < config.hidden_size; col++) {
memcpy((uint8_t*)tpc.down_proj + ((expert_id * weight_elem_count + col * tpc.intermediate_size) >> 1),
(uint8_t*)config.down_proj + ((expert_id * config.intermediate_size * config.hidden_size +
col * config.intermediate_size + i * tpc.intermediate_size) >>
1),
(sizeof(uint8_t) * tpc.intermediate_size) >> 1);
memcpy((ggml_bf16_t*)tpc.down_scale + (expert_id * scales_elem_count + col * (tpc.intermediate_size / group_size)),
(ggml_bf16_t*)config.down_scale + ((expert_id * (config.intermediate_size / group_size) * config.hidden_size) +
col * (config.intermediate_size / group_size) + i * (tpc.intermediate_size / group_size)),
sizeof(ggml_bf16_t) * (tpc.intermediate_size / group_size));
}
},
nullptr);
}
printf("TP %d load weight done.\n", i);
}
DO_TPS_LOAD_WEIGHTS(pool);
for (auto i = 0; i < tp_count; i++) {
auto& tpc = tps[i]->config_;
delete[] (uint8_t*)(tpc.gate_proj);
delete[] (uint8_t*)(tpc.up_proj);
delete[] (uint8_t*)(tpc.down_proj);
delete[] (ggml_bf16_t*)(tpc.gate_scale);
delete[] (ggml_bf16_t*)(tpc.up_scale);
delete[] (ggml_bf16_t*)(tpc.down_scale);
}
this->weights_loaded = true;
}
void write_weight_scale_to_buffer(int gpu_tp_count, int gpu_experts_num,
const std::vector<uintptr_t>& w13_weight_ptrs,
const std::vector<uintptr_t>& w13_scale_ptrs,
const std::vector<uintptr_t>& w2_weight_ptrs,
const std::vector<uintptr_t>& w2_scale_ptrs) {
if (this->weights_loaded == false) {
throw std::runtime_error("Not Loaded");
}
if (this->tps.empty()) {
throw std::runtime_error("No TP parts initialized");
}
// Validate input vector sizes
if (w13_weight_ptrs.size() != gpu_tp_count || w13_scale_ptrs.size() != gpu_tp_count ||
w2_weight_ptrs.size() != gpu_tp_count || w2_scale_ptrs.size() != gpu_tp_count) {
throw std::runtime_error("Pointer arrays size must match gpu_tp_count");
}
// Each TP part writes to its corresponding buffer
for (int tp_idx = 0; tp_idx < this->tp_count; tp_idx++) {
// Note: w13 combines gate and up projections
// Split w13 pointers for gate and up
this->tps[tp_idx]->write_weights_to_buffer(
gpu_tp_count, this->tp_count,
gpu_experts_num, this->config,
w13_weight_ptrs, w13_scale_ptrs, //gate + up use w13
w2_weight_ptrs, w2_scale_ptrs); // down uses w2
}
}
void merge_results(int qlen, void* output, bool incremental) {
auto pool = this->config.pool;
auto merge_fn = [this, output, incremental](int token_nth) {
auto& local_output_numa = this->local_output_numa;
auto& tp_configs = this->tp_configs;
auto& tp_count = this->tp_count;
auto& config = this->config;
float* merge_to = local_output_numa[0] + token_nth * tp_configs[0].hidden_size;
if (incremental) {
for (int e = 0; e < config.hidden_size; e += 32) {
__m512 x0, x1;
avx512_32xbf16_to_32xfp32((__m512i*)((ggml_bf16_t*)output + token_nth * config.hidden_size + e), &x0, &x1);
*((__m512*)(merge_to + e)) = _mm512_add_ps(*((__m512*)(merge_to + e)), x0);
*((__m512*)(merge_to + e + 16)) = _mm512_add_ps(*((__m512*)(merge_to + e + 16)), x1);
}
}
for (int i = 1; i < tp_count; i++) {
float* merge_from = local_output_numa[i] + token_nth * tp_configs[i].hidden_size;
for (int e = 0; e < tp_configs[i].hidden_size; e += 16) {
*((__m512*)(merge_to + e)) = _mm512_add_ps(*((__m512*)(merge_to + e)), *((__m512*)(merge_from + e)));
}
}
for (int e = 0; e < config.hidden_size; e += 32) {
__m512 x0 = *(__m512*)(merge_to + e);
__m512 x1 = *(__m512*)(merge_to + e + 16);
avx512_32xfp32_to_32xbf16(&x0, &x1, (__m512i*)((ggml_bf16_t*)output + token_nth * config.hidden_size + e));
}
};
for (int i = 0; i < qlen; i++) {
merge_fn(i);
}
}
void merge_results(int qlen, void* output) { merge_results(qlen, output, false); }
};
#endif // CPUINFER_OPERATOR_AMX_K2_MOE_H

View file

@ -4,6 +4,7 @@
#include <cassert> #include <cassert>
#include <cstdint> #include <cstdint>
#include <cstdio> #include <cstdio>
#include <cstring>
#include <limits> #include <limits>
#include <vector> #include <vector>
@ -344,9 +345,6 @@ struct BufferAKGroupImpl {
static constexpr int K_STEP = K::K_STEP; static constexpr int K_STEP = K::K_STEP;
static constexpr int K_BLOCK = K::K_BLOCK; static constexpr int K_BLOCK = K::K_BLOCK;
using index_t = Packed2DLayout::index_t;
Packed2DLayout pack;
static size_t required_size(int max_m, int k, int k_group_size) { static size_t required_size(int max_m, int k, int k_group_size) {
ASSERT_RELEASE(k % k_group_size == 0, "k must be multiple of k_group_size"); ASSERT_RELEASE(k % k_group_size == 0, "k must be multiple of k_group_size");
return sizeof(int8_t) * max_m * k + sizeof(float) * max_m * (k / k_group_size); return sizeof(int8_t) * max_m * k + sizeof(float) * max_m * (k / k_group_size);
@ -355,18 +353,12 @@ struct BufferAKGroupImpl {
BufferAKGroupImpl(int max_m, int k, int k_group_size, void* ptr) BufferAKGroupImpl(int max_m, int k, int k_group_size, void* ptr)
: max_m(max_m), : max_m(max_m),
k(k), k(k),
k_group_size(k_group_size), k_group_size(k_group_size) {
pack({{static_cast<index_t>(K_STEP), 'c'},
{static_cast<index_t>(M_STEP), 'r'},
{static_cast<index_t>(k_group_size / K_STEP), 'c'},
{static_cast<index_t>(K_BLOCK / k_group_size), 'c'},
{static_cast<index_t>(max_m / M_STEP), 'r'},
{static_cast<index_t>(k / K_BLOCK), 'c'}}) {
ASSERT_RELEASE(k % k_group_size == 0, "k must be multiple of k_group_size"); ASSERT_RELEASE(k % k_group_size == 0, "k must be multiple of k_group_size");
ASSERT_RELEASE(max_m % M_STEP == 0, "max_m must be multiple of M_STEP"); ASSERT_RELEASE(max_m % M_STEP == 0, "max_m must be multiple of M_STEP");
ASSERT_RELEASE(k % K_STEP == 0, "k must be multiple of K_STEP"); ASSERT_RELEASE(k % K_STEP == 0, "k must be multiple of K_STEP");
ASSERT_RELEASE(K_BLOCK % k_group_size == 0, "K_BLOCK must be multiple of k_group_size"); ASSERT_RELEASE(K_BLOCK % k_group_size == 0, "K_BLOCK must be multiple of k_group_size");
ASSERT_RELEASE(k % K_BLOCK == 0, "k must be multiple of K_BLOCK"); // ASSERT_RELEASE(k % K_BLOCK == 0, "k must be multiple of K_BLOCK");
k_group_count = k / k_group_size; k_group_count = k / k_group_size;
set_data(ptr); set_data(ptr);
@ -922,6 +914,77 @@ struct BufferBInt4WithZeroImpl {
float* get_min(int n, int n_begin) { return mins + n_begin; } float* get_min(int n, int n_begin) { return mins + n_begin; }
}; };
// BufferB for Signed Int4 with KGroup Scale (no zero point)
// Used for K2 MoE - signed int4 range: [-8, 7]
template <typename K>
struct BufferBInt4KGroupImpl {
using dt = typename K::dt;
dt* b; // packed signed int4 weights, col majored
float* d; // scales only (no mins/zero-points), row majored
int n, k, k_group_size, k_group_count;
static constexpr int N_STEP = K::N_STEP;
static constexpr int K_STEP = K::K_STEP;
static constexpr bool SCALE = true;
// Size calculation: packed int4 weights + scales (NO mins)
static size_t required_size(int n, int k, int k_group_size) {
return sizeof(int8_t) * n * k / 2 + sizeof(float) * n * (k / k_group_size);
}
BufferBInt4KGroupImpl(int n, int k, int k_group_size, void* ptr) : n(n), k(k), k_group_size(k_group_size) {
assert(reinterpret_cast<intptr_t>(ptr) % 64 == 0);
assert(n % N_STEP == 0);
assert(k % K_STEP == 0);
if (n % N_STEP || k % K_STEP || k % k_group_size) {
printf("BufferBInt4KGroupImpl: n: %d, k: %d, N_STEP: %d, K_STEP: %d, k_group_size: %d\n", n, k, N_STEP,
K_STEP, k_group_size);
throw std::runtime_error("n or k is not aligned to N_STEP or K_STEP");
}
k_group_count = k / k_group_size;
b = reinterpret_cast<dt*>(ptr);
d = reinterpret_cast<float*>(offset_pointer(b, n * k / 2));
}
// Load from packed signed int4 format
// Input: proj is packed int4 weights (2 int4 values per byte)
// Each int4 value is in range [-8, 7] (signed)
void from_raw_mat(uint8_t* proj, int ith, int nth) {
auto [n_start, n_end] = K::split_range_n(n, ith, nth);
if (n_start >= n_end) {
return;
}
const size_t row_bytes = static_cast<size_t>(k) / 2;
const size_t rows = static_cast<size_t>(n_end - n_start);
uint8_t* dst_weights = reinterpret_cast<uint8_t*>(b) + n_start * row_bytes;
const uint8_t* src_weights = proj + n_start * row_bytes;
std::memcpy(dst_weights, src_weights, rows * row_bytes);
}
// Get pointer to submatrix for computation
dt* get_submat(int n, int k, int n_begin, int k_begin) {
const size_t row_bytes = static_cast<size_t>(k) / 2;
const size_t row_offset = static_cast<size_t>(n_begin) * row_bytes;
const size_t col_offset = static_cast<size_t>(k_begin) / 2;
return reinterpret_cast<dt*>(reinterpret_cast<uint8_t*>(b) + row_offset + col_offset);
}
// Get scale pointer for a specific row and k_group
float* get_scale(int n, int n_begin, int k, int k_begin) {
int k_group_idx = k_begin / k_group_size;
return d + n_begin * (k / k_group_size) + k_group_idx;
}
// Split range for parallel processing
static std::pair<int, int> split_range_n(int n, int ith, int nth) {
int n_per_thread = (n + nth - 1) / nth;
n_per_thread = (n_per_thread + N_STEP - 1) / N_STEP * N_STEP;
int n_start = std::min(ith * n_per_thread, n);
int n_end = std::min(n_start + n_per_thread, n);
return {n_start, n_end};
}
};
template <typename K> template <typename K>
struct BufferBInt4WithZeroKGroupImpl { struct BufferBInt4WithZeroKGroupImpl {
using dt = typename K::dt; using dt = typename K::dt;

View file

@ -1015,8 +1015,9 @@ struct GemmKernel224Int8 {
static void avx_kernel(int m, int n, int k, int m_begin, int n_begin, int k_block_begin, float* c, BufferA* ba, static void avx_kernel(int m, int n, int k, int m_begin, int n_begin, int k_block_begin, float* c, BufferA* ba,
BufferB* bb) { BufferB* bb) {
__m512i* c512 = (__m512i*)c; __m512i* c512 = (__m512i*)c;
int m_block_end = std::min(m - m_begin, M_STEP);
if (k_block_begin == 0) { if (k_block_begin == 0) {
for (int m_i = 0; m_i < m; m_i++) { for (int m_i = 0; m_i < m_block_end; m_i++) {
c512[m_i * 2] = _mm512_setzero_si512(); c512[m_i * 2] = _mm512_setzero_si512();
c512[m_i * 2 + 1] = _mm512_setzero_si512(); c512[m_i * 2 + 1] = _mm512_setzero_si512();
} }
@ -1028,7 +1029,7 @@ struct GemmKernel224Int8 {
int32_t* a32 = (int32_t*)ba->get_submat(m, k, m_begin, k_block_begin + k_begin); int32_t* a32 = (int32_t*)ba->get_submat(m, k, m_begin, k_block_begin + k_begin);
__m512i* b512 = (__m512i*)bb->get_submat(n, k, n_begin, k_block_begin + k_begin); __m512i* b512 = (__m512i*)bb->get_submat(n, k, n_begin, k_block_begin + k_begin);
for (int m_i = 0; m_i < m && m_i < M_STEP; m_i++) { for (int m_i = 0; m_i < m_block_end; m_i++) {
for (int k_i = 0; k_i < 16; k_i++) { for (int k_i = 0; k_i < 16; k_i++) {
__m512i ma = _mm512_set1_epi32(a32[m_i * 16 + k_i]); __m512i ma = _mm512_set1_epi32(a32[m_i * 16 + k_i]);
for (int n_i = 0; n_i < 2; n_i++) { for (int n_i = 0; n_i < 2; n_i++) {
@ -1239,8 +1240,9 @@ struct GemmKernel224Int4 {
BufferB* bb) { BufferB* bb) {
using K = GemmKernel224Int4; using K = GemmKernel224Int4;
__m512i* c512 = (__m512i*)c; __m512i* c512 = (__m512i*)c;
int m_block_end = std::min(m - m_begin, M_STEP);
if (k_block_begin == 0) { if (k_block_begin == 0) {
for (int m_i = 0; m_i < m; m_i++) { for (int m_i = 0; m_i < m_block_end; m_i++) {
c512[m_i * 2] = _mm512_setzero_si512(); c512[m_i * 2] = _mm512_setzero_si512();
c512[m_i * 2 + 1] = _mm512_setzero_si512(); c512[m_i * 2 + 1] = _mm512_setzero_si512();
} }
@ -1250,7 +1252,7 @@ struct GemmKernel224Int4 {
int32_t* a32_lo = (int32_t*)ba->get_submat(m, k, m_begin, k_block_begin + k_begin); int32_t* a32_lo = (int32_t*)ba->get_submat(m, k, m_begin, k_block_begin + k_begin);
int32_t* a32_hi = (int32_t*)ba->get_submat(m, k, m_begin, k_block_begin + k_begin + K::K_STEP); int32_t* a32_hi = (int32_t*)ba->get_submat(m, k, m_begin, k_block_begin + k_begin + K::K_STEP);
__m512i* b512 = (__m512i*)bb->get_submat(n, k, n_begin, k_block_begin + k_begin); __m512i* b512 = (__m512i*)bb->get_submat(n, k, n_begin, k_block_begin + k_begin);
for (int m_i = 0; m_i < m && m_i < M_STEP; m_i++) { for (int m_i = 0; m_i < m_block_end; m_i++) {
for (int k_i = 0; k_i < 16; k_i++) { for (int k_i = 0; k_i < 16; k_i++) {
__m512i ma_lo = _mm512_set1_epi32(a32_lo[m_i * 16 + k_i]); __m512i ma_lo = _mm512_set1_epi32(a32_lo[m_i * 16 + k_i]);
__m512i ma_hi = _mm512_set1_epi32(a32_hi[m_i * 16 + k_i]); __m512i ma_hi = _mm512_set1_epi32(a32_hi[m_i * 16 + k_i]);
@ -1533,8 +1535,9 @@ struct GemmKernel224Int4_1 {
BufferB* bb) { BufferB* bb) {
using K = GemmKernel224Int4_1; using K = GemmKernel224Int4_1;
__m512i* c512 = (__m512i*)c; __m512i* c512 = (__m512i*)c;
int m_block_end = std::min(m - m_begin, M_STEP);
if (k_block_begin == 0) { if (k_block_begin == 0) {
for (int m_i = 0; m_i < m; m_i++) { for (int m_i = 0; m_i < m_block_end; m_i++) {
c512[m_i * 2] = _mm512_setzero_si512(); c512[m_i * 2] = _mm512_setzero_si512();
c512[m_i * 2 + 1] = _mm512_setzero_si512(); c512[m_i * 2 + 1] = _mm512_setzero_si512();
} }
@ -1543,7 +1546,7 @@ struct GemmKernel224Int4_1 {
int32_t* a32_lo = (int32_t*)ba->get_submat(m, k, m_begin, k_block_begin + k_begin); int32_t* a32_lo = (int32_t*)ba->get_submat(m, k, m_begin, k_block_begin + k_begin);
int32_t* a32_hi = (int32_t*)ba->get_submat(m, k, m_begin, k_block_begin + k_begin + K::K_STEP); int32_t* a32_hi = (int32_t*)ba->get_submat(m, k, m_begin, k_block_begin + k_begin + K::K_STEP);
__m512i* b512 = (__m512i*)bb->get_submat(n, k, n_begin, k_block_begin + k_begin); __m512i* b512 = (__m512i*)bb->get_submat(n, k, n_begin, k_block_begin + k_begin);
for (int m_i = 0; m_i < m && m_i < M_STEP; m_i++) { for (int m_i = 0; m_i < m_block_end; m_i++) {
for (int k_i = 0; k_i < 16; k_i++) { for (int k_i = 0; k_i < 16; k_i++) {
__m512i ma_lo = _mm512_set1_epi32(a32_lo[m_i * 16 + k_i]); __m512i ma_lo = _mm512_set1_epi32(a32_lo[m_i * 16 + k_i]);
__m512i ma_hi = _mm512_set1_epi32(a32_hi[m_i * 16 + k_i]); __m512i ma_hi = _mm512_set1_epi32(a32_hi[m_i * 16 + k_i]);
@ -2193,10 +2196,11 @@ struct GemmKernel224Int4KGroup {
BufferB* bb, int k_group_size) { BufferB* bb, int k_group_size) {
using K = GemmKernel224Int4KGroup; using K = GemmKernel224Int4KGroup;
__m512i* c512 = (__m512i*)int_c; __m512i* c512 = (__m512i*)int_c;
int m_block_end = std::min(m - m_begin, M_STEP);
// Initialize int_c to zero at the start of k_group // Initialize int_c to zero at the start of k_group
if (k_block_begin % k_group_size == 0) { if (k_block_begin % k_group_size == 0) {
for (int m_i = 0; m_i < m && m_i < M_STEP; m_i++) { for (int m_i = 0; m_i < m_block_end; m_i++) {
c512[m_i * 2] = _mm512_setzero_si512(); c512[m_i * 2] = _mm512_setzero_si512();
c512[m_i * 2 + 1] = _mm512_setzero_si512(); c512[m_i * 2 + 1] = _mm512_setzero_si512();
} }
@ -2205,7 +2209,7 @@ struct GemmKernel224Int4KGroup {
if (k_offset == 0) { if (k_offset == 0) {
int32_t* a32_lo = (int32_t*)ba->get_submat(m, k, m_begin, k_block_begin); int32_t* a32_lo = (int32_t*)ba->get_submat(m, k, m_begin, k_block_begin);
__m512i* b512 = (__m512i*)bb->get_submat(n, k, n_begin, k_block_begin); __m512i* b512 = (__m512i*)bb->get_submat(n, k, n_begin, k_block_begin);
for (int m_i = 0; m_i < m && m_i < M_STEP; m_i++) { for (int m_i = 0; m_i < m_block_end; m_i++) {
for (int k_i = 0; k_i < 16; k_i++) { for (int k_i = 0; k_i < 16; k_i++) {
__m512i ma_lo = _mm512_set1_epi32(a32_lo[m_i * 16 + k_i]); __m512i ma_lo = _mm512_set1_epi32(a32_lo[m_i * 16 + k_i]);
for (int n_i = 0; n_i < 2; n_i++) { for (int n_i = 0; n_i < 2; n_i++) {
@ -2217,7 +2221,7 @@ struct GemmKernel224Int4KGroup {
} else { } else {
int32_t* a32_hi = (int32_t*)ba->get_submat(m, k, m_begin, k_block_begin); int32_t* a32_hi = (int32_t*)ba->get_submat(m, k, m_begin, k_block_begin);
__m512i* b512 = (__m512i*)bb->get_submat(n, k, n_begin, k_block_begin - K::K_STEP); __m512i* b512 = (__m512i*)bb->get_submat(n, k, n_begin, k_block_begin - K::K_STEP);
for (int m_i = 0; m_i < m && m_i < M_STEP; m_i++) { for (int m_i = 0; m_i < m_block_end; m_i++) {
for (int k_i = 0; k_i < 16; k_i++) { for (int k_i = 0; k_i < 16; k_i++) {
__m512i ma_hi = _mm512_set1_epi32(a32_hi[m_i * 16 + k_i]); __m512i ma_hi = _mm512_set1_epi32(a32_hi[m_i * 16 + k_i]);
for (int n_i = 0; n_i < 2; n_i++) { for (int n_i = 0; n_i < 2; n_i++) {
@ -2471,8 +2475,9 @@ struct GemmKernel224Int4_1KGroup {
BufferB* bb, int k_group_size) { BufferB* bb, int k_group_size) {
using K = GemmKernel224Int4_1KGroup; using K = GemmKernel224Int4_1KGroup;
__m512i* c512 = (__m512i*)int_c; __m512i* c512 = (__m512i*)int_c;
int m_block_end = std::min(m - m_begin, M_STEP);
if (k_block_begin % k_group_size == 0) { if (k_block_begin % k_group_size == 0) {
for (int m_i = 0; m_i < m; m_i++) { for (int m_i = 0; m_i < m_block_end; m_i++) {
c512[m_i * 2] = _mm512_setzero_si512(); c512[m_i * 2] = _mm512_setzero_si512();
c512[m_i * 2 + 1] = _mm512_setzero_si512(); c512[m_i * 2 + 1] = _mm512_setzero_si512();
} }
@ -2481,7 +2486,7 @@ struct GemmKernel224Int4_1KGroup {
if (k_offset == 0) { if (k_offset == 0) {
int32_t* a32_lo = (int32_t*)ba->get_submat(m, k, m_begin, k_block_begin); int32_t* a32_lo = (int32_t*)ba->get_submat(m, k, m_begin, k_block_begin);
__m512i* b512 = (__m512i*)bb->get_submat(n, k, n_begin, k_block_begin); __m512i* b512 = (__m512i*)bb->get_submat(n, k, n_begin, k_block_begin);
for (int m_i = 0; m_i < m && m_i < M_STEP; m_i++) { for (int m_i = 0; m_i < m_block_end; m_i++) {
for (int k_i = 0; k_i < 16; k_i++) { for (int k_i = 0; k_i < 16; k_i++) {
__m512i ma_lo = _mm512_set1_epi32(a32_lo[m_i * 16 + k_i]); __m512i ma_lo = _mm512_set1_epi32(a32_lo[m_i * 16 + k_i]);
for (int n_i = 0; n_i < 2; n_i++) { for (int n_i = 0; n_i < 2; n_i++) {
@ -2493,7 +2498,7 @@ struct GemmKernel224Int4_1KGroup {
} else { } else {
int32_t* a32_hi = (int32_t*)ba->get_submat(m, k, m_begin, k_block_begin); int32_t* a32_hi = (int32_t*)ba->get_submat(m, k, m_begin, k_block_begin);
__m512i* b512 = (__m512i*)bb->get_submat(n, k, n_begin, k_block_begin - K::K_STEP); __m512i* b512 = (__m512i*)bb->get_submat(n, k, n_begin, k_block_begin - K::K_STEP);
for (int m_i = 0; m_i < m && m_i < M_STEP; m_i++) { for (int m_i = 0; m_i < m_block_end; m_i++) {
for (int k_i = 0; k_i < 16; k_i++) { for (int k_i = 0; k_i < 16; k_i++) {
__m512i ma_hi = _mm512_set1_epi32(a32_hi[m_i * 16 + k_i]); __m512i ma_hi = _mm512_set1_epi32(a32_hi[m_i * 16 + k_i]);
for (int n_i = 0; n_i < 2; n_i++) { for (int n_i = 0; n_i < 2; n_i++) {
@ -2746,8 +2751,9 @@ struct GemmKernel224Int4_1_LowKGroup {
BufferB* bb, int k_group_size) { BufferB* bb, int k_group_size) {
using K = GemmKernel224Int4_1_LowKGroup; using K = GemmKernel224Int4_1_LowKGroup;
__m512i* c512 = (__m512i*)int_c; __m512i* c512 = (__m512i*)int_c;
int m_block_end = std::min(m - m_begin, M_STEP);
if (k_block_begin % k_group_size == 0) { if (k_block_begin % k_group_size == 0) {
for (int m_i = 0; m_i < m; m_i++) { for (int m_i = 0; m_i < m_block_end; m_i++) {
c512[m_i * 2] = _mm512_setzero_si512(); c512[m_i * 2] = _mm512_setzero_si512();
c512[m_i * 2 + 1] = _mm512_setzero_si512(); c512[m_i * 2 + 1] = _mm512_setzero_si512();
} }
@ -2756,7 +2762,7 @@ struct GemmKernel224Int4_1_LowKGroup {
if (k_offset == 0) { if (k_offset == 0) {
int32_t* a32_lo = (int32_t*)ba->get_submat(m, k, m_begin, k_block_begin); int32_t* a32_lo = (int32_t*)ba->get_submat(m, k, m_begin, k_block_begin);
__m512i* b512 = (__m512i*)bb->get_submat(n, k, n_begin, k_block_begin); __m512i* b512 = (__m512i*)bb->get_submat(n, k, n_begin, k_block_begin);
for (int m_i = 0; m_i < m && m_i < M_STEP; m_i++) { for (int m_i = 0; m_i < m_block_end; m_i++) {
for (int k_i = 0; k_i < 16; k_i++) { for (int k_i = 0; k_i < 16; k_i++) {
__m512i ma_lo = _mm512_set1_epi32(a32_lo[m_i * 16 + k_i]); __m512i ma_lo = _mm512_set1_epi32(a32_lo[m_i * 16 + k_i]);
for (int n_i = 0; n_i < 2; n_i++) { for (int n_i = 0; n_i < 2; n_i++) {
@ -2768,7 +2774,7 @@ struct GemmKernel224Int4_1_LowKGroup {
} else { } else {
int32_t* a32_hi = (int32_t*)ba->get_submat(m, k, m_begin, k_block_begin); int32_t* a32_hi = (int32_t*)ba->get_submat(m, k, m_begin, k_block_begin);
__m512i* b512 = (__m512i*)bb->get_submat(n, k, n_begin, k_block_begin - K::K_STEP); __m512i* b512 = (__m512i*)bb->get_submat(n, k, n_begin, k_block_begin - K::K_STEP);
for (int m_i = 0; m_i < m && m_i < M_STEP; m_i++) { for (int m_i = 0; m_i < m_block_end; m_i++) {
for (int k_i = 0; k_i < 16; k_i++) { for (int k_i = 0; k_i < 16; k_i++) {
__m512i ma_hi = _mm512_set1_epi32(a32_hi[m_i * 16 + k_i]); __m512i ma_hi = _mm512_set1_epi32(a32_hi[m_i * 16 + k_i]);
for (int n_i = 0; n_i < 2; n_i++) { for (int n_i = 0; n_i < 2; n_i++) {
@ -2837,6 +2843,110 @@ struct GemmKernel224Int4_1_LowKGroup {
} }
}; };
// K2 Signed Int4 K-group quantization kernel (AVX only, no AMX)
// For K2 MoE - signed int4 range: [-8, 7]
struct GemmKernel224Int4SmallKGroup {
using dt = uint8_t; // packed int4 type
using output_t = int32_t;
static constexpr double ELEMENT_SIZE = 0.5;
static const int VNNI_BLK = 4;
static const int M_STEP = 1;
static const int N_STEP = 32;
static const int K_STEP = 32;
static inline const int N_BLOCK = 256;
// K_BLOCK should match k_group_size for proper scaling
static inline const int K_BLOCK = 7168; // Will be overridden by k_group_size
static std::string name() { return "K2_INT4_KGROUP"; }
static int recommended_nth(int n) { return (n + N_BLOCK - 1) / N_BLOCK; }
static std::pair<int, int> split_range_n(int n, int ith, int nth) {
int n_start = N_BLOCK * ith;
int n_end = std::min(n, N_BLOCK * (ith + 1));
return {n_start, n_end};
}
static void config() {}
alignas(64) static constexpr uint8_t hi_mask_arr[32] = {
0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0,
0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0
};
alignas(64) static constexpr uint8_t lo_mask_arr[32] = {
0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F,
0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F
};
alignas(64) static constexpr uint8_t sign_xor_arr[32] = {
0x88, 0x88, 0x88, 0x88, 0x88, 0x88, 0x88, 0x88, 0x88, 0x88, 0x88, 0x88, 0x88, 0x88, 0x88, 0x88,
0x88, 0x88, 0x88, 0x88, 0x88, 0x88, 0x88, 0x88, 0x88, 0x88, 0x88, 0x88, 0x88, 0x88, 0x88, 0x88
};
static __m256i hi_mask() { return *((__m256i*)(&hi_mask_arr[0])); }
static __m256i lo_mask() { return *((__m256i*)(&lo_mask_arr[0])); }
static __m256i sign_xor_mask() { return *((__m256i*)(&sign_xor_arr[0])); }
using BufferA = BufferAKGroupImpl<GemmKernel224Int4SmallKGroup>;
using BufferB = BufferBInt4KGroupImpl<GemmKernel224Int4SmallKGroup>; // Use new signed int4 buffer
using BufferC = BufferCReduceImpl<GemmKernel224Int4SmallKGroup>;
// K-group aware AVX kernel for signed int4
static inline __m512i compressed_int4_to_int8_avx512(__m256i b256) {
b256 = _mm256_xor_si256(b256, sign_xor_mask());
__m256i b_hi = _mm256_and_si256(b256, hi_mask());
__m256i b_lo = _mm256_slli_epi16(_mm256_andnot_si256(hi_mask(), b256), 4);
__m256i unpack_lo = _mm256_unpacklo_epi8(b_lo, b_hi);
__m256i unpack_hi = _mm256_unpackhi_epi8(b_lo, b_hi);
__m512i result = _mm512_inserti64x4(_mm512_castsi256_si512(unpack_lo), unpack_hi, 1);
const __m512i lane_shuffle = _mm512_set_epi64(7, 6, 3, 2, 5, 4, 1, 0);
return _mm512_permutexvar_epi64(lane_shuffle, result);
}
static inline void integer_mat_vec_kgroup(int m, int n, int k, int k_group_size, BufferA* ba, BufferB *bb, BufferC* bc, int ith, int nth) {
auto [n_start, n_end] = split_range_n(n, ith, nth);
for (int m_begin = 0; m_begin < m; m_begin ++) {
float* c = bc->get_submat(m, n, m_begin, 0);
__m512i* a512 = (__m512i*)ba->get_submat(m, k, m_begin, 0);
for (int n_block_begin = n_start; n_block_begin < n_end; n_block_begin ++) {
__m256i* b256 = (__m256i*)bb->get_submat(n, k, n_block_begin, 0);
float* as = (float*)ba->get_scale(m, m_begin, k, 0);
float* bs = (float*)bb->get_scale(n, n_block_begin, k, 0);
__m512 sum = _mm512_setzero_ps();
#define WORK_K_BLOCK(k_block) \
{ \
__m256 abscale0 = _mm256_set1_ps(as[(k_block)*2] * bs[(k_block)*2]); \
__m256 abscale1 = _mm256_set1_ps(as[(k_block)*2+1] * bs[(k_block)*2+1]); \
__m512 abscale = _mm512_insertf32x8(_mm512_castps256_ps512(abscale0), abscale1, 1); \
__m512i mul = _mm512_setzero_si512(); \
mul = _mm512_dpbssd_epi32(mul, a512[k_block], compressed_int4_to_int8_avx512(b256[k_block])); \
sum = _mm512_add_ps(sum, _mm512_mul_ps(abscale, _mm512_cvtepi32_ps(mul))); \
}
for (int k_block = 0; k_block < k / 64; k_block += 2) {
WORK_K_BLOCK(k_block);
WORK_K_BLOCK(k_block + 1);
}
c[n_block_begin] = _mm512_reduce_add_ps(sum) / 16;
}
}
}
};
inline void vec_mul_kgroup(int m, int n, int k, int k_group_size, std::shared_ptr<GemmKernel224Int4SmallKGroup::BufferA> ba,
std::shared_ptr<GemmKernel224Int4SmallKGroup::BufferB> bb,
std::shared_ptr<GemmKernel224Int4SmallKGroup::BufferC> bc, int ith, int nth) {
GemmKernel224Int4SmallKGroup::integer_mat_vec_kgroup(m, n, k, k_group_size, ba.get(), bb.get(), bc.get(), ith, nth);
}
inline void mat_mul_kgroup(int m, int n, int k, int k_group_size, std::shared_ptr<GemmKernel224Int4SmallKGroup::BufferA> ba,
std::shared_ptr<GemmKernel224Int4SmallKGroup::BufferB> bb,
std::shared_ptr<GemmKernel224Int4SmallKGroup::BufferC> bc, int ith, int nth) {
GemmKernel224Int4SmallKGroup::integer_mat_vec_kgroup(m, n, k, k_group_size, ba.get(), bb.get(), bc.get(), ith, nth);
}
// New k-group aware matrix multiplication function // New k-group aware matrix multiplication function
template <typename K, bool amx_or_avx = true> template <typename K, bool amx_or_avx = true>
void integer_mat_mul_kgroup(int m, int n, int k, int k_group_size, typename K::BufferA* ba, typename K::BufferB* bb, void integer_mat_mul_kgroup(int m, int n, int k, int k_group_size, typename K::BufferA* ba, typename K::BufferB* bb,

View file

@ -17,7 +17,7 @@ from typing import List, Optional
from .experts_base import BaseMoEWrapper, KExpertsCPUBuffer from .experts_base import BaseMoEWrapper, KExpertsCPUBuffer
# Import backend implementations # Import backend implementations
from .utils.amx import AMXMoEWrapper from .utils.amx import AMXMoEWrapper, RAWAMXMoEWrapper
from .utils.llamafile import LlamafileMoEWrapper from .utils.llamafile import LlamafileMoEWrapper
from .utils.moe_kernel import GeneralMoEWrapper from .utils.moe_kernel import GeneralMoEWrapper
@ -77,7 +77,7 @@ class KTMoEWrapper:
chunked_prefill_size: Maximum prefill chunk size chunked_prefill_size: Maximum prefill chunk size
cpu_save: Whether to save weights to CPU memory cpu_save: Whether to save weights to CPU memory
max_deferred_experts_per_token: Number of experts per token to defer. Defaults to 0. max_deferred_experts_per_token: Number of experts per token to defer. Defaults to 0.
method: Backend method ("AMXINT4", "AMXINT8", "LLAMAFILE", "MOE_INT4", "MOE_INT8") method: Backend method ("AMXINT4", "AMXINT8", "RAWINT4", "LLAMAFILE", "MOE_INT4", "MOE_INT8")
Returns: Returns:
An instance of the appropriate backend implementation (e.g., AMXMoEWrapper) An instance of the appropriate backend implementation (e.g., AMXMoEWrapper)
@ -85,6 +85,8 @@ class KTMoEWrapper:
# Select backend based on method # Select backend based on method
if method in ["AMXINT4", "AMXINT8"]: if method in ["AMXINT4", "AMXINT8"]:
backend_cls = AMXMoEWrapper backend_cls = AMXMoEWrapper
elif method == "RAWINT4":
backend_cls = RAWAMXMoEWrapper
elif method == "LLAMAFILE": elif method == "LLAMAFILE":
backend_cls = LlamafileMoEWrapper backend_cls = LlamafileMoEWrapper
elif method in ["MOE_INT4", "MOE_INT8"]: elif method in ["MOE_INT4", "MOE_INT8"]:

View file

@ -4,13 +4,15 @@
Utilities for kt_kernel package. Utilities for kt_kernel package.
""" """
from .amx import AMXMoEWrapper from .amx import AMXMoEWrapper, RAWAMXMoEWrapper
from .llamafile import LlamafileMoEWrapper from .llamafile import LlamafileMoEWrapper
from .loader import SafeTensorLoader, GGUFLoader from .loader import SafeTensorLoader, GGUFLoader, CompressedSafeTensorLoader
__all__ = [ __all__ = [
"AMXMoEWrapper", "AMXMoEWrapper",
"RAWAMXMoEWrapper",
"LlamafileMoEWrapper", "LlamafileMoEWrapper",
"SafeTensorLoader", "SafeTensorLoader",
"CompressedSafeTensorLoader",
"GGUFLoader", "GGUFLoader",
] ]

View file

@ -4,16 +4,16 @@ import ctypes
# Use relative imports for package structure # Use relative imports for package structure
from ..experts_base import BaseMoEWrapper from ..experts_base import BaseMoEWrapper
from .loader import SafeTensorLoader from .loader import SafeTensorLoader, CompressedSafeTensorLoader
from kt_kernel_ext.moe import MOEConfig from kt_kernel_ext.moe import MOEConfig
try: try:
from kt_kernel_ext.moe import AMXInt4_MOE, AMXInt8_MOE from kt_kernel_ext.moe import AMXInt4_MOE, AMXInt8_MOE, AMXInt4_KGroup_MOE
_HAS_AMX_SUPPORT = True _HAS_AMX_SUPPORT = True
except (ImportError, AttributeError): except (ImportError, AttributeError):
_HAS_AMX_SUPPORT = False _HAS_AMX_SUPPORT = False
AMXInt4_MOE, AMXInt8_MOE = None, None AMXInt4_MOE, AMXInt8_MOE, AMXInt4_KGroup_MOE = None, None, None
from typing import Optional from typing import Optional
@ -301,3 +301,152 @@ class AMXMoEWrapper(BaseMoEWrapper):
del self.gate_scales del self.gate_scales
del self.up_scales del self.up_scales
del self.down_scales del self.down_scales
class RAWAMXMoEWrapper(BaseMoEWrapper):
"""Wrapper for RAWINT4 experts stored in compressed SafeTensor format."""
_compressed_loader_instance = None
def __init__(
self,
layer_idx: int,
num_experts: int,
num_experts_per_tok: int,
hidden_size: int,
moe_intermediate_size: int,
num_gpu_experts: int,
cpuinfer_threads: int,
threadpool_count: int,
weight_path: str,
chunked_prefill_size: int,
cpu_save: bool = False,
max_deferred_experts_per_token: Optional[int] = None,
method: str = "RAWINT4",
):
if not _HAS_AMX_SUPPORT or AMXInt4_KGroup_MOE is None:
raise RuntimeError("AMX backend with RAWINT4 support is not available.")
super().__init__(
layer_idx=layer_idx,
num_experts=num_experts,
num_experts_per_tok=num_experts_per_tok,
hidden_size=hidden_size,
moe_intermediate_size=moe_intermediate_size,
num_gpu_experts=num_gpu_experts,
cpuinfer_threads=cpuinfer_threads,
threadpool_count=threadpool_count,
weight_path=weight_path,
chunked_prefill_size=chunked_prefill_size,
cpu_save=cpu_save,
max_deferred_experts_per_token=max_deferred_experts_per_token,
method=method,
)
if RAWAMXMoEWrapper._compressed_loader_instance is None:
RAWAMXMoEWrapper._compressed_loader_instance = CompressedSafeTensorLoader(weight_path)
self.loader = RAWAMXMoEWrapper._compressed_loader_instance
self.gate_weights = None
self.up_weights = None
self.down_weights = None
self.gate_scales = None
self.up_scales = None
self.down_scales = None
def load_weights_from_tensors(
self,
gate_proj: torch.Tensor,
up_proj: torch.Tensor,
down_proj: torch.Tensor,
physical_to_logical_map_cpu: torch.Tensor,
):
raise NotImplementedError("RAWINT4 wrapper expects pre-quantized safetensor weights.")
def load_weights(self, physical_to_logical_map_cpu: torch.Tensor):
base_key = f"model.layers.{self.layer_idx}"
weights = self.loader.load_experts(base_key)
self.gate_weights = torch.stack(weights["gate"], dim=0).contiguous()
self.up_weights = torch.stack(weights["up"], dim=0).contiguous()
self.down_weights = torch.stack(weights["down"], dim=0).contiguous()
self.gate_scales = torch.stack(weights["gate_scale"], dim=0).to(torch.bfloat16).contiguous()
self.up_scales = torch.stack(weights["up_scale"], dim=0).to(torch.bfloat16).contiguous()
self.down_scales = torch.stack(weights["down_scale"], dim=0).to(torch.bfloat16).contiguous()
moe_config = MOEConfig(
self.num_experts,
self.num_experts_per_tok,
self.hidden_size,
self.moe_intermediate_size,
self.num_gpu_experts,
)
moe_config.layer_idx = self.layer_idx
moe_config.pool = self.cpu_infer.backend_
moe_config.max_len = self.chunked_prefill_size
moe_config.quant_config.bits = 4
moe_config.quant_config.group_size = 32
moe_config.quant_config.zero_point = False
moe_config.gate_proj = self.gate_weights.data_ptr()
moe_config.up_proj = self.up_weights.data_ptr()
moe_config.down_proj = self.down_weights.data_ptr()
moe_config.gate_scale = self.gate_scales.data_ptr()
moe_config.up_scale = self.up_scales.data_ptr()
moe_config.down_scale = self.down_scales.data_ptr()
self.moe = AMXInt4_KGroup_MOE(moe_config)
self.cpu_infer.submit(self.moe.load_weights_task(physical_to_logical_map_cpu.data_ptr()))
self.cpu_infer.sync()
del self.gate_weights
del self.up_weights
del self.down_weights
del self.gate_scales
del self.up_scales
del self.down_scales
def submit_write_weight_scale_to_buffer(
self,
gpu_tp_count: int,
gpu_experts_num: int,
w13_weight_ptrs,
w13_scale_ptrs,
w2_weight_ptrs,
w2_scale_ptrs,
):
"""
Submit the write_weight_scale_to_buffer task for RAWINT4 KGroup AMX implementation.
This method submits the C++-exposed task `write_weight_scale_to_buffer_task` to the
shared CPUInfer queue. The pointer lists should be plain integer lists (e.g. from
tensor.data_ptr()).
"""
if self.moe is None:
raise RuntimeError("MoE instance not initialized; cannot submit write_weight_scale_to_buffer task.")
if not hasattr(self.moe, "write_weight_scale_to_buffer_task"):
raise NotImplementedError(
"write_weight_scale_to_buffer_task is not available for this backend implementation."
)
self.cpu_infer.submit(
self.moe.write_weight_scale_to_buffer_task(
gpu_tp_count,
gpu_experts_num,
w13_weight_ptrs,
w13_scale_ptrs,
w2_weight_ptrs,
w2_scale_ptrs,
)
)
def sync_write_weight_scale_to_buffer(self):
"""
Block until previously submitted write_weight_scale_to_buffer tasks finish.
"""
# The CPUInfer.sync() call blocks until pending tasks complete.
self.cpu_infer.sync()

View file

@ -237,6 +237,56 @@ class SafeTensorLoader:
return name in self.tensor_file_map return name in self.tensor_file_map
class CompressedSafeTensorLoader(SafeTensorLoader):
"""Loader for compressed SafeTensor layouts (RAWINT4 weights)."""
def load_experts(self, base_key: str, device: str = "cpu"):
"""Load raw expert weights stored in compressed safetensor format."""
experts_prefix = f"{base_key}.mlp.experts"
expert_idx = 0
while self.has_tensor(f"{experts_prefix}.{expert_idx}.up_proj.weight_packed"):
expert_idx += 1
if expert_idx == 0:
raise ValueError(f"No experts found for key {experts_prefix}")
def load_projection(proj_name: str):
weight_entries = []
scale_entries = []
for exp_id in range(expert_idx):
weight_key = f"{experts_prefix}.{exp_id}.{proj_name}_proj.weight_packed"
scale_key = f"{experts_prefix}.{exp_id}.{proj_name}_proj.weight_scale"
if not self.has_tensor(weight_key):
raise KeyError(f"Missing tensor: {weight_key}")
if not self.has_tensor(scale_key):
raise KeyError(f"Missing tensor: {scale_key}")
weight_tensor = self.load_tensor(weight_key, device).contiguous()
scale_tensor = self.load_tensor(scale_key, device).contiguous()
weight_entries.append(weight_tensor)
scale_entries.append(scale_tensor)
return weight_entries, scale_entries
gate_weights, gate_scales = load_projection("gate")
up_weights, up_scales = load_projection("up")
down_weights, down_scales = load_projection("down")
return {
"gate": gate_weights,
"up": up_weights,
"down": down_weights,
"gate_scale": gate_scales,
"up_scale": up_scales,
"down_scale": down_scales,
}
class GGUFLoader: class GGUFLoader:
""" """
GGUF format loader using the official gguf library (gguf.gguf_reader.GGUFReader) GGUF format loader using the official gguf library (gguf.gguf_reader.GGUFReader)