/** * @Description : * @Author : chenht2022 * @Date : 2024-07-22 02:03:22 * @Version : 1.0.0 * @LastEditors : kkk1nak0 * @LastEditTime : 2024-08-15 07:43:41 * @Copyright (c) 2024 by KVCache.AI, All Rights Reserved. **/ #include "moe.h" #include #include #include #ifdef USE_NUMA #include #include #endif MOE::MOE(MOEConfig config) { config_ = config; gate_proj_ = config_.gate_proj; up_proj_ = config_.up_proj; down_proj_ = config_.down_proj; #ifdef USE_NUMA int numa_nodes = numa_num_configured_nodes(); gate_proj_numa_.resize(numa_nodes); up_proj_numa_.resize(numa_nodes); down_proj_numa_.resize(numa_nodes); size_t exp_inter_hidden_mul_ = (size_t)config.expert_num * config.intermediate_size * config.hidden_size; for (int i = 0; i < numa_nodes; i++) { gate_proj_numa_[i] = numa_alloc_onnode(exp_inter_hidden_mul_* ggml_type_size(config.gate_type) / ggml_blck_size(config.gate_type), i); up_proj_numa_[i] = numa_alloc_onnode(exp_inter_hidden_mul_* ggml_type_size(config.up_type) / ggml_blck_size(config.up_type), i); down_proj_numa_[i] = numa_alloc_onnode(exp_inter_hidden_mul_* ggml_type_size(config.down_type) / ggml_blck_size(config.down_type), i); if (!gate_proj_numa_[i]) { std::cout << "Memory allocation failed for gate_proj_numa_ on node " << i << std::endl; } if (!up_proj_numa_[i]) { std::cout << "Memory allocation failed for up_proj_numa_ on node " << i << std::endl; } if (!down_proj_numa_[i]) { std::cout << "Memory allocation failed for down_proj_numa_ on node " << i << std::endl; } memcpy(gate_proj_numa_[i], gate_proj_, exp_inter_hidden_mul_* ggml_type_size(config.gate_type) / ggml_blck_size(config.gate_type)); memcpy(up_proj_numa_[i], up_proj_, exp_inter_hidden_mul_* ggml_type_size(config.up_type) / ggml_blck_size(config.up_type)); memcpy(down_proj_numa_[i], down_proj_, exp_inter_hidden_mul_* ggml_type_size(config.down_type) / ggml_blck_size(config.down_type)); } #endif std::vector> s_mem_requests; s_mem_requests.push_back({(void**)&s_input_fp32_, sizeof(float) * config_.hidden_size}); s_mem_requests.push_back({(void**)&s_gate_input_, config_.hidden_size * ggml_type_size(ggml_internal_get_type_traits(config_.gate_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(config_.gate_type).vec_dot_type)}); s_mem_requests.push_back({(void**)&s_up_input_, config_.hidden_size * ggml_type_size(ggml_internal_get_type_traits(config_.up_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(config_.up_type).vec_dot_type)}); s_gate_output_.resize(config_.routed_expert_num); s_up_output_.resize(config_.routed_expert_num); s_intermediate_fp32_.resize(config_.routed_expert_num); s_down_input_.resize(config_.routed_expert_num); s_down_output_.resize(config_.routed_expert_num); for (int i = 0; i < config_.routed_expert_num; i++) { s_mem_requests.push_back({(void**)&s_gate_output_[i], sizeof(float) * config_.intermediate_size}); s_mem_requests.push_back({(void**)&s_up_output_[i], sizeof(float) * config_.intermediate_size}); s_mem_requests.push_back({(void**)&s_intermediate_fp32_[i], sizeof(float) * config_.intermediate_size}); s_mem_requests.push_back({(void**)&s_down_input_[i], config_.intermediate_size * ggml_type_size(ggml_internal_get_type_traits(config_.down_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(config_.down_type).vec_dot_type)}); s_mem_requests.push_back({(void**)&s_down_output_[i], sizeof(float) * config_.hidden_size}); } s_mem_requests.push_back({(void**)&s_output_fp32_, sizeof(float) * config_.hidden_size}); shared_mem_buffer.alloc(this, s_mem_requests); std::vector> m_mem_requests; m_input_fp32_.resize(config_.group_max_len); m_gate_input_.resize(config_.group_max_len); m_up_input_.resize(config_.group_max_len); for (int i = 0; i < config_.group_max_len; i++) { m_mem_requests.push_back({(void**)&m_input_fp32_[i], sizeof(float) * config_.hidden_size}); m_mem_requests.push_back({(void**)&m_gate_input_[i], config_.hidden_size * ggml_type_size(ggml_internal_get_type_traits(config_.gate_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(config_.gate_type).vec_dot_type)}); m_mem_requests.push_back({(void**)&m_up_input_[i], config_.hidden_size * ggml_type_size(ggml_internal_get_type_traits(config_.up_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(config_.up_type).vec_dot_type)}); } m_mem_requests.push_back({(void**)&m_local_gate_input_, config_.routed_expert_num * config_.group_max_len * config_.hidden_size * ggml_type_size(ggml_internal_get_type_traits(config_.gate_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(config_.gate_type).vec_dot_type)}); m_mem_requests.push_back({(void**)&m_local_up_input_, config_.routed_expert_num * config_.group_max_len * config_.hidden_size * ggml_type_size(ggml_internal_get_type_traits(config_.up_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(config_.up_type).vec_dot_type)}); m_mem_requests.push_back({(void**)&m_local_gate_output_, sizeof(float) * config_.routed_expert_num * config_.group_max_len * config_.intermediate_size}); m_mem_requests.push_back({(void**)&m_local_up_output_, sizeof(float) * config_.routed_expert_num * config_.group_max_len * config_.intermediate_size}); m_mem_requests.push_back({(void**)&m_local_intermediate_fp32_, sizeof(float) * config_.routed_expert_num * config_.group_max_len * config_.intermediate_size}); m_mem_requests.push_back({(void**)&m_local_down_input_, config_.routed_expert_num * config_.group_max_len * config_.intermediate_size * ggml_type_size(ggml_internal_get_type_traits(config_.down_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(config_.down_type).vec_dot_type)}); m_mem_requests.push_back({(void**)&m_local_down_output_, sizeof(float) * config_.routed_expert_num * config_.group_max_len * config_.hidden_size}); m_output_fp32_.resize(config_.group_max_len); for (int i = 0; i < config_.group_max_len; i++) { m_mem_requests.push_back({(void**)&m_output_fp32_[i], sizeof(float) * config_.hidden_size}); } shared_mem_buffer.alloc(this, m_mem_requests); m_local_pos_.resize(config_.group_max_len); for (int i = 0; i < config_.group_max_len; i++) { m_local_pos_[i].resize(config_.routed_expert_num); } m_local_num_.resize(config_.expert_num); m_local_gate_input_ptr_.resize(config_.expert_num); m_local_up_input_ptr_.resize(config_.expert_num); m_local_gate_output_ptr_.resize(config_.expert_num); m_local_up_output_ptr_.resize(config_.expert_num); m_local_intermediate_fp32_ptr_.resize(config_.expert_num); m_local_down_input_ptr_.resize(config_.expert_num); m_local_down_output_ptr_.resize(config_.expert_num); } MOE::~MOE() { shared_mem_buffer.dealloc(this); #ifdef USE_NUMA int numa_nodes = numa_num_configured_nodes(); for (int i = 0; i < numa_nodes; i++) { numa_free(gate_proj_numa_[i], config_.expert_num * config_.intermediate_size * config_.hidden_size * ggml_type_size(config_.gate_type) / ggml_blck_size(config_.gate_type)); numa_free(up_proj_numa_[i], config_.expert_num * config_.intermediate_size * config_.hidden_size * ggml_type_size(config_.up_type) / ggml_blck_size(config_.up_type)); numa_free(down_proj_numa_[i], config_.expert_num * config_.hidden_size * config_.intermediate_size * ggml_type_size(config_.down_type) / ggml_blck_size(config_.down_type)); } #endif } void MOE::warm_up(Backend* backend) { std::vector input_fp32(config_.hidden_size); std::vector input(config_.hidden_size * ggml_type_size(config_.hidden_type) / ggml_blck_size(config_.hidden_type)); std::vector output(config_.hidden_size * ggml_type_size(config_.hidden_type) / ggml_blck_size(config_.hidden_type)); for (int i = 0; i < config_.hidden_size; i++) { input_fp32[i] = 0; } from_float(input_fp32.data(), input.data(), config_.hidden_size, config_.hidden_type); for (int i = 0; i < config_.expert_num; i++) { uint64_t expert_ids = i; float weights = 0; forward_one(1, &expert_ids, &weights, input.data(), output.data(), backend); } } static float act_fn(float x) { return x / (1.0f + expf(-x)); } static float act_fn_relu(float x) { if(x > 0.0){ return x; } else { return 0.0; } } void MOE::forward_one(int k, const uint64_t* expert_ids, const float* weights, const void* input, void* output, Backend* backend) { const void* gate_input_ptr; const void* up_input_ptr; if (config_.hidden_type == ggml_internal_get_type_traits(config_.gate_type).vec_dot_type && config_.hidden_type == ggml_internal_get_type_traits(config_.up_type).vec_dot_type) { gate_input_ptr = up_input_ptr = input; } else { to_float(input, s_input_fp32_, config_.hidden_size, config_.hidden_type); if (ggml_internal_get_type_traits(config_.gate_type).vec_dot_type == ggml_internal_get_type_traits(config_.up_type).vec_dot_type) { from_float(s_input_fp32_, s_gate_input_, config_.hidden_size, ggml_internal_get_type_traits(config_.gate_type).vec_dot_type); gate_input_ptr = up_input_ptr = s_gate_input_; } else { if (config_.hidden_type != ggml_internal_get_type_traits(config_.gate_type).vec_dot_type) { from_float(s_input_fp32_, s_gate_input_, config_.hidden_size, ggml_internal_get_type_traits(config_.gate_type).vec_dot_type); gate_input_ptr = s_gate_input_; } else { gate_input_ptr = input; } if (config_.hidden_type != ggml_internal_get_type_traits(config_.up_type).vec_dot_type) { from_float(s_input_fp32_, s_up_input_, config_.hidden_size, ggml_internal_get_type_traits(config_.up_type).vec_dot_type); up_input_ptr = s_up_input_; } else { up_input_ptr = input; } } } int nth = config_.intermediate_size / config_.stride; backend->do_work_stealing_job(nth * k, nullptr, [&](int task_id) { int expert_idx = task_id / nth; uint64_t expert_id = expert_ids[expert_idx]; int ith = task_id % nth; #ifdef USE_NUMA void* gate_proj_ptr = (uint8_t*)gate_proj_numa_[Backend::numa_node] + (expert_id * config_.intermediate_size + ith * config_.stride) * config_.hidden_size * ggml_type_size(config_.gate_type) / ggml_blck_size(config_.gate_type); #else void* gate_proj_ptr = (uint8_t*)gate_proj_ + (expert_id * config_.intermediate_size + ith * config_.stride) * config_.hidden_size * ggml_type_size(config_.gate_type) / ggml_blck_size(config_.gate_type); #endif float* gate_output_ptr = s_gate_output_[expert_idx] + ith * config_.stride; llamafile_sgemm(config_.stride, 1, config_.hidden_size / ggml_blck_size(config_.gate_type), gate_proj_ptr, config_.hidden_size / ggml_blck_size(config_.gate_type), gate_input_ptr, config_.hidden_size / ggml_blck_size(config_.gate_type), gate_output_ptr, config_.stride, 0, 1, GGML_TASK_TYPE_COMPUTE, config_.gate_type, ggml_internal_get_type_traits(config_.gate_type).vec_dot_type, GGML_TYPE_F32, GGML_PREC_DEFAULT); #ifdef USE_NUMA void* up_proj_ptr = (uint8_t*)up_proj_numa_[Backend::numa_node] + (expert_id * config_.intermediate_size + ith * config_.stride) * config_.hidden_size * ggml_type_size(config_.up_type) / ggml_blck_size(config_.up_type); #else void* up_proj_ptr = (uint8_t*)up_proj_ + (expert_id * config_.intermediate_size + ith * config_.stride) * config_.hidden_size * ggml_type_size(config_.up_type) / ggml_blck_size(config_.up_type); #endif float* up_output_ptr = s_up_output_[expert_idx] + ith * config_.stride; llamafile_sgemm(config_.stride, 1, config_.hidden_size / ggml_blck_size(config_.up_type), up_proj_ptr, config_.hidden_size / ggml_blck_size(config_.up_type), up_input_ptr, config_.hidden_size / ggml_blck_size(config_.up_type), up_output_ptr, config_.stride, 0, 1, GGML_TASK_TYPE_COMPUTE, config_.up_type, ggml_internal_get_type_traits(config_.up_type).vec_dot_type, GGML_TYPE_F32, GGML_PREC_DEFAULT); if(config_.use_silu){ // use silu as act fn for (int i = ith * config_.stride; i < (ith + 1) * config_.stride; i++) { s_intermediate_fp32_[expert_idx][i] = act_fn(s_gate_output_[expert_idx][i]) * s_up_output_[expert_idx][i]; } } else { // use relu as act fn for (int i = ith * config_.stride; i < (ith + 1) * config_.stride; i++) { s_intermediate_fp32_[expert_idx][i] = act_fn_relu(s_gate_output_[expert_idx][i]) * s_up_output_[expert_idx][i]; } } if (config_.stride % ggml_blck_size(ggml_internal_get_type_traits(config_.down_type).vec_dot_type) == 0) { float* intermediate_fp32_ptr = s_intermediate_fp32_[expert_idx] + ith * config_.stride; void* down_input_ptr = s_down_input_[expert_idx] + ith * config_.stride * ggml_type_size(ggml_internal_get_type_traits(config_.down_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(config_.down_type).vec_dot_type); from_float(intermediate_fp32_ptr, down_input_ptr, config_.stride, ggml_internal_get_type_traits(config_.down_type).vec_dot_type); } }, nullptr); if (config_.stride % ggml_blck_size(ggml_internal_get_type_traits(config_.down_type).vec_dot_type) != 0) { for (int i = 0; i < k; i++) { from_float(s_intermediate_fp32_[i], s_down_input_[i], config_.intermediate_size, ggml_internal_get_type_traits(config_.down_type).vec_dot_type); } } nth = config_.hidden_size / config_.stride; backend->do_work_stealing_job(nth, nullptr, [&](int task_id) { int ith = task_id; for (int i = ith * config_.stride; i < (ith + 1) * config_.stride; i++) { s_output_fp32_[i] = 0; } for (int expert_idx = 0; expert_idx < k; expert_idx++) { uint64_t expert_id = expert_ids[expert_idx]; #ifdef USE_NUMA void* down_proj_ptr = (uint8_t*)down_proj_numa_[Backend::numa_node] + (expert_id * config_.hidden_size + ith * config_.stride) * config_.intermediate_size * ggml_type_size(config_.down_type) / ggml_blck_size(config_.down_type); #else void* down_proj_ptr = (uint8_t*)down_proj_ + (expert_id * config_.hidden_size + ith * config_.stride) * config_.intermediate_size * ggml_type_size(config_.down_type) / ggml_blck_size(config_.down_type); #endif float* down_output_ptr = s_down_output_[expert_idx] + ith * config_.stride; llamafile_sgemm(config_.stride, 1, config_.intermediate_size / ggml_blck_size(config_.down_type), down_proj_ptr, config_.intermediate_size / ggml_blck_size(config_.down_type), s_down_input_[expert_idx], config_.intermediate_size / ggml_blck_size(config_.down_type), down_output_ptr, config_.stride, 0, 1, GGML_TASK_TYPE_COMPUTE, config_.down_type, ggml_internal_get_type_traits(config_.down_type).vec_dot_type, GGML_TYPE_F32, GGML_PREC_DEFAULT); for (int i = ith * config_.stride; i < (ith + 1) * config_.stride; i++) { s_output_fp32_[i] += s_down_output_[expert_idx][i] * weights[expert_idx]; } } if (config_.stride % ggml_blck_size(config_.hidden_type) == 0) { float* output_fp32_ptr = s_output_fp32_ + ith * config_.stride; void* output_ptr = (uint8_t*)output + ith * config_.stride * ggml_type_size(config_.hidden_type) / ggml_blck_size(config_.hidden_type); from_float(output_fp32_ptr, output_ptr, config_.stride, config_.hidden_type); } }, nullptr); if (config_.stride % ggml_blck_size(config_.hidden_type) != 0) { from_float(s_output_fp32_, output, config_.hidden_size, config_.hidden_type); } } void MOE::forward_many(int qlen, int k, const uint64_t* expert_ids, const float* weights, const void* input, void* output, Backend* backend) { for (int i = 0; i < config_.expert_num; i++) { m_local_num_[i] = 0; } for (int i = 0; i < qlen; i++) { for (int j = 0; j < k; j++) { m_local_pos_[i][j] = m_local_num_[expert_ids[i * k + j]]++; } } uint64_t offset = 0; for (int i = 0; i < config_.expert_num; i++) { m_local_gate_input_ptr_[i] = m_local_gate_input_ + offset * config_.hidden_size * ggml_type_size(ggml_internal_get_type_traits(config_.gate_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(config_.gate_type).vec_dot_type); m_local_up_input_ptr_[i] = m_local_up_input_ + offset * config_.hidden_size * ggml_type_size(ggml_internal_get_type_traits(config_.up_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(config_.up_type).vec_dot_type); m_local_gate_output_ptr_[i] = m_local_gate_output_ + offset * config_.intermediate_size; m_local_up_output_ptr_[i] = m_local_up_output_ + offset * config_.intermediate_size; m_local_intermediate_fp32_ptr_[i] = m_local_intermediate_fp32_ + offset * config_.intermediate_size; m_local_down_input_ptr_[i] = m_local_down_input_ + offset * config_.intermediate_size * ggml_type_size(ggml_internal_get_type_traits(config_.down_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(config_.down_type).vec_dot_type); m_local_down_output_ptr_[i] = m_local_down_output_ + offset * config_.hidden_size; offset += m_local_num_[i]; } backend->do_work_stealing_job(qlen, nullptr, [&](int i) { const void* gate_input_ptr; const void* up_input_ptr; if (config_.hidden_type == ggml_internal_get_type_traits(config_.gate_type).vec_dot_type && config_.hidden_type == ggml_internal_get_type_traits(config_.up_type).vec_dot_type) { gate_input_ptr = up_input_ptr = (uint8_t*)input + i * config_.hidden_size * ggml_type_size(config_.hidden_type) / ggml_blck_size(config_.hidden_type); } else { to_float((uint8_t*)input + i * config_.hidden_size * ggml_type_size(config_.hidden_type) / ggml_blck_size(config_.hidden_type), m_input_fp32_[i], config_.hidden_size, config_.hidden_type); if (ggml_internal_get_type_traits(config_.gate_type).vec_dot_type == ggml_internal_get_type_traits(config_.up_type).vec_dot_type) { from_float(m_input_fp32_[i], m_gate_input_[i], config_.hidden_size, ggml_internal_get_type_traits(config_.gate_type).vec_dot_type); gate_input_ptr = up_input_ptr = m_gate_input_[i]; } else { if (config_.hidden_type != ggml_internal_get_type_traits(config_.gate_type).vec_dot_type) { from_float(m_input_fp32_[i], m_gate_input_[i], config_.hidden_size, ggml_internal_get_type_traits(config_.gate_type).vec_dot_type); gate_input_ptr = m_gate_input_[i]; } else { gate_input_ptr = (uint8_t*)input + i * config_.hidden_size * ggml_type_size(config_.hidden_type) / ggml_blck_size(config_.hidden_type); } if (config_.hidden_type != ggml_internal_get_type_traits(config_.up_type).vec_dot_type) { from_float(m_input_fp32_[i], m_up_input_[i], config_.hidden_size, ggml_internal_get_type_traits(config_.up_type).vec_dot_type); up_input_ptr = m_up_input_[i]; } else { up_input_ptr = (uint8_t*)input + i * config_.hidden_size * ggml_type_size(config_.hidden_type) / ggml_blck_size(config_.hidden_type); } } } for (int j = 0; j < k; j++) { memcpy(m_local_gate_input_ptr_[expert_ids[i * k + j]] + m_local_pos_[i][j] * config_.hidden_size * ggml_type_size(ggml_internal_get_type_traits(config_.gate_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(config_.gate_type).vec_dot_type), gate_input_ptr, config_.hidden_size * ggml_type_size(ggml_internal_get_type_traits(config_.gate_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(config_.gate_type).vec_dot_type)); memcpy(m_local_up_input_ptr_[expert_ids[i * k + j]] + m_local_pos_[i][j] * config_.hidden_size * ggml_type_size(ggml_internal_get_type_traits(config_.up_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(config_.up_type).vec_dot_type), up_input_ptr, config_.hidden_size * ggml_type_size(ggml_internal_get_type_traits(config_.up_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(config_.up_type).vec_dot_type)); } }, nullptr); int stride = QK_K; int nth = config_.intermediate_size / stride; backend->do_work_stealing_job(nth * config_.expert_num, nullptr, [&](int task_id) { uint64_t expert_idx = task_id / nth; int ith = task_id % nth; void* gate_input_ptr = m_local_gate_input_ptr_[expert_idx]; #ifdef USE_NUMA void* gate_proj_ptr = (uint8_t*)gate_proj_numa_[Backend::numa_node] + (expert_idx * config_.intermediate_size + ith * stride) * config_.hidden_size * ggml_type_size(config_.gate_type) / ggml_blck_size(config_.gate_type); #else void* gate_proj_ptr = (uint8_t*)gate_proj_ + (expert_idx * config_.intermediate_size + ith * stride) * config_.hidden_size * ggml_type_size(config_.gate_type) / ggml_blck_size(config_.gate_type); #endif float* gate_output_ptr = m_local_gate_output_ptr_[expert_idx] + ith * stride; llamafile_sgemm(stride, m_local_num_[expert_idx], config_.hidden_size / ggml_blck_size(config_.gate_type), gate_proj_ptr, config_.hidden_size / ggml_blck_size(config_.gate_type), gate_input_ptr, config_.hidden_size / ggml_blck_size(config_.gate_type), gate_output_ptr, config_.intermediate_size, 0, 1, GGML_TASK_TYPE_COMPUTE, config_.gate_type, ggml_internal_get_type_traits(config_.gate_type).vec_dot_type, GGML_TYPE_F32, GGML_PREC_DEFAULT); void* up_input_ptr = m_local_up_input_ptr_[expert_idx]; #ifdef USE_NUMA void* up_proj_ptr = (uint8_t*)up_proj_numa_[Backend::numa_node] + (expert_idx * config_.intermediate_size + ith * stride) * config_.hidden_size * ggml_type_size(config_.up_type) / ggml_blck_size(config_.up_type); #else void* up_proj_ptr = (uint8_t*)up_proj_ + (expert_idx * config_.intermediate_size + ith * stride) * config_.hidden_size * ggml_type_size(config_.up_type) / ggml_blck_size(config_.up_type); #endif float* up_output_ptr = m_local_up_output_ptr_[expert_idx] + ith * stride; llamafile_sgemm(stride, m_local_num_[expert_idx], config_.hidden_size / ggml_blck_size(config_.up_type), up_proj_ptr, config_.hidden_size / ggml_blck_size(config_.up_type), up_input_ptr, config_.hidden_size / ggml_blck_size(config_.up_type), up_output_ptr, config_.intermediate_size, 0, 1, GGML_TASK_TYPE_COMPUTE, config_.up_type, ggml_internal_get_type_traits(config_.up_type).vec_dot_type, GGML_TYPE_F32, GGML_PREC_DEFAULT); for (int i = 0; i < m_local_num_[expert_idx]; i++) { if(config_.use_silu){ for (int j = ith * stride; j < (ith + 1) * stride; j++) { m_local_intermediate_fp32_ptr_[expert_idx][i * config_.intermediate_size + j] = act_fn(m_local_gate_output_ptr_[expert_idx][i * config_.intermediate_size + j]) * m_local_up_output_ptr_[expert_idx][i * config_.intermediate_size + j]; } } else { for (int j = ith * stride; j < (ith + 1) * stride; j++) { m_local_intermediate_fp32_ptr_[expert_idx][i * config_.intermediate_size + j] = act_fn_relu(m_local_gate_output_ptr_[expert_idx][i * config_.intermediate_size + j]) * m_local_up_output_ptr_[expert_idx][i * config_.intermediate_size + j]; } } float* intermediate_fp32_ptr = m_local_intermediate_fp32_ptr_[expert_idx] + i * config_.intermediate_size + ith * stride; void* down_input_ptr = m_local_down_input_ptr_[expert_idx] + i * config_.intermediate_size * ggml_type_size(ggml_internal_get_type_traits(config_.down_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(config_.down_type).vec_dot_type) + ith * stride * ggml_type_size(ggml_internal_get_type_traits(config_.down_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(config_.down_type).vec_dot_type); from_float(intermediate_fp32_ptr, down_input_ptr, stride, ggml_internal_get_type_traits(config_.down_type).vec_dot_type); } }, nullptr); stride = QK_K; nth = config_.hidden_size / stride; backend->do_work_stealing_job(nth * config_.expert_num, nullptr, [&](int task_id) { uint64_t expert_idx = task_id / nth; int ith = task_id % nth; void* down_input_ptr = m_local_down_input_ptr_[expert_idx]; #ifdef USE_NUMA void* down_proj_ptr = (uint8_t*)down_proj_numa_[Backend::numa_node] + (expert_idx * config_.hidden_size + ith * stride) * config_.intermediate_size * ggml_type_size(config_.down_type) / ggml_blck_size(config_.down_type); #else void* down_proj_ptr = (uint8_t*)down_proj_ + (expert_idx * config_.hidden_size + ith * stride) * config_.intermediate_size * ggml_type_size(config_.down_type) / ggml_blck_size(config_.down_type); #endif float* down_output_ptr = m_local_down_output_ptr_[expert_idx] + ith * stride; llamafile_sgemm(stride, m_local_num_[expert_idx], config_.intermediate_size / ggml_blck_size(config_.down_type), down_proj_ptr, config_.intermediate_size / ggml_blck_size(config_.down_type), down_input_ptr, config_.intermediate_size / ggml_blck_size(config_.down_type), down_output_ptr, config_.hidden_size, 0, 1, GGML_TASK_TYPE_COMPUTE, config_.down_type, ggml_internal_get_type_traits(config_.down_type).vec_dot_type, GGML_TYPE_F32, GGML_PREC_DEFAULT); }, nullptr); backend->do_work_stealing_job(qlen, nullptr, [&](int i) { for (int e = 0; e < config_.hidden_size; e++) { m_output_fp32_[i][e] = 0; } for (int j = 0; j < k; j++) { for (int e = 0; e < config_.hidden_size; e++) { m_output_fp32_[i][e] += m_local_down_output_ptr_[expert_ids[i * k + j]][m_local_pos_[i][j] * config_.hidden_size + e] * weights[i * k + j]; } } from_float(m_output_fp32_[i], (uint8_t*)output + i * config_.hidden_size * ggml_type_size(config_.hidden_type) / ggml_blck_size(config_.hidden_type), config_.hidden_size, config_.hidden_type); }, nullptr); } void MOE::forward(int qlen, int k, const uint64_t* expert_ids, const float* weights, const void* input, void* output, int* batch_size_tensor, Backend* backend) { qlen = batch_size_tensor[0]; if (qlen < config_.group_min_len) { for (int i = 0; i < qlen; i++) { forward_one(k, expert_ids + i * k, weights + i * k, (uint8_t*)input + i * config_.hidden_size * ggml_type_size(config_.hidden_type) / ggml_blck_size(config_.hidden_type), (uint8_t*)output + i * config_.hidden_size * ggml_type_size(config_.hidden_type) / ggml_blck_size(config_.hidden_type), backend); } return; } int forward_len = std::min(config_.group_max_len, qlen); forward_many(forward_len, k, expert_ids, weights, input, output, backend); batch_size_tensor[0] -= forward_len; forward(qlen - forward_len, k, expert_ids + forward_len * k, weights + forward_len * k, (uint8_t*)input + forward_len * config_.hidden_size * ggml_type_size(config_.hidden_type) / ggml_blck_size(config_.hidden_type), (uint8_t*)output + forward_len * config_.hidden_size * ggml_type_size(config_.hidden_type) / ggml_blck_size(config_.hidden_type), batch_size_tensor, backend); }