From acd604fb277044e07c2bff01f4c169167b45f478 Mon Sep 17 00:00:00 2001 From: Daniele <57776841+daniandtheweb@users.noreply.github.com> Date: Wed, 20 May 2026 17:15:13 +0200 Subject: [PATCH] vulkan: optimize operations in the IM2COL shader (#22685) * vulkan: optimize operations in the IM2COL shader * Add comments and improve the code formatting --- .../ggml-vulkan/vulkan-shaders/im2col.comp | 73 +++++++++++++++---- 1 file changed, 59 insertions(+), 14 deletions(-) diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp b/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp index ba4c2103f..f4130d223 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp @@ -44,36 +44,81 @@ void im2col(const uint ow, const uint z_idx) { const uint KHKW = p.KH * p.KW; + // Precompute base input coordinates + const int base_iw = int(ow * p.s0) - p.p0; + const int base_ih = int(oh * p.s1) - p.p1; + + // Precompute step deltas + const uint delta_ic = BLOCK_SIZE / KHKW; + const uint delta_rem = BLOCK_SIZE % KHKW; + + const uint delta_ky = delta_rem / p.KW; + const uint delta_kx = delta_rem % p.KW; + + const uint delta_ic_offset = delta_ic * p.offset_delta; + + // If using BDA mode, precompute the base pointer and step size +#if BDA + const BDA_STORAGE_T base_dst_addr = p.dst_addr + D_SIZE * dst_row; + const uint bda_step = D_SIZE * BLOCK_SIZE; +#endif + uint wg_x = gl_WorkGroupID.x; do { const uint wg_offset = wg_x * 512; - [[unroll]] for (uint i = 0; i < NUM_ITER; ++i) { - const uint chw_idx = wg_offset + gidx + i * BLOCK_SIZE; + uint chw_idx = wg_offset + gidx; + uint ic = chw_idx / KHKW; + uint rem = chw_idx % KHKW; + + uint ky = rem / p.KW; + uint kx = rem % p.KW; + + uint ic_offset = src_batch + ic * p.offset_delta; + + // Initialize running pointer/index for the destination buffer +#if BDA + BDA_STORAGE_T current_dst_addr = base_dst_addr + D_SIZE * chw_idx; +#else + uint current_dst_idx = dst_row + chw_idx; +#endif + + [[unroll]] for (uint i = 0; i < NUM_ITER; ++i) { if (chw_idx >= p.CHW) { return; } - const uint ic = chw_idx / KHKW; - const uint rem = chw_idx - ic * KHKW; - const uint ky = rem / p.KW; - const uint kx = rem - ky * p.KW; - - const uint iiw = ow * p.s0 + kx * p.d0 - p.p0; - const uint iih = oh * p.s1 + ky * p.d1 - p.p1; + const int iiw = base_iw + int(kx * p.d0); + const int iih = base_ih + int(ky * p.d1); A_TYPE val = A_TYPE(0); - if (iih < p.IH && iiw < p.IW) { - val = data_a[src_batch + ic * p.offset_delta + iih * p.IW + iiw]; + if (uint(iih) < p.IH && uint(iiw) < p.IW) { + val = data_a[ic_offset + uint(iih) * p.IW + uint(iiw)]; } #if BDA - D_ptr out_ptr = D_ptr(p.dst_addr + D_SIZE * (dst_row + chw_idx)); - out_ptr.d = D_TYPE(val); + D_ptr(current_dst_addr).d = D_TYPE(val); + current_dst_addr += bda_step; #else - data_d[dst_row + chw_idx] = D_TYPE(val); + data_d[current_dst_idx] = D_TYPE(val); + current_dst_idx += BLOCK_SIZE; #endif + + chw_idx += BLOCK_SIZE; + ic_offset += delta_ic_offset; + kx += delta_kx; + ky += delta_ky; + + // Handle X axis wrap + uint kx_wrap = uint(kx >= p.KW); + kx -= kx_wrap * p.KW; + ky += kx_wrap; + + // Handle Y axis wrap + uint ky_wrap = uint(ky >= p.KH); + ky -= ky_wrap * p.KH; + ic_offset += ky_wrap * p.offset_delta; } wg_x += gl_NumWorkGroups.x;