[feat](kt-kernel): adapt MXFP4 MoE backend for DeepSeek-V4-Flash (#1950)

V4-Flash routed experts ship as native MXFP4 (E2M1 nibble + ue8m0 group
scale). Expose AMXFP4_KGroup_MOE through NativeMoEWrapper, add a loader
that handles V4's `layers.{L}.ffn.experts.{i}.{w1,w3,w2}.{weight,scale}`
naming and converts ue8m0 → bf16 via a lossless bit-cast, register the
model entry, and ship an end-to-end numerical validation script.

Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
Benjamin F 2026-04-25 18:11:53 +08:00 committed by GitHub
parent 5c5d7d48c0
commit 8484ef8b16
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 322 additions and 2 deletions

View file

@ -0,0 +1,178 @@
"""End-to-end MXFP4 MoE validation against the native DeepSeek-V4-Flash ckpt.
Loads layer-`LAYER_ID` experts via :class:`MXFP4SafeTensorLoader`, runs the AMX
FP4 backend, and compares against a torch reference that dequantizes the same
nibble-packed weights with the OCP E2M1 LUT.
Usage:
python test_fp4_moe_v4.py --weight-path /path/to/DeepSeek-V4-Flash [--layer 1]
"""
from __future__ import annotations
import argparse
import os
import sys
from typing import Tuple
import torch
# Allow running from kt-kernel/examples without install.
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))) + "/build")
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))) + "/python")
from kt_kernel import kt_kernel_ext # noqa: E402
from kt_kernel.utils.loader import MXFP4SafeTensorLoader # noqa: E402
# OCP E2M1 codepoints in our LUT order (matches operators/amx/fp4-moe.hpp).
E2M1_VALUES = torch.tensor(
[0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0,
-0.0, -0.5, -1.0, -1.5, -2.0, -3.0, -4.0, -6.0],
dtype=torch.float32,
)
def dequantize_mxfp4(weight_u8: torch.Tensor, scale_bf16: torch.Tensor, group_size: int) -> torch.Tensor:
"""Decode a [N, K/2] uint8 tensor of nibble-packed E2M1 with [N, K/gs] bf16
scales into a [N, K] bf16 weight tensor.
Layout (matches kernel's mxfp4_to_bf16_32): byte `b` low nibble = element K=2b,
high nibble = element K=2b+1.
"""
n, k_packed = weight_u8.shape
k = k_packed * 2
assert k % group_size == 0, f"K={k} must be divisible by group_size={group_size}"
assert scale_bf16.shape == (n, k // group_size)
lo = (weight_u8 & 0x0F).to(torch.long)
hi = ((weight_u8 >> 4) & 0x0F).to(torch.long)
nibbles = torch.stack([lo, hi], dim=-1).view(n, k) # interleave back to K order
decoded = E2M1_VALUES.to(weight_u8.device)[nibbles] # [N, K] fp32
scale_fp32 = scale_bf16.to(torch.float32)
scale_full = scale_fp32.repeat_interleave(group_size, dim=-1) # [N, K]
return (decoded * scale_full).to(torch.bfloat16).contiguous()
def reference_mlp(x: torch.Tensor, gate: torch.Tensor, up: torch.Tensor, down: torch.Tensor) -> torch.Tensor:
g = torch.mm(x, gate.t())
u = torch.mm(x, up.t())
silu = g / (1.0 + torch.exp(-g.float())).to(g.dtype)
return torch.mm(silu * u, down.t())
def reference_moe(
hidden: torch.Tensor,
expert_ids: torch.Tensor,
weights: torch.Tensor,
gate_w: torch.Tensor, # [E, N, K]
up_w: torch.Tensor,
down_w: torch.Tensor,
) -> torch.Tensor:
out = torch.zeros_like(hidden, dtype=torch.float32)
for tok in range(hidden.shape[0]):
for slot in range(expert_ids.shape[1]):
eid = int(expert_ids[tok, slot])
w = float(weights[tok, slot])
x = hidden[tok : tok + 1]
y = reference_mlp(x, gate_w[eid], up_w[eid], down_w[eid])
out[tok] += w * y[0].float()
return out.to(hidden.dtype)
def parse_args() -> argparse.Namespace:
p = argparse.ArgumentParser()
p.add_argument("--weight-path", required=True, help="Path to DeepSeek-V4-Flash safetensors directory.")
p.add_argument("--layer", type=int, default=1, help="Layer index to validate (default: 1).")
p.add_argument("--qlen", type=int, default=1, help="Number of tokens to test.")
p.add_argument("--top-k", type=int, default=6, help="num_experts_per_tok (V4 default 6).")
p.add_argument("--cpu-threads", type=int, default=32)
p.add_argument("--max-experts", type=int, default=0, help="Cap number of experts loaded (0 = all).")
return p.parse_args()
def main() -> int:
args = parse_args()
torch.manual_seed(0)
print(f"[V4-MXFP4] Loading layer {args.layer} from {args.weight_path}")
loader = MXFP4SafeTensorLoader(args.weight_path)
weights = loader.load_experts(f"model.layers.{args.layer}")
expert_num = len(weights["gate"])
if args.max_experts and args.max_experts < expert_num:
for k in ("gate", "up", "down", "gate_scale", "up_scale", "down_scale"):
weights[k] = weights[k][: args.max_experts]
expert_num = args.max_experts
print(f"[V4-MXFP4] expert_num={expert_num}")
gate0 = weights["gate"][0]
down0 = weights["down"][0]
intermediate_size = gate0.shape[0]
hidden_size = gate0.shape[1] * 2 # nibble-packed K
assert down0.shape == (hidden_size, intermediate_size // 2), f"unexpected down shape {down0.shape}"
group_size = hidden_size // weights["gate_scale"][0].shape[1]
print(f"[V4-MXFP4] hidden={hidden_size} inter={intermediate_size} gs={group_size}")
assert group_size == 32, "MXFP4 backend hard-codes group_size=32"
physical_to_logical = torch.arange(expert_num, dtype=torch.int64).contiguous()
# ----- AMX FP4 forward -----
cpu_infer = kt_kernel_ext.CPUInfer(args.cpu_threads)
cfg = kt_kernel_ext.moe.MOEConfig(expert_num, args.top_k, hidden_size, intermediate_size, 0)
cfg.layer_idx = args.layer
cfg.max_len = max(args.qlen, 1)
cfg.pool = cpu_infer.backend_
cfg.quant_config.bits = 4
cfg.quant_config.group_size = group_size
cfg.quant_config.zero_point = False
cfg.gate_projs = [[t.data_ptr() for t in weights["gate"]]]
cfg.up_projs = [[t.data_ptr() for t in weights["up"]]]
cfg.down_projs = [[t.data_ptr() for t in weights["down"]]]
cfg.gate_scales = [[t.data_ptr() for t in weights["gate_scale"]]]
cfg.up_scales = [[t.data_ptr() for t in weights["up_scale"]]]
cfg.down_scales = [[t.data_ptr() for t in weights["down_scale"]]]
moe = kt_kernel_ext.moe.AMXFP4_KGroup_MOE(cfg)
cpu_infer.submit(moe.load_weights_task(physical_to_logical.data_ptr()))
cpu_infer.sync()
qlen = args.qlen
top_k = args.top_k
bsz = torch.tensor([qlen], dtype=torch.int32)
expert_ids = torch.stack([torch.randperm(expert_num)[:top_k] for _ in range(qlen)]).to(torch.int32).contiguous()
routing = torch.randn((qlen, top_k), dtype=torch.float32).contiguous()
x = (torch.randn((qlen, hidden_size), dtype=torch.bfloat16) / 100).contiguous()
y_amx = torch.empty((qlen, hidden_size), dtype=torch.bfloat16).contiguous()
cpu_infer.submit(
moe.forward_task(
bsz.data_ptr(), top_k, expert_ids.data_ptr(), routing.data_ptr(),
x.data_ptr(), y_amx.data_ptr(), False,
)
)
cpu_infer.sync()
# ----- Torch reference (dequantize same nibbles + scales) -----
print("[V4-MXFP4] Building torch reference (dequantizing all loaded experts)…")
gate_bf16 = torch.stack([dequantize_mxfp4(weights["gate"][i], weights["gate_scale"][i], group_size) for i in range(expert_num)])
up_bf16 = torch.stack([dequantize_mxfp4(weights["up"][i], weights["up_scale"][i], group_size) for i in range(expert_num)])
down_bf16 = torch.stack([dequantize_mxfp4(weights["down"][i], weights["down_scale"][i], group_size) for i in range(expert_num)])
y_ref = reference_moe(x, expert_ids, routing, gate_bf16, up_bf16, down_bf16)
diff = (y_amx.float() - y_ref.float()).abs()
rel = diff.mean() / (y_ref.float().abs().mean() + 1e-12)
print(f"[V4-MXFP4] mean abs diff = {diff.mean().item():.4e}")
print(f"[V4-MXFP4] max abs diff = {diff.max().item():.4e}")
print(f"[V4-MXFP4] rel mean diff = {rel.item()*100:.3f}%")
print(f"[V4-MXFP4] amx[:8] = {y_amx.flatten()[:8]}")
print(f"[V4-MXFP4] ref[:8] = {y_ref.flatten()[:8]}")
return 0 if rel.item() < 0.10 else 1
if __name__ == "__main__":
sys.exit(main())

View file

@ -81,6 +81,26 @@ BUILTIN_MODELS: list[ModelInfo] = [
description="DeepSeek R1-0528 reasoning model (May 2025, improved reasoning depth)",
description_zh="DeepSeek R1-0528 推理模型2025年5月改进的推理深度",
),
ModelInfo(
name="DeepSeek-V4-Flash",
hf_repo="deepseek-ai/DeepSeek-V4-Flash",
aliases=["deepseek-v4-flash", "deepseek-v4", "dsv4", "v4-flash", "v4"],
type="moe",
default_params={
"kt-method": "MXFP4",
"kt-gpu-prefill-token-threshold": 4096,
"attention-backend": "flashinfer",
"max-total-tokens": 100000,
"max-running-requests": 16,
"chunked-prefill-size": 32768,
"mem-fraction-static": 0.80,
"watchdog-timeout": 3000,
"served-model-name": "DeepSeek-V4-Flash",
"disable-shared-experts-fusion": True,
},
description="DeepSeek V4-Flash MoE model (native MXFP4 experts, MQA + sparse index attention)",
description_zh="DeepSeek V4-Flash MoE 模型(原生 MXFP4 专家MQA + 稀疏索引注意力)",
),
ModelInfo(
name="Kimi-K2-Thinking",
hf_repo="moonshotai/Kimi-K2-Thinking",
@ -368,6 +388,19 @@ def compute_deepseek_v3_gpu_experts(tensor_parallel_size: int, vram_per_gpu_gb:
return total_vram // 3
def compute_deepseek_v4_gpu_experts(tensor_parallel_size: int, vram_per_gpu_gb: float) -> int:
"""Compute kt-num-gpu-experts for DeepSeek-V4-Flash.
V4 uses MXFP4 experts (~0.5 bytes/param vs V3 FP8's 1 byte/param) so each GPU
can hold ~2x more experts per VRAM unit than V3 at the same fragmentation.
"""
per_gpu_gb = 16
if vram_per_gpu_gb < per_gpu_gb:
return 0
total_vram = int(tensor_parallel_size * (vram_per_gpu_gb - per_gpu_gb))
return total_vram * 2 // 3
def compute_kimi_k2_thinking_gpu_experts(tensor_parallel_size: int, vram_per_gpu_gb: float) -> int:
"""Compute kt-num-gpu-experts for Kimi K2 Thinking."""
per_gpu_gb = 16
@ -393,6 +426,7 @@ MODEL_COMPUTE_FUNCTIONS: dict[str, Callable[[int, float], int]] = {
"DeepSeek-V3-0324": compute_deepseek_v3_gpu_experts,
"DeepSeek-V3.2": compute_deepseek_v3_gpu_experts, # Same as V3-0324
"DeepSeek-R1-0528": compute_deepseek_v3_gpu_experts, # Same as V3-0324
"DeepSeek-V4-Flash": compute_deepseek_v4_gpu_experts,
"Kimi-K2-Thinking": compute_kimi_k2_thinking_gpu_experts,
"MiniMax-M2": compute_minimax_m2_gpu_experts,
"MiniMax-M2.1": compute_minimax_m2_gpu_experts, # Same as M2

View file

@ -87,7 +87,7 @@ class KTMoEWrapper:
cpu_save: Whether to save weights to CPU memory
max_deferred_experts_per_token: Number of experts per token to defer. Defaults to 0.
numa_nodes: Explicit list of NUMA node IDs for subpool mapping. If None, defaults to sequential.
method: Backend method ("AMXINT4", "AMXINT8", "RAWINT4", "FP8", "BF16", "LLAMAFILE", "MOE_INT4", "MOE_INT8")
method: Backend method ("AMXINT4", "AMXINT8", "RAWINT4", "MXFP4", "FP8", "BF16", "LLAMAFILE", "MOE_INT4", "MOE_INT8")
Returns:
An instance of the appropriate backend implementation (e.g., AMXMoEWrapper)
@ -95,7 +95,7 @@ class KTMoEWrapper:
# Select backend based on method
if method in ["AMXINT4", "AMXINT8"]:
backend_cls = AMXMoEWrapper
elif method in ["RAWINT4", "FP8", "BF16", "FP8_PERCHANNEL", "GPTQ_INT4"]:
elif method in ["RAWINT4", "FP8", "BF16", "FP8_PERCHANNEL", "GPTQ_INT4", "MXFP4"]:
backend_cls = NativeMoEWrapper
elif method == "LLAMAFILE":
backend_cls = LlamafileMoEWrapper

View file

@ -11,6 +11,7 @@ from .loader import (
FP8SafeTensorLoader,
BF16SafeTensorLoader,
GPTQSafeTensorLoader,
MXFP4SafeTensorLoader,
)
from kt_kernel_ext.moe import MOEConfig
import kt_kernel_ext.moe as _moe_mod
@ -18,6 +19,7 @@ import kt_kernel_ext.moe as _moe_mod
AMXInt4_MOE = getattr(_moe_mod, "AMXInt4_MOE", None)
AMXInt8_MOE = getattr(_moe_mod, "AMXInt8_MOE", None)
AMXInt4_KGroup_MOE = getattr(_moe_mod, "AMXInt4_KGroup_MOE", None)
AMXFP4_KGroup_MOE = getattr(_moe_mod, "AMXFP4_KGroup_MOE", None)
AMXFP8_MOE = getattr(_moe_mod, "AMXFP8_MOE", None)
AMXBF16_MOE = getattr(_moe_mod, "AMXBF16_MOE", None)
AMXFP8PerChannel_MOE = getattr(_moe_mod, "AMXFP8PerChannel_MOE", None)
@ -29,6 +31,7 @@ AVXVNNI256GPTQInt4_MOE = getattr(_moe_mod, "AVXVNNI256GPTQInt4_MOE", None)
_HAS_AMXINT4_SUPPORT = AMXInt4_MOE is not None
_HAS_AMXINT8_SUPPORT = AMXInt8_MOE is not None
_HAS_RAWINT4_SUPPORT = AMXInt4_KGroup_MOE is not None
_HAS_MXFP4_SUPPORT = AMXFP4_KGroup_MOE is not None
_HAS_FP8_SUPPORT = AMXFP8_MOE is not None
_HAS_BF16_SUPPORT = AMXBF16_MOE is not None
_HAS_FP8_PERCHANNEL_SUPPORT = AMXFP8PerChannel_MOE is not None
@ -444,6 +447,12 @@ class NativeMoEWrapper(BaseMoEWrapper):
"Please recompile kt_kernel_ext with GPTQ INT4 support enabled.\n"
"AVX-VNNI-256 will be selected automatically when available on the current CPU."
)
if method == "MXFP4" and not _HAS_MXFP4_SUPPORT:
raise RuntimeError(
"MXFP4 backend not available. Required ISA:\n"
" - AVX512F + AVX512BW + AVX512_BF16\n"
"Please recompile kt_kernel_ext with AVX512 + BF16 enabled."
)
super().__init__(
layer_idx=layer_idx,
@ -474,6 +483,8 @@ class NativeMoEWrapper(BaseMoEWrapper):
NativeMoEWrapper._native_loader_instance = BF16SafeTensorLoader(weight_path)
elif method == "GPTQ_INT4":
NativeMoEWrapper._native_loader_instance = GPTQSafeTensorLoader(weight_path)
elif method == "MXFP4":
NativeMoEWrapper._native_loader_instance = MXFP4SafeTensorLoader(weight_path)
else:
raise NotImplementedError(f"Unsupported method for NativeMoEWrapper: {method}")
self.loader = NativeMoEWrapper._native_loader_instance
@ -541,6 +552,10 @@ class NativeMoEWrapper(BaseMoEWrapper):
self.up_scales = [t.to(torch.float32).contiguous() for t in weights["up_scale"]]
self.down_scales = [t.to(torch.float32).contiguous() for t in weights["down_scale"]]
assert self.gate_scales[0].dtype == torch.float32, "Expected float32 scales for FP8_PERCHANNEL"
elif self.method == "MXFP4":
# ue8m0 is losslessly representable in bf16 (8-bit exponent, 0 mantissa);
# the loader has already done that conversion.
assert self.gate_scales[0].dtype == torch.bfloat16, "Expected bf16 scales for MXFP4"
t2 = time.time()
@ -590,6 +605,13 @@ class NativeMoEWrapper(BaseMoEWrapper):
moe_config.quant_config.group_size = group_size
moe_config.quant_config.zero_point = False
self.moe = AMXInt4_KGroup_MOE(moe_config)
elif self.method == "MXFP4":
# MXFP4: E2M1 nibble-packed weights, ue8m0/bf16 per-32 group scale
group_size = self.hidden_size // self.gate_scales[0].shape[1]
moe_config.quant_config.bits = 4
moe_config.quant_config.group_size = group_size
moe_config.quant_config.zero_point = False
self.moe = AMXFP4_KGroup_MOE(moe_config)
elif self.method == "FP8":
moe_config.quant_config.bits = 8
moe_config.quant_config.group_size = 128

View file

@ -1078,3 +1078,89 @@ class GPTQSafeTensorLoader(FP8SafeTensorLoader):
"up_scale": up_scales,
"down_scale": down_scales,
}
class MXFP4SafeTensorLoader(SafeTensorLoader):
"""Loader for native MXFP4 expert weights (DeepSeek-V4-Flash format).
Per expert layout:
{base}.ffn.experts.{i}.w1.weight I8 [N, K/2] nibble-packed E2M1 (gate)
{base}.ffn.experts.{i}.w1.scale F8_E8M0 [N, K/32] ue8m0 group scale
{base}.ffn.experts.{i}.w3.{weight,scale} up
{base}.ffn.experts.{i}.w2.{weight,scale} down
V4 ckpt keys are not prefixed with ``model.``; we also probe the stripped form so
callers can keep passing ``base_key="model.layers.{L}"``. ue8m0 bf16 is a lossless
bit shift (both have an 8-bit exponent and zero mantissa for ue8m0), and the AMX
FP4 backend already consumes bf16 scales.
"""
EXPERTS_PATH_TPL = "{base}.ffn.experts"
PROJ_NAMES = ("w1", "w3", "w2") # (gate, up, down)
def _experts_prefix_candidates(self, base_key: str) -> list[str]:
candidates = [self.EXPERTS_PATH_TPL.format(base=base_key)]
if base_key.startswith("model."):
candidates.append(self.EXPERTS_PATH_TPL.format(base=base_key[len("model.") :]))
return list(dict.fromkeys(candidates))
@staticmethod
def _ue8m0_to_bf16(scale_t: torch.Tensor) -> torch.Tensor:
if scale_t.dtype != torch.uint8:
scale_t = scale_t.view(torch.uint8)
# bf16 = [sign(1) | exp(8) | mant(7)]; setting mant=0, exp=e gives 2^(e-127),
# which is exactly the value encoded by ue8m0 for e ∈ [1, 254]. e=0 → bf16 +0
# (acceptable: ue8m0=0 represents 2^-127, below bf16 normal range), e=255 → +inf.
return (scale_t.to(torch.uint16) << 7).view(torch.bfloat16).contiguous()
def load_experts(self, base_key: str, device: str = "cpu"):
gate_name, up_name, down_name = self.PROJ_NAMES
prefix = None
expert_count = 0
for cand in self._experts_prefix_candidates(base_key):
expert_count = 0
while self.has_tensor(f"{cand}.{expert_count}.{gate_name}.weight"):
expert_count += 1
if expert_count > 0:
prefix = cand
break
if prefix is None:
raise ValueError(
f"No MXFP4 experts found under any of: {self._experts_prefix_candidates(base_key)}"
)
gate_weights = [None] * expert_count
up_weights = [None] * expert_count
down_weights = [None] * expert_count
gate_scales = [None] * expert_count
up_scales = [None] * expert_count
down_scales = [None] * expert_count
for exp_id in range(expert_count):
for proj, dst in (
(gate_name, gate_weights),
(up_name, up_weights),
(down_name, down_weights),
):
w = self.load_tensor(f"{prefix}.{exp_id}.{proj}.weight", device).contiguous()
if w.dtype != torch.uint8:
w = w.view(torch.uint8)
dst[exp_id] = w
for proj, dst in (
(gate_name, gate_scales),
(up_name, up_scales),
(down_name, down_scales),
):
s = self.load_tensor(f"{prefix}.{exp_id}.{proj}.scale", device)
dst[exp_id] = self._ue8m0_to_bf16(s)
print(f"[MXFP4SafeTensorLoader] Loaded {expert_count} experts from {prefix}")
return {
"gate": gate_weights,
"up": up_weights,
"down": down_weights,
"gate_scale": gate_scales,
"up_scale": up_scales,
"down_scale": down_scales,
}