vulkan: Check shared memory size for mmq shaders (#22693)

This commit is contained in:
Jeff Bolz 2026-05-12 04:41:58 -05:00 committed by GitHub
parent fa62042af9
commit 706fbd8ab6
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -681,6 +681,15 @@ struct vk_device_struct {
bool mul_mat_id_m[GGML_TYPE_COUNT];
bool mul_mat_id_s[GGML_TYPE_COUNT];
// Separate flags for the q8_1 (integer dot) mmq path, whose shader uses
// a different shared-memory layout than the float matmul shaders.
bool mul_mat_l_int[GGML_TYPE_COUNT];
bool mul_mat_m_int[GGML_TYPE_COUNT];
bool mul_mat_s_int[GGML_TYPE_COUNT];
bool mul_mat_id_l_int[GGML_TYPE_COUNT];
bool mul_mat_id_m_int[GGML_TYPE_COUNT];
bool mul_mat_id_s_int[GGML_TYPE_COUNT];
vk::DescriptorSetLayout dsl;
vk_matmul_pipeline pipeline_matmul_f32 {};
@ -3207,6 +3216,70 @@ static bool ggml_vk_matmul_shmem_support(const vk_device& device, const std::vec
return supported;
}
// Shmem usage for the q8_1 mmq shader (mul_mmq.comp), which uses
// block_a_cache / block_b_cache layouts (see mul_mmq_shmem_types.glsl) rather
// than the float load buffers checked by ggml_vk_matmul_shmem_support.
// Sizes follow std430 rules. Returns false for types without a q8_1 pipeline.
static bool ggml_vk_matmul_int_shmem_support(const vk_device& device, const std::vector<uint32_t>& warptile, bool mul_mat_id, ggml_type src0_type) {
// FLOAT_TYPE in the shader is float16_t with fp16 support, otherwise float.
const uint32_t fp_size = device->fp16 ? 2u : 4u;
const uint32_t fp_align = fp_size;
const uint32_t fp2_size = 2u * fp_size;
const uint32_t fp2_align = device->fp16 ? 4u : 8u;
struct member { uint32_t size, align; };
auto std430_size = [](std::initializer_list<member> members) {
uint32_t off = 0, struct_align = 1;
for (const auto &m : members) {
off = (off + m.align - 1) & ~(m.align - 1);
off += m.size;
struct_align = std::max(struct_align, m.align);
}
return (off + struct_align - 1) & ~(struct_align - 1);
};
uint32_t block_a_size = 0;
switch (src0_type) {
case GGML_TYPE_Q4_0: block_a_size = std430_size({{16, 4}, {fp_size, fp_align}}); break; // qs[16/4] + dm
case GGML_TYPE_Q4_1: block_a_size = std430_size({{16, 4}, {fp2_size, fp2_align}}); break; // qs[16/4] + dm(vec2)
case GGML_TYPE_Q5_0: block_a_size = std430_size({{16, 4}, {4, 4}, {fp_size, fp_align}}); break; // qs[16/4] + qh + dm
case GGML_TYPE_Q5_1: block_a_size = std430_size({{16, 4}, {4, 4}, {fp2_size, fp2_align}}); break; // qs[16/4] + qh + dm(vec2)
case GGML_TYPE_Q8_0: block_a_size = std430_size({{32, 4}, {fp_size, fp_align}}); break; // qs[8] + dm
case GGML_TYPE_MXFP4: block_a_size = std430_size({{32, 4}, {fp_size, fp_align}}); break; // qs[8] + d
case GGML_TYPE_Q2_K: block_a_size = std430_size({{ 8, 4}, {2, 2}, {fp2_size, fp2_align}}); break; // qs[2] + scales(u8vec2) + dm(vec2)
case GGML_TYPE_Q3_K: block_a_size = std430_size({{16, 4}, {fp2_size, fp2_align}}); break; // qs[4] + d_scales(vec2)
case GGML_TYPE_Q4_K: block_a_size = std430_size({{16, 4}, {fp2_size, fp2_align}}); break; // qs[4] + dm(vec2)
case GGML_TYPE_Q5_K: block_a_size = std430_size({{32, 4}, {fp2_size, fp2_align}}); break; // qs[8] + dm(vec2)
case GGML_TYPE_Q6_K: block_a_size = std430_size({{32, 4}, {fp2_size, fp2_align}}); break; // qs[8] + d_scales(vec2)
default:
return false;
}
// block_b_cache: { int32_t qs[8]; FLOAT_TYPEV2 ds; }
const uint32_t block_b_size = std430_size({{32, 4}, {fp2_size, fp2_align}});
const uint32_t BM = warptile[1];
const uint32_t BN = warptile[2];
// mul_mmq.comp: BK_STEP=1 for MUL_MAT_ID, 4 otherwise.
const uint32_t BK_STEP = mul_mat_id ? 1u : 4u;
const uint32_t buf_a_size = BM * BK_STEP * block_a_size;
const uint32_t buf_b_size = BN * BK_STEP * block_b_size;
const uint32_t mmid_row_ids = mul_mat_id ? (BN * 2u * (uint32_t)sizeof(uint16_t)) : 0u;
const uint32_t warps = warptile[0] / warptile[10];
const uint32_t ballots_sh = mul_mat_id ? (warps * 4u * (uint32_t)sizeof(uint32_t)) : 0u;
const uint32_t total_size = buf_a_size + buf_b_size + mmid_row_ids + ballots_sh;
const bool supported = total_size <= device->properties.limits.maxComputeSharedMemorySize;
VK_LOG_DEBUG("ggml_vk_matmul_int_shmem_support(warptile=(" << warptile[0] << "," << warptile[1] << "," << warptile[2] << "), "
"mul_mat_id=" << mul_mat_id << ", src0_type=" << ggml_type_name(src0_type) << ", total=" << total_size << ", supported=" << supported);
return supported;
}
struct GpuPipelineConfig {
// GPU architecture identifier.
// Example: vk_device_architecture::AMD_GCN
@ -3453,6 +3526,40 @@ static void ggml_vk_load_shaders(vk_device& device) {
} else if (!ggml_vk_matmul_shmem_support(device, l_warptile_mmqid, true, t)) {
device->mul_mat_id_l[i] = false;
}
// The q8_1 mmq path has its own (larger) shmem layout, check it separately.
// K-quants use the _int_k warptiles, others use _int.
const bool is_k_quant = (t == GGML_TYPE_Q2_K || t == GGML_TYPE_Q3_K ||
t == GGML_TYPE_Q4_K || t == GGML_TYPE_Q5_K ||
t == GGML_TYPE_Q6_K);
const auto & s_int = is_k_quant ? s_warptile_mmq_int_k : s_warptile_mmq_int;
const auto & m_int = is_k_quant ? m_warptile_mmq_int_k : m_warptile_mmq_int;
const auto & l_int = is_k_quant ? l_warptile_mmq_int_k : l_warptile_mmq_int;
const auto & s_intid = is_k_quant ? s_warptile_mmqid_int_k : s_warptile_mmqid_int;
const auto & m_intid = is_k_quant ? m_warptile_mmqid_int_k : m_warptile_mmqid_int;
const auto & l_intid = is_k_quant ? l_warptile_mmqid_int_k : l_warptile_mmqid_int;
if (!ggml_vk_matmul_int_shmem_support(device, s_int, false, t)) {
device->mul_mat_s_int[i] = false;
device->mul_mat_m_int[i] = false;
device->mul_mat_l_int[i] = false;
} else if (!ggml_vk_matmul_int_shmem_support(device, m_int, false, t)) {
device->mul_mat_m_int[i] = false;
device->mul_mat_l_int[i] = false;
} else if (!ggml_vk_matmul_int_shmem_support(device, l_int, false, t)) {
device->mul_mat_l_int[i] = false;
}
if (!ggml_vk_matmul_int_shmem_support(device, s_intid, true, t)) {
device->mul_mat_id_s_int[i] = false;
device->mul_mat_id_m_int[i] = false;
device->mul_mat_id_l_int[i] = false;
} else if (!ggml_vk_matmul_int_shmem_support(device, m_intid, true, t)) {
device->mul_mat_id_m_int[i] = false;
device->mul_mat_id_l_int[i] = false;
} else if (!ggml_vk_matmul_int_shmem_support(device, l_intid, true, t)) {
device->mul_mat_id_l_int[i] = false;
}
}
}
@ -5613,6 +5720,13 @@ static vk_device ggml_vk_get_device(size_t idx) {
device->mul_mat_id_s[i] = true;
break;
}
device->mul_mat_l_int[i] = true;
device->mul_mat_m_int[i] = true;
device->mul_mat_s_int[i] = true;
device->mul_mat_id_l_int[i] = true;
device->mul_mat_id_m_int[i] = true;
device->mul_mat_id_s_int[i] = true;
}
@ -7220,6 +7334,13 @@ static uint32_t ggml_vk_guess_split_k(ggml_backend_vk_context * ctx, uint32_t m,
static vk_pipeline ggml_vk_guess_matmul_pipeline(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, uint32_t m, uint32_t n, bool aligned, ggml_type src0_type, ggml_type src1_type) {
VK_LOG_DEBUG("ggml_vk_guess_matmul_pipeline(" << m << ", " << n << ", " << aligned << ", " << ggml_type_name(src0_type) << ", " << ggml_type_name(src1_type) << ")");
// The q8_1 (integer dot) mmq path uses a different shader with its own
// shared-memory layout, so use the int-specific availability flags.
const bool is_q8_1 = (src1_type == GGML_TYPE_Q8_1);
const bool mm_l = is_q8_1 ? ctx->device->mul_mat_l_int[src0_type] : ctx->device->mul_mat_l[src0_type];
const bool mm_m = is_q8_1 ? ctx->device->mul_mat_m_int[src0_type] : ctx->device->mul_mat_m[src0_type];
const bool mm_s = is_q8_1 ? ctx->device->mul_mat_s_int[src0_type] : ctx->device->mul_mat_s[src0_type];
if (ctx->device->coopmat2) {
const uint32_t shader_core_count = ctx->device->shader_core_count;
const uint32_t tiles_l = CEIL_DIV(m, mmp->a_l->wg_denoms[0]) * CEIL_DIV(n, mmp->a_l->wg_denoms[1]);
@ -7236,26 +7357,24 @@ static vk_pipeline ggml_vk_guess_matmul_pipeline(ggml_backend_vk_context * ctx,
// split_k==3 with large tiles likely better than medium tiles with no split_k.
(tiles_l <= shader_core_count / 3 && tiles_m > shader_core_count / 2);
if ((ctx->device->mul_mat_l[src0_type] && (n > crossover_large && prefer_large)) || (!ctx->device->mul_mat_m[src0_type] && !ctx->device->mul_mat_s[src0_type])) {
if ((mm_l && (n > crossover_large && prefer_large)) || (!mm_m && !mm_s)) {
return aligned ? mmp->a_l : mmp->l;
}
// Use medium shader when the N dimension is greater than the small shader's tile size
uint32_t crossover_medium = mmp->s->wg_denoms[1];
if ((ctx->device->mul_mat_m[src0_type] && (n > crossover_medium)) || !ctx->device->mul_mat_s[src0_type]) {
if ((mm_m && (n > crossover_medium)) || !mm_s) {
return aligned ? mmp->a_m : mmp->m;
}
return aligned ? mmp->a_s : mmp->s;
}
if ((ctx->device->mul_mat_s[src0_type] && (m <= 32 || n <= 32)) || (!ctx->device->mul_mat_m[src0_type] && !ctx->device->mul_mat_l[src0_type])) {
if ((mm_s && (m <= 32 || n <= 32)) || (!mm_m && !mm_l)) {
return aligned ? mmp->a_s : mmp->s;
}
if ((ctx->device->mul_mat_m[src0_type] && (m <= 64 || n <= 64)) || !ctx->device->mul_mat_l[src0_type]) {
if ((mm_m && (m <= 64 || n <= 64)) || !mm_l) {
return aligned ? mmp->a_m : mmp->m;
}
return aligned ? mmp->a_l : mmp->l;
GGML_UNUSED(src1_type);
}
static uint32_t ggml_vk_guess_matmul_pipeline_align(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n, ggml_type src0_type, ggml_type src1_type) {
@ -7312,35 +7431,42 @@ static void ggml_vk_matmul(
ctx->prealloc_split_k_need_sync = true;
}
static vk_pipeline ggml_vk_guess_matmul_id_pipeline(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, uint32_t m, uint32_t n, bool aligned, ggml_type src0_type) {
VK_LOG_DEBUG("ggml_vk_guess_matmul_id_pipeline(" << m << ", " << n << ", " << aligned << ", " << ggml_type_name(src0_type) << ")");
static vk_pipeline ggml_vk_guess_matmul_id_pipeline(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, uint32_t m, uint32_t n, bool aligned, ggml_type src0_type, ggml_type src1_type) {
VK_LOG_DEBUG("ggml_vk_guess_matmul_id_pipeline(" << m << ", " << n << ", " << aligned << ", " << ggml_type_name(src0_type) << ", " << ggml_type_name(src1_type) << ")");
// The q8_1 (integer dot) mmq path uses a different shader with its own
// shared-memory layout, so use the int-specific availability flags.
const bool is_q8_1 = (src1_type == GGML_TYPE_Q8_1);
const bool mm_l = is_q8_1 ? ctx->device->mul_mat_id_l_int[src0_type] : ctx->device->mul_mat_id_l[src0_type];
const bool mm_m = is_q8_1 ? ctx->device->mul_mat_id_m_int[src0_type] : ctx->device->mul_mat_id_m[src0_type];
const bool mm_s = is_q8_1 ? ctx->device->mul_mat_id_s_int[src0_type] : ctx->device->mul_mat_id_s[src0_type];
if (ctx->device->coopmat2) {
// Use large shader when the N dimension is greater than the medium shader's tile size
uint32_t crossover_large = mmp->m->wg_denoms[1];
if ((ctx->device->mul_mat_id_l[src0_type] && (n > crossover_large)) || (!ctx->device->mul_mat_id_m[src0_type] && !ctx->device->mul_mat_id_s[src0_type])) {
if ((mm_l && (n > crossover_large)) || (!mm_m && !mm_s)) {
return aligned ? mmp->a_l : mmp->l;
}
// Use medium shader when the N dimension is greater than the small shader's tile size
uint32_t crossover_medium = mmp->s->wg_denoms[1];
if ((ctx->device->mul_mat_id_m[src0_type] && (n > crossover_medium)) || !ctx->device->mul_mat_id_s[src0_type]) {
if ((mm_m && (n > crossover_medium)) || !mm_s) {
return aligned ? mmp->a_m : mmp->m;
}
return aligned ? mmp->a_s : mmp->s;
}
if ((ctx->device->mul_mat_id_s[src0_type] && (m <= 32 || n <= 32)) || (!ctx->device->mul_mat_id_m[src0_type] && !ctx->device->mul_mat_id_l[src0_type])) {
if ((mm_s && (m <= 32 || n <= 32)) || (!mm_m && !mm_l)) {
return aligned ? mmp->a_s : mmp->s;
}
if ((ctx->device->mul_mat_id_m[src0_type] && (m <= 64 || n <= 64)) || !ctx->device->mul_mat_id_l[src0_type]) {
if ((mm_m && (m <= 64 || n <= 64)) || !mm_l) {
return aligned ? mmp->a_m : mmp->m;
}
return aligned ? mmp->a_l : mmp->l;
}
static uint32_t ggml_vk_guess_matmul_id_pipeline_align(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n, ggml_type src0_type) {
VK_LOG_DEBUG("ggml_vk_guess_matmul_pipeline_align(" << m << ", " << n << ", " << ggml_type_name(src0_type) << ")");
return ggml_vk_guess_matmul_id_pipeline(ctx, mmp, m, n, true, src0_type)->align;
static uint32_t ggml_vk_guess_matmul_id_pipeline_align(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n, ggml_type src0_type, ggml_type src1_type) {
VK_LOG_DEBUG("ggml_vk_guess_matmul_pipeline_align(" << m << ", " << n << ", " << ggml_type_name(src0_type) << ", " << ggml_type_name(src1_type) << ")");
return ggml_vk_guess_matmul_id_pipeline(ctx, mmp, m, n, true, src0_type, src1_type)->align;
}
static void ggml_vk_matmul_id(
@ -7636,10 +7762,12 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
// Not implemented
GGML_ASSERT(y_non_contig || !qy_needs_dequant); // NOLINT
const uint32_t kpad = quantize_y ? 0 : ggml_vk_align_size(ne10, ggml_vk_guess_matmul_pipeline_align(ctx, mmp, ne01, ne11, qx_needs_dequant ? f16_type : src0->type, quantize_y ? GGML_TYPE_Q8_1 : (y_f32_kernel ? GGML_TYPE_F32 : src1->type)));
const ggml_type effective_src1_type = quantize_y ? GGML_TYPE_Q8_1 : (y_f32_kernel ? GGML_TYPE_F32 : src1->type);
const uint32_t kpad = quantize_y ? 0 : ggml_vk_align_size(ne10, ggml_vk_guess_matmul_pipeline_align(ctx, mmp, ne01, ne11, qx_needs_dequant ? f16_type : src0->type, effective_src1_type));
const bool aligned = !quantize_y && ne10 == kpad && ne01 > 8 && ne11 > 8;
vk_pipeline pipeline = ggml_vk_guess_matmul_pipeline(ctx, mmp, ne01, ne11, aligned, qx_needs_dequant ? f16_type : src0->type, quantize_y ? GGML_TYPE_Q8_1 : (y_f32_kernel ? GGML_TYPE_F32 : src1->type));
vk_pipeline pipeline = ggml_vk_guess_matmul_pipeline(ctx, mmp, ne01, ne11, aligned, qx_needs_dequant ? f16_type : src0->type, effective_src1_type);
if (ggml_nbytes(src0) > ctx->device->properties.limits.maxStorageBufferRange) {
pipeline = ggml_vk_get_64b_indexing_pipeline(ctx, pipeline);
@ -8471,10 +8599,12 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
// Not implemented
GGML_ASSERT(y_non_contig || !qy_needs_dequant); // NOLINT
const uint32_t kpad = quantize_y ? 0 : ggml_vk_align_size(ne10, ggml_vk_guess_matmul_id_pipeline_align(ctx, mmp, ne01, nei1, qx_needs_dequant ? f16_type : src0->type));
const ggml_type effective_src1_type = quantize_y ? GGML_TYPE_Q8_1 : (y_f32_kernel ? GGML_TYPE_F32 : src1->type);
const uint32_t kpad = quantize_y ? 0 : ggml_vk_align_size(ne10, ggml_vk_guess_matmul_id_pipeline_align(ctx, mmp, ne01, nei1, qx_needs_dequant ? f16_type : src0->type, effective_src1_type));
const bool aligned = !quantize_y && ne10 == kpad && ne01 > 8 && nei1 > 8;
vk_pipeline pipeline = ggml_vk_guess_matmul_id_pipeline(ctx, mmp, ne01, nei1, aligned, qx_needs_dequant ? f16_type : src0->type);
vk_pipeline pipeline = ggml_vk_guess_matmul_id_pipeline(ctx, mmp, ne01, nei1, aligned, qx_needs_dequant ? f16_type : src0->type, effective_src1_type);
if (ggml_nbytes(src0) > ctx->device->properties.limits.maxStorageBufferRange) {
pipeline = ggml_vk_get_64b_indexing_pipeline(ctx, pipeline);