diff --git a/.clang-format b/.clang-format new file mode 100644 index 0000000..fd64bc4 --- /dev/null +++ b/.clang-format @@ -0,0 +1,5 @@ +--- +BasedOnStyle: LLVM +ColumnLimit: 120 # 设置最大行宽为 100 +IndentWidth: 2 +--- diff --git a/csrc/ktransformers_ext/CMakeLists.txt b/csrc/ktransformers_ext/CMakeLists.txt index 3a86dd7..e435dee 100644 --- a/csrc/ktransformers_ext/CMakeLists.txt +++ b/csrc/ktransformers_ext/CMakeLists.txt @@ -293,9 +293,10 @@ aux_source_directory(${CMAKE_CURRENT_SOURCE_DIR}/cpu_backend SOURCE_DIR2) aux_source_directory(${CMAKE_CURRENT_SOURCE_DIR}/operators/llamafile SOURCE_DIR3) aux_source_directory(${CMAKE_CURRENT_SOURCE_DIR}/../../third_party/llamafile SOURCE_DIR4) aux_source_directory(${CMAKE_CURRENT_SOURCE_DIR}/operators/kvcache SOURCE_DIR5) +aux_source_directory(${CMAKE_CURRENT_SOURCE_DIR}/operators/amx SOURCE_DIR6) -set(ALL_SOURCES ${SOURCE_DIR1} ${SOURCE_DIR2} ${SOURCE_DIR3} ${SOURCE_DIR4} ${SOURCE_DIR5}) +set(ALL_SOURCES ${SOURCE_DIR1} ${SOURCE_DIR2} ${SOURCE_DIR3} ${SOURCE_DIR4} ${SOURCE_DIR5} ${SOURCE_DIR6}) file(GLOB_RECURSE FMT_SOURCES "${CMAKE_CURRENT_SOURCE_DIR}/*.cpp" "${CMAKE_CURRENT_SOURCE_DIR}/*.hpp" "${CMAKE_CURRENT_SOURCE_DIR}/*.h") diff --git a/csrc/ktransformers_ext/bench/bench_moe_amx.py b/csrc/ktransformers_ext/bench/bench_moe_amx.py new file mode 100644 index 0000000..1101ff8 --- /dev/null +++ b/csrc/ktransformers_ext/bench/bench_moe_amx.py @@ -0,0 +1,107 @@ +#!/usr/bin/env python +# coding=utf-8 +''' +Description : +Author : chenht2022 +Date : 2025-04-25 18:28:12 +Version : 1.0.0 +LastEditors : chenht2022 +LastEditTime : 2025-04-25 18:28:12 +Copyright (c) 2024 by KVCache.AI, All Rights Reserved. +''' +import os, sys +import time +sys.path.append(os.path.dirname(__file__) + '/../build') +import cpuinfer_ext +import torch + +expert_num = 8 +hidden_size = 7168 +intermediate_size = 2048 +max_len = 25600 +n_routed_experts = 8 +layer_num = 10 +qlen = 1024 +CPUInfer = cpuinfer_ext.CPUInfer(65) +warm_up_iter = 100 +test_iter = 100 + +def bench_moe(quant_mode: str): + with torch.inference_mode(mode=True): + if quant_mode == "bf16": + bytes_per_elem = 2.000000 + elif quant_mode == "int8": + bytes_per_elem = 1.000000 + else: + assert(False) + + + moes = [] + gate_projs = [] + up_projs = [] + down_projs = [] + for _ in range(layer_num): + gate_proj = torch.randn((expert_num, intermediate_size, hidden_size), dtype=torch.float32, device = "cuda").to("cpu").contiguous() + up_proj = torch.randn((expert_num, intermediate_size, hidden_size), dtype=torch.float32, device = "cuda").to("cpu").contiguous() + down_proj = torch.randn((expert_num, hidden_size, intermediate_size), dtype=torch.float32, device = "cuda").to("cpu").contiguous() + config = cpuinfer_ext.moe.AMX_MOEConfig(expert_num, n_routed_experts, hidden_size, intermediate_size, max_len, gate_proj.data_ptr(), up_proj.data_ptr(), down_proj.data_ptr()) + if quant_mode == "bf16": + moe = cpuinfer_ext.moe.AMXBF16_MOE(config) + CPUInfer.submit(moe.load_weights()) + CPUInfer.sync() + elif quant_mode == "int8": + moe = cpuinfer_ext.moe.AMXInt8_MOE(config) + CPUInfer.submit(moe.load_weights()) + CPUInfer.sync() + gate_projs.append(gate_proj) + up_projs.append(up_proj) + down_projs.append(down_proj) + moes.append(moe) + expert_ids = torch.stack([torch.stack([torch.randperm(expert_num, dtype=torch.int64, device = "cuda")[:n_routed_experts] for _ in range(qlen)]) for _ in range(layer_num)]).to("cpu").contiguous() + weights = torch.rand((layer_num, qlen, n_routed_experts), dtype=torch.float32, device = "cuda").to("cpu").contiguous() + input = torch.randn((layer_num, qlen, hidden_size), dtype=torch.bfloat16, device = "cuda").to("cpu").contiguous() + output = torch.empty((layer_num, qlen, hidden_size), dtype=torch.bfloat16, device = "cuda").to("cpu").contiguous() + qlen_tensor = torch.tensor([qlen], dtype=torch.int32) + + # warm up + for i in range(warm_up_iter): + CPUInfer.submit( + moes[i % layer_num].forward( + qlen, + n_routed_experts, + expert_ids[i % layer_num].data_ptr(), + weights[i % layer_num].data_ptr(), + input[i % layer_num].data_ptr(), + output[i % layer_num].data_ptr(), + qlen_tensor.data_ptr() + ) + ) + CPUInfer.sync() + + # test + start = time.perf_counter() + for i in range(test_iter): + CPUInfer.submit( + moes[i % layer_num].forward( + qlen, + n_routed_experts, + expert_ids[i % layer_num].data_ptr(), + weights[i % layer_num].data_ptr(), + input[i % layer_num].data_ptr(), + output[i % layer_num].data_ptr(), + qlen_tensor.data_ptr() + ) + ) + CPUInfer.sync() + end = time.perf_counter() + total_time = end - start + print('Quant mode: ', quant_mode) + print('Time(s): ', total_time) + print('Iteration: ', test_iter) + print('Time(us) per iteration: ', total_time / test_iter * 1000000) + print('Bandwidth: ', hidden_size * intermediate_size * 3 * n_routed_experts * bytes_per_elem * test_iter / total_time / 1000 / 1000 / 1000, 'GB/s') + print('Flops: ', hidden_size * intermediate_size * qlen * 3 * n_routed_experts * 2 * test_iter / total_time / 1000 / 1000 / 1000, 'GFLOPS') + print('') + +bench_moe("bf16") +bench_moe("int8") diff --git a/csrc/ktransformers_ext/operators/llamafile/shared_mem_buffer.cpp b/csrc/ktransformers_ext/cpu_backend/shared_mem_buffer.cpp similarity index 96% rename from csrc/ktransformers_ext/operators/llamafile/shared_mem_buffer.cpp rename to csrc/ktransformers_ext/cpu_backend/shared_mem_buffer.cpp index dc2d65d..0fe7275 100644 --- a/csrc/ktransformers_ext/operators/llamafile/shared_mem_buffer.cpp +++ b/csrc/ktransformers_ext/cpu_backend/shared_mem_buffer.cpp @@ -30,7 +30,8 @@ void SharedMemBuffer::alloc(void* object, std::vector> requests) *(request.first) = (uint8_t*)buffer_ + offset; offset += request.second; } -} +} \ No newline at end of file diff --git a/csrc/ktransformers_ext/cpu_backend/shared_mem_buffer.h b/csrc/ktransformers_ext/cpu_backend/shared_mem_buffer.h new file mode 100644 index 0000000..58e7aa6 --- /dev/null +++ b/csrc/ktransformers_ext/cpu_backend/shared_mem_buffer.h @@ -0,0 +1,37 @@ +/** + * @Description : + * @Author : chenht2022 + * @Date : 2024-08-05 04:49:08 + * @Version : 1.0.0 + * @LastEditors : chenht2022 + * @LastEditTime : 2024-08-05 06:36:41 + * @Copyright (c) 2024 by KVCache.AI, All Rights Reserved. + **/ + + #ifndef CPUINFER_SHAREDMEMBUFFER_H + #define CPUINFER_SHAREDMEMBUFFER_H + + #include + #include + #include + #include + + class SharedMemBuffer { + public: + SharedMemBuffer(); + ~SharedMemBuffer(); + + void alloc(void* object, std::vector> requests); + void dealloc(void* object); + + private: + void* buffer_; + uint64_t size_; + std::map>>> hist_requests_; + + void arrange(std::vector> requests); + }; + + static SharedMemBuffer shared_mem_buffer; + + #endif \ No newline at end of file diff --git a/csrc/ktransformers_ext/ext_bindings.cpp b/csrc/ktransformers_ext/ext_bindings.cpp index 44e6e6b..7bb0c89 100644 --- a/csrc/ktransformers_ext/ext_bindings.cpp +++ b/csrc/ktransformers_ext/ext_bindings.cpp @@ -17,6 +17,7 @@ #include "operators/llamafile/linear.h" #include "operators/llamafile/mlp.h" #include "operators/llamafile/moe.h" +#include "operators/amx/moe.hpp" #include "pybind11/functional.h" #include "pybind11/operators.h" #include "pybind11/pybind11.h" @@ -563,6 +564,75 @@ class MOEBindings { }; }; +template +class AMX_MOEBindings { + public: + class WarmUpBindings { + public: + struct Args { + CPUInfer *cpuinfer; + AMX_MOE *moe; + }; + static void inner(void *args) { + Args *args_ = (Args *)args; + args_->cpuinfer->enqueue(&AMX_MOE::warm_up, args_->moe); + } + static std::pair cpuinfer_interface(AMX_MOE &moe) { + Args *args = new Args{nullptr, &moe}; + return std::make_pair((intptr_t)&inner, (intptr_t)args); + } + }; + class LoadWeightsBindings { + public: + struct Args { + CPUInfer *cpuinfer; + AMX_MOE *moe; + }; + static void inner(void *args) { + Args *args_ = (Args *)args; + args_->cpuinfer->enqueue(&AMX_MOE::load_weights, args_->moe); + } + static std::pair cpuinfer_interface(AMX_MOE &moe) { + Args *args = new Args{nullptr, &moe}; + return std::make_pair((intptr_t)&inner, (intptr_t)args); + } + }; + class ForwardBindings { + public: + struct Args { + CPUInfer *cpuinfer; + AMX_MOE *moe; + int qlen; + int k; + const uint64_t *expert_ids; + const float *weights; + const void *input; + void *output; + int *batch_size_tensor; + }; + static void inner(void *args) { + Args *args_ = (Args *)args; + args_->cpuinfer->enqueue( + &AMX_MOE::forward, args_->moe, args_->qlen, args_->k, + args_->expert_ids, args_->weights, args_->input, args_->output, args_->batch_size_tensor); + } + static std::pair + cpuinfer_interface(AMX_MOE &moe, int qlen, int k, intptr_t expert_ids, + intptr_t weights, intptr_t input, intptr_t output, intptr_t batch_size_tensor) { + Args *args = new Args{nullptr, + &moe, + qlen, + k, + (const uint64_t *)expert_ids, + (const float *)weights, + (const void *)input, + (void *)output, + (int *)batch_size_tensor}; + return std::make_pair((intptr_t)&inner, (intptr_t)args); + } + }; +}; + PYBIND11_MODULE(cpuinfer_ext, m) { py::class_(m, "CPUInfer") .def(py::init()) @@ -621,6 +691,27 @@ PYBIND11_MODULE(cpuinfer_ext, m) { .def("warm_up", &MOEBindings::WarmUpBindinds::cpuinfer_interface) .def("forward", &MOEBindings::ForwardBindings::cpuinfer_interface); + py::class_(moe_module, "AMX_MOEConfig") + .def(py::init([](int expert_num, int routed_expert_num, int hidden_size, + int intermediate_size, + int max_len, intptr_t gate_proj, + intptr_t up_proj, intptr_t down_proj) { + return AMX_MOEConfig(expert_num, routed_expert_num, hidden_size, + intermediate_size, + max_len, (void *)gate_proj, + (void *)up_proj, (void *)down_proj); + })); + py::class_>(moe_module, "AMXBF16_MOE") + .def(py::init()) + .def("warm_up", &AMX_MOEBindings::WarmUpBindings::cpuinfer_interface) + .def("load_weights", &AMX_MOEBindings::LoadWeightsBindings::cpuinfer_interface) + .def("forward", &AMX_MOEBindings::ForwardBindings::cpuinfer_interface); + py::class_>(moe_module, "AMXInt8_MOE") + .def(py::init()) + .def("warm_up", &AMX_MOEBindings::WarmUpBindings::cpuinfer_interface) + .def("load_weights", &AMX_MOEBindings::LoadWeightsBindings::cpuinfer_interface) + .def("forward", &AMX_MOEBindings::ForwardBindings::cpuinfer_interface); + auto kvcache_module = m.def_submodule("kvcache"); py::enum_(kvcache_module, "AnchorType") diff --git a/csrc/ktransformers_ext/operators/amx/la/amx.hpp b/csrc/ktransformers_ext/operators/amx/la/amx.hpp new file mode 100644 index 0000000..3338e09 --- /dev/null +++ b/csrc/ktransformers_ext/operators/amx/la/amx.hpp @@ -0,0 +1,974 @@ +/** + * @Description : + * @Author : chenht2022 + * @Date : 2025-04-25 18:28:12 + * @Version : 1.0.0 + * @LastEditors : chenht2022 + * @LastEditTime : 2025-04-25 18:28:12 + * @Copyright (c) 2024 by KVCache.AI, All Rights Reserved. + **/ + +#pragma once +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "utils.hpp" +#include + +#if (defined(_WIN32) || defined(_WIN64)) +#define RESTRICT __restrict +#else +#define RESTRICT __restrict__ +#endif + +#if (defined(_WIN32) || defined(_WIN64)) +#define ALWAYS_INLINE __forceinline +#elif __has_attribute(always_inline) || defined(__GNUC__) +#define ALWAYS_INLINE __attribute__((__always_inline__)) inline +#else +#define ALWAYS_INLINE inline +#endif + +namespace amx { + +#define ARCH_GET_XCOMP_PERM 0x1022 +#define ARCH_REQ_XCOMP_PERM 0x1023 +#define XFEATURE_XTILECFG 17 +#define XFEATURE_XTILEDATA 18 + +const int TMMCount = 8; +const int MaxTileHeight = 16; +const int MaxTileWidth = 64; + +const int AMX_BLK_SIZE = 32; + +#define TMM0 0 +#define TMM1 1 +#define TMM2 2 +#define TMM3 3 +#define TMM4 4 +#define TMM5 5 +#define TMM6 6 +#define TMM7 7 + +inline bool enable_amx() { + static thread_local bool initialized = false; + if (initialized) { + return true; + } + initialized = true; + + if (syscall(SYS_arch_prctl, ARCH_REQ_XCOMP_PERM, XFEATURE_XTILEDATA)) { + printf("\n Fail to do XFEATURE_XTILEDATA \n\n"); + return false; + } else { + // printf("\n TILE DATA USE SET - OK \n\n"); + return true; + } + return true; +} + +struct alignas(64) TileConfig { + uint8_t palette; + uint8_t start_row; + std::array __0 = {}; + std::array colsb; + std::array __1 = {}; + std::array rows; + std::array __2 = {}; + + TileConfig() { + palette = 1; + start_row = 0; + for (int i = 0; i < 8; i++) { + set_row_col(i, 0, 0); + } + } + + void set_row_col(int i, uint8_t row, uint16_t col) { + colsb[i] = col; + rows[i] = row; + } + + void set_config() { _tile_loadconfig(this); } + + static void load_data(int to, void *from, size_t stride) { + switch (to) { + case 0: + _tile_loadd(0, from, stride); + break; + case 1: + _tile_loadd(1, from, stride); + break; + case 2: + _tile_loadd(2, from, stride); + break; + case 3: + _tile_loadd(3, from, stride); + break; + case 4: + _tile_loadd(4, from, stride); + break; + case 5: + _tile_loadd(5, from, stride); + break; + case 6: + _tile_loadd(6, from, stride); + break; + case 7: + _tile_loadd(7, from, stride); + break; + default: + throw std::runtime_error("no such tile"); + } + } + + static void store_data(int from, void *to, size_t stride) { + switch (from) { + case 0: + _tile_stored(0, to, stride); + break; + case 1: + _tile_stored(1, to, stride); + break; + case 2: + _tile_stored(2, to, stride); + break; + case 3: + _tile_stored(3, to, stride); + break; + case 4: + _tile_stored(4, to, stride); + break; + case 5: + _tile_stored(5, to, stride); + break; + case 6: + _tile_stored(6, to, stride); + break; + case 7: + _tile_stored(7, to, stride); + break; + default: + throw std::runtime_error("no such tile"); + } + } +}; + +static_assert(sizeof(TileConfig) == 64); + +inline void debug_tile(int t) { + printf("Tile %d\n", t); + uint8_t data[16][64] = {}; + TileConfig::store_data(t, data, 64); + for (int i = 0; i < 16; i++) { + for (int j = 0; j < 64; j++) { + printf("%3d ", data[i][j]); + } + printf("\n"); + } + printf("\n"); +} + +inline void debug_tiles(int to = 8) { + for (int i = 0; i < to; i++) { + debug_tile(i); + } +} + +inline void debug_m512(__m512 x) { + float data[16]; + _mm512_storeu_ps(data, x); + for (int i = 0; i < 16; i++) { + printf("%f ", data[i]); + } + printf("\n"); +} + +// transpose utils +inline void transpose_16x16_32bit(__m512i *v) { + __m512i v1[16]; + v1[0] = _mm512_unpacklo_epi32(v[0], v[1]); + v1[1] = _mm512_unpackhi_epi32(v[0], v[1]); + v1[2] = _mm512_unpacklo_epi32(v[2], v[3]); + v1[3] = _mm512_unpackhi_epi32(v[2], v[3]); + v1[4] = _mm512_unpacklo_epi32(v[4], v[5]); + v1[5] = _mm512_unpackhi_epi32(v[4], v[5]); + v1[6] = _mm512_unpacklo_epi32(v[6], v[7]); + v1[7] = _mm512_unpackhi_epi32(v[6], v[7]); + v1[8] = _mm512_unpacklo_epi32(v[8], v[9]); + v1[9] = _mm512_unpackhi_epi32(v[8], v[9]); + v1[10] = _mm512_unpacklo_epi32(v[10], v[11]); + v1[11] = _mm512_unpackhi_epi32(v[10], v[11]); + v1[12] = _mm512_unpacklo_epi32(v[12], v[13]); + v1[13] = _mm512_unpackhi_epi32(v[12], v[13]); + v1[14] = _mm512_unpacklo_epi32(v[14], v[15]); + v1[15] = _mm512_unpackhi_epi32(v[14], v[15]); + + v[0] = _mm512_unpacklo_epi64(v1[0], v1[2]); + v[1] = _mm512_unpackhi_epi64(v1[0], v1[2]); + v[2] = _mm512_unpacklo_epi64(v1[1], v1[3]); + v[3] = _mm512_unpackhi_epi64(v1[1], v1[3]); + v[4] = _mm512_unpacklo_epi64(v1[4], v1[6]); + v[5] = _mm512_unpackhi_epi64(v1[4], v1[6]); + v[6] = _mm512_unpacklo_epi64(v1[5], v1[7]); + v[7] = _mm512_unpackhi_epi64(v1[5], v1[7]); + v[8] = _mm512_unpacklo_epi64(v1[8], v1[10]); + v[9] = _mm512_unpackhi_epi64(v1[8], v1[10]); + v[10] = _mm512_unpacklo_epi64(v1[9], v1[11]); + v[11] = _mm512_unpackhi_epi64(v1[9], v1[11]); + v[12] = _mm512_unpacklo_epi64(v1[12], v1[14]); + v[13] = _mm512_unpackhi_epi64(v1[12], v1[14]); + v[14] = _mm512_unpacklo_epi64(v1[13], v1[15]); + v[15] = _mm512_unpackhi_epi64(v1[13], v1[15]); + + v1[0] = _mm512_shuffle_i32x4(v[0], v[4], 0x88); + v1[1] = _mm512_shuffle_i32x4(v[1], v[5], 0x88); + v1[2] = _mm512_shuffle_i32x4(v[2], v[6], 0x88); + v1[3] = _mm512_shuffle_i32x4(v[3], v[7], 0x88); + v1[4] = _mm512_shuffle_i32x4(v[0], v[4], 0xdd); + v1[5] = _mm512_shuffle_i32x4(v[1], v[5], 0xdd); + v1[6] = _mm512_shuffle_i32x4(v[2], v[6], 0xdd); + v1[7] = _mm512_shuffle_i32x4(v[3], v[7], 0xdd); + v1[8] = _mm512_shuffle_i32x4(v[8], v[12], 0x88); + v1[9] = _mm512_shuffle_i32x4(v[9], v[13], 0x88); + v1[10] = _mm512_shuffle_i32x4(v[10], v[14], 0x88); + v1[11] = _mm512_shuffle_i32x4(v[11], v[15], 0x88); + v1[12] = _mm512_shuffle_i32x4(v[8], v[12], 0xdd); + v1[13] = _mm512_shuffle_i32x4(v[9], v[13], 0xdd); + v1[14] = _mm512_shuffle_i32x4(v[10], v[14], 0xdd); + v1[15] = _mm512_shuffle_i32x4(v[11], v[15], 0xdd); + + v[0] = _mm512_shuffle_i32x4(v1[0], v1[8], 0x88); + v[1] = _mm512_shuffle_i32x4(v1[1], v1[9], 0x88); + v[2] = _mm512_shuffle_i32x4(v1[2], v1[10], 0x88); + v[3] = _mm512_shuffle_i32x4(v1[3], v1[11], 0x88); + v[4] = _mm512_shuffle_i32x4(v1[4], v1[12], 0x88); + v[5] = _mm512_shuffle_i32x4(v1[5], v1[13], 0x88); + v[6] = _mm512_shuffle_i32x4(v1[6], v1[14], 0x88); + v[7] = _mm512_shuffle_i32x4(v1[7], v1[15], 0x88); + v[8] = _mm512_shuffle_i32x4(v1[0], v1[8], 0xdd); + v[9] = _mm512_shuffle_i32x4(v1[1], v1[9], 0xdd); + v[10] = _mm512_shuffle_i32x4(v1[2], v1[10], 0xdd); + v[11] = _mm512_shuffle_i32x4(v1[3], v1[11], 0xdd); + v[12] = _mm512_shuffle_i32x4(v1[4], v1[12], 0xdd); + v[13] = _mm512_shuffle_i32x4(v1[5], v1[13], 0xdd); + v[14] = _mm512_shuffle_i32x4(v1[6], v1[14], 0xdd); + v[15] = _mm512_shuffle_i32x4(v1[7], v1[15], 0xdd); +} + +/* + Transpose 16x16 32-bit elements + Note that v must be 64 byte aligned +*/ +inline void transpose_16x16_32bit(__m512i *v, size_t stride) { + assert(reinterpret_cast(v) % 64 == 0 && "v must be 64 aligned"); + + auto stride_v = [=](int i) { return offset_pointer(v, i * stride); }; + __m512i v1[16]; + + v1[0] = _mm512_unpacklo_epi32(*stride_v(0), *stride_v(1)); + v1[1] = _mm512_unpackhi_epi32(*stride_v(0), *stride_v(1)); + v1[2] = _mm512_unpacklo_epi32(*stride_v(2), *stride_v(3)); + v1[3] = _mm512_unpackhi_epi32(*stride_v(2), *stride_v(3)); + v1[4] = _mm512_unpacklo_epi32(*stride_v(4), *stride_v(5)); + v1[5] = _mm512_unpackhi_epi32(*stride_v(4), *stride_v(5)); + v1[6] = _mm512_unpacklo_epi32(*stride_v(6), *stride_v(7)); + v1[7] = _mm512_unpackhi_epi32(*stride_v(6), *stride_v(7)); + v1[8] = _mm512_unpacklo_epi32(*stride_v(8), *stride_v(9)); + v1[9] = _mm512_unpackhi_epi32(*stride_v(8), *stride_v(9)); + v1[10] = _mm512_unpacklo_epi32(*stride_v(10), *stride_v(11)); + v1[11] = _mm512_unpackhi_epi32(*stride_v(10), *stride_v(11)); + v1[12] = _mm512_unpacklo_epi32(*stride_v(12), *stride_v(13)); + v1[13] = _mm512_unpackhi_epi32(*stride_v(12), *stride_v(13)); + v1[14] = _mm512_unpacklo_epi32(*stride_v(14), *stride_v(15)); + v1[15] = _mm512_unpackhi_epi32(*stride_v(14), *stride_v(15)); + + *stride_v(0) = _mm512_unpacklo_epi64(v1[0], v1[2]); + *stride_v(1) = _mm512_unpackhi_epi64(v1[0], v1[2]); + *stride_v(2) = _mm512_unpacklo_epi64(v1[1], v1[3]); + *stride_v(3) = _mm512_unpackhi_epi64(v1[1], v1[3]); + *stride_v(4) = _mm512_unpacklo_epi64(v1[4], v1[6]); + *stride_v(5) = _mm512_unpackhi_epi64(v1[4], v1[6]); + *stride_v(6) = _mm512_unpacklo_epi64(v1[5], v1[7]); + *stride_v(7) = _mm512_unpackhi_epi64(v1[5], v1[7]); + *stride_v(8) = _mm512_unpacklo_epi64(v1[8], v1[10]); + *stride_v(9) = _mm512_unpackhi_epi64(v1[8], v1[10]); + *stride_v(10) = _mm512_unpacklo_epi64(v1[9], v1[11]); + *stride_v(11) = _mm512_unpackhi_epi64(v1[9], v1[11]); + *stride_v(12) = _mm512_unpacklo_epi64(v1[12], v1[14]); + *stride_v(13) = _mm512_unpackhi_epi64(v1[12], v1[14]); + *stride_v(14) = _mm512_unpacklo_epi64(v1[13], v1[15]); + *stride_v(15) = _mm512_unpackhi_epi64(v1[13], v1[15]); + + v1[0] = _mm512_shuffle_i32x4(*stride_v(0), *stride_v(4), 0x88); + v1[1] = _mm512_shuffle_i32x4(*stride_v(1), *stride_v(5), 0x88); + v1[2] = _mm512_shuffle_i32x4(*stride_v(2), *stride_v(6), 0x88); + v1[3] = _mm512_shuffle_i32x4(*stride_v(3), *stride_v(7), 0x88); + v1[4] = _mm512_shuffle_i32x4(*stride_v(0), *stride_v(4), 0xdd); + v1[5] = _mm512_shuffle_i32x4(*stride_v(1), *stride_v(5), 0xdd); + v1[6] = _mm512_shuffle_i32x4(*stride_v(2), *stride_v(6), 0xdd); + v1[7] = _mm512_shuffle_i32x4(*stride_v(3), *stride_v(7), 0xdd); + v1[8] = _mm512_shuffle_i32x4(*stride_v(8), *stride_v(12), 0x88); + v1[9] = _mm512_shuffle_i32x4(*stride_v(9), *stride_v(13), 0x88); + v1[10] = _mm512_shuffle_i32x4(*stride_v(10), *stride_v(14), 0x88); + v1[11] = _mm512_shuffle_i32x4(*stride_v(11), *stride_v(15), 0x88); + v1[12] = _mm512_shuffle_i32x4(*stride_v(8), *stride_v(12), 0xdd); + v1[13] = _mm512_shuffle_i32x4(*stride_v(9), *stride_v(13), 0xdd); + v1[14] = _mm512_shuffle_i32x4(*stride_v(10), *stride_v(14), 0xdd); + v1[15] = _mm512_shuffle_i32x4(*stride_v(11), *stride_v(15), 0xdd); + + *stride_v(0) = _mm512_shuffle_i32x4(v1[0], v1[8], 0x88); + *stride_v(1) = _mm512_shuffle_i32x4(v1[1], v1[9], 0x88); + *stride_v(2) = _mm512_shuffle_i32x4(v1[2], v1[10], 0x88); + *stride_v(3) = _mm512_shuffle_i32x4(v1[3], v1[11], 0x88); + *stride_v(4) = _mm512_shuffle_i32x4(v1[4], v1[12], 0x88); + *stride_v(5) = _mm512_shuffle_i32x4(v1[5], v1[13], 0x88); + *stride_v(6) = _mm512_shuffle_i32x4(v1[6], v1[14], 0x88); + *stride_v(7) = _mm512_shuffle_i32x4(v1[7], v1[15], 0x88); + *stride_v(8) = _mm512_shuffle_i32x4(v1[0], v1[8], 0xdd); + *stride_v(9) = _mm512_shuffle_i32x4(v1[1], v1[9], 0xdd); + *stride_v(10) = _mm512_shuffle_i32x4(v1[2], v1[10], 0xdd); + *stride_v(11) = _mm512_shuffle_i32x4(v1[3], v1[11], 0xdd); + *stride_v(12) = _mm512_shuffle_i32x4(v1[4], v1[12], 0xdd); + *stride_v(13) = _mm512_shuffle_i32x4(v1[5], v1[13], 0xdd); + *stride_v(14) = _mm512_shuffle_i32x4(v1[6], v1[14], 0xdd); + *stride_v(15) = _mm512_shuffle_i32x4(v1[7], v1[15], 0xdd); +} + +struct GemmKernel224BF { + using dt = ggml_bf16_t; + using output_t = float; + static const int TILE_M = 16; + static const int TILE_K = 32; + static const int TILE_N = 16; + static const int VNNI_BLK = 2; + + static const int M_STEP = TILE_M * 2; + static const int N_STEP = TILE_N * 2; + static const int K_STEP = TILE_K; + + static inline const int N_BLOCK = 256; + static inline const int K_BLOCK = 1792; + + static int recommended_nth(int n) { return (n + N_BLOCK - 1) / N_BLOCK; } + + static std::pair split_range_n(int n, int ith, int nth) { + int n_start = N_BLOCK * ith; + int n_end = std::min(n, N_BLOCK * (ith + 1)); + return {n_start, n_end}; + } + + static void config() { + enable_amx(); + TileConfig tile_config; + + // size is 16 x 32 + for (int i = 0; i < 2; i++) + tile_config.set_row_col(i, TILE_M, TILE_K * sizeof(dt)); + + // size is 16 x 32 + for (int i = 2; i < 4; i++) + tile_config.set_row_col(i, TILE_K / VNNI_BLK, TILE_N * VNNI_BLK * sizeof(dt)); + + // size is 16 x 16 + for (int i = 4; i < 8; i++) + tile_config.set_row_col(i, TILE_M, TILE_N * sizeof(output_t)); + + tile_config.set_config(); + } + + static void load_a(dt *a, size_t lda) { + _tile_loadd(0, a, lda); + _tile_loadd(1, offset_pointer(a, lda * TILE_M), lda); + } + + static void load_b(dt *b, size_t ldb) { + _tile_loadd(2, b, ldb); + _tile_loadd(3, offset_pointer(b, ldb * TILE_N), ldb); + } + + static void clean_c() { + _tile_zero(4); + _tile_zero(5); + _tile_zero(6); + _tile_zero(7); + } + + static void load_c(output_t *c, size_t ldc) { + _tile_loadd(4, c, ldc); + _tile_loadd(5, offset_pointer(c, TILE_N * sizeof(output_t)), ldc); + _tile_loadd(6, offset_pointer(c, ldc * TILE_M), ldc); + _tile_loadd(7, offset_pointer(c, ldc * TILE_M + TILE_N * sizeof(output_t)), ldc); + } + + static void store_c(output_t *c, size_t ldc) { + _tile_stored(4, c, ldc); + _tile_stored(5, offset_pointer(c, TILE_N * sizeof(output_t)), ldc); + _tile_stored(6, offset_pointer(c, ldc * TILE_M), ldc); + _tile_stored(7, offset_pointer(c, ldc * TILE_M + TILE_N * sizeof(output_t)), ldc); + } + + static void run_tile() { + _tile_dpbf16ps(4, 0, 2); + _tile_dpbf16ps(5, 0, 3); + _tile_dpbf16ps(6, 1, 2); + _tile_dpbf16ps(7, 1, 3); + } + + struct BufferA { + ggml_bf16_t *a; + int max_m, k; + + static size_t required_size(int max_m, int k) { return max_m * k * sizeof(ggml_bf16_t); } + + BufferA(int max_m, int k, void *ptr) : max_m(max_m), k(k) { + assert(reinterpret_cast(ptr) % 64 == 0); + assert(max_m % M_STEP == 0); + assert(k % K_STEP == 0); + a = reinterpret_cast(ptr); + } + + void from_mat(int m, ggml_bf16_t *src, int ith, int nth) { + assert(m <= max_m); + assert(ith == 0 && nth == 1); + int m_block_size = (m + M_STEP - 1) / M_STEP * M_STEP; + for (int m_begin = 0; m_begin < m; m_begin += M_STEP) { + for (int k_block_begin = 0; k_block_begin < k; k_block_begin += K_BLOCK) { + int k_block_size = std::min(K_BLOCK, k - k_block_begin); + for (int k_begin = 0; k_begin < k_block_size; k_begin += K_STEP) { + for (int i = 0; i < M_STEP && m_begin + i < m; i++) { + __m512i *s = (__m512i *)(src + (m_begin + i) * k + k_block_begin + k_begin); + __m512i *d = (__m512i *)(a + k_block_begin * m_block_size + m_begin * k_block_size + k_begin * M_STEP + + i * K_STEP); + avx512_copy_32xbf16(s, d); + } + } + } + } + } + + ggml_bf16_t *get_submat(int m, int k, int m_begin, int k_begin) { + int m_block_size = (m + M_STEP - 1) / M_STEP * M_STEP; + int k_block_begin = k_begin / K_BLOCK * K_BLOCK; + k_begin -= k_block_begin; + int k_block_size = std::min(K_BLOCK, k - k_block_begin); + return a + k_block_begin * m_block_size + m_begin * k_block_size + k_begin * M_STEP; + } + }; + + struct BufferB { + ggml_bf16_t *b; + int n, k; + + static size_t required_size(int n, int k) { return n * k * sizeof(ggml_bf16_t); } + + BufferB(int n, int k, void *ptr) : n(n), k(k) { + assert(reinterpret_cast(ptr) % 64 == 0); + assert(n % N_STEP == 0); + assert(k % K_STEP == 0); + b = reinterpret_cast(ptr); + } + + void from_mat(ggml_bf16_t *src, int ith, int nth) { + auto [n_start, n_end] = split_range_n(n, ith, nth); + int n_block_begin = n_start; + int n_block_size = n_end - n_block_begin; + for (int n_begin = 0; n_begin < n_block_size; n_begin += N_STEP) { + for (int k_block_begin = 0; k_block_begin < k; k_block_begin += K_BLOCK) { + int k_block_size = std::min(K_BLOCK, k - k_block_begin); + for (int k_begin = 0; k_begin < k_block_size; k_begin += K_STEP) { + for (int i = 0; i < N_STEP; i++) { + __m512i *s = (__m512i *)(src + (n_block_begin + n_begin + i) * k + k_block_begin + k_begin); + __m512i *d = (__m512i *)(b + n_block_begin * k + k_block_begin * n_block_size + n_begin * k_block_size + + k_begin * N_STEP + i * K_STEP); + avx512_copy_32xbf16(s, d); + } + transpose_16x16_32bit((__m512i *)(b + n_block_begin * k + k_block_begin * n_block_size + + n_begin * k_block_size + k_begin * N_STEP)); + transpose_16x16_32bit((__m512i *)(b + n_block_begin * k + k_block_begin * n_block_size + + n_begin * k_block_size + k_begin * N_STEP + TILE_N * K_STEP)); + } + } + } + } + + ggml_bf16_t *get_submat(int n, int k, int n_begin, int k_begin) { + int n_block_begin = n_begin / N_BLOCK * N_BLOCK; + n_begin -= n_block_begin; + int n_block_size = std::min(N_BLOCK, n - n_block_begin); + int k_block_begin = k_begin / K_BLOCK * K_BLOCK; + k_begin -= k_block_begin; + int k_block_size = std::min(K_BLOCK, k - k_block_begin); + return b + n_block_begin * k + k_block_begin * n_block_size + n_begin * k_block_size + k_begin * N_STEP; + } + }; + + struct BufferC { + float *c; + int max_m, n; + + static size_t required_size(int max_m, int n) { return max_m * n * sizeof(float); } + + BufferC(int max_m, int n, void *ptr) : max_m(max_m), n(n) { + assert(reinterpret_cast(ptr) % 64 == 0); + assert(max_m % M_STEP == 0); + assert(n % N_STEP == 0); + c = reinterpret_cast(ptr); + } + + void to_mat(int m, ggml_bf16_t *dst, int ith, int nth) { + assert(m <= max_m); + auto [n_start, n_end] = split_range_n(n, ith, nth); + int m_block_size = (m + M_STEP - 1) / M_STEP * M_STEP; + int n_block_begin = n_start; + int n_block_size = n_end - n_block_begin; + for (int m_begin = 0; m_begin < m; m_begin += M_STEP) { + for (int n_begin = 0; n_begin < n_block_size; n_begin += N_STEP) { + for (int i = 0; i < M_STEP && m_begin + i < m; i++) { + __m512 *x0 = + (__m512 *)(c + m_block_size * n_block_begin + m_begin * n_block_size + n_begin * M_STEP + i * N_STEP); + __m512 *x1 = (__m512 *)(c + m_block_size * n_block_begin + m_begin * n_block_size + n_begin * M_STEP + + i * N_STEP + 16); + avx512_32xfp32_to_32xbf16(x0, x1, (__m512i *)(dst + (m_begin + i) * n + n_block_begin + n_begin)); + } + } + } + } + + float *get_submat(int m, int n, int m_begin, int n_begin) { + int m_block_size = (m + M_STEP - 1) / M_STEP * M_STEP; + int n_block_begin = n_begin / N_BLOCK * N_BLOCK; + int n_block_size = std::min(N_BLOCK, n - n_block_begin); + n_begin -= n_block_begin; + return c + m_block_size * n_block_begin + m_begin * n_block_size + n_begin * M_STEP; + } + }; +}; + +struct GemmKernel224Int8 { + using dt = int8_t; + using output_t = int32_t; + static const int TILE_M = 16; + static const int TILE_K = 64; + static const int TILE_N = 16; + static const int VNNI_BLK = 4; + + static const int M_STEP = TILE_M * 2; + static const int N_STEP = TILE_N * 2; + static const int K_STEP = TILE_K; + + static inline const int N_BLOCK = 256; + static inline const int K_BLOCK = 3584; + + static int recommended_nth(int n) { return (n + N_BLOCK - 1) / N_BLOCK; } + + static std::pair split_range_n(int n, int ith, int nth) { + int n_start = N_BLOCK * ith; + int n_end = std::min(n, N_BLOCK * (ith + 1)); + return {n_start, n_end}; + } + + static void config() { + enable_amx(); + TileConfig tile_config; + + // size is 16 x 64 + for (int i = 0; i < 2; i++) + tile_config.set_row_col(i, TILE_M, TILE_K * sizeof(dt)); + + // size is 16 x 64 + for (int i = 2; i < 4; i++) + tile_config.set_row_col(i, TILE_K / VNNI_BLK, TILE_N * VNNI_BLK * sizeof(dt)); + + // size is 16 x 16 + for (int i = 4; i < 8; i++) + tile_config.set_row_col(i, TILE_M, TILE_N * sizeof(output_t)); + + tile_config.set_config(); + } + + static void load_a(dt *a, size_t lda) { + _tile_loadd(0, a, lda); + _tile_loadd(1, offset_pointer(a, lda * TILE_M), lda); + } + + static void load_b(dt *b, size_t ldb) { + _tile_loadd(2, b, ldb); + _tile_loadd(3, offset_pointer(b, ldb * TILE_N), ldb); + } + + static void clean_c() { + _tile_zero(4); + _tile_zero(5); + _tile_zero(6); + _tile_zero(7); + } + + static void load_c(output_t *c, size_t ldc) { + _tile_loadd(4, c, ldc); + _tile_loadd(5, offset_pointer(c, TILE_N * sizeof(output_t)), ldc); + _tile_loadd(6, offset_pointer(c, ldc * TILE_M), ldc); + _tile_loadd(7, offset_pointer(c, ldc * TILE_M + TILE_N * sizeof(output_t)), ldc); + } + + static void store_c(output_t *c, size_t ldc) { + _tile_stored(4, c, ldc); + _tile_stored(5, offset_pointer(c, TILE_N * sizeof(output_t)), ldc); + _tile_stored(6, offset_pointer(c, ldc * TILE_M), ldc); + _tile_stored(7, offset_pointer(c, ldc * TILE_M + TILE_N * sizeof(output_t)), ldc); + } + + static void run_tile() { + _tile_dpbssd(4, 0, 2); + _tile_dpbssd(5, 0, 3); + _tile_dpbssd(6, 1, 2); + _tile_dpbssd(7, 1, 3); + } + + struct BufferA { + int8_t *a; + float *d; + int max_m, k; + + static size_t required_size(int max_m, int k) { return max_m * k * sizeof(int8_t) + max_m * sizeof(float); } + + BufferA(int max_m, int k, void *ptr) : max_m(max_m), k(k) { + assert(reinterpret_cast(ptr) % 64 == 0); + assert(max_m % M_STEP == 0); + assert(k % K_STEP == 0); + a = reinterpret_cast(ptr); + d = reinterpret_cast(a + max_m * k); + } + + void from_mat(int m, ggml_bf16_t *src, int ith, int nth) { + assert(m <= max_m); + assert(ith == 0 && nth == 1); + for (int m_begin = 0; m_begin < m; m_begin += M_STEP) { + for (int i = 0; i < M_STEP && m_begin + i < m; i++) { + float amax = 0.0f; + for (int j = 0; j < k; j += 32) { + __m512 f0, f1; + avx512_32xbf16_to_32xfp32((__m512i *)(src + (m_begin + i) * k + j), &f0, &f1); + amax = MAX(amax, _mm512_reduce_max_ps(_mm512_abs_ps(f0))); + amax = MAX(amax, _mm512_reduce_max_ps(_mm512_abs_ps(f1))); + } + d[m_begin + i] = amax / ((1 << 7) - 1); + } + } + int m_block_size = (m + M_STEP - 1) / M_STEP * M_STEP; + for (int m_begin = 0; m_begin < m; m_begin += M_STEP) { + for (int k_block_begin = 0; k_block_begin < k; k_block_begin += K_BLOCK) { + int k_block_size = std::min(K_BLOCK, k - k_block_begin); + for (int k_begin = 0; k_begin < k_block_size; k_begin += K_STEP) { + for (int i = 0; i < M_STEP && m_begin + i < m; i++) { + __m512 id = _mm512_set1_ps(d[m_begin + i] ? 1.0f / d[m_begin + i] : 0.0f); + int8_t *dst = a + k_block_begin * m_block_size + m_begin * k_block_size + k_begin * M_STEP + i * K_STEP; + __m512 f0, f1, f2, f3; + avx512_32xbf16_to_32xfp32((__m512i *)(src + (m_begin + i) * k + k_block_begin + k_begin), &f0, &f1); + avx512_32xbf16_to_32xfp32((__m512i *)(src + (m_begin + i) * k + k_block_begin + k_begin) + 1, &f2, &f3); + __m512i i0 = _mm512_cvtps_epi32(_mm512_mul_ps(f0, id)); + __m512i i1 = _mm512_cvtps_epi32(_mm512_mul_ps(f1, id)); + __m512i i2 = _mm512_cvtps_epi32(_mm512_mul_ps(f2, id)); + __m512i i3 = _mm512_cvtps_epi32(_mm512_mul_ps(f3, id)); + __m128i s0 = _mm512_cvtsepi32_epi8(i0); + __m128i s1 = _mm512_cvtsepi32_epi8(i1); + __m128i s2 = _mm512_cvtsepi32_epi8(i2); + __m128i s3 = _mm512_cvtsepi32_epi8(i3); + _mm_storeu_si128((__m128i *)dst, s0); + _mm_storeu_si128((__m128i *)(dst + 16), s1); + _mm_storeu_si128((__m128i *)(dst + 32), s2); + _mm_storeu_si128((__m128i *)(dst + 48), s3); + } + } + } + } + } + + int8_t *get_submat(int m, int k, int m_begin, int k_begin) { + int m_block_size = (m + M_STEP - 1) / M_STEP * M_STEP; + int k_block_begin = k_begin / K_BLOCK * K_BLOCK; + k_begin -= k_block_begin; + int k_block_size = std::min(K_BLOCK, k - k_block_begin); + return a + k_block_begin * m_block_size + m_begin * k_block_size + k_begin * M_STEP; + } + + float *get_scale(int m, int m_begin) { return d + m_begin; } + }; + + struct BufferB { + int8_t *b; + float *d; + int n, k; + + static size_t required_size(int n, int k) { return n * k * sizeof(int8_t) + n * sizeof(float); } + + BufferB(int n, int k, void *ptr) : n(n), k(k) { + assert(reinterpret_cast(ptr) % 64 == 0); + assert(n % N_STEP == 0); + assert(k % K_STEP == 0); + b = reinterpret_cast(ptr); + d = reinterpret_cast(b + n * k); + } + + void from_mat(ggml_bf16_t *src, int ith, int nth) { + auto [n_start, n_end] = split_range_n(n, ith, nth); + int n_block_begin = n_start; + int n_block_size = n_end - n_block_begin; + for (int n_begin = 0; n_begin < n_block_size; n_begin += N_STEP) { + for (int i = 0; i < N_STEP; i++) { + float amax = 0.0f; + for (int j = 0; j < k; j += 32) { + __m512 f0, f1; + avx512_32xbf16_to_32xfp32((__m512i *)(src + (n_block_begin + n_begin + i) * k + j), &f0, &f1); + amax = MAX(amax, _mm512_reduce_max_ps(_mm512_abs_ps(f0))); + amax = MAX(amax, _mm512_reduce_max_ps(_mm512_abs_ps(f1))); + } + d[n_block_begin + n_begin + i] = amax / ((1 << 7) - 1); + } + } + for (int n_begin = 0; n_begin < n_block_size; n_begin += N_STEP) { + for (int k_block_begin = 0; k_block_begin < k; k_block_begin += K_BLOCK) { + int k_block_size = std::min(K_BLOCK, k - k_block_begin); + for (int k_begin = 0; k_begin < k_block_size; k_begin += K_STEP) { + for (int i = 0; i < N_STEP; i++) { + __m512 id = _mm512_set1_ps(d[n_block_begin + n_begin + i] ? 1.0f / d[n_block_begin + n_begin + i] : 0.0f); + int8_t *dst = b + n_block_begin * k + k_block_begin * n_block_size + n_begin * k_block_size + + k_begin * N_STEP + i * K_STEP; + __m512 f0, f1, f2, f3; + avx512_32xbf16_to_32xfp32((__m512i *)(src + (n_block_begin + n_begin + i) * k + k_block_begin + k_begin), + &f0, &f1); + avx512_32xbf16_to_32xfp32( + (__m512i *)(src + (n_block_begin + n_begin + i) * k + k_block_begin + k_begin) + 1, &f2, &f3); + __m512i i0 = _mm512_cvtps_epi32(_mm512_mul_ps(f0, id)); + __m512i i1 = _mm512_cvtps_epi32(_mm512_mul_ps(f1, id)); + __m512i i2 = _mm512_cvtps_epi32(_mm512_mul_ps(f2, id)); + __m512i i3 = _mm512_cvtps_epi32(_mm512_mul_ps(f3, id)); + __m128i s0 = _mm512_cvtsepi32_epi8(i0); + __m128i s1 = _mm512_cvtsepi32_epi8(i1); + __m128i s2 = _mm512_cvtsepi32_epi8(i2); + __m128i s3 = _mm512_cvtsepi32_epi8(i3); + _mm_storeu_si128((__m128i *)dst, s0); + _mm_storeu_si128((__m128i *)(dst + 16), s1); + _mm_storeu_si128((__m128i *)(dst + 32), s2); + _mm_storeu_si128((__m128i *)(dst + 48), s3); + } + transpose_16x16_32bit((__m512i *)(b + n_block_begin * k + k_block_begin * n_block_size + + n_begin * k_block_size + k_begin * N_STEP)); + transpose_16x16_32bit((__m512i *)(b + n_block_begin * k + k_block_begin * n_block_size + + n_begin * k_block_size + k_begin * N_STEP + TILE_N * K_STEP)); + } + } + } + } + + int8_t *get_submat(int n, int k, int n_begin, int k_begin) { + int n_block_begin = n_begin / N_BLOCK * N_BLOCK; + n_begin -= n_block_begin; + int n_block_size = std::min(N_BLOCK, n - n_block_begin); + int k_block_begin = k_begin / K_BLOCK * K_BLOCK; + k_begin -= k_block_begin; + int k_block_size = std::min(K_BLOCK, k - k_block_begin); + return b + n_block_begin * k + k_block_begin * n_block_size + n_begin * k_block_size + k_begin * N_STEP; + } + + float *get_scale(int n, int n_begin) { return d + n_begin; } + }; + + struct BufferC { + float *c; + int max_m, n; + + static size_t required_size(int max_m, int n) { return max_m * n * sizeof(float); } + + BufferC(int max_m, int n, void *ptr) : max_m(max_m), n(n) { + assert(reinterpret_cast(ptr) % 64 == 0); + assert(max_m % M_STEP == 0); + assert(n % N_STEP == 0); + c = reinterpret_cast(ptr); + } + + void to_mat(int m, ggml_bf16_t *dst, int ith, int nth) { + assert(m <= max_m); + auto [n_start, n_end] = split_range_n(n, ith, nth); + int m_block_size = (m + M_STEP - 1) / M_STEP * M_STEP; + int n_block_begin = n_start; + int n_block_size = n_end - n_block_begin; + for (int m_begin = 0; m_begin < m; m_begin += M_STEP) { + for (int n_begin = 0; n_begin < n_block_size; n_begin += N_STEP) { + for (int i = 0; i < M_STEP && m_begin + i < m; i++) { + __m512 *x0 = + (__m512 *)(c + m_block_size * n_block_begin + m_begin * n_block_size + n_begin * M_STEP + i * N_STEP); + __m512 *x1 = (__m512 *)(c + m_block_size * n_block_begin + m_begin * n_block_size + n_begin * M_STEP + + i * N_STEP + 16); + avx512_32xfp32_to_32xbf16(x0, x1, (__m512i *)(dst + (m_begin + i) * n + n_block_begin + n_begin)); + } + } + } + } + + float *get_submat(int m, int n, int m_begin, int n_begin) { + int m_block_size = (m + M_STEP - 1) / M_STEP * M_STEP; + int n_block_begin = n_begin / N_BLOCK * N_BLOCK; + int n_block_size = std::min(N_BLOCK, n - n_block_begin); + n_begin -= n_block_begin; + return c + m_block_size * n_block_begin + m_begin * n_block_size + n_begin * M_STEP; + } + }; +}; + +inline void mat_mul(int m, int n, int k, std::shared_ptr ba, + std::shared_ptr bb, std::shared_ptr bc, int ith, + int nth, bool use_amx) { + using K = GemmKernel224BF; + assert(n % K::N_STEP == 0); + assert(k % K::K_STEP == 0); + + auto [n_start, n_end] = K::split_range_n(n, ith, nth); + + for (int k_block_begin = 0; k_block_begin < k; k_block_begin += K::K_BLOCK) { + for (int m_begin = 0; m_begin < m; m_begin += K::M_STEP) { + for (int n_begin = n_start; n_begin < n_end; n_begin += K::N_STEP) { + + float *c = bc->get_submat(m, n, m_begin, n_begin); + if (!use_amx) { + __m512 *c512 = (__m512 *)c; + if (k_block_begin == 0) { + for (int m_i = 0; m_i < m; m_i++) { + c512[m_i * 2] = _mm512_setzero_ps(); + c512[m_i * 2 + 1] = _mm512_setzero_ps(); + } + } + + for (int k_begin = 0; k_begin < K::K_BLOCK && k_block_begin + k_begin < k; k_begin += K::K_STEP) { + int32_t *a32 = (int32_t *)ba->get_submat(m, k, m_begin, k_block_begin + k_begin); + __m512bh *b512 = (__m512bh *)bb->get_submat(n, k, n_begin, k_block_begin + k_begin); + for (int m_i = 0; m_i < m; m_i++) { + for (int k_i = 0; k_i < 16; k_i++) { + __m512bh ma = (__m512bh)_mm512_set1_epi32(a32[m_i * 16 + k_i]); + for (int n_i = 0; n_i < 2; n_i++) { + c512[m_i * 2 + n_i] = _mm512_dpbf16_ps(c512[m_i * 2 + n_i], ma, b512[n_i * 16 + k_i]); + } + } + } + } + + } else { + if (k_block_begin == 0) { + K::clean_c(); + } else { + K::load_c(c, K::N_STEP * sizeof(float)); + } + for (int k_begin = 0; k_begin < K::K_BLOCK && k_block_begin + k_begin < k; k_begin += K::K_STEP) { + K::load_a(ba->get_submat(m, k, m_begin, k_block_begin + k_begin), K::K_STEP * sizeof(ggml_bf16_t)); + K::load_b(bb->get_submat(n, k, n_begin, k_block_begin + k_begin), K::K_STEP * sizeof(ggml_bf16_t)); + K::run_tile(); + } + K::store_c(c, K::N_STEP * sizeof(float)); + } + } + } + } +} + +inline __m512i _mm512_dpbssd_epi32(__m512i src, __m512i a, __m512i b) { + __m256i a_lo = _mm512_extracti64x4_epi64(a, 0); + __m256i a_hi = _mm512_extracti64x4_epi64(a, 1); + __m256i b_lo = _mm512_extracti64x4_epi64(b, 0); + __m256i b_hi = _mm512_extracti64x4_epi64(b, 1); + + b_lo = _mm256_sign_epi8(b_lo, a_lo); + b_hi = _mm256_sign_epi8(b_hi, a_hi); + + b = _mm512_inserti64x4(b, b_lo, 0); + b = _mm512_inserti64x4(b, b_hi, 1); + + a = _mm512_abs_epi8(a); + + return _mm512_dpbusd_epi32(src, a, b); +} + +inline void mat_mul(int m, int n, int k, std::shared_ptr ba, + std::shared_ptr bb, std::shared_ptr bc, + int ith, int nth, bool use_amx) { + using K = GemmKernel224Int8; + assert(n % K::N_STEP == 0); + assert(k % K::K_STEP == 0); + + auto [n_start, n_end] = K::split_range_n(n, ith, nth); + + for (int k_block_begin = 0; k_block_begin < k; k_block_begin += K::K_BLOCK) { + for (int m_begin = 0; m_begin < m; m_begin += K::M_STEP) { + for (int n_begin = n_start; n_begin < n_end; n_begin += K::N_STEP) { + float *c = bc->get_submat(m, n, m_begin, n_begin); + + if (!use_amx) { + __m512i *c512 = (__m512i *)c; + if (k_block_begin == 0) { + for (int m_i = 0; m_i < m; m_i++) { + c512[m_i * 2] = _mm512_setzero_si512(); + c512[m_i * 2 + 1] = _mm512_setzero_si512(); + } + } + + for (int k_begin = 0; k_begin < K::K_BLOCK && k_block_begin + k_begin < k; k_begin += K::K_STEP) { + static_assert(K::K_STEP * sizeof(int8_t) == sizeof(__m512i)); + static_assert(K::N_STEP / K::TILE_N == 2, "Must be lke this"); + + int32_t *a32 = (int32_t *)ba->get_submat(m, k, m_begin, k_block_begin + k_begin); + __m512i *b512 = (__m512i *)bb->get_submat(n, k, n_begin, k_block_begin + k_begin); + for (int m_i = 0; m_i < m; m_i++) { + for (int k_i = 0; k_i < 16; k_i++) { + __m512i ma = _mm512_set1_epi32(a32[m_i * 16 + k_i]); + for (int n_i = 0; n_i < 2; n_i++) { + c512[m_i * 2 + n_i] = _mm512_dpbssd_epi32(c512[m_i * 2 + n_i], ma, b512[n_i * 16 + k_i]); + } + } + } + } + } else { + if (k_block_begin == 0) { + K::clean_c(); + } else { + K::load_c((int32_t *)c, K::N_STEP * sizeof(int32_t)); + } + for (int k_begin = 0; k_begin < K::K_BLOCK && k_block_begin + k_begin < k; k_begin += K::K_STEP) { + K::load_a(ba->get_submat(m, k, m_begin, k_block_begin + k_begin), K::K_STEP * sizeof(int8_t)); + K::load_b(bb->get_submat(n, k, n_begin, k_block_begin + k_begin), K::K_STEP * sizeof(int8_t)); + K::run_tile(); + } + K::store_c((int32_t *)c, K::N_STEP * sizeof(int32_t)); + } + + if (k_block_begin + K::K_BLOCK >= k) { + int to = m - m_begin; + if (m - m_begin > K::M_STEP) { + to = K::M_STEP; + } + for (int i = 0; i < to; i++) { + __m512 as = _mm512_set1_ps(*ba->get_scale(m, m_begin + i)); + __m512 bs = _mm512_load_ps(bb->get_scale(n, n_begin)); + __m512i now = _mm512_load_si512((__m512i *)(c + i * K::N_STEP)); + __m512 result = _mm512_mul_ps(_mm512_mul_ps(as, bs), _mm512_cvtepi32_ps(now)); + _mm512_store_ps((__m512 *)(c + i * K::N_STEP), result); + bs = _mm512_load_ps(bb->get_scale(n, n_begin) + K::TILE_N); + now = _mm512_load_si512((__m512i *)(c + i * K::N_STEP + K::TILE_N)); + result = _mm512_mul_ps(_mm512_mul_ps(as, bs), _mm512_cvtepi32_ps(now)); + _mm512_store_ps((__m512 *)(c + i * K::N_STEP + K::TILE_N), result); + } + } + } + } + } +} + +} // namespace amx \ No newline at end of file diff --git a/csrc/ktransformers_ext/operators/amx/la/utils.hpp b/csrc/ktransformers_ext/operators/amx/la/utils.hpp new file mode 100644 index 0000000..a95a2e7 --- /dev/null +++ b/csrc/ktransformers_ext/operators/amx/la/utils.hpp @@ -0,0 +1,46 @@ +/** + * @Description : + * @Author : chenht2022 + * @Date : 2025-04-25 18:28:12 + * @Version : 1.0.0 + * @LastEditors : chenht2022 + * @LastEditTime : 2025-04-25 18:28:12 + * @Copyright (c) 2024 by KVCache.AI, All Rights Reserved. + **/ + +#pragma once +#include + + +template +T* offset_pointer(T* ptr, std::size_t byte_offset) { + return reinterpret_cast(reinterpret_cast(ptr) + byte_offset); +} + +template +const T* offset_pointer(const T* ptr, std::size_t byte_offset) { + return reinterpret_cast(reinterpret_cast(ptr) + byte_offset); +} + +template +T* offset_pointer_row_major(T* t, int row, int col, std::size_t ld) { + return offset_pointer(t, row * ld) + col; +} + +template +T* offset_pointer_col_major(T* t, int row, int col, std::size_t ld) { + return offset_pointer(t, col * ld) + row; +} + +static inline void avx512_copy_32xbf16(__m512i* src, __m512i* dst) { + _mm512_storeu_si512(dst, _mm512_loadu_si512(src)); +} + +static inline void avx512_32xfp32_to_32xbf16(__m512* src0, __m512* src1, __m512i* dst) { + _mm512_storeu_si512(dst, __m512i(_mm512_cvtne2ps_pbh(*src1, *src0))); +} + +static inline void avx512_32xbf16_to_32xfp32(__m512i* src, __m512* dst0, __m512* dst1) { + _mm512_storeu_ps(dst0, _mm512_castsi512_ps(_mm512_slli_epi32(_mm512_cvtepu16_epi32(_mm256_loadu_si256((const __m256i *)(src))), 16))); + _mm512_storeu_ps(dst1, _mm512_castsi512_ps(_mm512_slli_epi32(_mm512_cvtepu16_epi32(_mm256_loadu_si256((const __m256i *)(src) + 1)), 16))); +} \ No newline at end of file diff --git a/csrc/ktransformers_ext/operators/amx/moe.hpp b/csrc/ktransformers_ext/operators/amx/moe.hpp new file mode 100644 index 0000000..7e966ae --- /dev/null +++ b/csrc/ktransformers_ext/operators/amx/moe.hpp @@ -0,0 +1,398 @@ +/** + * @Description : + * @Author : chenht2022 + * @Date : 2025-04-25 18:28:12 + * @Version : 1.0.0 + * @LastEditors : chenht2022 + * @LastEditTime : 2025-04-25 18:28:12 + * @Copyright (c) 2024 by KVCache.AI, All Rights Reserved. + **/ +#ifndef CPUINFER_OPERATOR_AMX_MOE_H +#define CPUINFER_OPERATOR_AMX_MOE_H + +#include +#include +#include +#include +#include + +#include "../../cpu_backend/backend.h" +#include "../../cpu_backend/shared_mem_buffer.h" +#include "llama.cpp/ggml-impl.h" +#include "llama.cpp/ggml-quants.h" +#include "llama.cpp/ggml.h" +#include "llamafile/sgemm.h" + +#include "la/amx.hpp" + +#ifdef USE_NUMA +#include +#include +void *numa_alloc_aligned(size_t size, int node, size_t alignment) { + void *ptr = numa_alloc_onnode(size, node); + assert(reinterpret_cast(ptr) % 64 == 0); + return ptr; +} +#endif + +static inline __m512 exp_avx512(__m512 x) { + const __m512 log2e = _mm512_set1_ps(1.44269504089f); + const __m512 c1 = _mm512_set1_ps(0.69314718056f); + + __m512 y = _mm512_mul_ps(x, log2e); + __m512i int_part = _mm512_cvtps_epi32(y); + __m512 frac_part = _mm512_sub_ps(y, _mm512_cvtepi32_ps(int_part)); + + const __m512 poly_1 = _mm512_set1_ps(0.9999999995f); + const __m512 poly_2 = _mm512_set1_ps(0.6931471805f); + const __m512 poly_3 = _mm512_set1_ps(0.2402265069f); + const __m512 poly_4 = _mm512_set1_ps(0.0555041087f); + const __m512 poly_5 = _mm512_set1_ps(0.0096181291f); + const __m512 poly_6 = _mm512_set1_ps(0.0013333558f); + + __m512 frac_exp = _mm512_fmadd_ps( + frac_part, poly_6, + _mm512_fmadd_ps(frac_part, poly_5, + _mm512_fmadd_ps(frac_part, poly_4, + _mm512_fmadd_ps(frac_part, poly_3, _mm512_fmadd_ps(frac_part, poly_2, poly_1))))); + + __m512 two_pow_i = _mm512_scalef_ps(_mm512_set1_ps(1.0f), _mm512_cvtepi32_ps(int_part)); + return _mm512_mul_ps(two_pow_i, frac_exp); +} + +static inline __m512 act_fn(__m512 gate_val, __m512 up_val) { + __m512 neg_gate_val = _mm512_sub_ps(_mm512_setzero_ps(), gate_val); + __m512 exp_neg_gate = exp_avx512(neg_gate_val); + __m512 denom = _mm512_add_ps(_mm512_set1_ps(1.0f), exp_neg_gate); + __m512 act_val = _mm512_div_ps(gate_val, denom); + + return _mm512_mul_ps(act_val, up_val); +} + +struct AMX_MOEConfig { + int expert_num; + int routed_expert_num; + int hidden_size; + int intermediate_size; + int max_len; + void *gate_proj; + void *up_proj; + void *down_proj; + + AMX_MOEConfig() {} + + AMX_MOEConfig(int expert_num, int routed_expert_num, int hidden_size, int intermediate_size, int max_len, + void *gate_proj, void *up_proj, void *down_proj) + : expert_num(expert_num), routed_expert_num(routed_expert_num), hidden_size(hidden_size), + intermediate_size(intermediate_size), max_len(max_len), gate_proj(gate_proj), up_proj(up_proj), + down_proj(down_proj) {} +}; + +template class AMX_MOE { +private: + AMX_MOEConfig config_; + void *gate_proj_; // [expert_num * intermediate_size * hidden_size ( /32 if quantized)] + void *up_proj_; // [expert_num * intermediate_size * hidden_size ( /32 if quantized)] + void *down_proj_; // [expert_num * hidden_size * intermediate_size ( /32 if quantized)] + + ggml_bf16_t *m_local_input_; // [routed_expert_num * max_len * hidden_size] + ggml_bf16_t *m_local_gate_output_; // [routed_expert_num * max_len * intermediate_size] + ggml_bf16_t *m_local_up_output_; // [routed_expert_num * max_len * intermediate_size] + ggml_bf16_t *m_local_down_output_; // [routed_expert_num * max_len * hidden_size] + + std::vector> m_local_pos_; // [max_len, routed_expert_num] + std::vector m_local_num_; // [expert_num] + std::vector m_expert_id_map_; // [expert_num] + std::vector m_local_input_ptr_; // [expert_num] + std::vector m_local_gate_output_ptr_; // [expert_num] + std::vector m_local_up_output_ptr_; // [expert_num] + std::vector m_local_down_output_ptr_; // [expert_num] + + std::vector> gate_up_ba_; + std::vector> gate_bc_; + std::vector> up_bc_; + std::vector> down_ba_; + std::vector> down_bc_; + +#ifdef USE_NUMA + std::vector>> gate_bb_numa_; + std::vector>> up_bb_numa_; + std::vector>> down_bb_numa_; +#else + std::vector> gate_bb_; + std::vector> up_bb_; + std::vector> down_bb_; +#endif + +public: + AMX_MOE(AMX_MOEConfig config) { + config_ = config; + gate_proj_ = config_.gate_proj; + up_proj_ = config_.up_proj; + down_proj_ = config_.down_proj; + + std::vector> m_mem_requests; + m_mem_requests.push_back({(void **)&m_local_input_, + sizeof(ggml_bf16_t) * config_.routed_expert_num * config_.max_len * config_.hidden_size}); + m_mem_requests.push_back({(void **)&m_local_gate_output_, sizeof(ggml_bf16_t) * config_.routed_expert_num * + config_.max_len * config_.intermediate_size}); + m_mem_requests.push_back({(void **)&m_local_up_output_, sizeof(ggml_bf16_t) * config_.routed_expert_num * + config_.max_len * config_.intermediate_size}); + m_mem_requests.push_back({(void **)&m_local_down_output_, + sizeof(ggml_bf16_t) * config_.routed_expert_num * config_.max_len * config_.hidden_size}); + std::vector gate_up_ba_ptr(config_.expert_num); + std::vector gate_bc_ptr(config_.expert_num); + std::vector up_bc_ptr(config_.expert_num); + std::vector down_ba_ptr(config_.expert_num); + std::vector down_bc_ptr(config_.expert_num); + for (int i = 0; i < config_.expert_num; i++) { + m_mem_requests.push_back( + {(void **)&gate_up_ba_ptr[i], T::BufferA::required_size(config_.max_len, config_.hidden_size)}); + m_mem_requests.push_back( + {(void **)&gate_bc_ptr[i], T::BufferC::required_size(config_.max_len, config_.intermediate_size)}); + m_mem_requests.push_back( + {(void **)&up_bc_ptr[i], T::BufferC::required_size(config_.max_len, config_.intermediate_size)}); + m_mem_requests.push_back( + {(void **)&down_ba_ptr[i], T::BufferA::required_size(config_.max_len, config_.intermediate_size)}); + m_mem_requests.push_back( + {(void **)&down_bc_ptr[i], T::BufferC::required_size(config_.max_len, config_.hidden_size)}); + } + shared_mem_buffer.alloc(this, m_mem_requests); + + m_local_pos_.resize(config_.max_len); + for (int i = 0; i < config_.max_len; i++) { + m_local_pos_[i].resize(config_.routed_expert_num); + } + m_expert_id_map_.resize(config_.expert_num); + m_local_num_.resize(config_.expert_num); + m_local_input_ptr_.resize(config_.expert_num); + m_local_gate_output_ptr_.resize(config_.expert_num); + m_local_up_output_ptr_.resize(config_.expert_num); + m_local_down_output_ptr_.resize(config_.expert_num); + + for (uint64_t i = 0; i < config_.expert_num; i++) { + gate_up_ba_.push_back( + std::make_shared(config_.max_len, config_.hidden_size, gate_up_ba_ptr[i])); + gate_bc_.push_back( + std::make_shared(config_.max_len, config_.intermediate_size, gate_bc_ptr[i])); + up_bc_.push_back(std::make_shared(config_.max_len, config_.intermediate_size, up_bc_ptr[i])); + down_ba_.push_back( + std::make_shared(config_.max_len, config_.intermediate_size, down_ba_ptr[i])); + down_bc_.push_back(std::make_shared(config_.max_len, config_.hidden_size, down_bc_ptr[i])); + +#ifdef USE_NUMA + int numa_nodes = numa_num_configured_nodes(); + gate_bb_numa_.resize(numa_nodes); + up_bb_numa_.resize(numa_nodes); + down_bb_numa_.resize(numa_nodes); + for (int j = 0; j < numa_nodes; j++) { + void *gate_bb_ptr = + numa_alloc_aligned(T::BufferB::required_size(config_.intermediate_size, config_.hidden_size), j, 64); + gate_bb_numa_[j].push_back( + std::make_shared(config_.intermediate_size, config_.hidden_size, gate_bb_ptr)); + void *up_bb_ptr = + numa_alloc_aligned(T::BufferB::required_size(config_.intermediate_size, config_.hidden_size), j, 64); + up_bb_numa_[j].push_back( + std::make_shared(config_.intermediate_size, config_.hidden_size, up_bb_ptr)); + void *down_bb_ptr = + numa_alloc_aligned(T::BufferB::required_size(config_.hidden_size, config_.intermediate_size), j, 64); + down_bb_numa_[j].push_back( + std::make_shared(config_.hidden_size, config_.intermediate_size, down_bb_ptr)); + } +#else + void *gate_bb_ptr = + std::aligned_alloc(64, T::BufferB::required_size(config_.intermediate_size, config_.hidden_size)); + gate_bb_.push_back( + std::make_shared(config_.intermediate_size, config_.hidden_size, gate_bb_ptr)); + + void *up_bb_ptr = + std::aligned_alloc(64, T::BufferB::required_size(config_.intermediate_size, config_.hidden_size)); + up_bb_.push_back( + std::make_shared(config_.intermediate_size, config_.hidden_size, up_bb_ptr)); + + void *down_bb_ptr = + std::aligned_alloc(64, T::BufferB::required_size(config_.hidden_size, config_.intermediate_size)); + down_bb_.push_back( + std::make_shared(config_.hidden_size, config_.intermediate_size, down_bb_ptr)); +#endif + } + } + + ~AMX_MOE() { shared_mem_buffer.dealloc(this); } + + void load_weights(Backend *backend) { + int nth = T::recommended_nth(config_.intermediate_size); + backend->do_work_stealing_job( + nth * config_.expert_num, nullptr, + [&](int task_id) { + uint64_t expert_idx = task_id / nth; + int ith = task_id % nth; +#ifdef USE_NUMA + int numa_nodes = numa_num_configured_nodes(); + for (int j = 0; j < numa_nodes; j++) { + gate_bb_numa_[j][expert_idx]->from_mat((ggml_bf16_t *)config_.gate_proj + + expert_idx * config_.intermediate_size * config_.hidden_size, + ith, nth); + up_bb_numa_[j][expert_idx]->from_mat((ggml_bf16_t *)config_.up_proj + + expert_idx * config_.intermediate_size * config_.hidden_size, + ith, nth); + } +#else + gate_bb_[expert_idx]->from_mat((ggml_bf16_t *)config_.gate_proj + + expert_idx * config_.intermediate_size * config_.hidden_size, + ith, nth); + up_bb_[expert_idx]->from_mat( + (ggml_bf16_t *)config_.up_proj + expert_idx * config_.intermediate_size * config_.hidden_size, ith, nth); +#endif + }, + nullptr); + nth = T::recommended_nth(config_.hidden_size); + backend->do_work_stealing_job( + nth * config_.expert_num, nullptr, + [&](int task_id) { + uint64_t expert_idx = task_id / nth; + int ith = task_id % nth; +#ifdef USE_NUMA + int numa_nodes = numa_num_configured_nodes(); + for (int j = 0; j < numa_nodes; j++) { + down_bb_numa_[j][expert_idx]->from_mat((ggml_bf16_t *)config_.down_proj + + expert_idx * config_.hidden_size * config_.intermediate_size, + ith, nth); + } +#else + down_bb_[expert_idx]->from_mat((ggml_bf16_t *)config_.down_proj + + expert_idx * config_.hidden_size * config_.intermediate_size, + ith, nth); +#endif + }, + nullptr); + } + + void warm_up(Backend *backend) {} + + void forward(int qlen, int k, const uint64_t *expert_ids, const float *weights, const void *input, void *output, + int *batch_size_tensor, Backend *backend) { + bool use_amx = (qlen > 4 * config_.expert_num / config_.routed_expert_num); + qlen = batch_size_tensor[0]; + int activated_expert = 0; + for (int i = 0; i < config_.expert_num; i++) { + m_local_num_[i] = 0; + } + for (int i = 0; i < qlen; i++) { + for (int j = 0; j < k; j++) { + m_local_pos_[i][j] = m_local_num_[expert_ids[i * k + j]]++; + } + } + for (int i = 0; i < config_.expert_num; i++) { + if (m_local_num_[i] > 0) { + m_expert_id_map_[activated_expert] = i; + activated_expert++; + } + } + uint64_t offset = 0; + for (int i = 0; i < config_.expert_num; i++) { + m_local_input_ptr_[i] = m_local_input_ + offset * config_.hidden_size; + m_local_gate_output_ptr_[i] = m_local_gate_output_ + offset * config_.intermediate_size; + m_local_up_output_ptr_[i] = m_local_up_output_ + offset * config_.intermediate_size; + m_local_down_output_ptr_[i] = m_local_down_output_ + offset * config_.hidden_size; + offset += m_local_num_[i]; + } + backend->do_work_stealing_job( + qlen, nullptr, + [&](int i) { + for (int j = 0; j < k; j++) { + memcpy(m_local_input_ptr_[expert_ids[i * k + j]] + m_local_pos_[i][j] * config_.hidden_size, + (ggml_bf16_t *)input + i * config_.hidden_size, sizeof(ggml_bf16_t) * config_.hidden_size); + } + }, + nullptr); + backend->do_work_stealing_job( + activated_expert, nullptr, + [&](int task_id) { + int expert_idx = m_expert_id_map_[task_id]; + + gate_up_ba_[expert_idx]->from_mat(m_local_num_[expert_idx], m_local_input_ptr_[expert_idx], 0, 1); + }, + nullptr); + int nth = T::recommended_nth(config_.intermediate_size); + backend->do_work_stealing_job( + nth * activated_expert, [&](int _) { T::config(); }, + [&](int task_id) { + int expert_idx = m_expert_id_map_[task_id / nth]; + int ith = task_id % nth; +#ifdef USE_NUMA + amx::mat_mul(m_local_num_[expert_idx], config_.intermediate_size, config_.hidden_size, + gate_up_ba_[expert_idx], gate_bb_numa_[Backend::numa_node][expert_idx], gate_bc_[expert_idx], + ith, nth, use_amx); + amx::mat_mul(m_local_num_[expert_idx], config_.intermediate_size, config_.hidden_size, + gate_up_ba_[expert_idx], up_bb_numa_[Backend::numa_node][expert_idx], up_bc_[expert_idx], ith, + nth, use_amx); +#else + amx::mat_mul(m_local_num_[expert_idx], config_.intermediate_size, config_.hidden_size, + gate_up_ba_[expert_idx], gate_bb_[expert_idx], gate_bc_[expert_idx], ith, nth, use_amx); + amx::mat_mul(m_local_num_[expert_idx], config_.intermediate_size, config_.hidden_size, + gate_up_ba_[expert_idx], up_bb_[expert_idx], up_bc_[expert_idx], ith, nth, use_amx); +#endif + gate_bc_[expert_idx]->to_mat(m_local_num_[expert_idx], m_local_gate_output_ptr_[expert_idx], ith, nth); + up_bc_[expert_idx]->to_mat(m_local_num_[expert_idx], m_local_up_output_ptr_[expert_idx], ith, nth); + auto [n_start, n_end] = T::split_range_n(config_.intermediate_size, ith, nth); + for (int i = 0; i < m_local_num_[expert_idx]; i++) { + ggml_bf16_t *gate_output_ptr = &m_local_gate_output_ptr_[expert_idx][i * config_.intermediate_size]; + ggml_bf16_t *up_output_ptr = &m_local_up_output_ptr_[expert_idx][i * config_.intermediate_size]; + for (int j = n_start; j < n_end; j += 32) { + __m512 gate_val0, gate_val1, up_val0, up_val1; + avx512_32xbf16_to_32xfp32((__m512i *)(gate_output_ptr + j), &gate_val0, &gate_val1); + avx512_32xbf16_to_32xfp32((__m512i *)(up_output_ptr + j), &up_val0, &up_val1); + __m512 result0 = act_fn(gate_val0, up_val0); + __m512 result1 = act_fn(gate_val1, up_val1); + avx512_32xfp32_to_32xbf16(&result0, &result1, (__m512i *)(gate_output_ptr + j)); + } + } + }, + nullptr); + backend->do_work_stealing_job( + activated_expert, nullptr, + [&](int task_id) { + int expert_idx = m_expert_id_map_[task_id]; + down_ba_[expert_idx]->from_mat(m_local_num_[expert_idx], m_local_gate_output_ptr_[expert_idx], 0, 1); + }, + nullptr); + nth = T::recommended_nth(config_.hidden_size); + backend->do_work_stealing_job( + nth * activated_expert, [&](int _) { T::config(); }, + [&](int task_id) { + int expert_idx = m_expert_id_map_[task_id / nth]; + int ith = task_id % nth; +#ifdef USE_NUMA + amx::mat_mul(m_local_num_[expert_idx], config_.hidden_size, config_.intermediate_size, down_ba_[expert_idx], + down_bb_numa_[Backend::numa_node][expert_idx], down_bc_[expert_idx], ith, nth, use_amx); +#else + amx::mat_mul(m_local_num_[expert_idx], config_.hidden_size, config_.intermediate_size, down_ba_[expert_idx], + down_bb_[expert_idx], down_bc_[expert_idx], ith, nth, use_amx); +#endif + down_bc_[expert_idx]->to_mat(m_local_num_[expert_idx], m_local_down_output_ptr_[expert_idx], ith, nth); + }, + nullptr); + backend->do_work_stealing_job( + qlen, nullptr, + [&](int i) { + for (int e = 0; e < config_.hidden_size; e += 32) { + __m512 x0 = _mm512_setzero_ps(); + __m512 x1 = _mm512_setzero_ps(); + for (int j = 0; j < k; j++) { + __m512 weight = _mm512_set1_ps(weights[i * k + j]); + __m512 down_output0, down_output1; + avx512_32xbf16_to_32xfp32((__m512i *)(m_local_down_output_ptr_[expert_ids[i * k + j]] + + m_local_pos_[i][j] * config_.hidden_size + e), + &down_output0, &down_output1); + x0 = _mm512_fmadd_ps(down_output0, weight, x0); + x1 = _mm512_fmadd_ps(down_output1, weight, x1); + } + avx512_32xfp32_to_32xbf16(&x0, &x1, (__m512i *)((ggml_bf16_t *)output + i * config_.hidden_size + e)); + } + }, + nullptr); + } +}; + +#endif \ No newline at end of file diff --git a/csrc/ktransformers_ext/operators/llamafile/linear.h b/csrc/ktransformers_ext/operators/llamafile/linear.h index fd856f9..f8ae7ae 100644 --- a/csrc/ktransformers_ext/operators/llamafile/linear.h +++ b/csrc/ktransformers_ext/operators/llamafile/linear.h @@ -17,12 +17,12 @@ #include #include "../../cpu_backend/backend.h" +#include "../../cpu_backend/shared_mem_buffer.h" #include "conversion.h" #include "llama.cpp/ggml-impl.h" #include "llama.cpp/ggml-quants.h" #include "llama.cpp/ggml.h" #include "llamafile/sgemm.h" -#include "shared_mem_buffer.h" struct LinearConfig { int input_size; diff --git a/csrc/ktransformers_ext/operators/llamafile/mlp.h b/csrc/ktransformers_ext/operators/llamafile/mlp.h index eb93294..7e6e5cc 100644 --- a/csrc/ktransformers_ext/operators/llamafile/mlp.h +++ b/csrc/ktransformers_ext/operators/llamafile/mlp.h @@ -17,12 +17,12 @@ #include #include "../../cpu_backend/backend.h" +#include "../../cpu_backend/shared_mem_buffer.h" #include "conversion.h" #include "llama.cpp/ggml-impl.h" #include "llama.cpp/ggml-quants.h" #include "llama.cpp/ggml.h" #include "llamafile/sgemm.h" -#include "shared_mem_buffer.h" struct MLPConfig { int hidden_size; diff --git a/csrc/ktransformers_ext/operators/llamafile/moe.h b/csrc/ktransformers_ext/operators/llamafile/moe.h index 9a8b6cd..28d7ad3 100644 --- a/csrc/ktransformers_ext/operators/llamafile/moe.h +++ b/csrc/ktransformers_ext/operators/llamafile/moe.h @@ -17,12 +17,12 @@ #include #include "../../cpu_backend/backend.h" +#include "../../cpu_backend/shared_mem_buffer.h" #include "conversion.h" #include "llama.cpp/ggml-impl.h" #include "llama.cpp/ggml-quants.h" #include "llama.cpp/ggml.h" #include "llamafile/sgemm.h" -#include "shared_mem_buffer.h" struct MOEConfig { int expert_num; diff --git a/csrc/ktransformers_ext/operators/llamafile/shared_mem_buffer.h b/csrc/ktransformers_ext/operators/llamafile/shared_mem_buffer.h deleted file mode 100644 index eeaccd4..0000000 --- a/csrc/ktransformers_ext/operators/llamafile/shared_mem_buffer.h +++ /dev/null @@ -1,37 +0,0 @@ -/** - * @Description : - * @Author : chenht2022 - * @Date : 2024-08-05 04:49:08 - * @Version : 1.0.0 - * @LastEditors : chenht2022 - * @LastEditTime : 2024-08-05 06:36:41 - * @Copyright (c) 2024 by KVCache.AI, All Rights Reserved. - **/ - -#ifndef CPUINFER_SHAREDMEMBUFFER_H -#define CPUINFER_SHAREDMEMBUFFER_H - -#include -#include -#include -#include - -class SharedMemBuffer { - public: - SharedMemBuffer(); - ~SharedMemBuffer(); - - void alloc(void* object, std::vector> requests); - void dealloc(void* object); - - private: - void* buffer_; - uint64_t size_; - std::map>>> hist_requests_; - - void arrange(std::vector> requests); -}; - -static SharedMemBuffer shared_mem_buffer; - -#endif \ No newline at end of file diff --git a/ktransformers/operators/experts.py b/ktransformers/operators/experts.py index 4a4f3c1..f73c4c3 100644 --- a/ktransformers/operators/experts.py +++ b/ktransformers/operators/experts.py @@ -25,8 +25,9 @@ sys.path.append(os.path.join(os.path.dirname(__file__), "..", "ktransformers_ext sys.path.append(os.path.join(os.path.dirname(__file__), "..", "ktransformers_ext", "build", "Debug")) import cpuinfer_ext from cpuinfer_ext.moe import MOEConfig, MOE +from cpuinfer_ext.moe import AMX_MOEConfig, AMXBF16_MOE, AMXInt8_MOE import ctypes -from ktransformers.util.custom_gguf import GGUFLoader +from ktransformers.util.custom_gguf import GGMLQuantizationType, GGUFLoader from ktransformers.util.utils import InferenceState from ktransformers.server.config.config import Config from transformers.activations import ACT2FN @@ -141,6 +142,7 @@ class KExpertsCPU(KExpertsBase): assert device.lower() == "cpu", "KExpertsCPU can only be loaded on CPU" self.n_routed_experts = n_routed_experts self.out_device = out_device + self.backend = kwargs.get("backend", "llamafile") def load(self, w: dict | nn.Parameter | tuple | None = None, device:str|None = None, warmup:bool = False): if device: @@ -163,27 +165,62 @@ class KExpertsCPU(KExpertsBase): ) # print(self.gate_qtype, self.up_qtype, self.down_qtype) n_routed_experts = self.n_routed_experts + self.cpu_infer = KExpertsCPU.CPU_INFER # n_routed_experts = len(self.orig_module) - moe_config = MOEConfig( - n_routed_experts, - self.config.num_experts_per_tok, - self.config.hidden_size, - self.config.moe_intermediate_size, - 64, - 10, - 1024, - gate_ptr, - up_ptr, - down_ptr, - self.gate_type, - self.up_type, - self.down_type, - 30, # TODO: get from model.dtype - ) + if self.backend == "llamafile": + moe_config = MOEConfig( + n_routed_experts, + self.config.num_experts_per_tok, + self.config.hidden_size, + self.config.moe_intermediate_size, + 64, + 10, + 1024, + gate_ptr, + up_ptr, + down_ptr, + self.gate_type, + self.up_type, + self.down_type, + 30, # TODO: get from model.dtype + ) + self.moe = MOE(moe_config) + elif self.backend == "AMXBF16": + assert self.gate_type == GGMLQuantizationType.BF16 + assert self.up_type == GGMLQuantizationType.BF16 + assert self.down_type == GGMLQuantizationType.BF16 + moe_config = AMX_MOEConfig( + n_routed_experts, + self.config.num_experts_per_tok, + self.config.hidden_size, + self.config.moe_intermediate_size, + 25600, + gate_ptr, + up_ptr, + down_ptr, + ) + self.moe = AMXBF16_MOE(moe_config) + self.cpu_infer.submit(self.moe.load_weights()) + self.cpu_infer.sync() + elif self.backend == "AMXInt8": + assert self.gate_type == GGMLQuantizationType.BF16 + assert self.up_type == GGMLQuantizationType.BF16 + assert self.down_type == GGMLQuantizationType.BF16 + moe_config = AMX_MOEConfig( + n_routed_experts, + self.config.num_experts_per_tok, + self.config.hidden_size, + self.config.moe_intermediate_size, + 25600, + gate_ptr, + up_ptr, + down_ptr, + ) + self.moe = AMXInt8_MOE(moe_config) + self.cpu_infer.submit(self.moe.load_weights()) + self.cpu_infer.sync() # print(n_routed_experts, hidden_size, moe_intermediate_size) num_experts_per_tok = self.config.num_experts_per_tok - self.moe = MOE(moe_config) - self.cpu_infer = KExpertsCPU.CPU_INFER if warmup: self.cpu_infer.submit(self.moe.warm_up()) self.cpu_infer.sync() diff --git a/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-amx.yaml b/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-amx.yaml new file mode 100644 index 0000000..724e1a4 --- /dev/null +++ b/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-amx.yaml @@ -0,0 +1,77 @@ +- match: + class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding + replace: + class: ktransformers.operators.RoPE.YarnRotaryEmbeddingV3 + kwargs: + generate_device: "cuda" + prefill_device: "cuda" + +- match: + name: "^lm_head$" # regular expression + class: torch.nn.Linear # only match modules matching name and class simultaneously + replace: + class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types + kwargs: + generate_device: "cuda" + prefill_device: "cuda" + generate_op: "KLinearMarlin" + prefill_op: "KLinearTorch" + +- match: + name: "^model\\.layers\\.(?!.*self_attn\\.kv_b_proj).*$" # regular expression + class: torch.nn.Linear # only match modules matching name and class simultaneously + replace: + class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types + kwargs: + generate_device: "cuda" + prefill_device: "cuda" + generate_op: "KLinearMarlin" + prefill_op: "KLinearTorch" +- match: + name: "^model\\.layers\\..*\\.mlp$" + class: ktransformers.models.modeling_deepseek_v3.DeepseekV3MoE + replace: + class: ktransformers.operators.experts.KDeepseekV3MoE # mlp module with custom forward function + kwargs: + generate_device: "cuda" + prefill_device: "cuda" +- match: + class: ktransformers.models.modeling_deepseek_v3.MoEGate + replace: + class: ktransformers.operators.gate.KMoEGate + kwargs: + generate_device: "cuda:0" + prefill_device: "cuda:0" +- match: + name: "^model\\.layers\\..*\\.mlp\\.experts$" + replace: + class: ktransformers.operators.experts.KTransformersExperts # custom MoE Kernel with expert paralleism + kwargs: + prefill_device: "cuda" + prefill_op: "KExpertsTorch" + generate_device: "cpu" + generate_op: "KExpertsCPU" + out_device: "cuda" + backend: "AMXInt8" # or "AMXBF16" or "llamafile" (default) + recursive: False # don't recursively inject submodules of this module +- match: + name: "^model\\.layers\\..*\\.self_attn$" + replace: + class: ktransformers.operators.attention.KDeepseekV2Attention # optimized MLA implementation + kwargs: + generate_device: "cuda" + prefill_device: "cuda" + absorb_for_prefill: False # change this to True to enable long context(prefill may slower). +- match: + name: "^model$" + replace: + class: "ktransformers.operators.models.KDeepseekV2Model" + kwargs: + per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill +- match: + name: "^model.embed_tokens" + replace: + class: "default" + kwargs: + generate_device: "cpu" + prefill_device: "cpu" \ No newline at end of file