ggml-cpu: fuse RMS_NORM + MUL on CPU backend (#22423)

This commit is contained in:
zzzzwc 2026-05-06 15:41:14 +08:00 committed by GitHub
parent 07eaf919ed
commit f08f20a0e3
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 115 additions and 17 deletions

View file

@ -2965,6 +2965,45 @@ struct ggml_cplan ggml_graph_plan(
return cplan;
}
// Try to fuse the current node with subsequent nodes for better performance.
// Returns the number of nodes skipped by fusion (>=1), or 0 if no fusion was applied.
static bool ggml_cpu_disable_fusion = false; // initialized once in ggml_cpu_init(), read-only afterwards
static int ggml_cpu_try_fuse_ops(
const struct ggml_cgraph * cgraph,
const int node_n,
const struct ggml_compute_params * params,
const struct ggml_cplan * cplan) {
if (ggml_cpu_disable_fusion || cplan->use_ref) {
return 0;
}
struct ggml_tensor * node = cgraph->nodes[node_n];
if (node->op == GGML_OP_RMS_NORM) {
// RMS_NORM + MUL fusion
const enum ggml_op fuse_ops[] = { GGML_OP_RMS_NORM, GGML_OP_MUL };
if (ggml_can_fuse(cgraph, node_n, fuse_ops, 2)) {
struct ggml_tensor * mul_node = cgraph->nodes[node_n + 1];
const struct ggml_tensor * mul_w = (mul_node->src[0] == node)
? mul_node->src[1] : mul_node->src[0];
if (node->src[0]->type == GGML_TYPE_F32 &&
mul_node->type == GGML_TYPE_F32 &&
mul_w->type == GGML_TYPE_F32 &&
mul_w->ne[0] == node->ne[0] &&
mul_w->nb[0] == sizeof(float)) {
ggml_compute_forward_rms_norm_mul_fused(params, node, mul_node);
return 1;
}
}
}
return 0;
}
static thread_ret_t ggml_graph_compute_thread(void * data) {
struct ggml_compute_state * state = (struct ggml_compute_state *) data;
struct ggml_threadpool * tp = state->threadpool;
@ -3001,7 +3040,14 @@ static thread_ret_t ggml_graph_compute_thread(void * data) {
continue;
}
ggml_compute_forward(&params, node);
// TODO: move fused-op detection into ggml_graph_plan so fusion decisions are made once at planning time
// Try fused ops, fall back to normal compute
const int n_fused = ggml_cpu_try_fuse_ops(cgraph, node_n, &params, cplan);
if (n_fused > 0) {
node_n += n_fused;
} else {
ggml_compute_forward(&params, node);
}
if (state->ith == 0 && cplan->abort_callback &&
cplan->abort_callback(cplan->abort_callback_data)) {
@ -3763,6 +3809,11 @@ void ggml_cpu_init(void) {
ggml_init_riscv_arch_features();
#endif
{
const char * env = getenv("GGML_CPU_DISABLE_FUSION");
ggml_cpu_disable_fusion = (env != NULL && atoi(env) == 1);
}
is_first_call = false;
}

View file

@ -3713,11 +3713,27 @@ void ggml_compute_forward_norm(
// ggml_compute_forward_group_rms_norm
// fusion kinds that can be combined with the rms_norm computation in a single pass.
// extend this enum when adding new fused variants (e.g. FUSE_ADD, FUSE_MUL_ADD, ...).
enum ggml_rms_norm_fuse_op {
GGML_RMS_NORM_FUSE_OP_NONE,
GGML_RMS_NORM_FUSE_OP_MUL,
};
template <ggml_rms_norm_fuse_op FUSE_OP>
static void ggml_compute_forward_rms_norm_f32(
const ggml_compute_params * params,
ggml_tensor * dst) {
ggml_tensor * dst_rms_norm,
ggml_tensor * dst_fused = nullptr) {
const ggml_tensor * src0 = dst->src[0];
const ggml_tensor * src0 = dst_rms_norm->src[0];
const ggml_tensor * src1 = nullptr;
ggml_tensor * dst = dst_rms_norm;
if constexpr (FUSE_OP == GGML_RMS_NORM_FUSE_OP_MUL) {
src1 = (dst_fused->src[0] == dst_rms_norm) ? dst_fused->src[1] : dst_fused->src[0];
dst = dst_fused;
}
GGML_ASSERT(ggml_are_same_shape(src0, dst));
@ -3726,11 +3742,10 @@ static void ggml_compute_forward_rms_norm_f32(
const int ith = params->ith;
const int nth = params->nth;
GGML_TENSOR_UNARY_OP_LOCALS
GGML_TENSOR_BINARY_OP_LOCALS
float eps;
memcpy(&eps, dst->op_params, sizeof(float));
memcpy(&eps, dst_rms_norm->op_params, sizeof(float));
GGML_ASSERT(eps >= 0.0f);
// TODO: optimize
@ -3740,25 +3755,32 @@ static void ggml_compute_forward_rms_norm_f32(
const float * x = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
ggml_float sum = 0.0;
// worth switching to explicit SIMD?
for (int64_t i00 = 0; i00 < ne00; i00++) {
sum += (ggml_float)(x[i00] * x[i00]);
}
const float mean = sum/ne00;
float * y = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3);
memcpy(y, x, ne00 * sizeof(float));
// for (int i00 = 0; i00 < ne00; i00++) {
// y[i00] = x[i00];
// }
const float mean = sum/ne00;
const float scale = 1.0f/sqrtf(mean + eps);
// if you hit this, likely you got an inf somewhere earlier
assert(scale > 0.0f);
ggml_vec_scale_f32(ne00, y, scale);
float * y = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3);
if constexpr (FUSE_OP == GGML_RMS_NORM_FUSE_OP_MUL) {
const int64_t i11 = i01 % ne11;
const int64_t i12 = i02 % ne12;
const int64_t i13 = i03 % ne13;
const float * w = (float *) ((char *) src1->data + i11*nb11 + i12*nb12 + i13*nb13);
for (int64_t i00 = 0; i00 < ne00; i00++) {
y[i00] = x[i00] * scale * w[i00];
}
} else {
memcpy(y, x, ne00 * sizeof(float));
ggml_vec_scale_f32(ne00, y, scale);
}
}
}
}
@ -3773,7 +3795,31 @@ void ggml_compute_forward_rms_norm(
switch (src0->type) {
case GGML_TYPE_F32:
{
ggml_compute_forward_rms_norm_f32(params, dst);
ggml_compute_forward_rms_norm_f32<GGML_RMS_NORM_FUSE_OP_NONE>(params, dst);
} break;
default:
{
GGML_ABORT("fatal error");
}
}
}
// Fused RMS_NORM + MUL: computes dst = rms_norm(src0) * src1 in a single pass.
// This avoids materializing the intermediate rms_norm result in memory.
void ggml_compute_forward_rms_norm_mul_fused(
const ggml_compute_params * params,
ggml_tensor * dst_rms_norm,
ggml_tensor * dst_mul) {
GGML_ASSERT(dst_mul != nullptr);
GGML_ASSERT(dst_mul->src[0] == dst_rms_norm || dst_mul->src[1] == dst_rms_norm);
const ggml_tensor * src0 = dst_rms_norm->src[0];
switch (src0->type) {
case GGML_TYPE_F32:
{
ggml_compute_forward_rms_norm_f32<GGML_RMS_NORM_FUSE_OP_MUL>(params, dst_rms_norm, dst_mul);
} break;
default:
{

View file

@ -44,6 +44,7 @@ void ggml_compute_forward_concat(const struct ggml_compute_params * params, stru
void ggml_compute_forward_silu_back(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_norm(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_rms_norm(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_rms_norm_mul_fused(const struct ggml_compute_params * params, struct ggml_tensor * dst_rms_norm, struct ggml_tensor * dst_mul);
void ggml_compute_forward_rms_norm_back(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_group_norm(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_l2_norm(const struct ggml_compute_params * params, struct ggml_tensor * dst);