mirror of
https://github.com/LostRuins/koboldcpp.git
synced 2026-05-31 21:39:42 +00:00
vulkan: Switch MUL_MAT_VEC to 4 K per iteration for F16/32 (#22887)
* vulkan: Switch MUL_MAT_VEC to 4 K per iteration for F16/32 Against mesa git, this shows a 4.8% performance improvement for tg128 on Qwen3.5-9B:BF16 on Intel BMG. Note that this breaks some tests until the last commit which fixes OOB A reads. * vulkan: Use aligned loads in mul_mat_vec when available Against mesa git, this shows a 3.3% performance improvement for tg128 on Qwen3.5-9B:BF16 on Intel BMG. * Make explicit that `num_rows` is <= `NUM_ROWS` in mul_mat_vec Mesa's UUB logic can't see through conditionals, limiting its ability to understand the bounds on the `num_rows` field in the cleanup run. Making it explicit that `num_rows` is, indeed, always <= `NUM_ROWS` helps mesa make slightly better codegen. Against mesa git, this currently shows a 1% performance improvement in tg128 on Qwen3.5-9B:BF16 on Intel BMG. * vulkan: Fix OOB A reads in MUL_MAT_VEC for odd sizes There was a TODO to fix the OOB reads from the A matrix which we do here. It is within performance noise (+<0.1%) in tg128 for Qwen3.5-9B:BF16 on Intel BMG.
This commit is contained in:
parent
b36eefc1b3
commit
c6e4088376
3 changed files with 162 additions and 26 deletions
|
|
@ -5,21 +5,60 @@
|
|||
#include "types.glsl"
|
||||
|
||||
#if defined(DATA_A_F32)
|
||||
FLOAT_TYPE dequantize1(uint ib, uint iqs, uint a_offset) {
|
||||
return data_a[a_offset + ib];
|
||||
}
|
||||
vec2 dequantize(uint ib, uint iqs, uint a_offset) {
|
||||
return vec2(data_a[a_offset + ib], data_a[a_offset + ib + 1]);
|
||||
}
|
||||
vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
|
||||
return vec4(data_a[a_offset + ib ], data_a[a_offset + ib + 1],
|
||||
data_a[a_offset + ib + 2], data_a[a_offset + ib + 3]);
|
||||
}
|
||||
vec4 dequantize4_2aligned(uint ib, uint iqs, uint a_offset) {
|
||||
return vec4(data_a[a_offset + ib ], data_a[a_offset + ib + 1],
|
||||
data_a[a_offset + ib + 2], data_a[a_offset + ib + 3]);
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
#if defined(DATA_A_F16)
|
||||
FLOAT_TYPE dequantize1(uint ib, uint iqs, uint a_offset) {
|
||||
return data_a[a_offset + ib];
|
||||
}
|
||||
vec2 dequantize(uint ib, uint iqs, uint a_offset) {
|
||||
return vec2(data_a[a_offset + ib], data_a[a_offset + ib + 1]);
|
||||
}
|
||||
vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
|
||||
return vec4(data_a[a_offset + ib ], data_a[a_offset + ib + 1],
|
||||
data_a[a_offset + ib + 2], data_a[a_offset + ib + 3]);
|
||||
}
|
||||
vec4 dequantize4_2aligned(uint ib, uint iqs, uint a_offset) {
|
||||
const vec2 a = data_a_packed32[(a_offset + ib)/2];
|
||||
const vec2 b = data_a_packed32[(a_offset + ib)/2 + 1];
|
||||
return vec4(a, b);
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined(DATA_A_BF16)
|
||||
FLOAT_TYPE dequantize1(uint ib, uint iqs, uint a_offset) {
|
||||
return bf16_to_fp32(data_a[a_offset + ib]);
|
||||
}
|
||||
vec2 dequantize(uint ib, uint iqs, uint a_offset) {
|
||||
return vec2(bf16_to_fp32(data_a[a_offset + ib]), bf16_to_fp32(data_a[a_offset + ib + 1]));
|
||||
}
|
||||
vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
|
||||
return vec4(bf16_to_fp32(data_a[a_offset + ib ]), bf16_to_fp32(data_a[a_offset + ib + 1]),
|
||||
bf16_to_fp32(data_a[a_offset + ib + 2]), bf16_to_fp32(data_a[a_offset + ib + 3]));
|
||||
}
|
||||
vec4 dequantize4_2aligned(uint ib, uint iqs, uint a_offset) {
|
||||
const uint a = data_a_packed32[(a_offset + ib)/2];
|
||||
const uint b = data_a_packed32[(a_offset + ib)/2 + 1];
|
||||
return vec4(uintBitsToFloat((a & 0x0000ffff) << 16),
|
||||
uintBitsToFloat( a & 0xffff0000),
|
||||
uintBitsToFloat((b & 0x0000ffff) << 16),
|
||||
uintBitsToFloat( b & 0xffff0000));
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined(DATA_A_Q4_0)
|
||||
|
|
|
|||
|
|
@ -10,12 +10,38 @@ layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
|
|||
#if !defined(DATA_A_F32) && !defined(DATA_A_F16) && !defined(DATA_A_BF16)
|
||||
#define K_PER_ITER 8
|
||||
#else
|
||||
#define K_PER_ITER 2
|
||||
#define K_PER_ITER 4
|
||||
#endif
|
||||
|
||||
|
||||
uint a_offset, b_offset, d_offset, y_offset;
|
||||
|
||||
vec4 load_b(const uint j, const uint iybs, const uint iqs, const bool lastiter, out bool OOB_y, out bool OOB_z, out bool OOB_w) {
|
||||
// Check if the latter elements are OOB, and don't fetch B or accumulate it.
|
||||
OOB_y = lastiter && (iybs + iqs + y_offset >= p.ncols);
|
||||
OOB_z = lastiter && (iybs + iqs + y_offset*2 >= p.ncols);
|
||||
OOB_w = lastiter && (iybs + iqs + y_offset*3 >= p.ncols);
|
||||
|
||||
if (!OOB_w) {
|
||||
return vec4(FLOAT_TYPE(data_b[j*p.batch_stride_b + b_offset + iybs + iqs]),
|
||||
FLOAT_TYPE(data_b[j*p.batch_stride_b + b_offset + iybs + iqs + y_offset]),
|
||||
FLOAT_TYPE(data_b[j*p.batch_stride_b + b_offset + iybs + iqs + y_offset*2]),
|
||||
FLOAT_TYPE(data_b[j*p.batch_stride_b + b_offset + iybs + iqs + y_offset*3]));
|
||||
} else if (!OOB_z) {
|
||||
return vec4(FLOAT_TYPE(data_b[j*p.batch_stride_b + b_offset + iybs + iqs]),
|
||||
FLOAT_TYPE(data_b[j*p.batch_stride_b + b_offset + iybs + iqs + y_offset]),
|
||||
FLOAT_TYPE(data_b[j*p.batch_stride_b + b_offset + iybs + iqs + y_offset*2]),
|
||||
0);
|
||||
} else if (!OOB_y) {
|
||||
return vec4(FLOAT_TYPE(data_b[j*p.batch_stride_b + b_offset + iybs + iqs]),
|
||||
FLOAT_TYPE(data_b[j*p.batch_stride_b + b_offset + iybs + iqs + y_offset]),
|
||||
0, 0);
|
||||
} else {
|
||||
return vec4(FLOAT_TYPE(data_b[j*p.batch_stride_b + b_offset + iybs + iqs]),
|
||||
0, 0, 0);
|
||||
}
|
||||
}
|
||||
|
||||
void iter(inout FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const uint first_row, const uint num_rows, const uint tid, const uint i, bool lastiter)
|
||||
{
|
||||
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
|
||||
|
|
@ -25,6 +51,8 @@ void iter(inout FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const uint first_row, const
|
|||
|
||||
#if K_PER_ITER == 8
|
||||
#if QUANT_R == 2
|
||||
// Note that we end up fetching bogus elements here, but its fine as they'll be
|
||||
// within an accessible block.
|
||||
const vec4 bv02 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + iybs + iqs) / 4]);
|
||||
const vec4 bv13 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + iybs + iqs + y_offset) / 4]);
|
||||
const vec4 bv0 = vec4(bv02.x, bv13.x, bv02.y, bv13.y);
|
||||
|
|
@ -34,18 +62,11 @@ void iter(inout FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const uint first_row, const
|
|||
const vec4 bv1 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + iybs + iqs) / 4 + 1]);
|
||||
#endif
|
||||
#else
|
||||
// Check if the second of the pair of elements is OOB, and don't fetch B or
|
||||
// accumulate it. We still fetch a pair of elements for A, which is fine for
|
||||
// quantized formats since they'll be within the same block. We should
|
||||
// probably skip fetching the second element for F16/F32, but as of now we
|
||||
// still do.
|
||||
const bool OOB = lastiter && (iybs + iqs + y_offset >= p.ncols);
|
||||
bool OOB_y;
|
||||
bool OOB_z;
|
||||
bool OOB_w;
|
||||
|
||||
FLOAT_TYPE b0 = 0, b1 = 0;
|
||||
b0 = FLOAT_TYPE(data_b[j*p.batch_stride_b + b_offset + iybs + iqs]);
|
||||
if (!OOB) {
|
||||
b1 = FLOAT_TYPE(data_b[j*p.batch_stride_b + b_offset + iybs + iqs + y_offset]);
|
||||
}
|
||||
const vec4 b = load_b(j, iybs, iqs, lastiter, OOB_y, OOB_z, OOB_w);
|
||||
#endif
|
||||
uint ibi = first_row*p.ncols;
|
||||
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
|
||||
|
|
@ -71,22 +92,60 @@ void iter(inout FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const uint first_row, const
|
|||
|
||||
temp[j][n] += rowtmp;
|
||||
#else
|
||||
const vec2 v = dequantize(ib, iqs, a_offset);
|
||||
|
||||
// matrix multiplication
|
||||
temp[j][n] = fma(FLOAT_TYPE(v.x), b0, temp[j][n]);
|
||||
if (!OOB) {
|
||||
temp[j][n] = fma(FLOAT_TYPE(v.y), b1, temp[j][n]);
|
||||
if (!OOB_w) {
|
||||
const vec4 v = dequantize4(ib, iqs, a_offset);
|
||||
temp[j][n] += dot(v, b);
|
||||
} else if (!OOB_z) {
|
||||
const vec2 v0 = dequantize(ib, iqs, a_offset);
|
||||
const FLOAT_TYPE v1 = dequantize1(ib + 2/QUANT_R, iqs, a_offset);
|
||||
const vec3 v = vec3(v0.x, v0.y, v1);
|
||||
const vec3 b0 = vec3(b.x, b.y, b.z);
|
||||
temp[j][n] += dot(v, b0);
|
||||
} else if (!OOB_y) {
|
||||
const vec2 v0 = dequantize(ib, iqs, a_offset);
|
||||
const vec2 b0 = vec2(b.x, b.y);
|
||||
temp[j][n] += dot(v0, b0);
|
||||
} else {
|
||||
const FLOAT_TYPE v = dequantize1(ib, iqs, a_offset);
|
||||
temp[j][n] = fma(v, b.x, temp[j][n]);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#if defined(DATA_A_F32) || defined(DATA_A_F16) || defined(DATA_A_BF16)
|
||||
void iter_aligned_nonquant(inout FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const uint first_row, const uint num_rows, const uint tid, const uint i)
|
||||
{
|
||||
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
|
||||
const uint col = i*BLOCK_SIZE + K_PER_ITER*tid;
|
||||
const uint iqs = 0; // quant index
|
||||
const uint iybs = col; // y block start index
|
||||
|
||||
const vec4 b = data_b_v4[(j*p.batch_stride_b + b_offset + iybs + iqs) / 4];
|
||||
|
||||
uint ibi = first_row*p.ncols;
|
||||
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
|
||||
const uint ib = (ibi + col)/QUANT_K; // block index
|
||||
ibi += p.ncols;
|
||||
|
||||
const vec4 v = dequantize4_2aligned(ib, iqs, a_offset);
|
||||
|
||||
// matrix multiplication
|
||||
temp[j][n] += dot(v, b);
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
|
||||
const uint tid = gl_LocalInvocationID.x;
|
||||
|
||||
get_offsets(a_offset, b_offset, d_offset);
|
||||
const bool is_aligned_nonquant =
|
||||
p.batch_stride_b % 4 == 0 && b_offset % 4 == 0 &&
|
||||
p.ncols % 4 == 0 && BLOCK_SIZE % 4 == 0 &&
|
||||
K_PER_ITER == 4;
|
||||
|
||||
y_offset = QUANT_R == 1 ? 1 : QUANT_K/2;
|
||||
|
||||
|
|
@ -105,17 +164,26 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
|
|||
int unroll_count = 4;
|
||||
uint unrolled_iters = num_iters & ~(unroll_count - 1);
|
||||
|
||||
#if K_PER_ITER == 2
|
||||
uint i = 0;
|
||||
|
||||
#if K_PER_ITER == 4
|
||||
// If the K dimension is odd, we need lastiter==true on the last iteration
|
||||
// so OOB is computed correctly. Skip some unrolling to make that happen.
|
||||
if ((p.ncols & 1) != 0 &&
|
||||
if ((p.ncols & 3) != 0 &&
|
||||
unrolled_iters == num_iters &&
|
||||
unrolled_iters > 0) {
|
||||
unrolled_iters -= unroll_count;
|
||||
}
|
||||
if (is_aligned_nonquant) {
|
||||
while (i < unrolled_iters) {
|
||||
// Manually partially unroll the loop
|
||||
[[unroll]] for (uint k = 0; k < unroll_count; ++k) {
|
||||
iter_aligned_nonquant(temp, first_row, num_rows, tid, i*K_PER_ITER);
|
||||
i++;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
#endif
|
||||
|
||||
uint i = 0;
|
||||
while (i < unrolled_iters) {
|
||||
// Manually partially unroll the loop
|
||||
[[unroll]] for (uint k = 0; k < unroll_count; ++k) {
|
||||
|
|
@ -123,18 +191,30 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
|
|||
i++;
|
||||
}
|
||||
}
|
||||
#if K_PER_ITER == 4
|
||||
}
|
||||
#endif
|
||||
|
||||
unroll_count = 2;
|
||||
unrolled_iters = num_iters & ~(unroll_count - 1);
|
||||
|
||||
#if K_PER_ITER == 2
|
||||
if ((p.ncols & 1) != 0 &&
|
||||
#if K_PER_ITER == 4
|
||||
if ((p.ncols & 3) != 0 &&
|
||||
unrolled_iters == num_iters &&
|
||||
unrolled_iters > 0) {
|
||||
unrolled_iters -= unroll_count;
|
||||
}
|
||||
#endif
|
||||
|
||||
if (is_aligned_nonquant) {
|
||||
while (i < unrolled_iters && is_aligned_nonquant) {
|
||||
// Manually partially unroll the loop
|
||||
[[unroll]] for (uint k = 0; k < unroll_count; ++k) {
|
||||
iter_aligned_nonquant(temp, first_row, num_rows, tid, i*K_PER_ITER);
|
||||
i++;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
#endif
|
||||
while (i < unrolled_iters) {
|
||||
// Manually partially unroll the loop
|
||||
[[unroll]] for (uint k = 0; k < unroll_count; ++k) {
|
||||
|
|
@ -142,10 +222,25 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
|
|||
i++;
|
||||
}
|
||||
}
|
||||
#if K_PER_ITER == 4
|
||||
}
|
||||
#endif
|
||||
|
||||
#if K_PER_ITER == 4
|
||||
if (is_aligned_nonquant) {
|
||||
while (i < num_iters) {
|
||||
iter_aligned_nonquant(temp, first_row, num_rows, tid, i*K_PER_ITER);
|
||||
i++;
|
||||
}
|
||||
} else {
|
||||
#endif
|
||||
while (i < num_iters) {
|
||||
iter(temp, first_row, num_rows, tid, i*K_PER_ITER, true);
|
||||
i++;
|
||||
}
|
||||
#if K_PER_ITER == 4
|
||||
}
|
||||
#endif
|
||||
|
||||
reduce_result(temp, d_offset, first_row, num_rows, tid);
|
||||
}
|
||||
|
|
@ -164,6 +259,6 @@ void main() {
|
|||
if (first_row >= p.stride_d) {
|
||||
return;
|
||||
}
|
||||
compute_outputs(first_row, p.stride_d - first_row);
|
||||
compute_outputs(first_row, min(NUM_ROWS, p.stride_d - first_row));
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -31,6 +31,7 @@
|
|||
#else
|
||||
#define A_TYPE float16_t
|
||||
#endif
|
||||
#define A_TYPE_PACKED32 f16vec2
|
||||
#endif
|
||||
|
||||
#if defined(DATA_A_BF16)
|
||||
|
|
@ -44,6 +45,7 @@
|
|||
#else
|
||||
#define A_TYPE uint16_t
|
||||
#endif
|
||||
#define A_TYPE_PACKED32 uint32_t
|
||||
#endif
|
||||
|
||||
#define QUANT_K_Q4_0 32
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue