mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2026-04-28 20:00:06 +00:00
[feat](kt-kernel): adapt MXFP4 MoE backend for DeepSeek-V4-Flash (#1950)
V4-Flash routed experts ship as native MXFP4 (E2M1 nibble + ue8m0 group
scale). Expose AMXFP4_KGroup_MOE through NativeMoEWrapper, add a loader
that handles V4's `layers.{L}.ffn.experts.{i}.{w1,w3,w2}.{weight,scale}`
naming and converts ue8m0 → bf16 via a lossless bit-cast, register the
model entry, and ship an end-to-end numerical validation script.
Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
parent
5c5d7d48c0
commit
8484ef8b16
5 changed files with 322 additions and 2 deletions
|
|
@ -1078,3 +1078,89 @@ class GPTQSafeTensorLoader(FP8SafeTensorLoader):
|
|||
"up_scale": up_scales,
|
||||
"down_scale": down_scales,
|
||||
}
|
||||
|
||||
|
||||
class MXFP4SafeTensorLoader(SafeTensorLoader):
|
||||
"""Loader for native MXFP4 expert weights (DeepSeek-V4-Flash format).
|
||||
|
||||
Per expert layout:
|
||||
{base}.ffn.experts.{i}.w1.weight I8 [N, K/2] nibble-packed E2M1 (gate)
|
||||
{base}.ffn.experts.{i}.w1.scale F8_E8M0 [N, K/32] ue8m0 group scale
|
||||
{base}.ffn.experts.{i}.w3.{weight,scale} up
|
||||
{base}.ffn.experts.{i}.w2.{weight,scale} down
|
||||
|
||||
V4 ckpt keys are not prefixed with ``model.``; we also probe the stripped form so
|
||||
callers can keep passing ``base_key="model.layers.{L}"``. ue8m0 → bf16 is a lossless
|
||||
bit shift (both have an 8-bit exponent and zero mantissa for ue8m0), and the AMX
|
||||
FP4 backend already consumes bf16 scales.
|
||||
"""
|
||||
|
||||
EXPERTS_PATH_TPL = "{base}.ffn.experts"
|
||||
PROJ_NAMES = ("w1", "w3", "w2") # (gate, up, down)
|
||||
|
||||
def _experts_prefix_candidates(self, base_key: str) -> list[str]:
|
||||
candidates = [self.EXPERTS_PATH_TPL.format(base=base_key)]
|
||||
if base_key.startswith("model."):
|
||||
candidates.append(self.EXPERTS_PATH_TPL.format(base=base_key[len("model.") :]))
|
||||
return list(dict.fromkeys(candidates))
|
||||
|
||||
@staticmethod
|
||||
def _ue8m0_to_bf16(scale_t: torch.Tensor) -> torch.Tensor:
|
||||
if scale_t.dtype != torch.uint8:
|
||||
scale_t = scale_t.view(torch.uint8)
|
||||
# bf16 = [sign(1) | exp(8) | mant(7)]; setting mant=0, exp=e gives 2^(e-127),
|
||||
# which is exactly the value encoded by ue8m0 for e ∈ [1, 254]. e=0 → bf16 +0
|
||||
# (acceptable: ue8m0=0 represents 2^-127, below bf16 normal range), e=255 → +inf.
|
||||
return (scale_t.to(torch.uint16) << 7).view(torch.bfloat16).contiguous()
|
||||
|
||||
def load_experts(self, base_key: str, device: str = "cpu"):
|
||||
gate_name, up_name, down_name = self.PROJ_NAMES
|
||||
prefix = None
|
||||
expert_count = 0
|
||||
for cand in self._experts_prefix_candidates(base_key):
|
||||
expert_count = 0
|
||||
while self.has_tensor(f"{cand}.{expert_count}.{gate_name}.weight"):
|
||||
expert_count += 1
|
||||
if expert_count > 0:
|
||||
prefix = cand
|
||||
break
|
||||
if prefix is None:
|
||||
raise ValueError(
|
||||
f"No MXFP4 experts found under any of: {self._experts_prefix_candidates(base_key)}"
|
||||
)
|
||||
|
||||
gate_weights = [None] * expert_count
|
||||
up_weights = [None] * expert_count
|
||||
down_weights = [None] * expert_count
|
||||
gate_scales = [None] * expert_count
|
||||
up_scales = [None] * expert_count
|
||||
down_scales = [None] * expert_count
|
||||
|
||||
for exp_id in range(expert_count):
|
||||
for proj, dst in (
|
||||
(gate_name, gate_weights),
|
||||
(up_name, up_weights),
|
||||
(down_name, down_weights),
|
||||
):
|
||||
w = self.load_tensor(f"{prefix}.{exp_id}.{proj}.weight", device).contiguous()
|
||||
if w.dtype != torch.uint8:
|
||||
w = w.view(torch.uint8)
|
||||
dst[exp_id] = w
|
||||
|
||||
for proj, dst in (
|
||||
(gate_name, gate_scales),
|
||||
(up_name, up_scales),
|
||||
(down_name, down_scales),
|
||||
):
|
||||
s = self.load_tensor(f"{prefix}.{exp_id}.{proj}.scale", device)
|
||||
dst[exp_id] = self._ue8m0_to_bf16(s)
|
||||
|
||||
print(f"[MXFP4SafeTensorLoader] Loaded {expert_count} experts from {prefix}")
|
||||
return {
|
||||
"gate": gate_weights,
|
||||
"up": up_weights,
|
||||
"down": down_weights,
|
||||
"gate_scale": gate_scales,
|
||||
"up_scale": up_scales,
|
||||
"down_scale": down_scales,
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue