mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2026-05-05 15:40:13 +00:00
[feat](kt-kernel): adapt MXFP4 MoE backend for DeepSeek-V4-Flash (#1950)
V4-Flash routed experts ship as native MXFP4 (E2M1 nibble + ue8m0 group
scale). Expose AMXFP4_KGroup_MOE through NativeMoEWrapper, add a loader
that handles V4's `layers.{L}.ffn.experts.{i}.{w1,w3,w2}.{weight,scale}`
naming and converts ue8m0 → bf16 via a lossless bit-cast, register the
model entry, and ship an end-to-end numerical validation script.
Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
parent
5c5d7d48c0
commit
8484ef8b16
5 changed files with 322 additions and 2 deletions
|
|
@ -11,6 +11,7 @@ from .loader import (
|
|||
FP8SafeTensorLoader,
|
||||
BF16SafeTensorLoader,
|
||||
GPTQSafeTensorLoader,
|
||||
MXFP4SafeTensorLoader,
|
||||
)
|
||||
from kt_kernel_ext.moe import MOEConfig
|
||||
import kt_kernel_ext.moe as _moe_mod
|
||||
|
|
@ -18,6 +19,7 @@ import kt_kernel_ext.moe as _moe_mod
|
|||
AMXInt4_MOE = getattr(_moe_mod, "AMXInt4_MOE", None)
|
||||
AMXInt8_MOE = getattr(_moe_mod, "AMXInt8_MOE", None)
|
||||
AMXInt4_KGroup_MOE = getattr(_moe_mod, "AMXInt4_KGroup_MOE", None)
|
||||
AMXFP4_KGroup_MOE = getattr(_moe_mod, "AMXFP4_KGroup_MOE", None)
|
||||
AMXFP8_MOE = getattr(_moe_mod, "AMXFP8_MOE", None)
|
||||
AMXBF16_MOE = getattr(_moe_mod, "AMXBF16_MOE", None)
|
||||
AMXFP8PerChannel_MOE = getattr(_moe_mod, "AMXFP8PerChannel_MOE", None)
|
||||
|
|
@ -29,6 +31,7 @@ AVXVNNI256GPTQInt4_MOE = getattr(_moe_mod, "AVXVNNI256GPTQInt4_MOE", None)
|
|||
_HAS_AMXINT4_SUPPORT = AMXInt4_MOE is not None
|
||||
_HAS_AMXINT8_SUPPORT = AMXInt8_MOE is not None
|
||||
_HAS_RAWINT4_SUPPORT = AMXInt4_KGroup_MOE is not None
|
||||
_HAS_MXFP4_SUPPORT = AMXFP4_KGroup_MOE is not None
|
||||
_HAS_FP8_SUPPORT = AMXFP8_MOE is not None
|
||||
_HAS_BF16_SUPPORT = AMXBF16_MOE is not None
|
||||
_HAS_FP8_PERCHANNEL_SUPPORT = AMXFP8PerChannel_MOE is not None
|
||||
|
|
@ -444,6 +447,12 @@ class NativeMoEWrapper(BaseMoEWrapper):
|
|||
"Please recompile kt_kernel_ext with GPTQ INT4 support enabled.\n"
|
||||
"AVX-VNNI-256 will be selected automatically when available on the current CPU."
|
||||
)
|
||||
if method == "MXFP4" and not _HAS_MXFP4_SUPPORT:
|
||||
raise RuntimeError(
|
||||
"MXFP4 backend not available. Required ISA:\n"
|
||||
" - AVX512F + AVX512BW + AVX512_BF16\n"
|
||||
"Please recompile kt_kernel_ext with AVX512 + BF16 enabled."
|
||||
)
|
||||
|
||||
super().__init__(
|
||||
layer_idx=layer_idx,
|
||||
|
|
@ -474,6 +483,8 @@ class NativeMoEWrapper(BaseMoEWrapper):
|
|||
NativeMoEWrapper._native_loader_instance = BF16SafeTensorLoader(weight_path)
|
||||
elif method == "GPTQ_INT4":
|
||||
NativeMoEWrapper._native_loader_instance = GPTQSafeTensorLoader(weight_path)
|
||||
elif method == "MXFP4":
|
||||
NativeMoEWrapper._native_loader_instance = MXFP4SafeTensorLoader(weight_path)
|
||||
else:
|
||||
raise NotImplementedError(f"Unsupported method for NativeMoEWrapper: {method}")
|
||||
self.loader = NativeMoEWrapper._native_loader_instance
|
||||
|
|
@ -541,6 +552,10 @@ class NativeMoEWrapper(BaseMoEWrapper):
|
|||
self.up_scales = [t.to(torch.float32).contiguous() for t in weights["up_scale"]]
|
||||
self.down_scales = [t.to(torch.float32).contiguous() for t in weights["down_scale"]]
|
||||
assert self.gate_scales[0].dtype == torch.float32, "Expected float32 scales for FP8_PERCHANNEL"
|
||||
elif self.method == "MXFP4":
|
||||
# ue8m0 is losslessly representable in bf16 (8-bit exponent, 0 mantissa);
|
||||
# the loader has already done that conversion.
|
||||
assert self.gate_scales[0].dtype == torch.bfloat16, "Expected bf16 scales for MXFP4"
|
||||
|
||||
t2 = time.time()
|
||||
|
||||
|
|
@ -590,6 +605,13 @@ class NativeMoEWrapper(BaseMoEWrapper):
|
|||
moe_config.quant_config.group_size = group_size
|
||||
moe_config.quant_config.zero_point = False
|
||||
self.moe = AMXInt4_KGroup_MOE(moe_config)
|
||||
elif self.method == "MXFP4":
|
||||
# MXFP4: E2M1 nibble-packed weights, ue8m0/bf16 per-32 group scale
|
||||
group_size = self.hidden_size // self.gate_scales[0].shape[1]
|
||||
moe_config.quant_config.bits = 4
|
||||
moe_config.quant_config.group_size = group_size
|
||||
moe_config.quant_config.zero_point = False
|
||||
self.moe = AMXFP4_KGroup_MOE(moe_config)
|
||||
elif self.method == "FP8":
|
||||
moe_config.quant_config.bits = 8
|
||||
moe_config.quant_config.group_size = 128
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue