[Feature] Add avx-based kimi-k2 support (#1656)
Some checks are pending
Book-CI / test-2 (push) Waiting to run
Book-CI / test (push) Waiting to run
Book-CI / test-1 (push) Waiting to run
Deploy / deploy (macos-latest) (push) Waiting to run
Deploy / deploy (ubuntu-latest) (push) Waiting to run
Deploy / deploy (windows-latest) (push) Waiting to run

* support Kimi-K2-Thinking original weight
fix amx kernel bug

* update k2 avx kernel.

* feat: add CPUInfer write buffer task

* [feat]: add kimi k2 cpu write buffer support

- Implement write_weights_to_buffer function in k2-moe.hpp for extracting GPU expert weights
- Fix down (w2) weight column-wise slicing for different TP configurations
- Support three TP scenarios: cpu_tp == gpu_tp, cpu_tp > gpu_tp, cpu_tp < gpu_tp
- Add comprehensive test cases for weight extraction validation
- Ensure compatibility with Kimi model's MoE architecture

* [fix]: correct write_weight_scale_to_buffer expert offset calculation

Fixed the bug in write_weight_scale_to_buffer_task where expert offsets in GPU buffers were incorrectly calculated. Changed from using per_expert_gpu sizes to using full gpu_tp sizes, ensuring correct memory layout for multi-expert scenarios.

Also added benchmark scripts for k2 moe and write buffer operations, and cleaned up debug output in test files.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>

* [feat]: add write buffer wrapper

* [fix] fix comment

---------

Co-authored-by: ouqingliang <1692110604@qq.com>
Co-authored-by: Claude <noreply@anthropic.com>
This commit is contained in:
Jiaqi Liao 2025-12-02 16:01:07 +08:00 committed by GitHub
parent c2b8c60c4e
commit fcf8882075
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
12 changed files with 2649 additions and 34 deletions

View file

@ -1015,8 +1015,9 @@ struct GemmKernel224Int8 {
static void avx_kernel(int m, int n, int k, int m_begin, int n_begin, int k_block_begin, float* c, BufferA* ba,
BufferB* bb) {
__m512i* c512 = (__m512i*)c;
int m_block_end = std::min(m - m_begin, M_STEP);
if (k_block_begin == 0) {
for (int m_i = 0; m_i < m; m_i++) {
for (int m_i = 0; m_i < m_block_end; m_i++) {
c512[m_i * 2] = _mm512_setzero_si512();
c512[m_i * 2 + 1] = _mm512_setzero_si512();
}
@ -1028,7 +1029,7 @@ struct GemmKernel224Int8 {
int32_t* a32 = (int32_t*)ba->get_submat(m, k, m_begin, k_block_begin + k_begin);
__m512i* b512 = (__m512i*)bb->get_submat(n, k, n_begin, k_block_begin + k_begin);
for (int m_i = 0; m_i < m && m_i < M_STEP; m_i++) {
for (int m_i = 0; m_i < m_block_end; m_i++) {
for (int k_i = 0; k_i < 16; k_i++) {
__m512i ma = _mm512_set1_epi32(a32[m_i * 16 + k_i]);
for (int n_i = 0; n_i < 2; n_i++) {
@ -1239,8 +1240,9 @@ struct GemmKernel224Int4 {
BufferB* bb) {
using K = GemmKernel224Int4;
__m512i* c512 = (__m512i*)c;
int m_block_end = std::min(m - m_begin, M_STEP);
if (k_block_begin == 0) {
for (int m_i = 0; m_i < m; m_i++) {
for (int m_i = 0; m_i < m_block_end; m_i++) {
c512[m_i * 2] = _mm512_setzero_si512();
c512[m_i * 2 + 1] = _mm512_setzero_si512();
}
@ -1250,7 +1252,7 @@ struct GemmKernel224Int4 {
int32_t* a32_lo = (int32_t*)ba->get_submat(m, k, m_begin, k_block_begin + k_begin);
int32_t* a32_hi = (int32_t*)ba->get_submat(m, k, m_begin, k_block_begin + k_begin + K::K_STEP);
__m512i* b512 = (__m512i*)bb->get_submat(n, k, n_begin, k_block_begin + k_begin);
for (int m_i = 0; m_i < m && m_i < M_STEP; m_i++) {
for (int m_i = 0; m_i < m_block_end; m_i++) {
for (int k_i = 0; k_i < 16; k_i++) {
__m512i ma_lo = _mm512_set1_epi32(a32_lo[m_i * 16 + k_i]);
__m512i ma_hi = _mm512_set1_epi32(a32_hi[m_i * 16 + k_i]);
@ -1533,8 +1535,9 @@ struct GemmKernel224Int4_1 {
BufferB* bb) {
using K = GemmKernel224Int4_1;
__m512i* c512 = (__m512i*)c;
int m_block_end = std::min(m - m_begin, M_STEP);
if (k_block_begin == 0) {
for (int m_i = 0; m_i < m; m_i++) {
for (int m_i = 0; m_i < m_block_end; m_i++) {
c512[m_i * 2] = _mm512_setzero_si512();
c512[m_i * 2 + 1] = _mm512_setzero_si512();
}
@ -1543,7 +1546,7 @@ struct GemmKernel224Int4_1 {
int32_t* a32_lo = (int32_t*)ba->get_submat(m, k, m_begin, k_block_begin + k_begin);
int32_t* a32_hi = (int32_t*)ba->get_submat(m, k, m_begin, k_block_begin + k_begin + K::K_STEP);
__m512i* b512 = (__m512i*)bb->get_submat(n, k, n_begin, k_block_begin + k_begin);
for (int m_i = 0; m_i < m && m_i < M_STEP; m_i++) {
for (int m_i = 0; m_i < m_block_end; m_i++) {
for (int k_i = 0; k_i < 16; k_i++) {
__m512i ma_lo = _mm512_set1_epi32(a32_lo[m_i * 16 + k_i]);
__m512i ma_hi = _mm512_set1_epi32(a32_hi[m_i * 16 + k_i]);
@ -2193,10 +2196,11 @@ struct GemmKernel224Int4KGroup {
BufferB* bb, int k_group_size) {
using K = GemmKernel224Int4KGroup;
__m512i* c512 = (__m512i*)int_c;
int m_block_end = std::min(m - m_begin, M_STEP);
// Initialize int_c to zero at the start of k_group
if (k_block_begin % k_group_size == 0) {
for (int m_i = 0; m_i < m && m_i < M_STEP; m_i++) {
for (int m_i = 0; m_i < m_block_end; m_i++) {
c512[m_i * 2] = _mm512_setzero_si512();
c512[m_i * 2 + 1] = _mm512_setzero_si512();
}
@ -2205,7 +2209,7 @@ struct GemmKernel224Int4KGroup {
if (k_offset == 0) {
int32_t* a32_lo = (int32_t*)ba->get_submat(m, k, m_begin, k_block_begin);
__m512i* b512 = (__m512i*)bb->get_submat(n, k, n_begin, k_block_begin);
for (int m_i = 0; m_i < m && m_i < M_STEP; m_i++) {
for (int m_i = 0; m_i < m_block_end; m_i++) {
for (int k_i = 0; k_i < 16; k_i++) {
__m512i ma_lo = _mm512_set1_epi32(a32_lo[m_i * 16 + k_i]);
for (int n_i = 0; n_i < 2; n_i++) {
@ -2217,7 +2221,7 @@ struct GemmKernel224Int4KGroup {
} else {
int32_t* a32_hi = (int32_t*)ba->get_submat(m, k, m_begin, k_block_begin);
__m512i* b512 = (__m512i*)bb->get_submat(n, k, n_begin, k_block_begin - K::K_STEP);
for (int m_i = 0; m_i < m && m_i < M_STEP; m_i++) {
for (int m_i = 0; m_i < m_block_end; m_i++) {
for (int k_i = 0; k_i < 16; k_i++) {
__m512i ma_hi = _mm512_set1_epi32(a32_hi[m_i * 16 + k_i]);
for (int n_i = 0; n_i < 2; n_i++) {
@ -2471,8 +2475,9 @@ struct GemmKernel224Int4_1KGroup {
BufferB* bb, int k_group_size) {
using K = GemmKernel224Int4_1KGroup;
__m512i* c512 = (__m512i*)int_c;
int m_block_end = std::min(m - m_begin, M_STEP);
if (k_block_begin % k_group_size == 0) {
for (int m_i = 0; m_i < m; m_i++) {
for (int m_i = 0; m_i < m_block_end; m_i++) {
c512[m_i * 2] = _mm512_setzero_si512();
c512[m_i * 2 + 1] = _mm512_setzero_si512();
}
@ -2481,7 +2486,7 @@ struct GemmKernel224Int4_1KGroup {
if (k_offset == 0) {
int32_t* a32_lo = (int32_t*)ba->get_submat(m, k, m_begin, k_block_begin);
__m512i* b512 = (__m512i*)bb->get_submat(n, k, n_begin, k_block_begin);
for (int m_i = 0; m_i < m && m_i < M_STEP; m_i++) {
for (int m_i = 0; m_i < m_block_end; m_i++) {
for (int k_i = 0; k_i < 16; k_i++) {
__m512i ma_lo = _mm512_set1_epi32(a32_lo[m_i * 16 + k_i]);
for (int n_i = 0; n_i < 2; n_i++) {
@ -2493,7 +2498,7 @@ struct GemmKernel224Int4_1KGroup {
} else {
int32_t* a32_hi = (int32_t*)ba->get_submat(m, k, m_begin, k_block_begin);
__m512i* b512 = (__m512i*)bb->get_submat(n, k, n_begin, k_block_begin - K::K_STEP);
for (int m_i = 0; m_i < m && m_i < M_STEP; m_i++) {
for (int m_i = 0; m_i < m_block_end; m_i++) {
for (int k_i = 0; k_i < 16; k_i++) {
__m512i ma_hi = _mm512_set1_epi32(a32_hi[m_i * 16 + k_i]);
for (int n_i = 0; n_i < 2; n_i++) {
@ -2746,8 +2751,9 @@ struct GemmKernel224Int4_1_LowKGroup {
BufferB* bb, int k_group_size) {
using K = GemmKernel224Int4_1_LowKGroup;
__m512i* c512 = (__m512i*)int_c;
int m_block_end = std::min(m - m_begin, M_STEP);
if (k_block_begin % k_group_size == 0) {
for (int m_i = 0; m_i < m; m_i++) {
for (int m_i = 0; m_i < m_block_end; m_i++) {
c512[m_i * 2] = _mm512_setzero_si512();
c512[m_i * 2 + 1] = _mm512_setzero_si512();
}
@ -2756,7 +2762,7 @@ struct GemmKernel224Int4_1_LowKGroup {
if (k_offset == 0) {
int32_t* a32_lo = (int32_t*)ba->get_submat(m, k, m_begin, k_block_begin);
__m512i* b512 = (__m512i*)bb->get_submat(n, k, n_begin, k_block_begin);
for (int m_i = 0; m_i < m && m_i < M_STEP; m_i++) {
for (int m_i = 0; m_i < m_block_end; m_i++) {
for (int k_i = 0; k_i < 16; k_i++) {
__m512i ma_lo = _mm512_set1_epi32(a32_lo[m_i * 16 + k_i]);
for (int n_i = 0; n_i < 2; n_i++) {
@ -2768,7 +2774,7 @@ struct GemmKernel224Int4_1_LowKGroup {
} else {
int32_t* a32_hi = (int32_t*)ba->get_submat(m, k, m_begin, k_block_begin);
__m512i* b512 = (__m512i*)bb->get_submat(n, k, n_begin, k_block_begin - K::K_STEP);
for (int m_i = 0; m_i < m && m_i < M_STEP; m_i++) {
for (int m_i = 0; m_i < m_block_end; m_i++) {
for (int k_i = 0; k_i < 16; k_i++) {
__m512i ma_hi = _mm512_set1_epi32(a32_hi[m_i * 16 + k_i]);
for (int n_i = 0; n_i < 2; n_i++) {
@ -2837,6 +2843,110 @@ struct GemmKernel224Int4_1_LowKGroup {
}
};
// K2 Signed Int4 K-group quantization kernel (AVX only, no AMX)
// For K2 MoE - signed int4 range: [-8, 7]
struct GemmKernel224Int4SmallKGroup {
using dt = uint8_t; // packed int4 type
using output_t = int32_t;
static constexpr double ELEMENT_SIZE = 0.5;
static const int VNNI_BLK = 4;
static const int M_STEP = 1;
static const int N_STEP = 32;
static const int K_STEP = 32;
static inline const int N_BLOCK = 256;
// K_BLOCK should match k_group_size for proper scaling
static inline const int K_BLOCK = 7168; // Will be overridden by k_group_size
static std::string name() { return "K2_INT4_KGROUP"; }
static int recommended_nth(int n) { return (n + N_BLOCK - 1) / N_BLOCK; }
static std::pair<int, int> split_range_n(int n, int ith, int nth) {
int n_start = N_BLOCK * ith;
int n_end = std::min(n, N_BLOCK * (ith + 1));
return {n_start, n_end};
}
static void config() {}
alignas(64) static constexpr uint8_t hi_mask_arr[32] = {
0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0,
0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0
};
alignas(64) static constexpr uint8_t lo_mask_arr[32] = {
0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F,
0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F
};
alignas(64) static constexpr uint8_t sign_xor_arr[32] = {
0x88, 0x88, 0x88, 0x88, 0x88, 0x88, 0x88, 0x88, 0x88, 0x88, 0x88, 0x88, 0x88, 0x88, 0x88, 0x88,
0x88, 0x88, 0x88, 0x88, 0x88, 0x88, 0x88, 0x88, 0x88, 0x88, 0x88, 0x88, 0x88, 0x88, 0x88, 0x88
};
static __m256i hi_mask() { return *((__m256i*)(&hi_mask_arr[0])); }
static __m256i lo_mask() { return *((__m256i*)(&lo_mask_arr[0])); }
static __m256i sign_xor_mask() { return *((__m256i*)(&sign_xor_arr[0])); }
using BufferA = BufferAKGroupImpl<GemmKernel224Int4SmallKGroup>;
using BufferB = BufferBInt4KGroupImpl<GemmKernel224Int4SmallKGroup>; // Use new signed int4 buffer
using BufferC = BufferCReduceImpl<GemmKernel224Int4SmallKGroup>;
// K-group aware AVX kernel for signed int4
static inline __m512i compressed_int4_to_int8_avx512(__m256i b256) {
b256 = _mm256_xor_si256(b256, sign_xor_mask());
__m256i b_hi = _mm256_and_si256(b256, hi_mask());
__m256i b_lo = _mm256_slli_epi16(_mm256_andnot_si256(hi_mask(), b256), 4);
__m256i unpack_lo = _mm256_unpacklo_epi8(b_lo, b_hi);
__m256i unpack_hi = _mm256_unpackhi_epi8(b_lo, b_hi);
__m512i result = _mm512_inserti64x4(_mm512_castsi256_si512(unpack_lo), unpack_hi, 1);
const __m512i lane_shuffle = _mm512_set_epi64(7, 6, 3, 2, 5, 4, 1, 0);
return _mm512_permutexvar_epi64(lane_shuffle, result);
}
static inline void integer_mat_vec_kgroup(int m, int n, int k, int k_group_size, BufferA* ba, BufferB *bb, BufferC* bc, int ith, int nth) {
auto [n_start, n_end] = split_range_n(n, ith, nth);
for (int m_begin = 0; m_begin < m; m_begin ++) {
float* c = bc->get_submat(m, n, m_begin, 0);
__m512i* a512 = (__m512i*)ba->get_submat(m, k, m_begin, 0);
for (int n_block_begin = n_start; n_block_begin < n_end; n_block_begin ++) {
__m256i* b256 = (__m256i*)bb->get_submat(n, k, n_block_begin, 0);
float* as = (float*)ba->get_scale(m, m_begin, k, 0);
float* bs = (float*)bb->get_scale(n, n_block_begin, k, 0);
__m512 sum = _mm512_setzero_ps();
#define WORK_K_BLOCK(k_block) \
{ \
__m256 abscale0 = _mm256_set1_ps(as[(k_block)*2] * bs[(k_block)*2]); \
__m256 abscale1 = _mm256_set1_ps(as[(k_block)*2+1] * bs[(k_block)*2+1]); \
__m512 abscale = _mm512_insertf32x8(_mm512_castps256_ps512(abscale0), abscale1, 1); \
__m512i mul = _mm512_setzero_si512(); \
mul = _mm512_dpbssd_epi32(mul, a512[k_block], compressed_int4_to_int8_avx512(b256[k_block])); \
sum = _mm512_add_ps(sum, _mm512_mul_ps(abscale, _mm512_cvtepi32_ps(mul))); \
}
for (int k_block = 0; k_block < k / 64; k_block += 2) {
WORK_K_BLOCK(k_block);
WORK_K_BLOCK(k_block + 1);
}
c[n_block_begin] = _mm512_reduce_add_ps(sum) / 16;
}
}
}
};
inline void vec_mul_kgroup(int m, int n, int k, int k_group_size, std::shared_ptr<GemmKernel224Int4SmallKGroup::BufferA> ba,
std::shared_ptr<GemmKernel224Int4SmallKGroup::BufferB> bb,
std::shared_ptr<GemmKernel224Int4SmallKGroup::BufferC> bc, int ith, int nth) {
GemmKernel224Int4SmallKGroup::integer_mat_vec_kgroup(m, n, k, k_group_size, ba.get(), bb.get(), bc.get(), ith, nth);
}
inline void mat_mul_kgroup(int m, int n, int k, int k_group_size, std::shared_ptr<GemmKernel224Int4SmallKGroup::BufferA> ba,
std::shared_ptr<GemmKernel224Int4SmallKGroup::BufferB> bb,
std::shared_ptr<GemmKernel224Int4SmallKGroup::BufferC> bc, int ith, int nth) {
GemmKernel224Int4SmallKGroup::integer_mat_vec_kgroup(m, n, k, k_group_size, ba.get(), bb.get(), bc.get(), ith, nth);
}
// New k-group aware matrix multiplication function
template <typename K, bool amx_or_avx = true>
void integer_mat_mul_kgroup(int m, int n, int k, int k_group_size, typename K::BufferA* ba, typename K::BufferB* bb,