CUDA: Optimize reduce_rows_f32 kernel, leading up to 25x perf improvement on kernel-level and 10% perf increase for Gemma3n (#15132)

* Factor out `reduce_rows_f32` from common.cuh

This increases iteration cycle speed by not having to recompile
every kernel all the time

* Hide memory-latency by loop unrolling in reduce_rows_f32

* Further optimizations to `reduce_rows_f32`

1. Increase threadblock size to better hide latency of memory requests.
   As a consequence of bigger threadblocks, do 2-step summation, using
   shared memory to communicate results between invocations
2. Use sum_temp array to reduce waits on sum
3. Adjust num_unroll to reflext bigger threadblock
4. Improve default block_dims, increase support for more block_dims

* Add perf tests for `reduce_rows_f32` kernel

* Add heuristic to toggle 128/512 threads based on sm count

Break even point was the minimum of the following multiples.

| GPU Model                     | Nrow SM Count Multiple |
| -----------                   | -----------            |
| RTX 4000 SFF ADA              | 2.0x                   |
| RTX 6000 ADA                  | 2.5x                   |
| RTX PRO 6000 Blackwell Max-Q  | 3.04x                  |
| RTX PRO 4500 Blackwell	| 3.15x                  |

* Ensure perf gains also for small ncols and large nrows

Alternative to this, one could have also made the number of unrollings
template-able, but that would require compiling the kernel multiple
times, increasing binary size unnecessarily

* Modify perf and unit-tests

* Apply auto-formatting by clang

* Fix CI build failure

See https://github.com/ggml-org/llama.cpp/actions/runs/16798370266/job/47573716079?pr=15132#step:7:486
Building with VS generator worked though.

* Remove sm_count property from `ggml_backend_cuda_context`

Requested by @JohannesGaessler, and should fix remaining CI issues as a
side-effect

* Add CUB-based implementation for GGML_OP_MEAN

Currently this branch is only executed for nrows==1

* Add heuristics to execute CUB branch only when it brings perf

Heuristics were determined on the following HW:

* RTX 4000 SFF ADA
* RTX 6000 ADA
* RTX PRO 6000 Blackwell Max-Q
* RTX PRO 4500 Blackwell

* Add unit-test for CUB-based mean

Tests should run with CUDA Graphs enabled per default on NVGPUs

* Rename `USE_CUB` to `GGML_CUDA_USE_CUB`

Suggested by @JohannesGaessler

* Unindent Preprocessor directives

See
https://github.com/ggml-org/llama.cpp/pull/15132#discussion_r2269213506
This commit is contained in:
Oliver Simons 2025-08-13 10:04:46 +02:00 committed by GitHub
parent bc5182272c
commit 6028bf7435
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 155 additions and 36 deletions

View file

@ -5998,6 +5998,15 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
test_cases.emplace_back(new test_sum());
test_cases.emplace_back(new test_sum_rows());
test_cases.emplace_back(new test_mean());
test_cases.emplace_back(new test_sum(GGML_TYPE_F32, { 33, 1, 1, 1 }));
test_cases.emplace_back(new test_sum_rows(GGML_TYPE_F32, { 33, 1, 1, 1 }));
test_cases.emplace_back(new test_mean(GGML_TYPE_F32, { 33, 1, 1, 1 }));
test_cases.emplace_back(new test_sum(GGML_TYPE_F32, { 33, 1024, 1, 1 }));
test_cases.emplace_back(new test_sum_rows(GGML_TYPE_F32, { 33, 1024, 1, 1 }));
test_cases.emplace_back(new test_sum(GGML_TYPE_F32, { 33, 256, 1, 1 }));
test_cases.emplace_back(new test_sum_rows(GGML_TYPE_F32, { 33, 256, 1, 1 }));
test_cases.emplace_back(new test_mean(GGML_TYPE_F32, { 33, 256, 1, 1 }));
test_cases.emplace_back(new test_mean(GGML_TYPE_F32, { 32769, 1, 1, 1 }));
test_cases.emplace_back(new test_group_norm(GGML_TYPE_F32, {64, 64, 320, 1}));
test_cases.emplace_back(new test_group_norm(GGML_TYPE_F32, {9, 9, 1280, 1}));
test_cases.emplace_back(new test_acc());
@ -6179,6 +6188,18 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_perf() {
test_cases.emplace_back(new test_add_id(GGML_TYPE_F32, GGML_TYPE_F32, 2880, 32, 4, n_token));
}
std::vector<std::array<int64_t, 4>> reduce_rows_cases = {
{ 8192, 1, 1, 1 },
{ 8192, 8192, 1, 1 },
{ 128, 8192, 1, 1 },
};
for (auto it: reduce_rows_cases){
test_cases.emplace_back(new test_mean(GGML_TYPE_F32, it));
test_cases.emplace_back(new test_sum_rows(GGML_TYPE_F32, it));
test_cases.emplace_back(new test_sum(GGML_TYPE_F32, it));
}
return test_cases;
}