diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp index 3f2b03bd6..abdadee28 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp @@ -276,27 +276,22 @@ void main() { Lf[r] = eMf[r]*Lf[r] + rowsumf[r]; } - uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / 2; - - vec4 PVf[Br][D_per_thread / 4]; [[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) { [[unroll]] for (uint32_t r = 0; r < Br; ++r) { - PVf[r][d] = vec4(0.0); + Of[r][d] = eMf[r] * Of[r][d]; } } + + uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / 2; + [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) { [[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) { vec4 Vf = vec4(data_vv4[v_offset / 4 + (j * Bc + c * cols_per_iter + col_tid) * v_stride / 4 + d * D_split + d_tid]); [[unroll]] for (uint32_t r = 0; r < Br; ++r) { - PVf[r][d] += Pf[r][c] * Vf; + Of[r][d] += Pf[r][c] * Vf; } } } - [[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) { - [[unroll]] for (uint32_t r = 0; r < Br; ++r) { - Of[r][d] = eMf[r] * Of[r][d] + PVf[r][d]; - } - } barrier(); }