vulkan: Switch MUL_MAT_VEC to 4 K per iteration for F16/32 (#22887)

* vulkan: Switch MUL_MAT_VEC to 4 K per iteration for F16/32

Against mesa git, this shows a 4.8% performance improvement for
tg128 on Qwen3.5-9B:BF16 on Intel BMG.

Note that this breaks some tests until the last commit which fixes
OOB A reads.

* vulkan: Use aligned loads in mul_mat_vec when available

Against mesa git, this shows a 3.3% performance improvement for
tg128 on Qwen3.5-9B:BF16 on Intel BMG.

* Make explicit that `num_rows` is <= `NUM_ROWS` in mul_mat_vec

Mesa's UUB logic can't see through conditionals, limiting its
ability to understand the bounds on the `num_rows` field in the
cleanup run. Making it explicit that `num_rows` is, indeed, always
<= `NUM_ROWS` helps mesa make slightly better codegen.

Against mesa git, this currently shows a 1% performance improvement
in tg128 on Qwen3.5-9B:BF16 on Intel BMG.

* vulkan: Fix OOB A reads in MUL_MAT_VEC for odd sizes

There was a TODO to fix the OOB reads from the A matrix which we do
here.

It is within performance noise (+<0.1%) in tg128 for
Qwen3.5-9B:BF16 on Intel BMG.
This commit is contained in:
Matt Corallo 2026-05-27 15:19:23 +00:00 committed by GitHub
parent b36eefc1b3
commit c6e4088376
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 162 additions and 26 deletions

View file

@ -5,21 +5,60 @@
#include "types.glsl"
#if defined(DATA_A_F32)
FLOAT_TYPE dequantize1(uint ib, uint iqs, uint a_offset) {
return data_a[a_offset + ib];
}
vec2 dequantize(uint ib, uint iqs, uint a_offset) {
return vec2(data_a[a_offset + ib], data_a[a_offset + ib + 1]);
}
vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
return vec4(data_a[a_offset + ib ], data_a[a_offset + ib + 1],
data_a[a_offset + ib + 2], data_a[a_offset + ib + 3]);
}
vec4 dequantize4_2aligned(uint ib, uint iqs, uint a_offset) {
return vec4(data_a[a_offset + ib ], data_a[a_offset + ib + 1],
data_a[a_offset + ib + 2], data_a[a_offset + ib + 3]);
}
#endif
#if defined(DATA_A_F16)
FLOAT_TYPE dequantize1(uint ib, uint iqs, uint a_offset) {
return data_a[a_offset + ib];
}
vec2 dequantize(uint ib, uint iqs, uint a_offset) {
return vec2(data_a[a_offset + ib], data_a[a_offset + ib + 1]);
}
vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
return vec4(data_a[a_offset + ib ], data_a[a_offset + ib + 1],
data_a[a_offset + ib + 2], data_a[a_offset + ib + 3]);
}
vec4 dequantize4_2aligned(uint ib, uint iqs, uint a_offset) {
const vec2 a = data_a_packed32[(a_offset + ib)/2];
const vec2 b = data_a_packed32[(a_offset + ib)/2 + 1];
return vec4(a, b);
}
#endif
#if defined(DATA_A_BF16)
FLOAT_TYPE dequantize1(uint ib, uint iqs, uint a_offset) {
return bf16_to_fp32(data_a[a_offset + ib]);
}
vec2 dequantize(uint ib, uint iqs, uint a_offset) {
return vec2(bf16_to_fp32(data_a[a_offset + ib]), bf16_to_fp32(data_a[a_offset + ib + 1]));
}
vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
return vec4(bf16_to_fp32(data_a[a_offset + ib ]), bf16_to_fp32(data_a[a_offset + ib + 1]),
bf16_to_fp32(data_a[a_offset + ib + 2]), bf16_to_fp32(data_a[a_offset + ib + 3]));
}
vec4 dequantize4_2aligned(uint ib, uint iqs, uint a_offset) {
const uint a = data_a_packed32[(a_offset + ib)/2];
const uint b = data_a_packed32[(a_offset + ib)/2 + 1];
return vec4(uintBitsToFloat((a & 0x0000ffff) << 16),
uintBitsToFloat( a & 0xffff0000),
uintBitsToFloat((b & 0x0000ffff) << 16),
uintBitsToFloat( b & 0xffff0000));
}
#endif
#if defined(DATA_A_Q4_0)

View file

@ -10,12 +10,38 @@ layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
#if !defined(DATA_A_F32) && !defined(DATA_A_F16) && !defined(DATA_A_BF16)
#define K_PER_ITER 8
#else
#define K_PER_ITER 2
#define K_PER_ITER 4
#endif
uint a_offset, b_offset, d_offset, y_offset;
vec4 load_b(const uint j, const uint iybs, const uint iqs, const bool lastiter, out bool OOB_y, out bool OOB_z, out bool OOB_w) {
// Check if the latter elements are OOB, and don't fetch B or accumulate it.
OOB_y = lastiter && (iybs + iqs + y_offset >= p.ncols);
OOB_z = lastiter && (iybs + iqs + y_offset*2 >= p.ncols);
OOB_w = lastiter && (iybs + iqs + y_offset*3 >= p.ncols);
if (!OOB_w) {
return vec4(FLOAT_TYPE(data_b[j*p.batch_stride_b + b_offset + iybs + iqs]),
FLOAT_TYPE(data_b[j*p.batch_stride_b + b_offset + iybs + iqs + y_offset]),
FLOAT_TYPE(data_b[j*p.batch_stride_b + b_offset + iybs + iqs + y_offset*2]),
FLOAT_TYPE(data_b[j*p.batch_stride_b + b_offset + iybs + iqs + y_offset*3]));
} else if (!OOB_z) {
return vec4(FLOAT_TYPE(data_b[j*p.batch_stride_b + b_offset + iybs + iqs]),
FLOAT_TYPE(data_b[j*p.batch_stride_b + b_offset + iybs + iqs + y_offset]),
FLOAT_TYPE(data_b[j*p.batch_stride_b + b_offset + iybs + iqs + y_offset*2]),
0);
} else if (!OOB_y) {
return vec4(FLOAT_TYPE(data_b[j*p.batch_stride_b + b_offset + iybs + iqs]),
FLOAT_TYPE(data_b[j*p.batch_stride_b + b_offset + iybs + iqs + y_offset]),
0, 0);
} else {
return vec4(FLOAT_TYPE(data_b[j*p.batch_stride_b + b_offset + iybs + iqs]),
0, 0, 0);
}
}
void iter(inout FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const uint first_row, const uint num_rows, const uint tid, const uint i, bool lastiter)
{
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
@ -25,6 +51,8 @@ void iter(inout FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const uint first_row, const
#if K_PER_ITER == 8
#if QUANT_R == 2
// Note that we end up fetching bogus elements here, but its fine as they'll be
// within an accessible block.
const vec4 bv02 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + iybs + iqs) / 4]);
const vec4 bv13 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + iybs + iqs + y_offset) / 4]);
const vec4 bv0 = vec4(bv02.x, bv13.x, bv02.y, bv13.y);
@ -34,18 +62,11 @@ void iter(inout FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const uint first_row, const
const vec4 bv1 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + iybs + iqs) / 4 + 1]);
#endif
#else
// Check if the second of the pair of elements is OOB, and don't fetch B or
// accumulate it. We still fetch a pair of elements for A, which is fine for
// quantized formats since they'll be within the same block. We should
// probably skip fetching the second element for F16/F32, but as of now we
// still do.
const bool OOB = lastiter && (iybs + iqs + y_offset >= p.ncols);
bool OOB_y;
bool OOB_z;
bool OOB_w;
FLOAT_TYPE b0 = 0, b1 = 0;
b0 = FLOAT_TYPE(data_b[j*p.batch_stride_b + b_offset + iybs + iqs]);
if (!OOB) {
b1 = FLOAT_TYPE(data_b[j*p.batch_stride_b + b_offset + iybs + iqs + y_offset]);
}
const vec4 b = load_b(j, iybs, iqs, lastiter, OOB_y, OOB_z, OOB_w);
#endif
uint ibi = first_row*p.ncols;
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
@ -71,22 +92,60 @@ void iter(inout FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const uint first_row, const
temp[j][n] += rowtmp;
#else
const vec2 v = dequantize(ib, iqs, a_offset);
// matrix multiplication
temp[j][n] = fma(FLOAT_TYPE(v.x), b0, temp[j][n]);
if (!OOB) {
temp[j][n] = fma(FLOAT_TYPE(v.y), b1, temp[j][n]);
if (!OOB_w) {
const vec4 v = dequantize4(ib, iqs, a_offset);
temp[j][n] += dot(v, b);
} else if (!OOB_z) {
const vec2 v0 = dequantize(ib, iqs, a_offset);
const FLOAT_TYPE v1 = dequantize1(ib + 2/QUANT_R, iqs, a_offset);
const vec3 v = vec3(v0.x, v0.y, v1);
const vec3 b0 = vec3(b.x, b.y, b.z);
temp[j][n] += dot(v, b0);
} else if (!OOB_y) {
const vec2 v0 = dequantize(ib, iqs, a_offset);
const vec2 b0 = vec2(b.x, b.y);
temp[j][n] += dot(v0, b0);
} else {
const FLOAT_TYPE v = dequantize1(ib, iqs, a_offset);
temp[j][n] = fma(v, b.x, temp[j][n]);
}
#endif
}
}
}
#if defined(DATA_A_F32) || defined(DATA_A_F16) || defined(DATA_A_BF16)
void iter_aligned_nonquant(inout FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const uint first_row, const uint num_rows, const uint tid, const uint i)
{
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
const uint col = i*BLOCK_SIZE + K_PER_ITER*tid;
const uint iqs = 0; // quant index
const uint iybs = col; // y block start index
const vec4 b = data_b_v4[(j*p.batch_stride_b + b_offset + iybs + iqs) / 4];
uint ibi = first_row*p.ncols;
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
const uint ib = (ibi + col)/QUANT_K; // block index
ibi += p.ncols;
const vec4 v = dequantize4_2aligned(ib, iqs, a_offset);
// matrix multiplication
temp[j][n] += dot(v, b);
}
}
}
#endif
void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
const uint tid = gl_LocalInvocationID.x;
get_offsets(a_offset, b_offset, d_offset);
const bool is_aligned_nonquant =
p.batch_stride_b % 4 == 0 && b_offset % 4 == 0 &&
p.ncols % 4 == 0 && BLOCK_SIZE % 4 == 0 &&
K_PER_ITER == 4;
y_offset = QUANT_R == 1 ? 1 : QUANT_K/2;
@ -105,17 +164,26 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
int unroll_count = 4;
uint unrolled_iters = num_iters & ~(unroll_count - 1);
#if K_PER_ITER == 2
uint i = 0;
#if K_PER_ITER == 4
// If the K dimension is odd, we need lastiter==true on the last iteration
// so OOB is computed correctly. Skip some unrolling to make that happen.
if ((p.ncols & 1) != 0 &&
if ((p.ncols & 3) != 0 &&
unrolled_iters == num_iters &&
unrolled_iters > 0) {
unrolled_iters -= unroll_count;
}
if (is_aligned_nonquant) {
while (i < unrolled_iters) {
// Manually partially unroll the loop
[[unroll]] for (uint k = 0; k < unroll_count; ++k) {
iter_aligned_nonquant(temp, first_row, num_rows, tid, i*K_PER_ITER);
i++;
}
}
} else {
#endif
uint i = 0;
while (i < unrolled_iters) {
// Manually partially unroll the loop
[[unroll]] for (uint k = 0; k < unroll_count; ++k) {
@ -123,18 +191,30 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
i++;
}
}
#if K_PER_ITER == 4
}
#endif
unroll_count = 2;
unrolled_iters = num_iters & ~(unroll_count - 1);
#if K_PER_ITER == 2
if ((p.ncols & 1) != 0 &&
#if K_PER_ITER == 4
if ((p.ncols & 3) != 0 &&
unrolled_iters == num_iters &&
unrolled_iters > 0) {
unrolled_iters -= unroll_count;
}
#endif
if (is_aligned_nonquant) {
while (i < unrolled_iters && is_aligned_nonquant) {
// Manually partially unroll the loop
[[unroll]] for (uint k = 0; k < unroll_count; ++k) {
iter_aligned_nonquant(temp, first_row, num_rows, tid, i*K_PER_ITER);
i++;
}
}
} else {
#endif
while (i < unrolled_iters) {
// Manually partially unroll the loop
[[unroll]] for (uint k = 0; k < unroll_count; ++k) {
@ -142,10 +222,25 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
i++;
}
}
#if K_PER_ITER == 4
}
#endif
#if K_PER_ITER == 4
if (is_aligned_nonquant) {
while (i < num_iters) {
iter_aligned_nonquant(temp, first_row, num_rows, tid, i*K_PER_ITER);
i++;
}
} else {
#endif
while (i < num_iters) {
iter(temp, first_row, num_rows, tid, i*K_PER_ITER, true);
i++;
}
#if K_PER_ITER == 4
}
#endif
reduce_result(temp, d_offset, first_row, num_rows, tid);
}
@ -164,6 +259,6 @@ void main() {
if (first_row >= p.stride_d) {
return;
}
compute_outputs(first_row, p.stride_d - first_row);
compute_outputs(first_row, min(NUM_ROWS, p.stride_d - first_row));
}
}

View file

@ -31,6 +31,7 @@
#else
#define A_TYPE float16_t
#endif
#define A_TYPE_PACKED32 f16vec2
#endif
#if defined(DATA_A_BF16)
@ -44,6 +45,7 @@
#else
#define A_TYPE uint16_t
#endif
#define A_TYPE_PACKED32 uint32_t
#endif
#define QUANT_K_Q4_0 32