kvcache-ai-ktransformers/kt-kernel/bench/bench_moe_amx.py
mrhaoxx 9544a8960d
Some checks failed
Book-CI / test (push) Has been cancelled
Book-CI / test-1 (push) Has been cancelled
Book-CI / test-2 (push) Has been cancelled
Deploy / deploy (macos-latest) (push) Has been cancelled
Deploy / deploy (ubuntu-latest) (push) Has been cancelled
Deploy / deploy (windows-latest) (push) Has been cancelled
feat(sft): AMX MoE SFT backend with LoRA support (#1936)
* feat(sft): AMX MoE SFT backend with LoRA support

Complete SFT (Supervised Fine-Tuning) backend for MoE models using AMX SIMD:

Core C++ implementation:
- sft_moe.hpp: Forward/backward with LoRA fused operations (~5500 lines)
- moe-sft-tp.hpp: Tensor-parallel wrapper for multi-NUMA
- amx/moe-sft-tp.hpp: AMX-specific TP implementation
- avx_kernels.hpp: AVX512 SIMD kernels for LoRA GEMM
- amx_kernels.hpp: AMX tile kernels for Panel5 rank-outer optimization
- worker_pool: RDTSC profiling, Chrome trace output, SFT timer infrastructure
- ext_bindings.cpp: SFT MOE pybind bindings (BF16/INT8/INT4 + SkipLoRA variants)

Python sft/ submodule (kt_kernel.sft):
- base.py: BaseSFTMoEWrapper with buffer management (template method pattern)
- amx.py: AMXSFTMoEWrapper (weight loading, C++ task construction)
- autograd.py: KTMoEFunction (torch.autograd.Function for distributed training)
- layer.py: KTMoELayerWrapper (nn.Module replacing HF MoE layers)
- arch.py: MOEArchConfig (Qwen3/DeepSeek/Mixtral architecture detection)
- weights.py: Expert weight extraction and checkpoint loading
- lora.py: PEFT LoRA adaptation (view buffers, grad buffers, save/load adapter)
- wrapper.py: wrap_moe_layers_with_kt_wrapper, load_kt_model, build_kt_device_map
- config.py: KTConfig dataclass (DeepSpeed-style opaque config passthrough)
- dist_utils.py: Distributed gather/scatter, checkpoint-phase detection

Design decisions:
- Rank-0-only expert pattern: only rank 0 holds C++ wrapper and expert weights
- DeepSpeed-style integration: accelerate keeps only KTransformersPlugin (framework
  interaction fields), all logic in kt_kernel.sft
- Inference isolation: importing kt_kernel does not load sft/ submodule
- Old field name compatibility: _get_kt_config() converts kt_xxx→xxx automatically

Verified: Qwen3-235B-A22B 4GPU AMXBF16 training, loss converges normally.

* refactor(sft): unify KTConfig field names with kt_ prefix, add share_cache_pool, remove dead code

- KTConfig fields all use kt_ prefix matching dict keys — eliminates
  _OLD_TO_NEW mapping and prefix-stripping in wrapper.py
- Add kt_share_cache_pool field, auto-enabled when gradient_checkpointing
  is on (via training_args.py), flows through to C++ cache allocation
- Remove dead checkpoint detection code: in_ckpt_recompute,
  in_ckpt_first_forward vars (assigned but never read), fallback
  _is_in_checkpoint_first_forward() function, unused inspect import
- Remove redundant env var fallbacks in wrapper.py for share_backward_bb
  and share_cache_pool (KTConfig.__post_init__ already handles env vars)
- Simplify layer.py checkpoint logic to single _checkpoint_hook_mode() check

Verified: Qwen3-235B 3-step training on sap4, loss matches baseline
(1.2886 / 1.9824 / 1.377 vs 1.2886 / 1.9766 / 1.3809)

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

* refactor(sft): share_backward_bb default True, share_cache_pool auto-derived

- kt_share_backward_bb defaults to True (always saves memory)
- kt_share_cache_pool no longer reads from env var; defaults False,
  auto-set to True by trainer_config_process when gradient checkpointing
  is enabled

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

* fix: add missing gpu_experts_mask=None to KTMoEWrapper call in SFT wrapper

KTMoEWrapper.__new__() requires gpu_experts_mask as a positional argument,
but the SFT wrapper omitted it, causing MoE layer wrapping to fail silently
and FSDP2 to attempt broadcasting all expert weights (OOM/NCCL crash).

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

* feat(sft): support transformers v5 fused expert format

Fused experts (e.g. Qwen3MoeExperts) store weights as 3D Parameters
(gate_up_proj [E,2I,H], down_proj [E,H,I]) instead of per-expert
nn.Linear modules. PEFT cannot attach LoRA to these, so we create
KT-managed LoRA buffers with kaiming init, nn.Parameter wrappers
for the optimizer, and pre-assigned .grad for C++ backward.

- arch.py: detect_fused_experts() detection
- weights.py: fused format extraction and weight clearing
- wrapper.py: detect fused at wrap time, store _fused_experts/_lora_rank
- lora.py: _create_fused_expert_lora_buffers, save/load fused LoRA,
  get_kt_lora_params collects fused params, deduplicate wrapper finding
- layer.py: handle v5 TopKRouter tuple output, remove dead code
- autograd.py: sync_forward_sft/submit_forward_sft API rename

Verified: v5 loss/expert-LoRA values match v4 baseline, v4 backward compat.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

* feat(sft): add Qwen3.5 MoE support + fused checkpoint loading

- arch.py: add Qwen3_5Moe arch match, read config from text_config,
  _get_layers_prefix returns model.language_model.layers for Qwen3.5,
  _get_model_container_and_layers searches language_model attr
- weights.py: load_experts_from_checkpoint_files detects fused format
  (gate_up_proj in weight_map) and splits into gate/up/down
- wrapper.py: hidden_size fallback to text_config

Verified: Qwen3.5-35B-A3B (256 experts, fused format) E2E pass.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

* [fix](sft): align Python API with C++ backend after v5 refactor

- wrapper.py: pass gpu_experts_mask=None to KTMoEWrapper (required by C++ signature)
- layer.py: rename submit_forward_sft/sync_forward_sft to submit_forward/sync_forward
- autograd.py: rename sync_forward_sft to sync_forward

The sft-v5 refactor (commits 58d7eab, dd1da65) renamed Python-side method
calls but the C++ backend (AMXSFTMoEWrapper) still exposes the original
method names. This caused AttributeError on Qwen3.5-35B and other models.

* align sft branch with main: revert worker_pool, strip sft_timer, fix inference defaults

- Revert worker_pool.cpp/.h to main (remove RDTSC timer, Chrome Trace,
  sft_timer namespace, ITT API, extended do_work_stealing_job API)
- Strip all sft_timer instrumentation from sft-only files (sft_moe.hpp,
  moe-sft-tp.hpp, avx_kernels.hpp)
- Restore pin_memory=True in KExpertsCPUBuffer (inference path)
- Restore fused tensor transpose logic in convert_cpu_weights.py (main layout)

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

* revert CMakeLists.txt to main: remove debug flags and cpptrace dep

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

* clean up dev artifacts: remove SFT design docs, debug examples, bench scripts

Remove files not needed in the merge:
- docs/SFT+KTWrapper/ (6 Chinese design docs)
- docs/sft_moe_amx/ (21 dev/debug docs)
- 12 debug/test example scripts
- 6 SFT-specific bench scripts and report

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

* remove dev version stamps from ext_bindings, sft_moe, moe-sft-tp

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

---------

Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Co-authored-by: JimmyPeilinLi <lipeilin@mail.nwpu.edu.cn>
2026-04-22 11:27:01 +08:00

400 lines
14 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python
# coding=utf-8
"""
Description :
Author : chenht2022
Date : 2024-07-25 10:32:05
Version : 1.0.0
LastEditors : chenht2022
LastEditTime : 2024-08-06 10:41:28
Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
"""
import argparse
import os
import sys
import time
import json
import subprocess
import platform
from tqdm import tqdm
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "build"))
import torch
from kt_kernel import kt_kernel_ext
# 测试参数设置
expert_num = 256
hidden_size = 7168
intermediate_size = 2048
max_len = 25600
num_experts_per_tok = 8
layer_num = 5
qlen = 1
warm_up_iter = 1000
test_iter = 10000
gen_iter = 3000
show_progress = True
physical_to_logical_map = torch.tensor(data=range(expert_num), device="cpu", dtype=torch.int64).contiguous()
# 线程/NUMA 参数
CPUINFER_PARAM = 64
subpool_count = 2
interop_threads = 1
subpool_thread_count = []
def parse_csv(value: str):
return [item.strip() for item in value.split(",") if item.strip()]
def refresh_physical_to_logical_map():
global physical_to_logical_map
physical_to_logical_map = torch.tensor(data=range(expert_num), device="cpu", dtype=torch.int64).contiguous()
def configure_torch_threads(threads: int, interop: int):
os.environ["OMP_NUM_THREADS"] = str(threads)
os.environ["MKL_NUM_THREADS"] = str(threads)
torch.set_num_threads(threads)
try:
torch.set_num_interop_threads(interop)
except RuntimeError:
# set_num_interop_threads can only be called before parallel work starts.
pass
def build_cpuinfer(total_threads: int, num_subpools: int):
global subpool_thread_count
if num_subpools <= 0:
raise ValueError("subpool_count must be positive")
if total_threads < num_subpools:
raise ValueError("threads must be >= subpool_count")
base = total_threads // num_subpools
remain = total_threads % num_subpools
subpool_thread_count = [base + (1 if i < remain else 0) for i in range(num_subpools)]
worker_config = kt_kernel_ext.WorkerPoolConfig()
worker_config.subpool_count = num_subpools
worker_config.subpool_numa_map = list(range(num_subpools))
worker_config.subpool_thread_count = subpool_thread_count
return kt_kernel_ext.CPUInfer(worker_config)
configure_torch_threads(CPUINFER_PARAM, interop_threads)
CPUInfer = build_cpuinfer(CPUINFER_PARAM, subpool_count)
def get_git_commit():
"""
获取当前 git 提交记录commit hash 和提交信息),
并检查是否存在未提交的更改dirty
"""
result = {}
try:
commit = subprocess.check_output(["git", "rev-parse", "HEAD"]).decode("utf-8").strip()
commit_msg = subprocess.check_output(["git", "log", "-1", "--pretty=%B"]).decode("utf-8").strip()
result["commit"] = commit
result["commit_message"] = commit_msg
# 检查是否存在未提交的更改
dirty_output = subprocess.check_output(["git", "status", "--porcelain"]).decode("utf-8").strip()
if dirty_output:
result["dirty"] = True
result["dirty_files"] = dirty_output.splitlines()
else:
result["dirty"] = False
except Exception as e:
result["commit"] = None
result["commit_message"] = None
result["dirty"] = None
result["error"] = str(e)
return result
def get_system_info():
"""
获取系统信息包括系统名称、CPU 型号、内存大小GB、CPU 核数及 socket 数量
"""
info = {}
# 系统名称及主机名
uname = platform.uname()
info["system_name"] = uname.system # 如 Linux, Windows 等
info["node_name"] = uname.node # 主机名称
# 获取 CPU 型号(仅 Linux 支持)
cpu_model = None
if os.path.exists("/proc/cpuinfo"):
try:
with open("/proc/cpuinfo", "r") as f:
for line in f:
if "model name" in line:
cpu_model = line.split(":", 1)[1].strip()
break
except Exception as e:
cpu_model = f"Error: {e}"
info["cpu_model"] = cpu_model
# 获取内存大小单位GB仅 Linux 支持
mem_total_gb = None
if os.path.exists("/proc/meminfo"):
try:
with open("/proc/meminfo", "r") as f:
for line in f:
if "MemTotal" in line:
mem_kb = float(line.split(":", 1)[1].split()[0])
mem_total_gb = round(mem_kb / (1024 * 1024), 2)
break
except Exception as e:
mem_total_gb = f"Error: {e}"
info["memory_size_GB"] = mem_total_gb
# 获取 CPU 核数(逻辑核数)
info["cpu_core_count"] = os.cpu_count()
# 解析 /proc/cpuinfo 获取 socket 数量
sockets = set()
if os.path.exists("/proc/cpuinfo"):
try:
with open("/proc/cpuinfo", "r") as f:
for line in f:
if "physical id" in line:
sockets.add(line.split(":", 1)[1].strip())
except Exception as e:
sockets = set()
# 如果没有解析到 socket 信息,则默认至少有 1 个 socket
info["cpu_socket_count"] = len(sockets) if len(sockets) > 0 else 1
return info
script_path = os.path.abspath(__file__)
script_dir = os.path.dirname(script_path)
script_name = os.path.splitext(os.path.basename(script_path))[0]
json_path = os.path.join(script_dir, script_name + ".jsonl")
def record_results(result, filename=json_path):
"""
将结果以 JSON 格式追加到文件中
"""
with open(filename, "a") as f:
f.write(json.dumps(result) + "\n")
def bench_moe(quant_mode: str):
with torch.inference_mode():
if quant_mode == "bf16":
bytes_per_elem = 2.0
elif quant_mode == "int8":
bytes_per_elem = 1.0
elif quant_mode == "int4":
bytes_per_elem = 0.5
else:
raise ValueError("不支持的量化模式")
moes = []
gate_projs = []
up_projs = []
down_projs = []
for layer_index in range(layer_num):
gate_proj = torch.randn(
(expert_num, intermediate_size, hidden_size), dtype=torch.float32, device="cpu"
).contiguous()
up_proj = torch.randn(
(expert_num, intermediate_size, hidden_size), dtype=torch.float32, device="cpu"
).contiguous()
down_proj = torch.randn(
(expert_num, hidden_size, intermediate_size), dtype=torch.float32, device="cpu"
).contiguous()
config = kt_kernel_ext.moe.MOEConfig(expert_num, num_experts_per_tok, hidden_size, intermediate_size, 0)
config.max_len = max_len
config.gate_proj = gate_proj.data_ptr()
config.up_proj = up_proj.data_ptr()
config.down_proj = down_proj.data_ptr()
config.pool = CPUInfer.backend_
config.physical_to_logical_map = physical_to_logical_map.data_ptr()
if quant_mode == "bf16":
moe = kt_kernel_ext.moe.AMXBF16_MOE(config)
elif quant_mode == "int8":
moe = kt_kernel_ext.moe.AMXInt8_MOE(config)
elif quant_mode == "int4":
moe = kt_kernel_ext.moe.AMXInt4_MOE(config)
CPUInfer.submit(moe.load_weights_task())
CPUInfer.sync()
gate_projs.append(gate_proj)
up_projs.append(up_proj)
down_projs.append(down_proj)
moes.append(moe)
expert_ids = (
torch.rand(gen_iter * qlen, expert_num, device="cpu")
.argsort(dim=-1)[:, :num_experts_per_tok]
.reshape(gen_iter, qlen * num_experts_per_tok)
.to("cpu")
.contiguous()
)
weights = (
torch.rand((gen_iter, qlen, num_experts_per_tok), dtype=torch.float32, device="cpu").to("cpu").contiguous()
)
input_tensor = torch.randn((layer_num, qlen, hidden_size), dtype=torch.bfloat16, device="cpu").contiguous()
output_tensor = torch.empty((layer_num, qlen, hidden_size), dtype=torch.bfloat16, device="cpu").contiguous()
bsz_tensor = torch.tensor([qlen], dtype=torch.int32, device="cpu")
# 预热迭代
for i in tqdm(range(warm_up_iter), desc="Warm-up", disable=not show_progress):
# start_it = time.time_ns()
CPUInfer.submit(
moes[i % layer_num].forward_task(
bsz_tensor.data_ptr(),
num_experts_per_tok,
expert_ids[i % gen_iter].data_ptr(),
weights[i % gen_iter].data_ptr(),
input_tensor[i % layer_num].data_ptr(),
output_tensor[i % layer_num].data_ptr(),
False,
)
)
CPUInfer.sync()
# end_it = time.time_ns()
# print('python Time(ns): ', end_it - start_it)
# 测试迭代
start = time.perf_counter()
for i in tqdm(range(test_iter), desc="Testing", disable=not show_progress):
# print(f'test iteration {i}')
# start_it = time.time_ns()
CPUInfer.submit(
moes[i % layer_num].forward_task(
bsz_tensor.data_ptr(),
num_experts_per_tok,
expert_ids[i % gen_iter].data_ptr(),
weights[i % gen_iter].data_ptr(),
input_tensor[i % layer_num].data_ptr(),
output_tensor[i % layer_num].data_ptr(),
False,
)
)
CPUInfer.sync()
# end_it = time.time_ns()
# print('python Time(ns): ', end_it - start_it)
end = time.perf_counter()
total_time = end - start
# 计算性能指标
time_per_iter_us = total_time / test_iter * 1e6
work_elems = hidden_size * intermediate_size * qlen * 3 * num_experts_per_tok
bandwidth = work_elems * bytes_per_elem * test_iter / total_time / 1e9 # 单位GB/s
flops = work_elems * 2 * test_iter / total_time / 1e12 # 单位TFLOPS
print("Quant mode: ", quant_mode)
print("Time(s): ", total_time)
print("Iteration: ", test_iter)
print("Time(us) per iteration: ", time_per_iter_us)
print("Bandwidth: ", bandwidth, "GB/s")
print("Flops: ", flops, "TFLOPS")
print("")
# 整理结果记录,包括测试参数
result = {
"quant_mode": quant_mode,
"total_time_seconds": total_time,
"iterations": test_iter,
"time_per_iteration_us": time_per_iter_us,
"bandwidth_GBs": bandwidth,
"flops_TFLOPS": flops,
"timestamp": time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()),
"test_parameters": {
"expert_num": expert_num,
"hidden_size": hidden_size,
"intermediate_size": intermediate_size,
"max_len": max_len,
"num_experts_per_tok": num_experts_per_tok,
"layer_num": layer_num,
"qlen": qlen,
"warm_up_iter": warm_up_iter,
"test_iter": test_iter,
"CPUInfer_parameter": CPUINFER_PARAM,
"subpool_count": subpool_count,
"subpool_thread_count": subpool_thread_count,
},
}
# 添加 git 提交记录信息
result.update(get_git_commit())
# 添加系统信息(包括 CPU 核数和 socket 数量)
result.update(get_system_info())
# 将结果以 JSON 形式追加到文件中
record_results(result)
def main():
global expert_num
global hidden_size
global intermediate_size
global max_len
global num_experts_per_tok
global layer_num
global qlen
global warm_up_iter
global test_iter
global gen_iter
global CPUINFER_PARAM
global subpool_count
global interop_threads
global show_progress
global CPUInfer
parser = argparse.ArgumentParser(description="AMX MoE benchmark")
parser.add_argument("--expert-num", type=int, default=expert_num)
parser.add_argument("--hidden-size", type=int, default=hidden_size)
parser.add_argument("--intermediate-size", type=int, default=intermediate_size)
parser.add_argument("--max-len", type=int, default=max_len)
parser.add_argument("--num-experts-per-tok", type=int, default=num_experts_per_tok)
parser.add_argument("--layer-num", type=int, default=layer_num)
parser.add_argument("--qlen", type=int, default=qlen)
parser.add_argument("--warm-up-iter", type=int, default=warm_up_iter)
parser.add_argument("--test-iter", type=int, default=test_iter)
parser.add_argument("--gen-iter", type=int, default=gen_iter)
parser.add_argument("--threads", type=int, default=CPUINFER_PARAM)
parser.add_argument("--subpool-count", type=int, default=subpool_count)
parser.add_argument("--interop-threads", type=int, default=interop_threads)
parser.add_argument("--quant-modes", type=str, default="int8")
parser.add_argument("--no-progress", action="store_true", default=False)
args = parser.parse_args()
expert_num = args.expert_num
hidden_size = args.hidden_size
intermediate_size = args.intermediate_size
max_len = args.max_len
num_experts_per_tok = args.num_experts_per_tok
layer_num = args.layer_num
qlen = args.qlen
warm_up_iter = args.warm_up_iter
test_iter = args.test_iter
gen_iter = args.gen_iter
CPUINFER_PARAM = args.threads
subpool_count = args.subpool_count
interop_threads = args.interop_threads
show_progress = not args.no_progress
refresh_physical_to_logical_map()
configure_torch_threads(CPUINFER_PARAM, interop_threads)
CPUInfer = build_cpuinfer(CPUINFER_PARAM, subpool_count)
quant_modes = parse_csv(args.quant_modes)
print("[config] amx bench")
print(
f"[config] E={expert_num}, H={hidden_size}, I={intermediate_size}, topk={num_experts_per_tok}, "
f"layers={layer_num}, qlen={qlen}"
)
print(f"[config] warmup={warm_up_iter}, test={test_iter}, gen_iter={gen_iter}")
print(f"[config] threads={CPUINFER_PARAM}, interop_threads={interop_threads}")
print(f"[config] subpool_count={subpool_count}, subpool_thread_count={subpool_thread_count}")
print(f"[config] quant_modes={quant_modes}, show_progress={show_progress}")
for mode in quant_modes:
bench_moe(mode)
if __name__ == "__main__":
main()