[feat](kt-kernel): adapt MXFP4 MoE backend for DeepSeek-V4-Flash (#1950)

V4-Flash routed experts ship as native MXFP4 (E2M1 nibble + ue8m0 group
scale). Expose AMXFP4_KGroup_MOE through NativeMoEWrapper, add a loader
that handles V4's `layers.{L}.ffn.experts.{i}.{w1,w3,w2}.{weight,scale}`
naming and converts ue8m0 → bf16 via a lossless bit-cast, register the
model entry, and ship an end-to-end numerical validation script.

Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
Benjamin F 2026-04-25 18:11:53 +08:00 committed by GitHub
parent 5c5d7d48c0
commit 8484ef8b16
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 322 additions and 2 deletions

View file

@ -11,6 +11,7 @@ from .loader import (
FP8SafeTensorLoader,
BF16SafeTensorLoader,
GPTQSafeTensorLoader,
MXFP4SafeTensorLoader,
)
from kt_kernel_ext.moe import MOEConfig
import kt_kernel_ext.moe as _moe_mod
@ -18,6 +19,7 @@ import kt_kernel_ext.moe as _moe_mod
AMXInt4_MOE = getattr(_moe_mod, "AMXInt4_MOE", None)
AMXInt8_MOE = getattr(_moe_mod, "AMXInt8_MOE", None)
AMXInt4_KGroup_MOE = getattr(_moe_mod, "AMXInt4_KGroup_MOE", None)
AMXFP4_KGroup_MOE = getattr(_moe_mod, "AMXFP4_KGroup_MOE", None)
AMXFP8_MOE = getattr(_moe_mod, "AMXFP8_MOE", None)
AMXBF16_MOE = getattr(_moe_mod, "AMXBF16_MOE", None)
AMXFP8PerChannel_MOE = getattr(_moe_mod, "AMXFP8PerChannel_MOE", None)
@ -29,6 +31,7 @@ AVXVNNI256GPTQInt4_MOE = getattr(_moe_mod, "AVXVNNI256GPTQInt4_MOE", None)
_HAS_AMXINT4_SUPPORT = AMXInt4_MOE is not None
_HAS_AMXINT8_SUPPORT = AMXInt8_MOE is not None
_HAS_RAWINT4_SUPPORT = AMXInt4_KGroup_MOE is not None
_HAS_MXFP4_SUPPORT = AMXFP4_KGroup_MOE is not None
_HAS_FP8_SUPPORT = AMXFP8_MOE is not None
_HAS_BF16_SUPPORT = AMXBF16_MOE is not None
_HAS_FP8_PERCHANNEL_SUPPORT = AMXFP8PerChannel_MOE is not None
@ -444,6 +447,12 @@ class NativeMoEWrapper(BaseMoEWrapper):
"Please recompile kt_kernel_ext with GPTQ INT4 support enabled.\n"
"AVX-VNNI-256 will be selected automatically when available on the current CPU."
)
if method == "MXFP4" and not _HAS_MXFP4_SUPPORT:
raise RuntimeError(
"MXFP4 backend not available. Required ISA:\n"
" - AVX512F + AVX512BW + AVX512_BF16\n"
"Please recompile kt_kernel_ext with AVX512 + BF16 enabled."
)
super().__init__(
layer_idx=layer_idx,
@ -474,6 +483,8 @@ class NativeMoEWrapper(BaseMoEWrapper):
NativeMoEWrapper._native_loader_instance = BF16SafeTensorLoader(weight_path)
elif method == "GPTQ_INT4":
NativeMoEWrapper._native_loader_instance = GPTQSafeTensorLoader(weight_path)
elif method == "MXFP4":
NativeMoEWrapper._native_loader_instance = MXFP4SafeTensorLoader(weight_path)
else:
raise NotImplementedError(f"Unsupported method for NativeMoEWrapper: {method}")
self.loader = NativeMoEWrapper._native_loader_instance
@ -541,6 +552,10 @@ class NativeMoEWrapper(BaseMoEWrapper):
self.up_scales = [t.to(torch.float32).contiguous() for t in weights["up_scale"]]
self.down_scales = [t.to(torch.float32).contiguous() for t in weights["down_scale"]]
assert self.gate_scales[0].dtype == torch.float32, "Expected float32 scales for FP8_PERCHANNEL"
elif self.method == "MXFP4":
# ue8m0 is losslessly representable in bf16 (8-bit exponent, 0 mantissa);
# the loader has already done that conversion.
assert self.gate_scales[0].dtype == torch.bfloat16, "Expected bf16 scales for MXFP4"
t2 = time.time()
@ -590,6 +605,13 @@ class NativeMoEWrapper(BaseMoEWrapper):
moe_config.quant_config.group_size = group_size
moe_config.quant_config.zero_point = False
self.moe = AMXInt4_KGroup_MOE(moe_config)
elif self.method == "MXFP4":
# MXFP4: E2M1 nibble-packed weights, ue8m0/bf16 per-32 group scale
group_size = self.hidden_size // self.gate_scales[0].shape[1]
moe_config.quant_config.bits = 4
moe_config.quant_config.group_size = group_size
moe_config.quant_config.zero_point = False
self.moe = AMXFP4_KGroup_MOE(moe_config)
elif self.method == "FP8":
moe_config.quant_config.bits = 8
moe_config.quant_config.group_size = 128