mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2026-04-26 10:50:59 +00:00
[feat]: add AVX512F+BW fallback for FP8 and BF16 under AMX backend (#1908)
Some checks failed
Deploy / deploy (macos-latest) (push) Has been cancelled
Deploy / deploy (ubuntu-latest) (push) Has been cancelled
Deploy / deploy (windows-latest) (push) Has been cancelled
Book-CI / test (push) Has been cancelled
Book-CI / test-1 (push) Has been cancelled
Book-CI / test-2 (push) Has been cancelled
Some checks failed
Deploy / deploy (macos-latest) (push) Has been cancelled
Deploy / deploy (ubuntu-latest) (push) Has been cancelled
Deploy / deploy (windows-latest) (push) Has been cancelled
Book-CI / test (push) Has been cancelled
Book-CI / test-1 (push) Has been cancelled
Book-CI / test-2 (push) Has been cancelled
This commit is contained in:
parent
db9326302b
commit
8a427c9321
2 changed files with 52 additions and 6 deletions
|
|
@ -36,11 +36,9 @@ static const bool _is_plain_ = false;
|
|||
|
||||
#if defined(__x86_64__) && defined(USE_AMX_AVX_KERNEL)
|
||||
#include "operators/amx/awq-moe.hpp"
|
||||
#if defined(__AVX512BF16__)
|
||||
#include "operators/amx/bf16-moe.hpp" // Native BF16 MoE using CRTP pattern
|
||||
#include "operators/amx/fp8-moe.hpp" // FP8 MoE requires AVX512 BF16 support
|
||||
#include "operators/amx/bf16-moe.hpp" // Native BF16 MoE using CRTP pattern, with fallback for AVX512F
|
||||
#include "operators/amx/fp8-moe.hpp" // FP8 MoE requires AVX512 BF16 support, with fallback for AVX512F+BW
|
||||
#include "operators/amx/fp8-perchannel-moe.hpp" // FP8 Per-Channel MoE for GLM-4.7-FP8
|
||||
#endif
|
||||
#include "operators/amx/k2-moe.hpp"
|
||||
#include "operators/amx/la/amx_kernels.hpp"
|
||||
#include "operators/amx/moe.hpp"
|
||||
|
|
@ -579,7 +577,7 @@ PYBIND11_MODULE(kt_kernel_ext, m) {
|
|||
bind_moe_module<AMX_MOE_TP<amx::GemmKernel224Int4_1>>(moe_module, "AMXInt4_1_MOE");
|
||||
bind_moe_module<AMX_AWQ_MOE_TP<amx::GemmKernel224Int4_1_LowKGroup>>(moe_module, "AMXInt4_1KGroup_MOE");
|
||||
bind_moe_module<AMX_K2_MOE_TP<amx::GemmKernel224Int4SmallKGroup>>(moe_module, "AMXInt4_KGroup_MOE");
|
||||
#if defined(__AVX512BF16__)
|
||||
#if defined(__AVX512F__)
|
||||
bind_moe_module<AMX_BF16_MOE_TP<amx::GemmKernel224BF16>>(moe_module, "AMXBF16_MOE");
|
||||
bind_moe_module<AMX_FP8_MOE_TP<amx::GemmKernel224FP8>>(moe_module, "AMXFP8_MOE");
|
||||
bind_moe_module<AMX_FP8_PERCHANNEL_MOE_TP<amx::GemmKernel224FP8PerChannel>>(moe_module, "AMXFP8PerChannel_MOE");
|
||||
|
|
|
|||
|
|
@ -14,6 +14,54 @@
|
|||
#define ALWAYS_INLINE inline
|
||||
#endif
|
||||
#include <immintrin.h>
|
||||
|
||||
#if !defined(__AVX512BF16__)
|
||||
#ifndef __m512bh
|
||||
#define __m512bh __m512i
|
||||
#endif
|
||||
|
||||
// BF16 emulation via AVX512F
|
||||
static ALWAYS_INLINE __m512 _mm512_dpbf16_ps_emulated(__m512 src, __m512i a, __m512i b) {
|
||||
__m512 a_low = _mm512_castsi512_ps(_mm512_slli_epi32(a, 16));
|
||||
__m512 b_low = _mm512_castsi512_ps(_mm512_slli_epi32(b, 16));
|
||||
__m512i mask = _mm512_set1_epi32(0xFFFF0000);
|
||||
__m512 a_high = _mm512_castsi512_ps(_mm512_and_si512(a, mask));
|
||||
__m512 b_high = _mm512_castsi512_ps(_mm512_and_si512(b, mask));
|
||||
__m512 res = _mm512_fmadd_ps(a_low, b_low, src);
|
||||
res = _mm512_fmadd_ps(a_high, b_high, res);
|
||||
return res;
|
||||
}
|
||||
|
||||
#ifndef _mm512_dpbf16_ps
|
||||
#define _mm512_dpbf16_ps _mm512_dpbf16_ps_emulated
|
||||
#endif
|
||||
#endif
|
||||
|
||||
#if !defined(__AVX512VBMI__)
|
||||
// VBMI emulation via AVX512F+BW
|
||||
static ALWAYS_INLINE __m512i _mm512_permutex2var_epi8_emulated(__m512i a, __m512i idx, __m512i b) {
|
||||
__m256i idx_lo_256 = _mm512_castsi512_si256(idx);
|
||||
__m256i idx_hi_256 = _mm512_extracti64x4_epi64(idx, 1);
|
||||
__m512i idx_lo_w = _mm512_cvtepu8_epi16(idx_lo_256);
|
||||
__m512i idx_hi_w = _mm512_cvtepu8_epi16(idx_hi_256);
|
||||
__m512i p_lo = _mm512_srli_epi16(idx_lo_w, 1);
|
||||
__m512i p_hi = _mm512_srli_epi16(idx_hi_w, 1);
|
||||
__m512i w_lo = _mm512_permutex2var_epi16(a, p_lo, b);
|
||||
__m512i w_hi = _mm512_permutex2var_epi16(a, p_hi, b);
|
||||
__m512i shift_lo = _mm512_slli_epi16(_mm512_and_si512(idx_lo_w, _mm512_set1_epi16(1)), 3);
|
||||
__m512i shift_hi = _mm512_slli_epi16(_mm512_and_si512(idx_hi_w, _mm512_set1_epi16(1)), 3);
|
||||
__m512i res_lo_w = _mm512_srlv_epi16(w_lo, shift_lo);
|
||||
__m512i res_hi_w = _mm512_srlv_epi16(w_hi, shift_hi);
|
||||
__m256i res_lo = _mm512_cvtepi16_epi8(res_lo_w);
|
||||
__m256i res_hi = _mm512_cvtepi16_epi8(res_hi_w);
|
||||
return _mm512_inserti64x4(_mm512_castsi256_si512(res_lo), res_hi, 1);
|
||||
}
|
||||
|
||||
#ifndef _mm512_permutex2var_epi8
|
||||
#define _mm512_permutex2var_epi8 _mm512_permutex2var_epi8_emulated
|
||||
#endif
|
||||
#endif
|
||||
|
||||
#if defined(__AMX__) || defined(__AMXINT8__) || defined(__AMXBF16__) || defined(__AMX_TILE__) || defined(HAVE_AMX)
|
||||
#ifndef HAVE_AMX
|
||||
#define HAVE_AMX
|
||||
|
|
@ -187,4 +235,4 @@ static_assert(sizeof(TileConfig) == 64);
|
|||
|
||||
} // namespace amx
|
||||
#endif // defined(__AMX__)
|
||||
#endif // AMX_CONFIG_HPP
|
||||
#endif // AMX_CONFIG_HPP
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue