kvcache-ai-ktransformers/csrc/ktransformers_ext/operators/llamafile/moe.cpp
2025-03-31 22:55:32 +08:00

357 lines
No EOL
27 KiB
C++

/**
* @Description :
* @Author : chenht2022
* @Date : 2024-07-22 02:03:22
* @Version : 1.0.0
* @LastEditors : kkk1nak0
* @LastEditTime : 2024-08-15 07:43:41
* @Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
**/
#include "moe.h"
#include <iostream>
#include <cstdint>
#ifdef USE_NUMA
#include <numa.h>
#include <numaif.h>
#endif
MOE::MOE(MOEConfig config) {
config_ = config;
gate_proj_ = config_.gate_proj;
up_proj_ = config_.up_proj;
down_proj_ = config_.down_proj;
#ifdef USE_NUMA
int numa_nodes = numa_num_configured_nodes();
gate_proj_numa_.resize(numa_nodes);
up_proj_numa_.resize(numa_nodes);
down_proj_numa_.resize(numa_nodes);
size_t exp_inter_hidden_mul_ = (size_t)config.expert_num * config.intermediate_size * config.hidden_size;
for (int i = 0; i < numa_nodes; i++) {
gate_proj_numa_[i] = numa_alloc_onnode(exp_inter_hidden_mul_* ggml_type_size(config.gate_type) / ggml_blck_size(config.gate_type), i);
up_proj_numa_[i] = numa_alloc_onnode(exp_inter_hidden_mul_* ggml_type_size(config.up_type) / ggml_blck_size(config.up_type), i);
down_proj_numa_[i] = numa_alloc_onnode(exp_inter_hidden_mul_* ggml_type_size(config.down_type) / ggml_blck_size(config.down_type), i);
if (!gate_proj_numa_[i]) {
std::cout << "Memory allocation failed for gate_proj_numa_ on node " << i << std::endl;
}
if (!up_proj_numa_[i]) {
std::cout << "Memory allocation failed for up_proj_numa_ on node " << i << std::endl;
}
if (!down_proj_numa_[i]) {
std::cout << "Memory allocation failed for down_proj_numa_ on node " << i << std::endl;
}
memcpy(gate_proj_numa_[i], gate_proj_, exp_inter_hidden_mul_* ggml_type_size(config.gate_type) / ggml_blck_size(config.gate_type));
memcpy(up_proj_numa_[i], up_proj_, exp_inter_hidden_mul_* ggml_type_size(config.up_type) / ggml_blck_size(config.up_type));
memcpy(down_proj_numa_[i], down_proj_, exp_inter_hidden_mul_* ggml_type_size(config.down_type) / ggml_blck_size(config.down_type));
}
#endif
std::vector<std::pair<void**, uint64_t>> s_mem_requests;
s_mem_requests.push_back({(void**)&s_input_fp32_, sizeof(float) * config_.hidden_size});
s_mem_requests.push_back({(void**)&s_gate_input_, config_.hidden_size * ggml_type_size(ggml_internal_get_type_traits(config_.gate_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(config_.gate_type).vec_dot_type)});
s_mem_requests.push_back({(void**)&s_up_input_, config_.hidden_size * ggml_type_size(ggml_internal_get_type_traits(config_.up_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(config_.up_type).vec_dot_type)});
s_gate_output_.resize(config_.routed_expert_num);
s_up_output_.resize(config_.routed_expert_num);
s_intermediate_fp32_.resize(config_.routed_expert_num);
s_down_input_.resize(config_.routed_expert_num);
s_down_output_.resize(config_.routed_expert_num);
for (int i = 0; i < config_.routed_expert_num; i++) {
s_mem_requests.push_back({(void**)&s_gate_output_[i], sizeof(float) * config_.intermediate_size});
s_mem_requests.push_back({(void**)&s_up_output_[i], sizeof(float) * config_.intermediate_size});
s_mem_requests.push_back({(void**)&s_intermediate_fp32_[i], sizeof(float) * config_.intermediate_size});
s_mem_requests.push_back({(void**)&s_down_input_[i], config_.intermediate_size * ggml_type_size(ggml_internal_get_type_traits(config_.down_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(config_.down_type).vec_dot_type)});
s_mem_requests.push_back({(void**)&s_down_output_[i], sizeof(float) * config_.hidden_size});
}
s_mem_requests.push_back({(void**)&s_output_fp32_, sizeof(float) * config_.hidden_size});
shared_mem_buffer.alloc(this, s_mem_requests);
std::vector<std::pair<void**, uint64_t>> m_mem_requests;
m_input_fp32_.resize(config_.group_max_len);
m_gate_input_.resize(config_.group_max_len);
m_up_input_.resize(config_.group_max_len);
for (int i = 0; i < config_.group_max_len; i++) {
m_mem_requests.push_back({(void**)&m_input_fp32_[i], sizeof(float) * config_.hidden_size});
m_mem_requests.push_back({(void**)&m_gate_input_[i], config_.hidden_size * ggml_type_size(ggml_internal_get_type_traits(config_.gate_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(config_.gate_type).vec_dot_type)});
m_mem_requests.push_back({(void**)&m_up_input_[i], config_.hidden_size * ggml_type_size(ggml_internal_get_type_traits(config_.up_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(config_.up_type).vec_dot_type)});
}
m_mem_requests.push_back({(void**)&m_local_gate_input_, config_.routed_expert_num * config_.group_max_len * config_.hidden_size * ggml_type_size(ggml_internal_get_type_traits(config_.gate_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(config_.gate_type).vec_dot_type)});
m_mem_requests.push_back({(void**)&m_local_up_input_, config_.routed_expert_num * config_.group_max_len * config_.hidden_size * ggml_type_size(ggml_internal_get_type_traits(config_.up_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(config_.up_type).vec_dot_type)});
m_mem_requests.push_back({(void**)&m_local_gate_output_, sizeof(float) * config_.routed_expert_num * config_.group_max_len * config_.intermediate_size});
m_mem_requests.push_back({(void**)&m_local_up_output_, sizeof(float) * config_.routed_expert_num * config_.group_max_len * config_.intermediate_size});
m_mem_requests.push_back({(void**)&m_local_intermediate_fp32_, sizeof(float) * config_.routed_expert_num * config_.group_max_len * config_.intermediate_size});
m_mem_requests.push_back({(void**)&m_local_down_input_, config_.routed_expert_num * config_.group_max_len * config_.intermediate_size * ggml_type_size(ggml_internal_get_type_traits(config_.down_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(config_.down_type).vec_dot_type)});
m_mem_requests.push_back({(void**)&m_local_down_output_, sizeof(float) * config_.routed_expert_num * config_.group_max_len * config_.hidden_size});
m_output_fp32_.resize(config_.group_max_len);
for (int i = 0; i < config_.group_max_len; i++) {
m_mem_requests.push_back({(void**)&m_output_fp32_[i], sizeof(float) * config_.hidden_size});
}
shared_mem_buffer.alloc(this, m_mem_requests);
m_local_pos_.resize(config_.group_max_len);
for (int i = 0; i < config_.group_max_len; i++) {
m_local_pos_[i].resize(config_.routed_expert_num);
}
m_local_num_.resize(config_.expert_num);
m_local_gate_input_ptr_.resize(config_.expert_num);
m_local_up_input_ptr_.resize(config_.expert_num);
m_local_gate_output_ptr_.resize(config_.expert_num);
m_local_up_output_ptr_.resize(config_.expert_num);
m_local_intermediate_fp32_ptr_.resize(config_.expert_num);
m_local_down_input_ptr_.resize(config_.expert_num);
m_local_down_output_ptr_.resize(config_.expert_num);
}
MOE::~MOE() {
shared_mem_buffer.dealloc(this);
#ifdef USE_NUMA
int numa_nodes = numa_num_configured_nodes();
for (int i = 0; i < numa_nodes; i++) {
numa_free(gate_proj_numa_[i], config_.expert_num * config_.intermediate_size * config_.hidden_size * ggml_type_size(config_.gate_type) / ggml_blck_size(config_.gate_type));
numa_free(up_proj_numa_[i], config_.expert_num * config_.intermediate_size * config_.hidden_size * ggml_type_size(config_.up_type) / ggml_blck_size(config_.up_type));
numa_free(down_proj_numa_[i], config_.expert_num * config_.hidden_size * config_.intermediate_size * ggml_type_size(config_.down_type) / ggml_blck_size(config_.down_type));
}
#endif
}
void MOE::warm_up(Backend* backend) {
std::vector<float> input_fp32(config_.hidden_size);
std::vector<uint8_t> input(config_.hidden_size * ggml_type_size(config_.hidden_type) / ggml_blck_size(config_.hidden_type));
std::vector<uint8_t> output(config_.hidden_size * ggml_type_size(config_.hidden_type) / ggml_blck_size(config_.hidden_type));
for (int i = 0; i < config_.hidden_size; i++) {
input_fp32[i] = 0;
}
from_float(input_fp32.data(), input.data(), config_.hidden_size, config_.hidden_type);
for (int i = 0; i < config_.expert_num; i++) {
uint64_t expert_ids = i;
float weights = 0;
forward_one(1, &expert_ids, &weights, input.data(), output.data(), backend);
}
}
static float act_fn(float x) {
return x / (1.0f + expf(-x));
}
void MOE::forward_one(int k, const uint64_t* expert_ids, const float* weights, const void* input, void* output, Backend* backend) {
const void* gate_input_ptr;
const void* up_input_ptr;
if (config_.hidden_type == ggml_internal_get_type_traits(config_.gate_type).vec_dot_type && config_.hidden_type == ggml_internal_get_type_traits(config_.up_type).vec_dot_type) {
gate_input_ptr = up_input_ptr = input;
} else {
to_float(input, s_input_fp32_, config_.hidden_size, config_.hidden_type);
if (ggml_internal_get_type_traits(config_.gate_type).vec_dot_type == ggml_internal_get_type_traits(config_.up_type).vec_dot_type) {
from_float(s_input_fp32_, s_gate_input_, config_.hidden_size, ggml_internal_get_type_traits(config_.gate_type).vec_dot_type);
gate_input_ptr = up_input_ptr = s_gate_input_;
} else {
if (config_.hidden_type != ggml_internal_get_type_traits(config_.gate_type).vec_dot_type) {
from_float(s_input_fp32_, s_gate_input_, config_.hidden_size, ggml_internal_get_type_traits(config_.gate_type).vec_dot_type);
gate_input_ptr = s_gate_input_;
} else {
gate_input_ptr = input;
}
if (config_.hidden_type != ggml_internal_get_type_traits(config_.up_type).vec_dot_type) {
from_float(s_input_fp32_, s_up_input_, config_.hidden_size, ggml_internal_get_type_traits(config_.up_type).vec_dot_type);
up_input_ptr = s_up_input_;
} else {
up_input_ptr = input;
}
}
}
int nth = config_.intermediate_size / config_.stride;
backend->do_work_stealing_job(nth * k, nullptr, [&](int task_id) {
int expert_idx = task_id / nth;
uint64_t expert_id = expert_ids[expert_idx];
int ith = task_id % nth;
#ifdef USE_NUMA
void* gate_proj_ptr = (uint8_t*)gate_proj_numa_[Backend::numa_node] + (expert_id * config_.intermediate_size + ith * config_.stride) * config_.hidden_size * ggml_type_size(config_.gate_type) / ggml_blck_size(config_.gate_type);
#else
void* gate_proj_ptr = (uint8_t*)gate_proj_ + (expert_id * config_.intermediate_size + ith * config_.stride) * config_.hidden_size * ggml_type_size(config_.gate_type) / ggml_blck_size(config_.gate_type);
#endif
float* gate_output_ptr = s_gate_output_[expert_idx] + ith * config_.stride;
llamafile_sgemm(config_.stride, 1, config_.hidden_size / ggml_blck_size(config_.gate_type), gate_proj_ptr, config_.hidden_size / ggml_blck_size(config_.gate_type), gate_input_ptr, config_.hidden_size / ggml_blck_size(config_.gate_type), gate_output_ptr, config_.stride, 0, 1, GGML_TASK_TYPE_COMPUTE, config_.gate_type, ggml_internal_get_type_traits(config_.gate_type).vec_dot_type, GGML_TYPE_F32, GGML_PREC_DEFAULT);
#ifdef USE_NUMA
void* up_proj_ptr = (uint8_t*)up_proj_numa_[Backend::numa_node] + (expert_id * config_.intermediate_size + ith * config_.stride) * config_.hidden_size * ggml_type_size(config_.up_type) / ggml_blck_size(config_.up_type);
#else
void* up_proj_ptr = (uint8_t*)up_proj_ + (expert_id * config_.intermediate_size + ith * config_.stride) * config_.hidden_size * ggml_type_size(config_.up_type) / ggml_blck_size(config_.up_type);
#endif
float* up_output_ptr = s_up_output_[expert_idx] + ith * config_.stride;
llamafile_sgemm(config_.stride, 1, config_.hidden_size / ggml_blck_size(config_.up_type), up_proj_ptr, config_.hidden_size / ggml_blck_size(config_.up_type), up_input_ptr, config_.hidden_size / ggml_blck_size(config_.up_type), up_output_ptr, config_.stride, 0, 1, GGML_TASK_TYPE_COMPUTE, config_.up_type, ggml_internal_get_type_traits(config_.up_type).vec_dot_type, GGML_TYPE_F32, GGML_PREC_DEFAULT);
for (int i = ith * config_.stride; i < (ith + 1) * config_.stride; i++) {
s_intermediate_fp32_[expert_idx][i] = act_fn(s_gate_output_[expert_idx][i]) * s_up_output_[expert_idx][i];
}
if (config_.stride % ggml_blck_size(ggml_internal_get_type_traits(config_.down_type).vec_dot_type) == 0) {
float* intermediate_fp32_ptr = s_intermediate_fp32_[expert_idx] + ith * config_.stride;
void* down_input_ptr = s_down_input_[expert_idx] + ith * config_.stride * ggml_type_size(ggml_internal_get_type_traits(config_.down_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(config_.down_type).vec_dot_type);
from_float(intermediate_fp32_ptr, down_input_ptr, config_.stride, ggml_internal_get_type_traits(config_.down_type).vec_dot_type);
}
}, nullptr);
if (config_.stride % ggml_blck_size(ggml_internal_get_type_traits(config_.down_type).vec_dot_type) != 0) {
for (int i = 0; i < k; i++) {
from_float(s_intermediate_fp32_[i], s_down_input_[i], config_.intermediate_size, ggml_internal_get_type_traits(config_.down_type).vec_dot_type);
}
}
nth = config_.hidden_size / config_.stride;
backend->do_work_stealing_job(nth, nullptr, [&](int task_id) {
int ith = task_id;
for (int i = ith * config_.stride; i < (ith + 1) * config_.stride; i++) {
s_output_fp32_[i] = 0;
}
for (int expert_idx = 0; expert_idx < k; expert_idx++) {
uint64_t expert_id = expert_ids[expert_idx];
#ifdef USE_NUMA
void* down_proj_ptr = (uint8_t*)down_proj_numa_[Backend::numa_node] + (expert_id * config_.hidden_size + ith * config_.stride) * config_.intermediate_size * ggml_type_size(config_.down_type) / ggml_blck_size(config_.down_type);
#else
void* down_proj_ptr = (uint8_t*)down_proj_ + (expert_id * config_.hidden_size + ith * config_.stride) * config_.intermediate_size * ggml_type_size(config_.down_type) / ggml_blck_size(config_.down_type);
#endif
float* down_output_ptr = s_down_output_[expert_idx] + ith * config_.stride;
llamafile_sgemm(config_.stride, 1, config_.intermediate_size / ggml_blck_size(config_.down_type), down_proj_ptr, config_.intermediate_size / ggml_blck_size(config_.down_type), s_down_input_[expert_idx], config_.intermediate_size / ggml_blck_size(config_.down_type), down_output_ptr, config_.stride, 0, 1, GGML_TASK_TYPE_COMPUTE, config_.down_type, ggml_internal_get_type_traits(config_.down_type).vec_dot_type, GGML_TYPE_F32, GGML_PREC_DEFAULT);
for (int i = ith * config_.stride; i < (ith + 1) * config_.stride; i++) {
s_output_fp32_[i] += s_down_output_[expert_idx][i] * weights[expert_idx];
}
}
if (config_.stride % ggml_blck_size(config_.hidden_type) == 0) {
float* output_fp32_ptr = s_output_fp32_ + ith * config_.stride;
void* output_ptr = (uint8_t*)output + ith * config_.stride * ggml_type_size(config_.hidden_type) / ggml_blck_size(config_.hidden_type);
from_float(output_fp32_ptr, output_ptr, config_.stride, config_.hidden_type);
}
}, nullptr);
if (config_.stride % ggml_blck_size(config_.hidden_type) != 0) {
from_float(s_output_fp32_, output, config_.hidden_size, config_.hidden_type);
}
}
void MOE::forward_many(int qlen, int k, const uint64_t* expert_ids, const float* weights, const void* input, void* output, Backend* backend) {
for (int i = 0; i < config_.expert_num; i++) {
m_local_num_[i] = 0;
}
for (int i = 0; i < qlen; i++) {
for (int j = 0; j < k; j++) {
m_local_pos_[i][j] = m_local_num_[expert_ids[i * k + j]]++;
}
}
uint64_t offset = 0;
for (int i = 0; i < config_.expert_num; i++) {
m_local_gate_input_ptr_[i] = m_local_gate_input_ + offset * config_.hidden_size * ggml_type_size(ggml_internal_get_type_traits(config_.gate_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(config_.gate_type).vec_dot_type);
m_local_up_input_ptr_[i] = m_local_up_input_ + offset * config_.hidden_size * ggml_type_size(ggml_internal_get_type_traits(config_.up_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(config_.up_type).vec_dot_type);
m_local_gate_output_ptr_[i] = m_local_gate_output_ + offset * config_.intermediate_size;
m_local_up_output_ptr_[i] = m_local_up_output_ + offset * config_.intermediate_size;
m_local_intermediate_fp32_ptr_[i] = m_local_intermediate_fp32_ + offset * config_.intermediate_size;
m_local_down_input_ptr_[i] = m_local_down_input_ + offset * config_.intermediate_size * ggml_type_size(ggml_internal_get_type_traits(config_.down_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(config_.down_type).vec_dot_type);
m_local_down_output_ptr_[i] = m_local_down_output_ + offset * config_.hidden_size;
offset += m_local_num_[i];
}
backend->do_work_stealing_job(qlen, nullptr, [&](int i) {
const void* gate_input_ptr;
const void* up_input_ptr;
if (config_.hidden_type == ggml_internal_get_type_traits(config_.gate_type).vec_dot_type && config_.hidden_type == ggml_internal_get_type_traits(config_.up_type).vec_dot_type) {
gate_input_ptr = up_input_ptr = (uint8_t*)input + i * config_.hidden_size * ggml_type_size(config_.hidden_type) / ggml_blck_size(config_.hidden_type);
} else {
to_float((uint8_t*)input + i * config_.hidden_size * ggml_type_size(config_.hidden_type) / ggml_blck_size(config_.hidden_type), m_input_fp32_[i], config_.hidden_size, config_.hidden_type);
if (ggml_internal_get_type_traits(config_.gate_type).vec_dot_type == ggml_internal_get_type_traits(config_.up_type).vec_dot_type) {
from_float(m_input_fp32_[i], m_gate_input_[i], config_.hidden_size, ggml_internal_get_type_traits(config_.gate_type).vec_dot_type);
gate_input_ptr = up_input_ptr = m_gate_input_[i];
} else {
if (config_.hidden_type != ggml_internal_get_type_traits(config_.gate_type).vec_dot_type) {
from_float(m_input_fp32_[i], m_gate_input_[i], config_.hidden_size, ggml_internal_get_type_traits(config_.gate_type).vec_dot_type);
gate_input_ptr = m_gate_input_[i];
} else {
gate_input_ptr = (uint8_t*)input + i * config_.hidden_size * ggml_type_size(config_.hidden_type) / ggml_blck_size(config_.hidden_type);
}
if (config_.hidden_type != ggml_internal_get_type_traits(config_.up_type).vec_dot_type) {
from_float(m_input_fp32_[i], m_up_input_[i], config_.hidden_size, ggml_internal_get_type_traits(config_.up_type).vec_dot_type);
up_input_ptr = m_up_input_[i];
} else {
up_input_ptr = (uint8_t*)input + i * config_.hidden_size * ggml_type_size(config_.hidden_type) / ggml_blck_size(config_.hidden_type);
}
}
}
for (int j = 0; j < k; j++) {
memcpy(m_local_gate_input_ptr_[expert_ids[i * k + j]] + m_local_pos_[i][j] * config_.hidden_size * ggml_type_size(ggml_internal_get_type_traits(config_.gate_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(config_.gate_type).vec_dot_type), gate_input_ptr, config_.hidden_size * ggml_type_size(ggml_internal_get_type_traits(config_.gate_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(config_.gate_type).vec_dot_type));
memcpy(m_local_up_input_ptr_[expert_ids[i * k + j]] + m_local_pos_[i][j] * config_.hidden_size * ggml_type_size(ggml_internal_get_type_traits(config_.up_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(config_.up_type).vec_dot_type), up_input_ptr, config_.hidden_size * ggml_type_size(ggml_internal_get_type_traits(config_.up_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(config_.up_type).vec_dot_type));
}
}, nullptr);
int stride = QK_K;
int nth = config_.intermediate_size / stride;
backend->do_work_stealing_job(nth * config_.expert_num, nullptr, [&](int task_id) {
uint64_t expert_idx = task_id / nth;
int ith = task_id % nth;
void* gate_input_ptr = m_local_gate_input_ptr_[expert_idx];
#ifdef USE_NUMA
void* gate_proj_ptr = (uint8_t*)gate_proj_numa_[Backend::numa_node] + (expert_idx * config_.intermediate_size + ith * stride) * config_.hidden_size * ggml_type_size(config_.gate_type) / ggml_blck_size(config_.gate_type);
#else
void* gate_proj_ptr = (uint8_t*)gate_proj_ + (expert_idx * config_.intermediate_size + ith * stride) * config_.hidden_size * ggml_type_size(config_.gate_type) / ggml_blck_size(config_.gate_type);
#endif
float* gate_output_ptr = m_local_gate_output_ptr_[expert_idx] + ith * stride;
llamafile_sgemm(stride, m_local_num_[expert_idx], config_.hidden_size / ggml_blck_size(config_.gate_type), gate_proj_ptr, config_.hidden_size / ggml_blck_size(config_.gate_type), gate_input_ptr, config_.hidden_size / ggml_blck_size(config_.gate_type), gate_output_ptr, config_.intermediate_size, 0, 1, GGML_TASK_TYPE_COMPUTE, config_.gate_type, ggml_internal_get_type_traits(config_.gate_type).vec_dot_type, GGML_TYPE_F32, GGML_PREC_DEFAULT);
void* up_input_ptr = m_local_up_input_ptr_[expert_idx];
#ifdef USE_NUMA
void* up_proj_ptr = (uint8_t*)up_proj_numa_[Backend::numa_node] + (expert_idx * config_.intermediate_size + ith * stride) * config_.hidden_size * ggml_type_size(config_.up_type) / ggml_blck_size(config_.up_type);
#else
void* up_proj_ptr = (uint8_t*)up_proj_ + (expert_idx * config_.intermediate_size + ith * stride) * config_.hidden_size * ggml_type_size(config_.up_type) / ggml_blck_size(config_.up_type);
#endif
float* up_output_ptr = m_local_up_output_ptr_[expert_idx] + ith * stride;
llamafile_sgemm(stride, m_local_num_[expert_idx], config_.hidden_size / ggml_blck_size(config_.up_type), up_proj_ptr, config_.hidden_size / ggml_blck_size(config_.up_type), up_input_ptr, config_.hidden_size / ggml_blck_size(config_.up_type), up_output_ptr, config_.intermediate_size, 0, 1, GGML_TASK_TYPE_COMPUTE, config_.up_type, ggml_internal_get_type_traits(config_.up_type).vec_dot_type, GGML_TYPE_F32, GGML_PREC_DEFAULT);
for (int i = 0; i < m_local_num_[expert_idx]; i++) {
for (int j = ith * stride; j < (ith + 1) * stride; j++) {
m_local_intermediate_fp32_ptr_[expert_idx][i * config_.intermediate_size + j] = act_fn(m_local_gate_output_ptr_[expert_idx][i * config_.intermediate_size + j]) * m_local_up_output_ptr_[expert_idx][i * config_.intermediate_size + j];
}
float* intermediate_fp32_ptr = m_local_intermediate_fp32_ptr_[expert_idx] + i * config_.intermediate_size + ith * stride;
void* down_input_ptr = m_local_down_input_ptr_[expert_idx] + i * config_.intermediate_size * ggml_type_size(ggml_internal_get_type_traits(config_.down_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(config_.down_type).vec_dot_type) + ith * stride * ggml_type_size(ggml_internal_get_type_traits(config_.down_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(config_.down_type).vec_dot_type);
from_float(intermediate_fp32_ptr, down_input_ptr, stride, ggml_internal_get_type_traits(config_.down_type).vec_dot_type);
}
}, nullptr);
stride = QK_K;
nth = config_.hidden_size / stride;
backend->do_work_stealing_job(nth * config_.expert_num, nullptr, [&](int task_id) {
uint64_t expert_idx = task_id / nth;
int ith = task_id % nth;
void* down_input_ptr = m_local_down_input_ptr_[expert_idx];
#ifdef USE_NUMA
void* down_proj_ptr = (uint8_t*)down_proj_numa_[Backend::numa_node] + (expert_idx * config_.hidden_size + ith * stride) * config_.intermediate_size * ggml_type_size(config_.down_type) / ggml_blck_size(config_.down_type);
#else
void* down_proj_ptr = (uint8_t*)down_proj_ + (expert_idx * config_.hidden_size + ith * stride) * config_.intermediate_size * ggml_type_size(config_.down_type) / ggml_blck_size(config_.down_type);
#endif
float* down_output_ptr = m_local_down_output_ptr_[expert_idx] + ith * stride;
llamafile_sgemm(stride, m_local_num_[expert_idx], config_.intermediate_size / ggml_blck_size(config_.down_type), down_proj_ptr, config_.intermediate_size / ggml_blck_size(config_.down_type), down_input_ptr, config_.intermediate_size / ggml_blck_size(config_.down_type), down_output_ptr, config_.hidden_size, 0, 1, GGML_TASK_TYPE_COMPUTE, config_.down_type, ggml_internal_get_type_traits(config_.down_type).vec_dot_type, GGML_TYPE_F32, GGML_PREC_DEFAULT);
}, nullptr);
backend->do_work_stealing_job(qlen, nullptr, [&](int i) {
for (int e = 0; e < config_.hidden_size; e++) {
m_output_fp32_[i][e] = 0;
}
for (int j = 0; j < k; j++) {
for (int e = 0; e < config_.hidden_size; e++) {
m_output_fp32_[i][e] += m_local_down_output_ptr_[expert_ids[i * k + j]][m_local_pos_[i][j] * config_.hidden_size + e] * weights[i * k + j];
}
}
from_float(m_output_fp32_[i], (uint8_t*)output + i * config_.hidden_size * ggml_type_size(config_.hidden_type) / ggml_blck_size(config_.hidden_type), config_.hidden_size, config_.hidden_type);
}, nullptr);
}
void MOE::forward(int qlen, int k, const uint64_t* expert_ids, const float* weights, const void* input, void* output, int* batch_size_tensor, Backend* backend) {
qlen = batch_size_tensor[0];
if (qlen < config_.group_min_len) {
for (int i = 0; i < qlen; i++) {
forward_one(k, expert_ids + i * k, weights + i * k, (uint8_t*)input + i * config_.hidden_size * ggml_type_size(config_.hidden_type) / ggml_blck_size(config_.hidden_type), (uint8_t*)output + i * config_.hidden_size * ggml_type_size(config_.hidden_type) / ggml_blck_size(config_.hidden_type), backend);
}
return;
}
int forward_len = std::min(config_.group_max_len, qlen);
forward_many(forward_len, k, expert_ids, weights, input, output, backend);
batch_size_tensor[0] -= forward_len;
forward(qlen - forward_len, k, expert_ids + forward_len * k, weights + forward_len * k, (uint8_t*)input + forward_len * config_.hidden_size * ggml_type_size(config_.hidden_type) / ggml_blck_size(config_.hidden_type), (uint8_t*)output + forward_len * config_.hidden_size * ggml_type_size(config_.hidden_type) / ggml_blck_size(config_.hidden_type), batch_size_tensor, backend);
}