mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2026-04-28 11:49:51 +00:00
[Feature] Add avx-based kimi-k2 support (#1656)
Some checks are pending
Book-CI / test-2 (push) Waiting to run
Book-CI / test (push) Waiting to run
Book-CI / test-1 (push) Waiting to run
Deploy / deploy (macos-latest) (push) Waiting to run
Deploy / deploy (ubuntu-latest) (push) Waiting to run
Deploy / deploy (windows-latest) (push) Waiting to run
Some checks are pending
Book-CI / test-2 (push) Waiting to run
Book-CI / test (push) Waiting to run
Book-CI / test-1 (push) Waiting to run
Deploy / deploy (macos-latest) (push) Waiting to run
Deploy / deploy (ubuntu-latest) (push) Waiting to run
Deploy / deploy (windows-latest) (push) Waiting to run
* support Kimi-K2-Thinking original weight fix amx kernel bug * update k2 avx kernel. * feat: add CPUInfer write buffer task * [feat]: add kimi k2 cpu write buffer support - Implement write_weights_to_buffer function in k2-moe.hpp for extracting GPU expert weights - Fix down (w2) weight column-wise slicing for different TP configurations - Support three TP scenarios: cpu_tp == gpu_tp, cpu_tp > gpu_tp, cpu_tp < gpu_tp - Add comprehensive test cases for weight extraction validation - Ensure compatibility with Kimi model's MoE architecture * [fix]: correct write_weight_scale_to_buffer expert offset calculation Fixed the bug in write_weight_scale_to_buffer_task where expert offsets in GPU buffers were incorrectly calculated. Changed from using per_expert_gpu sizes to using full gpu_tp sizes, ensuring correct memory layout for multi-expert scenarios. Also added benchmark scripts for k2 moe and write buffer operations, and cleaned up debug output in test files. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com> * [feat]: add write buffer wrapper * [fix] fix comment --------- Co-authored-by: ouqingliang <1692110604@qq.com> Co-authored-by: Claude <noreply@anthropic.com>
This commit is contained in:
parent
c2b8c60c4e
commit
fcf8882075
12 changed files with 2649 additions and 34 deletions
363
kt-kernel/bench/bench_k2_moe_amx.py
Normal file
363
kt-kernel/bench/bench_k2_moe_amx.py
Normal file
|
|
@ -0,0 +1,363 @@
|
||||||
|
#!/usr/bin/env python
|
||||||
|
# coding=utf-8
|
||||||
|
"""
|
||||||
|
Benchmark AMX_K2_MOE_TP int4 path with packed weights and BF16 scales.
|
||||||
|
"""
|
||||||
|
import json
|
||||||
|
import math
|
||||||
|
import os
|
||||||
|
import platform
|
||||||
|
import subprocess
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
|
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "build"))
|
||||||
|
|
||||||
|
import kt_kernel_ext
|
||||||
|
import torch
|
||||||
|
|
||||||
|
# Benchmark parameters (single MoE, no layer loop)
|
||||||
|
expert_num = 384
|
||||||
|
hidden_size = 7168
|
||||||
|
intermediate_size = 2048
|
||||||
|
max_len = 25600
|
||||||
|
num_experts_per_tok = 8
|
||||||
|
qlen = 1
|
||||||
|
warm_up_iter = 1000
|
||||||
|
test_iter = 5000
|
||||||
|
k_group_size = 32
|
||||||
|
|
||||||
|
physical_to_logical_map = (
|
||||||
|
torch.tensor(data=range(expert_num), device="cpu", dtype=torch.int64).contiguous()
|
||||||
|
)
|
||||||
|
|
||||||
|
worker_config = kt_kernel_ext.WorkerPoolConfig()
|
||||||
|
worker_config.subpool_count = 2
|
||||||
|
worker_config.subpool_numa_map = [0, 1]
|
||||||
|
worker_config.subpool_thread_count = [40, 40]
|
||||||
|
CPUInfer = kt_kernel_ext.CPUInfer(worker_config)
|
||||||
|
|
||||||
|
|
||||||
|
def get_git_commit():
|
||||||
|
result = {}
|
||||||
|
try:
|
||||||
|
commit = (
|
||||||
|
subprocess.check_output(["git", "rev-parse", "HEAD"])
|
||||||
|
.decode("utf-8")
|
||||||
|
.strip()
|
||||||
|
)
|
||||||
|
commit_msg = (
|
||||||
|
subprocess.check_output(["git", "log", "-1", "--pretty=%B"])
|
||||||
|
.decode("utf-8")
|
||||||
|
.strip()
|
||||||
|
)
|
||||||
|
result["commit"] = commit
|
||||||
|
result["commit_message"] = commit_msg
|
||||||
|
|
||||||
|
dirty_output = (
|
||||||
|
subprocess.check_output(["git", "status", "--porcelain"])
|
||||||
|
.decode("utf-8")
|
||||||
|
.strip()
|
||||||
|
)
|
||||||
|
if dirty_output:
|
||||||
|
result["dirty"] = True
|
||||||
|
result["dirty_files"] = dirty_output.splitlines()
|
||||||
|
else:
|
||||||
|
result["dirty"] = False
|
||||||
|
except Exception as e:
|
||||||
|
result["commit"] = None
|
||||||
|
result["commit_message"] = None
|
||||||
|
result["dirty"] = None
|
||||||
|
result["error"] = str(e)
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def get_system_info():
|
||||||
|
info = {}
|
||||||
|
uname = platform.uname()
|
||||||
|
info["system_name"] = uname.system
|
||||||
|
info["node_name"] = uname.node
|
||||||
|
|
||||||
|
cpu_model = None
|
||||||
|
if os.path.exists("/proc/cpuinfo"):
|
||||||
|
try:
|
||||||
|
with open("/proc/cpuinfo", "r") as f:
|
||||||
|
for line in f:
|
||||||
|
if "model name" in line:
|
||||||
|
cpu_model = line.split(":", 1)[1].strip()
|
||||||
|
break
|
||||||
|
except Exception as e:
|
||||||
|
cpu_model = f"Error: {e}"
|
||||||
|
info["cpu_model"] = cpu_model
|
||||||
|
|
||||||
|
mem_total_gb = None
|
||||||
|
if os.path.exists("/proc/meminfo"):
|
||||||
|
try:
|
||||||
|
with open("/proc/meminfo", "r") as f:
|
||||||
|
for line in f:
|
||||||
|
if "MemTotal" in line:
|
||||||
|
mem_kb = float(line.split(":", 1)[1].split()[0])
|
||||||
|
mem_total_gb = round(mem_kb / (1024 * 1024), 2)
|
||||||
|
break
|
||||||
|
except Exception as e:
|
||||||
|
mem_total_gb = f"Error: {e}"
|
||||||
|
info["memory_size_GB"] = mem_total_gb
|
||||||
|
|
||||||
|
info["cpu_core_count"] = os.cpu_count()
|
||||||
|
|
||||||
|
sockets = set()
|
||||||
|
if os.path.exists("/proc/cpuinfo"):
|
||||||
|
try:
|
||||||
|
with open("/proc/cpuinfo", "r") as f:
|
||||||
|
for line in f:
|
||||||
|
if "physical id" in line:
|
||||||
|
sockets.add(line.split(":", 1)[1].strip())
|
||||||
|
except Exception:
|
||||||
|
sockets = set()
|
||||||
|
info["cpu_socket_count"] = len(sockets) if len(sockets) > 0 else 1
|
||||||
|
|
||||||
|
return info
|
||||||
|
|
||||||
|
|
||||||
|
script_path = os.path.abspath(__file__)
|
||||||
|
script_dir = os.path.dirname(script_path)
|
||||||
|
script_name = os.path.splitext(os.path.basename(script_path))[0]
|
||||||
|
json_path = os.path.join(script_dir, script_name + ".jsonl")
|
||||||
|
|
||||||
|
|
||||||
|
def record_results(result, filename=json_path):
|
||||||
|
with open(filename, "a") as f:
|
||||||
|
f.write(json.dumps(result) + "\n")
|
||||||
|
|
||||||
|
|
||||||
|
def pack_to_int32(
|
||||||
|
value: torch.Tensor, num_bits: int, packed_dim: int = 1
|
||||||
|
) -> torch.Tensor:
|
||||||
|
if value.dtype is not torch.int8:
|
||||||
|
raise ValueError("Tensor must be torch.int8 before packing")
|
||||||
|
if not (1 <= num_bits <= 8):
|
||||||
|
raise ValueError(f"num_bits must be in [1, 8], got {num_bits}")
|
||||||
|
|
||||||
|
offset = 1 << (num_bits - 1)
|
||||||
|
value = (value + offset).to(torch.uint8)
|
||||||
|
device = value.device
|
||||||
|
|
||||||
|
pack_factor = 32 // num_bits
|
||||||
|
|
||||||
|
if packed_dim == 0:
|
||||||
|
value = value.transpose(0, 1)
|
||||||
|
|
||||||
|
rows, cols = value.shape
|
||||||
|
padded_cols = math.ceil(cols / pack_factor) * pack_factor
|
||||||
|
pad_len = padded_cols - cols
|
||||||
|
|
||||||
|
if pad_len > 0:
|
||||||
|
value = torch.nn.functional.pad(value, (0, pad_len))
|
||||||
|
|
||||||
|
num_groups = padded_cols // pack_factor
|
||||||
|
reshaped = value.view(rows, num_groups, pack_factor).to(torch.int32)
|
||||||
|
bit_shifts = torch.arange(pack_factor, device=device, dtype=torch.int32) * num_bits
|
||||||
|
packed = (reshaped << bit_shifts).sum(dim=2, dtype=torch.int32)
|
||||||
|
|
||||||
|
if packed_dim == 0:
|
||||||
|
packed = packed.transpose(0, 1)
|
||||||
|
|
||||||
|
return packed
|
||||||
|
|
||||||
|
|
||||||
|
def pack_tensor_per_row(q: torch.Tensor, num_bits: int) -> torch.Tensor:
|
||||||
|
e, rows, cols = q.shape
|
||||||
|
flat = q.view(e * rows, cols)
|
||||||
|
packed = pack_to_int32(flat, num_bits)
|
||||||
|
return packed.view(e, rows, -1).contiguous()
|
||||||
|
|
||||||
|
|
||||||
|
def quantize_k2_tensor(weights: torch.Tensor, group_size: int):
|
||||||
|
"""
|
||||||
|
K2 int4 quantization producing int32-packed weights (8 int4s each) and BF16 scales.
|
||||||
|
"""
|
||||||
|
weights_f32 = weights.to(torch.float32)
|
||||||
|
e, rows, cols = weights_f32.shape
|
||||||
|
if cols % group_size != 0 or cols % 2 != 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"cols ({cols}) must be divisible by group_size ({group_size}) and 2"
|
||||||
|
)
|
||||||
|
|
||||||
|
reshaped = weights_f32.view(e, rows, cols // group_size, group_size)
|
||||||
|
max_abs = reshaped.abs().amax(dim=-1, keepdim=True).clamp(min=1e-8)
|
||||||
|
scales = (max_abs / 7.0).squeeze(-1)
|
||||||
|
q = torch.round(reshaped / scales.unsqueeze(-1)).clamp(-8, 7).to(torch.int8)
|
||||||
|
q = q.view(e, rows, cols)
|
||||||
|
packed = pack_tensor_per_row(q, num_bits=4).view(e, rows, cols // 8).contiguous()
|
||||||
|
scales = scales.to(torch.bfloat16).contiguous().view(
|
||||||
|
e, rows, cols // group_size
|
||||||
|
).contiguous()
|
||||||
|
return packed, scales
|
||||||
|
|
||||||
|
|
||||||
|
def build_quantized_layer_weights():
|
||||||
|
gate_proj = torch.randn(
|
||||||
|
(expert_num, intermediate_size, hidden_size),
|
||||||
|
dtype=torch.float32,
|
||||||
|
device="cpu",
|
||||||
|
).contiguous()
|
||||||
|
up_proj = torch.randn(
|
||||||
|
(expert_num, intermediate_size, hidden_size),
|
||||||
|
dtype=torch.float32,
|
||||||
|
device="cpu",
|
||||||
|
).contiguous()
|
||||||
|
down_proj = torch.randn(
|
||||||
|
(expert_num, hidden_size, intermediate_size),
|
||||||
|
dtype=torch.float32,
|
||||||
|
device="cpu",
|
||||||
|
).contiguous()
|
||||||
|
|
||||||
|
gate_q, gate_scales = quantize_k2_tensor(gate_proj, k_group_size)
|
||||||
|
up_q, up_scales = quantize_k2_tensor(up_proj, k_group_size)
|
||||||
|
down_q, down_scales = quantize_k2_tensor(down_proj, k_group_size)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"gate_qweight": gate_q,
|
||||||
|
"up_qweight": up_q,
|
||||||
|
"down_qweight": down_q,
|
||||||
|
"gate_scales": gate_scales,
|
||||||
|
"up_scales": up_scales,
|
||||||
|
"down_scales": down_scales,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def bench_k2_moe():
|
||||||
|
with torch.inference_mode():
|
||||||
|
bytes_per_elem = 0.5 + 2.0 / k_group_size
|
||||||
|
|
||||||
|
quant_data = build_quantized_layer_weights()
|
||||||
|
config = kt_kernel_ext.moe.MOEConfig(
|
||||||
|
expert_num, num_experts_per_tok, hidden_size, intermediate_size, 0
|
||||||
|
)
|
||||||
|
config.max_len = max_len
|
||||||
|
config.quant_config.bits = 4
|
||||||
|
config.quant_config.group_size = k_group_size
|
||||||
|
config.quant_config.zero_point = False
|
||||||
|
|
||||||
|
config.gate_proj = quant_data["gate_qweight"].data_ptr()
|
||||||
|
config.up_proj = quant_data["up_qweight"].data_ptr()
|
||||||
|
config.down_proj = quant_data["down_qweight"].data_ptr()
|
||||||
|
|
||||||
|
config.gate_scale = quant_data["gate_scales"].data_ptr()
|
||||||
|
config.up_scale = quant_data["up_scales"].data_ptr()
|
||||||
|
config.down_scale = quant_data["down_scales"].data_ptr()
|
||||||
|
config.pool = CPUInfer.backend_
|
||||||
|
|
||||||
|
moe = kt_kernel_ext.moe.AMXInt4_KGroup_MOE(config)
|
||||||
|
CPUInfer.submit(moe.load_weights_task(physical_to_logical_map.data_ptr()))
|
||||||
|
CPUInfer.sync()
|
||||||
|
|
||||||
|
gen_iter = 3000
|
||||||
|
expert_ids = (
|
||||||
|
torch.rand(gen_iter * qlen, expert_num, device="cpu")
|
||||||
|
.argsort(dim=-1)[:, :num_experts_per_tok]
|
||||||
|
.reshape(gen_iter, qlen * num_experts_per_tok)
|
||||||
|
.contiguous()
|
||||||
|
)
|
||||||
|
weights = torch.rand(
|
||||||
|
(gen_iter, qlen, num_experts_per_tok), dtype=torch.float32, device="cpu"
|
||||||
|
).contiguous()
|
||||||
|
input_tensor = torch.randn(
|
||||||
|
(qlen, hidden_size), dtype=torch.bfloat16, device="cpu"
|
||||||
|
).contiguous()
|
||||||
|
output_tensor = torch.empty_like(input_tensor)
|
||||||
|
bsz_tensor = torch.tensor([qlen], device="cpu")
|
||||||
|
|
||||||
|
for i in tqdm(range(warm_up_iter), desc="Warm-up"):
|
||||||
|
CPUInfer.submit(
|
||||||
|
moe.forward_task(
|
||||||
|
bsz_tensor.data_ptr(),
|
||||||
|
num_experts_per_tok,
|
||||||
|
expert_ids[i % gen_iter].data_ptr(),
|
||||||
|
weights[i % gen_iter].data_ptr(),
|
||||||
|
input_tensor.data_ptr(),
|
||||||
|
output_tensor.data_ptr(),
|
||||||
|
False,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
CPUInfer.sync()
|
||||||
|
|
||||||
|
start = time.perf_counter()
|
||||||
|
for i in tqdm(range(test_iter), desc="Testing"):
|
||||||
|
CPUInfer.submit(
|
||||||
|
moe.forward_task(
|
||||||
|
bsz_tensor.data_ptr(),
|
||||||
|
num_experts_per_tok,
|
||||||
|
expert_ids[i % gen_iter].data_ptr(),
|
||||||
|
weights[i % gen_iter].data_ptr(),
|
||||||
|
input_tensor.data_ptr(),
|
||||||
|
output_tensor.data_ptr(),
|
||||||
|
False,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
CPUInfer.sync()
|
||||||
|
end = time.perf_counter()
|
||||||
|
total_time = end - start
|
||||||
|
|
||||||
|
time_per_iter_us = total_time / test_iter * 1e6
|
||||||
|
bandwidth = (
|
||||||
|
hidden_size
|
||||||
|
* intermediate_size
|
||||||
|
* 3
|
||||||
|
* num_experts_per_tok
|
||||||
|
* (1 / 8 * 256 * (1 - (31 / 32) ** qlen))
|
||||||
|
* bytes_per_elem
|
||||||
|
* test_iter
|
||||||
|
/ total_time
|
||||||
|
/ 1e9
|
||||||
|
)
|
||||||
|
flops = (
|
||||||
|
hidden_size
|
||||||
|
* intermediate_size
|
||||||
|
* qlen
|
||||||
|
* 3
|
||||||
|
* num_experts_per_tok
|
||||||
|
* 2
|
||||||
|
* test_iter
|
||||||
|
/ total_time
|
||||||
|
/ 1e12
|
||||||
|
)
|
||||||
|
|
||||||
|
print("Quant mode: int4_k2")
|
||||||
|
print("Time(s): ", total_time)
|
||||||
|
print("Iteration: ", test_iter)
|
||||||
|
print("Time(us) per iteration: ", time_per_iter_us)
|
||||||
|
print("Bandwidth: ", bandwidth, "GB/s")
|
||||||
|
print("Flops: ", flops, "TFLOPS")
|
||||||
|
print("")
|
||||||
|
|
||||||
|
result = {
|
||||||
|
"quant_mode": "int4_k2",
|
||||||
|
"total_time_seconds": total_time,
|
||||||
|
"iterations": test_iter,
|
||||||
|
"time_per_iteration_us": time_per_iter_us,
|
||||||
|
"bandwidth_GBs": bandwidth,
|
||||||
|
"flops_TFLOPS": flops,
|
||||||
|
"timestamp": time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()),
|
||||||
|
"test_parameters": {
|
||||||
|
"expert_num": expert_num,
|
||||||
|
"hidden_size": hidden_size,
|
||||||
|
"intermediate_size": intermediate_size,
|
||||||
|
"max_len": max_len,
|
||||||
|
"num_experts_per_tok": num_experts_per_tok,
|
||||||
|
"qlen": qlen,
|
||||||
|
"warm_up_iter": warm_up_iter,
|
||||||
|
"test_iter": test_iter,
|
||||||
|
"k_group_size": k_group_size,
|
||||||
|
"bytes_per_elem": bytes_per_elem,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
result.update(get_git_commit())
|
||||||
|
result.update(get_system_info())
|
||||||
|
record_results(result)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
bench_k2_moe()
|
||||||
309
kt-kernel/bench/bench_k2_write_buffer.py
Normal file
309
kt-kernel/bench/bench_k2_write_buffer.py
Normal file
|
|
@ -0,0 +1,309 @@
|
||||||
|
#!/usr/bin/env python
|
||||||
|
# coding=utf-8
|
||||||
|
"""
|
||||||
|
Benchmark write_weight_scale_to_buffer for AMX_K2_MOE_TP (int4 packed weights + bf16 scales).
|
||||||
|
"""
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import platform
|
||||||
|
import subprocess
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
|
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "build"))
|
||||||
|
|
||||||
|
import kt_kernel_ext
|
||||||
|
import torch
|
||||||
|
|
||||||
|
# Benchmark parameters (single MoE, mirror examples/test_k2_write_buffer.py)
|
||||||
|
expert_num = 384
|
||||||
|
num_experts_per_tok = expert_num
|
||||||
|
gpu_tp_count = 4
|
||||||
|
|
||||||
|
warm_up_iter = 3
|
||||||
|
test_iter = 7
|
||||||
|
|
||||||
|
gpu_experts_num = expert_num
|
||||||
|
|
||||||
|
hidden_size = 7168
|
||||||
|
intermediate_size = 2048
|
||||||
|
group_size = 32
|
||||||
|
max_len = 1
|
||||||
|
|
||||||
|
physical_to_logical_map = torch.arange(expert_num, dtype=torch.int64, device="cpu").contiguous()
|
||||||
|
CPUInfer = kt_kernel_ext.CPUInfer(96)
|
||||||
|
|
||||||
|
|
||||||
|
def get_git_commit():
|
||||||
|
result = {}
|
||||||
|
try:
|
||||||
|
commit = (
|
||||||
|
subprocess.check_output(["git", "rev-parse", "HEAD"]).decode("utf-8").strip()
|
||||||
|
)
|
||||||
|
commit_msg = (
|
||||||
|
subprocess.check_output(["git", "log", "-1", "--pretty=%B"])
|
||||||
|
.decode("utf-8")
|
||||||
|
.strip()
|
||||||
|
)
|
||||||
|
result["commit"] = commit
|
||||||
|
result["commit_message"] = commit_msg
|
||||||
|
|
||||||
|
dirty_output = (
|
||||||
|
subprocess.check_output(["git", "status", "--porcelain"]).decode("utf-8").strip()
|
||||||
|
)
|
||||||
|
if dirty_output:
|
||||||
|
result["dirty"] = True
|
||||||
|
result["dirty_files"] = dirty_output.splitlines()
|
||||||
|
else:
|
||||||
|
result["dirty"] = False
|
||||||
|
except Exception as e:
|
||||||
|
result["commit"] = None
|
||||||
|
result["commit_message"] = None
|
||||||
|
result["dirty"] = None
|
||||||
|
result["error"] = str(e)
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def get_system_info():
|
||||||
|
info = {}
|
||||||
|
uname = platform.uname()
|
||||||
|
info["system_name"] = uname.system
|
||||||
|
info["node_name"] = uname.node
|
||||||
|
|
||||||
|
cpu_model = None
|
||||||
|
if os.path.exists("/proc/cpuinfo"):
|
||||||
|
try:
|
||||||
|
with open("/proc/cpuinfo", "r") as f:
|
||||||
|
for line in f:
|
||||||
|
if "model name" in line:
|
||||||
|
cpu_model = line.split(":", 1)[1].strip()
|
||||||
|
break
|
||||||
|
except Exception as e:
|
||||||
|
cpu_model = f"Error: {e}"
|
||||||
|
info["cpu_model"] = cpu_model
|
||||||
|
|
||||||
|
mem_total_gb = None
|
||||||
|
if os.path.exists("/proc/meminfo"):
|
||||||
|
try:
|
||||||
|
with open("/proc/meminfo", "r") as f:
|
||||||
|
for line in f:
|
||||||
|
if "MemTotal" in line:
|
||||||
|
mem_kb = float(line.split(":", 1)[1].split()[0])
|
||||||
|
mem_total_gb = round(mem_kb / (1024 * 1024), 2)
|
||||||
|
break
|
||||||
|
except Exception as e:
|
||||||
|
mem_total_gb = f"Error: {e}"
|
||||||
|
info["memory_size_GB"] = mem_total_gb
|
||||||
|
|
||||||
|
info["cpu_core_count"] = os.cpu_count()
|
||||||
|
|
||||||
|
sockets = set()
|
||||||
|
if os.path.exists("/proc/cpuinfo"):
|
||||||
|
try:
|
||||||
|
with open("/proc/cpuinfo", "r") as f:
|
||||||
|
for line in f:
|
||||||
|
if "physical id" in line:
|
||||||
|
sockets.add(line.split(":", 1)[1].strip())
|
||||||
|
except Exception:
|
||||||
|
sockets = set()
|
||||||
|
info["cpu_socket_count"] = len(sockets) if len(sockets) > 0 else 1
|
||||||
|
|
||||||
|
return info
|
||||||
|
|
||||||
|
|
||||||
|
script_path = os.path.abspath(__file__)
|
||||||
|
script_dir = os.path.dirname(script_path)
|
||||||
|
script_name = os.path.splitext(os.path.basename(script_path))[0]
|
||||||
|
json_path = os.path.join(script_dir, script_name + ".jsonl")
|
||||||
|
|
||||||
|
|
||||||
|
def record_results(result, filename=json_path):
|
||||||
|
with open(filename, "a") as f:
|
||||||
|
f.write(json.dumps(result) + "\n")
|
||||||
|
|
||||||
|
|
||||||
|
def allocate_weights():
|
||||||
|
per_mat_weight_bytes = (hidden_size * intermediate_size) // 2
|
||||||
|
per_mat_scale_elems = (hidden_size * intermediate_size) // group_size
|
||||||
|
|
||||||
|
gate_q = torch.randint(0, 256, (expert_num * per_mat_weight_bytes,), dtype=torch.uint8)
|
||||||
|
up_q = torch.randint(0, 256, (expert_num * per_mat_weight_bytes,), dtype=torch.uint8)
|
||||||
|
down_q = torch.randint(0, 256, (expert_num * per_mat_weight_bytes,), dtype=torch.uint8)
|
||||||
|
|
||||||
|
gate_scale = torch.randn(expert_num * per_mat_scale_elems, dtype=torch.bfloat16)
|
||||||
|
up_scale = torch.randn(expert_num * per_mat_scale_elems, dtype=torch.bfloat16)
|
||||||
|
down_scale = torch.randn(expert_num * per_mat_scale_elems, dtype=torch.bfloat16)
|
||||||
|
|
||||||
|
return (
|
||||||
|
gate_q.contiguous(),
|
||||||
|
up_q.contiguous(),
|
||||||
|
down_q.contiguous(),
|
||||||
|
gate_scale.contiguous(),
|
||||||
|
up_scale.contiguous(),
|
||||||
|
down_scale.contiguous(),
|
||||||
|
per_mat_weight_bytes,
|
||||||
|
per_mat_scale_elems,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def build_moe():
|
||||||
|
(
|
||||||
|
gate_q,
|
||||||
|
up_q,
|
||||||
|
down_q,
|
||||||
|
gate_scale,
|
||||||
|
up_scale,
|
||||||
|
down_scale,
|
||||||
|
per_mat_weight_bytes,
|
||||||
|
per_mat_scale_elems,
|
||||||
|
) = allocate_weights()
|
||||||
|
|
||||||
|
config = kt_kernel_ext.moe.MOEConfig(
|
||||||
|
expert_num, num_experts_per_tok, hidden_size, intermediate_size
|
||||||
|
)
|
||||||
|
config.max_len = max_len
|
||||||
|
config.quant_config.bits = 4
|
||||||
|
config.quant_config.group_size = group_size
|
||||||
|
config.quant_config.zero_point = False
|
||||||
|
config.pool = CPUInfer.backend_
|
||||||
|
|
||||||
|
config.gate_proj = gate_q.data_ptr()
|
||||||
|
config.up_proj = up_q.data_ptr()
|
||||||
|
config.down_proj = down_q.data_ptr()
|
||||||
|
config.gate_scale = gate_scale.data_ptr()
|
||||||
|
config.up_scale = up_scale.data_ptr()
|
||||||
|
config.down_scale = down_scale.data_ptr()
|
||||||
|
|
||||||
|
moe = kt_kernel_ext.moe.AMXInt4_KGroup_MOE(config)
|
||||||
|
CPUInfer.submit(moe.load_weights_task(physical_to_logical_map.data_ptr()))
|
||||||
|
CPUInfer.sync()
|
||||||
|
|
||||||
|
# Buffer sizing per TP
|
||||||
|
weight_bytes_per_expert_per_tp = per_mat_weight_bytes // gpu_tp_count
|
||||||
|
scale_elems_per_expert_per_tp = per_mat_scale_elems // gpu_tp_count
|
||||||
|
total_weight_bytes_per_tp = gpu_experts_num * weight_bytes_per_expert_per_tp
|
||||||
|
total_scale_elems_per_tp = gpu_experts_num * scale_elems_per_expert_per_tp
|
||||||
|
|
||||||
|
w13_weight_bufs = [
|
||||||
|
torch.empty(2 * total_weight_bytes_per_tp, dtype=torch.uint8) for _ in range(gpu_tp_count)
|
||||||
|
]
|
||||||
|
w13_scale_bufs = [
|
||||||
|
torch.empty(2 * total_scale_elems_per_tp, dtype=torch.bfloat16) for _ in range(gpu_tp_count)
|
||||||
|
]
|
||||||
|
w2_weight_bufs = [
|
||||||
|
torch.empty(total_weight_bytes_per_tp, dtype=torch.uint8) for _ in range(gpu_tp_count)
|
||||||
|
]
|
||||||
|
w2_scale_bufs = [
|
||||||
|
torch.empty(total_scale_elems_per_tp, dtype=torch.bfloat16) for _ in range(gpu_tp_count)
|
||||||
|
]
|
||||||
|
|
||||||
|
buffer_ptrs = {
|
||||||
|
"w13_weight_ptrs": [buf.data_ptr() for buf in w13_weight_bufs],
|
||||||
|
"w13_scale_ptrs": [buf.data_ptr() for buf in w13_scale_bufs],
|
||||||
|
"w2_weight_ptrs": [buf.data_ptr() for buf in w2_weight_bufs],
|
||||||
|
"w2_scale_ptrs": [buf.data_ptr() for buf in w2_scale_bufs],
|
||||||
|
}
|
||||||
|
|
||||||
|
buffer_shapes = {
|
||||||
|
"per_mat_weight_bytes": per_mat_weight_bytes,
|
||||||
|
"per_mat_scale_elems": per_mat_scale_elems,
|
||||||
|
"weight_bytes_per_expert_per_tp": weight_bytes_per_expert_per_tp,
|
||||||
|
"scale_elems_per_expert_per_tp": scale_elems_per_expert_per_tp,
|
||||||
|
"total_weight_bytes_per_tp": total_weight_bytes_per_tp,
|
||||||
|
"total_scale_elems_per_tp": total_scale_elems_per_tp,
|
||||||
|
}
|
||||||
|
|
||||||
|
keep_tensors = {
|
||||||
|
"gate_q": gate_q,
|
||||||
|
"up_q": up_q,
|
||||||
|
"down_q": down_q,
|
||||||
|
"gate_scale": gate_scale,
|
||||||
|
"up_scale": up_scale,
|
||||||
|
"down_scale": down_scale,
|
||||||
|
"w13_weight_bufs": w13_weight_bufs,
|
||||||
|
"w13_scale_bufs": w13_scale_bufs,
|
||||||
|
"w2_weight_bufs": w2_weight_bufs,
|
||||||
|
"w2_scale_bufs": w2_scale_bufs,
|
||||||
|
}
|
||||||
|
|
||||||
|
return moe, buffer_ptrs, buffer_shapes, keep_tensors
|
||||||
|
|
||||||
|
|
||||||
|
def bench_write_buffer():
|
||||||
|
moe, buffer_ptrs, buffer_shapes, keep_tensors = build_moe()
|
||||||
|
|
||||||
|
total_weights = hidden_size * intermediate_size * expert_num * 3
|
||||||
|
# Throughput accounting consistent with examples/test_k2_write_buffer.py
|
||||||
|
bytes_per_call = total_weights // group_size + total_weights // 2
|
||||||
|
|
||||||
|
# Warm-up
|
||||||
|
for _ in tqdm(range(warm_up_iter), desc="Warm-up"):
|
||||||
|
CPUInfer.submit(
|
||||||
|
moe.write_weight_scale_to_buffer_task(
|
||||||
|
gpu_tp_count=gpu_tp_count,
|
||||||
|
gpu_experts_num=gpu_experts_num,
|
||||||
|
**buffer_ptrs,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
CPUInfer.sync()
|
||||||
|
|
||||||
|
total_time = 0
|
||||||
|
for _ in tqdm(range(test_iter), desc="Testing"):
|
||||||
|
start = time.perf_counter()
|
||||||
|
CPUInfer.submit(
|
||||||
|
moe.write_weight_scale_to_buffer_task(
|
||||||
|
gpu_tp_count=gpu_tp_count,
|
||||||
|
gpu_experts_num=gpu_experts_num,
|
||||||
|
**buffer_ptrs,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
CPUInfer.sync()
|
||||||
|
end = time.perf_counter()
|
||||||
|
total_time += end - start
|
||||||
|
time.sleep(0.6)
|
||||||
|
print(end - start)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
time_per_iter_us = total_time / test_iter * 1e6
|
||||||
|
bandwidth_gbs = bytes_per_call * test_iter / total_time / 1e9
|
||||||
|
|
||||||
|
print("write_weight_scale_to_buffer benchmark")
|
||||||
|
print("Time(s): ", total_time)
|
||||||
|
print("Iteration: ", test_iter)
|
||||||
|
print("Time(us) per iteration: ", time_per_iter_us)
|
||||||
|
print("Bandwidth: ", bandwidth_gbs, "GB/s")
|
||||||
|
print("")
|
||||||
|
|
||||||
|
result = {
|
||||||
|
"op": "write_weight_scale_to_buffer",
|
||||||
|
"total_time_seconds": total_time,
|
||||||
|
"iterations": test_iter,
|
||||||
|
"time_per_iteration_us": time_per_iter_us,
|
||||||
|
"bandwidth_GBs": bandwidth_gbs,
|
||||||
|
"timestamp": time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()),
|
||||||
|
"test_parameters": {
|
||||||
|
"expert_num": expert_num,
|
||||||
|
"hidden_size": hidden_size,
|
||||||
|
"intermediate_size": intermediate_size,
|
||||||
|
"group_size": group_size,
|
||||||
|
"max_len": max_len,
|
||||||
|
"num_experts_per_tok": num_experts_per_tok,
|
||||||
|
"gpu_tp_count": gpu_tp_count,
|
||||||
|
"gpu_experts_num": gpu_experts_num,
|
||||||
|
"warm_up_iter": warm_up_iter,
|
||||||
|
"test_iter": test_iter,
|
||||||
|
"bytes_per_call": bytes_per_call,
|
||||||
|
},
|
||||||
|
"buffer_shapes": buffer_shapes,
|
||||||
|
"keep_tensors_alive": list(keep_tensors.keys()),
|
||||||
|
}
|
||||||
|
result.update(get_git_commit())
|
||||||
|
result.update(get_system_info())
|
||||||
|
record_results(result)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
bench_write_buffer()
|
||||||
319
kt-kernel/examples/test_k2_moe_amx.py
Normal file
319
kt-kernel/examples/test_k2_moe_amx.py
Normal file
|
|
@ -0,0 +1,319 @@
|
||||||
|
import math
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
from typing import Dict, Literal
|
||||||
|
|
||||||
|
sys.path.insert(0, os.path.dirname(__file__) + "/../build")
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import kt_kernel_ext
|
||||||
|
|
||||||
|
torch.manual_seed(42)
|
||||||
|
|
||||||
|
hidden_size = 7168
|
||||||
|
intermediate_size = 2048
|
||||||
|
max_len = 25600
|
||||||
|
|
||||||
|
expert_num = 16
|
||||||
|
num_experts_per_tok = 8
|
||||||
|
|
||||||
|
qlen = 1
|
||||||
|
layer_num = 1
|
||||||
|
CPUInfer = kt_kernel_ext.CPUInfer(40)
|
||||||
|
validation_iter = 10
|
||||||
|
k_group_size = 32
|
||||||
|
debug_print_count = 16
|
||||||
|
|
||||||
|
physical_to_logical_map = torch.tensor(data=range(expert_num), device="cpu", dtype=torch.int64).contiguous()
|
||||||
|
|
||||||
|
|
||||||
|
def _pattern_uniform(groups: int) -> torch.Tensor:
|
||||||
|
return torch.full((groups,), 0.02, dtype=torch.float32)
|
||||||
|
|
||||||
|
|
||||||
|
def _pattern_alternating(groups: int) -> torch.Tensor:
|
||||||
|
vals = torch.full((groups,), 0.015, dtype=torch.float32)
|
||||||
|
vals[1::2] = 0.03
|
||||||
|
return vals
|
||||||
|
|
||||||
|
|
||||||
|
def _pattern_ramp(groups: int) -> torch.Tensor:
|
||||||
|
return torch.linspace(0.005, 0.04, steps=groups, dtype=torch.float32)
|
||||||
|
|
||||||
|
|
||||||
|
WEIGHT_PATTERNS = {
|
||||||
|
"uniform_scale": ("All k-groups share the same abs max / scale", _pattern_uniform),
|
||||||
|
"alternating_scale": ("Alternate small / large abs max per k-group", _pattern_alternating),
|
||||||
|
"ramp_scale": ("Linearly increasing abs max per k-group", _pattern_ramp),
|
||||||
|
"random": ("Random bf16 weights (baseline)", None),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def act_fn(x):
|
||||||
|
return x / (1.0 + torch.exp(-x))
|
||||||
|
|
||||||
|
|
||||||
|
def mlp_torch(input, gate_proj, up_proj, down_proj):
|
||||||
|
gate_buf = torch.mm(input, gate_proj.t())
|
||||||
|
up_buf = torch.mm(input, up_proj.t())
|
||||||
|
print(f"gate_buf: {gate_buf}")
|
||||||
|
print(f"up_buf: {up_buf}")
|
||||||
|
intermediate = act_fn(gate_buf) * up_buf
|
||||||
|
ret = torch.mm(intermediate, down_proj.t())
|
||||||
|
print(f"intermediate: {intermediate}")
|
||||||
|
print(f"mlp output: {ret}")
|
||||||
|
return ret
|
||||||
|
|
||||||
|
|
||||||
|
def moe_torch(input, expert_ids, weights, gate_proj, up_proj, down_proj):
|
||||||
|
cnts = expert_ids.new_zeros((expert_ids.shape[0], expert_num))
|
||||||
|
cnts.scatter_(1, expert_ids, 1)
|
||||||
|
tokens_per_expert = cnts.sum(dim=0)
|
||||||
|
idxs = expert_ids.view(-1).argsort()
|
||||||
|
sorted_tokens = input[idxs // expert_ids.shape[1]]
|
||||||
|
|
||||||
|
outputs = []
|
||||||
|
start_idx = 0
|
||||||
|
|
||||||
|
for i, num_tokens in enumerate(tokens_per_expert):
|
||||||
|
end_idx = start_idx + num_tokens
|
||||||
|
if num_tokens == 0:
|
||||||
|
continue
|
||||||
|
tokens_for_this_expert = sorted_tokens[start_idx:end_idx]
|
||||||
|
expert_out = mlp_torch(tokens_for_this_expert, gate_proj[i], up_proj[i], down_proj[i])
|
||||||
|
outputs.append(expert_out)
|
||||||
|
start_idx = end_idx
|
||||||
|
|
||||||
|
outs = torch.cat(outputs, dim=0) if len(outputs) else sorted_tokens.new_empty(0)
|
||||||
|
|
||||||
|
new_x = torch.empty_like(outs)
|
||||||
|
new_x[idxs] = outs
|
||||||
|
t_output = (
|
||||||
|
new_x.view(*expert_ids.shape, -1)
|
||||||
|
.type(weights.dtype)
|
||||||
|
.mul_(weights.unsqueeze(dim=-1))
|
||||||
|
.sum(dim=1)
|
||||||
|
.type(new_x.dtype)
|
||||||
|
)
|
||||||
|
return t_output
|
||||||
|
|
||||||
|
|
||||||
|
def pack_to_int32(value: torch.Tensor, num_bits: int, packed_dim: Literal[0, 1] = 1) -> torch.Tensor:
|
||||||
|
if value.dtype is not torch.int8:
|
||||||
|
raise ValueError("Tensor must be torch.int8 before packing")
|
||||||
|
if not (1 <= num_bits <= 8):
|
||||||
|
raise ValueError(f"num_bits must be in [1, 8], got {num_bits}")
|
||||||
|
|
||||||
|
offset = 1 << (num_bits - 1)
|
||||||
|
value = (value + offset).to(torch.uint8)
|
||||||
|
device = value.device
|
||||||
|
|
||||||
|
pack_factor = 32 // num_bits
|
||||||
|
|
||||||
|
if packed_dim == 0:
|
||||||
|
value = value.transpose(0, 1)
|
||||||
|
|
||||||
|
rows, cols = value.shape
|
||||||
|
padded_cols = math.ceil(cols / pack_factor) * pack_factor
|
||||||
|
pad_len = padded_cols - cols
|
||||||
|
|
||||||
|
if pad_len > 0:
|
||||||
|
value = torch.nn.functional.pad(value, (0, pad_len))
|
||||||
|
|
||||||
|
num_groups = padded_cols // pack_factor
|
||||||
|
|
||||||
|
# Use int32 here
|
||||||
|
reshaped = value.view(rows, num_groups, pack_factor).to(torch.int32)
|
||||||
|
bit_shifts = torch.arange(pack_factor, device=device, dtype=torch.int32) * num_bits
|
||||||
|
packed = (reshaped << bit_shifts).sum(dim=2, dtype=torch.int32)
|
||||||
|
|
||||||
|
if packed_dim == 0:
|
||||||
|
packed = packed.transpose(0, 1)
|
||||||
|
|
||||||
|
return packed
|
||||||
|
|
||||||
|
def pack_tensor_per_row(q: torch.Tensor, num_bits: int) -> torch.Tensor:
|
||||||
|
e, rows, cols = q.shape
|
||||||
|
flat = q.view(e * rows, cols)
|
||||||
|
packed = pack_to_int32(flat, num_bits)
|
||||||
|
return packed.view(e, rows, -1).contiguous()
|
||||||
|
|
||||||
|
|
||||||
|
def quantize_k2_tensor(weights: torch.Tensor, group_size: int):
|
||||||
|
"""
|
||||||
|
Symmetric max-abs/7 quantization per k-group following compressed_tensors packing.
|
||||||
|
Args:
|
||||||
|
weights: [expert_num, rows (N), cols (K)]
|
||||||
|
Returns:
|
||||||
|
packed_q: int32 tensor storing 8 int4s per element with shape [expert_num, rows * (cols // 8)]
|
||||||
|
scales: bfloat16 tensor with shape [expert_num, rows * (cols // group_size)]
|
||||||
|
"""
|
||||||
|
weights_f32 = weights.to(torch.float32)
|
||||||
|
e, rows, cols = weights_f32.shape
|
||||||
|
if cols % group_size != 0 or cols % 2 != 0:
|
||||||
|
raise ValueError(f"cols ({cols}) must be divisible by group_size ({group_size}) and 2")
|
||||||
|
|
||||||
|
reshaped = weights_f32.view(e, rows, cols // group_size, group_size)
|
||||||
|
max_abs = reshaped.abs().amax(dim=-1, keepdim=True)
|
||||||
|
max_abs = torch.clamp(max_abs, min=1e-8)
|
||||||
|
scales = (max_abs / 7.0).squeeze(-1)
|
||||||
|
q = torch.round(reshaped / scales.unsqueeze(-1)).clamp(-8, 7).to(torch.int8)
|
||||||
|
q = q.view(e, rows, cols)
|
||||||
|
packed = pack_tensor_per_row(q, num_bits=4).view(e, rows, cols // 8).contiguous()
|
||||||
|
scales = scales.to(torch.bfloat16).contiguous().view(e, rows, cols // group_size).contiguous()
|
||||||
|
|
||||||
|
print(f"Quantized weights: {packed.shape}, scales: {scales.shape}")
|
||||||
|
print(f"Quantized tensors: \n{packed},\n {scales}")
|
||||||
|
return packed, scales
|
||||||
|
|
||||||
|
|
||||||
|
def build_structured_tensor(shape: torch.Size, pattern: str) -> torch.Tensor:
|
||||||
|
if pattern == "random":
|
||||||
|
torch.manual_seed(42)
|
||||||
|
return (torch.randn(shape, dtype=torch.bfloat16, device="cpu") / 100.0).contiguous()
|
||||||
|
|
||||||
|
e, rows, cols = shape
|
||||||
|
groups = cols // k_group_size
|
||||||
|
group_builder = WEIGHT_PATTERNS[pattern][1]
|
||||||
|
group_vals = group_builder(groups).to(torch.float32)
|
||||||
|
block = group_vals.view(1, 1, groups, 1).expand(e, rows, groups, k_group_size).clone()
|
||||||
|
row_signs = torch.where(
|
||||||
|
(torch.arange(rows) % 2 == 0),
|
||||||
|
torch.ones(rows, dtype=torch.float32),
|
||||||
|
-torch.ones(rows, dtype=torch.float32),
|
||||||
|
).view(1, rows, 1, 1)
|
||||||
|
col_offsets = torch.linspace(-0.0005, 0.0005, steps=k_group_size, dtype=torch.float32).view(1, 1, 1, k_group_size)
|
||||||
|
block = block * row_signs + col_offsets
|
||||||
|
return block.reshape(shape).to(torch.bfloat16).contiguous()
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_k2_quantized_weights(pattern: str) -> Dict[str, torch.Tensor]:
|
||||||
|
if pattern not in WEIGHT_PATTERNS:
|
||||||
|
raise ValueError(f"Unknown weight pattern: {pattern}")
|
||||||
|
|
||||||
|
gate_proj = build_structured_tensor((expert_num, intermediate_size, hidden_size), pattern)
|
||||||
|
up_proj = build_structured_tensor((expert_num, intermediate_size, hidden_size), pattern)
|
||||||
|
down_proj = build_structured_tensor((expert_num, hidden_size, intermediate_size), pattern)
|
||||||
|
|
||||||
|
gate_q, gate_scales = quantize_k2_tensor(gate_proj, k_group_size)
|
||||||
|
up_q, up_scales = quantize_k2_tensor(up_proj, k_group_size)
|
||||||
|
down_q, down_scales = quantize_k2_tensor(down_proj, k_group_size)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"gate_qweight": gate_q.contiguous(),
|
||||||
|
"up_qweight": up_q.contiguous(),
|
||||||
|
"down_qweight": down_q.contiguous(),
|
||||||
|
"gate_scales": gate_scales.contiguous(),
|
||||||
|
"up_scales": up_scales.contiguous(),
|
||||||
|
"down_scales": down_scales.contiguous(),
|
||||||
|
"original_fp16": {
|
||||||
|
"gate_proj": gate_proj.to(torch.float16).contiguous(),
|
||||||
|
"up_proj": up_proj.to(torch.float16).contiguous(),
|
||||||
|
"down_proj": down_proj.to(torch.float16).contiguous(),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def build_moes_from_quantized_data(quant_data: Dict[str, torch.Tensor]):
|
||||||
|
moes = []
|
||||||
|
with torch.inference_mode(mode=True):
|
||||||
|
for _ in range(layer_num):
|
||||||
|
config = kt_kernel_ext.moe.MOEConfig(expert_num, num_experts_per_tok, hidden_size, intermediate_size, 0)
|
||||||
|
config.max_len = max_len
|
||||||
|
config.quant_config.bits = 4
|
||||||
|
config.quant_config.group_size = k_group_size
|
||||||
|
config.quant_config.zero_point = False
|
||||||
|
|
||||||
|
config.gate_proj = quant_data["gate_qweight"].data_ptr()
|
||||||
|
config.up_proj = quant_data["up_qweight"].data_ptr()
|
||||||
|
config.down_proj = quant_data["down_qweight"].data_ptr()
|
||||||
|
|
||||||
|
config.gate_scale = quant_data["gate_scales"].data_ptr()
|
||||||
|
config.up_scale = quant_data["up_scales"].data_ptr()
|
||||||
|
config.down_scale = quant_data["down_scales"].data_ptr()
|
||||||
|
config.pool = CPUInfer.backend_
|
||||||
|
|
||||||
|
moe = kt_kernel_ext.moe.AMXInt4_KGroup_MOE(config)
|
||||||
|
CPUInfer.submit(moe.load_weights_task(physical_to_logical_map.data_ptr()))
|
||||||
|
CPUInfer.sync()
|
||||||
|
# CPUInfer.submit(moe.warm_up_task())
|
||||||
|
# CPUInfer.sync()
|
||||||
|
moes.append(moe)
|
||||||
|
return moes
|
||||||
|
|
||||||
|
|
||||||
|
def run_case(pattern: str) -> Dict[str, float]:
|
||||||
|
print("\n" + "=" * 70)
|
||||||
|
desc = WEIGHT_PATTERNS[pattern][0]
|
||||||
|
print(f"Running case: {pattern} -> {desc}")
|
||||||
|
print("=" * 70)
|
||||||
|
|
||||||
|
quant_data = prepare_k2_quantized_weights(pattern)
|
||||||
|
moes = build_moes_from_quantized_data(quant_data)
|
||||||
|
|
||||||
|
original_weights = quant_data["original_fp16"]
|
||||||
|
gate_fp16 = original_weights["gate_proj"]
|
||||||
|
up_fp16 = original_weights["up_proj"]
|
||||||
|
down_fp16 = original_weights["down_proj"]
|
||||||
|
|
||||||
|
diffs = []
|
||||||
|
with torch.inference_mode(mode=True):
|
||||||
|
for i in range(validation_iter):
|
||||||
|
torch.manual_seed(100 + i)
|
||||||
|
bsz_tensor = torch.tensor([qlen], device="cpu")
|
||||||
|
expert_ids = torch.stack(
|
||||||
|
[torch.randperm(expert_num)[:num_experts_per_tok] for _ in range(qlen)]
|
||||||
|
).contiguous()
|
||||||
|
weights = torch.randn((qlen, num_experts_per_tok), dtype=torch.float32).contiguous()
|
||||||
|
input_tensor = torch.randn((qlen, hidden_size), dtype=torch.bfloat16).contiguous() / 100
|
||||||
|
output = torch.empty((qlen, hidden_size), dtype=torch.bfloat16).contiguous()
|
||||||
|
|
||||||
|
moe = moes[i % layer_num]
|
||||||
|
CPUInfer.submit(
|
||||||
|
moe.forward_task(
|
||||||
|
bsz_tensor.data_ptr(),
|
||||||
|
num_experts_per_tok,
|
||||||
|
expert_ids.data_ptr(),
|
||||||
|
weights.data_ptr(),
|
||||||
|
input_tensor.data_ptr(),
|
||||||
|
output.data_ptr(),
|
||||||
|
False,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
CPUInfer.sync()
|
||||||
|
|
||||||
|
input_tensor_fp16 = input_tensor.to(torch.float16)
|
||||||
|
t_output = moe_torch(
|
||||||
|
input_tensor_fp16, expert_ids, weights, gate_fp16, up_fp16, down_fp16
|
||||||
|
).to(torch.bfloat16)
|
||||||
|
|
||||||
|
t_output = t_output.flatten()
|
||||||
|
output = output.flatten()
|
||||||
|
|
||||||
|
diff = torch.mean(torch.abs(output - t_output)) / (torch.mean(torch.abs(t_output)) + 1e-12)
|
||||||
|
diffs.append(diff.item())
|
||||||
|
print(f"[{pattern}] Iteration {i}: relative L1 diff = {diff:.4f}")
|
||||||
|
print(f" output {output}")
|
||||||
|
print(f" t_output {t_output}")
|
||||||
|
|
||||||
|
mean_diff = float(sum(diffs) / len(diffs))
|
||||||
|
max_diff = float(max(diffs))
|
||||||
|
min_diff = float(min(diffs))
|
||||||
|
return {"case": pattern, "description": desc, "mean": mean_diff, "max": max_diff, "min": min_diff}
|
||||||
|
|
||||||
|
|
||||||
|
def run_k2_moe_test():
|
||||||
|
summary_rows = []
|
||||||
|
for case_name in WEIGHT_PATTERNS.keys():
|
||||||
|
results = run_case(case_name)
|
||||||
|
summary_rows.append(results)
|
||||||
|
# break
|
||||||
|
|
||||||
|
print("\n=== Case vs. Relative Error Summary ===")
|
||||||
|
print(f"{'Case':<20} {'Mean':>10} {'Max':>10} {'Min':>10}")
|
||||||
|
for row in summary_rows:
|
||||||
|
print(f"{row['case']:<20} {row['mean']*100:9.2f}% {row['max']*100:9.2f}% {row['min']*100:9.2f}%")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
run_k2_moe_test()
|
||||||
267
kt-kernel/examples/test_k2_write_buffer.py
Normal file
267
kt-kernel/examples/test_k2_write_buffer.py
Normal file
|
|
@ -0,0 +1,267 @@
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
# Ensure we can import the local extension
|
||||||
|
# REPO_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "../"))
|
||||||
|
# if REPO_ROOT not in sys.path:
|
||||||
|
# sys.path.insert(0, REPO_ROOT)
|
||||||
|
|
||||||
|
import kt_kernel_ext
|
||||||
|
from kt_kernel_ext import CPUInfer
|
||||||
|
|
||||||
|
|
||||||
|
def make_cpu_infer(thread_num=80):
|
||||||
|
return CPUInfer(thread_num)
|
||||||
|
|
||||||
|
|
||||||
|
def build_config(cpuinfer, expert_num, num_experts_per_tok, hidden_size, intermediate_size, group_size):
|
||||||
|
cfg = kt_kernel_ext.moe.MOEConfig(expert_num, num_experts_per_tok, hidden_size, intermediate_size)
|
||||||
|
cfg.max_len = 1
|
||||||
|
cfg.quant_config.bits = 4
|
||||||
|
cfg.quant_config.group_size = group_size
|
||||||
|
cfg.quant_config.zero_point = False
|
||||||
|
cfg.pool = cpuinfer.backend_
|
||||||
|
return cfg
|
||||||
|
|
||||||
|
|
||||||
|
def allocate_weights(expert_num, hidden_size, intermediate_size, group_size):
|
||||||
|
# packed int4 weights: 2 values per byte
|
||||||
|
per_mat_weight_bytes = (hidden_size * intermediate_size) // 2
|
||||||
|
per_mat_scale_elems = (hidden_size * intermediate_size) // group_size
|
||||||
|
|
||||||
|
gate_q = torch.randint(0, 256, (expert_num * per_mat_weight_bytes,), dtype=torch.uint8)
|
||||||
|
up_q = torch.randint(0, 256, (expert_num * per_mat_weight_bytes,), dtype=torch.uint8)
|
||||||
|
down_q = torch.randint(0, 256, (expert_num * per_mat_weight_bytes,), dtype=torch.uint8)
|
||||||
|
|
||||||
|
gate_scale = torch.randn(expert_num * per_mat_scale_elems, dtype=torch.bfloat16)
|
||||||
|
up_scale = torch.randn(expert_num * per_mat_scale_elems, dtype=torch.bfloat16)
|
||||||
|
down_scale = torch.randn(expert_num * per_mat_scale_elems, dtype=torch.bfloat16)
|
||||||
|
|
||||||
|
return (
|
||||||
|
gate_q,
|
||||||
|
up_q,
|
||||||
|
down_q,
|
||||||
|
gate_scale,
|
||||||
|
up_scale,
|
||||||
|
down_scale,
|
||||||
|
per_mat_weight_bytes,
|
||||||
|
per_mat_scale_elems,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
torch.manual_seed(123)
|
||||||
|
|
||||||
|
expert_num = 256 # Total experts
|
||||||
|
gpu_experts = expert_num # Number of experts on GPU
|
||||||
|
gpu_tp_count = 2 # Number of TP parts
|
||||||
|
|
||||||
|
num_experts_per_tok = 8
|
||||||
|
hidden_size = 7168
|
||||||
|
intermediate_size = 2048
|
||||||
|
group_size = 32
|
||||||
|
|
||||||
|
cpuinfer = make_cpu_infer()
|
||||||
|
cfg = build_config(cpuinfer, expert_num, num_experts_per_tok, hidden_size, intermediate_size, group_size)
|
||||||
|
|
||||||
|
(
|
||||||
|
gate_q,
|
||||||
|
up_q,
|
||||||
|
down_q,
|
||||||
|
gate_scale,
|
||||||
|
up_scale,
|
||||||
|
down_scale,
|
||||||
|
per_mat_weight_bytes,
|
||||||
|
per_mat_scale_elems,
|
||||||
|
) = allocate_weights(expert_num, hidden_size, intermediate_size, group_size)
|
||||||
|
|
||||||
|
cfg.gate_proj = gate_q.data_ptr()
|
||||||
|
cfg.up_proj = up_q.data_ptr()
|
||||||
|
cfg.down_proj = down_q.data_ptr()
|
||||||
|
cfg.gate_scale = gate_scale.data_ptr()
|
||||||
|
cfg.up_scale = up_scale.data_ptr()
|
||||||
|
cfg.down_scale = down_scale.data_ptr()
|
||||||
|
|
||||||
|
moe = kt_kernel_ext.moe.AMXInt4_KGroup_MOE(cfg)
|
||||||
|
|
||||||
|
physical_to_logical_map = (
|
||||||
|
torch.arange(expert_num, dtype=torch.int64, device="cpu").contiguous()
|
||||||
|
)
|
||||||
|
cpuinfer.submit(moe.load_weights_task(physical_to_logical_map.data_ptr()))
|
||||||
|
cpuinfer.sync()
|
||||||
|
|
||||||
|
# TP configuration
|
||||||
|
|
||||||
|
# Since weights are col-major, we can directly divide the total size by tp_count
|
||||||
|
# Each matrix is divided into gpu_tp_count parts in memory order
|
||||||
|
|
||||||
|
# Calculate sizes per TP part (direct division since col-major)
|
||||||
|
weight_bytes_per_expert_per_tp = per_mat_weight_bytes // gpu_tp_count
|
||||||
|
scale_elems_per_expert_per_tp = per_mat_scale_elems // gpu_tp_count
|
||||||
|
|
||||||
|
# Total sizes for all gpu_experts
|
||||||
|
total_weight_bytes_per_tp = gpu_experts * weight_bytes_per_expert_per_tp
|
||||||
|
total_scale_elems_per_tp = gpu_experts * scale_elems_per_expert_per_tp
|
||||||
|
|
||||||
|
# Create buffer lists for w13 (gate+up) and w2 (down)
|
||||||
|
w13_weight_bufs = []
|
||||||
|
w13_scale_bufs = []
|
||||||
|
w2_weight_bufs = []
|
||||||
|
w2_scale_bufs = []
|
||||||
|
|
||||||
|
for tp_idx in range(gpu_tp_count):
|
||||||
|
# w13 combines gate and up, so needs 2x the size
|
||||||
|
w13_weight_bufs.append(torch.empty(2 * total_weight_bytes_per_tp, dtype=torch.uint8))
|
||||||
|
w13_scale_bufs.append(torch.empty(2 * total_scale_elems_per_tp, dtype=torch.bfloat16))
|
||||||
|
w2_weight_bufs.append(torch.empty(total_weight_bytes_per_tp, dtype=torch.uint8))
|
||||||
|
w2_scale_bufs.append(torch.empty(total_scale_elems_per_tp, dtype=torch.bfloat16))
|
||||||
|
|
||||||
|
# Get data pointers for all buffers
|
||||||
|
w13_weight_ptrs = [buf.data_ptr() for buf in w13_weight_bufs]
|
||||||
|
w13_scale_ptrs = [buf.data_ptr() for buf in w13_scale_bufs]
|
||||||
|
w2_weight_ptrs = [buf.data_ptr() for buf in w2_weight_bufs]
|
||||||
|
w2_scale_ptrs = [buf.data_ptr() for buf in w2_scale_bufs]
|
||||||
|
|
||||||
|
print(f"Total experts: {expert_num}, GPU experts: {gpu_experts}")
|
||||||
|
print(f"GPU TP count: {gpu_tp_count}")
|
||||||
|
print(f"Original per matrix weight bytes: {per_mat_weight_bytes}")
|
||||||
|
print(f"Original per matrix scale elements: {per_mat_scale_elems}")
|
||||||
|
print(f"Weight bytes per expert per TP: {weight_bytes_per_expert_per_tp}")
|
||||||
|
print(f"Scale elements per expert per TP: {scale_elems_per_expert_per_tp}")
|
||||||
|
print(f"Total weight bytes per TP (w13): {2 * total_weight_bytes_per_tp}")
|
||||||
|
print(f"Total weight bytes per TP (w2): {total_weight_bytes_per_tp}")
|
||||||
|
print(f"Total scale elements per TP (w13): {2 * total_scale_elems_per_tp}")
|
||||||
|
print(f"Total scale elements per TP (w2): {total_scale_elems_per_tp}")
|
||||||
|
|
||||||
|
for i in range(5):
|
||||||
|
cpuinfer.submit(
|
||||||
|
moe.write_weight_scale_to_buffer_task(
|
||||||
|
gpu_tp_count=gpu_tp_count,
|
||||||
|
gpu_experts_num=gpu_experts,
|
||||||
|
w13_weight_ptrs=w13_weight_ptrs,
|
||||||
|
w13_scale_ptrs=w13_scale_ptrs,
|
||||||
|
w2_weight_ptrs=w2_weight_ptrs,
|
||||||
|
w2_scale_ptrs=w2_scale_ptrs,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
cpuinfer.sync()
|
||||||
|
|
||||||
|
begin_time = time.perf_counter_ns()
|
||||||
|
cpuinfer.submit(
|
||||||
|
moe.write_weight_scale_to_buffer_task(
|
||||||
|
gpu_tp_count=gpu_tp_count,
|
||||||
|
gpu_experts_num=gpu_experts,
|
||||||
|
w13_weight_ptrs=w13_weight_ptrs,
|
||||||
|
w13_scale_ptrs=w13_scale_ptrs,
|
||||||
|
w2_weight_ptrs=w2_weight_ptrs,
|
||||||
|
w2_scale_ptrs=w2_scale_ptrs,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
cpuinfer.sync()
|
||||||
|
end_time = time.perf_counter_ns()
|
||||||
|
elapsed_ms = (end_time - begin_time) / 1000000
|
||||||
|
total_weights = hidden_size * intermediate_size * expert_num * 3
|
||||||
|
total_bytes = total_weights // group_size + total_weights // 2
|
||||||
|
print(f"write_weight_scale_to_buffer time: {elapsed_ms:.2f} ms")
|
||||||
|
print(f"Throughput: {total_bytes / (elapsed_ms * 1e6):.2f} GB/s")
|
||||||
|
def split_expert_tensor(tensor, chunk):
|
||||||
|
"""Split tensor by experts"""
|
||||||
|
return [tensor[i * chunk : (i + 1) * chunk] for i in range(expert_num)]
|
||||||
|
|
||||||
|
# Split by experts first
|
||||||
|
gate_q_experts = split_expert_tensor(gate_q, per_mat_weight_bytes)
|
||||||
|
up_q_experts = split_expert_tensor(up_q, per_mat_weight_bytes)
|
||||||
|
down_q_experts = split_expert_tensor(down_q, per_mat_weight_bytes)
|
||||||
|
|
||||||
|
gate_scale_experts = split_expert_tensor(gate_scale, per_mat_scale_elems)
|
||||||
|
up_scale_experts = split_expert_tensor(up_scale, per_mat_scale_elems)
|
||||||
|
down_scale_experts = split_expert_tensor(down_scale, per_mat_scale_elems)
|
||||||
|
|
||||||
|
# CPU TP count is always 2 in this test setup (one TP per NUMA node)
|
||||||
|
cpu_tp_count = 2
|
||||||
|
|
||||||
|
# Verify buffers for each TP part
|
||||||
|
for tp_idx in range(gpu_tp_count):
|
||||||
|
expected_w13_weights = []
|
||||||
|
expected_w13_scales = []
|
||||||
|
expected_w2_weights = []
|
||||||
|
expected_w2_scales = []
|
||||||
|
|
||||||
|
weight13_per_tp = per_mat_weight_bytes // gpu_tp_count
|
||||||
|
scale13_per_tp = per_mat_scale_elems // gpu_tp_count
|
||||||
|
# Process each GPU expert
|
||||||
|
for expert_idx in range(gpu_experts):
|
||||||
|
# For w13 (gate and up), the slicing is straightforward
|
||||||
|
|
||||||
|
start_weight = tp_idx * weight13_per_tp
|
||||||
|
end_weight = (tp_idx + 1) * weight13_per_tp
|
||||||
|
start_scale = tp_idx * scale13_per_tp
|
||||||
|
end_scale = (tp_idx + 1) * scale13_per_tp
|
||||||
|
|
||||||
|
# Gate
|
||||||
|
gate_weight_tp = gate_q_experts[expert_idx][start_weight:end_weight]
|
||||||
|
gate_scale_tp = gate_scale_experts[expert_idx][start_scale:end_scale]
|
||||||
|
|
||||||
|
# Up
|
||||||
|
up_weight_tp = up_q_experts[expert_idx][start_weight:end_weight]
|
||||||
|
up_scale_tp = up_scale_experts[expert_idx][start_scale:end_scale]
|
||||||
|
|
||||||
|
# Down matrix needs special handling because it's sliced column-wise
|
||||||
|
# We need to reconstruct it from column slices
|
||||||
|
down_weight_tp_parts = []
|
||||||
|
down_scale_tp_parts = []
|
||||||
|
|
||||||
|
# Iterate through each column to extract the corresponding parts
|
||||||
|
for col_idx in range(hidden_size):
|
||||||
|
col_weight_start = col_idx * (intermediate_size // 2)
|
||||||
|
col_scale_start = col_idx * (intermediate_size // group_size)
|
||||||
|
|
||||||
|
# Direct mapping: each CPU TP corresponds to a GPU TP
|
||||||
|
tp_slice_weight_size = (intermediate_size // gpu_tp_count) // 2
|
||||||
|
tp_slice_scale_size = (intermediate_size // gpu_tp_count) // group_size
|
||||||
|
|
||||||
|
tp_weight_offset = col_weight_start + tp_idx * tp_slice_weight_size
|
||||||
|
tp_scale_offset = col_scale_start + tp_idx * tp_slice_scale_size
|
||||||
|
|
||||||
|
down_weight_tp_parts.append(
|
||||||
|
down_q_experts[expert_idx][tp_weight_offset:tp_weight_offset + tp_slice_weight_size]
|
||||||
|
)
|
||||||
|
down_scale_tp_parts.append(
|
||||||
|
down_scale_experts[expert_idx][tp_scale_offset:tp_scale_offset + tp_slice_scale_size]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Concatenate all column slices for this TP
|
||||||
|
down_weight_tp = torch.cat(down_weight_tp_parts)
|
||||||
|
down_scale_tp = torch.cat(down_scale_tp_parts)
|
||||||
|
|
||||||
|
expected_w13_weights.append(gate_weight_tp)
|
||||||
|
expected_w13_weights.append(up_weight_tp)
|
||||||
|
expected_w13_scales.append(gate_scale_tp)
|
||||||
|
expected_w13_scales.append(up_scale_tp)
|
||||||
|
expected_w2_weights.append(down_weight_tp)
|
||||||
|
expected_w2_scales.append(down_scale_tp)
|
||||||
|
|
||||||
|
# Concatenate all experts for this TP part
|
||||||
|
expected_w13_weight = torch.cat(expected_w13_weights)
|
||||||
|
expected_w13_scale = torch.cat(expected_w13_scales)
|
||||||
|
expected_w2_weight = torch.cat(expected_w2_weights)
|
||||||
|
expected_w2_scale = torch.cat(expected_w2_scales)
|
||||||
|
|
||||||
|
print(f"=== Checking TP part {tp_idx} ===")
|
||||||
|
|
||||||
|
# Assert all checks pass
|
||||||
|
assert torch.equal(w13_weight_bufs[tp_idx], expected_w13_weight), f"w13 weight bytes mismatch for TP {tp_idx}"
|
||||||
|
assert torch.allclose(w13_scale_bufs[tp_idx], expected_w13_scale), f"w13 scale values mismatch for TP {tp_idx}"
|
||||||
|
assert torch.equal(w2_weight_bufs[tp_idx], expected_w2_weight), f"w2 weight bytes mismatch for TP {tp_idx}"
|
||||||
|
assert torch.allclose(w2_scale_bufs[tp_idx], expected_w2_scale), f"w2 scale values mismatch for TP {tp_idx}"
|
||||||
|
|
||||||
|
print(f"\n✓ write_weight_scale_to_buffer passed: extracted {gpu_experts} GPU experts across {gpu_tp_count} TP parts from total {expert_num} experts")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
|
|
@ -36,6 +36,7 @@ static const bool _is_plain_ = false;
|
||||||
|
|
||||||
#if defined(__x86_64__) && defined(USE_AMX_AVX_KERNEL)
|
#if defined(__x86_64__) && defined(USE_AMX_AVX_KERNEL)
|
||||||
#include "operators/amx/awq-moe.hpp"
|
#include "operators/amx/awq-moe.hpp"
|
||||||
|
#include "operators/amx/k2-moe.hpp"
|
||||||
#include "operators/amx/la/amx_kernels.hpp"
|
#include "operators/amx/la/amx_kernels.hpp"
|
||||||
#include "operators/amx/moe.hpp"
|
#include "operators/amx/moe.hpp"
|
||||||
#endif
|
#endif
|
||||||
|
|
@ -43,6 +44,7 @@ static const bool _is_plain_ = false;
|
||||||
|
|
||||||
#include <cstdint>
|
#include <cstdint>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
|
#include <type_traits>
|
||||||
|
|
||||||
#include "operators/kvcache/kvcache.h"
|
#include "operators/kvcache/kvcache.h"
|
||||||
#include "operators/llamafile/linear.h"
|
#include "operators/llamafile/linear.h"
|
||||||
|
|
@ -225,7 +227,9 @@ void bind_moe_module(py::module_& moe_module, const char* name) {
|
||||||
using MoeClass = TP_MOE<MoeTP>;
|
using MoeClass = TP_MOE<MoeTP>;
|
||||||
using MoeBindings = MOEBindings<MoeTP>;
|
using MoeBindings = MOEBindings<MoeTP>;
|
||||||
|
|
||||||
py::class_<MoeClass, MoE_Interface, std::shared_ptr<MoeClass>>(moe_module, name)
|
auto moe_cls = py::class_<MoeClass, MoE_Interface, std::shared_ptr<MoeClass>>(moe_module, name);
|
||||||
|
|
||||||
|
moe_cls
|
||||||
.def(py::init<GeneralMOEConfig>())
|
.def(py::init<GeneralMOEConfig>())
|
||||||
.def("warm_up_task", &MoeBindings::WarmUpBindings::cpuinfer_interface)
|
.def("warm_up_task", &MoeBindings::WarmUpBindings::cpuinfer_interface)
|
||||||
.def("load_weights_task",
|
.def("load_weights_task",
|
||||||
|
|
@ -244,6 +248,53 @@ void bind_moe_module(py::module_& moe_module, const char* name) {
|
||||||
.def("warm_up", &MoeClass::warm_up)
|
.def("warm_up", &MoeClass::warm_up)
|
||||||
.def("load_weights", &MoeClass::load_weights)
|
.def("load_weights", &MoeClass::load_weights)
|
||||||
.def("forward", &MoeClass::forward_binding);
|
.def("forward", &MoeClass::forward_binding);
|
||||||
|
|
||||||
|
#if defined(__x86_64__) && defined(USE_AMX_AVX_KERNEL)
|
||||||
|
if constexpr (std::is_same_v<MoeTP, AMX_K2_MOE_TP<amx::GemmKernel224Int4SmallKGroup>>) {
|
||||||
|
struct WriteWeightScaleToBufferBindings {
|
||||||
|
struct Args {
|
||||||
|
CPUInfer* cpuinfer;
|
||||||
|
MoeClass* moe;
|
||||||
|
int gpu_tp_count;
|
||||||
|
int gpu_experts_num;
|
||||||
|
std::vector<uintptr_t> w13_weight_ptrs;
|
||||||
|
std::vector<uintptr_t> w13_scale_ptrs;
|
||||||
|
std::vector<uintptr_t> w2_weight_ptrs;
|
||||||
|
std::vector<uintptr_t> w2_scale_ptrs;
|
||||||
|
};
|
||||||
|
|
||||||
|
static void inner(void* args) {
|
||||||
|
Args* args_ = (Args*)args;
|
||||||
|
args_->cpuinfer->enqueue(&MoeClass::write_weight_scale_to_buffer, args_->moe,
|
||||||
|
args_->gpu_tp_count, args_->gpu_experts_num,
|
||||||
|
args_->w13_weight_ptrs, args_->w13_scale_ptrs,
|
||||||
|
args_->w2_weight_ptrs, args_->w2_scale_ptrs);
|
||||||
|
}
|
||||||
|
|
||||||
|
static std::pair<intptr_t, intptr_t> cpuinfer_interface(std::shared_ptr<MoeClass> moe,
|
||||||
|
int gpu_tp_count, int gpu_experts_num,
|
||||||
|
py::list w13_weight_ptrs, py::list w13_scale_ptrs,
|
||||||
|
py::list w2_weight_ptrs, py::list w2_scale_ptrs) {
|
||||||
|
// Convert Python lists to std::vector<uintptr_t>
|
||||||
|
std::vector<uintptr_t> w13_weight_vec, w13_scale_vec, w2_weight_vec, w2_scale_vec;
|
||||||
|
|
||||||
|
for (auto item : w13_weight_ptrs) w13_weight_vec.push_back(py::cast<uintptr_t>(item));
|
||||||
|
for (auto item : w13_scale_ptrs) w13_scale_vec.push_back(py::cast<uintptr_t>(item));
|
||||||
|
for (auto item : w2_weight_ptrs) w2_weight_vec.push_back(py::cast<uintptr_t>(item));
|
||||||
|
for (auto item : w2_scale_ptrs) w2_scale_vec.push_back(py::cast<uintptr_t>(item));
|
||||||
|
|
||||||
|
Args* args = new Args{nullptr, moe.get(), gpu_tp_count, gpu_experts_num,
|
||||||
|
w13_weight_vec, w13_scale_vec, w2_weight_vec, w2_scale_vec};
|
||||||
|
return std::make_pair((intptr_t)&inner, (intptr_t)args);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
moe_cls.def("write_weight_scale_to_buffer_task", &WriteWeightScaleToBufferBindings::cpuinfer_interface,
|
||||||
|
py::arg("gpu_tp_count"), py::arg("gpu_experts_num"),
|
||||||
|
py::arg("w13_weight_ptrs"), py::arg("w13_scale_ptrs"),
|
||||||
|
py::arg("w2_weight_ptrs"), py::arg("w2_scale_ptrs"));
|
||||||
|
}
|
||||||
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
PYBIND11_MODULE(kt_kernel_ext, m) {
|
PYBIND11_MODULE(kt_kernel_ext, m) {
|
||||||
|
|
@ -513,6 +564,7 @@ PYBIND11_MODULE(kt_kernel_ext, m) {
|
||||||
bind_moe_module<AMX_MOE_TP<amx::GemmKernel224Int4>>(moe_module, "AMXInt4_MOE");
|
bind_moe_module<AMX_MOE_TP<amx::GemmKernel224Int4>>(moe_module, "AMXInt4_MOE");
|
||||||
bind_moe_module<AMX_MOE_TP<amx::GemmKernel224Int4_1>>(moe_module, "AMXInt4_1_MOE");
|
bind_moe_module<AMX_MOE_TP<amx::GemmKernel224Int4_1>>(moe_module, "AMXInt4_1_MOE");
|
||||||
bind_moe_module<AMX_AWQ_MOE_TP<amx::GemmKernel224Int4_1_LowKGroup>>(moe_module, "AMXInt4_1KGroup_MOE");
|
bind_moe_module<AMX_AWQ_MOE_TP<amx::GemmKernel224Int4_1_LowKGroup>>(moe_module, "AMXInt4_1KGroup_MOE");
|
||||||
|
bind_moe_module<AMX_K2_MOE_TP<amx::GemmKernel224Int4SmallKGroup>>(moe_module, "AMXInt4_KGroup_MOE");
|
||||||
#endif
|
#endif
|
||||||
#if defined(USE_MOE_KERNEL)
|
#if defined(USE_MOE_KERNEL)
|
||||||
bind_moe_module<MOE_KERNEL_TP<moe_kernel::GemmKernelInt8, _is_plain_>>(moe_module, "Int8_KERNEL_MOE");
|
bind_moe_module<MOE_KERNEL_TP<moe_kernel::GemmKernelInt8, _is_plain_>>(moe_module, "Int8_KERNEL_MOE");
|
||||||
|
|
|
||||||
929
kt-kernel/operators/amx/k2-moe.hpp
Normal file
929
kt-kernel/operators/amx/k2-moe.hpp
Normal file
|
|
@ -0,0 +1,929 @@
|
||||||
|
/**
|
||||||
|
* @Description : Skeleton for K2 AMX MoE operator.
|
||||||
|
* @Author : Codex
|
||||||
|
* @Date : 2024-07-22
|
||||||
|
* @Version : 0.1.0
|
||||||
|
* @LastEditors : Codex
|
||||||
|
* @LastEditTime : 2024-07-22
|
||||||
|
* @Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
|
||||||
|
**/
|
||||||
|
#ifndef CPUINFER_OPERATOR_AMX_K2_MOE_H
|
||||||
|
#define CPUINFER_OPERATOR_AMX_K2_MOE_H
|
||||||
|
|
||||||
|
// #define DEBUG_K2_MOE
|
||||||
|
|
||||||
|
#include <cstddef>
|
||||||
|
#include <cstdint>
|
||||||
|
#include <cstring>
|
||||||
|
// #define FORWARD_TIME_PROFILE
|
||||||
|
// #define FORWARD_TIME_REPORT
|
||||||
|
|
||||||
|
#include <immintrin.h>
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
|
#include <chrono>
|
||||||
|
#include <cmath>
|
||||||
|
#include <cstdio>
|
||||||
|
#include <filesystem>
|
||||||
|
#include <fstream>
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "../../cpu_backend/shared_mem_buffer.h"
|
||||||
|
#include "../../cpu_backend/worker_pool.h"
|
||||||
|
#include "../common.hpp"
|
||||||
|
#include "../moe-tp.hpp"
|
||||||
|
#include "la/amx.hpp"
|
||||||
|
#include "llama.cpp/ggml.h"
|
||||||
|
|
||||||
|
template <class T>
|
||||||
|
class AMX_K2_MOE_TP {
|
||||||
|
private:
|
||||||
|
int tp_part_idx = 0;
|
||||||
|
|
||||||
|
void* gate_proj_ = nullptr; // [expert_num * intermediate_size * hidden_size ( /32 if quantized)]
|
||||||
|
void* up_proj_ = nullptr; // [expert_num * intermediate_size * hidden_size ( /32 if quantized)]
|
||||||
|
void* down_proj_ = nullptr; // [expert_num * hidden_size * intermediate_size ( /32 if quantized)]
|
||||||
|
|
||||||
|
ggml_bf16_t* m_local_input_ = nullptr; // [num_experts_per_tok * max_len * hidden_size]
|
||||||
|
ggml_bf16_t* m_local_gate_output_ = nullptr; // [num_experts_per_tok * max_len * intermediate_size]
|
||||||
|
ggml_bf16_t* m_local_up_output_ = nullptr; // [num_experts_per_tok * max_len * intermediate_size]
|
||||||
|
ggml_bf16_t* m_local_down_output_ = nullptr; // [num_experts_per_tok * max_len * hidden_size]
|
||||||
|
|
||||||
|
std::vector<std::vector<int>> m_local_pos_; // [max_len, num_experts_per_tok]
|
||||||
|
std::vector<int> m_local_num_; // [expert_num]
|
||||||
|
std::vector<int> m_expert_id_map_; // [expert_num]
|
||||||
|
std::vector<ggml_bf16_t*> m_local_input_ptr_; // [expert_num]
|
||||||
|
std::vector<ggml_bf16_t*> m_local_gate_output_ptr_; // [expert_num]
|
||||||
|
std::vector<ggml_bf16_t*> m_local_up_output_ptr_; // [expert_num]
|
||||||
|
std::vector<ggml_bf16_t*> m_local_down_output_ptr_; // [expert_num]
|
||||||
|
|
||||||
|
std::vector<std::shared_ptr<typename T::BufferA>> gate_up_ba_;
|
||||||
|
std::vector<std::shared_ptr<typename T::BufferB>> gate_bb_;
|
||||||
|
std::vector<std::shared_ptr<typename T::BufferC>> gate_bc_;
|
||||||
|
std::vector<std::shared_ptr<typename T::BufferB>> up_bb_;
|
||||||
|
std::vector<std::shared_ptr<typename T::BufferC>> up_bc_;
|
||||||
|
std::vector<std::shared_ptr<typename T::BufferA>> down_ba_;
|
||||||
|
std::vector<std::shared_ptr<typename T::BufferB>> down_bb_;
|
||||||
|
std::vector<std::shared_ptr<typename T::BufferC>> down_bc_;
|
||||||
|
#ifdef CHECK
|
||||||
|
char verify_bb[100000000];
|
||||||
|
char check_bb[100000000];
|
||||||
|
uint8_t compare_expers = 3;
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#ifdef CHECK
|
||||||
|
inline void load_check() {
|
||||||
|
// TODO: implement load_check for verification.
|
||||||
|
}
|
||||||
|
|
||||||
|
void verify_load_right() {
|
||||||
|
// TODO: implement verification helpers.
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
inline void dump_buffer_b(const std::string &quantization_type, int expert_idx, const std::string &matrix_type,
|
||||||
|
typename T::BufferB *buffer) {
|
||||||
|
auto &quant_config = config_.quant_config;
|
||||||
|
int &group_size = quant_config.group_size;
|
||||||
|
|
||||||
|
printf("[DUMP_BUFFER_B] TP%d %s Expert%d %s:\n", tp_part_idx, quantization_type.c_str(), expert_idx,
|
||||||
|
matrix_type.c_str());
|
||||||
|
|
||||||
|
// Calculate dimensions based on matrix type
|
||||||
|
int rows, cols, num_groups;
|
||||||
|
size_t scale_elem_count;
|
||||||
|
if (matrix_type == "gate" || matrix_type == "up") {
|
||||||
|
rows = config_.intermediate_size;
|
||||||
|
cols = config_.hidden_size;
|
||||||
|
num_groups = cols / group_size;
|
||||||
|
scale_elem_count = num_groups * rows;
|
||||||
|
} else { // down
|
||||||
|
rows = config_.hidden_size;
|
||||||
|
cols = config_.intermediate_size;
|
||||||
|
num_groups = cols / group_size;
|
||||||
|
scale_elem_count = num_groups * rows;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Dump scales (as float)
|
||||||
|
printf(" Scales[first 16]: ");
|
||||||
|
for (int i = 0; i < std::min(16, (int)scale_elem_count); i++) {
|
||||||
|
printf("%.6f ", buffer->d[i]);
|
||||||
|
}
|
||||||
|
printf("\n");
|
||||||
|
|
||||||
|
if (scale_elem_count > 16) {
|
||||||
|
printf(" Scales[last 16]: ");
|
||||||
|
int start_idx = std::max(0, (int)scale_elem_count - 16);
|
||||||
|
for (int i = start_idx; i < (int)scale_elem_count; i++) {
|
||||||
|
printf("%.6f ", buffer->d[i]);
|
||||||
|
}
|
||||||
|
printf("\n");
|
||||||
|
}
|
||||||
|
// Dump quantized weights (as hex uint8)
|
||||||
|
size_t weight_size = (rows * cols) / 2; // INT4 packed
|
||||||
|
uint8_t *weight_ptr = (uint8_t *)buffer->b;
|
||||||
|
|
||||||
|
printf(" Weights[first 32 bytes]: ");
|
||||||
|
for (int i = 0; i < std::min(32, (int)weight_size); i++) {
|
||||||
|
printf("%02x ", weight_ptr[i]);
|
||||||
|
}
|
||||||
|
printf("\n");
|
||||||
|
|
||||||
|
if (weight_size > 32) {
|
||||||
|
printf(" Weights[last 32 bytes]: ");
|
||||||
|
int start_idx = std::max(32, (int)weight_size - 32);
|
||||||
|
for (int i = start_idx; i < (int)weight_size; i++) {
|
||||||
|
printf("%02x ", weight_ptr[i]);
|
||||||
|
}
|
||||||
|
printf("\n");
|
||||||
|
}
|
||||||
|
|
||||||
|
printf(" Matrix dimensions: %dx%d, Groups: %d, Group size: %d, Scale elements: %zu\n", rows, cols, num_groups,
|
||||||
|
group_size, scale_elem_count);
|
||||||
|
printf("\n");
|
||||||
|
fflush(stdout);
|
||||||
|
}
|
||||||
|
|
||||||
|
#ifdef FORWARD_TIME_REPORT
|
||||||
|
std::chrono::time_point<std::chrono::high_resolution_clock> last_now;
|
||||||
|
#endif
|
||||||
|
|
||||||
|
public:
|
||||||
|
using input_t = ggml_bf16_t;
|
||||||
|
using output_t = float;
|
||||||
|
GeneralMOEConfig config_;
|
||||||
|
static constexpr double ELEMENT_SIZE = T::ELEMENT_SIZE;
|
||||||
|
|
||||||
|
AMX_K2_MOE_TP(GeneralMOEConfig config, int tp_part_idx_) {
|
||||||
|
auto& quant_config = config.quant_config;
|
||||||
|
int& group_size = quant_config.group_size;
|
||||||
|
if (quant_config.group_size == 0 || quant_config.zero_point) {
|
||||||
|
throw std::runtime_error("Kimi-K2 MoE only support KGroup Int4");
|
||||||
|
}
|
||||||
|
printf("Creating AMX_K2_MOE_TP %d at numa %d\n", tp_part_idx_, numa_node_of_cpu(sched_getcpu()));
|
||||||
|
auto& load = config.load;
|
||||||
|
auto& save = config.save;
|
||||||
|
if (load && config.path == "") {
|
||||||
|
load = false;
|
||||||
|
}
|
||||||
|
|
||||||
|
this->tp_part_idx = tp_part_idx_;
|
||||||
|
config_ = config;
|
||||||
|
gate_proj_ = config_.gate_proj;
|
||||||
|
up_proj_ = config_.up_proj;
|
||||||
|
down_proj_ = config_.down_proj;
|
||||||
|
|
||||||
|
MemoryRequest mem_requests;
|
||||||
|
mem_requests.append_pointer(
|
||||||
|
&m_local_input_, sizeof(ggml_bf16_t) * config_.num_experts_per_tok * config_.max_len * config_.hidden_size);
|
||||||
|
mem_requests.append_pointer(&m_local_gate_output_, sizeof(ggml_bf16_t) * config_.num_experts_per_tok *
|
||||||
|
config_.max_len * config_.intermediate_size);
|
||||||
|
mem_requests.append_pointer(&m_local_up_output_, sizeof(ggml_bf16_t) * config_.num_experts_per_tok *
|
||||||
|
config_.max_len * config_.intermediate_size);
|
||||||
|
mem_requests.append_pointer(&m_local_down_output_, sizeof(ggml_bf16_t) * config_.num_experts_per_tok *
|
||||||
|
config_.max_len * config_.hidden_size);
|
||||||
|
|
||||||
|
m_local_pos_.resize(config_.max_len);
|
||||||
|
for (int i = 0; i < config_.max_len; i++) {
|
||||||
|
m_local_pos_[i].resize(config_.num_experts_per_tok);
|
||||||
|
}
|
||||||
|
m_expert_id_map_.resize(config_.expert_num);
|
||||||
|
m_local_num_.resize(config_.expert_num);
|
||||||
|
m_local_input_ptr_.resize(config_.expert_num);
|
||||||
|
m_local_gate_output_ptr_.resize(config_.expert_num);
|
||||||
|
m_local_up_output_ptr_.resize(config_.expert_num);
|
||||||
|
m_local_down_output_ptr_.resize(config_.expert_num);
|
||||||
|
|
||||||
|
for (size_t i = 0; i < config_.expert_num; i++) {
|
||||||
|
gate_up_ba_.push_back(
|
||||||
|
std::make_shared<typename T::BufferA>(config_.max_len, config_.hidden_size, group_size, nullptr));
|
||||||
|
gate_bc_.push_back(std::make_shared<typename T::BufferC>(config_.max_len, config_.intermediate_size, nullptr));
|
||||||
|
up_bc_.push_back(std::make_shared<typename T::BufferC>(config_.max_len, config_.intermediate_size, nullptr));
|
||||||
|
down_ba_.push_back(
|
||||||
|
std::make_shared<typename T::BufferA>(config_.max_len, config_.intermediate_size, group_size, nullptr));
|
||||||
|
down_bc_.push_back(std::make_shared<typename T::BufferC>(config_.max_len, config_.hidden_size, nullptr));
|
||||||
|
|
||||||
|
void* gate_bb_ptr =
|
||||||
|
std::aligned_alloc(64, T::BufferB::required_size(config_.intermediate_size, config_.hidden_size, group_size));
|
||||||
|
gate_bb_.push_back(std::make_shared<typename T::BufferB>(config_.intermediate_size, config_.hidden_size,
|
||||||
|
group_size, gate_bb_ptr));
|
||||||
|
|
||||||
|
void* up_bb_ptr =
|
||||||
|
std::aligned_alloc(64, T::BufferB::required_size(config_.intermediate_size, config_.hidden_size, group_size));
|
||||||
|
up_bb_.push_back(
|
||||||
|
std::make_shared<typename T::BufferB>(config_.intermediate_size, config_.hidden_size, group_size, up_bb_ptr));
|
||||||
|
|
||||||
|
void* down_bb_ptr =
|
||||||
|
std::aligned_alloc(64, T::BufferB::required_size(config_.hidden_size, config_.intermediate_size, group_size));
|
||||||
|
down_bb_.push_back(std::make_shared<typename T::BufferB>(config_.hidden_size, config_.intermediate_size,
|
||||||
|
group_size, down_bb_ptr));
|
||||||
|
}
|
||||||
|
for (int i = 0; i < config_.expert_num; i++) {
|
||||||
|
mem_requests.append_function([this, i](void* new_ptr) { gate_up_ba_[i]->set_data(new_ptr); },
|
||||||
|
T::BufferA::required_size(config_.max_len, config_.hidden_size, group_size));
|
||||||
|
mem_requests.append_function([this, i](void* new_ptr) { gate_bc_[i]->set_data(new_ptr); },
|
||||||
|
T::BufferC::required_size(config_.max_len, config_.intermediate_size));
|
||||||
|
mem_requests.append_function([this, i](void* new_ptr) { up_bc_[i]->set_data(new_ptr); },
|
||||||
|
T::BufferC::required_size(config_.max_len, config_.intermediate_size));
|
||||||
|
mem_requests.append_function([this, i](void* new_ptr) { down_ba_[i]->set_data(new_ptr); },
|
||||||
|
T::BufferA::required_size(config_.max_len, config_.intermediate_size, group_size));
|
||||||
|
mem_requests.append_function([this, i](void* new_ptr) { down_bc_[i]->set_data(new_ptr); },
|
||||||
|
T::BufferC::required_size(config_.max_len, config_.hidden_size));
|
||||||
|
}
|
||||||
|
shared_mem_buffer_numa.alloc(tp_part_idx, this, mem_requests);
|
||||||
|
}
|
||||||
|
|
||||||
|
~AMX_K2_MOE_TP() = default;
|
||||||
|
|
||||||
|
void load_weights() {
|
||||||
|
auto& quant_config = config_.quant_config;
|
||||||
|
int& group_size = quant_config.group_size;
|
||||||
|
const uint64_t* physical_to_logical_map = (const uint64_t*)config_.physical_to_logical_map;
|
||||||
|
auto pool = config_.pool->get_subpool(tp_part_idx);
|
||||||
|
|
||||||
|
if (quant_config.group_size == 0 || quant_config.zero_point) {
|
||||||
|
throw std::runtime_error("Kimi AVX MOE only support KGroup Int4.");
|
||||||
|
}
|
||||||
|
if (config_.gate_scale == nullptr) {
|
||||||
|
throw std::runtime_error("Kimi AVX MOE only support load native weight.");
|
||||||
|
}
|
||||||
|
// load weight
|
||||||
|
int nth = T::recommended_nth(config_.intermediate_size);
|
||||||
|
pool->do_work_stealing_job(
|
||||||
|
nth * config_.expert_num, nullptr,
|
||||||
|
[this, nth, physical_to_logical_map](int task_id) {
|
||||||
|
uint64_t expert_idx = task_id / nth;
|
||||||
|
uint64_t logical_expert_id = expert_map(physical_to_logical_map, expert_idx);
|
||||||
|
int ith = task_id % nth;
|
||||||
|
// gate part
|
||||||
|
gate_bb_[expert_idx]->from_raw_mat(
|
||||||
|
(uint8_t*)config_.gate_proj +
|
||||||
|
((logical_expert_id * config_.intermediate_size * config_.hidden_size) >> 1),
|
||||||
|
ith, nth);
|
||||||
|
// up part
|
||||||
|
up_bb_[expert_idx]->from_raw_mat(
|
||||||
|
(uint8_t*)config_.up_proj +
|
||||||
|
((logical_expert_id * config_.intermediate_size * config_.hidden_size) >> 1),
|
||||||
|
ith, nth);
|
||||||
|
},
|
||||||
|
nullptr);
|
||||||
|
|
||||||
|
nth = T::recommended_nth(config_.hidden_size);
|
||||||
|
pool->do_work_stealing_job(
|
||||||
|
nth * config_.expert_num, nullptr,
|
||||||
|
[this, nth, physical_to_logical_map](int task_id) {
|
||||||
|
uint64_t expert_idx = task_id / nth;
|
||||||
|
uint64_t logical_expert_id = expert_map(physical_to_logical_map, expert_idx);
|
||||||
|
int ith = task_id % nth;
|
||||||
|
// down part
|
||||||
|
down_bb_[expert_idx]->from_raw_mat(
|
||||||
|
(uint8_t*)config_.down_proj +
|
||||||
|
((logical_expert_id * config_.hidden_size * config_.intermediate_size) >> 1),
|
||||||
|
ith, nth);
|
||||||
|
},
|
||||||
|
nullptr);
|
||||||
|
|
||||||
|
pool->do_work_stealing_job(
|
||||||
|
config_.expert_num, nullptr,
|
||||||
|
[this, physical_to_logical_map](int task_id) {
|
||||||
|
uint64_t expert_idx = task_id;
|
||||||
|
uint64_t logical_expert_id = expert_map(physical_to_logical_map, expert_idx);
|
||||||
|
size_t scale_elem_count =
|
||||||
|
(config_.hidden_size * config_.intermediate_size) / config_.quant_config.group_size;
|
||||||
|
|
||||||
|
// convert scales from BF16 to FP32
|
||||||
|
convert_or_copy(gate_bb_[expert_idx]->d,
|
||||||
|
(ggml_bf16_t*)config_.gate_scale + (logical_expert_id * scale_elem_count),
|
||||||
|
scale_elem_count);
|
||||||
|
convert_or_copy(up_bb_[expert_idx]->d,
|
||||||
|
(ggml_bf16_t*)config_.up_scale + (logical_expert_id * scale_elem_count),
|
||||||
|
scale_elem_count);
|
||||||
|
convert_or_copy(down_bb_[expert_idx]->d,
|
||||||
|
(ggml_bf16_t*)config_.down_scale + (logical_expert_id * scale_elem_count),
|
||||||
|
scale_elem_count);
|
||||||
|
},
|
||||||
|
nullptr);
|
||||||
|
// dump_buffer_b("native", 0, "down", down_bb_[0].get());
|
||||||
|
}
|
||||||
|
|
||||||
|
// Reconstruct weights for all experts to the output buffers
|
||||||
|
// This function handles the TP-specific portion of the reconstruction for all experts
|
||||||
|
void write_weights_to_buffer(int gpu_tp_count, int cpu_tp_count, int num_experts, const GeneralMOEConfig& full_config,
|
||||||
|
const std::vector<uintptr_t>& w13_weight_ptrs,
|
||||||
|
const std::vector<uintptr_t>& w13_scale_ptrs,
|
||||||
|
const std::vector<uintptr_t>& w2_weight_ptrs,
|
||||||
|
const std::vector<uintptr_t>& w2_scale_ptrs) const {
|
||||||
|
const int group_size = config_.quant_config.group_size;
|
||||||
|
auto pool = config_.pool->get_subpool(tp_part_idx);
|
||||||
|
|
||||||
|
// Calculate sizes for CPU TP part (this instance)
|
||||||
|
size_t cpu_tp_weight_elem_count = (size_t)config_.intermediate_size * config_.hidden_size;
|
||||||
|
size_t cpu_tp_weight_bytes = cpu_tp_weight_elem_count / 2; // int4 packing
|
||||||
|
size_t cpu_tp_scale_elem_count = cpu_tp_weight_elem_count / group_size;
|
||||||
|
|
||||||
|
// Calculate sizes for GPU TP part
|
||||||
|
size_t gpu_tp_weight_elem_count = (size_t)full_config.intermediate_size * full_config.hidden_size / gpu_tp_count;
|
||||||
|
size_t gpu_tp_weight_bytes = gpu_tp_weight_elem_count / 2; // int4 packing
|
||||||
|
size_t gpu_tp_scale_elem_count = gpu_tp_weight_elem_count / group_size;
|
||||||
|
|
||||||
|
// Determine mapping: which GPU TP parts should this CPU TP part write to?
|
||||||
|
// Since weights are col-major and we slice directly by memory order:
|
||||||
|
// - If cpu_tp_count >= gpu_tp_count: multiple(or one) CPU TPs write to one GPU TP
|
||||||
|
// - If cpu_tp_count < gpu_tp_count: one CPU TP writes to multiple GPU TPs
|
||||||
|
if (cpu_tp_count >= gpu_tp_count) {
|
||||||
|
// Multiple CPU TPs map to one GPU TP
|
||||||
|
int target_gpu_tp = tp_part_idx / (cpu_tp_count / gpu_tp_count);
|
||||||
|
int local_idx = tp_part_idx % (cpu_tp_count / gpu_tp_count);
|
||||||
|
|
||||||
|
// Get pointers for this GPU TP part
|
||||||
|
uint8_t* w13_weight_dst = (uint8_t*)w13_weight_ptrs[target_gpu_tp];
|
||||||
|
ggml_bf16_t* w13_scale_dst = (ggml_bf16_t*)w13_scale_ptrs[target_gpu_tp];
|
||||||
|
uint8_t* w2_weight_dst = (uint8_t*)w2_weight_ptrs[target_gpu_tp];
|
||||||
|
ggml_bf16_t* w2_scale_dst = (ggml_bf16_t*)w2_scale_ptrs[target_gpu_tp];
|
||||||
|
|
||||||
|
// Calculate offset within the GPU TP buffer
|
||||||
|
size_t offset_in_gpu_weight = local_idx * cpu_tp_weight_bytes;
|
||||||
|
size_t offset_in_gpu_scale = local_idx * cpu_tp_scale_elem_count;
|
||||||
|
|
||||||
|
// Process only the first num_experts experts (GPU experts)
|
||||||
|
int nth = T::recommended_nth(config_.intermediate_size);
|
||||||
|
nth = 1;
|
||||||
|
pool->do_work_stealing_job(
|
||||||
|
nth * num_experts, nullptr,
|
||||||
|
[&, this](int task_id) {
|
||||||
|
int expert_id = task_id / nth;
|
||||||
|
// int ith = task_id % nth;
|
||||||
|
// auto [n_start, n_end] = T::split_range_n(config_.intermediate_size, ith, nth);
|
||||||
|
|
||||||
|
// Calculate base offsets for this expert in the GPU buffers
|
||||||
|
// For w13: each expert has gate+up, so the offset needs to account for 2x size
|
||||||
|
size_t w13_expert_base_weight = expert_id * 2 * gpu_tp_weight_bytes;
|
||||||
|
size_t w13_expert_base_scale = expert_id * 2 * gpu_tp_scale_elem_count;
|
||||||
|
size_t w2_expert_base_weight = expert_id * gpu_tp_weight_bytes;
|
||||||
|
size_t w2_expert_base_scale = expert_id * gpu_tp_scale_elem_count;
|
||||||
|
|
||||||
|
// Gate (first part of w13 for this expert)
|
||||||
|
uint8_t* gate_weight_src = (uint8_t*)gate_bb_[expert_id]->b;
|
||||||
|
float* gate_scale_src = gate_bb_[expert_id]->d;
|
||||||
|
std::memcpy(w13_weight_dst + w13_expert_base_weight + offset_in_gpu_weight,
|
||||||
|
gate_weight_src, cpu_tp_weight_bytes);
|
||||||
|
convert_or_copy((ggml_bf16_t*)(w13_scale_dst + w13_expert_base_scale + offset_in_gpu_scale),
|
||||||
|
gate_scale_src, cpu_tp_scale_elem_count);
|
||||||
|
|
||||||
|
// Up (second part of w13 for this expert, immediately after gate)
|
||||||
|
uint8_t* up_weight_src = (uint8_t*)up_bb_[expert_id]->b;
|
||||||
|
float* up_scale_src = up_bb_[expert_id]->d;
|
||||||
|
std::memcpy(w13_weight_dst + w13_expert_base_weight + offset_in_gpu_weight + gpu_tp_weight_bytes,
|
||||||
|
up_weight_src, cpu_tp_weight_bytes);
|
||||||
|
convert_or_copy((ggml_bf16_t*)(w13_scale_dst + w13_expert_base_scale + offset_in_gpu_scale + gpu_tp_scale_elem_count),
|
||||||
|
up_scale_src, cpu_tp_scale_elem_count);
|
||||||
|
|
||||||
|
// Down (w2) - need to handle column-wise slicing
|
||||||
|
// The down matrix is transposed compared to gate/up, so we need to extract by columns
|
||||||
|
// When multiple CPU TPs map to one GPU TP, each CPU TP has a slice of intermediate dimension
|
||||||
|
// CPU TP internal layout: each column has config_.intermediate_size elements
|
||||||
|
// GPU expects: each column has full_config.intermediate_size elements
|
||||||
|
size_t cpu_tps_per_gpu = cpu_tp_count / gpu_tp_count;
|
||||||
|
|
||||||
|
for (size_t col = 0; col < config_.hidden_size; col++) {
|
||||||
|
// GPU buffer column width is full_config.intermediate_size / gpu_tp_count
|
||||||
|
size_t gpu_col_offset = col * ((full_config.intermediate_size / gpu_tp_count) >> 1);
|
||||||
|
size_t cpu_col_offset = col * (config_.intermediate_size >> 1);
|
||||||
|
size_t gpu_col_slice_offset = local_idx * (config_.intermediate_size >> 1);
|
||||||
|
|
||||||
|
std::memcpy(w2_weight_dst + w2_expert_base_weight + gpu_col_offset + gpu_col_slice_offset,
|
||||||
|
(uint8_t*)down_bb_[expert_id]->b + cpu_col_offset,
|
||||||
|
config_.intermediate_size / 2);
|
||||||
|
|
||||||
|
// Same for scales
|
||||||
|
size_t gpu_scale_col_offset = col * ((full_config.intermediate_size / gpu_tp_count) / group_size);
|
||||||
|
size_t cpu_scale_col_offset = col * (config_.intermediate_size / group_size);
|
||||||
|
size_t gpu_scale_slice_offset = local_idx * (config_.intermediate_size / group_size);
|
||||||
|
|
||||||
|
convert_or_copy((ggml_bf16_t*)(w2_scale_dst + w2_expert_base_scale + gpu_scale_col_offset + gpu_scale_slice_offset),
|
||||||
|
down_bb_[expert_id]->d + cpu_scale_col_offset,
|
||||||
|
config_.intermediate_size / group_size);
|
||||||
|
}
|
||||||
|
},
|
||||||
|
nullptr);
|
||||||
|
} else {
|
||||||
|
// cpu_tp_count < gpu_tp_count: one CPU TP writes to multiple GPU TPs
|
||||||
|
// Each CPU TP part contains data for multiple GPU TP parts
|
||||||
|
int gpu_tps_per_cpu_tp = gpu_tp_count / cpu_tp_count;
|
||||||
|
|
||||||
|
// This CPU TP part writes to GPU TP indices: [start_gpu_tp, start_gpu_tp + gpu_tps_per_cpu_tp)
|
||||||
|
int start_gpu_tp = tp_part_idx * gpu_tps_per_cpu_tp;
|
||||||
|
|
||||||
|
// Size of data per GPU TP within this CPU TP
|
||||||
|
size_t data_per_gpu_tp_weight = cpu_tp_weight_bytes / gpu_tps_per_cpu_tp;
|
||||||
|
size_t data_per_gpu_tp_scale = cpu_tp_scale_elem_count / gpu_tps_per_cpu_tp;
|
||||||
|
|
||||||
|
// Process all experts for this GPU TP
|
||||||
|
pool->do_work_stealing_job(
|
||||||
|
gpu_tps_per_cpu_tp * num_experts, nullptr,
|
||||||
|
[&, this](int task_id) {
|
||||||
|
int expert_id = task_id % num_experts;
|
||||||
|
int local_gpu_idx = task_id / num_experts;
|
||||||
|
int gpu_tp_idx = start_gpu_tp + local_gpu_idx;
|
||||||
|
|
||||||
|
// Get pointers for this GPU TP part
|
||||||
|
uint8_t* w13_weight_dst = (uint8_t*)w13_weight_ptrs[gpu_tp_idx];
|
||||||
|
ggml_bf16_t* w13_scale_dst = (ggml_bf16_t*)w13_scale_ptrs[gpu_tp_idx];
|
||||||
|
uint8_t* w2_weight_dst = (uint8_t*)w2_weight_ptrs[gpu_tp_idx];
|
||||||
|
ggml_bf16_t* w2_scale_dst = (ggml_bf16_t*)w2_scale_ptrs[gpu_tp_idx];
|
||||||
|
|
||||||
|
// Calculate offsets within CPU TP buffers
|
||||||
|
size_t cpu_offset_weight = local_gpu_idx * data_per_gpu_tp_weight;
|
||||||
|
size_t cpu_offset_scale = local_gpu_idx * data_per_gpu_tp_scale;
|
||||||
|
|
||||||
|
// Calculate offsets for this expert in GPU buffers
|
||||||
|
// For w13: each expert has gate+up, so the offset needs to account for 2x size
|
||||||
|
size_t w13_gpu_expert_offset_weight = expert_id * 2 * gpu_tp_weight_bytes;
|
||||||
|
size_t w13_gpu_expert_offset_scale = expert_id * 2 * gpu_tp_scale_elem_count;
|
||||||
|
size_t w2_gpu_expert_offset_weight = expert_id * gpu_tp_weight_bytes;
|
||||||
|
size_t w2_gpu_expert_offset_scale = expert_id * gpu_tp_scale_elem_count;
|
||||||
|
|
||||||
|
// Gate (first part of w13 for this expert)
|
||||||
|
uint8_t* gate_weight_src = (uint8_t*)gate_bb_[expert_id]->b + cpu_offset_weight;
|
||||||
|
float* gate_scale_src = gate_bb_[expert_id]->d + cpu_offset_scale;
|
||||||
|
std::memcpy(w13_weight_dst + w13_gpu_expert_offset_weight,
|
||||||
|
gate_weight_src, data_per_gpu_tp_weight);
|
||||||
|
convert_or_copy((ggml_bf16_t*)(w13_scale_dst + w13_gpu_expert_offset_scale),
|
||||||
|
gate_scale_src, data_per_gpu_tp_scale);
|
||||||
|
|
||||||
|
// Up (second part of w13 for this expert, immediately after gate)
|
||||||
|
uint8_t* up_weight_src = (uint8_t*)up_bb_[expert_id]->b + cpu_offset_weight;
|
||||||
|
float* up_scale_src = up_bb_[expert_id]->d + cpu_offset_scale;
|
||||||
|
std::memcpy(w13_weight_dst + w13_gpu_expert_offset_weight + gpu_tp_weight_bytes,
|
||||||
|
up_weight_src, data_per_gpu_tp_weight);
|
||||||
|
convert_or_copy((ggml_bf16_t*)(w13_scale_dst + w13_gpu_expert_offset_scale + gpu_tp_scale_elem_count),
|
||||||
|
up_scale_src, data_per_gpu_tp_scale);
|
||||||
|
|
||||||
|
// Down (w2) - need to handle column-wise slicing
|
||||||
|
// The down matrix is transposed compared to gate/up, so we need to extract by columns
|
||||||
|
for (size_t col = 0; col < config_.hidden_size; col++) {
|
||||||
|
// Calculate the offset within the column for this GPU TP part
|
||||||
|
size_t col_offset_weight = (col * config_.intermediate_size / 2) + (local_gpu_idx * data_per_gpu_tp_weight / config_.hidden_size);
|
||||||
|
size_t col_offset_scale = (col * (config_.intermediate_size / group_size)) + (local_gpu_idx * data_per_gpu_tp_scale / config_.hidden_size);
|
||||||
|
|
||||||
|
// Copy weights column by column
|
||||||
|
std::memcpy(w2_weight_dst + w2_gpu_expert_offset_weight + (col * (config_.intermediate_size / gpu_tps_per_cpu_tp) / 2),
|
||||||
|
(uint8_t*)down_bb_[expert_id]->b + col_offset_weight,
|
||||||
|
(config_.intermediate_size / gpu_tps_per_cpu_tp) / 2);
|
||||||
|
|
||||||
|
// Copy scales column by column
|
||||||
|
convert_or_copy((ggml_bf16_t*)(w2_scale_dst + w2_gpu_expert_offset_scale + col * ((config_.intermediate_size / gpu_tps_per_cpu_tp) / group_size)),
|
||||||
|
down_bb_[expert_id]->d + col_offset_scale,
|
||||||
|
(config_.intermediate_size / gpu_tps_per_cpu_tp) / group_size);
|
||||||
|
}
|
||||||
|
},
|
||||||
|
nullptr);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void warm_up() {
|
||||||
|
int qlen = config_.max_len;
|
||||||
|
std::vector<uint8_t> input(sizeof(ggml_bf16_t) * qlen * config_.hidden_size);
|
||||||
|
std::vector<uint8_t> output(sizeof(ggml_bf16_t) * qlen * config_.hidden_size);
|
||||||
|
std::vector<int64_t> expert_ids(qlen * config_.num_experts_per_tok);
|
||||||
|
std::vector<float> weights(qlen * config_.num_experts_per_tok);
|
||||||
|
for (int i = 0; i < qlen * config_.num_experts_per_tok; i++) {
|
||||||
|
expert_ids[i] = i % config_.expert_num;
|
||||||
|
weights[i] = 0.01;
|
||||||
|
}
|
||||||
|
forward(qlen, config_.num_experts_per_tok, expert_ids.data(), weights.data(), input.data(), output.data());
|
||||||
|
}
|
||||||
|
|
||||||
|
void forward(int qlen, int k, const int64_t* expert_ids, const float* weights, const void* input, void* output) {
|
||||||
|
if (qlen > 1) {
|
||||||
|
forward_prefill(qlen, k, expert_ids, weights, input, output);
|
||||||
|
} else {
|
||||||
|
forward_decode(k, expert_ids, weights, input, output);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#ifndef DIRECT_OR_POOL_BY_QLEN
|
||||||
|
#define DIRECT_OR_POOL_BY_QLEN(var, fn) \
|
||||||
|
do { \
|
||||||
|
if (qlen < 10) { \
|
||||||
|
for (int i = 0; i < (var); i++) { \
|
||||||
|
(fn)(i); \
|
||||||
|
} \
|
||||||
|
} else { \
|
||||||
|
pool->do_work_stealing_job((var), nullptr, (fn), nullptr); \
|
||||||
|
} \
|
||||||
|
} while (0)
|
||||||
|
#endif
|
||||||
|
|
||||||
|
void forward_prefill(int qlen, int k, const int64_t* expert_ids, const float* weights, const void* input,
|
||||||
|
void* output) {
|
||||||
|
for (int i = 0; i < qlen; i ++)
|
||||||
|
forward_decode(k, expert_ids + i * k, weights + i * k, (ggml_bf16_t*)input + i * config_.hidden_size, (float*)output + i * config_.hidden_size);
|
||||||
|
}
|
||||||
|
|
||||||
|
void forward_decode(int k, const int64_t* expert_ids, const float* weights, const void* input, void* output) {
|
||||||
|
int qlen = 1;
|
||||||
|
auto pool = config_.pool->get_subpool(tp_part_idx);
|
||||||
|
auto& quant_config = config_.quant_config;
|
||||||
|
int& group_size = quant_config.group_size;
|
||||||
|
#ifdef FORWARD_TIME_PROFILE
|
||||||
|
auto start_time = std::chrono::high_resolution_clock::now();
|
||||||
|
auto last = start_time;
|
||||||
|
// 用于保存各阶段耗时(单位:微秒)
|
||||||
|
long prepare_time = 0, cpy_input_time = 0, q_input_time = 0, up_gate_time = 0;
|
||||||
|
long act_time = 0, q_down_time = 0, down_time = 0, weight_time = 0;
|
||||||
|
int max_local_num = 0; // 记录最大的 local num
|
||||||
|
#endif
|
||||||
|
|
||||||
|
int activated_expert = 0;
|
||||||
|
for (int i = 0; i < k; i++) {
|
||||||
|
if (expert_ids[i] < config_.num_gpu_experts || expert_ids[i] >= config_.expert_num) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
m_expert_id_map_[activated_expert] = expert_ids[i];
|
||||||
|
activated_expert++;
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t offset = 0;
|
||||||
|
for (int i = 0; i < activated_expert; i++) {
|
||||||
|
auto expert_idx = m_expert_id_map_[i];
|
||||||
|
m_local_gate_output_ptr_[expert_idx] = m_local_gate_output_ + offset * config_.intermediate_size;
|
||||||
|
m_local_up_output_ptr_[expert_idx] = m_local_up_output_ + offset * config_.intermediate_size;
|
||||||
|
m_local_down_output_ptr_[expert_idx] = m_local_down_output_ + offset * config_.hidden_size;
|
||||||
|
offset += qlen;
|
||||||
|
}
|
||||||
|
|
||||||
|
gate_up_ba_[0]->from_mat(qlen, (ggml_bf16_t*)input, 0, 1);
|
||||||
|
|
||||||
|
#ifdef FORWARD_TIME_PROFILE
|
||||||
|
{
|
||||||
|
auto now_time = std::chrono::high_resolution_clock::now();
|
||||||
|
q_input_time = std::chrono::duration_cast<std::chrono::microseconds>(now_time - last).count();
|
||||||
|
last = now_time;
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
// calc gate & up
|
||||||
|
int nth = T::recommended_nth(config_.intermediate_size);
|
||||||
|
pool->do_work_stealing_job(
|
||||||
|
nth * activated_expert * 2, [](int _) { T::config(); },
|
||||||
|
[this, nth, qlen](int task_id2) {
|
||||||
|
int& group_size = config_.quant_config.group_size;
|
||||||
|
int task_id = task_id2 / 2;
|
||||||
|
bool do_up = task_id2 % 2;
|
||||||
|
int expert_idx = m_expert_id_map_[task_id / nth];
|
||||||
|
|
||||||
|
int ith = task_id % nth;
|
||||||
|
if (do_up) {
|
||||||
|
amx::vec_mul_kgroup(qlen, config_.intermediate_size, config_.hidden_size, group_size, gate_up_ba_[0],
|
||||||
|
up_bb_[expert_idx], up_bc_[expert_idx], ith, nth);
|
||||||
|
up_bc_[expert_idx]->to_mat(qlen, m_local_up_output_ptr_[expert_idx], ith, nth);
|
||||||
|
} else {
|
||||||
|
amx::vec_mul_kgroup(qlen, config_.intermediate_size, config_.hidden_size, group_size, gate_up_ba_[0],
|
||||||
|
gate_bb_[expert_idx], gate_bc_[expert_idx], ith, nth);
|
||||||
|
gate_bc_[expert_idx]->to_mat(qlen, m_local_gate_output_ptr_[expert_idx], ith, nth);
|
||||||
|
}
|
||||||
|
},
|
||||||
|
nullptr);
|
||||||
|
|
||||||
|
#ifdef DEBUG_K2_MOE
|
||||||
|
if (activated_expert > 0) {
|
||||||
|
int print_elems = std::min(config_.intermediate_size, 16);
|
||||||
|
for (int dbg = 0; dbg < activated_expert; ++dbg) {
|
||||||
|
int sample_expert = m_expert_id_map_[dbg];
|
||||||
|
ggml_bf16_t* gate_ptr = m_local_gate_output_ptr_[sample_expert];
|
||||||
|
if (gate_ptr == nullptr) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
printf("[K2][TP %d] gate_out (expert %d, first %d elems): ", tp_part_idx, sample_expert, print_elems);
|
||||||
|
for (int idx = 0; idx < print_elems; idx++) {
|
||||||
|
float val = ggml_bf16_to_fp32(gate_ptr[idx]);
|
||||||
|
printf("%.6f ", val);
|
||||||
|
}
|
||||||
|
printf("\n");
|
||||||
|
|
||||||
|
int tail_start = config_.intermediate_size > print_elems ? config_.intermediate_size - print_elems : 0;
|
||||||
|
printf("[K2][TP %d] gate_out (expert %d, last %d elems): ", tp_part_idx, sample_expert, print_elems);
|
||||||
|
for (int idx = 0; idx < print_elems; idx++) {
|
||||||
|
float val = ggml_bf16_to_fp32(gate_ptr[tail_start + idx]);
|
||||||
|
printf("%.6f ", val);
|
||||||
|
}
|
||||||
|
printf("\n");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#ifdef FORWARD_TIME_PROFILE
|
||||||
|
{
|
||||||
|
auto now_time = std::chrono::high_resolution_clock::now();
|
||||||
|
up_gate_time = std::chrono::duration_cast<std::chrono::microseconds>(now_time - last).count();
|
||||||
|
last = now_time;
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
// act
|
||||||
|
for (int task_id = 0; task_id < nth * activated_expert; task_id++) {
|
||||||
|
int expert_idx = m_expert_id_map_[task_id / nth];
|
||||||
|
int ith = task_id % nth;
|
||||||
|
auto [n_start, n_end] = T::split_range_n(config_.intermediate_size, ith, nth);
|
||||||
|
for (int i = 0; i < qlen; i++) {
|
||||||
|
ggml_bf16_t* gate_output_ptr = &m_local_gate_output_ptr_[expert_idx][i * config_.intermediate_size];
|
||||||
|
ggml_bf16_t* up_output_ptr = &m_local_up_output_ptr_[expert_idx][i * config_.intermediate_size];
|
||||||
|
for (int j = n_start; j < n_end; j += 32) {
|
||||||
|
__m512 gate_val0, gate_val1, up_val0, up_val1;
|
||||||
|
avx512_32xbf16_to_32xfp32((__m512i*)(gate_output_ptr + j), &gate_val0, &gate_val1);
|
||||||
|
avx512_32xbf16_to_32xfp32((__m512i*)(up_output_ptr + j), &up_val0, &up_val1);
|
||||||
|
__m512 result0 = amx::act_fn(gate_val0, up_val0);
|
||||||
|
__m512 result1 = amx::act_fn(gate_val1, up_val1);
|
||||||
|
avx512_32xfp32_to_32xbf16(&result0, &result1, (__m512i*)(gate_output_ptr + j));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
#ifdef FORWARD_TIME_PROFILE
|
||||||
|
{
|
||||||
|
auto now_time = std::chrono::high_resolution_clock::now();
|
||||||
|
act_time = std::chrono::duration_cast<std::chrono::microseconds>(now_time - last).count();
|
||||||
|
last = now_time;
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
// quant, get down a
|
||||||
|
pool->do_work_stealing_job(
|
||||||
|
activated_expert, nullptr,
|
||||||
|
[this, qlen](int task_id) {
|
||||||
|
int expert_idx = m_expert_id_map_[task_id];
|
||||||
|
down_ba_[expert_idx]->from_mat(qlen, m_local_gate_output_ptr_[expert_idx], 0, 1);
|
||||||
|
},
|
||||||
|
nullptr);
|
||||||
|
#ifdef FORWARD_TIME_PROFILE
|
||||||
|
{
|
||||||
|
auto now_time = std::chrono::high_resolution_clock::now();
|
||||||
|
q_down_time = std::chrono::duration_cast<std::chrono::microseconds>(now_time - last).count();
|
||||||
|
last = now_time;
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
// * down
|
||||||
|
nth = T::recommended_nth(config_.hidden_size);
|
||||||
|
pool->do_work_stealing_job(
|
||||||
|
nth * activated_expert, [](int _) { T::config(); },
|
||||||
|
[this, nth, qlen](int task_id) {
|
||||||
|
int& group_size = config_.quant_config.group_size;
|
||||||
|
int expert_idx = m_expert_id_map_[task_id / nth];
|
||||||
|
int ith = task_id % nth;
|
||||||
|
amx::vec_mul_kgroup(qlen, config_.hidden_size, config_.intermediate_size, group_size, down_ba_[expert_idx],
|
||||||
|
down_bb_[expert_idx], down_bc_[expert_idx], ith, nth);
|
||||||
|
down_bc_[expert_idx]->to_mat(qlen, m_local_down_output_ptr_[expert_idx], ith, nth);
|
||||||
|
},
|
||||||
|
nullptr);
|
||||||
|
|
||||||
|
#ifdef DEBUG_K2_MOE
|
||||||
|
if (activated_expert > 0) {
|
||||||
|
int print_elems = std::min(config_.hidden_size, 16);
|
||||||
|
for (int dbg = 0; dbg < activated_expert; ++dbg) {
|
||||||
|
int sample_expert = m_expert_id_map_[dbg];
|
||||||
|
ggml_bf16_t* down_ptr = m_local_down_output_ptr_[sample_expert];
|
||||||
|
if (down_ptr == nullptr) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
printf("[K2][TP %d] down_out (expert %d, first %d elems): ", tp_part_idx, sample_expert, print_elems);
|
||||||
|
for (int idx = 0; idx < print_elems; idx++) {
|
||||||
|
float val = ggml_bf16_to_fp32(down_ptr[idx]);
|
||||||
|
printf("%.6f ", val);
|
||||||
|
}
|
||||||
|
printf("\n");
|
||||||
|
|
||||||
|
int tail_start = config_.hidden_size > print_elems ? config_.hidden_size - print_elems : 0;
|
||||||
|
printf("[K2][TP %d] down_out (expert %d, last %d elems): ", tp_part_idx, sample_expert, print_elems);
|
||||||
|
for (int idx = 0; idx < print_elems; idx++) {
|
||||||
|
float val = ggml_bf16_to_fp32(down_ptr[tail_start + idx]);
|
||||||
|
printf("%.6f ", val);
|
||||||
|
}
|
||||||
|
printf("\n");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#ifdef FORWARD_TIME_PROFILE
|
||||||
|
{
|
||||||
|
auto now_time = std::chrono::high_resolution_clock::now();
|
||||||
|
down_time = std::chrono::duration_cast<std::chrono::microseconds>(now_time - last).count();
|
||||||
|
last = now_time;
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
// get output
|
||||||
|
for (int e = 0; e < config_.hidden_size; e += 32) {
|
||||||
|
__m512 x0 = _mm512_setzero_ps();
|
||||||
|
__m512 x1 = _mm512_setzero_ps();
|
||||||
|
for (int j = 0; j < k; j++) {
|
||||||
|
if (expert_ids[j] < config_.num_gpu_experts || expert_ids[j] >= config_.expert_num) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
__m512 weight = _mm512_set1_ps(weights[j]);
|
||||||
|
__m512 down_output0, down_output1;
|
||||||
|
avx512_32xbf16_to_32xfp32((__m512i*)(m_local_down_output_ptr_[expert_ids[j]] +
|
||||||
|
m_local_pos_[0][j] * config_.hidden_size + e),
|
||||||
|
&down_output0, &down_output1);
|
||||||
|
x0 = _mm512_fmadd_ps(down_output0, weight, x0);
|
||||||
|
x1 = _mm512_fmadd_ps(down_output1, weight, x1);
|
||||||
|
}
|
||||||
|
auto f32out = (__m512*)((float*)output + e);
|
||||||
|
f32out[0] = x0;
|
||||||
|
f32out[1] = x1;
|
||||||
|
}
|
||||||
|
|
||||||
|
#ifdef FORWARD_TIME_PROFILE
|
||||||
|
{
|
||||||
|
auto now_time = std::chrono::high_resolution_clock::now();
|
||||||
|
weight_time = std::chrono::duration_cast<std::chrono::microseconds>(now_time - last).count();
|
||||||
|
last = now_time;
|
||||||
|
}
|
||||||
|
auto end_time = std::chrono::high_resolution_clock::now();
|
||||||
|
auto forward_total_time = std::chrono::duration_cast<std::chrono::microseconds>(end_time - start_time).count();
|
||||||
|
// 在函数末尾一次性打印所有阶段的耗时,并附带 max_local_num 和 qlen
|
||||||
|
printf(
|
||||||
|
"Profiling Results (numa[%d]): activated_expert: %d, q_input: %ld us, "
|
||||||
|
"up_gate: %ld us, act: %ld us, q_down: %ld us, down: %ld us, weight: %ld us, total: %ld us\n",
|
||||||
|
tp_part_idx, activated_expert, q_input_time, up_gate_time, act_time, q_down_time, down_time, weight_time,
|
||||||
|
forward_total_time);
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename K>
|
||||||
|
class TP_MOE<AMX_K2_MOE_TP<K>> : public TP_MOE_Common<AMX_K2_MOE_TP<K>> {
|
||||||
|
public:
|
||||||
|
using TP_MOE_Common<AMX_K2_MOE_TP<K>>::TP_MOE_Common;
|
||||||
|
|
||||||
|
void load_weights() {
|
||||||
|
auto& config = this->config;
|
||||||
|
auto& tps = this->tps;
|
||||||
|
auto& tp_count = this->tp_count;
|
||||||
|
auto pool = config.pool;
|
||||||
|
const uint64_t* physical_to_logical_map = (const uint64_t*)config.physical_to_logical_map;
|
||||||
|
|
||||||
|
if (config.gate_scale == nullptr) {
|
||||||
|
throw std::runtime_error("K2 MoE only supports Packed Int4 with KGroup Scale");
|
||||||
|
}
|
||||||
|
printf("From Packed Int4 with KGroup Scale\n");
|
||||||
|
int& group_size = config.quant_config.group_size;
|
||||||
|
for (auto i = 0; i < tp_count; i++) {
|
||||||
|
auto& tpc = tps[i]->config_;
|
||||||
|
size_t weight_elem_count = tpc.intermediate_size * tpc.hidden_size;
|
||||||
|
tpc.gate_proj = new uint8_t[(tpc.expert_num * weight_elem_count) / 2];
|
||||||
|
tpc.up_proj = new uint8_t[(tpc.expert_num * weight_elem_count) / 2];
|
||||||
|
tpc.down_proj = new uint8_t[(tpc.expert_num * weight_elem_count) / 2];
|
||||||
|
|
||||||
|
size_t scales_elem_count = (tpc.hidden_size / group_size) * tpc.intermediate_size;
|
||||||
|
|
||||||
|
tpc.gate_scale = new ggml_bf16_t[(tpc.expert_num * scales_elem_count)];
|
||||||
|
tpc.up_scale = new ggml_bf16_t[(tpc.expert_num * scales_elem_count)];
|
||||||
|
tpc.down_scale = new ggml_bf16_t[(tpc.expert_num * scales_elem_count)];
|
||||||
|
|
||||||
|
if (tps[i]->config_.load == false) {
|
||||||
|
pool->get_subpool(i)->do_work_stealing_job(
|
||||||
|
tpc.expert_num, nullptr,
|
||||||
|
[&](int expert_id_) { // weight and scale are all in col majored.
|
||||||
|
size_t expert_id = expert_map(physical_to_logical_map, expert_id_);
|
||||||
|
|
||||||
|
// weight and scale TP-slicing for gate and up
|
||||||
|
memcpy((uint8_t*)tpc.gate_proj + ((expert_id * weight_elem_count) >> 1),
|
||||||
|
(uint8_t*)config.gate_proj +
|
||||||
|
((expert_id * config.intermediate_size * config.hidden_size + i * weight_elem_count) >> 1),
|
||||||
|
((sizeof(uint8_t) * weight_elem_count) >> 1));
|
||||||
|
|
||||||
|
memcpy((uint8_t*)tpc.up_proj + ((expert_id * weight_elem_count) >> 1),
|
||||||
|
(uint8_t*)config.up_proj +
|
||||||
|
((expert_id * config.intermediate_size * config.hidden_size + i * weight_elem_count) >> 1),
|
||||||
|
((sizeof(uint8_t) * weight_elem_count) >> 1));
|
||||||
|
|
||||||
|
memcpy((ggml_bf16_t*)tpc.gate_scale + (expert_id * scales_elem_count),
|
||||||
|
(ggml_bf16_t*)config.gate_scale +
|
||||||
|
(expert_id * (config.hidden_size / group_size) * config.intermediate_size +
|
||||||
|
i * scales_elem_count),
|
||||||
|
sizeof(ggml_bf16_t) * scales_elem_count);
|
||||||
|
|
||||||
|
memcpy((ggml_bf16_t*)tpc.up_scale + (expert_id * scales_elem_count),
|
||||||
|
(ggml_bf16_t*)config.up_scale +
|
||||||
|
(expert_id * (config.hidden_size / group_size) * config.intermediate_size +
|
||||||
|
i * scales_elem_count),
|
||||||
|
sizeof(ggml_bf16_t) * scales_elem_count);
|
||||||
|
|
||||||
|
// memcpy((uint8_t*)tpc.down_proj + ((expert_id * weight_elem_count) >> 1),
|
||||||
|
// (uint8_t*)config.down_proj +
|
||||||
|
// ((expert_id * config.intermediate_size * config.hidden_size + i * weight_elem_count) >> 1),
|
||||||
|
// ((sizeof(uint8_t) * weight_elem_count) >> 1));
|
||||||
|
|
||||||
|
// memcpy((ggml_bf16_t*)tpc.down_scale + (expert_id * scales_elem_count),
|
||||||
|
// (ggml_bf16_t*)config.down_scale +
|
||||||
|
// (expert_id * (config.intermediate_size / group_size) * config.hidden_size +
|
||||||
|
// i * scales_elem_count),
|
||||||
|
// sizeof(ggml_bf16_t) * scales_elem_count);
|
||||||
|
|
||||||
|
// weight and scale TP-slicing for down (by column)
|
||||||
|
for (size_t col = 0; col < config.hidden_size; col++) {
|
||||||
|
memcpy((uint8_t*)tpc.down_proj + ((expert_id * weight_elem_count + col * tpc.intermediate_size) >> 1),
|
||||||
|
(uint8_t*)config.down_proj + ((expert_id * config.intermediate_size * config.hidden_size +
|
||||||
|
col * config.intermediate_size + i * tpc.intermediate_size) >>
|
||||||
|
1),
|
||||||
|
(sizeof(uint8_t) * tpc.intermediate_size) >> 1);
|
||||||
|
memcpy((ggml_bf16_t*)tpc.down_scale + (expert_id * scales_elem_count + col * (tpc.intermediate_size / group_size)),
|
||||||
|
(ggml_bf16_t*)config.down_scale + ((expert_id * (config.intermediate_size / group_size) * config.hidden_size) +
|
||||||
|
col * (config.intermediate_size / group_size) + i * (tpc.intermediate_size / group_size)),
|
||||||
|
sizeof(ggml_bf16_t) * (tpc.intermediate_size / group_size));
|
||||||
|
}
|
||||||
|
},
|
||||||
|
nullptr);
|
||||||
|
}
|
||||||
|
printf("TP %d load weight done.\n", i);
|
||||||
|
}
|
||||||
|
|
||||||
|
DO_TPS_LOAD_WEIGHTS(pool);
|
||||||
|
|
||||||
|
for (auto i = 0; i < tp_count; i++) {
|
||||||
|
auto& tpc = tps[i]->config_;
|
||||||
|
delete[] (uint8_t*)(tpc.gate_proj);
|
||||||
|
delete[] (uint8_t*)(tpc.up_proj);
|
||||||
|
delete[] (uint8_t*)(tpc.down_proj);
|
||||||
|
|
||||||
|
delete[] (ggml_bf16_t*)(tpc.gate_scale);
|
||||||
|
delete[] (ggml_bf16_t*)(tpc.up_scale);
|
||||||
|
delete[] (ggml_bf16_t*)(tpc.down_scale);
|
||||||
|
}
|
||||||
|
|
||||||
|
this->weights_loaded = true;
|
||||||
|
}
|
||||||
|
|
||||||
|
void write_weight_scale_to_buffer(int gpu_tp_count, int gpu_experts_num,
|
||||||
|
const std::vector<uintptr_t>& w13_weight_ptrs,
|
||||||
|
const std::vector<uintptr_t>& w13_scale_ptrs,
|
||||||
|
const std::vector<uintptr_t>& w2_weight_ptrs,
|
||||||
|
const std::vector<uintptr_t>& w2_scale_ptrs) {
|
||||||
|
if (this->weights_loaded == false) {
|
||||||
|
throw std::runtime_error("Not Loaded");
|
||||||
|
}
|
||||||
|
if (this->tps.empty()) {
|
||||||
|
throw std::runtime_error("No TP parts initialized");
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate input vector sizes
|
||||||
|
if (w13_weight_ptrs.size() != gpu_tp_count || w13_scale_ptrs.size() != gpu_tp_count ||
|
||||||
|
w2_weight_ptrs.size() != gpu_tp_count || w2_scale_ptrs.size() != gpu_tp_count) {
|
||||||
|
throw std::runtime_error("Pointer arrays size must match gpu_tp_count");
|
||||||
|
}
|
||||||
|
|
||||||
|
// Each TP part writes to its corresponding buffer
|
||||||
|
for (int tp_idx = 0; tp_idx < this->tp_count; tp_idx++) {
|
||||||
|
// Note: w13 combines gate and up projections
|
||||||
|
// Split w13 pointers for gate and up
|
||||||
|
this->tps[tp_idx]->write_weights_to_buffer(
|
||||||
|
gpu_tp_count, this->tp_count,
|
||||||
|
gpu_experts_num, this->config,
|
||||||
|
w13_weight_ptrs, w13_scale_ptrs, //gate + up use w13
|
||||||
|
w2_weight_ptrs, w2_scale_ptrs); // down uses w2
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void merge_results(int qlen, void* output, bool incremental) {
|
||||||
|
auto pool = this->config.pool;
|
||||||
|
auto merge_fn = [this, output, incremental](int token_nth) {
|
||||||
|
auto& local_output_numa = this->local_output_numa;
|
||||||
|
auto& tp_configs = this->tp_configs;
|
||||||
|
auto& tp_count = this->tp_count;
|
||||||
|
auto& config = this->config;
|
||||||
|
float* merge_to = local_output_numa[0] + token_nth * tp_configs[0].hidden_size;
|
||||||
|
if (incremental) {
|
||||||
|
for (int e = 0; e < config.hidden_size; e += 32) {
|
||||||
|
__m512 x0, x1;
|
||||||
|
avx512_32xbf16_to_32xfp32((__m512i*)((ggml_bf16_t*)output + token_nth * config.hidden_size + e), &x0, &x1);
|
||||||
|
*((__m512*)(merge_to + e)) = _mm512_add_ps(*((__m512*)(merge_to + e)), x0);
|
||||||
|
*((__m512*)(merge_to + e + 16)) = _mm512_add_ps(*((__m512*)(merge_to + e + 16)), x1);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for (int i = 1; i < tp_count; i++) {
|
||||||
|
float* merge_from = local_output_numa[i] + token_nth * tp_configs[i].hidden_size;
|
||||||
|
for (int e = 0; e < tp_configs[i].hidden_size; e += 16) {
|
||||||
|
*((__m512*)(merge_to + e)) = _mm512_add_ps(*((__m512*)(merge_to + e)), *((__m512*)(merge_from + e)));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for (int e = 0; e < config.hidden_size; e += 32) {
|
||||||
|
__m512 x0 = *(__m512*)(merge_to + e);
|
||||||
|
__m512 x1 = *(__m512*)(merge_to + e + 16);
|
||||||
|
avx512_32xfp32_to_32xbf16(&x0, &x1, (__m512i*)((ggml_bf16_t*)output + token_nth * config.hidden_size + e));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
for (int i = 0; i < qlen; i++) {
|
||||||
|
merge_fn(i);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void merge_results(int qlen, void* output) { merge_results(qlen, output, false); }
|
||||||
|
};
|
||||||
|
|
||||||
|
#endif // CPUINFER_OPERATOR_AMX_K2_MOE_H
|
||||||
|
|
@ -4,6 +4,7 @@
|
||||||
#include <cassert>
|
#include <cassert>
|
||||||
#include <cstdint>
|
#include <cstdint>
|
||||||
#include <cstdio>
|
#include <cstdio>
|
||||||
|
#include <cstring>
|
||||||
#include <limits>
|
#include <limits>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
|
|
@ -344,9 +345,6 @@ struct BufferAKGroupImpl {
|
||||||
static constexpr int K_STEP = K::K_STEP;
|
static constexpr int K_STEP = K::K_STEP;
|
||||||
static constexpr int K_BLOCK = K::K_BLOCK;
|
static constexpr int K_BLOCK = K::K_BLOCK;
|
||||||
|
|
||||||
using index_t = Packed2DLayout::index_t;
|
|
||||||
Packed2DLayout pack;
|
|
||||||
|
|
||||||
static size_t required_size(int max_m, int k, int k_group_size) {
|
static size_t required_size(int max_m, int k, int k_group_size) {
|
||||||
ASSERT_RELEASE(k % k_group_size == 0, "k must be multiple of k_group_size");
|
ASSERT_RELEASE(k % k_group_size == 0, "k must be multiple of k_group_size");
|
||||||
return sizeof(int8_t) * max_m * k + sizeof(float) * max_m * (k / k_group_size);
|
return sizeof(int8_t) * max_m * k + sizeof(float) * max_m * (k / k_group_size);
|
||||||
|
|
@ -355,18 +353,12 @@ struct BufferAKGroupImpl {
|
||||||
BufferAKGroupImpl(int max_m, int k, int k_group_size, void* ptr)
|
BufferAKGroupImpl(int max_m, int k, int k_group_size, void* ptr)
|
||||||
: max_m(max_m),
|
: max_m(max_m),
|
||||||
k(k),
|
k(k),
|
||||||
k_group_size(k_group_size),
|
k_group_size(k_group_size) {
|
||||||
pack({{static_cast<index_t>(K_STEP), 'c'},
|
|
||||||
{static_cast<index_t>(M_STEP), 'r'},
|
|
||||||
{static_cast<index_t>(k_group_size / K_STEP), 'c'},
|
|
||||||
{static_cast<index_t>(K_BLOCK / k_group_size), 'c'},
|
|
||||||
{static_cast<index_t>(max_m / M_STEP), 'r'},
|
|
||||||
{static_cast<index_t>(k / K_BLOCK), 'c'}}) {
|
|
||||||
ASSERT_RELEASE(k % k_group_size == 0, "k must be multiple of k_group_size");
|
ASSERT_RELEASE(k % k_group_size == 0, "k must be multiple of k_group_size");
|
||||||
ASSERT_RELEASE(max_m % M_STEP == 0, "max_m must be multiple of M_STEP");
|
ASSERT_RELEASE(max_m % M_STEP == 0, "max_m must be multiple of M_STEP");
|
||||||
ASSERT_RELEASE(k % K_STEP == 0, "k must be multiple of K_STEP");
|
ASSERT_RELEASE(k % K_STEP == 0, "k must be multiple of K_STEP");
|
||||||
ASSERT_RELEASE(K_BLOCK % k_group_size == 0, "K_BLOCK must be multiple of k_group_size");
|
ASSERT_RELEASE(K_BLOCK % k_group_size == 0, "K_BLOCK must be multiple of k_group_size");
|
||||||
ASSERT_RELEASE(k % K_BLOCK == 0, "k must be multiple of K_BLOCK");
|
// ASSERT_RELEASE(k % K_BLOCK == 0, "k must be multiple of K_BLOCK");
|
||||||
k_group_count = k / k_group_size;
|
k_group_count = k / k_group_size;
|
||||||
|
|
||||||
set_data(ptr);
|
set_data(ptr);
|
||||||
|
|
@ -922,6 +914,77 @@ struct BufferBInt4WithZeroImpl {
|
||||||
float* get_min(int n, int n_begin) { return mins + n_begin; }
|
float* get_min(int n, int n_begin) { return mins + n_begin; }
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// BufferB for Signed Int4 with KGroup Scale (no zero point)
|
||||||
|
// Used for K2 MoE - signed int4 range: [-8, 7]
|
||||||
|
template <typename K>
|
||||||
|
struct BufferBInt4KGroupImpl {
|
||||||
|
using dt = typename K::dt;
|
||||||
|
dt* b; // packed signed int4 weights, col majored
|
||||||
|
float* d; // scales only (no mins/zero-points), row majored
|
||||||
|
int n, k, k_group_size, k_group_count;
|
||||||
|
|
||||||
|
static constexpr int N_STEP = K::N_STEP;
|
||||||
|
static constexpr int K_STEP = K::K_STEP;
|
||||||
|
static constexpr bool SCALE = true;
|
||||||
|
|
||||||
|
// Size calculation: packed int4 weights + scales (NO mins)
|
||||||
|
static size_t required_size(int n, int k, int k_group_size) {
|
||||||
|
return sizeof(int8_t) * n * k / 2 + sizeof(float) * n * (k / k_group_size);
|
||||||
|
}
|
||||||
|
|
||||||
|
BufferBInt4KGroupImpl(int n, int k, int k_group_size, void* ptr) : n(n), k(k), k_group_size(k_group_size) {
|
||||||
|
assert(reinterpret_cast<intptr_t>(ptr) % 64 == 0);
|
||||||
|
assert(n % N_STEP == 0);
|
||||||
|
assert(k % K_STEP == 0);
|
||||||
|
if (n % N_STEP || k % K_STEP || k % k_group_size) {
|
||||||
|
printf("BufferBInt4KGroupImpl: n: %d, k: %d, N_STEP: %d, K_STEP: %d, k_group_size: %d\n", n, k, N_STEP,
|
||||||
|
K_STEP, k_group_size);
|
||||||
|
throw std::runtime_error("n or k is not aligned to N_STEP or K_STEP");
|
||||||
|
}
|
||||||
|
k_group_count = k / k_group_size;
|
||||||
|
b = reinterpret_cast<dt*>(ptr);
|
||||||
|
d = reinterpret_cast<float*>(offset_pointer(b, n * k / 2));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Load from packed signed int4 format
|
||||||
|
// Input: proj is packed int4 weights (2 int4 values per byte)
|
||||||
|
// Each int4 value is in range [-8, 7] (signed)
|
||||||
|
void from_raw_mat(uint8_t* proj, int ith, int nth) {
|
||||||
|
auto [n_start, n_end] = K::split_range_n(n, ith, nth);
|
||||||
|
if (n_start >= n_end) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
const size_t row_bytes = static_cast<size_t>(k) / 2;
|
||||||
|
const size_t rows = static_cast<size_t>(n_end - n_start);
|
||||||
|
uint8_t* dst_weights = reinterpret_cast<uint8_t*>(b) + n_start * row_bytes;
|
||||||
|
const uint8_t* src_weights = proj + n_start * row_bytes;
|
||||||
|
std::memcpy(dst_weights, src_weights, rows * row_bytes);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get pointer to submatrix for computation
|
||||||
|
dt* get_submat(int n, int k, int n_begin, int k_begin) {
|
||||||
|
const size_t row_bytes = static_cast<size_t>(k) / 2;
|
||||||
|
const size_t row_offset = static_cast<size_t>(n_begin) * row_bytes;
|
||||||
|
const size_t col_offset = static_cast<size_t>(k_begin) / 2;
|
||||||
|
return reinterpret_cast<dt*>(reinterpret_cast<uint8_t*>(b) + row_offset + col_offset);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get scale pointer for a specific row and k_group
|
||||||
|
float* get_scale(int n, int n_begin, int k, int k_begin) {
|
||||||
|
int k_group_idx = k_begin / k_group_size;
|
||||||
|
return d + n_begin * (k / k_group_size) + k_group_idx;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Split range for parallel processing
|
||||||
|
static std::pair<int, int> split_range_n(int n, int ith, int nth) {
|
||||||
|
int n_per_thread = (n + nth - 1) / nth;
|
||||||
|
n_per_thread = (n_per_thread + N_STEP - 1) / N_STEP * N_STEP;
|
||||||
|
int n_start = std::min(ith * n_per_thread, n);
|
||||||
|
int n_end = std::min(n_start + n_per_thread, n);
|
||||||
|
return {n_start, n_end};
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
template <typename K>
|
template <typename K>
|
||||||
struct BufferBInt4WithZeroKGroupImpl {
|
struct BufferBInt4WithZeroKGroupImpl {
|
||||||
using dt = typename K::dt;
|
using dt = typename K::dt;
|
||||||
|
|
|
||||||
|
|
@ -1015,8 +1015,9 @@ struct GemmKernel224Int8 {
|
||||||
static void avx_kernel(int m, int n, int k, int m_begin, int n_begin, int k_block_begin, float* c, BufferA* ba,
|
static void avx_kernel(int m, int n, int k, int m_begin, int n_begin, int k_block_begin, float* c, BufferA* ba,
|
||||||
BufferB* bb) {
|
BufferB* bb) {
|
||||||
__m512i* c512 = (__m512i*)c;
|
__m512i* c512 = (__m512i*)c;
|
||||||
|
int m_block_end = std::min(m - m_begin, M_STEP);
|
||||||
if (k_block_begin == 0) {
|
if (k_block_begin == 0) {
|
||||||
for (int m_i = 0; m_i < m; m_i++) {
|
for (int m_i = 0; m_i < m_block_end; m_i++) {
|
||||||
c512[m_i * 2] = _mm512_setzero_si512();
|
c512[m_i * 2] = _mm512_setzero_si512();
|
||||||
c512[m_i * 2 + 1] = _mm512_setzero_si512();
|
c512[m_i * 2 + 1] = _mm512_setzero_si512();
|
||||||
}
|
}
|
||||||
|
|
@ -1028,7 +1029,7 @@ struct GemmKernel224Int8 {
|
||||||
|
|
||||||
int32_t* a32 = (int32_t*)ba->get_submat(m, k, m_begin, k_block_begin + k_begin);
|
int32_t* a32 = (int32_t*)ba->get_submat(m, k, m_begin, k_block_begin + k_begin);
|
||||||
__m512i* b512 = (__m512i*)bb->get_submat(n, k, n_begin, k_block_begin + k_begin);
|
__m512i* b512 = (__m512i*)bb->get_submat(n, k, n_begin, k_block_begin + k_begin);
|
||||||
for (int m_i = 0; m_i < m && m_i < M_STEP; m_i++) {
|
for (int m_i = 0; m_i < m_block_end; m_i++) {
|
||||||
for (int k_i = 0; k_i < 16; k_i++) {
|
for (int k_i = 0; k_i < 16; k_i++) {
|
||||||
__m512i ma = _mm512_set1_epi32(a32[m_i * 16 + k_i]);
|
__m512i ma = _mm512_set1_epi32(a32[m_i * 16 + k_i]);
|
||||||
for (int n_i = 0; n_i < 2; n_i++) {
|
for (int n_i = 0; n_i < 2; n_i++) {
|
||||||
|
|
@ -1239,8 +1240,9 @@ struct GemmKernel224Int4 {
|
||||||
BufferB* bb) {
|
BufferB* bb) {
|
||||||
using K = GemmKernel224Int4;
|
using K = GemmKernel224Int4;
|
||||||
__m512i* c512 = (__m512i*)c;
|
__m512i* c512 = (__m512i*)c;
|
||||||
|
int m_block_end = std::min(m - m_begin, M_STEP);
|
||||||
if (k_block_begin == 0) {
|
if (k_block_begin == 0) {
|
||||||
for (int m_i = 0; m_i < m; m_i++) {
|
for (int m_i = 0; m_i < m_block_end; m_i++) {
|
||||||
c512[m_i * 2] = _mm512_setzero_si512();
|
c512[m_i * 2] = _mm512_setzero_si512();
|
||||||
c512[m_i * 2 + 1] = _mm512_setzero_si512();
|
c512[m_i * 2 + 1] = _mm512_setzero_si512();
|
||||||
}
|
}
|
||||||
|
|
@ -1250,7 +1252,7 @@ struct GemmKernel224Int4 {
|
||||||
int32_t* a32_lo = (int32_t*)ba->get_submat(m, k, m_begin, k_block_begin + k_begin);
|
int32_t* a32_lo = (int32_t*)ba->get_submat(m, k, m_begin, k_block_begin + k_begin);
|
||||||
int32_t* a32_hi = (int32_t*)ba->get_submat(m, k, m_begin, k_block_begin + k_begin + K::K_STEP);
|
int32_t* a32_hi = (int32_t*)ba->get_submat(m, k, m_begin, k_block_begin + k_begin + K::K_STEP);
|
||||||
__m512i* b512 = (__m512i*)bb->get_submat(n, k, n_begin, k_block_begin + k_begin);
|
__m512i* b512 = (__m512i*)bb->get_submat(n, k, n_begin, k_block_begin + k_begin);
|
||||||
for (int m_i = 0; m_i < m && m_i < M_STEP; m_i++) {
|
for (int m_i = 0; m_i < m_block_end; m_i++) {
|
||||||
for (int k_i = 0; k_i < 16; k_i++) {
|
for (int k_i = 0; k_i < 16; k_i++) {
|
||||||
__m512i ma_lo = _mm512_set1_epi32(a32_lo[m_i * 16 + k_i]);
|
__m512i ma_lo = _mm512_set1_epi32(a32_lo[m_i * 16 + k_i]);
|
||||||
__m512i ma_hi = _mm512_set1_epi32(a32_hi[m_i * 16 + k_i]);
|
__m512i ma_hi = _mm512_set1_epi32(a32_hi[m_i * 16 + k_i]);
|
||||||
|
|
@ -1533,8 +1535,9 @@ struct GemmKernel224Int4_1 {
|
||||||
BufferB* bb) {
|
BufferB* bb) {
|
||||||
using K = GemmKernel224Int4_1;
|
using K = GemmKernel224Int4_1;
|
||||||
__m512i* c512 = (__m512i*)c;
|
__m512i* c512 = (__m512i*)c;
|
||||||
|
int m_block_end = std::min(m - m_begin, M_STEP);
|
||||||
if (k_block_begin == 0) {
|
if (k_block_begin == 0) {
|
||||||
for (int m_i = 0; m_i < m; m_i++) {
|
for (int m_i = 0; m_i < m_block_end; m_i++) {
|
||||||
c512[m_i * 2] = _mm512_setzero_si512();
|
c512[m_i * 2] = _mm512_setzero_si512();
|
||||||
c512[m_i * 2 + 1] = _mm512_setzero_si512();
|
c512[m_i * 2 + 1] = _mm512_setzero_si512();
|
||||||
}
|
}
|
||||||
|
|
@ -1543,7 +1546,7 @@ struct GemmKernel224Int4_1 {
|
||||||
int32_t* a32_lo = (int32_t*)ba->get_submat(m, k, m_begin, k_block_begin + k_begin);
|
int32_t* a32_lo = (int32_t*)ba->get_submat(m, k, m_begin, k_block_begin + k_begin);
|
||||||
int32_t* a32_hi = (int32_t*)ba->get_submat(m, k, m_begin, k_block_begin + k_begin + K::K_STEP);
|
int32_t* a32_hi = (int32_t*)ba->get_submat(m, k, m_begin, k_block_begin + k_begin + K::K_STEP);
|
||||||
__m512i* b512 = (__m512i*)bb->get_submat(n, k, n_begin, k_block_begin + k_begin);
|
__m512i* b512 = (__m512i*)bb->get_submat(n, k, n_begin, k_block_begin + k_begin);
|
||||||
for (int m_i = 0; m_i < m && m_i < M_STEP; m_i++) {
|
for (int m_i = 0; m_i < m_block_end; m_i++) {
|
||||||
for (int k_i = 0; k_i < 16; k_i++) {
|
for (int k_i = 0; k_i < 16; k_i++) {
|
||||||
__m512i ma_lo = _mm512_set1_epi32(a32_lo[m_i * 16 + k_i]);
|
__m512i ma_lo = _mm512_set1_epi32(a32_lo[m_i * 16 + k_i]);
|
||||||
__m512i ma_hi = _mm512_set1_epi32(a32_hi[m_i * 16 + k_i]);
|
__m512i ma_hi = _mm512_set1_epi32(a32_hi[m_i * 16 + k_i]);
|
||||||
|
|
@ -2193,10 +2196,11 @@ struct GemmKernel224Int4KGroup {
|
||||||
BufferB* bb, int k_group_size) {
|
BufferB* bb, int k_group_size) {
|
||||||
using K = GemmKernel224Int4KGroup;
|
using K = GemmKernel224Int4KGroup;
|
||||||
__m512i* c512 = (__m512i*)int_c;
|
__m512i* c512 = (__m512i*)int_c;
|
||||||
|
int m_block_end = std::min(m - m_begin, M_STEP);
|
||||||
|
|
||||||
// Initialize int_c to zero at the start of k_group
|
// Initialize int_c to zero at the start of k_group
|
||||||
if (k_block_begin % k_group_size == 0) {
|
if (k_block_begin % k_group_size == 0) {
|
||||||
for (int m_i = 0; m_i < m && m_i < M_STEP; m_i++) {
|
for (int m_i = 0; m_i < m_block_end; m_i++) {
|
||||||
c512[m_i * 2] = _mm512_setzero_si512();
|
c512[m_i * 2] = _mm512_setzero_si512();
|
||||||
c512[m_i * 2 + 1] = _mm512_setzero_si512();
|
c512[m_i * 2 + 1] = _mm512_setzero_si512();
|
||||||
}
|
}
|
||||||
|
|
@ -2205,7 +2209,7 @@ struct GemmKernel224Int4KGroup {
|
||||||
if (k_offset == 0) {
|
if (k_offset == 0) {
|
||||||
int32_t* a32_lo = (int32_t*)ba->get_submat(m, k, m_begin, k_block_begin);
|
int32_t* a32_lo = (int32_t*)ba->get_submat(m, k, m_begin, k_block_begin);
|
||||||
__m512i* b512 = (__m512i*)bb->get_submat(n, k, n_begin, k_block_begin);
|
__m512i* b512 = (__m512i*)bb->get_submat(n, k, n_begin, k_block_begin);
|
||||||
for (int m_i = 0; m_i < m && m_i < M_STEP; m_i++) {
|
for (int m_i = 0; m_i < m_block_end; m_i++) {
|
||||||
for (int k_i = 0; k_i < 16; k_i++) {
|
for (int k_i = 0; k_i < 16; k_i++) {
|
||||||
__m512i ma_lo = _mm512_set1_epi32(a32_lo[m_i * 16 + k_i]);
|
__m512i ma_lo = _mm512_set1_epi32(a32_lo[m_i * 16 + k_i]);
|
||||||
for (int n_i = 0; n_i < 2; n_i++) {
|
for (int n_i = 0; n_i < 2; n_i++) {
|
||||||
|
|
@ -2217,7 +2221,7 @@ struct GemmKernel224Int4KGroup {
|
||||||
} else {
|
} else {
|
||||||
int32_t* a32_hi = (int32_t*)ba->get_submat(m, k, m_begin, k_block_begin);
|
int32_t* a32_hi = (int32_t*)ba->get_submat(m, k, m_begin, k_block_begin);
|
||||||
__m512i* b512 = (__m512i*)bb->get_submat(n, k, n_begin, k_block_begin - K::K_STEP);
|
__m512i* b512 = (__m512i*)bb->get_submat(n, k, n_begin, k_block_begin - K::K_STEP);
|
||||||
for (int m_i = 0; m_i < m && m_i < M_STEP; m_i++) {
|
for (int m_i = 0; m_i < m_block_end; m_i++) {
|
||||||
for (int k_i = 0; k_i < 16; k_i++) {
|
for (int k_i = 0; k_i < 16; k_i++) {
|
||||||
__m512i ma_hi = _mm512_set1_epi32(a32_hi[m_i * 16 + k_i]);
|
__m512i ma_hi = _mm512_set1_epi32(a32_hi[m_i * 16 + k_i]);
|
||||||
for (int n_i = 0; n_i < 2; n_i++) {
|
for (int n_i = 0; n_i < 2; n_i++) {
|
||||||
|
|
@ -2471,8 +2475,9 @@ struct GemmKernel224Int4_1KGroup {
|
||||||
BufferB* bb, int k_group_size) {
|
BufferB* bb, int k_group_size) {
|
||||||
using K = GemmKernel224Int4_1KGroup;
|
using K = GemmKernel224Int4_1KGroup;
|
||||||
__m512i* c512 = (__m512i*)int_c;
|
__m512i* c512 = (__m512i*)int_c;
|
||||||
|
int m_block_end = std::min(m - m_begin, M_STEP);
|
||||||
if (k_block_begin % k_group_size == 0) {
|
if (k_block_begin % k_group_size == 0) {
|
||||||
for (int m_i = 0; m_i < m; m_i++) {
|
for (int m_i = 0; m_i < m_block_end; m_i++) {
|
||||||
c512[m_i * 2] = _mm512_setzero_si512();
|
c512[m_i * 2] = _mm512_setzero_si512();
|
||||||
c512[m_i * 2 + 1] = _mm512_setzero_si512();
|
c512[m_i * 2 + 1] = _mm512_setzero_si512();
|
||||||
}
|
}
|
||||||
|
|
@ -2481,7 +2486,7 @@ struct GemmKernel224Int4_1KGroup {
|
||||||
if (k_offset == 0) {
|
if (k_offset == 0) {
|
||||||
int32_t* a32_lo = (int32_t*)ba->get_submat(m, k, m_begin, k_block_begin);
|
int32_t* a32_lo = (int32_t*)ba->get_submat(m, k, m_begin, k_block_begin);
|
||||||
__m512i* b512 = (__m512i*)bb->get_submat(n, k, n_begin, k_block_begin);
|
__m512i* b512 = (__m512i*)bb->get_submat(n, k, n_begin, k_block_begin);
|
||||||
for (int m_i = 0; m_i < m && m_i < M_STEP; m_i++) {
|
for (int m_i = 0; m_i < m_block_end; m_i++) {
|
||||||
for (int k_i = 0; k_i < 16; k_i++) {
|
for (int k_i = 0; k_i < 16; k_i++) {
|
||||||
__m512i ma_lo = _mm512_set1_epi32(a32_lo[m_i * 16 + k_i]);
|
__m512i ma_lo = _mm512_set1_epi32(a32_lo[m_i * 16 + k_i]);
|
||||||
for (int n_i = 0; n_i < 2; n_i++) {
|
for (int n_i = 0; n_i < 2; n_i++) {
|
||||||
|
|
@ -2493,7 +2498,7 @@ struct GemmKernel224Int4_1KGroup {
|
||||||
} else {
|
} else {
|
||||||
int32_t* a32_hi = (int32_t*)ba->get_submat(m, k, m_begin, k_block_begin);
|
int32_t* a32_hi = (int32_t*)ba->get_submat(m, k, m_begin, k_block_begin);
|
||||||
__m512i* b512 = (__m512i*)bb->get_submat(n, k, n_begin, k_block_begin - K::K_STEP);
|
__m512i* b512 = (__m512i*)bb->get_submat(n, k, n_begin, k_block_begin - K::K_STEP);
|
||||||
for (int m_i = 0; m_i < m && m_i < M_STEP; m_i++) {
|
for (int m_i = 0; m_i < m_block_end; m_i++) {
|
||||||
for (int k_i = 0; k_i < 16; k_i++) {
|
for (int k_i = 0; k_i < 16; k_i++) {
|
||||||
__m512i ma_hi = _mm512_set1_epi32(a32_hi[m_i * 16 + k_i]);
|
__m512i ma_hi = _mm512_set1_epi32(a32_hi[m_i * 16 + k_i]);
|
||||||
for (int n_i = 0; n_i < 2; n_i++) {
|
for (int n_i = 0; n_i < 2; n_i++) {
|
||||||
|
|
@ -2746,8 +2751,9 @@ struct GemmKernel224Int4_1_LowKGroup {
|
||||||
BufferB* bb, int k_group_size) {
|
BufferB* bb, int k_group_size) {
|
||||||
using K = GemmKernel224Int4_1_LowKGroup;
|
using K = GemmKernel224Int4_1_LowKGroup;
|
||||||
__m512i* c512 = (__m512i*)int_c;
|
__m512i* c512 = (__m512i*)int_c;
|
||||||
|
int m_block_end = std::min(m - m_begin, M_STEP);
|
||||||
if (k_block_begin % k_group_size == 0) {
|
if (k_block_begin % k_group_size == 0) {
|
||||||
for (int m_i = 0; m_i < m; m_i++) {
|
for (int m_i = 0; m_i < m_block_end; m_i++) {
|
||||||
c512[m_i * 2] = _mm512_setzero_si512();
|
c512[m_i * 2] = _mm512_setzero_si512();
|
||||||
c512[m_i * 2 + 1] = _mm512_setzero_si512();
|
c512[m_i * 2 + 1] = _mm512_setzero_si512();
|
||||||
}
|
}
|
||||||
|
|
@ -2756,7 +2762,7 @@ struct GemmKernel224Int4_1_LowKGroup {
|
||||||
if (k_offset == 0) {
|
if (k_offset == 0) {
|
||||||
int32_t* a32_lo = (int32_t*)ba->get_submat(m, k, m_begin, k_block_begin);
|
int32_t* a32_lo = (int32_t*)ba->get_submat(m, k, m_begin, k_block_begin);
|
||||||
__m512i* b512 = (__m512i*)bb->get_submat(n, k, n_begin, k_block_begin);
|
__m512i* b512 = (__m512i*)bb->get_submat(n, k, n_begin, k_block_begin);
|
||||||
for (int m_i = 0; m_i < m && m_i < M_STEP; m_i++) {
|
for (int m_i = 0; m_i < m_block_end; m_i++) {
|
||||||
for (int k_i = 0; k_i < 16; k_i++) {
|
for (int k_i = 0; k_i < 16; k_i++) {
|
||||||
__m512i ma_lo = _mm512_set1_epi32(a32_lo[m_i * 16 + k_i]);
|
__m512i ma_lo = _mm512_set1_epi32(a32_lo[m_i * 16 + k_i]);
|
||||||
for (int n_i = 0; n_i < 2; n_i++) {
|
for (int n_i = 0; n_i < 2; n_i++) {
|
||||||
|
|
@ -2768,7 +2774,7 @@ struct GemmKernel224Int4_1_LowKGroup {
|
||||||
} else {
|
} else {
|
||||||
int32_t* a32_hi = (int32_t*)ba->get_submat(m, k, m_begin, k_block_begin);
|
int32_t* a32_hi = (int32_t*)ba->get_submat(m, k, m_begin, k_block_begin);
|
||||||
__m512i* b512 = (__m512i*)bb->get_submat(n, k, n_begin, k_block_begin - K::K_STEP);
|
__m512i* b512 = (__m512i*)bb->get_submat(n, k, n_begin, k_block_begin - K::K_STEP);
|
||||||
for (int m_i = 0; m_i < m && m_i < M_STEP; m_i++) {
|
for (int m_i = 0; m_i < m_block_end; m_i++) {
|
||||||
for (int k_i = 0; k_i < 16; k_i++) {
|
for (int k_i = 0; k_i < 16; k_i++) {
|
||||||
__m512i ma_hi = _mm512_set1_epi32(a32_hi[m_i * 16 + k_i]);
|
__m512i ma_hi = _mm512_set1_epi32(a32_hi[m_i * 16 + k_i]);
|
||||||
for (int n_i = 0; n_i < 2; n_i++) {
|
for (int n_i = 0; n_i < 2; n_i++) {
|
||||||
|
|
@ -2837,6 +2843,110 @@ struct GemmKernel224Int4_1_LowKGroup {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// K2 Signed Int4 K-group quantization kernel (AVX only, no AMX)
|
||||||
|
// For K2 MoE - signed int4 range: [-8, 7]
|
||||||
|
struct GemmKernel224Int4SmallKGroup {
|
||||||
|
using dt = uint8_t; // packed int4 type
|
||||||
|
using output_t = int32_t;
|
||||||
|
static constexpr double ELEMENT_SIZE = 0.5;
|
||||||
|
static const int VNNI_BLK = 4;
|
||||||
|
|
||||||
|
static const int M_STEP = 1;
|
||||||
|
static const int N_STEP = 32;
|
||||||
|
static const int K_STEP = 32;
|
||||||
|
|
||||||
|
static inline const int N_BLOCK = 256;
|
||||||
|
// K_BLOCK should match k_group_size for proper scaling
|
||||||
|
static inline const int K_BLOCK = 7168; // Will be overridden by k_group_size
|
||||||
|
|
||||||
|
static std::string name() { return "K2_INT4_KGROUP"; }
|
||||||
|
static int recommended_nth(int n) { return (n + N_BLOCK - 1) / N_BLOCK; }
|
||||||
|
static std::pair<int, int> split_range_n(int n, int ith, int nth) {
|
||||||
|
int n_start = N_BLOCK * ith;
|
||||||
|
int n_end = std::min(n, N_BLOCK * (ith + 1));
|
||||||
|
return {n_start, n_end};
|
||||||
|
}
|
||||||
|
static void config() {}
|
||||||
|
|
||||||
|
alignas(64) static constexpr uint8_t hi_mask_arr[32] = {
|
||||||
|
0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0,
|
||||||
|
0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0
|
||||||
|
};
|
||||||
|
|
||||||
|
alignas(64) static constexpr uint8_t lo_mask_arr[32] = {
|
||||||
|
0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F,
|
||||||
|
0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F
|
||||||
|
};
|
||||||
|
|
||||||
|
alignas(64) static constexpr uint8_t sign_xor_arr[32] = {
|
||||||
|
0x88, 0x88, 0x88, 0x88, 0x88, 0x88, 0x88, 0x88, 0x88, 0x88, 0x88, 0x88, 0x88, 0x88, 0x88, 0x88,
|
||||||
|
0x88, 0x88, 0x88, 0x88, 0x88, 0x88, 0x88, 0x88, 0x88, 0x88, 0x88, 0x88, 0x88, 0x88, 0x88, 0x88
|
||||||
|
};
|
||||||
|
static __m256i hi_mask() { return *((__m256i*)(&hi_mask_arr[0])); }
|
||||||
|
static __m256i lo_mask() { return *((__m256i*)(&lo_mask_arr[0])); }
|
||||||
|
static __m256i sign_xor_mask() { return *((__m256i*)(&sign_xor_arr[0])); }
|
||||||
|
|
||||||
|
using BufferA = BufferAKGroupImpl<GemmKernel224Int4SmallKGroup>;
|
||||||
|
using BufferB = BufferBInt4KGroupImpl<GemmKernel224Int4SmallKGroup>; // Use new signed int4 buffer
|
||||||
|
using BufferC = BufferCReduceImpl<GemmKernel224Int4SmallKGroup>;
|
||||||
|
|
||||||
|
// K-group aware AVX kernel for signed int4
|
||||||
|
static inline __m512i compressed_int4_to_int8_avx512(__m256i b256) {
|
||||||
|
b256 = _mm256_xor_si256(b256, sign_xor_mask());
|
||||||
|
__m256i b_hi = _mm256_and_si256(b256, hi_mask());
|
||||||
|
__m256i b_lo = _mm256_slli_epi16(_mm256_andnot_si256(hi_mask(), b256), 4);
|
||||||
|
|
||||||
|
__m256i unpack_lo = _mm256_unpacklo_epi8(b_lo, b_hi);
|
||||||
|
__m256i unpack_hi = _mm256_unpackhi_epi8(b_lo, b_hi);
|
||||||
|
__m512i result = _mm512_inserti64x4(_mm512_castsi256_si512(unpack_lo), unpack_hi, 1);
|
||||||
|
const __m512i lane_shuffle = _mm512_set_epi64(7, 6, 3, 2, 5, 4, 1, 0);
|
||||||
|
return _mm512_permutexvar_epi64(lane_shuffle, result);
|
||||||
|
}
|
||||||
|
static inline void integer_mat_vec_kgroup(int m, int n, int k, int k_group_size, BufferA* ba, BufferB *bb, BufferC* bc, int ith, int nth) {
|
||||||
|
auto [n_start, n_end] = split_range_n(n, ith, nth);
|
||||||
|
for (int m_begin = 0; m_begin < m; m_begin ++) {
|
||||||
|
float* c = bc->get_submat(m, n, m_begin, 0);
|
||||||
|
__m512i* a512 = (__m512i*)ba->get_submat(m, k, m_begin, 0);
|
||||||
|
|
||||||
|
for (int n_block_begin = n_start; n_block_begin < n_end; n_block_begin ++) {
|
||||||
|
__m256i* b256 = (__m256i*)bb->get_submat(n, k, n_block_begin, 0);
|
||||||
|
float* as = (float*)ba->get_scale(m, m_begin, k, 0);
|
||||||
|
float* bs = (float*)bb->get_scale(n, n_block_begin, k, 0);
|
||||||
|
|
||||||
|
__m512 sum = _mm512_setzero_ps();
|
||||||
|
#define WORK_K_BLOCK(k_block) \
|
||||||
|
{ \
|
||||||
|
__m256 abscale0 = _mm256_set1_ps(as[(k_block)*2] * bs[(k_block)*2]); \
|
||||||
|
__m256 abscale1 = _mm256_set1_ps(as[(k_block)*2+1] * bs[(k_block)*2+1]); \
|
||||||
|
__m512 abscale = _mm512_insertf32x8(_mm512_castps256_ps512(abscale0), abscale1, 1); \
|
||||||
|
__m512i mul = _mm512_setzero_si512(); \
|
||||||
|
mul = _mm512_dpbssd_epi32(mul, a512[k_block], compressed_int4_to_int8_avx512(b256[k_block])); \
|
||||||
|
sum = _mm512_add_ps(sum, _mm512_mul_ps(abscale, _mm512_cvtepi32_ps(mul))); \
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int k_block = 0; k_block < k / 64; k_block += 2) {
|
||||||
|
WORK_K_BLOCK(k_block);
|
||||||
|
WORK_K_BLOCK(k_block + 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
c[n_block_begin] = _mm512_reduce_add_ps(sum) / 16;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
inline void vec_mul_kgroup(int m, int n, int k, int k_group_size, std::shared_ptr<GemmKernel224Int4SmallKGroup::BufferA> ba,
|
||||||
|
std::shared_ptr<GemmKernel224Int4SmallKGroup::BufferB> bb,
|
||||||
|
std::shared_ptr<GemmKernel224Int4SmallKGroup::BufferC> bc, int ith, int nth) {
|
||||||
|
GemmKernel224Int4SmallKGroup::integer_mat_vec_kgroup(m, n, k, k_group_size, ba.get(), bb.get(), bc.get(), ith, nth);
|
||||||
|
}
|
||||||
|
|
||||||
|
inline void mat_mul_kgroup(int m, int n, int k, int k_group_size, std::shared_ptr<GemmKernel224Int4SmallKGroup::BufferA> ba,
|
||||||
|
std::shared_ptr<GemmKernel224Int4SmallKGroup::BufferB> bb,
|
||||||
|
std::shared_ptr<GemmKernel224Int4SmallKGroup::BufferC> bc, int ith, int nth) {
|
||||||
|
GemmKernel224Int4SmallKGroup::integer_mat_vec_kgroup(m, n, k, k_group_size, ba.get(), bb.get(), bc.get(), ith, nth);
|
||||||
|
}
|
||||||
|
|
||||||
// New k-group aware matrix multiplication function
|
// New k-group aware matrix multiplication function
|
||||||
template <typename K, bool amx_or_avx = true>
|
template <typename K, bool amx_or_avx = true>
|
||||||
void integer_mat_mul_kgroup(int m, int n, int k, int k_group_size, typename K::BufferA* ba, typename K::BufferB* bb,
|
void integer_mat_mul_kgroup(int m, int n, int k, int k_group_size, typename K::BufferA* ba, typename K::BufferB* bb,
|
||||||
|
|
|
||||||
|
|
@ -17,7 +17,7 @@ from typing import List, Optional
|
||||||
from .experts_base import BaseMoEWrapper, KExpertsCPUBuffer
|
from .experts_base import BaseMoEWrapper, KExpertsCPUBuffer
|
||||||
|
|
||||||
# Import backend implementations
|
# Import backend implementations
|
||||||
from .utils.amx import AMXMoEWrapper
|
from .utils.amx import AMXMoEWrapper, RAWAMXMoEWrapper
|
||||||
from .utils.llamafile import LlamafileMoEWrapper
|
from .utils.llamafile import LlamafileMoEWrapper
|
||||||
from .utils.moe_kernel import GeneralMoEWrapper
|
from .utils.moe_kernel import GeneralMoEWrapper
|
||||||
|
|
||||||
|
|
@ -77,7 +77,7 @@ class KTMoEWrapper:
|
||||||
chunked_prefill_size: Maximum prefill chunk size
|
chunked_prefill_size: Maximum prefill chunk size
|
||||||
cpu_save: Whether to save weights to CPU memory
|
cpu_save: Whether to save weights to CPU memory
|
||||||
max_deferred_experts_per_token: Number of experts per token to defer. Defaults to 0.
|
max_deferred_experts_per_token: Number of experts per token to defer. Defaults to 0.
|
||||||
method: Backend method ("AMXINT4", "AMXINT8", "LLAMAFILE", "MOE_INT4", "MOE_INT8")
|
method: Backend method ("AMXINT4", "AMXINT8", "RAWINT4", "LLAMAFILE", "MOE_INT4", "MOE_INT8")
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
An instance of the appropriate backend implementation (e.g., AMXMoEWrapper)
|
An instance of the appropriate backend implementation (e.g., AMXMoEWrapper)
|
||||||
|
|
@ -85,6 +85,8 @@ class KTMoEWrapper:
|
||||||
# Select backend based on method
|
# Select backend based on method
|
||||||
if method in ["AMXINT4", "AMXINT8"]:
|
if method in ["AMXINT4", "AMXINT8"]:
|
||||||
backend_cls = AMXMoEWrapper
|
backend_cls = AMXMoEWrapper
|
||||||
|
elif method == "RAWINT4":
|
||||||
|
backend_cls = RAWAMXMoEWrapper
|
||||||
elif method == "LLAMAFILE":
|
elif method == "LLAMAFILE":
|
||||||
backend_cls = LlamafileMoEWrapper
|
backend_cls = LlamafileMoEWrapper
|
||||||
elif method in ["MOE_INT4", "MOE_INT8"]:
|
elif method in ["MOE_INT4", "MOE_INT8"]:
|
||||||
|
|
|
||||||
|
|
@ -4,13 +4,15 @@
|
||||||
Utilities for kt_kernel package.
|
Utilities for kt_kernel package.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from .amx import AMXMoEWrapper
|
from .amx import AMXMoEWrapper, RAWAMXMoEWrapper
|
||||||
from .llamafile import LlamafileMoEWrapper
|
from .llamafile import LlamafileMoEWrapper
|
||||||
from .loader import SafeTensorLoader, GGUFLoader
|
from .loader import SafeTensorLoader, GGUFLoader, CompressedSafeTensorLoader
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"AMXMoEWrapper",
|
"AMXMoEWrapper",
|
||||||
|
"RAWAMXMoEWrapper",
|
||||||
"LlamafileMoEWrapper",
|
"LlamafileMoEWrapper",
|
||||||
"SafeTensorLoader",
|
"SafeTensorLoader",
|
||||||
|
"CompressedSafeTensorLoader",
|
||||||
"GGUFLoader",
|
"GGUFLoader",
|
||||||
]
|
]
|
||||||
|
|
|
||||||
|
|
@ -4,16 +4,16 @@ import ctypes
|
||||||
|
|
||||||
# Use relative imports for package structure
|
# Use relative imports for package structure
|
||||||
from ..experts_base import BaseMoEWrapper
|
from ..experts_base import BaseMoEWrapper
|
||||||
from .loader import SafeTensorLoader
|
from .loader import SafeTensorLoader, CompressedSafeTensorLoader
|
||||||
from kt_kernel_ext.moe import MOEConfig
|
from kt_kernel_ext.moe import MOEConfig
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from kt_kernel_ext.moe import AMXInt4_MOE, AMXInt8_MOE
|
from kt_kernel_ext.moe import AMXInt4_MOE, AMXInt8_MOE, AMXInt4_KGroup_MOE
|
||||||
|
|
||||||
_HAS_AMX_SUPPORT = True
|
_HAS_AMX_SUPPORT = True
|
||||||
except (ImportError, AttributeError):
|
except (ImportError, AttributeError):
|
||||||
_HAS_AMX_SUPPORT = False
|
_HAS_AMX_SUPPORT = False
|
||||||
AMXInt4_MOE, AMXInt8_MOE = None, None
|
AMXInt4_MOE, AMXInt8_MOE, AMXInt4_KGroup_MOE = None, None, None
|
||||||
|
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
|
|
@ -301,3 +301,152 @@ class AMXMoEWrapper(BaseMoEWrapper):
|
||||||
del self.gate_scales
|
del self.gate_scales
|
||||||
del self.up_scales
|
del self.up_scales
|
||||||
del self.down_scales
|
del self.down_scales
|
||||||
|
|
||||||
|
|
||||||
|
class RAWAMXMoEWrapper(BaseMoEWrapper):
|
||||||
|
"""Wrapper for RAWINT4 experts stored in compressed SafeTensor format."""
|
||||||
|
|
||||||
|
_compressed_loader_instance = None
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
layer_idx: int,
|
||||||
|
num_experts: int,
|
||||||
|
num_experts_per_tok: int,
|
||||||
|
hidden_size: int,
|
||||||
|
moe_intermediate_size: int,
|
||||||
|
num_gpu_experts: int,
|
||||||
|
cpuinfer_threads: int,
|
||||||
|
threadpool_count: int,
|
||||||
|
weight_path: str,
|
||||||
|
chunked_prefill_size: int,
|
||||||
|
cpu_save: bool = False,
|
||||||
|
max_deferred_experts_per_token: Optional[int] = None,
|
||||||
|
method: str = "RAWINT4",
|
||||||
|
):
|
||||||
|
if not _HAS_AMX_SUPPORT or AMXInt4_KGroup_MOE is None:
|
||||||
|
raise RuntimeError("AMX backend with RAWINT4 support is not available.")
|
||||||
|
|
||||||
|
super().__init__(
|
||||||
|
layer_idx=layer_idx,
|
||||||
|
num_experts=num_experts,
|
||||||
|
num_experts_per_tok=num_experts_per_tok,
|
||||||
|
hidden_size=hidden_size,
|
||||||
|
moe_intermediate_size=moe_intermediate_size,
|
||||||
|
num_gpu_experts=num_gpu_experts,
|
||||||
|
cpuinfer_threads=cpuinfer_threads,
|
||||||
|
threadpool_count=threadpool_count,
|
||||||
|
weight_path=weight_path,
|
||||||
|
chunked_prefill_size=chunked_prefill_size,
|
||||||
|
cpu_save=cpu_save,
|
||||||
|
max_deferred_experts_per_token=max_deferred_experts_per_token,
|
||||||
|
method=method,
|
||||||
|
)
|
||||||
|
|
||||||
|
if RAWAMXMoEWrapper._compressed_loader_instance is None:
|
||||||
|
RAWAMXMoEWrapper._compressed_loader_instance = CompressedSafeTensorLoader(weight_path)
|
||||||
|
self.loader = RAWAMXMoEWrapper._compressed_loader_instance
|
||||||
|
|
||||||
|
self.gate_weights = None
|
||||||
|
self.up_weights = None
|
||||||
|
self.down_weights = None
|
||||||
|
self.gate_scales = None
|
||||||
|
self.up_scales = None
|
||||||
|
self.down_scales = None
|
||||||
|
|
||||||
|
def load_weights_from_tensors(
|
||||||
|
self,
|
||||||
|
gate_proj: torch.Tensor,
|
||||||
|
up_proj: torch.Tensor,
|
||||||
|
down_proj: torch.Tensor,
|
||||||
|
physical_to_logical_map_cpu: torch.Tensor,
|
||||||
|
):
|
||||||
|
raise NotImplementedError("RAWINT4 wrapper expects pre-quantized safetensor weights.")
|
||||||
|
|
||||||
|
def load_weights(self, physical_to_logical_map_cpu: torch.Tensor):
|
||||||
|
base_key = f"model.layers.{self.layer_idx}"
|
||||||
|
weights = self.loader.load_experts(base_key)
|
||||||
|
|
||||||
|
self.gate_weights = torch.stack(weights["gate"], dim=0).contiguous()
|
||||||
|
self.up_weights = torch.stack(weights["up"], dim=0).contiguous()
|
||||||
|
self.down_weights = torch.stack(weights["down"], dim=0).contiguous()
|
||||||
|
|
||||||
|
self.gate_scales = torch.stack(weights["gate_scale"], dim=0).to(torch.bfloat16).contiguous()
|
||||||
|
self.up_scales = torch.stack(weights["up_scale"], dim=0).to(torch.bfloat16).contiguous()
|
||||||
|
self.down_scales = torch.stack(weights["down_scale"], dim=0).to(torch.bfloat16).contiguous()
|
||||||
|
|
||||||
|
moe_config = MOEConfig(
|
||||||
|
self.num_experts,
|
||||||
|
self.num_experts_per_tok,
|
||||||
|
self.hidden_size,
|
||||||
|
self.moe_intermediate_size,
|
||||||
|
self.num_gpu_experts,
|
||||||
|
)
|
||||||
|
moe_config.layer_idx = self.layer_idx
|
||||||
|
moe_config.pool = self.cpu_infer.backend_
|
||||||
|
moe_config.max_len = self.chunked_prefill_size
|
||||||
|
|
||||||
|
moe_config.quant_config.bits = 4
|
||||||
|
moe_config.quant_config.group_size = 32
|
||||||
|
moe_config.quant_config.zero_point = False
|
||||||
|
|
||||||
|
moe_config.gate_proj = self.gate_weights.data_ptr()
|
||||||
|
moe_config.up_proj = self.up_weights.data_ptr()
|
||||||
|
moe_config.down_proj = self.down_weights.data_ptr()
|
||||||
|
moe_config.gate_scale = self.gate_scales.data_ptr()
|
||||||
|
moe_config.up_scale = self.up_scales.data_ptr()
|
||||||
|
moe_config.down_scale = self.down_scales.data_ptr()
|
||||||
|
|
||||||
|
self.moe = AMXInt4_KGroup_MOE(moe_config)
|
||||||
|
|
||||||
|
self.cpu_infer.submit(self.moe.load_weights_task(physical_to_logical_map_cpu.data_ptr()))
|
||||||
|
self.cpu_infer.sync()
|
||||||
|
|
||||||
|
del self.gate_weights
|
||||||
|
del self.up_weights
|
||||||
|
del self.down_weights
|
||||||
|
del self.gate_scales
|
||||||
|
del self.up_scales
|
||||||
|
del self.down_scales
|
||||||
|
|
||||||
|
def submit_write_weight_scale_to_buffer(
|
||||||
|
self,
|
||||||
|
gpu_tp_count: int,
|
||||||
|
gpu_experts_num: int,
|
||||||
|
w13_weight_ptrs,
|
||||||
|
w13_scale_ptrs,
|
||||||
|
w2_weight_ptrs,
|
||||||
|
w2_scale_ptrs,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Submit the write_weight_scale_to_buffer task for RAWINT4 KGroup AMX implementation.
|
||||||
|
|
||||||
|
This method submits the C++-exposed task `write_weight_scale_to_buffer_task` to the
|
||||||
|
shared CPUInfer queue. The pointer lists should be plain integer lists (e.g. from
|
||||||
|
tensor.data_ptr()).
|
||||||
|
"""
|
||||||
|
if self.moe is None:
|
||||||
|
raise RuntimeError("MoE instance not initialized; cannot submit write_weight_scale_to_buffer task.")
|
||||||
|
|
||||||
|
if not hasattr(self.moe, "write_weight_scale_to_buffer_task"):
|
||||||
|
raise NotImplementedError(
|
||||||
|
"write_weight_scale_to_buffer_task is not available for this backend implementation."
|
||||||
|
)
|
||||||
|
|
||||||
|
self.cpu_infer.submit(
|
||||||
|
self.moe.write_weight_scale_to_buffer_task(
|
||||||
|
gpu_tp_count,
|
||||||
|
gpu_experts_num,
|
||||||
|
w13_weight_ptrs,
|
||||||
|
w13_scale_ptrs,
|
||||||
|
w2_weight_ptrs,
|
||||||
|
w2_scale_ptrs,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
def sync_write_weight_scale_to_buffer(self):
|
||||||
|
"""
|
||||||
|
Block until previously submitted write_weight_scale_to_buffer tasks finish.
|
||||||
|
"""
|
||||||
|
# The CPUInfer.sync() call blocks until pending tasks complete.
|
||||||
|
self.cpu_infer.sync()
|
||||||
|
|
|
||||||
|
|
@ -237,6 +237,56 @@ class SafeTensorLoader:
|
||||||
return name in self.tensor_file_map
|
return name in self.tensor_file_map
|
||||||
|
|
||||||
|
|
||||||
|
class CompressedSafeTensorLoader(SafeTensorLoader):
|
||||||
|
"""Loader for compressed SafeTensor layouts (RAWINT4 weights)."""
|
||||||
|
|
||||||
|
def load_experts(self, base_key: str, device: str = "cpu"):
|
||||||
|
"""Load raw expert weights stored in compressed safetensor format."""
|
||||||
|
|
||||||
|
experts_prefix = f"{base_key}.mlp.experts"
|
||||||
|
|
||||||
|
expert_idx = 0
|
||||||
|
while self.has_tensor(f"{experts_prefix}.{expert_idx}.up_proj.weight_packed"):
|
||||||
|
expert_idx += 1
|
||||||
|
|
||||||
|
if expert_idx == 0:
|
||||||
|
raise ValueError(f"No experts found for key {experts_prefix}")
|
||||||
|
|
||||||
|
def load_projection(proj_name: str):
|
||||||
|
weight_entries = []
|
||||||
|
scale_entries = []
|
||||||
|
|
||||||
|
for exp_id in range(expert_idx):
|
||||||
|
weight_key = f"{experts_prefix}.{exp_id}.{proj_name}_proj.weight_packed"
|
||||||
|
scale_key = f"{experts_prefix}.{exp_id}.{proj_name}_proj.weight_scale"
|
||||||
|
|
||||||
|
if not self.has_tensor(weight_key):
|
||||||
|
raise KeyError(f"Missing tensor: {weight_key}")
|
||||||
|
if not self.has_tensor(scale_key):
|
||||||
|
raise KeyError(f"Missing tensor: {scale_key}")
|
||||||
|
|
||||||
|
weight_tensor = self.load_tensor(weight_key, device).contiguous()
|
||||||
|
scale_tensor = self.load_tensor(scale_key, device).contiguous()
|
||||||
|
|
||||||
|
weight_entries.append(weight_tensor)
|
||||||
|
scale_entries.append(scale_tensor)
|
||||||
|
|
||||||
|
return weight_entries, scale_entries
|
||||||
|
|
||||||
|
gate_weights, gate_scales = load_projection("gate")
|
||||||
|
up_weights, up_scales = load_projection("up")
|
||||||
|
down_weights, down_scales = load_projection("down")
|
||||||
|
|
||||||
|
return {
|
||||||
|
"gate": gate_weights,
|
||||||
|
"up": up_weights,
|
||||||
|
"down": down_weights,
|
||||||
|
"gate_scale": gate_scales,
|
||||||
|
"up_scale": up_scales,
|
||||||
|
"down_scale": down_scales,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
class GGUFLoader:
|
class GGUFLoader:
|
||||||
"""
|
"""
|
||||||
GGUF format loader using the official gguf library (gguf.gguf_reader.GGUFReader)
|
GGUF format loader using the official gguf library (gguf.gguf_reader.GGUFReader)
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue