[feat]: add AVX512F+BW fallback for FP8 and BF16 under AMX backend (#1908)
Some checks failed
Deploy / deploy (macos-latest) (push) Has been cancelled
Deploy / deploy (ubuntu-latest) (push) Has been cancelled
Deploy / deploy (windows-latest) (push) Has been cancelled
Book-CI / test (push) Has been cancelled
Book-CI / test-1 (push) Has been cancelled
Book-CI / test-2 (push) Has been cancelled

This commit is contained in:
Jim James 2026-04-03 00:46:22 -04:00 committed by GitHub
parent db9326302b
commit 8a427c9321
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 52 additions and 6 deletions

View file

@ -36,11 +36,9 @@ static const bool _is_plain_ = false;
#if defined(__x86_64__) && defined(USE_AMX_AVX_KERNEL)
#include "operators/amx/awq-moe.hpp"
#if defined(__AVX512BF16__)
#include "operators/amx/bf16-moe.hpp" // Native BF16 MoE using CRTP pattern
#include "operators/amx/fp8-moe.hpp" // FP8 MoE requires AVX512 BF16 support
#include "operators/amx/bf16-moe.hpp" // Native BF16 MoE using CRTP pattern, with fallback for AVX512F
#include "operators/amx/fp8-moe.hpp" // FP8 MoE requires AVX512 BF16 support, with fallback for AVX512F+BW
#include "operators/amx/fp8-perchannel-moe.hpp" // FP8 Per-Channel MoE for GLM-4.7-FP8
#endif
#include "operators/amx/k2-moe.hpp"
#include "operators/amx/la/amx_kernels.hpp"
#include "operators/amx/moe.hpp"
@ -579,7 +577,7 @@ PYBIND11_MODULE(kt_kernel_ext, m) {
bind_moe_module<AMX_MOE_TP<amx::GemmKernel224Int4_1>>(moe_module, "AMXInt4_1_MOE");
bind_moe_module<AMX_AWQ_MOE_TP<amx::GemmKernel224Int4_1_LowKGroup>>(moe_module, "AMXInt4_1KGroup_MOE");
bind_moe_module<AMX_K2_MOE_TP<amx::GemmKernel224Int4SmallKGroup>>(moe_module, "AMXInt4_KGroup_MOE");
#if defined(__AVX512BF16__)
#if defined(__AVX512F__)
bind_moe_module<AMX_BF16_MOE_TP<amx::GemmKernel224BF16>>(moe_module, "AMXBF16_MOE");
bind_moe_module<AMX_FP8_MOE_TP<amx::GemmKernel224FP8>>(moe_module, "AMXFP8_MOE");
bind_moe_module<AMX_FP8_PERCHANNEL_MOE_TP<amx::GemmKernel224FP8PerChannel>>(moe_module, "AMXFP8PerChannel_MOE");