diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index 82e29d5ad..f6ffb2b3a 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -4881,15 +4881,32 @@ kernel void kernel_conv_transpose_1d( uint3 tgpig[[threadgroup_position_in_grid]], uint3 tgpg[[threadgroups_per_grid]]) { + // For output position j on the time axis, only input positions + // i such that i*s0 <= j < i*s0 + K + // contribute -- i.e. i in [ceil((j - K + 1)/s0), floor(j/s0)] + // intersected with [0, IL-1]. That's at most ceil(K/s0) values + // (typically 2 for stride==K/2 transposed convs). + const int32_t j = tgpig[0]; + const int32_t s0 = args.s0; + const int32_t K = args.K; + const int32_t IL = args.IL; + + int32_t i_min; + { + int32_t a = j - K + 1; + i_min = a <= 0 ? 0 : (a + s0 - 1) / s0; // ceil(a/s0) for a>0 + } + int32_t i_max = j / s0; + if (i_max > IL - 1) i_max = IL - 1; + float v = 0.0f; + if (i_min <= i_max) { + for (int64_t c = 0; c < args.IC; c++) { + const int32_t kernel_offset = c * tgpg[1] * K + K * tgpig[1]; + const int32_t input_offset = c * IL; - for (int64_t c = 0; c < args.IC; c++) { - const int32_t kernel_offset = c * tgpg[1] * args.K + args.K * tgpig[1]; - const int32_t input_offset = c * args.IL; - - for (int64_t i = 0; i < args.IL; i++) { - if (tgpig[0] >= i * args.s0 && tgpig[0] < i * args.s0 + args.K) { - v += src0[kernel_offset + tgpig[0] - i * args.s0] * src1[input_offset + i]; + for (int32_t i = i_min; i <= i_max; i++) { + v += float(src0[kernel_offset + j - i * s0]) * src1[input_offset + i]; } } }