[feat]: add AVX512F+BW fallback for FP8 and BF16 under AMX backend (#1908)
Some checks failed
Deploy / deploy (macos-latest) (push) Has been cancelled
Deploy / deploy (ubuntu-latest) (push) Has been cancelled
Deploy / deploy (windows-latest) (push) Has been cancelled
Book-CI / test (push) Has been cancelled
Book-CI / test-1 (push) Has been cancelled
Book-CI / test-2 (push) Has been cancelled

This commit is contained in:
Jim James 2026-04-03 00:46:22 -04:00 committed by GitHub
parent db9326302b
commit 8a427c9321
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 52 additions and 6 deletions

View file

@ -36,11 +36,9 @@ static const bool _is_plain_ = false;
#if defined(__x86_64__) && defined(USE_AMX_AVX_KERNEL)
#include "operators/amx/awq-moe.hpp"
#if defined(__AVX512BF16__)
#include "operators/amx/bf16-moe.hpp" // Native BF16 MoE using CRTP pattern
#include "operators/amx/fp8-moe.hpp" // FP8 MoE requires AVX512 BF16 support
#include "operators/amx/bf16-moe.hpp" // Native BF16 MoE using CRTP pattern, with fallback for AVX512F
#include "operators/amx/fp8-moe.hpp" // FP8 MoE requires AVX512 BF16 support, with fallback for AVX512F+BW
#include "operators/amx/fp8-perchannel-moe.hpp" // FP8 Per-Channel MoE for GLM-4.7-FP8
#endif
#include "operators/amx/k2-moe.hpp"
#include "operators/amx/la/amx_kernels.hpp"
#include "operators/amx/moe.hpp"
@ -579,7 +577,7 @@ PYBIND11_MODULE(kt_kernel_ext, m) {
bind_moe_module<AMX_MOE_TP<amx::GemmKernel224Int4_1>>(moe_module, "AMXInt4_1_MOE");
bind_moe_module<AMX_AWQ_MOE_TP<amx::GemmKernel224Int4_1_LowKGroup>>(moe_module, "AMXInt4_1KGroup_MOE");
bind_moe_module<AMX_K2_MOE_TP<amx::GemmKernel224Int4SmallKGroup>>(moe_module, "AMXInt4_KGroup_MOE");
#if defined(__AVX512BF16__)
#if defined(__AVX512F__)
bind_moe_module<AMX_BF16_MOE_TP<amx::GemmKernel224BF16>>(moe_module, "AMXBF16_MOE");
bind_moe_module<AMX_FP8_MOE_TP<amx::GemmKernel224FP8>>(moe_module, "AMXFP8_MOE");
bind_moe_module<AMX_FP8_PERCHANNEL_MOE_TP<amx::GemmKernel224FP8PerChannel>>(moe_module, "AMXFP8PerChannel_MOE");

View file

@ -14,6 +14,54 @@
#define ALWAYS_INLINE inline
#endif
#include <immintrin.h>
#if !defined(__AVX512BF16__)
#ifndef __m512bh
#define __m512bh __m512i
#endif
// BF16 emulation via AVX512F
static ALWAYS_INLINE __m512 _mm512_dpbf16_ps_emulated(__m512 src, __m512i a, __m512i b) {
__m512 a_low = _mm512_castsi512_ps(_mm512_slli_epi32(a, 16));
__m512 b_low = _mm512_castsi512_ps(_mm512_slli_epi32(b, 16));
__m512i mask = _mm512_set1_epi32(0xFFFF0000);
__m512 a_high = _mm512_castsi512_ps(_mm512_and_si512(a, mask));
__m512 b_high = _mm512_castsi512_ps(_mm512_and_si512(b, mask));
__m512 res = _mm512_fmadd_ps(a_low, b_low, src);
res = _mm512_fmadd_ps(a_high, b_high, res);
return res;
}
#ifndef _mm512_dpbf16_ps
#define _mm512_dpbf16_ps _mm512_dpbf16_ps_emulated
#endif
#endif
#if !defined(__AVX512VBMI__)
// VBMI emulation via AVX512F+BW
static ALWAYS_INLINE __m512i _mm512_permutex2var_epi8_emulated(__m512i a, __m512i idx, __m512i b) {
__m256i idx_lo_256 = _mm512_castsi512_si256(idx);
__m256i idx_hi_256 = _mm512_extracti64x4_epi64(idx, 1);
__m512i idx_lo_w = _mm512_cvtepu8_epi16(idx_lo_256);
__m512i idx_hi_w = _mm512_cvtepu8_epi16(idx_hi_256);
__m512i p_lo = _mm512_srli_epi16(idx_lo_w, 1);
__m512i p_hi = _mm512_srli_epi16(idx_hi_w, 1);
__m512i w_lo = _mm512_permutex2var_epi16(a, p_lo, b);
__m512i w_hi = _mm512_permutex2var_epi16(a, p_hi, b);
__m512i shift_lo = _mm512_slli_epi16(_mm512_and_si512(idx_lo_w, _mm512_set1_epi16(1)), 3);
__m512i shift_hi = _mm512_slli_epi16(_mm512_and_si512(idx_hi_w, _mm512_set1_epi16(1)), 3);
__m512i res_lo_w = _mm512_srlv_epi16(w_lo, shift_lo);
__m512i res_hi_w = _mm512_srlv_epi16(w_hi, shift_hi);
__m256i res_lo = _mm512_cvtepi16_epi8(res_lo_w);
__m256i res_hi = _mm512_cvtepi16_epi8(res_hi_w);
return _mm512_inserti64x4(_mm512_castsi256_si512(res_lo), res_hi, 1);
}
#ifndef _mm512_permutex2var_epi8
#define _mm512_permutex2var_epi8 _mm512_permutex2var_epi8_emulated
#endif
#endif
#if defined(__AMX__) || defined(__AMXINT8__) || defined(__AMXBF16__) || defined(__AMX_TILE__) || defined(HAVE_AMX)
#ifndef HAVE_AMX
#define HAVE_AMX
@ -187,4 +235,4 @@ static_assert(sizeof(TileConfig) == 64);
} // namespace amx
#endif // defined(__AMX__)
#endif // AMX_CONFIG_HPP
#endif // AMX_CONFIG_HPP