vulkan: optimize rms_norm, and allow the work to spread across multiple SMs (#15281)

* vulkan: optimize rms_norm, and allow the work to spread across multiple SMs

There are really two parts to this change:
(1) Some optimizations similar to what we have in soft_max, to unroll with
different numbers of iterations.
(2) A fusion optimization where we detect add followed by rms_norm, and make
the add shader atomically accumulate the values^2 into memory. Then the
rms_norm shader can just load that sum. This allows the rms_norm to be
parallelized across multiple workgroups, it just becomes a simple per-element
multiply.

The fusion optimization is currently only applied when the rms_norm is on a
single vector. This previously always ran on a single SM. It could apply more
broadly, but when there are other dimensions the work can already spread across
SMs, and there would be some complexity to tracking multiple atomic sums.

* Change add+rms_norm optimization to write out an array of partial sums
rather than using atomic add, to make it deterministic. The rms_norm
shader fetches a subgroup's worth in parallel and uses subgroupAdd to
add them up.

* complete rebase against fused adds - multi_add shader can also compute partial sums

* fix validation errors

* disable add_rms_fusion for Intel due to possible driver bug

* resolve against #15489, sync after clearing partial sums
This commit is contained in:
Jeff Bolz 2025-08-23 13:16:17 -05:00 committed by GitHub
parent b1afcab804
commit 611f419cff
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 379 additions and 50 deletions

View file

@ -2858,6 +2858,7 @@ struct test_rms_norm_mul_add : public test_case {
const std::array<int64_t, 4> ne;
const float eps;
const bool broadcast;
const bool multi_add; // test a sequence of adds feeding into rms_norm
std::string op_desc(ggml_tensor * t) override {
GGML_UNUSED(t);
@ -2867,13 +2868,13 @@ struct test_rms_norm_mul_add : public test_case {
bool run_whole_graph() override { return true; }
std::string vars() override {
return VARS_TO_STR4(type, ne, eps, broadcast);
return VARS_TO_STR5(type, ne, eps, broadcast, multi_add);
}
test_rms_norm_mul_add(ggml_type type = GGML_TYPE_F32,
std::array<int64_t, 4> ne = {64, 5, 4, 3},
float eps = 1e-6f, bool broadcast = false)
: type(type), ne(ne), eps(eps), broadcast(broadcast) {}
float eps = 1e-6f, bool broadcast = false, bool multi_add = false)
: type(type), ne(ne), eps(eps), broadcast(broadcast), multi_add(multi_add) {}
ggml_tensor * build_graph(ggml_context * ctx) override {
std::array<int64_t, 4> broadcast_dims = {ne[0]*2, ne[1]*3, ne[2]*3, ne[3]*4};
@ -2891,6 +2892,9 @@ struct test_rms_norm_mul_add : public test_case {
// Use a, b and c early, so we don't end up with an OP_NONE between rms_norm and mul
a = ggml_add(ctx, ggml_add(ctx, a, b), c);
if (multi_add) {
a = ggml_add(ctx, ggml_add(ctx, a, b), c);
}
ggml_tensor * out = ggml_add(ctx, ggml_mul(ctx, ggml_rms_norm(ctx, a, eps), b), c);
ggml_set_name(out, "out");
@ -5842,6 +5846,11 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
test_cases.emplace_back(new test_rms_norm_mul_add(GGML_TYPE_F32, {64, 5, 4, 3}, eps));
test_cases.emplace_back(new test_rms_norm_mul_add(GGML_TYPE_F32, {64, 5, 4, 3}, eps, true));
}
for (uint32_t n : {1, 511, 1025, 8192, 33*512}) {
for (bool multi_add : {false, true}) {
test_cases.emplace_back(new test_rms_norm_mul_add(GGML_TYPE_F32, {n, 1, 1, 1}, 1e-6f, false, multi_add));
}
}
test_cases.emplace_back(new test_l2_norm(GGML_TYPE_F32, {64, 5, 4, 3}, 1e-12f));