[feat](kt-kernel): support avx2 only inference for bf16 fp8 and gptq int4 (#1892)
Some checks are pending
Book-CI / test (push) Waiting to run
Book-CI / test-1 (push) Waiting to run
Book-CI / test-2 (push) Waiting to run
Deploy / deploy (macos-latest) (push) Waiting to run
Deploy / deploy (ubuntu-latest) (push) Waiting to run
Deploy / deploy (windows-latest) (push) Waiting to run

* feat: support avx2 bf16 fp8 inference

* feat: support avx2 gptq int4 inference

* fix: numeric issues in fp8 dequant

* Tutorial avx2 (#1900)

* fix: prevent injecting -DLLAMA_AVX512=ON on AVX2-only machines

* docs: add AVX2 tutorial for running KTransformers on AVX2-only CPUs

* Tutorial avx2 (#1901)

* fix: prevent injecting -DLLAMA_AVX512=ON on AVX2-only machines

* docs: add AVX2 tutorial for running KTransformers on AVX2-only CPUs

* docs: update README.md

---------

Co-authored-by: Benjamin F <159887351+yyj6666667@users.noreply.github.com>
This commit is contained in:
mrhaoxx 2026-03-27 14:45:02 +08:00 committed by GitHub
parent 8561a71dd1
commit 7a9daf0cd4
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
19 changed files with 3472 additions and 12 deletions

View file

@ -5,7 +5,7 @@ from typing import Optional
# Use relative imports for package structure
from ..experts_base import BaseMoEWrapper
from .loader import SafeTensorLoader, CompressedSafeTensorLoader, FP8SafeTensorLoader, BF16SafeTensorLoader
from .loader import SafeTensorLoader, CompressedSafeTensorLoader, FP8SafeTensorLoader, BF16SafeTensorLoader, GPTQSafeTensorLoader
from kt_kernel_ext.moe import MOEConfig
import kt_kernel_ext.moe as _moe_mod
@ -15,6 +15,9 @@ AMXInt4_KGroup_MOE = getattr(_moe_mod, "AMXInt4_KGroup_MOE", None)
AMXFP8_MOE = getattr(_moe_mod, "AMXFP8_MOE", None)
AMXBF16_MOE = getattr(_moe_mod, "AMXBF16_MOE", None)
AMXFP8PerChannel_MOE = getattr(_moe_mod, "AMXFP8PerChannel_MOE", None)
AVX2BF16_MOE = getattr(_moe_mod, "AVX2BF16_MOE", None)
AVX2FP8_MOE = getattr(_moe_mod, "AVX2FP8_MOE", None)
AVX2GPTQInt4_MOE = getattr(_moe_mod, "AVX2GPTQInt4_MOE", None)
_HAS_AMXINT4_SUPPORT = AMXInt4_MOE is not None
_HAS_AMXINT8_SUPPORT = AMXInt8_MOE is not None
@ -22,6 +25,9 @@ _HAS_RAWINT4_SUPPORT = AMXInt4_KGroup_MOE is not None
_HAS_FP8_SUPPORT = AMXFP8_MOE is not None
_HAS_BF16_SUPPORT = AMXBF16_MOE is not None
_HAS_FP8_PERCHANNEL_SUPPORT = AMXFP8PerChannel_MOE is not None
_HAS_AVX2_BF16_SUPPORT = AVX2BF16_MOE is not None
_HAS_AVX2_FP8_SUPPORT = AVX2FP8_MOE is not None
_HAS_AVX2_GPTQ_INT4_SUPPORT = AVX2GPTQInt4_MOE is not None
class AMXMoEWrapper(BaseMoEWrapper):
@ -346,10 +352,11 @@ class NativeMoEWrapper(BaseMoEWrapper):
" - AVX512F + AVX512BW (VNNI optional)\n"
"Please recompile kt_kernel_ext with AVX512 enabled."
)
if method == "FP8" and not _HAS_FP8_SUPPORT:
if method == "FP8" and not _HAS_FP8_SUPPORT and not _HAS_AVX2_FP8_SUPPORT:
raise RuntimeError(
"FP8 backend not available. Required ISA:\n"
" - AVX512F + AVX512BW + AVX512_BF16 + AVX512_VBMI\n"
" - AVX512F + AVX512BW + AVX512_BF16 + AVX512_VBMI (for AMX), or\n"
" - AVX2 + FMA (for AVX2 fallback)\n"
"Please recompile kt_kernel_ext with AVX512 + BF16 + VBMI enabled."
)
if method == "FP8_PERCHANNEL" and not _HAS_FP8_PERCHANNEL_SUPPORT:
@ -358,11 +365,17 @@ class NativeMoEWrapper(BaseMoEWrapper):
" - AVX512F + AVX512BW + AVX512_BF16 + AVX512_VBMI\n"
"Please recompile kt_kernel_ext with AVX512 + BF16 + VBMI enabled."
)
if method == "BF16" and not _HAS_BF16_SUPPORT:
if method == "BF16" and not _HAS_BF16_SUPPORT and not _HAS_AVX2_BF16_SUPPORT:
raise RuntimeError(
"BF16 backend not available. Required ISA:\n"
" - AVX512F + AVX512BW + AVX512_BF16\n"
"Please recompile kt_kernel_ext with AVX512 + BF16 enabled."
" - AVX512F + AVX512BW + AVX512_BF16 (for AMX backend), or\n"
" - AVX2 + FMA (for AVX2 fallback backend)\n"
"Please recompile kt_kernel_ext with AVX512+BF16 or AVX2 enabled."
)
if method == "GPTQ_INT4" and not _HAS_AVX2_GPTQ_INT4_SUPPORT:
raise RuntimeError(
"GPTQ_INT4 backend not available.\n"
"Please recompile kt_kernel_ext with AVX2 enabled."
)
super().__init__(
@ -391,6 +404,8 @@ class NativeMoEWrapper(BaseMoEWrapper):
NativeMoEWrapper._native_loader_instance = FP8SafeTensorLoader(weight_path, scale_suffix="weight_scale")
elif method == "BF16":
NativeMoEWrapper._native_loader_instance = BF16SafeTensorLoader(weight_path)
elif method == "GPTQ_INT4":
NativeMoEWrapper._native_loader_instance = GPTQSafeTensorLoader(weight_path)
else:
raise NotImplementedError(f"Unsupported method for NativeMoEWrapper: {method}")
self.loader = NativeMoEWrapper._native_loader_instance
@ -506,15 +521,31 @@ class NativeMoEWrapper(BaseMoEWrapper):
moe_config.quant_config.bits = 8
moe_config.quant_config.group_size = 128
moe_config.quant_config.zero_point = False
self.moe = AMXFP8_MOE(moe_config)
if _HAS_FP8_SUPPORT:
self.moe = AMXFP8_MOE(moe_config)
else:
self.moe = AVX2FP8_MOE(moe_config)
elif self.method == "FP8_PERCHANNEL":
moe_config.quant_config.bits = 8
moe_config.quant_config.per_channel = True
moe_config.quant_config.zero_point = False
self.moe = AMXFP8PerChannel_MOE(moe_config)
elif self.method == "GPTQ_INT4":
# GPTQ symmetric INT4: qweight (int32) + scales (fp32)
group_size = self.gate_scales[0].shape[0] # scales shape [K/gs, N], first dim = num_groups
# hidden_size / num_groups = group_size
actual_gs = self.hidden_size // group_size
moe_config.quant_config.bits = 4
moe_config.quant_config.group_size = actual_gs
moe_config.quant_config.zero_point = False
self.moe = AVX2GPTQInt4_MOE(moe_config)
elif self.method == "BF16":
# BF16 has no quantization config needed
self.moe = AMXBF16_MOE(moe_config)
# Prefer AMX backend, fall back to AVX2
if _HAS_BF16_SUPPORT:
self.moe = AMXBF16_MOE(moe_config)
else:
self.moe = AVX2BF16_MOE(moe_config)
t4 = time.time()
self.cpu_infer.submit(self.moe.load_weights_task(physical_to_logical_map_cpu.data_ptr()))