From c6e40883767cafcad79e54de62dfe48e373f77e2 Mon Sep 17 00:00:00 2001 From: Matt Corallo <649246+TheBlueMatt@users.noreply.github.com> Date: Wed, 27 May 2026 15:19:23 +0000 Subject: [PATCH] vulkan: Switch MUL_MAT_VEC to 4 K per iteration for F16/32 (#22887) * vulkan: Switch MUL_MAT_VEC to 4 K per iteration for F16/32 Against mesa git, this shows a 4.8% performance improvement for tg128 on Qwen3.5-9B:BF16 on Intel BMG. Note that this breaks some tests until the last commit which fixes OOB A reads. * vulkan: Use aligned loads in mul_mat_vec when available Against mesa git, this shows a 3.3% performance improvement for tg128 on Qwen3.5-9B:BF16 on Intel BMG. * Make explicit that `num_rows` is <= `NUM_ROWS` in mul_mat_vec Mesa's UUB logic can't see through conditionals, limiting its ability to understand the bounds on the `num_rows` field in the cleanup run. Making it explicit that `num_rows` is, indeed, always <= `NUM_ROWS` helps mesa make slightly better codegen. Against mesa git, this currently shows a 1% performance improvement in tg128 on Qwen3.5-9B:BF16 on Intel BMG. * vulkan: Fix OOB A reads in MUL_MAT_VEC for odd sizes There was a TODO to fix the OOB reads from the A matrix which we do here. It is within performance noise (+<0.1%) in tg128 for Qwen3.5-9B:BF16 on Intel BMG. --- .../vulkan-shaders/dequant_funcs.glsl | 39 +++++ .../vulkan-shaders/mul_mat_vec.comp | 147 ++++++++++++++---- .../src/ggml-vulkan/vulkan-shaders/types.glsl | 2 + 3 files changed, 162 insertions(+), 26 deletions(-) diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.glsl index 88d07d2df..e67299fde 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.glsl @@ -5,21 +5,60 @@ #include "types.glsl" #if defined(DATA_A_F32) +FLOAT_TYPE dequantize1(uint ib, uint iqs, uint a_offset) { + return data_a[a_offset + ib]; +} vec2 dequantize(uint ib, uint iqs, uint a_offset) { return vec2(data_a[a_offset + ib], data_a[a_offset + ib + 1]); } +vec4 dequantize4(uint ib, uint iqs, uint a_offset) { + return vec4(data_a[a_offset + ib ], data_a[a_offset + ib + 1], + data_a[a_offset + ib + 2], data_a[a_offset + ib + 3]); +} +vec4 dequantize4_2aligned(uint ib, uint iqs, uint a_offset) { + return vec4(data_a[a_offset + ib ], data_a[a_offset + ib + 1], + data_a[a_offset + ib + 2], data_a[a_offset + ib + 3]); +} + #endif #if defined(DATA_A_F16) +FLOAT_TYPE dequantize1(uint ib, uint iqs, uint a_offset) { + return data_a[a_offset + ib]; +} vec2 dequantize(uint ib, uint iqs, uint a_offset) { return vec2(data_a[a_offset + ib], data_a[a_offset + ib + 1]); } +vec4 dequantize4(uint ib, uint iqs, uint a_offset) { + return vec4(data_a[a_offset + ib ], data_a[a_offset + ib + 1], + data_a[a_offset + ib + 2], data_a[a_offset + ib + 3]); +} +vec4 dequantize4_2aligned(uint ib, uint iqs, uint a_offset) { + const vec2 a = data_a_packed32[(a_offset + ib)/2]; + const vec2 b = data_a_packed32[(a_offset + ib)/2 + 1]; + return vec4(a, b); +} #endif #if defined(DATA_A_BF16) +FLOAT_TYPE dequantize1(uint ib, uint iqs, uint a_offset) { + return bf16_to_fp32(data_a[a_offset + ib]); +} vec2 dequantize(uint ib, uint iqs, uint a_offset) { return vec2(bf16_to_fp32(data_a[a_offset + ib]), bf16_to_fp32(data_a[a_offset + ib + 1])); } +vec4 dequantize4(uint ib, uint iqs, uint a_offset) { + return vec4(bf16_to_fp32(data_a[a_offset + ib ]), bf16_to_fp32(data_a[a_offset + ib + 1]), + bf16_to_fp32(data_a[a_offset + ib + 2]), bf16_to_fp32(data_a[a_offset + ib + 3])); +} +vec4 dequantize4_2aligned(uint ib, uint iqs, uint a_offset) { + const uint a = data_a_packed32[(a_offset + ib)/2]; + const uint b = data_a_packed32[(a_offset + ib)/2 + 1]; + return vec4(uintBitsToFloat((a & 0x0000ffff) << 16), + uintBitsToFloat( a & 0xffff0000), + uintBitsToFloat((b & 0x0000ffff) << 16), + uintBitsToFloat( b & 0xffff0000)); +} #endif #if defined(DATA_A_Q4_0) diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp index 2271be402..5a9d0e778 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp @@ -10,12 +10,38 @@ layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; #if !defined(DATA_A_F32) && !defined(DATA_A_F16) && !defined(DATA_A_BF16) #define K_PER_ITER 8 #else -#define K_PER_ITER 2 +#define K_PER_ITER 4 #endif uint a_offset, b_offset, d_offset, y_offset; +vec4 load_b(const uint j, const uint iybs, const uint iqs, const bool lastiter, out bool OOB_y, out bool OOB_z, out bool OOB_w) { + // Check if the latter elements are OOB, and don't fetch B or accumulate it. + OOB_y = lastiter && (iybs + iqs + y_offset >= p.ncols); + OOB_z = lastiter && (iybs + iqs + y_offset*2 >= p.ncols); + OOB_w = lastiter && (iybs + iqs + y_offset*3 >= p.ncols); + + if (!OOB_w) { + return vec4(FLOAT_TYPE(data_b[j*p.batch_stride_b + b_offset + iybs + iqs]), + FLOAT_TYPE(data_b[j*p.batch_stride_b + b_offset + iybs + iqs + y_offset]), + FLOAT_TYPE(data_b[j*p.batch_stride_b + b_offset + iybs + iqs + y_offset*2]), + FLOAT_TYPE(data_b[j*p.batch_stride_b + b_offset + iybs + iqs + y_offset*3])); + } else if (!OOB_z) { + return vec4(FLOAT_TYPE(data_b[j*p.batch_stride_b + b_offset + iybs + iqs]), + FLOAT_TYPE(data_b[j*p.batch_stride_b + b_offset + iybs + iqs + y_offset]), + FLOAT_TYPE(data_b[j*p.batch_stride_b + b_offset + iybs + iqs + y_offset*2]), + 0); + } else if (!OOB_y) { + return vec4(FLOAT_TYPE(data_b[j*p.batch_stride_b + b_offset + iybs + iqs]), + FLOAT_TYPE(data_b[j*p.batch_stride_b + b_offset + iybs + iqs + y_offset]), + 0, 0); + } else { + return vec4(FLOAT_TYPE(data_b[j*p.batch_stride_b + b_offset + iybs + iqs]), + 0, 0, 0); + } +} + void iter(inout FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const uint first_row, const uint num_rows, const uint tid, const uint i, bool lastiter) { [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { @@ -25,6 +51,8 @@ void iter(inout FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const uint first_row, const #if K_PER_ITER == 8 #if QUANT_R == 2 + // Note that we end up fetching bogus elements here, but its fine as they'll be + // within an accessible block. const vec4 bv02 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + iybs + iqs) / 4]); const vec4 bv13 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + iybs + iqs + y_offset) / 4]); const vec4 bv0 = vec4(bv02.x, bv13.x, bv02.y, bv13.y); @@ -34,18 +62,11 @@ void iter(inout FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const uint first_row, const const vec4 bv1 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + iybs + iqs) / 4 + 1]); #endif #else - // Check if the second of the pair of elements is OOB, and don't fetch B or - // accumulate it. We still fetch a pair of elements for A, which is fine for - // quantized formats since they'll be within the same block. We should - // probably skip fetching the second element for F16/F32, but as of now we - // still do. - const bool OOB = lastiter && (iybs + iqs + y_offset >= p.ncols); + bool OOB_y; + bool OOB_z; + bool OOB_w; - FLOAT_TYPE b0 = 0, b1 = 0; - b0 = FLOAT_TYPE(data_b[j*p.batch_stride_b + b_offset + iybs + iqs]); - if (!OOB) { - b1 = FLOAT_TYPE(data_b[j*p.batch_stride_b + b_offset + iybs + iqs + y_offset]); - } + const vec4 b = load_b(j, iybs, iqs, lastiter, OOB_y, OOB_z, OOB_w); #endif uint ibi = first_row*p.ncols; [[unroll]] for (uint n = 0; n < num_rows; ++n) { @@ -71,22 +92,60 @@ void iter(inout FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const uint first_row, const temp[j][n] += rowtmp; #else - const vec2 v = dequantize(ib, iqs, a_offset); - - // matrix multiplication - temp[j][n] = fma(FLOAT_TYPE(v.x), b0, temp[j][n]); - if (!OOB) { - temp[j][n] = fma(FLOAT_TYPE(v.y), b1, temp[j][n]); + if (!OOB_w) { + const vec4 v = dequantize4(ib, iqs, a_offset); + temp[j][n] += dot(v, b); + } else if (!OOB_z) { + const vec2 v0 = dequantize(ib, iqs, a_offset); + const FLOAT_TYPE v1 = dequantize1(ib + 2/QUANT_R, iqs, a_offset); + const vec3 v = vec3(v0.x, v0.y, v1); + const vec3 b0 = vec3(b.x, b.y, b.z); + temp[j][n] += dot(v, b0); + } else if (!OOB_y) { + const vec2 v0 = dequantize(ib, iqs, a_offset); + const vec2 b0 = vec2(b.x, b.y); + temp[j][n] += dot(v0, b0); + } else { + const FLOAT_TYPE v = dequantize1(ib, iqs, a_offset); + temp[j][n] = fma(v, b.x, temp[j][n]); } #endif } } } +#if defined(DATA_A_F32) || defined(DATA_A_F16) || defined(DATA_A_BF16) +void iter_aligned_nonquant(inout FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const uint first_row, const uint num_rows, const uint tid, const uint i) +{ + [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { + const uint col = i*BLOCK_SIZE + K_PER_ITER*tid; + const uint iqs = 0; // quant index + const uint iybs = col; // y block start index + + const vec4 b = data_b_v4[(j*p.batch_stride_b + b_offset + iybs + iqs) / 4]; + + uint ibi = first_row*p.ncols; + [[unroll]] for (uint n = 0; n < num_rows; ++n) { + const uint ib = (ibi + col)/QUANT_K; // block index + ibi += p.ncols; + + const vec4 v = dequantize4_2aligned(ib, iqs, a_offset); + + // matrix multiplication + temp[j][n] += dot(v, b); + } + } +} +#endif + void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { const uint tid = gl_LocalInvocationID.x; get_offsets(a_offset, b_offset, d_offset); + const bool is_aligned_nonquant = + p.batch_stride_b % 4 == 0 && b_offset % 4 == 0 && + p.ncols % 4 == 0 && BLOCK_SIZE % 4 == 0 && + K_PER_ITER == 4; y_offset = QUANT_R == 1 ? 1 : QUANT_K/2; @@ -105,17 +164,26 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { int unroll_count = 4; uint unrolled_iters = num_iters & ~(unroll_count - 1); -#if K_PER_ITER == 2 + uint i = 0; + +#if K_PER_ITER == 4 // If the K dimension is odd, we need lastiter==true on the last iteration // so OOB is computed correctly. Skip some unrolling to make that happen. - if ((p.ncols & 1) != 0 && + if ((p.ncols & 3) != 0 && unrolled_iters == num_iters && unrolled_iters > 0) { unrolled_iters -= unroll_count; } + if (is_aligned_nonquant) { + while (i < unrolled_iters) { + // Manually partially unroll the loop + [[unroll]] for (uint k = 0; k < unroll_count; ++k) { + iter_aligned_nonquant(temp, first_row, num_rows, tid, i*K_PER_ITER); + i++; + } + } + } else { #endif - - uint i = 0; while (i < unrolled_iters) { // Manually partially unroll the loop [[unroll]] for (uint k = 0; k < unroll_count; ++k) { @@ -123,18 +191,30 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { i++; } } +#if K_PER_ITER == 4 + } +#endif unroll_count = 2; unrolled_iters = num_iters & ~(unroll_count - 1); -#if K_PER_ITER == 2 - if ((p.ncols & 1) != 0 && +#if K_PER_ITER == 4 + if ((p.ncols & 3) != 0 && unrolled_iters == num_iters && unrolled_iters > 0) { unrolled_iters -= unroll_count; } -#endif + if (is_aligned_nonquant) { + while (i < unrolled_iters && is_aligned_nonquant) { + // Manually partially unroll the loop + [[unroll]] for (uint k = 0; k < unroll_count; ++k) { + iter_aligned_nonquant(temp, first_row, num_rows, tid, i*K_PER_ITER); + i++; + } + } + } else { +#endif while (i < unrolled_iters) { // Manually partially unroll the loop [[unroll]] for (uint k = 0; k < unroll_count; ++k) { @@ -142,10 +222,25 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { i++; } } +#if K_PER_ITER == 4 + } +#endif + +#if K_PER_ITER == 4 + if (is_aligned_nonquant) { + while (i < num_iters) { + iter_aligned_nonquant(temp, first_row, num_rows, tid, i*K_PER_ITER); + i++; + } + } else { +#endif while (i < num_iters) { iter(temp, first_row, num_rows, tid, i*K_PER_ITER, true); i++; } +#if K_PER_ITER == 4 + } +#endif reduce_result(temp, d_offset, first_row, num_rows, tid); } @@ -164,6 +259,6 @@ void main() { if (first_row >= p.stride_d) { return; } - compute_outputs(first_row, p.stride_d - first_row); + compute_outputs(first_row, min(NUM_ROWS, p.stride_d - first_row)); } } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/types.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/types.glsl index 06eff6f21..f84d6f873 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/types.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/types.glsl @@ -31,6 +31,7 @@ #else #define A_TYPE float16_t #endif +#define A_TYPE_PACKED32 f16vec2 #endif #if defined(DATA_A_BF16) @@ -44,6 +45,7 @@ #else #define A_TYPE uint16_t #endif +#define A_TYPE_PACKED32 uint32_t #endif #define QUANT_K_Q4_0 32