/** * @Description : * @Author : chenht2022 * @Date : 2025-04-25 18:28:12 * @Version : 1.0.0 * @LastEditors : chenht2022 * @LastEditTime : 2025-04-25 18:28:12 * @Copyright (c) 2024 by KVCache.AI, All Rights Reserved. **/ #ifndef CPUINFER_OPERATOR_AMX_MOE_H #define CPUINFER_OPERATOR_AMX_MOE_H #include #include #include #include #include #include "../../cpu_backend/backend.h" #include "../../cpu_backend/shared_mem_buffer.h" #include "llama.cpp/ggml-impl.h" #include "llama.cpp/ggml-quants.h" #include "llama.cpp/ggml.h" #include "llamafile/sgemm.h" #include "la/amx.hpp" #ifdef USE_NUMA #include #include void *numa_alloc_aligned(size_t size, int node, size_t alignment) { void *ptr = numa_alloc_onnode(size, node); assert(reinterpret_cast(ptr) % 64 == 0); return ptr; } #endif static inline __m512 exp_avx512(__m512 x) { const __m512 log2e = _mm512_set1_ps(1.44269504089f); const __m512 c1 = _mm512_set1_ps(0.69314718056f); __m512 y = _mm512_mul_ps(x, log2e); __m512i int_part = _mm512_cvtps_epi32(y); __m512 frac_part = _mm512_sub_ps(y, _mm512_cvtepi32_ps(int_part)); const __m512 poly_1 = _mm512_set1_ps(0.9999999995f); const __m512 poly_2 = _mm512_set1_ps(0.6931471805f); const __m512 poly_3 = _mm512_set1_ps(0.2402265069f); const __m512 poly_4 = _mm512_set1_ps(0.0555041087f); const __m512 poly_5 = _mm512_set1_ps(0.0096181291f); const __m512 poly_6 = _mm512_set1_ps(0.0013333558f); __m512 frac_exp = _mm512_fmadd_ps( frac_part, poly_6, _mm512_fmadd_ps(frac_part, poly_5, _mm512_fmadd_ps(frac_part, poly_4, _mm512_fmadd_ps(frac_part, poly_3, _mm512_fmadd_ps(frac_part, poly_2, poly_1))))); __m512 two_pow_i = _mm512_scalef_ps(_mm512_set1_ps(1.0f), _mm512_cvtepi32_ps(int_part)); return _mm512_mul_ps(two_pow_i, frac_exp); } static inline __m512 act_fn(__m512 gate_val, __m512 up_val) { __m512 neg_gate_val = _mm512_sub_ps(_mm512_setzero_ps(), gate_val); __m512 exp_neg_gate = exp_avx512(neg_gate_val); __m512 denom = _mm512_add_ps(_mm512_set1_ps(1.0f), exp_neg_gate); __m512 act_val = _mm512_div_ps(gate_val, denom); return _mm512_mul_ps(act_val, up_val); } struct AMX_MOEConfig { int expert_num; int routed_expert_num; int hidden_size; int intermediate_size; int max_len; void *gate_proj; void *up_proj; void *down_proj; AMX_MOEConfig() {} AMX_MOEConfig(int expert_num, int routed_expert_num, int hidden_size, int intermediate_size, int max_len, void *gate_proj, void *up_proj, void *down_proj) : expert_num(expert_num), routed_expert_num(routed_expert_num), hidden_size(hidden_size), intermediate_size(intermediate_size), max_len(max_len), gate_proj(gate_proj), up_proj(up_proj), down_proj(down_proj) {} }; template class AMX_MOE { private: AMX_MOEConfig config_; void *gate_proj_; // [expert_num * intermediate_size * hidden_size ( /32 if quantized)] void *up_proj_; // [expert_num * intermediate_size * hidden_size ( /32 if quantized)] void *down_proj_; // [expert_num * hidden_size * intermediate_size ( /32 if quantized)] ggml_bf16_t *m_local_input_; // [routed_expert_num * max_len * hidden_size] ggml_bf16_t *m_local_gate_output_; // [routed_expert_num * max_len * intermediate_size] ggml_bf16_t *m_local_up_output_; // [routed_expert_num * max_len * intermediate_size] ggml_bf16_t *m_local_down_output_; // [routed_expert_num * max_len * hidden_size] std::vector> m_local_pos_; // [max_len, routed_expert_num] std::vector m_local_num_; // [expert_num] std::vector m_expert_id_map_; // [expert_num] std::vector m_local_input_ptr_; // [expert_num] std::vector m_local_gate_output_ptr_; // [expert_num] std::vector m_local_up_output_ptr_; // [expert_num] std::vector m_local_down_output_ptr_; // [expert_num] std::vector> gate_up_ba_; std::vector> gate_bc_; std::vector> up_bc_; std::vector> down_ba_; std::vector> down_bc_; #ifdef USE_NUMA std::vector>> gate_bb_numa_; std::vector>> up_bb_numa_; std::vector>> down_bb_numa_; #else std::vector> gate_bb_; std::vector> up_bb_; std::vector> down_bb_; #endif public: AMX_MOE(AMX_MOEConfig config) { config_ = config; gate_proj_ = config_.gate_proj; up_proj_ = config_.up_proj; down_proj_ = config_.down_proj; std::vector> m_mem_requests; m_mem_requests.push_back({(void **)&m_local_input_, sizeof(ggml_bf16_t) * config_.routed_expert_num * config_.max_len * config_.hidden_size}); m_mem_requests.push_back({(void **)&m_local_gate_output_, sizeof(ggml_bf16_t) * config_.routed_expert_num * config_.max_len * config_.intermediate_size}); m_mem_requests.push_back({(void **)&m_local_up_output_, sizeof(ggml_bf16_t) * config_.routed_expert_num * config_.max_len * config_.intermediate_size}); m_mem_requests.push_back({(void **)&m_local_down_output_, sizeof(ggml_bf16_t) * config_.routed_expert_num * config_.max_len * config_.hidden_size}); std::vector gate_up_ba_ptr(config_.expert_num); std::vector gate_bc_ptr(config_.expert_num); std::vector up_bc_ptr(config_.expert_num); std::vector down_ba_ptr(config_.expert_num); std::vector down_bc_ptr(config_.expert_num); for (int i = 0; i < config_.expert_num; i++) { m_mem_requests.push_back( {(void **)&gate_up_ba_ptr[i], T::BufferA::required_size(config_.max_len, config_.hidden_size)}); m_mem_requests.push_back( {(void **)&gate_bc_ptr[i], T::BufferC::required_size(config_.max_len, config_.intermediate_size)}); m_mem_requests.push_back( {(void **)&up_bc_ptr[i], T::BufferC::required_size(config_.max_len, config_.intermediate_size)}); m_mem_requests.push_back( {(void **)&down_ba_ptr[i], T::BufferA::required_size(config_.max_len, config_.intermediate_size)}); m_mem_requests.push_back( {(void **)&down_bc_ptr[i], T::BufferC::required_size(config_.max_len, config_.hidden_size)}); } shared_mem_buffer.alloc(this, m_mem_requests); m_local_pos_.resize(config_.max_len); for (int i = 0; i < config_.max_len; i++) { m_local_pos_[i].resize(config_.routed_expert_num); } m_expert_id_map_.resize(config_.expert_num); m_local_num_.resize(config_.expert_num); m_local_input_ptr_.resize(config_.expert_num); m_local_gate_output_ptr_.resize(config_.expert_num); m_local_up_output_ptr_.resize(config_.expert_num); m_local_down_output_ptr_.resize(config_.expert_num); for (uint64_t i = 0; i < config_.expert_num; i++) { gate_up_ba_.push_back( std::make_shared(config_.max_len, config_.hidden_size, gate_up_ba_ptr[i])); gate_bc_.push_back( std::make_shared(config_.max_len, config_.intermediate_size, gate_bc_ptr[i])); up_bc_.push_back(std::make_shared(config_.max_len, config_.intermediate_size, up_bc_ptr[i])); down_ba_.push_back( std::make_shared(config_.max_len, config_.intermediate_size, down_ba_ptr[i])); down_bc_.push_back(std::make_shared(config_.max_len, config_.hidden_size, down_bc_ptr[i])); #ifdef USE_NUMA int numa_nodes = numa_num_configured_nodes(); gate_bb_numa_.resize(numa_nodes); up_bb_numa_.resize(numa_nodes); down_bb_numa_.resize(numa_nodes); for (int j = 0; j < numa_nodes; j++) { void *gate_bb_ptr = numa_alloc_aligned(T::BufferB::required_size(config_.intermediate_size, config_.hidden_size), j, 64); gate_bb_numa_[j].push_back( std::make_shared(config_.intermediate_size, config_.hidden_size, gate_bb_ptr)); void *up_bb_ptr = numa_alloc_aligned(T::BufferB::required_size(config_.intermediate_size, config_.hidden_size), j, 64); up_bb_numa_[j].push_back( std::make_shared(config_.intermediate_size, config_.hidden_size, up_bb_ptr)); void *down_bb_ptr = numa_alloc_aligned(T::BufferB::required_size(config_.hidden_size, config_.intermediate_size), j, 64); down_bb_numa_[j].push_back( std::make_shared(config_.hidden_size, config_.intermediate_size, down_bb_ptr)); } #else void *gate_bb_ptr = std::aligned_alloc(64, T::BufferB::required_size(config_.intermediate_size, config_.hidden_size)); gate_bb_.push_back( std::make_shared(config_.intermediate_size, config_.hidden_size, gate_bb_ptr)); void *up_bb_ptr = std::aligned_alloc(64, T::BufferB::required_size(config_.intermediate_size, config_.hidden_size)); up_bb_.push_back( std::make_shared(config_.intermediate_size, config_.hidden_size, up_bb_ptr)); void *down_bb_ptr = std::aligned_alloc(64, T::BufferB::required_size(config_.hidden_size, config_.intermediate_size)); down_bb_.push_back( std::make_shared(config_.hidden_size, config_.intermediate_size, down_bb_ptr)); #endif } } ~AMX_MOE() { shared_mem_buffer.dealloc(this); } void load_weights(Backend *backend) { int nth = T::recommended_nth(config_.intermediate_size); backend->do_work_stealing_job( nth * config_.expert_num, nullptr, [&](int task_id) { uint64_t expert_idx = task_id / nth; int ith = task_id % nth; #ifdef USE_NUMA int numa_nodes = numa_num_configured_nodes(); for (int j = 0; j < numa_nodes; j++) { gate_bb_numa_[j][expert_idx]->from_mat((ggml_bf16_t *)config_.gate_proj + expert_idx * config_.intermediate_size * config_.hidden_size, ith, nth); up_bb_numa_[j][expert_idx]->from_mat((ggml_bf16_t *)config_.up_proj + expert_idx * config_.intermediate_size * config_.hidden_size, ith, nth); } #else gate_bb_[expert_idx]->from_mat((ggml_bf16_t *)config_.gate_proj + expert_idx * config_.intermediate_size * config_.hidden_size, ith, nth); up_bb_[expert_idx]->from_mat( (ggml_bf16_t *)config_.up_proj + expert_idx * config_.intermediate_size * config_.hidden_size, ith, nth); #endif }, nullptr); nth = T::recommended_nth(config_.hidden_size); backend->do_work_stealing_job( nth * config_.expert_num, nullptr, [&](int task_id) { uint64_t expert_idx = task_id / nth; int ith = task_id % nth; #ifdef USE_NUMA int numa_nodes = numa_num_configured_nodes(); for (int j = 0; j < numa_nodes; j++) { down_bb_numa_[j][expert_idx]->from_mat((ggml_bf16_t *)config_.down_proj + expert_idx * config_.hidden_size * config_.intermediate_size, ith, nth); } #else down_bb_[expert_idx]->from_mat((ggml_bf16_t *)config_.down_proj + expert_idx * config_.hidden_size * config_.intermediate_size, ith, nth); #endif }, nullptr); } void warm_up(Backend *backend) {} void forward(int qlen, int k, const uint64_t *expert_ids, const float *weights, const void *input, void *output, int *batch_size_tensor, Backend *backend) { qlen = batch_size_tensor[0]; bool use_amx = (qlen > 4 * config_.expert_num / config_.routed_expert_num); int activated_expert = 0; for (int i = 0; i < config_.expert_num; i++) { m_local_num_[i] = 0; } for (int i = 0; i < qlen; i++) { for (int j = 0; j < k; j++) { m_local_pos_[i][j] = m_local_num_[expert_ids[i * k + j]]++; } } for (int i = 0; i < config_.expert_num; i++) { if (m_local_num_[i] > 0) { m_expert_id_map_[activated_expert] = i; activated_expert++; } } uint64_t offset = 0; for (int i = 0; i < config_.expert_num; i++) { m_local_input_ptr_[i] = m_local_input_ + offset * config_.hidden_size; m_local_gate_output_ptr_[i] = m_local_gate_output_ + offset * config_.intermediate_size; m_local_up_output_ptr_[i] = m_local_up_output_ + offset * config_.intermediate_size; m_local_down_output_ptr_[i] = m_local_down_output_ + offset * config_.hidden_size; offset += m_local_num_[i]; } backend->do_work_stealing_job( qlen, nullptr, [&](int i) { for (int j = 0; j < k; j++) { memcpy(m_local_input_ptr_[expert_ids[i * k + j]] + m_local_pos_[i][j] * config_.hidden_size, (ggml_bf16_t *)input + i * config_.hidden_size, sizeof(ggml_bf16_t) * config_.hidden_size); } }, nullptr); backend->do_work_stealing_job( activated_expert, nullptr, [&](int task_id) { int expert_idx = m_expert_id_map_[task_id]; gate_up_ba_[expert_idx]->from_mat(m_local_num_[expert_idx], m_local_input_ptr_[expert_idx], 0, 1); }, nullptr); int nth = T::recommended_nth(config_.intermediate_size); backend->do_work_stealing_job( nth * activated_expert, [&](int _) { T::config(); }, [&](int task_id) { int expert_idx = m_expert_id_map_[task_id / nth]; int ith = task_id % nth; #ifdef USE_NUMA amx::mat_mul(m_local_num_[expert_idx], config_.intermediate_size, config_.hidden_size, gate_up_ba_[expert_idx], gate_bb_numa_[Backend::numa_node][expert_idx], gate_bc_[expert_idx], ith, nth, use_amx); amx::mat_mul(m_local_num_[expert_idx], config_.intermediate_size, config_.hidden_size, gate_up_ba_[expert_idx], up_bb_numa_[Backend::numa_node][expert_idx], up_bc_[expert_idx], ith, nth, use_amx); #else amx::mat_mul(m_local_num_[expert_idx], config_.intermediate_size, config_.hidden_size, gate_up_ba_[expert_idx], gate_bb_[expert_idx], gate_bc_[expert_idx], ith, nth, use_amx); amx::mat_mul(m_local_num_[expert_idx], config_.intermediate_size, config_.hidden_size, gate_up_ba_[expert_idx], up_bb_[expert_idx], up_bc_[expert_idx], ith, nth, use_amx); #endif gate_bc_[expert_idx]->to_mat(m_local_num_[expert_idx], m_local_gate_output_ptr_[expert_idx], ith, nth); up_bc_[expert_idx]->to_mat(m_local_num_[expert_idx], m_local_up_output_ptr_[expert_idx], ith, nth); auto [n_start, n_end] = T::split_range_n(config_.intermediate_size, ith, nth); for (int i = 0; i < m_local_num_[expert_idx]; i++) { ggml_bf16_t *gate_output_ptr = &m_local_gate_output_ptr_[expert_idx][i * config_.intermediate_size]; ggml_bf16_t *up_output_ptr = &m_local_up_output_ptr_[expert_idx][i * config_.intermediate_size]; for (int j = n_start; j < n_end; j += 32) { __m512 gate_val0, gate_val1, up_val0, up_val1; avx512_32xbf16_to_32xfp32((__m512i *)(gate_output_ptr + j), &gate_val0, &gate_val1); avx512_32xbf16_to_32xfp32((__m512i *)(up_output_ptr + j), &up_val0, &up_val1); __m512 result0 = act_fn(gate_val0, up_val0); __m512 result1 = act_fn(gate_val1, up_val1); avx512_32xfp32_to_32xbf16(&result0, &result1, (__m512i *)(gate_output_ptr + j)); } } }, nullptr); backend->do_work_stealing_job( activated_expert, nullptr, [&](int task_id) { int expert_idx = m_expert_id_map_[task_id]; down_ba_[expert_idx]->from_mat(m_local_num_[expert_idx], m_local_gate_output_ptr_[expert_idx], 0, 1); }, nullptr); nth = T::recommended_nth(config_.hidden_size); backend->do_work_stealing_job( nth * activated_expert, [&](int _) { T::config(); }, [&](int task_id) { int expert_idx = m_expert_id_map_[task_id / nth]; int ith = task_id % nth; #ifdef USE_NUMA amx::mat_mul(m_local_num_[expert_idx], config_.hidden_size, config_.intermediate_size, down_ba_[expert_idx], down_bb_numa_[Backend::numa_node][expert_idx], down_bc_[expert_idx], ith, nth, use_amx); #else amx::mat_mul(m_local_num_[expert_idx], config_.hidden_size, config_.intermediate_size, down_ba_[expert_idx], down_bb_[expert_idx], down_bc_[expert_idx], ith, nth, use_amx); #endif down_bc_[expert_idx]->to_mat(m_local_num_[expert_idx], m_local_down_output_ptr_[expert_idx], ith, nth); }, nullptr); backend->do_work_stealing_job( qlen, nullptr, [&](int i) { for (int e = 0; e < config_.hidden_size; e += 32) { __m512 x0 = _mm512_setzero_ps(); __m512 x1 = _mm512_setzero_ps(); for (int j = 0; j < k; j++) { __m512 weight = _mm512_set1_ps(weights[i * k + j]); __m512 down_output0, down_output1; avx512_32xbf16_to_32xfp32((__m512i *)(m_local_down_output_ptr_[expert_ids[i * k + j]] + m_local_pos_[i][j] * config_.hidden_size + e), &down_output0, &down_output1); x0 = _mm512_fmadd_ps(down_output0, weight, x0); x1 = _mm512_fmadd_ps(down_output1, weight, x1); } avx512_32xfp32_to_32xbf16(&x0, &x1, (__m512i *)((ggml_bf16_t *)output + i * config_.hidden_size + e)); } }, nullptr); } }; #endif