mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2026-05-02 13:41:15 +00:00
[Feature] Add avx-based kimi-k2 support (#1656)
Some checks are pending
Book-CI / test-2 (push) Waiting to run
Book-CI / test (push) Waiting to run
Book-CI / test-1 (push) Waiting to run
Deploy / deploy (macos-latest) (push) Waiting to run
Deploy / deploy (ubuntu-latest) (push) Waiting to run
Deploy / deploy (windows-latest) (push) Waiting to run
Some checks are pending
Book-CI / test-2 (push) Waiting to run
Book-CI / test (push) Waiting to run
Book-CI / test-1 (push) Waiting to run
Deploy / deploy (macos-latest) (push) Waiting to run
Deploy / deploy (ubuntu-latest) (push) Waiting to run
Deploy / deploy (windows-latest) (push) Waiting to run
* support Kimi-K2-Thinking original weight fix amx kernel bug * update k2 avx kernel. * feat: add CPUInfer write buffer task * [feat]: add kimi k2 cpu write buffer support - Implement write_weights_to_buffer function in k2-moe.hpp for extracting GPU expert weights - Fix down (w2) weight column-wise slicing for different TP configurations - Support three TP scenarios: cpu_tp == gpu_tp, cpu_tp > gpu_tp, cpu_tp < gpu_tp - Add comprehensive test cases for weight extraction validation - Ensure compatibility with Kimi model's MoE architecture * [fix]: correct write_weight_scale_to_buffer expert offset calculation Fixed the bug in write_weight_scale_to_buffer_task where expert offsets in GPU buffers were incorrectly calculated. Changed from using per_expert_gpu sizes to using full gpu_tp sizes, ensuring correct memory layout for multi-expert scenarios. Also added benchmark scripts for k2 moe and write buffer operations, and cleaned up debug output in test files. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com> * [feat]: add write buffer wrapper * [fix] fix comment --------- Co-authored-by: ouqingliang <1692110604@qq.com> Co-authored-by: Claude <noreply@anthropic.com>
This commit is contained in:
parent
c2b8c60c4e
commit
fcf8882075
12 changed files with 2649 additions and 34 deletions
|
|
@ -1015,8 +1015,9 @@ struct GemmKernel224Int8 {
|
|||
static void avx_kernel(int m, int n, int k, int m_begin, int n_begin, int k_block_begin, float* c, BufferA* ba,
|
||||
BufferB* bb) {
|
||||
__m512i* c512 = (__m512i*)c;
|
||||
int m_block_end = std::min(m - m_begin, M_STEP);
|
||||
if (k_block_begin == 0) {
|
||||
for (int m_i = 0; m_i < m; m_i++) {
|
||||
for (int m_i = 0; m_i < m_block_end; m_i++) {
|
||||
c512[m_i * 2] = _mm512_setzero_si512();
|
||||
c512[m_i * 2 + 1] = _mm512_setzero_si512();
|
||||
}
|
||||
|
|
@ -1028,7 +1029,7 @@ struct GemmKernel224Int8 {
|
|||
|
||||
int32_t* a32 = (int32_t*)ba->get_submat(m, k, m_begin, k_block_begin + k_begin);
|
||||
__m512i* b512 = (__m512i*)bb->get_submat(n, k, n_begin, k_block_begin + k_begin);
|
||||
for (int m_i = 0; m_i < m && m_i < M_STEP; m_i++) {
|
||||
for (int m_i = 0; m_i < m_block_end; m_i++) {
|
||||
for (int k_i = 0; k_i < 16; k_i++) {
|
||||
__m512i ma = _mm512_set1_epi32(a32[m_i * 16 + k_i]);
|
||||
for (int n_i = 0; n_i < 2; n_i++) {
|
||||
|
|
@ -1239,8 +1240,9 @@ struct GemmKernel224Int4 {
|
|||
BufferB* bb) {
|
||||
using K = GemmKernel224Int4;
|
||||
__m512i* c512 = (__m512i*)c;
|
||||
int m_block_end = std::min(m - m_begin, M_STEP);
|
||||
if (k_block_begin == 0) {
|
||||
for (int m_i = 0; m_i < m; m_i++) {
|
||||
for (int m_i = 0; m_i < m_block_end; m_i++) {
|
||||
c512[m_i * 2] = _mm512_setzero_si512();
|
||||
c512[m_i * 2 + 1] = _mm512_setzero_si512();
|
||||
}
|
||||
|
|
@ -1250,7 +1252,7 @@ struct GemmKernel224Int4 {
|
|||
int32_t* a32_lo = (int32_t*)ba->get_submat(m, k, m_begin, k_block_begin + k_begin);
|
||||
int32_t* a32_hi = (int32_t*)ba->get_submat(m, k, m_begin, k_block_begin + k_begin + K::K_STEP);
|
||||
__m512i* b512 = (__m512i*)bb->get_submat(n, k, n_begin, k_block_begin + k_begin);
|
||||
for (int m_i = 0; m_i < m && m_i < M_STEP; m_i++) {
|
||||
for (int m_i = 0; m_i < m_block_end; m_i++) {
|
||||
for (int k_i = 0; k_i < 16; k_i++) {
|
||||
__m512i ma_lo = _mm512_set1_epi32(a32_lo[m_i * 16 + k_i]);
|
||||
__m512i ma_hi = _mm512_set1_epi32(a32_hi[m_i * 16 + k_i]);
|
||||
|
|
@ -1533,8 +1535,9 @@ struct GemmKernel224Int4_1 {
|
|||
BufferB* bb) {
|
||||
using K = GemmKernel224Int4_1;
|
||||
__m512i* c512 = (__m512i*)c;
|
||||
int m_block_end = std::min(m - m_begin, M_STEP);
|
||||
if (k_block_begin == 0) {
|
||||
for (int m_i = 0; m_i < m; m_i++) {
|
||||
for (int m_i = 0; m_i < m_block_end; m_i++) {
|
||||
c512[m_i * 2] = _mm512_setzero_si512();
|
||||
c512[m_i * 2 + 1] = _mm512_setzero_si512();
|
||||
}
|
||||
|
|
@ -1543,7 +1546,7 @@ struct GemmKernel224Int4_1 {
|
|||
int32_t* a32_lo = (int32_t*)ba->get_submat(m, k, m_begin, k_block_begin + k_begin);
|
||||
int32_t* a32_hi = (int32_t*)ba->get_submat(m, k, m_begin, k_block_begin + k_begin + K::K_STEP);
|
||||
__m512i* b512 = (__m512i*)bb->get_submat(n, k, n_begin, k_block_begin + k_begin);
|
||||
for (int m_i = 0; m_i < m && m_i < M_STEP; m_i++) {
|
||||
for (int m_i = 0; m_i < m_block_end; m_i++) {
|
||||
for (int k_i = 0; k_i < 16; k_i++) {
|
||||
__m512i ma_lo = _mm512_set1_epi32(a32_lo[m_i * 16 + k_i]);
|
||||
__m512i ma_hi = _mm512_set1_epi32(a32_hi[m_i * 16 + k_i]);
|
||||
|
|
@ -2193,10 +2196,11 @@ struct GemmKernel224Int4KGroup {
|
|||
BufferB* bb, int k_group_size) {
|
||||
using K = GemmKernel224Int4KGroup;
|
||||
__m512i* c512 = (__m512i*)int_c;
|
||||
int m_block_end = std::min(m - m_begin, M_STEP);
|
||||
|
||||
// Initialize int_c to zero at the start of k_group
|
||||
if (k_block_begin % k_group_size == 0) {
|
||||
for (int m_i = 0; m_i < m && m_i < M_STEP; m_i++) {
|
||||
for (int m_i = 0; m_i < m_block_end; m_i++) {
|
||||
c512[m_i * 2] = _mm512_setzero_si512();
|
||||
c512[m_i * 2 + 1] = _mm512_setzero_si512();
|
||||
}
|
||||
|
|
@ -2205,7 +2209,7 @@ struct GemmKernel224Int4KGroup {
|
|||
if (k_offset == 0) {
|
||||
int32_t* a32_lo = (int32_t*)ba->get_submat(m, k, m_begin, k_block_begin);
|
||||
__m512i* b512 = (__m512i*)bb->get_submat(n, k, n_begin, k_block_begin);
|
||||
for (int m_i = 0; m_i < m && m_i < M_STEP; m_i++) {
|
||||
for (int m_i = 0; m_i < m_block_end; m_i++) {
|
||||
for (int k_i = 0; k_i < 16; k_i++) {
|
||||
__m512i ma_lo = _mm512_set1_epi32(a32_lo[m_i * 16 + k_i]);
|
||||
for (int n_i = 0; n_i < 2; n_i++) {
|
||||
|
|
@ -2217,7 +2221,7 @@ struct GemmKernel224Int4KGroup {
|
|||
} else {
|
||||
int32_t* a32_hi = (int32_t*)ba->get_submat(m, k, m_begin, k_block_begin);
|
||||
__m512i* b512 = (__m512i*)bb->get_submat(n, k, n_begin, k_block_begin - K::K_STEP);
|
||||
for (int m_i = 0; m_i < m && m_i < M_STEP; m_i++) {
|
||||
for (int m_i = 0; m_i < m_block_end; m_i++) {
|
||||
for (int k_i = 0; k_i < 16; k_i++) {
|
||||
__m512i ma_hi = _mm512_set1_epi32(a32_hi[m_i * 16 + k_i]);
|
||||
for (int n_i = 0; n_i < 2; n_i++) {
|
||||
|
|
@ -2471,8 +2475,9 @@ struct GemmKernel224Int4_1KGroup {
|
|||
BufferB* bb, int k_group_size) {
|
||||
using K = GemmKernel224Int4_1KGroup;
|
||||
__m512i* c512 = (__m512i*)int_c;
|
||||
int m_block_end = std::min(m - m_begin, M_STEP);
|
||||
if (k_block_begin % k_group_size == 0) {
|
||||
for (int m_i = 0; m_i < m; m_i++) {
|
||||
for (int m_i = 0; m_i < m_block_end; m_i++) {
|
||||
c512[m_i * 2] = _mm512_setzero_si512();
|
||||
c512[m_i * 2 + 1] = _mm512_setzero_si512();
|
||||
}
|
||||
|
|
@ -2481,7 +2486,7 @@ struct GemmKernel224Int4_1KGroup {
|
|||
if (k_offset == 0) {
|
||||
int32_t* a32_lo = (int32_t*)ba->get_submat(m, k, m_begin, k_block_begin);
|
||||
__m512i* b512 = (__m512i*)bb->get_submat(n, k, n_begin, k_block_begin);
|
||||
for (int m_i = 0; m_i < m && m_i < M_STEP; m_i++) {
|
||||
for (int m_i = 0; m_i < m_block_end; m_i++) {
|
||||
for (int k_i = 0; k_i < 16; k_i++) {
|
||||
__m512i ma_lo = _mm512_set1_epi32(a32_lo[m_i * 16 + k_i]);
|
||||
for (int n_i = 0; n_i < 2; n_i++) {
|
||||
|
|
@ -2493,7 +2498,7 @@ struct GemmKernel224Int4_1KGroup {
|
|||
} else {
|
||||
int32_t* a32_hi = (int32_t*)ba->get_submat(m, k, m_begin, k_block_begin);
|
||||
__m512i* b512 = (__m512i*)bb->get_submat(n, k, n_begin, k_block_begin - K::K_STEP);
|
||||
for (int m_i = 0; m_i < m && m_i < M_STEP; m_i++) {
|
||||
for (int m_i = 0; m_i < m_block_end; m_i++) {
|
||||
for (int k_i = 0; k_i < 16; k_i++) {
|
||||
__m512i ma_hi = _mm512_set1_epi32(a32_hi[m_i * 16 + k_i]);
|
||||
for (int n_i = 0; n_i < 2; n_i++) {
|
||||
|
|
@ -2746,8 +2751,9 @@ struct GemmKernel224Int4_1_LowKGroup {
|
|||
BufferB* bb, int k_group_size) {
|
||||
using K = GemmKernel224Int4_1_LowKGroup;
|
||||
__m512i* c512 = (__m512i*)int_c;
|
||||
int m_block_end = std::min(m - m_begin, M_STEP);
|
||||
if (k_block_begin % k_group_size == 0) {
|
||||
for (int m_i = 0; m_i < m; m_i++) {
|
||||
for (int m_i = 0; m_i < m_block_end; m_i++) {
|
||||
c512[m_i * 2] = _mm512_setzero_si512();
|
||||
c512[m_i * 2 + 1] = _mm512_setzero_si512();
|
||||
}
|
||||
|
|
@ -2756,7 +2762,7 @@ struct GemmKernel224Int4_1_LowKGroup {
|
|||
if (k_offset == 0) {
|
||||
int32_t* a32_lo = (int32_t*)ba->get_submat(m, k, m_begin, k_block_begin);
|
||||
__m512i* b512 = (__m512i*)bb->get_submat(n, k, n_begin, k_block_begin);
|
||||
for (int m_i = 0; m_i < m && m_i < M_STEP; m_i++) {
|
||||
for (int m_i = 0; m_i < m_block_end; m_i++) {
|
||||
for (int k_i = 0; k_i < 16; k_i++) {
|
||||
__m512i ma_lo = _mm512_set1_epi32(a32_lo[m_i * 16 + k_i]);
|
||||
for (int n_i = 0; n_i < 2; n_i++) {
|
||||
|
|
@ -2768,7 +2774,7 @@ struct GemmKernel224Int4_1_LowKGroup {
|
|||
} else {
|
||||
int32_t* a32_hi = (int32_t*)ba->get_submat(m, k, m_begin, k_block_begin);
|
||||
__m512i* b512 = (__m512i*)bb->get_submat(n, k, n_begin, k_block_begin - K::K_STEP);
|
||||
for (int m_i = 0; m_i < m && m_i < M_STEP; m_i++) {
|
||||
for (int m_i = 0; m_i < m_block_end; m_i++) {
|
||||
for (int k_i = 0; k_i < 16; k_i++) {
|
||||
__m512i ma_hi = _mm512_set1_epi32(a32_hi[m_i * 16 + k_i]);
|
||||
for (int n_i = 0; n_i < 2; n_i++) {
|
||||
|
|
@ -2837,6 +2843,110 @@ struct GemmKernel224Int4_1_LowKGroup {
|
|||
}
|
||||
};
|
||||
|
||||
// K2 Signed Int4 K-group quantization kernel (AVX only, no AMX)
|
||||
// For K2 MoE - signed int4 range: [-8, 7]
|
||||
struct GemmKernel224Int4SmallKGroup {
|
||||
using dt = uint8_t; // packed int4 type
|
||||
using output_t = int32_t;
|
||||
static constexpr double ELEMENT_SIZE = 0.5;
|
||||
static const int VNNI_BLK = 4;
|
||||
|
||||
static const int M_STEP = 1;
|
||||
static const int N_STEP = 32;
|
||||
static const int K_STEP = 32;
|
||||
|
||||
static inline const int N_BLOCK = 256;
|
||||
// K_BLOCK should match k_group_size for proper scaling
|
||||
static inline const int K_BLOCK = 7168; // Will be overridden by k_group_size
|
||||
|
||||
static std::string name() { return "K2_INT4_KGROUP"; }
|
||||
static int recommended_nth(int n) { return (n + N_BLOCK - 1) / N_BLOCK; }
|
||||
static std::pair<int, int> split_range_n(int n, int ith, int nth) {
|
||||
int n_start = N_BLOCK * ith;
|
||||
int n_end = std::min(n, N_BLOCK * (ith + 1));
|
||||
return {n_start, n_end};
|
||||
}
|
||||
static void config() {}
|
||||
|
||||
alignas(64) static constexpr uint8_t hi_mask_arr[32] = {
|
||||
0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0,
|
||||
0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0
|
||||
};
|
||||
|
||||
alignas(64) static constexpr uint8_t lo_mask_arr[32] = {
|
||||
0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F,
|
||||
0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F
|
||||
};
|
||||
|
||||
alignas(64) static constexpr uint8_t sign_xor_arr[32] = {
|
||||
0x88, 0x88, 0x88, 0x88, 0x88, 0x88, 0x88, 0x88, 0x88, 0x88, 0x88, 0x88, 0x88, 0x88, 0x88, 0x88,
|
||||
0x88, 0x88, 0x88, 0x88, 0x88, 0x88, 0x88, 0x88, 0x88, 0x88, 0x88, 0x88, 0x88, 0x88, 0x88, 0x88
|
||||
};
|
||||
static __m256i hi_mask() { return *((__m256i*)(&hi_mask_arr[0])); }
|
||||
static __m256i lo_mask() { return *((__m256i*)(&lo_mask_arr[0])); }
|
||||
static __m256i sign_xor_mask() { return *((__m256i*)(&sign_xor_arr[0])); }
|
||||
|
||||
using BufferA = BufferAKGroupImpl<GemmKernel224Int4SmallKGroup>;
|
||||
using BufferB = BufferBInt4KGroupImpl<GemmKernel224Int4SmallKGroup>; // Use new signed int4 buffer
|
||||
using BufferC = BufferCReduceImpl<GemmKernel224Int4SmallKGroup>;
|
||||
|
||||
// K-group aware AVX kernel for signed int4
|
||||
static inline __m512i compressed_int4_to_int8_avx512(__m256i b256) {
|
||||
b256 = _mm256_xor_si256(b256, sign_xor_mask());
|
||||
__m256i b_hi = _mm256_and_si256(b256, hi_mask());
|
||||
__m256i b_lo = _mm256_slli_epi16(_mm256_andnot_si256(hi_mask(), b256), 4);
|
||||
|
||||
__m256i unpack_lo = _mm256_unpacklo_epi8(b_lo, b_hi);
|
||||
__m256i unpack_hi = _mm256_unpackhi_epi8(b_lo, b_hi);
|
||||
__m512i result = _mm512_inserti64x4(_mm512_castsi256_si512(unpack_lo), unpack_hi, 1);
|
||||
const __m512i lane_shuffle = _mm512_set_epi64(7, 6, 3, 2, 5, 4, 1, 0);
|
||||
return _mm512_permutexvar_epi64(lane_shuffle, result);
|
||||
}
|
||||
static inline void integer_mat_vec_kgroup(int m, int n, int k, int k_group_size, BufferA* ba, BufferB *bb, BufferC* bc, int ith, int nth) {
|
||||
auto [n_start, n_end] = split_range_n(n, ith, nth);
|
||||
for (int m_begin = 0; m_begin < m; m_begin ++) {
|
||||
float* c = bc->get_submat(m, n, m_begin, 0);
|
||||
__m512i* a512 = (__m512i*)ba->get_submat(m, k, m_begin, 0);
|
||||
|
||||
for (int n_block_begin = n_start; n_block_begin < n_end; n_block_begin ++) {
|
||||
__m256i* b256 = (__m256i*)bb->get_submat(n, k, n_block_begin, 0);
|
||||
float* as = (float*)ba->get_scale(m, m_begin, k, 0);
|
||||
float* bs = (float*)bb->get_scale(n, n_block_begin, k, 0);
|
||||
|
||||
__m512 sum = _mm512_setzero_ps();
|
||||
#define WORK_K_BLOCK(k_block) \
|
||||
{ \
|
||||
__m256 abscale0 = _mm256_set1_ps(as[(k_block)*2] * bs[(k_block)*2]); \
|
||||
__m256 abscale1 = _mm256_set1_ps(as[(k_block)*2+1] * bs[(k_block)*2+1]); \
|
||||
__m512 abscale = _mm512_insertf32x8(_mm512_castps256_ps512(abscale0), abscale1, 1); \
|
||||
__m512i mul = _mm512_setzero_si512(); \
|
||||
mul = _mm512_dpbssd_epi32(mul, a512[k_block], compressed_int4_to_int8_avx512(b256[k_block])); \
|
||||
sum = _mm512_add_ps(sum, _mm512_mul_ps(abscale, _mm512_cvtepi32_ps(mul))); \
|
||||
}
|
||||
|
||||
for (int k_block = 0; k_block < k / 64; k_block += 2) {
|
||||
WORK_K_BLOCK(k_block);
|
||||
WORK_K_BLOCK(k_block + 1);
|
||||
}
|
||||
|
||||
c[n_block_begin] = _mm512_reduce_add_ps(sum) / 16;
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
inline void vec_mul_kgroup(int m, int n, int k, int k_group_size, std::shared_ptr<GemmKernel224Int4SmallKGroup::BufferA> ba,
|
||||
std::shared_ptr<GemmKernel224Int4SmallKGroup::BufferB> bb,
|
||||
std::shared_ptr<GemmKernel224Int4SmallKGroup::BufferC> bc, int ith, int nth) {
|
||||
GemmKernel224Int4SmallKGroup::integer_mat_vec_kgroup(m, n, k, k_group_size, ba.get(), bb.get(), bc.get(), ith, nth);
|
||||
}
|
||||
|
||||
inline void mat_mul_kgroup(int m, int n, int k, int k_group_size, std::shared_ptr<GemmKernel224Int4SmallKGroup::BufferA> ba,
|
||||
std::shared_ptr<GemmKernel224Int4SmallKGroup::BufferB> bb,
|
||||
std::shared_ptr<GemmKernel224Int4SmallKGroup::BufferC> bc, int ith, int nth) {
|
||||
GemmKernel224Int4SmallKGroup::integer_mat_vec_kgroup(m, n, k, k_group_size, ba.get(), bb.get(), bc.get(), ith, nth);
|
||||
}
|
||||
|
||||
// New k-group aware matrix multiplication function
|
||||
template <typename K, bool amx_or_avx = true>
|
||||
void integer_mat_mul_kgroup(int m, int n, int k, int k_group_size, typename K::BufferA* ba, typename K::BufferB* bb,
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue