diff --git a/kt-kernel/ext_bindings.cpp b/kt-kernel/ext_bindings.cpp index 950e09b8..4c6b48cc 100644 --- a/kt-kernel/ext_bindings.cpp +++ b/kt-kernel/ext_bindings.cpp @@ -36,11 +36,9 @@ static const bool _is_plain_ = false; #if defined(__x86_64__) && defined(USE_AMX_AVX_KERNEL) #include "operators/amx/awq-moe.hpp" -#if defined(__AVX512BF16__) -#include "operators/amx/bf16-moe.hpp" // Native BF16 MoE using CRTP pattern -#include "operators/amx/fp8-moe.hpp" // FP8 MoE requires AVX512 BF16 support +#include "operators/amx/bf16-moe.hpp" // Native BF16 MoE using CRTP pattern, with fallback for AVX512F +#include "operators/amx/fp8-moe.hpp" // FP8 MoE requires AVX512 BF16 support, with fallback for AVX512F+BW #include "operators/amx/fp8-perchannel-moe.hpp" // FP8 Per-Channel MoE for GLM-4.7-FP8 -#endif #include "operators/amx/k2-moe.hpp" #include "operators/amx/la/amx_kernels.hpp" #include "operators/amx/moe.hpp" @@ -579,7 +577,7 @@ PYBIND11_MODULE(kt_kernel_ext, m) { bind_moe_module>(moe_module, "AMXInt4_1_MOE"); bind_moe_module>(moe_module, "AMXInt4_1KGroup_MOE"); bind_moe_module>(moe_module, "AMXInt4_KGroup_MOE"); -#if defined(__AVX512BF16__) +#if defined(__AVX512F__) bind_moe_module>(moe_module, "AMXBF16_MOE"); bind_moe_module>(moe_module, "AMXFP8_MOE"); bind_moe_module>(moe_module, "AMXFP8PerChannel_MOE"); diff --git a/kt-kernel/operators/amx/la/amx_config.hpp b/kt-kernel/operators/amx/la/amx_config.hpp index 142a4b73..82f4b581 100644 --- a/kt-kernel/operators/amx/la/amx_config.hpp +++ b/kt-kernel/operators/amx/la/amx_config.hpp @@ -14,6 +14,54 @@ #define ALWAYS_INLINE inline #endif #include + +#if !defined(__AVX512BF16__) +#ifndef __m512bh +#define __m512bh __m512i +#endif + +// BF16 emulation via AVX512F +static ALWAYS_INLINE __m512 _mm512_dpbf16_ps_emulated(__m512 src, __m512i a, __m512i b) { + __m512 a_low = _mm512_castsi512_ps(_mm512_slli_epi32(a, 16)); + __m512 b_low = _mm512_castsi512_ps(_mm512_slli_epi32(b, 16)); + __m512i mask = _mm512_set1_epi32(0xFFFF0000); + __m512 a_high = _mm512_castsi512_ps(_mm512_and_si512(a, mask)); + __m512 b_high = _mm512_castsi512_ps(_mm512_and_si512(b, mask)); + __m512 res = _mm512_fmadd_ps(a_low, b_low, src); + res = _mm512_fmadd_ps(a_high, b_high, res); + return res; +} + +#ifndef _mm512_dpbf16_ps +#define _mm512_dpbf16_ps _mm512_dpbf16_ps_emulated +#endif +#endif + +#if !defined(__AVX512VBMI__) +// VBMI emulation via AVX512F+BW +static ALWAYS_INLINE __m512i _mm512_permutex2var_epi8_emulated(__m512i a, __m512i idx, __m512i b) { + __m256i idx_lo_256 = _mm512_castsi512_si256(idx); + __m256i idx_hi_256 = _mm512_extracti64x4_epi64(idx, 1); + __m512i idx_lo_w = _mm512_cvtepu8_epi16(idx_lo_256); + __m512i idx_hi_w = _mm512_cvtepu8_epi16(idx_hi_256); + __m512i p_lo = _mm512_srli_epi16(idx_lo_w, 1); + __m512i p_hi = _mm512_srli_epi16(idx_hi_w, 1); + __m512i w_lo = _mm512_permutex2var_epi16(a, p_lo, b); + __m512i w_hi = _mm512_permutex2var_epi16(a, p_hi, b); + __m512i shift_lo = _mm512_slli_epi16(_mm512_and_si512(idx_lo_w, _mm512_set1_epi16(1)), 3); + __m512i shift_hi = _mm512_slli_epi16(_mm512_and_si512(idx_hi_w, _mm512_set1_epi16(1)), 3); + __m512i res_lo_w = _mm512_srlv_epi16(w_lo, shift_lo); + __m512i res_hi_w = _mm512_srlv_epi16(w_hi, shift_hi); + __m256i res_lo = _mm512_cvtepi16_epi8(res_lo_w); + __m256i res_hi = _mm512_cvtepi16_epi8(res_hi_w); + return _mm512_inserti64x4(_mm512_castsi256_si512(res_lo), res_hi, 1); +} + +#ifndef _mm512_permutex2var_epi8 +#define _mm512_permutex2var_epi8 _mm512_permutex2var_epi8_emulated +#endif +#endif + #if defined(__AMX__) || defined(__AMXINT8__) || defined(__AMXBF16__) || defined(__AMX_TILE__) || defined(HAVE_AMX) #ifndef HAVE_AMX #define HAVE_AMX @@ -187,4 +235,4 @@ static_assert(sizeof(TileConfig) == 64); } // namespace amx #endif // defined(__AMX__) -#endif // AMX_CONFIG_HPP \ No newline at end of file +#endif // AMX_CONFIG_HPP