[Feature] Add avx-based kimi-k2 support (#1656)
Some checks are pending
Book-CI / test-2 (push) Waiting to run
Book-CI / test (push) Waiting to run
Book-CI / test-1 (push) Waiting to run
Deploy / deploy (macos-latest) (push) Waiting to run
Deploy / deploy (ubuntu-latest) (push) Waiting to run
Deploy / deploy (windows-latest) (push) Waiting to run

* support Kimi-K2-Thinking original weight
fix amx kernel bug

* update k2 avx kernel.

* feat: add CPUInfer write buffer task

* [feat]: add kimi k2 cpu write buffer support

- Implement write_weights_to_buffer function in k2-moe.hpp for extracting GPU expert weights
- Fix down (w2) weight column-wise slicing for different TP configurations
- Support three TP scenarios: cpu_tp == gpu_tp, cpu_tp > gpu_tp, cpu_tp < gpu_tp
- Add comprehensive test cases for weight extraction validation
- Ensure compatibility with Kimi model's MoE architecture

* [fix]: correct write_weight_scale_to_buffer expert offset calculation

Fixed the bug in write_weight_scale_to_buffer_task where expert offsets in GPU buffers were incorrectly calculated. Changed from using per_expert_gpu sizes to using full gpu_tp sizes, ensuring correct memory layout for multi-expert scenarios.

Also added benchmark scripts for k2 moe and write buffer operations, and cleaned up debug output in test files.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>

* [feat]: add write buffer wrapper

* [fix] fix comment

---------

Co-authored-by: ouqingliang <1692110604@qq.com>
Co-authored-by: Claude <noreply@anthropic.com>
This commit is contained in:
Jiaqi Liao 2025-12-02 16:01:07 +08:00 committed by GitHub
parent c2b8c60c4e
commit fcf8882075
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
12 changed files with 2649 additions and 34 deletions

View file

@ -0,0 +1,363 @@
#!/usr/bin/env python
# coding=utf-8
"""
Benchmark AMX_K2_MOE_TP int4 path with packed weights and BF16 scales.
"""
import json
import math
import os
import platform
import subprocess
import sys
import time
from tqdm import tqdm
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "build"))
import kt_kernel_ext
import torch
# Benchmark parameters (single MoE, no layer loop)
expert_num = 384
hidden_size = 7168
intermediate_size = 2048
max_len = 25600
num_experts_per_tok = 8
qlen = 1
warm_up_iter = 1000
test_iter = 5000
k_group_size = 32
physical_to_logical_map = (
torch.tensor(data=range(expert_num), device="cpu", dtype=torch.int64).contiguous()
)
worker_config = kt_kernel_ext.WorkerPoolConfig()
worker_config.subpool_count = 2
worker_config.subpool_numa_map = [0, 1]
worker_config.subpool_thread_count = [40, 40]
CPUInfer = kt_kernel_ext.CPUInfer(worker_config)
def get_git_commit():
result = {}
try:
commit = (
subprocess.check_output(["git", "rev-parse", "HEAD"])
.decode("utf-8")
.strip()
)
commit_msg = (
subprocess.check_output(["git", "log", "-1", "--pretty=%B"])
.decode("utf-8")
.strip()
)
result["commit"] = commit
result["commit_message"] = commit_msg
dirty_output = (
subprocess.check_output(["git", "status", "--porcelain"])
.decode("utf-8")
.strip()
)
if dirty_output:
result["dirty"] = True
result["dirty_files"] = dirty_output.splitlines()
else:
result["dirty"] = False
except Exception as e:
result["commit"] = None
result["commit_message"] = None
result["dirty"] = None
result["error"] = str(e)
return result
def get_system_info():
info = {}
uname = platform.uname()
info["system_name"] = uname.system
info["node_name"] = uname.node
cpu_model = None
if os.path.exists("/proc/cpuinfo"):
try:
with open("/proc/cpuinfo", "r") as f:
for line in f:
if "model name" in line:
cpu_model = line.split(":", 1)[1].strip()
break
except Exception as e:
cpu_model = f"Error: {e}"
info["cpu_model"] = cpu_model
mem_total_gb = None
if os.path.exists("/proc/meminfo"):
try:
with open("/proc/meminfo", "r") as f:
for line in f:
if "MemTotal" in line:
mem_kb = float(line.split(":", 1)[1].split()[0])
mem_total_gb = round(mem_kb / (1024 * 1024), 2)
break
except Exception as e:
mem_total_gb = f"Error: {e}"
info["memory_size_GB"] = mem_total_gb
info["cpu_core_count"] = os.cpu_count()
sockets = set()
if os.path.exists("/proc/cpuinfo"):
try:
with open("/proc/cpuinfo", "r") as f:
for line in f:
if "physical id" in line:
sockets.add(line.split(":", 1)[1].strip())
except Exception:
sockets = set()
info["cpu_socket_count"] = len(sockets) if len(sockets) > 0 else 1
return info
script_path = os.path.abspath(__file__)
script_dir = os.path.dirname(script_path)
script_name = os.path.splitext(os.path.basename(script_path))[0]
json_path = os.path.join(script_dir, script_name + ".jsonl")
def record_results(result, filename=json_path):
with open(filename, "a") as f:
f.write(json.dumps(result) + "\n")
def pack_to_int32(
value: torch.Tensor, num_bits: int, packed_dim: int = 1
) -> torch.Tensor:
if value.dtype is not torch.int8:
raise ValueError("Tensor must be torch.int8 before packing")
if not (1 <= num_bits <= 8):
raise ValueError(f"num_bits must be in [1, 8], got {num_bits}")
offset = 1 << (num_bits - 1)
value = (value + offset).to(torch.uint8)
device = value.device
pack_factor = 32 // num_bits
if packed_dim == 0:
value = value.transpose(0, 1)
rows, cols = value.shape
padded_cols = math.ceil(cols / pack_factor) * pack_factor
pad_len = padded_cols - cols
if pad_len > 0:
value = torch.nn.functional.pad(value, (0, pad_len))
num_groups = padded_cols // pack_factor
reshaped = value.view(rows, num_groups, pack_factor).to(torch.int32)
bit_shifts = torch.arange(pack_factor, device=device, dtype=torch.int32) * num_bits
packed = (reshaped << bit_shifts).sum(dim=2, dtype=torch.int32)
if packed_dim == 0:
packed = packed.transpose(0, 1)
return packed
def pack_tensor_per_row(q: torch.Tensor, num_bits: int) -> torch.Tensor:
e, rows, cols = q.shape
flat = q.view(e * rows, cols)
packed = pack_to_int32(flat, num_bits)
return packed.view(e, rows, -1).contiguous()
def quantize_k2_tensor(weights: torch.Tensor, group_size: int):
"""
K2 int4 quantization producing int32-packed weights (8 int4s each) and BF16 scales.
"""
weights_f32 = weights.to(torch.float32)
e, rows, cols = weights_f32.shape
if cols % group_size != 0 or cols % 2 != 0:
raise ValueError(
f"cols ({cols}) must be divisible by group_size ({group_size}) and 2"
)
reshaped = weights_f32.view(e, rows, cols // group_size, group_size)
max_abs = reshaped.abs().amax(dim=-1, keepdim=True).clamp(min=1e-8)
scales = (max_abs / 7.0).squeeze(-1)
q = torch.round(reshaped / scales.unsqueeze(-1)).clamp(-8, 7).to(torch.int8)
q = q.view(e, rows, cols)
packed = pack_tensor_per_row(q, num_bits=4).view(e, rows, cols // 8).contiguous()
scales = scales.to(torch.bfloat16).contiguous().view(
e, rows, cols // group_size
).contiguous()
return packed, scales
def build_quantized_layer_weights():
gate_proj = torch.randn(
(expert_num, intermediate_size, hidden_size),
dtype=torch.float32,
device="cpu",
).contiguous()
up_proj = torch.randn(
(expert_num, intermediate_size, hidden_size),
dtype=torch.float32,
device="cpu",
).contiguous()
down_proj = torch.randn(
(expert_num, hidden_size, intermediate_size),
dtype=torch.float32,
device="cpu",
).contiguous()
gate_q, gate_scales = quantize_k2_tensor(gate_proj, k_group_size)
up_q, up_scales = quantize_k2_tensor(up_proj, k_group_size)
down_q, down_scales = quantize_k2_tensor(down_proj, k_group_size)
return {
"gate_qweight": gate_q,
"up_qweight": up_q,
"down_qweight": down_q,
"gate_scales": gate_scales,
"up_scales": up_scales,
"down_scales": down_scales,
}
def bench_k2_moe():
with torch.inference_mode():
bytes_per_elem = 0.5 + 2.0 / k_group_size
quant_data = build_quantized_layer_weights()
config = kt_kernel_ext.moe.MOEConfig(
expert_num, num_experts_per_tok, hidden_size, intermediate_size, 0
)
config.max_len = max_len
config.quant_config.bits = 4
config.quant_config.group_size = k_group_size
config.quant_config.zero_point = False
config.gate_proj = quant_data["gate_qweight"].data_ptr()
config.up_proj = quant_data["up_qweight"].data_ptr()
config.down_proj = quant_data["down_qweight"].data_ptr()
config.gate_scale = quant_data["gate_scales"].data_ptr()
config.up_scale = quant_data["up_scales"].data_ptr()
config.down_scale = quant_data["down_scales"].data_ptr()
config.pool = CPUInfer.backend_
moe = kt_kernel_ext.moe.AMXInt4_KGroup_MOE(config)
CPUInfer.submit(moe.load_weights_task(physical_to_logical_map.data_ptr()))
CPUInfer.sync()
gen_iter = 3000
expert_ids = (
torch.rand(gen_iter * qlen, expert_num, device="cpu")
.argsort(dim=-1)[:, :num_experts_per_tok]
.reshape(gen_iter, qlen * num_experts_per_tok)
.contiguous()
)
weights = torch.rand(
(gen_iter, qlen, num_experts_per_tok), dtype=torch.float32, device="cpu"
).contiguous()
input_tensor = torch.randn(
(qlen, hidden_size), dtype=torch.bfloat16, device="cpu"
).contiguous()
output_tensor = torch.empty_like(input_tensor)
bsz_tensor = torch.tensor([qlen], device="cpu")
for i in tqdm(range(warm_up_iter), desc="Warm-up"):
CPUInfer.submit(
moe.forward_task(
bsz_tensor.data_ptr(),
num_experts_per_tok,
expert_ids[i % gen_iter].data_ptr(),
weights[i % gen_iter].data_ptr(),
input_tensor.data_ptr(),
output_tensor.data_ptr(),
False,
)
)
CPUInfer.sync()
start = time.perf_counter()
for i in tqdm(range(test_iter), desc="Testing"):
CPUInfer.submit(
moe.forward_task(
bsz_tensor.data_ptr(),
num_experts_per_tok,
expert_ids[i % gen_iter].data_ptr(),
weights[i % gen_iter].data_ptr(),
input_tensor.data_ptr(),
output_tensor.data_ptr(),
False,
)
)
CPUInfer.sync()
end = time.perf_counter()
total_time = end - start
time_per_iter_us = total_time / test_iter * 1e6
bandwidth = (
hidden_size
* intermediate_size
* 3
* num_experts_per_tok
* (1 / 8 * 256 * (1 - (31 / 32) ** qlen))
* bytes_per_elem
* test_iter
/ total_time
/ 1e9
)
flops = (
hidden_size
* intermediate_size
* qlen
* 3
* num_experts_per_tok
* 2
* test_iter
/ total_time
/ 1e12
)
print("Quant mode: int4_k2")
print("Time(s): ", total_time)
print("Iteration: ", test_iter)
print("Time(us) per iteration: ", time_per_iter_us)
print("Bandwidth: ", bandwidth, "GB/s")
print("Flops: ", flops, "TFLOPS")
print("")
result = {
"quant_mode": "int4_k2",
"total_time_seconds": total_time,
"iterations": test_iter,
"time_per_iteration_us": time_per_iter_us,
"bandwidth_GBs": bandwidth,
"flops_TFLOPS": flops,
"timestamp": time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()),
"test_parameters": {
"expert_num": expert_num,
"hidden_size": hidden_size,
"intermediate_size": intermediate_size,
"max_len": max_len,
"num_experts_per_tok": num_experts_per_tok,
"qlen": qlen,
"warm_up_iter": warm_up_iter,
"test_iter": test_iter,
"k_group_size": k_group_size,
"bytes_per_elem": bytes_per_elem,
},
}
result.update(get_git_commit())
result.update(get_system_info())
record_results(result)
if __name__ == "__main__":
bench_k2_moe()

View file

@ -0,0 +1,309 @@
#!/usr/bin/env python
# coding=utf-8
"""
Benchmark write_weight_scale_to_buffer for AMX_K2_MOE_TP (int4 packed weights + bf16 scales).
"""
import json
import os
import platform
import subprocess
import sys
import time
from tqdm import tqdm
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "build"))
import kt_kernel_ext
import torch
# Benchmark parameters (single MoE, mirror examples/test_k2_write_buffer.py)
expert_num = 384
num_experts_per_tok = expert_num
gpu_tp_count = 4
warm_up_iter = 3
test_iter = 7
gpu_experts_num = expert_num
hidden_size = 7168
intermediate_size = 2048
group_size = 32
max_len = 1
physical_to_logical_map = torch.arange(expert_num, dtype=torch.int64, device="cpu").contiguous()
CPUInfer = kt_kernel_ext.CPUInfer(96)
def get_git_commit():
result = {}
try:
commit = (
subprocess.check_output(["git", "rev-parse", "HEAD"]).decode("utf-8").strip()
)
commit_msg = (
subprocess.check_output(["git", "log", "-1", "--pretty=%B"])
.decode("utf-8")
.strip()
)
result["commit"] = commit
result["commit_message"] = commit_msg
dirty_output = (
subprocess.check_output(["git", "status", "--porcelain"]).decode("utf-8").strip()
)
if dirty_output:
result["dirty"] = True
result["dirty_files"] = dirty_output.splitlines()
else:
result["dirty"] = False
except Exception as e:
result["commit"] = None
result["commit_message"] = None
result["dirty"] = None
result["error"] = str(e)
return result
def get_system_info():
info = {}
uname = platform.uname()
info["system_name"] = uname.system
info["node_name"] = uname.node
cpu_model = None
if os.path.exists("/proc/cpuinfo"):
try:
with open("/proc/cpuinfo", "r") as f:
for line in f:
if "model name" in line:
cpu_model = line.split(":", 1)[1].strip()
break
except Exception as e:
cpu_model = f"Error: {e}"
info["cpu_model"] = cpu_model
mem_total_gb = None
if os.path.exists("/proc/meminfo"):
try:
with open("/proc/meminfo", "r") as f:
for line in f:
if "MemTotal" in line:
mem_kb = float(line.split(":", 1)[1].split()[0])
mem_total_gb = round(mem_kb / (1024 * 1024), 2)
break
except Exception as e:
mem_total_gb = f"Error: {e}"
info["memory_size_GB"] = mem_total_gb
info["cpu_core_count"] = os.cpu_count()
sockets = set()
if os.path.exists("/proc/cpuinfo"):
try:
with open("/proc/cpuinfo", "r") as f:
for line in f:
if "physical id" in line:
sockets.add(line.split(":", 1)[1].strip())
except Exception:
sockets = set()
info["cpu_socket_count"] = len(sockets) if len(sockets) > 0 else 1
return info
script_path = os.path.abspath(__file__)
script_dir = os.path.dirname(script_path)
script_name = os.path.splitext(os.path.basename(script_path))[0]
json_path = os.path.join(script_dir, script_name + ".jsonl")
def record_results(result, filename=json_path):
with open(filename, "a") as f:
f.write(json.dumps(result) + "\n")
def allocate_weights():
per_mat_weight_bytes = (hidden_size * intermediate_size) // 2
per_mat_scale_elems = (hidden_size * intermediate_size) // group_size
gate_q = torch.randint(0, 256, (expert_num * per_mat_weight_bytes,), dtype=torch.uint8)
up_q = torch.randint(0, 256, (expert_num * per_mat_weight_bytes,), dtype=torch.uint8)
down_q = torch.randint(0, 256, (expert_num * per_mat_weight_bytes,), dtype=torch.uint8)
gate_scale = torch.randn(expert_num * per_mat_scale_elems, dtype=torch.bfloat16)
up_scale = torch.randn(expert_num * per_mat_scale_elems, dtype=torch.bfloat16)
down_scale = torch.randn(expert_num * per_mat_scale_elems, dtype=torch.bfloat16)
return (
gate_q.contiguous(),
up_q.contiguous(),
down_q.contiguous(),
gate_scale.contiguous(),
up_scale.contiguous(),
down_scale.contiguous(),
per_mat_weight_bytes,
per_mat_scale_elems,
)
def build_moe():
(
gate_q,
up_q,
down_q,
gate_scale,
up_scale,
down_scale,
per_mat_weight_bytes,
per_mat_scale_elems,
) = allocate_weights()
config = kt_kernel_ext.moe.MOEConfig(
expert_num, num_experts_per_tok, hidden_size, intermediate_size
)
config.max_len = max_len
config.quant_config.bits = 4
config.quant_config.group_size = group_size
config.quant_config.zero_point = False
config.pool = CPUInfer.backend_
config.gate_proj = gate_q.data_ptr()
config.up_proj = up_q.data_ptr()
config.down_proj = down_q.data_ptr()
config.gate_scale = gate_scale.data_ptr()
config.up_scale = up_scale.data_ptr()
config.down_scale = down_scale.data_ptr()
moe = kt_kernel_ext.moe.AMXInt4_KGroup_MOE(config)
CPUInfer.submit(moe.load_weights_task(physical_to_logical_map.data_ptr()))
CPUInfer.sync()
# Buffer sizing per TP
weight_bytes_per_expert_per_tp = per_mat_weight_bytes // gpu_tp_count
scale_elems_per_expert_per_tp = per_mat_scale_elems // gpu_tp_count
total_weight_bytes_per_tp = gpu_experts_num * weight_bytes_per_expert_per_tp
total_scale_elems_per_tp = gpu_experts_num * scale_elems_per_expert_per_tp
w13_weight_bufs = [
torch.empty(2 * total_weight_bytes_per_tp, dtype=torch.uint8) for _ in range(gpu_tp_count)
]
w13_scale_bufs = [
torch.empty(2 * total_scale_elems_per_tp, dtype=torch.bfloat16) for _ in range(gpu_tp_count)
]
w2_weight_bufs = [
torch.empty(total_weight_bytes_per_tp, dtype=torch.uint8) for _ in range(gpu_tp_count)
]
w2_scale_bufs = [
torch.empty(total_scale_elems_per_tp, dtype=torch.bfloat16) for _ in range(gpu_tp_count)
]
buffer_ptrs = {
"w13_weight_ptrs": [buf.data_ptr() for buf in w13_weight_bufs],
"w13_scale_ptrs": [buf.data_ptr() for buf in w13_scale_bufs],
"w2_weight_ptrs": [buf.data_ptr() for buf in w2_weight_bufs],
"w2_scale_ptrs": [buf.data_ptr() for buf in w2_scale_bufs],
}
buffer_shapes = {
"per_mat_weight_bytes": per_mat_weight_bytes,
"per_mat_scale_elems": per_mat_scale_elems,
"weight_bytes_per_expert_per_tp": weight_bytes_per_expert_per_tp,
"scale_elems_per_expert_per_tp": scale_elems_per_expert_per_tp,
"total_weight_bytes_per_tp": total_weight_bytes_per_tp,
"total_scale_elems_per_tp": total_scale_elems_per_tp,
}
keep_tensors = {
"gate_q": gate_q,
"up_q": up_q,
"down_q": down_q,
"gate_scale": gate_scale,
"up_scale": up_scale,
"down_scale": down_scale,
"w13_weight_bufs": w13_weight_bufs,
"w13_scale_bufs": w13_scale_bufs,
"w2_weight_bufs": w2_weight_bufs,
"w2_scale_bufs": w2_scale_bufs,
}
return moe, buffer_ptrs, buffer_shapes, keep_tensors
def bench_write_buffer():
moe, buffer_ptrs, buffer_shapes, keep_tensors = build_moe()
total_weights = hidden_size * intermediate_size * expert_num * 3
# Throughput accounting consistent with examples/test_k2_write_buffer.py
bytes_per_call = total_weights // group_size + total_weights // 2
# Warm-up
for _ in tqdm(range(warm_up_iter), desc="Warm-up"):
CPUInfer.submit(
moe.write_weight_scale_to_buffer_task(
gpu_tp_count=gpu_tp_count,
gpu_experts_num=gpu_experts_num,
**buffer_ptrs,
)
)
CPUInfer.sync()
total_time = 0
for _ in tqdm(range(test_iter), desc="Testing"):
start = time.perf_counter()
CPUInfer.submit(
moe.write_weight_scale_to_buffer_task(
gpu_tp_count=gpu_tp_count,
gpu_experts_num=gpu_experts_num,
**buffer_ptrs,
)
)
CPUInfer.sync()
end = time.perf_counter()
total_time += end - start
time.sleep(0.6)
print(end - start)
time_per_iter_us = total_time / test_iter * 1e6
bandwidth_gbs = bytes_per_call * test_iter / total_time / 1e9
print("write_weight_scale_to_buffer benchmark")
print("Time(s): ", total_time)
print("Iteration: ", test_iter)
print("Time(us) per iteration: ", time_per_iter_us)
print("Bandwidth: ", bandwidth_gbs, "GB/s")
print("")
result = {
"op": "write_weight_scale_to_buffer",
"total_time_seconds": total_time,
"iterations": test_iter,
"time_per_iteration_us": time_per_iter_us,
"bandwidth_GBs": bandwidth_gbs,
"timestamp": time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()),
"test_parameters": {
"expert_num": expert_num,
"hidden_size": hidden_size,
"intermediate_size": intermediate_size,
"group_size": group_size,
"max_len": max_len,
"num_experts_per_tok": num_experts_per_tok,
"gpu_tp_count": gpu_tp_count,
"gpu_experts_num": gpu_experts_num,
"warm_up_iter": warm_up_iter,
"test_iter": test_iter,
"bytes_per_call": bytes_per_call,
},
"buffer_shapes": buffer_shapes,
"keep_tensors_alive": list(keep_tensors.keys()),
}
result.update(get_git_commit())
result.update(get_system_info())
record_results(result)
if __name__ == "__main__":
bench_write_buffer()

View file

@ -0,0 +1,319 @@
import math
import os
import sys
from typing import Dict, Literal
sys.path.insert(0, os.path.dirname(__file__) + "/../build")
import torch
import kt_kernel_ext
torch.manual_seed(42)
hidden_size = 7168
intermediate_size = 2048
max_len = 25600
expert_num = 16
num_experts_per_tok = 8
qlen = 1
layer_num = 1
CPUInfer = kt_kernel_ext.CPUInfer(40)
validation_iter = 10
k_group_size = 32
debug_print_count = 16
physical_to_logical_map = torch.tensor(data=range(expert_num), device="cpu", dtype=torch.int64).contiguous()
def _pattern_uniform(groups: int) -> torch.Tensor:
return torch.full((groups,), 0.02, dtype=torch.float32)
def _pattern_alternating(groups: int) -> torch.Tensor:
vals = torch.full((groups,), 0.015, dtype=torch.float32)
vals[1::2] = 0.03
return vals
def _pattern_ramp(groups: int) -> torch.Tensor:
return torch.linspace(0.005, 0.04, steps=groups, dtype=torch.float32)
WEIGHT_PATTERNS = {
"uniform_scale": ("All k-groups share the same abs max / scale", _pattern_uniform),
"alternating_scale": ("Alternate small / large abs max per k-group", _pattern_alternating),
"ramp_scale": ("Linearly increasing abs max per k-group", _pattern_ramp),
"random": ("Random bf16 weights (baseline)", None),
}
def act_fn(x):
return x / (1.0 + torch.exp(-x))
def mlp_torch(input, gate_proj, up_proj, down_proj):
gate_buf = torch.mm(input, gate_proj.t())
up_buf = torch.mm(input, up_proj.t())
print(f"gate_buf: {gate_buf}")
print(f"up_buf: {up_buf}")
intermediate = act_fn(gate_buf) * up_buf
ret = torch.mm(intermediate, down_proj.t())
print(f"intermediate: {intermediate}")
print(f"mlp output: {ret}")
return ret
def moe_torch(input, expert_ids, weights, gate_proj, up_proj, down_proj):
cnts = expert_ids.new_zeros((expert_ids.shape[0], expert_num))
cnts.scatter_(1, expert_ids, 1)
tokens_per_expert = cnts.sum(dim=0)
idxs = expert_ids.view(-1).argsort()
sorted_tokens = input[idxs // expert_ids.shape[1]]
outputs = []
start_idx = 0
for i, num_tokens in enumerate(tokens_per_expert):
end_idx = start_idx + num_tokens
if num_tokens == 0:
continue
tokens_for_this_expert = sorted_tokens[start_idx:end_idx]
expert_out = mlp_torch(tokens_for_this_expert, gate_proj[i], up_proj[i], down_proj[i])
outputs.append(expert_out)
start_idx = end_idx
outs = torch.cat(outputs, dim=0) if len(outputs) else sorted_tokens.new_empty(0)
new_x = torch.empty_like(outs)
new_x[idxs] = outs
t_output = (
new_x.view(*expert_ids.shape, -1)
.type(weights.dtype)
.mul_(weights.unsqueeze(dim=-1))
.sum(dim=1)
.type(new_x.dtype)
)
return t_output
def pack_to_int32(value: torch.Tensor, num_bits: int, packed_dim: Literal[0, 1] = 1) -> torch.Tensor:
if value.dtype is not torch.int8:
raise ValueError("Tensor must be torch.int8 before packing")
if not (1 <= num_bits <= 8):
raise ValueError(f"num_bits must be in [1, 8], got {num_bits}")
offset = 1 << (num_bits - 1)
value = (value + offset).to(torch.uint8)
device = value.device
pack_factor = 32 // num_bits
if packed_dim == 0:
value = value.transpose(0, 1)
rows, cols = value.shape
padded_cols = math.ceil(cols / pack_factor) * pack_factor
pad_len = padded_cols - cols
if pad_len > 0:
value = torch.nn.functional.pad(value, (0, pad_len))
num_groups = padded_cols // pack_factor
# Use int32 here
reshaped = value.view(rows, num_groups, pack_factor).to(torch.int32)
bit_shifts = torch.arange(pack_factor, device=device, dtype=torch.int32) * num_bits
packed = (reshaped << bit_shifts).sum(dim=2, dtype=torch.int32)
if packed_dim == 0:
packed = packed.transpose(0, 1)
return packed
def pack_tensor_per_row(q: torch.Tensor, num_bits: int) -> torch.Tensor:
e, rows, cols = q.shape
flat = q.view(e * rows, cols)
packed = pack_to_int32(flat, num_bits)
return packed.view(e, rows, -1).contiguous()
def quantize_k2_tensor(weights: torch.Tensor, group_size: int):
"""
Symmetric max-abs/7 quantization per k-group following compressed_tensors packing.
Args:
weights: [expert_num, rows (N), cols (K)]
Returns:
packed_q: int32 tensor storing 8 int4s per element with shape [expert_num, rows * (cols // 8)]
scales: bfloat16 tensor with shape [expert_num, rows * (cols // group_size)]
"""
weights_f32 = weights.to(torch.float32)
e, rows, cols = weights_f32.shape
if cols % group_size != 0 or cols % 2 != 0:
raise ValueError(f"cols ({cols}) must be divisible by group_size ({group_size}) and 2")
reshaped = weights_f32.view(e, rows, cols // group_size, group_size)
max_abs = reshaped.abs().amax(dim=-1, keepdim=True)
max_abs = torch.clamp(max_abs, min=1e-8)
scales = (max_abs / 7.0).squeeze(-1)
q = torch.round(reshaped / scales.unsqueeze(-1)).clamp(-8, 7).to(torch.int8)
q = q.view(e, rows, cols)
packed = pack_tensor_per_row(q, num_bits=4).view(e, rows, cols // 8).contiguous()
scales = scales.to(torch.bfloat16).contiguous().view(e, rows, cols // group_size).contiguous()
print(f"Quantized weights: {packed.shape}, scales: {scales.shape}")
print(f"Quantized tensors: \n{packed},\n {scales}")
return packed, scales
def build_structured_tensor(shape: torch.Size, pattern: str) -> torch.Tensor:
if pattern == "random":
torch.manual_seed(42)
return (torch.randn(shape, dtype=torch.bfloat16, device="cpu") / 100.0).contiguous()
e, rows, cols = shape
groups = cols // k_group_size
group_builder = WEIGHT_PATTERNS[pattern][1]
group_vals = group_builder(groups).to(torch.float32)
block = group_vals.view(1, 1, groups, 1).expand(e, rows, groups, k_group_size).clone()
row_signs = torch.where(
(torch.arange(rows) % 2 == 0),
torch.ones(rows, dtype=torch.float32),
-torch.ones(rows, dtype=torch.float32),
).view(1, rows, 1, 1)
col_offsets = torch.linspace(-0.0005, 0.0005, steps=k_group_size, dtype=torch.float32).view(1, 1, 1, k_group_size)
block = block * row_signs + col_offsets
return block.reshape(shape).to(torch.bfloat16).contiguous()
def prepare_k2_quantized_weights(pattern: str) -> Dict[str, torch.Tensor]:
if pattern not in WEIGHT_PATTERNS:
raise ValueError(f"Unknown weight pattern: {pattern}")
gate_proj = build_structured_tensor((expert_num, intermediate_size, hidden_size), pattern)
up_proj = build_structured_tensor((expert_num, intermediate_size, hidden_size), pattern)
down_proj = build_structured_tensor((expert_num, hidden_size, intermediate_size), pattern)
gate_q, gate_scales = quantize_k2_tensor(gate_proj, k_group_size)
up_q, up_scales = quantize_k2_tensor(up_proj, k_group_size)
down_q, down_scales = quantize_k2_tensor(down_proj, k_group_size)
return {
"gate_qweight": gate_q.contiguous(),
"up_qweight": up_q.contiguous(),
"down_qweight": down_q.contiguous(),
"gate_scales": gate_scales.contiguous(),
"up_scales": up_scales.contiguous(),
"down_scales": down_scales.contiguous(),
"original_fp16": {
"gate_proj": gate_proj.to(torch.float16).contiguous(),
"up_proj": up_proj.to(torch.float16).contiguous(),
"down_proj": down_proj.to(torch.float16).contiguous(),
},
}
def build_moes_from_quantized_data(quant_data: Dict[str, torch.Tensor]):
moes = []
with torch.inference_mode(mode=True):
for _ in range(layer_num):
config = kt_kernel_ext.moe.MOEConfig(expert_num, num_experts_per_tok, hidden_size, intermediate_size, 0)
config.max_len = max_len
config.quant_config.bits = 4
config.quant_config.group_size = k_group_size
config.quant_config.zero_point = False
config.gate_proj = quant_data["gate_qweight"].data_ptr()
config.up_proj = quant_data["up_qweight"].data_ptr()
config.down_proj = quant_data["down_qweight"].data_ptr()
config.gate_scale = quant_data["gate_scales"].data_ptr()
config.up_scale = quant_data["up_scales"].data_ptr()
config.down_scale = quant_data["down_scales"].data_ptr()
config.pool = CPUInfer.backend_
moe = kt_kernel_ext.moe.AMXInt4_KGroup_MOE(config)
CPUInfer.submit(moe.load_weights_task(physical_to_logical_map.data_ptr()))
CPUInfer.sync()
# CPUInfer.submit(moe.warm_up_task())
# CPUInfer.sync()
moes.append(moe)
return moes
def run_case(pattern: str) -> Dict[str, float]:
print("\n" + "=" * 70)
desc = WEIGHT_PATTERNS[pattern][0]
print(f"Running case: {pattern} -> {desc}")
print("=" * 70)
quant_data = prepare_k2_quantized_weights(pattern)
moes = build_moes_from_quantized_data(quant_data)
original_weights = quant_data["original_fp16"]
gate_fp16 = original_weights["gate_proj"]
up_fp16 = original_weights["up_proj"]
down_fp16 = original_weights["down_proj"]
diffs = []
with torch.inference_mode(mode=True):
for i in range(validation_iter):
torch.manual_seed(100 + i)
bsz_tensor = torch.tensor([qlen], device="cpu")
expert_ids = torch.stack(
[torch.randperm(expert_num)[:num_experts_per_tok] for _ in range(qlen)]
).contiguous()
weights = torch.randn((qlen, num_experts_per_tok), dtype=torch.float32).contiguous()
input_tensor = torch.randn((qlen, hidden_size), dtype=torch.bfloat16).contiguous() / 100
output = torch.empty((qlen, hidden_size), dtype=torch.bfloat16).contiguous()
moe = moes[i % layer_num]
CPUInfer.submit(
moe.forward_task(
bsz_tensor.data_ptr(),
num_experts_per_tok,
expert_ids.data_ptr(),
weights.data_ptr(),
input_tensor.data_ptr(),
output.data_ptr(),
False,
)
)
CPUInfer.sync()
input_tensor_fp16 = input_tensor.to(torch.float16)
t_output = moe_torch(
input_tensor_fp16, expert_ids, weights, gate_fp16, up_fp16, down_fp16
).to(torch.bfloat16)
t_output = t_output.flatten()
output = output.flatten()
diff = torch.mean(torch.abs(output - t_output)) / (torch.mean(torch.abs(t_output)) + 1e-12)
diffs.append(diff.item())
print(f"[{pattern}] Iteration {i}: relative L1 diff = {diff:.4f}")
print(f" output {output}")
print(f" t_output {t_output}")
mean_diff = float(sum(diffs) / len(diffs))
max_diff = float(max(diffs))
min_diff = float(min(diffs))
return {"case": pattern, "description": desc, "mean": mean_diff, "max": max_diff, "min": min_diff}
def run_k2_moe_test():
summary_rows = []
for case_name in WEIGHT_PATTERNS.keys():
results = run_case(case_name)
summary_rows.append(results)
# break
print("\n=== Case vs. Relative Error Summary ===")
print(f"{'Case':<20} {'Mean':>10} {'Max':>10} {'Min':>10}")
for row in summary_rows:
print(f"{row['case']:<20} {row['mean']*100:9.2f}% {row['max']*100:9.2f}% {row['min']*100:9.2f}%")
if __name__ == "__main__":
run_k2_moe_test()

View file

@ -0,0 +1,267 @@
import os
import sys
import time
import torch
import numpy as np
# Ensure we can import the local extension
# REPO_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "../"))
# if REPO_ROOT not in sys.path:
# sys.path.insert(0, REPO_ROOT)
import kt_kernel_ext
from kt_kernel_ext import CPUInfer
def make_cpu_infer(thread_num=80):
return CPUInfer(thread_num)
def build_config(cpuinfer, expert_num, num_experts_per_tok, hidden_size, intermediate_size, group_size):
cfg = kt_kernel_ext.moe.MOEConfig(expert_num, num_experts_per_tok, hidden_size, intermediate_size)
cfg.max_len = 1
cfg.quant_config.bits = 4
cfg.quant_config.group_size = group_size
cfg.quant_config.zero_point = False
cfg.pool = cpuinfer.backend_
return cfg
def allocate_weights(expert_num, hidden_size, intermediate_size, group_size):
# packed int4 weights: 2 values per byte
per_mat_weight_bytes = (hidden_size * intermediate_size) // 2
per_mat_scale_elems = (hidden_size * intermediate_size) // group_size
gate_q = torch.randint(0, 256, (expert_num * per_mat_weight_bytes,), dtype=torch.uint8)
up_q = torch.randint(0, 256, (expert_num * per_mat_weight_bytes,), dtype=torch.uint8)
down_q = torch.randint(0, 256, (expert_num * per_mat_weight_bytes,), dtype=torch.uint8)
gate_scale = torch.randn(expert_num * per_mat_scale_elems, dtype=torch.bfloat16)
up_scale = torch.randn(expert_num * per_mat_scale_elems, dtype=torch.bfloat16)
down_scale = torch.randn(expert_num * per_mat_scale_elems, dtype=torch.bfloat16)
return (
gate_q,
up_q,
down_q,
gate_scale,
up_scale,
down_scale,
per_mat_weight_bytes,
per_mat_scale_elems,
)
def main():
torch.manual_seed(123)
expert_num = 256 # Total experts
gpu_experts = expert_num # Number of experts on GPU
gpu_tp_count = 2 # Number of TP parts
num_experts_per_tok = 8
hidden_size = 7168
intermediate_size = 2048
group_size = 32
cpuinfer = make_cpu_infer()
cfg = build_config(cpuinfer, expert_num, num_experts_per_tok, hidden_size, intermediate_size, group_size)
(
gate_q,
up_q,
down_q,
gate_scale,
up_scale,
down_scale,
per_mat_weight_bytes,
per_mat_scale_elems,
) = allocate_weights(expert_num, hidden_size, intermediate_size, group_size)
cfg.gate_proj = gate_q.data_ptr()
cfg.up_proj = up_q.data_ptr()
cfg.down_proj = down_q.data_ptr()
cfg.gate_scale = gate_scale.data_ptr()
cfg.up_scale = up_scale.data_ptr()
cfg.down_scale = down_scale.data_ptr()
moe = kt_kernel_ext.moe.AMXInt4_KGroup_MOE(cfg)
physical_to_logical_map = (
torch.arange(expert_num, dtype=torch.int64, device="cpu").contiguous()
)
cpuinfer.submit(moe.load_weights_task(physical_to_logical_map.data_ptr()))
cpuinfer.sync()
# TP configuration
# Since weights are col-major, we can directly divide the total size by tp_count
# Each matrix is divided into gpu_tp_count parts in memory order
# Calculate sizes per TP part (direct division since col-major)
weight_bytes_per_expert_per_tp = per_mat_weight_bytes // gpu_tp_count
scale_elems_per_expert_per_tp = per_mat_scale_elems // gpu_tp_count
# Total sizes for all gpu_experts
total_weight_bytes_per_tp = gpu_experts * weight_bytes_per_expert_per_tp
total_scale_elems_per_tp = gpu_experts * scale_elems_per_expert_per_tp
# Create buffer lists for w13 (gate+up) and w2 (down)
w13_weight_bufs = []
w13_scale_bufs = []
w2_weight_bufs = []
w2_scale_bufs = []
for tp_idx in range(gpu_tp_count):
# w13 combines gate and up, so needs 2x the size
w13_weight_bufs.append(torch.empty(2 * total_weight_bytes_per_tp, dtype=torch.uint8))
w13_scale_bufs.append(torch.empty(2 * total_scale_elems_per_tp, dtype=torch.bfloat16))
w2_weight_bufs.append(torch.empty(total_weight_bytes_per_tp, dtype=torch.uint8))
w2_scale_bufs.append(torch.empty(total_scale_elems_per_tp, dtype=torch.bfloat16))
# Get data pointers for all buffers
w13_weight_ptrs = [buf.data_ptr() for buf in w13_weight_bufs]
w13_scale_ptrs = [buf.data_ptr() for buf in w13_scale_bufs]
w2_weight_ptrs = [buf.data_ptr() for buf in w2_weight_bufs]
w2_scale_ptrs = [buf.data_ptr() for buf in w2_scale_bufs]
print(f"Total experts: {expert_num}, GPU experts: {gpu_experts}")
print(f"GPU TP count: {gpu_tp_count}")
print(f"Original per matrix weight bytes: {per_mat_weight_bytes}")
print(f"Original per matrix scale elements: {per_mat_scale_elems}")
print(f"Weight bytes per expert per TP: {weight_bytes_per_expert_per_tp}")
print(f"Scale elements per expert per TP: {scale_elems_per_expert_per_tp}")
print(f"Total weight bytes per TP (w13): {2 * total_weight_bytes_per_tp}")
print(f"Total weight bytes per TP (w2): {total_weight_bytes_per_tp}")
print(f"Total scale elements per TP (w13): {2 * total_scale_elems_per_tp}")
print(f"Total scale elements per TP (w2): {total_scale_elems_per_tp}")
for i in range(5):
cpuinfer.submit(
moe.write_weight_scale_to_buffer_task(
gpu_tp_count=gpu_tp_count,
gpu_experts_num=gpu_experts,
w13_weight_ptrs=w13_weight_ptrs,
w13_scale_ptrs=w13_scale_ptrs,
w2_weight_ptrs=w2_weight_ptrs,
w2_scale_ptrs=w2_scale_ptrs,
)
)
cpuinfer.sync()
begin_time = time.perf_counter_ns()
cpuinfer.submit(
moe.write_weight_scale_to_buffer_task(
gpu_tp_count=gpu_tp_count,
gpu_experts_num=gpu_experts,
w13_weight_ptrs=w13_weight_ptrs,
w13_scale_ptrs=w13_scale_ptrs,
w2_weight_ptrs=w2_weight_ptrs,
w2_scale_ptrs=w2_scale_ptrs,
)
)
cpuinfer.sync()
end_time = time.perf_counter_ns()
elapsed_ms = (end_time - begin_time) / 1000000
total_weights = hidden_size * intermediate_size * expert_num * 3
total_bytes = total_weights // group_size + total_weights // 2
print(f"write_weight_scale_to_buffer time: {elapsed_ms:.2f} ms")
print(f"Throughput: {total_bytes / (elapsed_ms * 1e6):.2f} GB/s")
def split_expert_tensor(tensor, chunk):
"""Split tensor by experts"""
return [tensor[i * chunk : (i + 1) * chunk] for i in range(expert_num)]
# Split by experts first
gate_q_experts = split_expert_tensor(gate_q, per_mat_weight_bytes)
up_q_experts = split_expert_tensor(up_q, per_mat_weight_bytes)
down_q_experts = split_expert_tensor(down_q, per_mat_weight_bytes)
gate_scale_experts = split_expert_tensor(gate_scale, per_mat_scale_elems)
up_scale_experts = split_expert_tensor(up_scale, per_mat_scale_elems)
down_scale_experts = split_expert_tensor(down_scale, per_mat_scale_elems)
# CPU TP count is always 2 in this test setup (one TP per NUMA node)
cpu_tp_count = 2
# Verify buffers for each TP part
for tp_idx in range(gpu_tp_count):
expected_w13_weights = []
expected_w13_scales = []
expected_w2_weights = []
expected_w2_scales = []
weight13_per_tp = per_mat_weight_bytes // gpu_tp_count
scale13_per_tp = per_mat_scale_elems // gpu_tp_count
# Process each GPU expert
for expert_idx in range(gpu_experts):
# For w13 (gate and up), the slicing is straightforward
start_weight = tp_idx * weight13_per_tp
end_weight = (tp_idx + 1) * weight13_per_tp
start_scale = tp_idx * scale13_per_tp
end_scale = (tp_idx + 1) * scale13_per_tp
# Gate
gate_weight_tp = gate_q_experts[expert_idx][start_weight:end_weight]
gate_scale_tp = gate_scale_experts[expert_idx][start_scale:end_scale]
# Up
up_weight_tp = up_q_experts[expert_idx][start_weight:end_weight]
up_scale_tp = up_scale_experts[expert_idx][start_scale:end_scale]
# Down matrix needs special handling because it's sliced column-wise
# We need to reconstruct it from column slices
down_weight_tp_parts = []
down_scale_tp_parts = []
# Iterate through each column to extract the corresponding parts
for col_idx in range(hidden_size):
col_weight_start = col_idx * (intermediate_size // 2)
col_scale_start = col_idx * (intermediate_size // group_size)
# Direct mapping: each CPU TP corresponds to a GPU TP
tp_slice_weight_size = (intermediate_size // gpu_tp_count) // 2
tp_slice_scale_size = (intermediate_size // gpu_tp_count) // group_size
tp_weight_offset = col_weight_start + tp_idx * tp_slice_weight_size
tp_scale_offset = col_scale_start + tp_idx * tp_slice_scale_size
down_weight_tp_parts.append(
down_q_experts[expert_idx][tp_weight_offset:tp_weight_offset + tp_slice_weight_size]
)
down_scale_tp_parts.append(
down_scale_experts[expert_idx][tp_scale_offset:tp_scale_offset + tp_slice_scale_size]
)
# Concatenate all column slices for this TP
down_weight_tp = torch.cat(down_weight_tp_parts)
down_scale_tp = torch.cat(down_scale_tp_parts)
expected_w13_weights.append(gate_weight_tp)
expected_w13_weights.append(up_weight_tp)
expected_w13_scales.append(gate_scale_tp)
expected_w13_scales.append(up_scale_tp)
expected_w2_weights.append(down_weight_tp)
expected_w2_scales.append(down_scale_tp)
# Concatenate all experts for this TP part
expected_w13_weight = torch.cat(expected_w13_weights)
expected_w13_scale = torch.cat(expected_w13_scales)
expected_w2_weight = torch.cat(expected_w2_weights)
expected_w2_scale = torch.cat(expected_w2_scales)
print(f"=== Checking TP part {tp_idx} ===")
# Assert all checks pass
assert torch.equal(w13_weight_bufs[tp_idx], expected_w13_weight), f"w13 weight bytes mismatch for TP {tp_idx}"
assert torch.allclose(w13_scale_bufs[tp_idx], expected_w13_scale), f"w13 scale values mismatch for TP {tp_idx}"
assert torch.equal(w2_weight_bufs[tp_idx], expected_w2_weight), f"w2 weight bytes mismatch for TP {tp_idx}"
assert torch.allclose(w2_scale_bufs[tp_idx], expected_w2_scale), f"w2 scale values mismatch for TP {tp_idx}"
print(f"\n✓ write_weight_scale_to_buffer passed: extracted {gpu_experts} GPU experts across {gpu_tp_count} TP parts from total {expert_num} experts")
if __name__ == "__main__":
main()

View file

@ -36,6 +36,7 @@ static const bool _is_plain_ = false;
#if defined(__x86_64__) && defined(USE_AMX_AVX_KERNEL)
#include "operators/amx/awq-moe.hpp"
#include "operators/amx/k2-moe.hpp"
#include "operators/amx/la/amx_kernels.hpp"
#include "operators/amx/moe.hpp"
#endif
@ -43,6 +44,7 @@ static const bool _is_plain_ = false;
#include <cstdint>
#include <memory>
#include <type_traits>
#include "operators/kvcache/kvcache.h"
#include "operators/llamafile/linear.h"
@ -225,7 +227,9 @@ void bind_moe_module(py::module_& moe_module, const char* name) {
using MoeClass = TP_MOE<MoeTP>;
using MoeBindings = MOEBindings<MoeTP>;
py::class_<MoeClass, MoE_Interface, std::shared_ptr<MoeClass>>(moe_module, name)
auto moe_cls = py::class_<MoeClass, MoE_Interface, std::shared_ptr<MoeClass>>(moe_module, name);
moe_cls
.def(py::init<GeneralMOEConfig>())
.def("warm_up_task", &MoeBindings::WarmUpBindings::cpuinfer_interface)
.def("load_weights_task",
@ -244,6 +248,53 @@ void bind_moe_module(py::module_& moe_module, const char* name) {
.def("warm_up", &MoeClass::warm_up)
.def("load_weights", &MoeClass::load_weights)
.def("forward", &MoeClass::forward_binding);
#if defined(__x86_64__) && defined(USE_AMX_AVX_KERNEL)
if constexpr (std::is_same_v<MoeTP, AMX_K2_MOE_TP<amx::GemmKernel224Int4SmallKGroup>>) {
struct WriteWeightScaleToBufferBindings {
struct Args {
CPUInfer* cpuinfer;
MoeClass* moe;
int gpu_tp_count;
int gpu_experts_num;
std::vector<uintptr_t> w13_weight_ptrs;
std::vector<uintptr_t> w13_scale_ptrs;
std::vector<uintptr_t> w2_weight_ptrs;
std::vector<uintptr_t> w2_scale_ptrs;
};
static void inner(void* args) {
Args* args_ = (Args*)args;
args_->cpuinfer->enqueue(&MoeClass::write_weight_scale_to_buffer, args_->moe,
args_->gpu_tp_count, args_->gpu_experts_num,
args_->w13_weight_ptrs, args_->w13_scale_ptrs,
args_->w2_weight_ptrs, args_->w2_scale_ptrs);
}
static std::pair<intptr_t, intptr_t> cpuinfer_interface(std::shared_ptr<MoeClass> moe,
int gpu_tp_count, int gpu_experts_num,
py::list w13_weight_ptrs, py::list w13_scale_ptrs,
py::list w2_weight_ptrs, py::list w2_scale_ptrs) {
// Convert Python lists to std::vector<uintptr_t>
std::vector<uintptr_t> w13_weight_vec, w13_scale_vec, w2_weight_vec, w2_scale_vec;
for (auto item : w13_weight_ptrs) w13_weight_vec.push_back(py::cast<uintptr_t>(item));
for (auto item : w13_scale_ptrs) w13_scale_vec.push_back(py::cast<uintptr_t>(item));
for (auto item : w2_weight_ptrs) w2_weight_vec.push_back(py::cast<uintptr_t>(item));
for (auto item : w2_scale_ptrs) w2_scale_vec.push_back(py::cast<uintptr_t>(item));
Args* args = new Args{nullptr, moe.get(), gpu_tp_count, gpu_experts_num,
w13_weight_vec, w13_scale_vec, w2_weight_vec, w2_scale_vec};
return std::make_pair((intptr_t)&inner, (intptr_t)args);
}
};
moe_cls.def("write_weight_scale_to_buffer_task", &WriteWeightScaleToBufferBindings::cpuinfer_interface,
py::arg("gpu_tp_count"), py::arg("gpu_experts_num"),
py::arg("w13_weight_ptrs"), py::arg("w13_scale_ptrs"),
py::arg("w2_weight_ptrs"), py::arg("w2_scale_ptrs"));
}
#endif
}
PYBIND11_MODULE(kt_kernel_ext, m) {
@ -513,6 +564,7 @@ PYBIND11_MODULE(kt_kernel_ext, m) {
bind_moe_module<AMX_MOE_TP<amx::GemmKernel224Int4>>(moe_module, "AMXInt4_MOE");
bind_moe_module<AMX_MOE_TP<amx::GemmKernel224Int4_1>>(moe_module, "AMXInt4_1_MOE");
bind_moe_module<AMX_AWQ_MOE_TP<amx::GemmKernel224Int4_1_LowKGroup>>(moe_module, "AMXInt4_1KGroup_MOE");
bind_moe_module<AMX_K2_MOE_TP<amx::GemmKernel224Int4SmallKGroup>>(moe_module, "AMXInt4_KGroup_MOE");
#endif
#if defined(USE_MOE_KERNEL)
bind_moe_module<MOE_KERNEL_TP<moe_kernel::GemmKernelInt8, _is_plain_>>(moe_module, "Int8_KERNEL_MOE");

View file

@ -0,0 +1,929 @@
/**
* @Description : Skeleton for K2 AMX MoE operator.
* @Author : Codex
* @Date : 2024-07-22
* @Version : 0.1.0
* @LastEditors : Codex
* @LastEditTime : 2024-07-22
* @Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
**/
#ifndef CPUINFER_OPERATOR_AMX_K2_MOE_H
#define CPUINFER_OPERATOR_AMX_K2_MOE_H
// #define DEBUG_K2_MOE
#include <cstddef>
#include <cstdint>
#include <cstring>
// #define FORWARD_TIME_PROFILE
// #define FORWARD_TIME_REPORT
#include <immintrin.h>
#include <algorithm>
#include <chrono>
#include <cmath>
#include <cstdio>
#include <filesystem>
#include <fstream>
#include <string>
#include <vector>
#include "../../cpu_backend/shared_mem_buffer.h"
#include "../../cpu_backend/worker_pool.h"
#include "../common.hpp"
#include "../moe-tp.hpp"
#include "la/amx.hpp"
#include "llama.cpp/ggml.h"
template <class T>
class AMX_K2_MOE_TP {
private:
int tp_part_idx = 0;
void* gate_proj_ = nullptr; // [expert_num * intermediate_size * hidden_size ( /32 if quantized)]
void* up_proj_ = nullptr; // [expert_num * intermediate_size * hidden_size ( /32 if quantized)]
void* down_proj_ = nullptr; // [expert_num * hidden_size * intermediate_size ( /32 if quantized)]
ggml_bf16_t* m_local_input_ = nullptr; // [num_experts_per_tok * max_len * hidden_size]
ggml_bf16_t* m_local_gate_output_ = nullptr; // [num_experts_per_tok * max_len * intermediate_size]
ggml_bf16_t* m_local_up_output_ = nullptr; // [num_experts_per_tok * max_len * intermediate_size]
ggml_bf16_t* m_local_down_output_ = nullptr; // [num_experts_per_tok * max_len * hidden_size]
std::vector<std::vector<int>> m_local_pos_; // [max_len, num_experts_per_tok]
std::vector<int> m_local_num_; // [expert_num]
std::vector<int> m_expert_id_map_; // [expert_num]
std::vector<ggml_bf16_t*> m_local_input_ptr_; // [expert_num]
std::vector<ggml_bf16_t*> m_local_gate_output_ptr_; // [expert_num]
std::vector<ggml_bf16_t*> m_local_up_output_ptr_; // [expert_num]
std::vector<ggml_bf16_t*> m_local_down_output_ptr_; // [expert_num]
std::vector<std::shared_ptr<typename T::BufferA>> gate_up_ba_;
std::vector<std::shared_ptr<typename T::BufferB>> gate_bb_;
std::vector<std::shared_ptr<typename T::BufferC>> gate_bc_;
std::vector<std::shared_ptr<typename T::BufferB>> up_bb_;
std::vector<std::shared_ptr<typename T::BufferC>> up_bc_;
std::vector<std::shared_ptr<typename T::BufferA>> down_ba_;
std::vector<std::shared_ptr<typename T::BufferB>> down_bb_;
std::vector<std::shared_ptr<typename T::BufferC>> down_bc_;
#ifdef CHECK
char verify_bb[100000000];
char check_bb[100000000];
uint8_t compare_expers = 3;
#endif
#ifdef CHECK
inline void load_check() {
// TODO: implement load_check for verification.
}
void verify_load_right() {
// TODO: implement verification helpers.
}
#endif
inline void dump_buffer_b(const std::string &quantization_type, int expert_idx, const std::string &matrix_type,
typename T::BufferB *buffer) {
auto &quant_config = config_.quant_config;
int &group_size = quant_config.group_size;
printf("[DUMP_BUFFER_B] TP%d %s Expert%d %s:\n", tp_part_idx, quantization_type.c_str(), expert_idx,
matrix_type.c_str());
// Calculate dimensions based on matrix type
int rows, cols, num_groups;
size_t scale_elem_count;
if (matrix_type == "gate" || matrix_type == "up") {
rows = config_.intermediate_size;
cols = config_.hidden_size;
num_groups = cols / group_size;
scale_elem_count = num_groups * rows;
} else { // down
rows = config_.hidden_size;
cols = config_.intermediate_size;
num_groups = cols / group_size;
scale_elem_count = num_groups * rows;
}
// Dump scales (as float)
printf(" Scales[first 16]: ");
for (int i = 0; i < std::min(16, (int)scale_elem_count); i++) {
printf("%.6f ", buffer->d[i]);
}
printf("\n");
if (scale_elem_count > 16) {
printf(" Scales[last 16]: ");
int start_idx = std::max(0, (int)scale_elem_count - 16);
for (int i = start_idx; i < (int)scale_elem_count; i++) {
printf("%.6f ", buffer->d[i]);
}
printf("\n");
}
// Dump quantized weights (as hex uint8)
size_t weight_size = (rows * cols) / 2; // INT4 packed
uint8_t *weight_ptr = (uint8_t *)buffer->b;
printf(" Weights[first 32 bytes]: ");
for (int i = 0; i < std::min(32, (int)weight_size); i++) {
printf("%02x ", weight_ptr[i]);
}
printf("\n");
if (weight_size > 32) {
printf(" Weights[last 32 bytes]: ");
int start_idx = std::max(32, (int)weight_size - 32);
for (int i = start_idx; i < (int)weight_size; i++) {
printf("%02x ", weight_ptr[i]);
}
printf("\n");
}
printf(" Matrix dimensions: %dx%d, Groups: %d, Group size: %d, Scale elements: %zu\n", rows, cols, num_groups,
group_size, scale_elem_count);
printf("\n");
fflush(stdout);
}
#ifdef FORWARD_TIME_REPORT
std::chrono::time_point<std::chrono::high_resolution_clock> last_now;
#endif
public:
using input_t = ggml_bf16_t;
using output_t = float;
GeneralMOEConfig config_;
static constexpr double ELEMENT_SIZE = T::ELEMENT_SIZE;
AMX_K2_MOE_TP(GeneralMOEConfig config, int tp_part_idx_) {
auto& quant_config = config.quant_config;
int& group_size = quant_config.group_size;
if (quant_config.group_size == 0 || quant_config.zero_point) {
throw std::runtime_error("Kimi-K2 MoE only support KGroup Int4");
}
printf("Creating AMX_K2_MOE_TP %d at numa %d\n", tp_part_idx_, numa_node_of_cpu(sched_getcpu()));
auto& load = config.load;
auto& save = config.save;
if (load && config.path == "") {
load = false;
}
this->tp_part_idx = tp_part_idx_;
config_ = config;
gate_proj_ = config_.gate_proj;
up_proj_ = config_.up_proj;
down_proj_ = config_.down_proj;
MemoryRequest mem_requests;
mem_requests.append_pointer(
&m_local_input_, sizeof(ggml_bf16_t) * config_.num_experts_per_tok * config_.max_len * config_.hidden_size);
mem_requests.append_pointer(&m_local_gate_output_, sizeof(ggml_bf16_t) * config_.num_experts_per_tok *
config_.max_len * config_.intermediate_size);
mem_requests.append_pointer(&m_local_up_output_, sizeof(ggml_bf16_t) * config_.num_experts_per_tok *
config_.max_len * config_.intermediate_size);
mem_requests.append_pointer(&m_local_down_output_, sizeof(ggml_bf16_t) * config_.num_experts_per_tok *
config_.max_len * config_.hidden_size);
m_local_pos_.resize(config_.max_len);
for (int i = 0; i < config_.max_len; i++) {
m_local_pos_[i].resize(config_.num_experts_per_tok);
}
m_expert_id_map_.resize(config_.expert_num);
m_local_num_.resize(config_.expert_num);
m_local_input_ptr_.resize(config_.expert_num);
m_local_gate_output_ptr_.resize(config_.expert_num);
m_local_up_output_ptr_.resize(config_.expert_num);
m_local_down_output_ptr_.resize(config_.expert_num);
for (size_t i = 0; i < config_.expert_num; i++) {
gate_up_ba_.push_back(
std::make_shared<typename T::BufferA>(config_.max_len, config_.hidden_size, group_size, nullptr));
gate_bc_.push_back(std::make_shared<typename T::BufferC>(config_.max_len, config_.intermediate_size, nullptr));
up_bc_.push_back(std::make_shared<typename T::BufferC>(config_.max_len, config_.intermediate_size, nullptr));
down_ba_.push_back(
std::make_shared<typename T::BufferA>(config_.max_len, config_.intermediate_size, group_size, nullptr));
down_bc_.push_back(std::make_shared<typename T::BufferC>(config_.max_len, config_.hidden_size, nullptr));
void* gate_bb_ptr =
std::aligned_alloc(64, T::BufferB::required_size(config_.intermediate_size, config_.hidden_size, group_size));
gate_bb_.push_back(std::make_shared<typename T::BufferB>(config_.intermediate_size, config_.hidden_size,
group_size, gate_bb_ptr));
void* up_bb_ptr =
std::aligned_alloc(64, T::BufferB::required_size(config_.intermediate_size, config_.hidden_size, group_size));
up_bb_.push_back(
std::make_shared<typename T::BufferB>(config_.intermediate_size, config_.hidden_size, group_size, up_bb_ptr));
void* down_bb_ptr =
std::aligned_alloc(64, T::BufferB::required_size(config_.hidden_size, config_.intermediate_size, group_size));
down_bb_.push_back(std::make_shared<typename T::BufferB>(config_.hidden_size, config_.intermediate_size,
group_size, down_bb_ptr));
}
for (int i = 0; i < config_.expert_num; i++) {
mem_requests.append_function([this, i](void* new_ptr) { gate_up_ba_[i]->set_data(new_ptr); },
T::BufferA::required_size(config_.max_len, config_.hidden_size, group_size));
mem_requests.append_function([this, i](void* new_ptr) { gate_bc_[i]->set_data(new_ptr); },
T::BufferC::required_size(config_.max_len, config_.intermediate_size));
mem_requests.append_function([this, i](void* new_ptr) { up_bc_[i]->set_data(new_ptr); },
T::BufferC::required_size(config_.max_len, config_.intermediate_size));
mem_requests.append_function([this, i](void* new_ptr) { down_ba_[i]->set_data(new_ptr); },
T::BufferA::required_size(config_.max_len, config_.intermediate_size, group_size));
mem_requests.append_function([this, i](void* new_ptr) { down_bc_[i]->set_data(new_ptr); },
T::BufferC::required_size(config_.max_len, config_.hidden_size));
}
shared_mem_buffer_numa.alloc(tp_part_idx, this, mem_requests);
}
~AMX_K2_MOE_TP() = default;
void load_weights() {
auto& quant_config = config_.quant_config;
int& group_size = quant_config.group_size;
const uint64_t* physical_to_logical_map = (const uint64_t*)config_.physical_to_logical_map;
auto pool = config_.pool->get_subpool(tp_part_idx);
if (quant_config.group_size == 0 || quant_config.zero_point) {
throw std::runtime_error("Kimi AVX MOE only support KGroup Int4.");
}
if (config_.gate_scale == nullptr) {
throw std::runtime_error("Kimi AVX MOE only support load native weight.");
}
// load weight
int nth = T::recommended_nth(config_.intermediate_size);
pool->do_work_stealing_job(
nth * config_.expert_num, nullptr,
[this, nth, physical_to_logical_map](int task_id) {
uint64_t expert_idx = task_id / nth;
uint64_t logical_expert_id = expert_map(physical_to_logical_map, expert_idx);
int ith = task_id % nth;
// gate part
gate_bb_[expert_idx]->from_raw_mat(
(uint8_t*)config_.gate_proj +
((logical_expert_id * config_.intermediate_size * config_.hidden_size) >> 1),
ith, nth);
// up part
up_bb_[expert_idx]->from_raw_mat(
(uint8_t*)config_.up_proj +
((logical_expert_id * config_.intermediate_size * config_.hidden_size) >> 1),
ith, nth);
},
nullptr);
nth = T::recommended_nth(config_.hidden_size);
pool->do_work_stealing_job(
nth * config_.expert_num, nullptr,
[this, nth, physical_to_logical_map](int task_id) {
uint64_t expert_idx = task_id / nth;
uint64_t logical_expert_id = expert_map(physical_to_logical_map, expert_idx);
int ith = task_id % nth;
// down part
down_bb_[expert_idx]->from_raw_mat(
(uint8_t*)config_.down_proj +
((logical_expert_id * config_.hidden_size * config_.intermediate_size) >> 1),
ith, nth);
},
nullptr);
pool->do_work_stealing_job(
config_.expert_num, nullptr,
[this, physical_to_logical_map](int task_id) {
uint64_t expert_idx = task_id;
uint64_t logical_expert_id = expert_map(physical_to_logical_map, expert_idx);
size_t scale_elem_count =
(config_.hidden_size * config_.intermediate_size) / config_.quant_config.group_size;
// convert scales from BF16 to FP32
convert_or_copy(gate_bb_[expert_idx]->d,
(ggml_bf16_t*)config_.gate_scale + (logical_expert_id * scale_elem_count),
scale_elem_count);
convert_or_copy(up_bb_[expert_idx]->d,
(ggml_bf16_t*)config_.up_scale + (logical_expert_id * scale_elem_count),
scale_elem_count);
convert_or_copy(down_bb_[expert_idx]->d,
(ggml_bf16_t*)config_.down_scale + (logical_expert_id * scale_elem_count),
scale_elem_count);
},
nullptr);
// dump_buffer_b("native", 0, "down", down_bb_[0].get());
}
// Reconstruct weights for all experts to the output buffers
// This function handles the TP-specific portion of the reconstruction for all experts
void write_weights_to_buffer(int gpu_tp_count, int cpu_tp_count, int num_experts, const GeneralMOEConfig& full_config,
const std::vector<uintptr_t>& w13_weight_ptrs,
const std::vector<uintptr_t>& w13_scale_ptrs,
const std::vector<uintptr_t>& w2_weight_ptrs,
const std::vector<uintptr_t>& w2_scale_ptrs) const {
const int group_size = config_.quant_config.group_size;
auto pool = config_.pool->get_subpool(tp_part_idx);
// Calculate sizes for CPU TP part (this instance)
size_t cpu_tp_weight_elem_count = (size_t)config_.intermediate_size * config_.hidden_size;
size_t cpu_tp_weight_bytes = cpu_tp_weight_elem_count / 2; // int4 packing
size_t cpu_tp_scale_elem_count = cpu_tp_weight_elem_count / group_size;
// Calculate sizes for GPU TP part
size_t gpu_tp_weight_elem_count = (size_t)full_config.intermediate_size * full_config.hidden_size / gpu_tp_count;
size_t gpu_tp_weight_bytes = gpu_tp_weight_elem_count / 2; // int4 packing
size_t gpu_tp_scale_elem_count = gpu_tp_weight_elem_count / group_size;
// Determine mapping: which GPU TP parts should this CPU TP part write to?
// Since weights are col-major and we slice directly by memory order:
// - If cpu_tp_count >= gpu_tp_count: multiple(or one) CPU TPs write to one GPU TP
// - If cpu_tp_count < gpu_tp_count: one CPU TP writes to multiple GPU TPs
if (cpu_tp_count >= gpu_tp_count) {
// Multiple CPU TPs map to one GPU TP
int target_gpu_tp = tp_part_idx / (cpu_tp_count / gpu_tp_count);
int local_idx = tp_part_idx % (cpu_tp_count / gpu_tp_count);
// Get pointers for this GPU TP part
uint8_t* w13_weight_dst = (uint8_t*)w13_weight_ptrs[target_gpu_tp];
ggml_bf16_t* w13_scale_dst = (ggml_bf16_t*)w13_scale_ptrs[target_gpu_tp];
uint8_t* w2_weight_dst = (uint8_t*)w2_weight_ptrs[target_gpu_tp];
ggml_bf16_t* w2_scale_dst = (ggml_bf16_t*)w2_scale_ptrs[target_gpu_tp];
// Calculate offset within the GPU TP buffer
size_t offset_in_gpu_weight = local_idx * cpu_tp_weight_bytes;
size_t offset_in_gpu_scale = local_idx * cpu_tp_scale_elem_count;
// Process only the first num_experts experts (GPU experts)
int nth = T::recommended_nth(config_.intermediate_size);
nth = 1;
pool->do_work_stealing_job(
nth * num_experts, nullptr,
[&, this](int task_id) {
int expert_id = task_id / nth;
// int ith = task_id % nth;
// auto [n_start, n_end] = T::split_range_n(config_.intermediate_size, ith, nth);
// Calculate base offsets for this expert in the GPU buffers
// For w13: each expert has gate+up, so the offset needs to account for 2x size
size_t w13_expert_base_weight = expert_id * 2 * gpu_tp_weight_bytes;
size_t w13_expert_base_scale = expert_id * 2 * gpu_tp_scale_elem_count;
size_t w2_expert_base_weight = expert_id * gpu_tp_weight_bytes;
size_t w2_expert_base_scale = expert_id * gpu_tp_scale_elem_count;
// Gate (first part of w13 for this expert)
uint8_t* gate_weight_src = (uint8_t*)gate_bb_[expert_id]->b;
float* gate_scale_src = gate_bb_[expert_id]->d;
std::memcpy(w13_weight_dst + w13_expert_base_weight + offset_in_gpu_weight,
gate_weight_src, cpu_tp_weight_bytes);
convert_or_copy((ggml_bf16_t*)(w13_scale_dst + w13_expert_base_scale + offset_in_gpu_scale),
gate_scale_src, cpu_tp_scale_elem_count);
// Up (second part of w13 for this expert, immediately after gate)
uint8_t* up_weight_src = (uint8_t*)up_bb_[expert_id]->b;
float* up_scale_src = up_bb_[expert_id]->d;
std::memcpy(w13_weight_dst + w13_expert_base_weight + offset_in_gpu_weight + gpu_tp_weight_bytes,
up_weight_src, cpu_tp_weight_bytes);
convert_or_copy((ggml_bf16_t*)(w13_scale_dst + w13_expert_base_scale + offset_in_gpu_scale + gpu_tp_scale_elem_count),
up_scale_src, cpu_tp_scale_elem_count);
// Down (w2) - need to handle column-wise slicing
// The down matrix is transposed compared to gate/up, so we need to extract by columns
// When multiple CPU TPs map to one GPU TP, each CPU TP has a slice of intermediate dimension
// CPU TP internal layout: each column has config_.intermediate_size elements
// GPU expects: each column has full_config.intermediate_size elements
size_t cpu_tps_per_gpu = cpu_tp_count / gpu_tp_count;
for (size_t col = 0; col < config_.hidden_size; col++) {
// GPU buffer column width is full_config.intermediate_size / gpu_tp_count
size_t gpu_col_offset = col * ((full_config.intermediate_size / gpu_tp_count) >> 1);
size_t cpu_col_offset = col * (config_.intermediate_size >> 1);
size_t gpu_col_slice_offset = local_idx * (config_.intermediate_size >> 1);
std::memcpy(w2_weight_dst + w2_expert_base_weight + gpu_col_offset + gpu_col_slice_offset,
(uint8_t*)down_bb_[expert_id]->b + cpu_col_offset,
config_.intermediate_size / 2);
// Same for scales
size_t gpu_scale_col_offset = col * ((full_config.intermediate_size / gpu_tp_count) / group_size);
size_t cpu_scale_col_offset = col * (config_.intermediate_size / group_size);
size_t gpu_scale_slice_offset = local_idx * (config_.intermediate_size / group_size);
convert_or_copy((ggml_bf16_t*)(w2_scale_dst + w2_expert_base_scale + gpu_scale_col_offset + gpu_scale_slice_offset),
down_bb_[expert_id]->d + cpu_scale_col_offset,
config_.intermediate_size / group_size);
}
},
nullptr);
} else {
// cpu_tp_count < gpu_tp_count: one CPU TP writes to multiple GPU TPs
// Each CPU TP part contains data for multiple GPU TP parts
int gpu_tps_per_cpu_tp = gpu_tp_count / cpu_tp_count;
// This CPU TP part writes to GPU TP indices: [start_gpu_tp, start_gpu_tp + gpu_tps_per_cpu_tp)
int start_gpu_tp = tp_part_idx * gpu_tps_per_cpu_tp;
// Size of data per GPU TP within this CPU TP
size_t data_per_gpu_tp_weight = cpu_tp_weight_bytes / gpu_tps_per_cpu_tp;
size_t data_per_gpu_tp_scale = cpu_tp_scale_elem_count / gpu_tps_per_cpu_tp;
// Process all experts for this GPU TP
pool->do_work_stealing_job(
gpu_tps_per_cpu_tp * num_experts, nullptr,
[&, this](int task_id) {
int expert_id = task_id % num_experts;
int local_gpu_idx = task_id / num_experts;
int gpu_tp_idx = start_gpu_tp + local_gpu_idx;
// Get pointers for this GPU TP part
uint8_t* w13_weight_dst = (uint8_t*)w13_weight_ptrs[gpu_tp_idx];
ggml_bf16_t* w13_scale_dst = (ggml_bf16_t*)w13_scale_ptrs[gpu_tp_idx];
uint8_t* w2_weight_dst = (uint8_t*)w2_weight_ptrs[gpu_tp_idx];
ggml_bf16_t* w2_scale_dst = (ggml_bf16_t*)w2_scale_ptrs[gpu_tp_idx];
// Calculate offsets within CPU TP buffers
size_t cpu_offset_weight = local_gpu_idx * data_per_gpu_tp_weight;
size_t cpu_offset_scale = local_gpu_idx * data_per_gpu_tp_scale;
// Calculate offsets for this expert in GPU buffers
// For w13: each expert has gate+up, so the offset needs to account for 2x size
size_t w13_gpu_expert_offset_weight = expert_id * 2 * gpu_tp_weight_bytes;
size_t w13_gpu_expert_offset_scale = expert_id * 2 * gpu_tp_scale_elem_count;
size_t w2_gpu_expert_offset_weight = expert_id * gpu_tp_weight_bytes;
size_t w2_gpu_expert_offset_scale = expert_id * gpu_tp_scale_elem_count;
// Gate (first part of w13 for this expert)
uint8_t* gate_weight_src = (uint8_t*)gate_bb_[expert_id]->b + cpu_offset_weight;
float* gate_scale_src = gate_bb_[expert_id]->d + cpu_offset_scale;
std::memcpy(w13_weight_dst + w13_gpu_expert_offset_weight,
gate_weight_src, data_per_gpu_tp_weight);
convert_or_copy((ggml_bf16_t*)(w13_scale_dst + w13_gpu_expert_offset_scale),
gate_scale_src, data_per_gpu_tp_scale);
// Up (second part of w13 for this expert, immediately after gate)
uint8_t* up_weight_src = (uint8_t*)up_bb_[expert_id]->b + cpu_offset_weight;
float* up_scale_src = up_bb_[expert_id]->d + cpu_offset_scale;
std::memcpy(w13_weight_dst + w13_gpu_expert_offset_weight + gpu_tp_weight_bytes,
up_weight_src, data_per_gpu_tp_weight);
convert_or_copy((ggml_bf16_t*)(w13_scale_dst + w13_gpu_expert_offset_scale + gpu_tp_scale_elem_count),
up_scale_src, data_per_gpu_tp_scale);
// Down (w2) - need to handle column-wise slicing
// The down matrix is transposed compared to gate/up, so we need to extract by columns
for (size_t col = 0; col < config_.hidden_size; col++) {
// Calculate the offset within the column for this GPU TP part
size_t col_offset_weight = (col * config_.intermediate_size / 2) + (local_gpu_idx * data_per_gpu_tp_weight / config_.hidden_size);
size_t col_offset_scale = (col * (config_.intermediate_size / group_size)) + (local_gpu_idx * data_per_gpu_tp_scale / config_.hidden_size);
// Copy weights column by column
std::memcpy(w2_weight_dst + w2_gpu_expert_offset_weight + (col * (config_.intermediate_size / gpu_tps_per_cpu_tp) / 2),
(uint8_t*)down_bb_[expert_id]->b + col_offset_weight,
(config_.intermediate_size / gpu_tps_per_cpu_tp) / 2);
// Copy scales column by column
convert_or_copy((ggml_bf16_t*)(w2_scale_dst + w2_gpu_expert_offset_scale + col * ((config_.intermediate_size / gpu_tps_per_cpu_tp) / group_size)),
down_bb_[expert_id]->d + col_offset_scale,
(config_.intermediate_size / gpu_tps_per_cpu_tp) / group_size);
}
},
nullptr);
}
}
void warm_up() {
int qlen = config_.max_len;
std::vector<uint8_t> input(sizeof(ggml_bf16_t) * qlen * config_.hidden_size);
std::vector<uint8_t> output(sizeof(ggml_bf16_t) * qlen * config_.hidden_size);
std::vector<int64_t> expert_ids(qlen * config_.num_experts_per_tok);
std::vector<float> weights(qlen * config_.num_experts_per_tok);
for (int i = 0; i < qlen * config_.num_experts_per_tok; i++) {
expert_ids[i] = i % config_.expert_num;
weights[i] = 0.01;
}
forward(qlen, config_.num_experts_per_tok, expert_ids.data(), weights.data(), input.data(), output.data());
}
void forward(int qlen, int k, const int64_t* expert_ids, const float* weights, const void* input, void* output) {
if (qlen > 1) {
forward_prefill(qlen, k, expert_ids, weights, input, output);
} else {
forward_decode(k, expert_ids, weights, input, output);
}
}
#ifndef DIRECT_OR_POOL_BY_QLEN
#define DIRECT_OR_POOL_BY_QLEN(var, fn) \
do { \
if (qlen < 10) { \
for (int i = 0; i < (var); i++) { \
(fn)(i); \
} \
} else { \
pool->do_work_stealing_job((var), nullptr, (fn), nullptr); \
} \
} while (0)
#endif
void forward_prefill(int qlen, int k, const int64_t* expert_ids, const float* weights, const void* input,
void* output) {
for (int i = 0; i < qlen; i ++)
forward_decode(k, expert_ids + i * k, weights + i * k, (ggml_bf16_t*)input + i * config_.hidden_size, (float*)output + i * config_.hidden_size);
}
void forward_decode(int k, const int64_t* expert_ids, const float* weights, const void* input, void* output) {
int qlen = 1;
auto pool = config_.pool->get_subpool(tp_part_idx);
auto& quant_config = config_.quant_config;
int& group_size = quant_config.group_size;
#ifdef FORWARD_TIME_PROFILE
auto start_time = std::chrono::high_resolution_clock::now();
auto last = start_time;
// 用于保存各阶段耗时(单位:微秒)
long prepare_time = 0, cpy_input_time = 0, q_input_time = 0, up_gate_time = 0;
long act_time = 0, q_down_time = 0, down_time = 0, weight_time = 0;
int max_local_num = 0; // 记录最大的 local num
#endif
int activated_expert = 0;
for (int i = 0; i < k; i++) {
if (expert_ids[i] < config_.num_gpu_experts || expert_ids[i] >= config_.expert_num) {
continue;
}
m_expert_id_map_[activated_expert] = expert_ids[i];
activated_expert++;
}
size_t offset = 0;
for (int i = 0; i < activated_expert; i++) {
auto expert_idx = m_expert_id_map_[i];
m_local_gate_output_ptr_[expert_idx] = m_local_gate_output_ + offset * config_.intermediate_size;
m_local_up_output_ptr_[expert_idx] = m_local_up_output_ + offset * config_.intermediate_size;
m_local_down_output_ptr_[expert_idx] = m_local_down_output_ + offset * config_.hidden_size;
offset += qlen;
}
gate_up_ba_[0]->from_mat(qlen, (ggml_bf16_t*)input, 0, 1);
#ifdef FORWARD_TIME_PROFILE
{
auto now_time = std::chrono::high_resolution_clock::now();
q_input_time = std::chrono::duration_cast<std::chrono::microseconds>(now_time - last).count();
last = now_time;
}
#endif
// calc gate & up
int nth = T::recommended_nth(config_.intermediate_size);
pool->do_work_stealing_job(
nth * activated_expert * 2, [](int _) { T::config(); },
[this, nth, qlen](int task_id2) {
int& group_size = config_.quant_config.group_size;
int task_id = task_id2 / 2;
bool do_up = task_id2 % 2;
int expert_idx = m_expert_id_map_[task_id / nth];
int ith = task_id % nth;
if (do_up) {
amx::vec_mul_kgroup(qlen, config_.intermediate_size, config_.hidden_size, group_size, gate_up_ba_[0],
up_bb_[expert_idx], up_bc_[expert_idx], ith, nth);
up_bc_[expert_idx]->to_mat(qlen, m_local_up_output_ptr_[expert_idx], ith, nth);
} else {
amx::vec_mul_kgroup(qlen, config_.intermediate_size, config_.hidden_size, group_size, gate_up_ba_[0],
gate_bb_[expert_idx], gate_bc_[expert_idx], ith, nth);
gate_bc_[expert_idx]->to_mat(qlen, m_local_gate_output_ptr_[expert_idx], ith, nth);
}
},
nullptr);
#ifdef DEBUG_K2_MOE
if (activated_expert > 0) {
int print_elems = std::min(config_.intermediate_size, 16);
for (int dbg = 0; dbg < activated_expert; ++dbg) {
int sample_expert = m_expert_id_map_[dbg];
ggml_bf16_t* gate_ptr = m_local_gate_output_ptr_[sample_expert];
if (gate_ptr == nullptr) {
continue;
}
printf("[K2][TP %d] gate_out (expert %d, first %d elems): ", tp_part_idx, sample_expert, print_elems);
for (int idx = 0; idx < print_elems; idx++) {
float val = ggml_bf16_to_fp32(gate_ptr[idx]);
printf("%.6f ", val);
}
printf("\n");
int tail_start = config_.intermediate_size > print_elems ? config_.intermediate_size - print_elems : 0;
printf("[K2][TP %d] gate_out (expert %d, last %d elems): ", tp_part_idx, sample_expert, print_elems);
for (int idx = 0; idx < print_elems; idx++) {
float val = ggml_bf16_to_fp32(gate_ptr[tail_start + idx]);
printf("%.6f ", val);
}
printf("\n");
}
}
#endif
#ifdef FORWARD_TIME_PROFILE
{
auto now_time = std::chrono::high_resolution_clock::now();
up_gate_time = std::chrono::duration_cast<std::chrono::microseconds>(now_time - last).count();
last = now_time;
}
#endif
// act
for (int task_id = 0; task_id < nth * activated_expert; task_id++) {
int expert_idx = m_expert_id_map_[task_id / nth];
int ith = task_id % nth;
auto [n_start, n_end] = T::split_range_n(config_.intermediate_size, ith, nth);
for (int i = 0; i < qlen; i++) {
ggml_bf16_t* gate_output_ptr = &m_local_gate_output_ptr_[expert_idx][i * config_.intermediate_size];
ggml_bf16_t* up_output_ptr = &m_local_up_output_ptr_[expert_idx][i * config_.intermediate_size];
for (int j = n_start; j < n_end; j += 32) {
__m512 gate_val0, gate_val1, up_val0, up_val1;
avx512_32xbf16_to_32xfp32((__m512i*)(gate_output_ptr + j), &gate_val0, &gate_val1);
avx512_32xbf16_to_32xfp32((__m512i*)(up_output_ptr + j), &up_val0, &up_val1);
__m512 result0 = amx::act_fn(gate_val0, up_val0);
__m512 result1 = amx::act_fn(gate_val1, up_val1);
avx512_32xfp32_to_32xbf16(&result0, &result1, (__m512i*)(gate_output_ptr + j));
}
}
}
#ifdef FORWARD_TIME_PROFILE
{
auto now_time = std::chrono::high_resolution_clock::now();
act_time = std::chrono::duration_cast<std::chrono::microseconds>(now_time - last).count();
last = now_time;
}
#endif
// quant, get down a
pool->do_work_stealing_job(
activated_expert, nullptr,
[this, qlen](int task_id) {
int expert_idx = m_expert_id_map_[task_id];
down_ba_[expert_idx]->from_mat(qlen, m_local_gate_output_ptr_[expert_idx], 0, 1);
},
nullptr);
#ifdef FORWARD_TIME_PROFILE
{
auto now_time = std::chrono::high_resolution_clock::now();
q_down_time = std::chrono::duration_cast<std::chrono::microseconds>(now_time - last).count();
last = now_time;
}
#endif
// * down
nth = T::recommended_nth(config_.hidden_size);
pool->do_work_stealing_job(
nth * activated_expert, [](int _) { T::config(); },
[this, nth, qlen](int task_id) {
int& group_size = config_.quant_config.group_size;
int expert_idx = m_expert_id_map_[task_id / nth];
int ith = task_id % nth;
amx::vec_mul_kgroup(qlen, config_.hidden_size, config_.intermediate_size, group_size, down_ba_[expert_idx],
down_bb_[expert_idx], down_bc_[expert_idx], ith, nth);
down_bc_[expert_idx]->to_mat(qlen, m_local_down_output_ptr_[expert_idx], ith, nth);
},
nullptr);
#ifdef DEBUG_K2_MOE
if (activated_expert > 0) {
int print_elems = std::min(config_.hidden_size, 16);
for (int dbg = 0; dbg < activated_expert; ++dbg) {
int sample_expert = m_expert_id_map_[dbg];
ggml_bf16_t* down_ptr = m_local_down_output_ptr_[sample_expert];
if (down_ptr == nullptr) {
continue;
}
printf("[K2][TP %d] down_out (expert %d, first %d elems): ", tp_part_idx, sample_expert, print_elems);
for (int idx = 0; idx < print_elems; idx++) {
float val = ggml_bf16_to_fp32(down_ptr[idx]);
printf("%.6f ", val);
}
printf("\n");
int tail_start = config_.hidden_size > print_elems ? config_.hidden_size - print_elems : 0;
printf("[K2][TP %d] down_out (expert %d, last %d elems): ", tp_part_idx, sample_expert, print_elems);
for (int idx = 0; idx < print_elems; idx++) {
float val = ggml_bf16_to_fp32(down_ptr[tail_start + idx]);
printf("%.6f ", val);
}
printf("\n");
}
}
#endif
#ifdef FORWARD_TIME_PROFILE
{
auto now_time = std::chrono::high_resolution_clock::now();
down_time = std::chrono::duration_cast<std::chrono::microseconds>(now_time - last).count();
last = now_time;
}
#endif
// get output
for (int e = 0; e < config_.hidden_size; e += 32) {
__m512 x0 = _mm512_setzero_ps();
__m512 x1 = _mm512_setzero_ps();
for (int j = 0; j < k; j++) {
if (expert_ids[j] < config_.num_gpu_experts || expert_ids[j] >= config_.expert_num) {
continue;
}
__m512 weight = _mm512_set1_ps(weights[j]);
__m512 down_output0, down_output1;
avx512_32xbf16_to_32xfp32((__m512i*)(m_local_down_output_ptr_[expert_ids[j]] +
m_local_pos_[0][j] * config_.hidden_size + e),
&down_output0, &down_output1);
x0 = _mm512_fmadd_ps(down_output0, weight, x0);
x1 = _mm512_fmadd_ps(down_output1, weight, x1);
}
auto f32out = (__m512*)((float*)output + e);
f32out[0] = x0;
f32out[1] = x1;
}
#ifdef FORWARD_TIME_PROFILE
{
auto now_time = std::chrono::high_resolution_clock::now();
weight_time = std::chrono::duration_cast<std::chrono::microseconds>(now_time - last).count();
last = now_time;
}
auto end_time = std::chrono::high_resolution_clock::now();
auto forward_total_time = std::chrono::duration_cast<std::chrono::microseconds>(end_time - start_time).count();
// 在函数末尾一次性打印所有阶段的耗时,并附带 max_local_num 和 qlen
printf(
"Profiling Results (numa[%d]): activated_expert: %d, q_input: %ld us, "
"up_gate: %ld us, act: %ld us, q_down: %ld us, down: %ld us, weight: %ld us, total: %ld us\n",
tp_part_idx, activated_expert, q_input_time, up_gate_time, act_time, q_down_time, down_time, weight_time,
forward_total_time);
#endif
}
};
template <typename K>
class TP_MOE<AMX_K2_MOE_TP<K>> : public TP_MOE_Common<AMX_K2_MOE_TP<K>> {
public:
using TP_MOE_Common<AMX_K2_MOE_TP<K>>::TP_MOE_Common;
void load_weights() {
auto& config = this->config;
auto& tps = this->tps;
auto& tp_count = this->tp_count;
auto pool = config.pool;
const uint64_t* physical_to_logical_map = (const uint64_t*)config.physical_to_logical_map;
if (config.gate_scale == nullptr) {
throw std::runtime_error("K2 MoE only supports Packed Int4 with KGroup Scale");
}
printf("From Packed Int4 with KGroup Scale\n");
int& group_size = config.quant_config.group_size;
for (auto i = 0; i < tp_count; i++) {
auto& tpc = tps[i]->config_;
size_t weight_elem_count = tpc.intermediate_size * tpc.hidden_size;
tpc.gate_proj = new uint8_t[(tpc.expert_num * weight_elem_count) / 2];
tpc.up_proj = new uint8_t[(tpc.expert_num * weight_elem_count) / 2];
tpc.down_proj = new uint8_t[(tpc.expert_num * weight_elem_count) / 2];
size_t scales_elem_count = (tpc.hidden_size / group_size) * tpc.intermediate_size;
tpc.gate_scale = new ggml_bf16_t[(tpc.expert_num * scales_elem_count)];
tpc.up_scale = new ggml_bf16_t[(tpc.expert_num * scales_elem_count)];
tpc.down_scale = new ggml_bf16_t[(tpc.expert_num * scales_elem_count)];
if (tps[i]->config_.load == false) {
pool->get_subpool(i)->do_work_stealing_job(
tpc.expert_num, nullptr,
[&](int expert_id_) { // weight and scale are all in col majored.
size_t expert_id = expert_map(physical_to_logical_map, expert_id_);
// weight and scale TP-slicing for gate and up
memcpy((uint8_t*)tpc.gate_proj + ((expert_id * weight_elem_count) >> 1),
(uint8_t*)config.gate_proj +
((expert_id * config.intermediate_size * config.hidden_size + i * weight_elem_count) >> 1),
((sizeof(uint8_t) * weight_elem_count) >> 1));
memcpy((uint8_t*)tpc.up_proj + ((expert_id * weight_elem_count) >> 1),
(uint8_t*)config.up_proj +
((expert_id * config.intermediate_size * config.hidden_size + i * weight_elem_count) >> 1),
((sizeof(uint8_t) * weight_elem_count) >> 1));
memcpy((ggml_bf16_t*)tpc.gate_scale + (expert_id * scales_elem_count),
(ggml_bf16_t*)config.gate_scale +
(expert_id * (config.hidden_size / group_size) * config.intermediate_size +
i * scales_elem_count),
sizeof(ggml_bf16_t) * scales_elem_count);
memcpy((ggml_bf16_t*)tpc.up_scale + (expert_id * scales_elem_count),
(ggml_bf16_t*)config.up_scale +
(expert_id * (config.hidden_size / group_size) * config.intermediate_size +
i * scales_elem_count),
sizeof(ggml_bf16_t) * scales_elem_count);
// memcpy((uint8_t*)tpc.down_proj + ((expert_id * weight_elem_count) >> 1),
// (uint8_t*)config.down_proj +
// ((expert_id * config.intermediate_size * config.hidden_size + i * weight_elem_count) >> 1),
// ((sizeof(uint8_t) * weight_elem_count) >> 1));
// memcpy((ggml_bf16_t*)tpc.down_scale + (expert_id * scales_elem_count),
// (ggml_bf16_t*)config.down_scale +
// (expert_id * (config.intermediate_size / group_size) * config.hidden_size +
// i * scales_elem_count),
// sizeof(ggml_bf16_t) * scales_elem_count);
// weight and scale TP-slicing for down (by column)
for (size_t col = 0; col < config.hidden_size; col++) {
memcpy((uint8_t*)tpc.down_proj + ((expert_id * weight_elem_count + col * tpc.intermediate_size) >> 1),
(uint8_t*)config.down_proj + ((expert_id * config.intermediate_size * config.hidden_size +
col * config.intermediate_size + i * tpc.intermediate_size) >>
1),
(sizeof(uint8_t) * tpc.intermediate_size) >> 1);
memcpy((ggml_bf16_t*)tpc.down_scale + (expert_id * scales_elem_count + col * (tpc.intermediate_size / group_size)),
(ggml_bf16_t*)config.down_scale + ((expert_id * (config.intermediate_size / group_size) * config.hidden_size) +
col * (config.intermediate_size / group_size) + i * (tpc.intermediate_size / group_size)),
sizeof(ggml_bf16_t) * (tpc.intermediate_size / group_size));
}
},
nullptr);
}
printf("TP %d load weight done.\n", i);
}
DO_TPS_LOAD_WEIGHTS(pool);
for (auto i = 0; i < tp_count; i++) {
auto& tpc = tps[i]->config_;
delete[] (uint8_t*)(tpc.gate_proj);
delete[] (uint8_t*)(tpc.up_proj);
delete[] (uint8_t*)(tpc.down_proj);
delete[] (ggml_bf16_t*)(tpc.gate_scale);
delete[] (ggml_bf16_t*)(tpc.up_scale);
delete[] (ggml_bf16_t*)(tpc.down_scale);
}
this->weights_loaded = true;
}
void write_weight_scale_to_buffer(int gpu_tp_count, int gpu_experts_num,
const std::vector<uintptr_t>& w13_weight_ptrs,
const std::vector<uintptr_t>& w13_scale_ptrs,
const std::vector<uintptr_t>& w2_weight_ptrs,
const std::vector<uintptr_t>& w2_scale_ptrs) {
if (this->weights_loaded == false) {
throw std::runtime_error("Not Loaded");
}
if (this->tps.empty()) {
throw std::runtime_error("No TP parts initialized");
}
// Validate input vector sizes
if (w13_weight_ptrs.size() != gpu_tp_count || w13_scale_ptrs.size() != gpu_tp_count ||
w2_weight_ptrs.size() != gpu_tp_count || w2_scale_ptrs.size() != gpu_tp_count) {
throw std::runtime_error("Pointer arrays size must match gpu_tp_count");
}
// Each TP part writes to its corresponding buffer
for (int tp_idx = 0; tp_idx < this->tp_count; tp_idx++) {
// Note: w13 combines gate and up projections
// Split w13 pointers for gate and up
this->tps[tp_idx]->write_weights_to_buffer(
gpu_tp_count, this->tp_count,
gpu_experts_num, this->config,
w13_weight_ptrs, w13_scale_ptrs, //gate + up use w13
w2_weight_ptrs, w2_scale_ptrs); // down uses w2
}
}
void merge_results(int qlen, void* output, bool incremental) {
auto pool = this->config.pool;
auto merge_fn = [this, output, incremental](int token_nth) {
auto& local_output_numa = this->local_output_numa;
auto& tp_configs = this->tp_configs;
auto& tp_count = this->tp_count;
auto& config = this->config;
float* merge_to = local_output_numa[0] + token_nth * tp_configs[0].hidden_size;
if (incremental) {
for (int e = 0; e < config.hidden_size; e += 32) {
__m512 x0, x1;
avx512_32xbf16_to_32xfp32((__m512i*)((ggml_bf16_t*)output + token_nth * config.hidden_size + e), &x0, &x1);
*((__m512*)(merge_to + e)) = _mm512_add_ps(*((__m512*)(merge_to + e)), x0);
*((__m512*)(merge_to + e + 16)) = _mm512_add_ps(*((__m512*)(merge_to + e + 16)), x1);
}
}
for (int i = 1; i < tp_count; i++) {
float* merge_from = local_output_numa[i] + token_nth * tp_configs[i].hidden_size;
for (int e = 0; e < tp_configs[i].hidden_size; e += 16) {
*((__m512*)(merge_to + e)) = _mm512_add_ps(*((__m512*)(merge_to + e)), *((__m512*)(merge_from + e)));
}
}
for (int e = 0; e < config.hidden_size; e += 32) {
__m512 x0 = *(__m512*)(merge_to + e);
__m512 x1 = *(__m512*)(merge_to + e + 16);
avx512_32xfp32_to_32xbf16(&x0, &x1, (__m512i*)((ggml_bf16_t*)output + token_nth * config.hidden_size + e));
}
};
for (int i = 0; i < qlen; i++) {
merge_fn(i);
}
}
void merge_results(int qlen, void* output) { merge_results(qlen, output, false); }
};
#endif // CPUINFER_OPERATOR_AMX_K2_MOE_H

View file

@ -4,6 +4,7 @@
#include <cassert>
#include <cstdint>
#include <cstdio>
#include <cstring>
#include <limits>
#include <vector>
@ -344,9 +345,6 @@ struct BufferAKGroupImpl {
static constexpr int K_STEP = K::K_STEP;
static constexpr int K_BLOCK = K::K_BLOCK;
using index_t = Packed2DLayout::index_t;
Packed2DLayout pack;
static size_t required_size(int max_m, int k, int k_group_size) {
ASSERT_RELEASE(k % k_group_size == 0, "k must be multiple of k_group_size");
return sizeof(int8_t) * max_m * k + sizeof(float) * max_m * (k / k_group_size);
@ -355,18 +353,12 @@ struct BufferAKGroupImpl {
BufferAKGroupImpl(int max_m, int k, int k_group_size, void* ptr)
: max_m(max_m),
k(k),
k_group_size(k_group_size),
pack({{static_cast<index_t>(K_STEP), 'c'},
{static_cast<index_t>(M_STEP), 'r'},
{static_cast<index_t>(k_group_size / K_STEP), 'c'},
{static_cast<index_t>(K_BLOCK / k_group_size), 'c'},
{static_cast<index_t>(max_m / M_STEP), 'r'},
{static_cast<index_t>(k / K_BLOCK), 'c'}}) {
k_group_size(k_group_size) {
ASSERT_RELEASE(k % k_group_size == 0, "k must be multiple of k_group_size");
ASSERT_RELEASE(max_m % M_STEP == 0, "max_m must be multiple of M_STEP");
ASSERT_RELEASE(k % K_STEP == 0, "k must be multiple of K_STEP");
ASSERT_RELEASE(K_BLOCK % k_group_size == 0, "K_BLOCK must be multiple of k_group_size");
ASSERT_RELEASE(k % K_BLOCK == 0, "k must be multiple of K_BLOCK");
// ASSERT_RELEASE(k % K_BLOCK == 0, "k must be multiple of K_BLOCK");
k_group_count = k / k_group_size;
set_data(ptr);
@ -922,6 +914,77 @@ struct BufferBInt4WithZeroImpl {
float* get_min(int n, int n_begin) { return mins + n_begin; }
};
// BufferB for Signed Int4 with KGroup Scale (no zero point)
// Used for K2 MoE - signed int4 range: [-8, 7]
template <typename K>
struct BufferBInt4KGroupImpl {
using dt = typename K::dt;
dt* b; // packed signed int4 weights, col majored
float* d; // scales only (no mins/zero-points), row majored
int n, k, k_group_size, k_group_count;
static constexpr int N_STEP = K::N_STEP;
static constexpr int K_STEP = K::K_STEP;
static constexpr bool SCALE = true;
// Size calculation: packed int4 weights + scales (NO mins)
static size_t required_size(int n, int k, int k_group_size) {
return sizeof(int8_t) * n * k / 2 + sizeof(float) * n * (k / k_group_size);
}
BufferBInt4KGroupImpl(int n, int k, int k_group_size, void* ptr) : n(n), k(k), k_group_size(k_group_size) {
assert(reinterpret_cast<intptr_t>(ptr) % 64 == 0);
assert(n % N_STEP == 0);
assert(k % K_STEP == 0);
if (n % N_STEP || k % K_STEP || k % k_group_size) {
printf("BufferBInt4KGroupImpl: n: %d, k: %d, N_STEP: %d, K_STEP: %d, k_group_size: %d\n", n, k, N_STEP,
K_STEP, k_group_size);
throw std::runtime_error("n or k is not aligned to N_STEP or K_STEP");
}
k_group_count = k / k_group_size;
b = reinterpret_cast<dt*>(ptr);
d = reinterpret_cast<float*>(offset_pointer(b, n * k / 2));
}
// Load from packed signed int4 format
// Input: proj is packed int4 weights (2 int4 values per byte)
// Each int4 value is in range [-8, 7] (signed)
void from_raw_mat(uint8_t* proj, int ith, int nth) {
auto [n_start, n_end] = K::split_range_n(n, ith, nth);
if (n_start >= n_end) {
return;
}
const size_t row_bytes = static_cast<size_t>(k) / 2;
const size_t rows = static_cast<size_t>(n_end - n_start);
uint8_t* dst_weights = reinterpret_cast<uint8_t*>(b) + n_start * row_bytes;
const uint8_t* src_weights = proj + n_start * row_bytes;
std::memcpy(dst_weights, src_weights, rows * row_bytes);
}
// Get pointer to submatrix for computation
dt* get_submat(int n, int k, int n_begin, int k_begin) {
const size_t row_bytes = static_cast<size_t>(k) / 2;
const size_t row_offset = static_cast<size_t>(n_begin) * row_bytes;
const size_t col_offset = static_cast<size_t>(k_begin) / 2;
return reinterpret_cast<dt*>(reinterpret_cast<uint8_t*>(b) + row_offset + col_offset);
}
// Get scale pointer for a specific row and k_group
float* get_scale(int n, int n_begin, int k, int k_begin) {
int k_group_idx = k_begin / k_group_size;
return d + n_begin * (k / k_group_size) + k_group_idx;
}
// Split range for parallel processing
static std::pair<int, int> split_range_n(int n, int ith, int nth) {
int n_per_thread = (n + nth - 1) / nth;
n_per_thread = (n_per_thread + N_STEP - 1) / N_STEP * N_STEP;
int n_start = std::min(ith * n_per_thread, n);
int n_end = std::min(n_start + n_per_thread, n);
return {n_start, n_end};
}
};
template <typename K>
struct BufferBInt4WithZeroKGroupImpl {
using dt = typename K::dt;

View file

@ -1015,8 +1015,9 @@ struct GemmKernel224Int8 {
static void avx_kernel(int m, int n, int k, int m_begin, int n_begin, int k_block_begin, float* c, BufferA* ba,
BufferB* bb) {
__m512i* c512 = (__m512i*)c;
int m_block_end = std::min(m - m_begin, M_STEP);
if (k_block_begin == 0) {
for (int m_i = 0; m_i < m; m_i++) {
for (int m_i = 0; m_i < m_block_end; m_i++) {
c512[m_i * 2] = _mm512_setzero_si512();
c512[m_i * 2 + 1] = _mm512_setzero_si512();
}
@ -1028,7 +1029,7 @@ struct GemmKernel224Int8 {
int32_t* a32 = (int32_t*)ba->get_submat(m, k, m_begin, k_block_begin + k_begin);
__m512i* b512 = (__m512i*)bb->get_submat(n, k, n_begin, k_block_begin + k_begin);
for (int m_i = 0; m_i < m && m_i < M_STEP; m_i++) {
for (int m_i = 0; m_i < m_block_end; m_i++) {
for (int k_i = 0; k_i < 16; k_i++) {
__m512i ma = _mm512_set1_epi32(a32[m_i * 16 + k_i]);
for (int n_i = 0; n_i < 2; n_i++) {
@ -1239,8 +1240,9 @@ struct GemmKernel224Int4 {
BufferB* bb) {
using K = GemmKernel224Int4;
__m512i* c512 = (__m512i*)c;
int m_block_end = std::min(m - m_begin, M_STEP);
if (k_block_begin == 0) {
for (int m_i = 0; m_i < m; m_i++) {
for (int m_i = 0; m_i < m_block_end; m_i++) {
c512[m_i * 2] = _mm512_setzero_si512();
c512[m_i * 2 + 1] = _mm512_setzero_si512();
}
@ -1250,7 +1252,7 @@ struct GemmKernel224Int4 {
int32_t* a32_lo = (int32_t*)ba->get_submat(m, k, m_begin, k_block_begin + k_begin);
int32_t* a32_hi = (int32_t*)ba->get_submat(m, k, m_begin, k_block_begin + k_begin + K::K_STEP);
__m512i* b512 = (__m512i*)bb->get_submat(n, k, n_begin, k_block_begin + k_begin);
for (int m_i = 0; m_i < m && m_i < M_STEP; m_i++) {
for (int m_i = 0; m_i < m_block_end; m_i++) {
for (int k_i = 0; k_i < 16; k_i++) {
__m512i ma_lo = _mm512_set1_epi32(a32_lo[m_i * 16 + k_i]);
__m512i ma_hi = _mm512_set1_epi32(a32_hi[m_i * 16 + k_i]);
@ -1533,8 +1535,9 @@ struct GemmKernel224Int4_1 {
BufferB* bb) {
using K = GemmKernel224Int4_1;
__m512i* c512 = (__m512i*)c;
int m_block_end = std::min(m - m_begin, M_STEP);
if (k_block_begin == 0) {
for (int m_i = 0; m_i < m; m_i++) {
for (int m_i = 0; m_i < m_block_end; m_i++) {
c512[m_i * 2] = _mm512_setzero_si512();
c512[m_i * 2 + 1] = _mm512_setzero_si512();
}
@ -1543,7 +1546,7 @@ struct GemmKernel224Int4_1 {
int32_t* a32_lo = (int32_t*)ba->get_submat(m, k, m_begin, k_block_begin + k_begin);
int32_t* a32_hi = (int32_t*)ba->get_submat(m, k, m_begin, k_block_begin + k_begin + K::K_STEP);
__m512i* b512 = (__m512i*)bb->get_submat(n, k, n_begin, k_block_begin + k_begin);
for (int m_i = 0; m_i < m && m_i < M_STEP; m_i++) {
for (int m_i = 0; m_i < m_block_end; m_i++) {
for (int k_i = 0; k_i < 16; k_i++) {
__m512i ma_lo = _mm512_set1_epi32(a32_lo[m_i * 16 + k_i]);
__m512i ma_hi = _mm512_set1_epi32(a32_hi[m_i * 16 + k_i]);
@ -2193,10 +2196,11 @@ struct GemmKernel224Int4KGroup {
BufferB* bb, int k_group_size) {
using K = GemmKernel224Int4KGroup;
__m512i* c512 = (__m512i*)int_c;
int m_block_end = std::min(m - m_begin, M_STEP);
// Initialize int_c to zero at the start of k_group
if (k_block_begin % k_group_size == 0) {
for (int m_i = 0; m_i < m && m_i < M_STEP; m_i++) {
for (int m_i = 0; m_i < m_block_end; m_i++) {
c512[m_i * 2] = _mm512_setzero_si512();
c512[m_i * 2 + 1] = _mm512_setzero_si512();
}
@ -2205,7 +2209,7 @@ struct GemmKernel224Int4KGroup {
if (k_offset == 0) {
int32_t* a32_lo = (int32_t*)ba->get_submat(m, k, m_begin, k_block_begin);
__m512i* b512 = (__m512i*)bb->get_submat(n, k, n_begin, k_block_begin);
for (int m_i = 0; m_i < m && m_i < M_STEP; m_i++) {
for (int m_i = 0; m_i < m_block_end; m_i++) {
for (int k_i = 0; k_i < 16; k_i++) {
__m512i ma_lo = _mm512_set1_epi32(a32_lo[m_i * 16 + k_i]);
for (int n_i = 0; n_i < 2; n_i++) {
@ -2217,7 +2221,7 @@ struct GemmKernel224Int4KGroup {
} else {
int32_t* a32_hi = (int32_t*)ba->get_submat(m, k, m_begin, k_block_begin);
__m512i* b512 = (__m512i*)bb->get_submat(n, k, n_begin, k_block_begin - K::K_STEP);
for (int m_i = 0; m_i < m && m_i < M_STEP; m_i++) {
for (int m_i = 0; m_i < m_block_end; m_i++) {
for (int k_i = 0; k_i < 16; k_i++) {
__m512i ma_hi = _mm512_set1_epi32(a32_hi[m_i * 16 + k_i]);
for (int n_i = 0; n_i < 2; n_i++) {
@ -2471,8 +2475,9 @@ struct GemmKernel224Int4_1KGroup {
BufferB* bb, int k_group_size) {
using K = GemmKernel224Int4_1KGroup;
__m512i* c512 = (__m512i*)int_c;
int m_block_end = std::min(m - m_begin, M_STEP);
if (k_block_begin % k_group_size == 0) {
for (int m_i = 0; m_i < m; m_i++) {
for (int m_i = 0; m_i < m_block_end; m_i++) {
c512[m_i * 2] = _mm512_setzero_si512();
c512[m_i * 2 + 1] = _mm512_setzero_si512();
}
@ -2481,7 +2486,7 @@ struct GemmKernel224Int4_1KGroup {
if (k_offset == 0) {
int32_t* a32_lo = (int32_t*)ba->get_submat(m, k, m_begin, k_block_begin);
__m512i* b512 = (__m512i*)bb->get_submat(n, k, n_begin, k_block_begin);
for (int m_i = 0; m_i < m && m_i < M_STEP; m_i++) {
for (int m_i = 0; m_i < m_block_end; m_i++) {
for (int k_i = 0; k_i < 16; k_i++) {
__m512i ma_lo = _mm512_set1_epi32(a32_lo[m_i * 16 + k_i]);
for (int n_i = 0; n_i < 2; n_i++) {
@ -2493,7 +2498,7 @@ struct GemmKernel224Int4_1KGroup {
} else {
int32_t* a32_hi = (int32_t*)ba->get_submat(m, k, m_begin, k_block_begin);
__m512i* b512 = (__m512i*)bb->get_submat(n, k, n_begin, k_block_begin - K::K_STEP);
for (int m_i = 0; m_i < m && m_i < M_STEP; m_i++) {
for (int m_i = 0; m_i < m_block_end; m_i++) {
for (int k_i = 0; k_i < 16; k_i++) {
__m512i ma_hi = _mm512_set1_epi32(a32_hi[m_i * 16 + k_i]);
for (int n_i = 0; n_i < 2; n_i++) {
@ -2746,8 +2751,9 @@ struct GemmKernel224Int4_1_LowKGroup {
BufferB* bb, int k_group_size) {
using K = GemmKernel224Int4_1_LowKGroup;
__m512i* c512 = (__m512i*)int_c;
int m_block_end = std::min(m - m_begin, M_STEP);
if (k_block_begin % k_group_size == 0) {
for (int m_i = 0; m_i < m; m_i++) {
for (int m_i = 0; m_i < m_block_end; m_i++) {
c512[m_i * 2] = _mm512_setzero_si512();
c512[m_i * 2 + 1] = _mm512_setzero_si512();
}
@ -2756,7 +2762,7 @@ struct GemmKernel224Int4_1_LowKGroup {
if (k_offset == 0) {
int32_t* a32_lo = (int32_t*)ba->get_submat(m, k, m_begin, k_block_begin);
__m512i* b512 = (__m512i*)bb->get_submat(n, k, n_begin, k_block_begin);
for (int m_i = 0; m_i < m && m_i < M_STEP; m_i++) {
for (int m_i = 0; m_i < m_block_end; m_i++) {
for (int k_i = 0; k_i < 16; k_i++) {
__m512i ma_lo = _mm512_set1_epi32(a32_lo[m_i * 16 + k_i]);
for (int n_i = 0; n_i < 2; n_i++) {
@ -2768,7 +2774,7 @@ struct GemmKernel224Int4_1_LowKGroup {
} else {
int32_t* a32_hi = (int32_t*)ba->get_submat(m, k, m_begin, k_block_begin);
__m512i* b512 = (__m512i*)bb->get_submat(n, k, n_begin, k_block_begin - K::K_STEP);
for (int m_i = 0; m_i < m && m_i < M_STEP; m_i++) {
for (int m_i = 0; m_i < m_block_end; m_i++) {
for (int k_i = 0; k_i < 16; k_i++) {
__m512i ma_hi = _mm512_set1_epi32(a32_hi[m_i * 16 + k_i]);
for (int n_i = 0; n_i < 2; n_i++) {
@ -2837,6 +2843,110 @@ struct GemmKernel224Int4_1_LowKGroup {
}
};
// K2 Signed Int4 K-group quantization kernel (AVX only, no AMX)
// For K2 MoE - signed int4 range: [-8, 7]
struct GemmKernel224Int4SmallKGroup {
using dt = uint8_t; // packed int4 type
using output_t = int32_t;
static constexpr double ELEMENT_SIZE = 0.5;
static const int VNNI_BLK = 4;
static const int M_STEP = 1;
static const int N_STEP = 32;
static const int K_STEP = 32;
static inline const int N_BLOCK = 256;
// K_BLOCK should match k_group_size for proper scaling
static inline const int K_BLOCK = 7168; // Will be overridden by k_group_size
static std::string name() { return "K2_INT4_KGROUP"; }
static int recommended_nth(int n) { return (n + N_BLOCK - 1) / N_BLOCK; }
static std::pair<int, int> split_range_n(int n, int ith, int nth) {
int n_start = N_BLOCK * ith;
int n_end = std::min(n, N_BLOCK * (ith + 1));
return {n_start, n_end};
}
static void config() {}
alignas(64) static constexpr uint8_t hi_mask_arr[32] = {
0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0,
0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0
};
alignas(64) static constexpr uint8_t lo_mask_arr[32] = {
0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F,
0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F
};
alignas(64) static constexpr uint8_t sign_xor_arr[32] = {
0x88, 0x88, 0x88, 0x88, 0x88, 0x88, 0x88, 0x88, 0x88, 0x88, 0x88, 0x88, 0x88, 0x88, 0x88, 0x88,
0x88, 0x88, 0x88, 0x88, 0x88, 0x88, 0x88, 0x88, 0x88, 0x88, 0x88, 0x88, 0x88, 0x88, 0x88, 0x88
};
static __m256i hi_mask() { return *((__m256i*)(&hi_mask_arr[0])); }
static __m256i lo_mask() { return *((__m256i*)(&lo_mask_arr[0])); }
static __m256i sign_xor_mask() { return *((__m256i*)(&sign_xor_arr[0])); }
using BufferA = BufferAKGroupImpl<GemmKernel224Int4SmallKGroup>;
using BufferB = BufferBInt4KGroupImpl<GemmKernel224Int4SmallKGroup>; // Use new signed int4 buffer
using BufferC = BufferCReduceImpl<GemmKernel224Int4SmallKGroup>;
// K-group aware AVX kernel for signed int4
static inline __m512i compressed_int4_to_int8_avx512(__m256i b256) {
b256 = _mm256_xor_si256(b256, sign_xor_mask());
__m256i b_hi = _mm256_and_si256(b256, hi_mask());
__m256i b_lo = _mm256_slli_epi16(_mm256_andnot_si256(hi_mask(), b256), 4);
__m256i unpack_lo = _mm256_unpacklo_epi8(b_lo, b_hi);
__m256i unpack_hi = _mm256_unpackhi_epi8(b_lo, b_hi);
__m512i result = _mm512_inserti64x4(_mm512_castsi256_si512(unpack_lo), unpack_hi, 1);
const __m512i lane_shuffle = _mm512_set_epi64(7, 6, 3, 2, 5, 4, 1, 0);
return _mm512_permutexvar_epi64(lane_shuffle, result);
}
static inline void integer_mat_vec_kgroup(int m, int n, int k, int k_group_size, BufferA* ba, BufferB *bb, BufferC* bc, int ith, int nth) {
auto [n_start, n_end] = split_range_n(n, ith, nth);
for (int m_begin = 0; m_begin < m; m_begin ++) {
float* c = bc->get_submat(m, n, m_begin, 0);
__m512i* a512 = (__m512i*)ba->get_submat(m, k, m_begin, 0);
for (int n_block_begin = n_start; n_block_begin < n_end; n_block_begin ++) {
__m256i* b256 = (__m256i*)bb->get_submat(n, k, n_block_begin, 0);
float* as = (float*)ba->get_scale(m, m_begin, k, 0);
float* bs = (float*)bb->get_scale(n, n_block_begin, k, 0);
__m512 sum = _mm512_setzero_ps();
#define WORK_K_BLOCK(k_block) \
{ \
__m256 abscale0 = _mm256_set1_ps(as[(k_block)*2] * bs[(k_block)*2]); \
__m256 abscale1 = _mm256_set1_ps(as[(k_block)*2+1] * bs[(k_block)*2+1]); \
__m512 abscale = _mm512_insertf32x8(_mm512_castps256_ps512(abscale0), abscale1, 1); \
__m512i mul = _mm512_setzero_si512(); \
mul = _mm512_dpbssd_epi32(mul, a512[k_block], compressed_int4_to_int8_avx512(b256[k_block])); \
sum = _mm512_add_ps(sum, _mm512_mul_ps(abscale, _mm512_cvtepi32_ps(mul))); \
}
for (int k_block = 0; k_block < k / 64; k_block += 2) {
WORK_K_BLOCK(k_block);
WORK_K_BLOCK(k_block + 1);
}
c[n_block_begin] = _mm512_reduce_add_ps(sum) / 16;
}
}
}
};
inline void vec_mul_kgroup(int m, int n, int k, int k_group_size, std::shared_ptr<GemmKernel224Int4SmallKGroup::BufferA> ba,
std::shared_ptr<GemmKernel224Int4SmallKGroup::BufferB> bb,
std::shared_ptr<GemmKernel224Int4SmallKGroup::BufferC> bc, int ith, int nth) {
GemmKernel224Int4SmallKGroup::integer_mat_vec_kgroup(m, n, k, k_group_size, ba.get(), bb.get(), bc.get(), ith, nth);
}
inline void mat_mul_kgroup(int m, int n, int k, int k_group_size, std::shared_ptr<GemmKernel224Int4SmallKGroup::BufferA> ba,
std::shared_ptr<GemmKernel224Int4SmallKGroup::BufferB> bb,
std::shared_ptr<GemmKernel224Int4SmallKGroup::BufferC> bc, int ith, int nth) {
GemmKernel224Int4SmallKGroup::integer_mat_vec_kgroup(m, n, k, k_group_size, ba.get(), bb.get(), bc.get(), ith, nth);
}
// New k-group aware matrix multiplication function
template <typename K, bool amx_or_avx = true>
void integer_mat_mul_kgroup(int m, int n, int k, int k_group_size, typename K::BufferA* ba, typename K::BufferB* bb,

View file

@ -17,7 +17,7 @@ from typing import List, Optional
from .experts_base import BaseMoEWrapper, KExpertsCPUBuffer
# Import backend implementations
from .utils.amx import AMXMoEWrapper
from .utils.amx import AMXMoEWrapper, RAWAMXMoEWrapper
from .utils.llamafile import LlamafileMoEWrapper
from .utils.moe_kernel import GeneralMoEWrapper
@ -77,7 +77,7 @@ class KTMoEWrapper:
chunked_prefill_size: Maximum prefill chunk size
cpu_save: Whether to save weights to CPU memory
max_deferred_experts_per_token: Number of experts per token to defer. Defaults to 0.
method: Backend method ("AMXINT4", "AMXINT8", "LLAMAFILE", "MOE_INT4", "MOE_INT8")
method: Backend method ("AMXINT4", "AMXINT8", "RAWINT4", "LLAMAFILE", "MOE_INT4", "MOE_INT8")
Returns:
An instance of the appropriate backend implementation (e.g., AMXMoEWrapper)
@ -85,6 +85,8 @@ class KTMoEWrapper:
# Select backend based on method
if method in ["AMXINT4", "AMXINT8"]:
backend_cls = AMXMoEWrapper
elif method == "RAWINT4":
backend_cls = RAWAMXMoEWrapper
elif method == "LLAMAFILE":
backend_cls = LlamafileMoEWrapper
elif method in ["MOE_INT4", "MOE_INT8"]:

View file

@ -4,13 +4,15 @@
Utilities for kt_kernel package.
"""
from .amx import AMXMoEWrapper
from .amx import AMXMoEWrapper, RAWAMXMoEWrapper
from .llamafile import LlamafileMoEWrapper
from .loader import SafeTensorLoader, GGUFLoader
from .loader import SafeTensorLoader, GGUFLoader, CompressedSafeTensorLoader
__all__ = [
"AMXMoEWrapper",
"RAWAMXMoEWrapper",
"LlamafileMoEWrapper",
"SafeTensorLoader",
"CompressedSafeTensorLoader",
"GGUFLoader",
]

View file

@ -4,16 +4,16 @@ import ctypes
# Use relative imports for package structure
from ..experts_base import BaseMoEWrapper
from .loader import SafeTensorLoader
from .loader import SafeTensorLoader, CompressedSafeTensorLoader
from kt_kernel_ext.moe import MOEConfig
try:
from kt_kernel_ext.moe import AMXInt4_MOE, AMXInt8_MOE
from kt_kernel_ext.moe import AMXInt4_MOE, AMXInt8_MOE, AMXInt4_KGroup_MOE
_HAS_AMX_SUPPORT = True
except (ImportError, AttributeError):
_HAS_AMX_SUPPORT = False
AMXInt4_MOE, AMXInt8_MOE = None, None
AMXInt4_MOE, AMXInt8_MOE, AMXInt4_KGroup_MOE = None, None, None
from typing import Optional
@ -301,3 +301,152 @@ class AMXMoEWrapper(BaseMoEWrapper):
del self.gate_scales
del self.up_scales
del self.down_scales
class RAWAMXMoEWrapper(BaseMoEWrapper):
"""Wrapper for RAWINT4 experts stored in compressed SafeTensor format."""
_compressed_loader_instance = None
def __init__(
self,
layer_idx: int,
num_experts: int,
num_experts_per_tok: int,
hidden_size: int,
moe_intermediate_size: int,
num_gpu_experts: int,
cpuinfer_threads: int,
threadpool_count: int,
weight_path: str,
chunked_prefill_size: int,
cpu_save: bool = False,
max_deferred_experts_per_token: Optional[int] = None,
method: str = "RAWINT4",
):
if not _HAS_AMX_SUPPORT or AMXInt4_KGroup_MOE is None:
raise RuntimeError("AMX backend with RAWINT4 support is not available.")
super().__init__(
layer_idx=layer_idx,
num_experts=num_experts,
num_experts_per_tok=num_experts_per_tok,
hidden_size=hidden_size,
moe_intermediate_size=moe_intermediate_size,
num_gpu_experts=num_gpu_experts,
cpuinfer_threads=cpuinfer_threads,
threadpool_count=threadpool_count,
weight_path=weight_path,
chunked_prefill_size=chunked_prefill_size,
cpu_save=cpu_save,
max_deferred_experts_per_token=max_deferred_experts_per_token,
method=method,
)
if RAWAMXMoEWrapper._compressed_loader_instance is None:
RAWAMXMoEWrapper._compressed_loader_instance = CompressedSafeTensorLoader(weight_path)
self.loader = RAWAMXMoEWrapper._compressed_loader_instance
self.gate_weights = None
self.up_weights = None
self.down_weights = None
self.gate_scales = None
self.up_scales = None
self.down_scales = None
def load_weights_from_tensors(
self,
gate_proj: torch.Tensor,
up_proj: torch.Tensor,
down_proj: torch.Tensor,
physical_to_logical_map_cpu: torch.Tensor,
):
raise NotImplementedError("RAWINT4 wrapper expects pre-quantized safetensor weights.")
def load_weights(self, physical_to_logical_map_cpu: torch.Tensor):
base_key = f"model.layers.{self.layer_idx}"
weights = self.loader.load_experts(base_key)
self.gate_weights = torch.stack(weights["gate"], dim=0).contiguous()
self.up_weights = torch.stack(weights["up"], dim=0).contiguous()
self.down_weights = torch.stack(weights["down"], dim=0).contiguous()
self.gate_scales = torch.stack(weights["gate_scale"], dim=0).to(torch.bfloat16).contiguous()
self.up_scales = torch.stack(weights["up_scale"], dim=0).to(torch.bfloat16).contiguous()
self.down_scales = torch.stack(weights["down_scale"], dim=0).to(torch.bfloat16).contiguous()
moe_config = MOEConfig(
self.num_experts,
self.num_experts_per_tok,
self.hidden_size,
self.moe_intermediate_size,
self.num_gpu_experts,
)
moe_config.layer_idx = self.layer_idx
moe_config.pool = self.cpu_infer.backend_
moe_config.max_len = self.chunked_prefill_size
moe_config.quant_config.bits = 4
moe_config.quant_config.group_size = 32
moe_config.quant_config.zero_point = False
moe_config.gate_proj = self.gate_weights.data_ptr()
moe_config.up_proj = self.up_weights.data_ptr()
moe_config.down_proj = self.down_weights.data_ptr()
moe_config.gate_scale = self.gate_scales.data_ptr()
moe_config.up_scale = self.up_scales.data_ptr()
moe_config.down_scale = self.down_scales.data_ptr()
self.moe = AMXInt4_KGroup_MOE(moe_config)
self.cpu_infer.submit(self.moe.load_weights_task(physical_to_logical_map_cpu.data_ptr()))
self.cpu_infer.sync()
del self.gate_weights
del self.up_weights
del self.down_weights
del self.gate_scales
del self.up_scales
del self.down_scales
def submit_write_weight_scale_to_buffer(
self,
gpu_tp_count: int,
gpu_experts_num: int,
w13_weight_ptrs,
w13_scale_ptrs,
w2_weight_ptrs,
w2_scale_ptrs,
):
"""
Submit the write_weight_scale_to_buffer task for RAWINT4 KGroup AMX implementation.
This method submits the C++-exposed task `write_weight_scale_to_buffer_task` to the
shared CPUInfer queue. The pointer lists should be plain integer lists (e.g. from
tensor.data_ptr()).
"""
if self.moe is None:
raise RuntimeError("MoE instance not initialized; cannot submit write_weight_scale_to_buffer task.")
if not hasattr(self.moe, "write_weight_scale_to_buffer_task"):
raise NotImplementedError(
"write_weight_scale_to_buffer_task is not available for this backend implementation."
)
self.cpu_infer.submit(
self.moe.write_weight_scale_to_buffer_task(
gpu_tp_count,
gpu_experts_num,
w13_weight_ptrs,
w13_scale_ptrs,
w2_weight_ptrs,
w2_scale_ptrs,
)
)
def sync_write_weight_scale_to_buffer(self):
"""
Block until previously submitted write_weight_scale_to_buffer tasks finish.
"""
# The CPUInfer.sync() call blocks until pending tasks complete.
self.cpu_infer.sync()

View file

@ -237,6 +237,56 @@ class SafeTensorLoader:
return name in self.tensor_file_map
class CompressedSafeTensorLoader(SafeTensorLoader):
"""Loader for compressed SafeTensor layouts (RAWINT4 weights)."""
def load_experts(self, base_key: str, device: str = "cpu"):
"""Load raw expert weights stored in compressed safetensor format."""
experts_prefix = f"{base_key}.mlp.experts"
expert_idx = 0
while self.has_tensor(f"{experts_prefix}.{expert_idx}.up_proj.weight_packed"):
expert_idx += 1
if expert_idx == 0:
raise ValueError(f"No experts found for key {experts_prefix}")
def load_projection(proj_name: str):
weight_entries = []
scale_entries = []
for exp_id in range(expert_idx):
weight_key = f"{experts_prefix}.{exp_id}.{proj_name}_proj.weight_packed"
scale_key = f"{experts_prefix}.{exp_id}.{proj_name}_proj.weight_scale"
if not self.has_tensor(weight_key):
raise KeyError(f"Missing tensor: {weight_key}")
if not self.has_tensor(scale_key):
raise KeyError(f"Missing tensor: {scale_key}")
weight_tensor = self.load_tensor(weight_key, device).contiguous()
scale_tensor = self.load_tensor(scale_key, device).contiguous()
weight_entries.append(weight_tensor)
scale_entries.append(scale_tensor)
return weight_entries, scale_entries
gate_weights, gate_scales = load_projection("gate")
up_weights, up_scales = load_projection("up")
down_weights, down_scales = load_projection("down")
return {
"gate": gate_weights,
"up": up_weights,
"down": down_weights,
"gate_scale": gate_scales,
"up_scale": up_scales,
"down_scale": down_scales,
}
class GGUFLoader:
"""
GGUF format loader using the official gguf library (gguf.gguf_reader.GGUFReader)