From 989bfb18fc9bd28c12c143f28bdab31eb3a2dc09 Mon Sep 17 00:00:00 2001 From: Jeff Bolz Date: Wed, 7 May 2025 15:57:38 -0500 Subject: [PATCH] vulkan: load each Q value once. optimize O reduction. more tuning --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 6 ++-- .../vulkan-shaders/flash_attn.comp | 31 ++++++++++--------- 2 files changed, 19 insertions(+), 18 deletions(-) diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index b285f43e9..6e0d24bec 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -1590,7 +1590,7 @@ static void ggml_vk_wait_events(vk_context& ctx, std::vector&& events // number of rows/cols for flash attention shader static constexpr uint32_t flash_attention_num_small_rows = 32; -static constexpr uint32_t scalar_flash_attention_num_small_rows = 4; +static constexpr uint32_t scalar_flash_attention_num_small_rows = 8; static uint32_t get_fa_num_small_rows(bool scalar) { return scalar ? scalar_flash_attention_num_small_rows : flash_attention_num_small_rows; @@ -1601,7 +1601,7 @@ static std::array fa_rows_cols(bool scalar, uint32_t D, uint32_t cl // small rows, large cols if (small_rows || scalar) { - return {get_fa_num_small_rows(scalar), 64}; + return {get_fa_num_small_rows(scalar), 32}; } // small cols to reduce register count @@ -1913,7 +1913,7 @@ static void ggml_vk_load_shaders(vk_device& device) { // D_split can't be larger than a subgroup because we use subgroupShuffle to reduce it. // D_split can't be larger than the LSB of D divided by 4 due to vectorization in the shader. const uint32_t D_lsb = D ^ (D & (D-1)); - uint32_t D_split = std::min(std::min(device->subgroup_size, 16u), D_lsb / 4); + uint32_t D_split = std::min(std::min(device->subgroup_size, 8u), D_lsb / 4); // mask dim1 is padded to 64, we rely on this to avoid clamping mask loads GGML_ASSERT((GGML_KQ_MASK_PAD % rows_cols[0]) == 0); diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp index cf3b72a4a..607a19c7f 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp @@ -105,6 +105,7 @@ ACC_TYPE perElemOpComputeSlope(const in uint32_t r, const in uint32_t c, const i } shared FLOAT_TYPE tmpsh[gl_WorkGroupSize.x]; +shared vec4 tmpshv4[gl_WorkGroupSize.x]; shared float16_t masksh[Bc][Br]; shared vec4 Qf[Br][D / 4]; @@ -168,13 +169,15 @@ void main() { uint32_t q_offset = (iq2*p.nb02+iq3*p.nb03) / 4; - [[unroll]] for (uint32_t r = 0; r < Br; ++r) { - if (i * Br + r < N) { - [[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) { - Qf[r][d * D_split + d_tid] = vec4(data_qv4[q_offset / 4 + (i * Br + r) * q_stride / 4 + d * D_split + d_tid]) * p.scale; - } + [[unroll]] for (uint32_t idx = 0; idx < Br * D / 4; idx += gl_WorkGroupSize.x) { + uint32_t d = (idx + tid) % (D / 4); + uint32_t r = (idx + tid) / (D / 4); + if (r < Br && d < D / 4 && + i * Br + r < N) { + Qf[r][d] = vec4(data_qv4[q_offset / 4 + (i * Br + r) * q_stride / 4 + d]) * p.scale; } } + barrier(); vec4 Of[Br][D_per_thread / 4]; [[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) { @@ -350,20 +353,18 @@ void main() { [[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) { Of[r][d] = eMf * Of[r][d]; - [[unroll]] for (uint32_t c = 0; c < 4; ++c) { - tmpsh[tid] = Of[r][d][c]; + tmpshv4[tid] = Of[r][d]; - barrier(); - [[unroll]] for (int s = int(gl_WorkGroupSize.x) / 2; s >= D_split; s >>= 1) { - if (tid < s) { - Of[r][d][c] += tmpsh[tid + s]; - tmpsh[tid] = Of[r][d][c]; - } - barrier(); + barrier(); + [[unroll]] for (int s = int(gl_WorkGroupSize.x) / 2; s >= D_split; s >>= 1) { + if (tid < s) { + Of[r][d] += tmpshv4[tid + s]; + tmpshv4[tid] = Of[r][d]; } - Of[r][d][c] = tmpsh[d_tid]; barrier(); } + Of[r][d] = tmpshv4[d_tid]; + barrier(); } }