Merge branch 'upstream' into concedo_experimental

# Conflicts:
#	.devops/vulkan.Dockerfile
#	ggml/src/ggml-sycl/ggml-sycl.cpp
#	ggml/src/ggml-webgpu/CMakeLists.txt
#	ggml/src/ggml-webgpu/ggml-webgpu.cpp
#	ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py
#	ggml/src/ggml-webgpu/wgsl-shaders/memset.wgsl
#	tests/test-backend-ops.cpp
#	tests/test-opt.cpp
This commit is contained in:
Concedo 2025-08-23 17:49:24 +08:00
commit 4828d0e148
6 changed files with 443 additions and 137 deletions

View file

@ -3498,11 +3498,11 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
case GGML_OP_CONV_TRANSPOSE_2D:
case GGML_OP_POOL_2D:
case GGML_OP_SUM:
case GGML_OP_SUM_ROWS:
case GGML_OP_MEAN:
case GGML_OP_ARGSORT:
case GGML_OP_ACC:
return true;
case GGML_OP_SUM_ROWS:
case GGML_OP_MEAN:
case GGML_OP_GROUP_NORM:
return ggml_is_contiguous(op->src[0]);
case GGML_OP_UPSCALE:

View file

@ -1031,6 +1031,39 @@ struct vk_op_upscale_push_constants {
float sf0; float sf1; float sf2; float sf3;
};
struct vk_op_sum_rows_push_constants
{
uint32_t n_cols;
uint32_t ne01, ne02;
uint32_t nb01, nb02, nb03;
uint32_t nb11, nb12, nb13;
float weight;
uint32_t misalign_offsets;
uint32_t ne0_12mp, ne0_12L;
uint32_t ne0_1mp, ne0_1L;
};
vk_op_sum_rows_push_constants vk_op_sum_rows_push_constants_init(const ggml_tensor * src, const ggml_tensor * dst, int64_t n_cols) {
uint32_t type_size = (uint32_t)ggml_type_size(src->type);
vk_op_sum_rows_push_constants p = {};
p.n_cols = (uint32_t)n_cols;
p.ne01 = (uint32_t)src->ne[1];
p.ne02 = (uint32_t)src->ne[2];
p.nb01 = (uint32_t)src->nb[1] / type_size;
p.nb02 = (uint32_t)src->nb[2] / type_size;
p.nb03 = (uint32_t)src->nb[3] / type_size;
p.nb11 = (uint32_t)dst->nb[1] / type_size;
p.nb12 = (uint32_t)dst->nb[2] / type_size;
p.nb13 = (uint32_t)dst->nb[3] / type_size;
p.weight = 1.0f;
return p;
}
template <> void init_pushconst_fastdiv(vk_op_sum_rows_push_constants &p) {
init_fastdiv_values(p.ne01*p.ne02, p.ne0_12mp, p.ne0_12L);
init_fastdiv_values(p.ne01, p.ne0_1mp, p.ne0_1L);
}
// Allow pre-recording command buffers
struct vk_staging_memcpy {
vk_staging_memcpy(void * _dst, const void * _src, size_t _n) : dst(_dst), src(_src), n(_n) {}
@ -1214,6 +1247,14 @@ struct ggml_backend_vk_context {
vk_pipeline_struct * prealloc_y_last_pipeline_used {};
const ggml_tensor * prealloc_y_last_tensor_used {};
// Track which nodes have been used since the last sync, and whether they were written to
std::vector<const ggml_tensor *> unsynced_nodes_written;
std::vector<const ggml_tensor *> unsynced_nodes_read;
// Track which prealloc buffers have pending reads that need to be synchronized.
// These are checked before writing to the buffer (and call ggml_vk_sync_buffers if set),
// and set to true after the buffer contents are consumed.
bool prealloc_x_need_sync, prealloc_y_need_sync, prealloc_split_k_need_sync;
vk_buffer buffer_pool[MAX_VK_BUFFERS];
vk_context_ref compute_ctx;
@ -1889,14 +1930,18 @@ static vk_subbuffer ggml_vk_subbuffer(vk_buffer& buf) {
return { buf, 0, VK_WHOLE_SIZE };
}
static void ggml_vk_sync_buffers(vk_context& ctx) {
static void ggml_vk_sync_buffers(ggml_backend_vk_context* ctx, vk_context& subctx) {
VK_LOG_DEBUG("ggml_vk_sync_buffers()");
const bool transfer_queue = ctx->p->q->transfer_only;
const bool transfer_queue = subctx->p->q->transfer_only;
ctx->s->buffer.pipelineBarrier(
ctx->p->q->stage_flags,
ctx->p->q->stage_flags,
if (ctx) {
ctx->prealloc_x_need_sync = ctx->prealloc_y_need_sync = ctx->prealloc_split_k_need_sync = false;
}
subctx->s->buffer.pipelineBarrier(
subctx->p->q->stage_flags,
subctx->p->q->stage_flags,
{},
{ {
{ !transfer_queue ? (vk::AccessFlagBits::eShaderRead | vk::AccessFlagBits::eShaderWrite | vk::AccessFlagBits::eTransferRead | vk::AccessFlagBits::eTransferWrite) : (vk::AccessFlagBits::eTransferRead | vk::AccessFlagBits::eTransferWrite) },
@ -2184,9 +2229,9 @@ static void ggml_vk_load_shaders(vk_device& device) {
s_mmq_wg_denoms_k = { 32, 64, 1 };
// spec constants and tile sizes for quant matmul_id
l_warptile_mmqid = { 256, 128, 128, 16, 0 };
m_warptile_mmqid = { 256, 128, 64, 16, 0 };
s_warptile_mmqid = { 256, 128, 64, 16, 0 };
l_warptile_mmqid = { 256, 128, 128, 16, 0, device->subgroup_size };
m_warptile_mmqid = { 256, 128, 64, 16, 0, device->subgroup_size };
s_warptile_mmqid = { 256, 128, 64, 16, 0, device->subgroup_size };
l_mmqid_wg_denoms = { 128, 128, 1 };
m_mmqid_wg_denoms = { 128, 64, 1 };
s_mmqid_wg_denoms = { 128, 64, 1 };
@ -3144,7 +3189,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
ggml_vk_create_pipeline(device, device->pipeline_argmax_f32, "argmax_f32", argmax_f32_len, argmax_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
ggml_vk_create_pipeline(device, device->pipeline_sum_rows_f32, "sum_rows_f32", sum_rows_f32_len, sum_rows_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
ggml_vk_create_pipeline(device, device->pipeline_sum_rows_f32, "sum_rows_f32", sum_rows_f32_len, sum_rows_f32_data, "main", 2, sizeof(vk_op_sum_rows_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
ggml_vk_create_pipeline(device, device->pipeline_count_equal_i32, "count_equal_i32", count_equal_i32_len, count_equal_i32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, { device->subgroup_size }, 1);
@ -4895,7 +4940,7 @@ static void ggml_vk_buffer_write_nc_async(ggml_backend_vk_context * ctx, vk_cont
}
}
ggml_vk_sync_buffers(subctx);
ggml_vk_sync_buffers(ctx, subctx);
subctx->s->buffer.copyBuffer(buf->buffer, dst->buffer, slices);
return;
}
@ -4910,7 +4955,7 @@ static void ggml_vk_buffer_write_nc_async(ggml_backend_vk_context * ctx, vk_cont
ggml_vk_ensure_sync_staging_buffer(ctx->device, copy_size);
VkBufferCopy buf_copy{ 0, offset, copy_size };
ggml_vk_sync_buffers(subctx);
ggml_vk_sync_buffers(ctx, subctx);
vkCmdCopyBuffer(subctx->s->buffer, (VkBuffer)staging->buffer, (VkBuffer)dst->buffer, 1, &buf_copy);
for (uint64_t i3 = 0; i3 < ne3; i3++) {
@ -4964,7 +5009,7 @@ static void ggml_vk_buffer_write_2d_async(vk_context subctx, vk_buffer& dst, siz
}
}
ggml_vk_sync_buffers(subctx);
ggml_vk_sync_buffers(nullptr, subctx);
subctx->s->buffer.copyBuffer(buf->buffer, dst->buffer, slices);
return;
}
@ -4985,7 +5030,7 @@ static void ggml_vk_buffer_write_2d_async(vk_context subctx, vk_buffer& dst, siz
offset,
copy_size};
ggml_vk_sync_buffers(subctx);
ggml_vk_sync_buffers(nullptr, subctx);
vkCmdCopyBuffer(subctx->s->buffer, (VkBuffer)staging_buffer->buffer, (VkBuffer)dst->buffer, 1, &buf_copy);
if (width == spitch) {
@ -5065,7 +5110,7 @@ static void ggml_vk_buffer_read_2d_async(vk_context subctx, vk_buffer& src, size
if (buf != nullptr) {
// Memory is pinned, use as staging buffer
ggml_vk_sync_buffers(subctx);
ggml_vk_sync_buffers(nullptr, subctx);
subctx->s->buffer.copyBuffer(src->buffer, buf->buffer, slices);
return;
@ -5082,7 +5127,7 @@ static void ggml_vk_buffer_read_2d_async(vk_context subctx, vk_buffer& src, size
vk_buffer& staging_buffer = src->device->sync_staging;
ggml_vk_sync_buffers(subctx);
ggml_vk_sync_buffers(nullptr, subctx);
subctx->s->buffer.copyBuffer(src->buffer, staging_buffer->buffer, slices);
deferred_memcpy(dst, staging_buffer->ptr, copy_size, &subctx->out_memcpys);
@ -5272,13 +5317,16 @@ static void ggml_vk_matmul(
uint32_t split_k, uint32_t batch, uint32_t ne02, uint32_t ne12, uint32_t broadcast2, uint32_t broadcast3,
uint32_t padded_n) {
VK_LOG_DEBUG("ggml_vk_matmul(a: (" << a.buffer->buffer << ", " << a.offset << ", " << a.size << "), b: (" << b.buffer->buffer << ", " << b.offset << ", " << b.size << "), d: (" << d.buffer->buffer << ", " << d.offset << ", " << d.size << "), split_k: (" << (split_k_buffer.buffer != nullptr ? split_k_buffer.buffer->buffer : VK_NULL_HANDLE) << ", " << split_k_buffer.offset << ", " << split_k_buffer.size << "), m: " << m << ", n: " << n << ", k: " << k << ", stride_a: " << stride_a << ", stride_b: " << stride_b << ", stride_d: " << stride_d << ", batch_stride_a: " << batch_stride_a << ", batch_stride_b: " << batch_stride_b << ", batch_stride_d: " << batch_stride_d << ", split_k: " << split_k << ", batch: " << batch << ", ne02: " << ne02 << ", ne12: " << ne12 << ", broadcast2: " << broadcast2 << ", broadcast3: " << broadcast3 << ", padded_n: " << padded_n << ")");
ggml_vk_sync_buffers(subctx);
if (split_k == 1) {
const vk_mat_mat_push_constants pc = { m, n, k, stride_a, stride_b, stride_d, batch_stride_a, batch_stride_b, batch_stride_d, k, ne02, ne12, broadcast2, broadcast3, padded_n };
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { a, b, d }, pc, { m, n, batch });
return;
}
if (ctx->prealloc_split_k_need_sync) {
ggml_vk_sync_buffers(ctx, subctx);
}
GGML_ASSERT(batch_stride_d == m * n);
// Round the split size up to a multiple of 256 (k-quant alignment)
@ -5288,9 +5336,10 @@ static void ggml_vk_matmul(
const vk_mat_mat_push_constants pc1 = { m, n, k, stride_a, stride_b, stride_d, batch_stride_a, batch_stride_b, batch_stride_d, k_split, ne02, ne12, broadcast2, broadcast3, padded_n };
// Make sure enough workgroups get assigned for split k to work
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { a, b, split_k_buffer }, pc1, { (CEIL_DIV(m, pipeline->wg_denoms[0]) * pipeline->wg_denoms[0]) * split_k, n, batch });
ggml_vk_sync_buffers(subctx);
ggml_vk_sync_buffers(ctx, subctx);
const std::array<uint32_t, 2> pc2 = { (uint32_t)(m * n * batch), split_k };
ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_matmul_split_k_reduce, { split_k_buffer, d }, pc2, { m * n * batch, 1, 1 });
ctx->prealloc_split_k_need_sync = true;
}
static vk_pipeline ggml_vk_guess_matmul_id_pipeline(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, uint32_t m, uint32_t n, bool aligned, ggml_type src0_type) {
@ -5335,7 +5384,6 @@ static void ggml_vk_matmul_id(
"m: " << m << ", n: " << n << ", k: " << k << ", stride_a: " << stride_a << ", stride_b: " << stride_b << ", stride_d: " << stride_d << ", " <<
"batch_stride_a: " << batch_stride_a << ", batch_stride_b: " << batch_stride_b << ", batch_stride_d: " << batch_stride_d << ", " <<
"n_as: " << n_as << ", nei0: " << nei0 << ", nei1: " << nei1 << ", nbi1: " << nbi1 << ", ne11: " << ne11 << ")");
ggml_vk_sync_buffers(subctx);
const vk_mat_mat_id_push_constants pc = { m, n, k, stride_a, stride_b, stride_d, batch_stride_a, batch_stride_b, batch_stride_d,
nei0, nei1, nbi1, ne11, padded_n };
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { a, b, d, ids }, pc, { m, nei1, n_as });
@ -5466,8 +5514,8 @@ static void ggml_vk_cpy_to_contiguous(ggml_backend_vk_context * ctx, vk_context&
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
};
init_pushconst_fastdiv(pc);
ggml_vk_sync_buffers(subctx);
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { in, out }, pc, elements);
ggml_vk_sync_buffers(ctx, subctx);
}
static vk_pipeline ggml_vk_get_quantize_pipeline(ggml_backend_vk_context * ctx, ggml_type type) {
@ -5485,8 +5533,8 @@ static void ggml_vk_quantize_q8_1(ggml_backend_vk_context * ctx, vk_context& sub
vk_pipeline pipeline = ggml_vk_get_quantize_pipeline(ctx, GGML_TYPE_Q8_1);
ggml_vk_sync_buffers(subctx);
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { in, out }, std::array<uint32_t, 1>{ne}, { ne, 1, 1 });
ggml_vk_sync_buffers(ctx, subctx);
}
static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
@ -5681,12 +5729,23 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
GGML_ASSERT(qy_sz == y_sz);
}
if (x_non_contig || qx_needs_dequant) {
if (ctx->prealloc_x_need_sync) {
ggml_vk_sync_buffers(ctx, subctx);
}
}
if (y_non_contig || quantize_y) {
if (ctx->prealloc_y_need_sync) {
ggml_vk_sync_buffers(ctx, subctx);
}
}
if (x_non_contig) {
ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_0, src0, { d_Qx, qx_buf_offset, VK_WHOLE_SIZE }, { d_X, 0, VK_WHOLE_SIZE });
} else if (qx_needs_dequant) {
const std::vector<uint32_t> pc = { (uint32_t)ne01, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)(ggml_nelements(src0)) };
ggml_vk_sync_buffers(subctx);
ggml_vk_dispatch_pipeline(ctx, subctx, to_fp16_vk_0, { vk_subbuffer{ d_Qx, qx_buf_offset, qx_sz * ne02 * ne03 }, vk_subbuffer{ d_X, 0, x_sz * ne02 * ne03 } }, pc, { (uint32_t)(x_ne * ne02 * ne03), 1, 1});
ggml_vk_sync_buffers(ctx, subctx);
}
if (y_non_contig) {
if (ctx->prealloc_y_last_pipeline_used != to_fp16_vk_1.get() ||
@ -5725,6 +5784,13 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
ne10, ne10, ne01, stride_batch_x, stride_batch_y, ne20*ne21,
split_k, ne12*ne13, ne02, ne12, r2, r3, padded_n
); // NOLINT
if (x_non_contig || qx_needs_dequant) {
ctx->prealloc_x_need_sync = true;
}
if (y_non_contig || quantize_y) {
ctx->prealloc_y_need_sync = true;
}
}
static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
@ -5871,6 +5937,17 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context&
GGML_ASSERT(qy_sz == y_sz);
}
if (x_non_contig) {
if (ctx->prealloc_x_need_sync) {
ggml_vk_sync_buffers(ctx, subctx);
}
}
if (y_non_contig) {
if (ctx->prealloc_y_need_sync) {
ggml_vk_sync_buffers(ctx, subctx);
}
}
if (x_non_contig) {
GGML_ASSERT(x_sz == ggml_vk_align_size(ggml_type_size(src0->type) * x_ne, ctx->device->properties.limits.minStorageBufferOffsetAlignment));
ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_0, src0, { d_Qx, qx_buf_offset, VK_WHOLE_SIZE }, { d_X, 0, VK_WHOLE_SIZE });
@ -5914,10 +5991,16 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context&
stride_batch_x, stride_batch_y, stride_batch_d,
(uint32_t)ne02, (uint32_t)ne12, (uint32_t)r2, (uint32_t)r3,
};
ggml_vk_sync_buffers(subctx);
ggml_vk_dispatch_pipeline(ctx, subctx, dmmv,
{ vk_subbuffer{ d_X, x_buf_offset, x_sz * ne02 * ne03 }, vk_subbuffer{ d_Y, y_buf_offset, y_sz * ne12 * ne13 }, vk_subbuffer{ d_D, d_buf_offset, d_sz * ne22 * ne23} },
pc, { groups_x, (uint32_t)(ne12 * ne13), groups_z });
if (x_non_contig) {
ctx->prealloc_x_need_sync = true;
}
if (y_non_contig) {
ctx->prealloc_y_need_sync = true;
}
}
static void ggml_vk_mul_mat_vec_p021_f16_f32(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
@ -6004,7 +6087,6 @@ static void ggml_vk_mul_mat_vec_p021_f16_f32(ggml_backend_vk_context * ctx, vk_c
workgroups_z /= gqa_ratio;
}
ggml_vk_sync_buffers(subctx);
ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_mul_mat_vec_p021_f16_f32[gqa_ratio - 1], { vk_subbuffer{ d_Qx, qx_buf_offset, qx_sz }, vk_subbuffer{ d_Qy, qy_buffer_offset, qy_sz + qy_shader_offset }, vk_subbuffer{ d_D, d_buffer_offset, d_sz + d_shader_offset } }, pc, { 1, (uint32_t)ne01, workgroups_z });
}
@ -6091,7 +6173,6 @@ static void ggml_vk_mul_mat_vec_nc_f16_f32(ggml_backend_vk_context * ctx, vk_con
// compute
const std::array<uint32_t, 12> pc = { (uint32_t)ne00, (uint32_t)ne01, row_stride_x, channel_stride_x, channel_stride_y, (uint32_t)(ne12 / ne02), (uint32_t)ne12, (uint32_t)(qy_shader_offset / ggml_type_size(src1->type)), (uint32_t)(d_shader_offset / ggml_type_size(dst->type)), nb03, nb13, nb23 };
ggml_vk_sync_buffers(subctx);
ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_mul_mat_vec_nc_f16_f32,
{ vk_subbuffer{ d_Qx, qx_buf_offset, qx_sz }, vk_subbuffer{ d_Qy, qy_buffer_offset, qy_sz + qy_shader_offset }, vk_subbuffer{ d_D, d_buffer_offset, d_sz + d_shader_offset } }, pc, { (uint32_t)ne03, (uint32_t)ne01, (uint32_t)ne12 });
}
@ -6303,13 +6384,24 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
GGML_ASSERT(qy_sz == y_sz);
}
if (x_non_contig || qx_needs_dequant) {
if (ctx->prealloc_x_need_sync) {
ggml_vk_sync_buffers(ctx, subctx);
}
}
if (y_non_contig) {
if (ctx->prealloc_y_need_sync) {
ggml_vk_sync_buffers(ctx, subctx);
}
}
if (x_non_contig) {
ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_0, src0, { d_Qx, qx_buf_offset, VK_WHOLE_SIZE }, { d_X, 0, VK_WHOLE_SIZE });
} else if (qx_needs_dequant) {
const std::vector<uint32_t> pc = { (uint32_t)ne01, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)(ggml_nelements(src0)) };
ggml_vk_sync_buffers(subctx);
ggml_vk_dispatch_pipeline(ctx, subctx, to_fp16_vk_0,
{ vk_subbuffer{ d_Qx, qx_buf_offset, qx_sz * ne02 * ne03 }, vk_subbuffer{ d_X, 0, x_sz * ne02 * ne03 } }, pc, { (uint32_t)(x_ne * ne02 * ne03), 1, 1});
ggml_vk_sync_buffers(ctx, subctx);
}
if (y_non_contig) {
if (ctx->prealloc_y_last_pipeline_used != to_fp16_vk_1.get() ||
@ -6340,6 +6432,13 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
stride_batch_x, stride_batch_y, ne20*ne21,
n_as, nei0, nei1, nbi1 / ggml_type_size(ids->type), ne11, padded_n
); // NOLINT
if (x_non_contig || qx_needs_dequant) {
ctx->prealloc_x_need_sync = true;
}
if (y_non_contig) {
ctx->prealloc_y_need_sync = true;
}
}
static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst, bool dryrun = false) {
@ -6499,6 +6598,17 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte
GGML_ASSERT(qy_sz == y_sz);
}
if (x_non_contig) {
if (ctx->prealloc_x_need_sync) {
ggml_vk_sync_buffers(ctx, subctx);
}
}
if (y_non_contig) {
if (ctx->prealloc_y_need_sync) {
ggml_vk_sync_buffers(ctx, subctx);
}
}
if (x_non_contig) {
GGML_ASSERT(x_sz == ggml_vk_align_size(ggml_type_size(src0->type) * x_ne, ctx->device->properties.limits.minStorageBufferOffsetAlignment));
ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_0, src0, { d_Qx, qx_buf_offset, VK_WHOLE_SIZE }, { d_X, 0, VK_WHOLE_SIZE });
@ -6535,11 +6645,17 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte
(uint32_t)x_ne, stride_batch_y, (uint32_t)(ne20*ne21),
(uint32_t)nei0, (uint32_t)ne11,
};
ggml_vk_sync_buffers(subctx);
ggml_vk_dispatch_pipeline(ctx, subctx, dmmv,
{ vk_subbuffer{ d_X, x_buf_offset, x_sz * ne02 * ne03 },
vk_subbuffer{ d_Y, y_buf_offset, y_sz * ne12 * ne13 }, vk_subbuffer{ d_D, d_buf_offset, d_sz * ne22 * ne23}, vk_subbuffer{ d_ids, ids_buf_offset, ids_sz } },
pc, { groups_x, (uint32_t)nei0, groups_z });
if (x_non_contig) {
ctx->prealloc_x_need_sync = true;
}
if (y_non_contig) {
ctx->prealloc_y_need_sync = true;
}
}
static void ggml_vk_mul_mat_id(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, bool dryrun = false) {
@ -6922,9 +7038,11 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
mask_n_head_log2, m0, m1,
gqa_ratio, split_kv, split_k };
ggml_vk_sync_buffers(subctx);
if (split_k > 1) {
if (ctx->prealloc_split_k_need_sync) {
ggml_vk_sync_buffers(ctx, subctx);
}
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
{
vk_subbuffer{d_Q, q_buf_offset, VK_WHOLE_SIZE},
@ -6940,7 +7058,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
// cancel out the divide by wg_denoms[0].
pc, { workgroups_x * pipeline->wg_denoms[0], workgroups_y, workgroups_z });
ggml_vk_sync_buffers(subctx);
ggml_vk_sync_buffers(ctx, subctx);
const std::array<uint32_t, 5> pc2 = { HSV, (uint32_t)ne1, (uint32_t)ne3, split_k, (sinks != nullptr) };
ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_flash_attn_split_k_reduce,
{
@ -6949,6 +7067,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
vk_subbuffer{d_D, d_buf_offset, VK_WHOLE_SIZE},
},
pc2, { (uint32_t)ne1, HSV, (uint32_t)ne3 });
ctx->prealloc_split_k_need_sync = true;
} else {
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
{
@ -7279,6 +7398,7 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
return nullptr;
case GGML_OP_SUM:
case GGML_OP_SUM_ROWS:
case GGML_OP_MEAN:
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
return ctx->device->pipeline_sum_rows_f32;
}
@ -7417,6 +7537,9 @@ static bool ggml_vk_op_supports_incontiguous(ggml_op op) {
case GGML_OP_CONV_2D_DW:
case GGML_OP_IM2COL:
case GGML_OP_SET_ROWS:
case GGML_OP_SUM:
case GGML_OP_SUM_ROWS:
case GGML_OP_MEAN:
return true;
default:
return false;
@ -7451,6 +7574,16 @@ template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk
GGML_UNUSED(src2);
}
template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk_op_sum_rows_push_constants &p, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst) {
const uint32_t a_offset = get_misalign_bytes(ctx, src0) / ggml_type_size(src0->type);
const uint32_t d_offset = get_misalign_bytes(ctx, dst) / ggml_type_size(dst->type);
p.misalign_offsets = (a_offset << 16) | d_offset;
GGML_UNUSED(src1);
GGML_UNUSED(src2);
}
template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk_op_binary_push_constants &p, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst) {
const uint32_t a_offset = get_misalign_bytes(ctx, src0) / ggml_type_size(src0->type);
const uint32_t b_offset = get_misalign_bytes(ctx, src1) / ggml_type_size(src1->type);
@ -7601,10 +7734,10 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
d_buf_offset &= ~(ctx->device->properties.limits.minStorageBufferOffsetAlignment - 1);
if (op_supports_incontiguous) {
x_sz = ggml_nbytes(src0);
y_sz = use_src1 ? ggml_nbytes(src1) : 0;
z_sz = use_src2 ? ggml_nbytes(src2) : 0;
d_sz = ggml_nbytes(dst);
x_sz = ggml_nbytes(src0) + get_misalign_bytes(ctx, src0);
y_sz = use_src1 ? ggml_nbytes(src1) + get_misalign_bytes(ctx, src1) : 0;
z_sz = use_src2 ? ggml_nbytes(src2) + get_misalign_bytes(ctx, src2) : 0;
d_sz = ggml_nbytes(dst) + get_misalign_bytes(ctx, dst);
if (x_buf_offset + x_sz >= d_X->size) {
x_sz = VK_WHOLE_SIZE;
@ -7632,6 +7765,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
case GGML_OP_SOFT_MAX:
case GGML_OP_SOFT_MAX_BACK:
case GGML_OP_SUM_ROWS:
case GGML_OP_MEAN:
case GGML_OP_ARGMAX:
{
const uint32_t nr = ggml_nrows(src0);
@ -7802,7 +7936,6 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
subbuf_y = { d_X, 0, x_sz };
}
ggml_vk_sync_buffers(subctx);
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, subbuf_y, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements);
} else if (op == GGML_OP_SOFT_MAX) {
// Empty src1 and src2 is possible in soft_max, but the shader needs a buffer
@ -7820,7 +7953,6 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
subbuf_z = { d_X, 0, x_sz };
}
ggml_vk_sync_buffers(subctx);
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, subbuf_y, subbuf_z, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements);
} else if (op == GGML_OP_ROPE || op == GGML_OP_ROPE_BACK) {
// Empty src2 is possible in rope, but the shader needs a buffer
@ -7831,30 +7963,23 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
subbuf_z = { d_X, 0, x_sz };
}
ggml_vk_sync_buffers(subctx);
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset, y_sz }, subbuf_z, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements);
} else if (op == GGML_OP_IM2COL) {
// im2col uses only src1 and dst buffers
ggml_vk_sync_buffers(subctx);
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_Y, y_buf_offset, y_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements);
} else if (op == GGML_OP_COUNT_EQUAL) {
ggml_vk_sync_buffers(subctx);
// count_equal assumes that destination buffer is initialized with zeroes
ggml_vk_buffer_memset_async(subctx, d_D, d_buf_offset, 0, d_sz);
ggml_vk_sync_buffers(subctx);
ggml_vk_sync_buffers(ctx, subctx);
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset, y_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements);
} else if (op == GGML_OP_OPT_STEP_SGD) {
// OPT_STEP_SGD works on src0, it does not need dst
ggml_vk_sync_buffers(subctx);
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset, y_sz }, vk_subbuffer{ d_Z, z_buf_offset, z_sz } }, pc, elements);
} else if (use_src2) {
ggml_vk_sync_buffers(subctx);
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset, y_sz }, vk_subbuffer{ d_Z, z_buf_offset, z_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements);
} else if (use_src1) {
ggml_vk_sync_buffers(subctx);
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset, y_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements);
} else {
ggml_vk_sync_buffers(subctx);
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements);
}
}
@ -7981,7 +8106,6 @@ static void ggml_vk_multi_add(ggml_backend_vk_context * ctx, vk_context& subctx,
elements = { ne, 1, 1 };
}
ggml_vk_sync_buffers(subctx);
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
{
vk_subbuffer{ buf[0], offset[0], VK_WHOLE_SIZE },
@ -8094,8 +8218,6 @@ static void ggml_vk_op_f32_wkv(ggml_backend_vk_context * ctx, vk_context& subctx
src_buf_ctxs[i] = (ggml_backend_vk_buffer_context *)dst->src[i]->buffer->context;
}
ggml_vk_sync_buffers(subctx);
vk_buffer d_D = nullptr, d_srcs[7] = { nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr };
size_t dst_offset = 0, src_offsets[7] = { 0, 0, 0, 0, 0, 0, 0 };
bool dst_uma = false, srcs_uma[7] = { false, false, false, false, false, false, false };
@ -8233,8 +8355,6 @@ static void ggml_vk_op_f32_opt_step_adamw(ggml_backend_vk_context * ctx, vk_cont
ggml_backend_vk_buffer_context * gv_buf_ctx = (ggml_backend_vk_buffer_context *)gv->buffer->context;
ggml_backend_vk_buffer_context * p_buf_ctx = (ggml_backend_vk_buffer_context *)p->buffer->context;
ggml_vk_sync_buffers(subctx);
vk_buffer d_X = nullptr, d_G = nullptr, d_GM = nullptr, d_GV = nullptr, d_P = nullptr;
size_t x_offset = 0, g_offset = 0, gm_offset = 0, gv_offset = 0, p_offset = 0;
bool X_uma = false, G_uma = false, GM_uma = false, GV_uma = false, P_uma = false;
@ -8618,11 +8738,19 @@ static void ggml_vk_argsort(ggml_backend_vk_context * ctx, vk_context& subctx, c
}
static void ggml_vk_sum(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SUM, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f }, dryrun);
vk_op_sum_rows_push_constants p = vk_op_sum_rows_push_constants_init(src0, dst, ggml_nelements(src0));
ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SUM, p, dryrun);
}
static void ggml_vk_sum_rows(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SUM_ROWS, { (uint32_t)src0->ne[0], 0, 0.0f, 0.0f }, dryrun);
vk_op_sum_rows_push_constants p = vk_op_sum_rows_push_constants_init(src0, dst, src0->ne[0]);
ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SUM_ROWS, p, dryrun);
}
static void ggml_vk_mean(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
vk_op_sum_rows_push_constants p = vk_op_sum_rows_push_constants_init(src0, dst, src0->ne[0]);
p.weight = 1.0f / (float)src0->ne[0];
ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_MEAN, p, dryrun);
}
static void ggml_vk_argmax(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
@ -9845,6 +9973,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
case GGML_OP_ARGSORT:
case GGML_OP_SUM:
case GGML_OP_SUM_ROWS:
case GGML_OP_MEAN:
case GGML_OP_ARGMAX:
case GGML_OP_COUNT_EQUAL:
case GGML_OP_IM2COL:
@ -9914,6 +10043,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
case GGML_OP_ARGSORT:
case GGML_OP_SUM:
case GGML_OP_SUM_ROWS:
case GGML_OP_MEAN:
case GGML_OP_ARGMAX:
case GGML_OP_COUNT_EQUAL:
case GGML_OP_IM2COL:
@ -9936,6 +10066,83 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
}
}
if (!dryrun) {
// This logic detects dependencies between modes in the graph and calls ggml_vk_sync_buffers
// to synchronize them. This handles most "normal" synchronization when computing the graph, and when
// there is no auxiliary memory use, it shouldn't be necessary to call ggml_vk_sync_buffers
// outside of this logic. When a node uses one of the prealloc buffers for something like
// dequantization or split_k, additional synchronization is needed between those passes.
bool need_sync = false;
// Check whether "node" requires synchronization. The node requires synchronization if it
// overlaps in memory with another unsynchronized node and at least one of them is a write.
// Destination nodes are checked against both the written/read lists. Source nodes are only
// checked against the written list. Two nodes overlap in memory if they come from the same
// buffer and the tensor or view ranges overlap.
auto const &overlaps_unsynced = [&](const ggml_tensor *node, const std::vector<const ggml_tensor *> &unsynced_nodes) -> bool {
if (unsynced_nodes.size() == 0) {
return false;
}
auto n_base = vk_tensor_offset(node) + node->view_offs;
auto n_size = ggml_nbytes(node);
ggml_backend_vk_buffer_context * a_buf_ctx = (ggml_backend_vk_buffer_context *)node->buffer->context;
vk_buffer a_buf = a_buf_ctx->dev_buffer;
for (auto &other : unsynced_nodes) {
ggml_backend_vk_buffer_context * o_buf_ctx = (ggml_backend_vk_buffer_context *)other->buffer->context;
vk_buffer o_buf = o_buf_ctx->dev_buffer;
if (a_buf == o_buf) {
auto o_base = vk_tensor_offset(other) + other->view_offs;
auto o_size = ggml_nbytes(other);
if ((o_base <= n_base && n_base < o_base + o_size) ||
(n_base <= o_base && o_base < n_base + n_size)) {
return true;
}
}
}
return false;
};
// For all fused ops, check if the destination node or any of the source
// nodes require synchronization.
for (int32_t i = 0; i < ctx->num_additional_fused_ops + 1 && !need_sync; ++i) {
const ggml_tensor *cur_node = cgraph->nodes[node_idx + i];
if (overlaps_unsynced(cur_node, ctx->unsynced_nodes_read) || overlaps_unsynced(cur_node, ctx->unsynced_nodes_written)) {
need_sync = true;
break;
}
for (uint32_t j = 0; j < GGML_MAX_SRC; ++j) {
if (!cur_node->src[j]) {
continue;
}
if (overlaps_unsynced(cur_node->src[j], ctx->unsynced_nodes_written)) {
need_sync = true;
break;
}
}
}
if (need_sync) {
VK_LOG_DEBUG("node_idx=" << i << " sync");
ctx->unsynced_nodes_written.clear();
ctx->unsynced_nodes_read.clear();
ggml_vk_sync_buffers(ctx, compute_ctx);
} else {
VK_LOG_DEBUG("node_idx=" << i << " unsynced");
}
// Add the last fused node and all fused source nodes to the unsynchronized list.
const ggml_tensor * last_node = cgraph->nodes[node_idx + ctx->num_additional_fused_ops];
ctx->unsynced_nodes_written.push_back(last_node);
for (int32_t i = 0; i < ctx->num_additional_fused_ops + 1; ++i) {
const ggml_tensor *cur_node = cgraph->nodes[node_idx + i];
for (uint32_t j = 0; j < GGML_MAX_SRC; ++j) {
if (!cur_node->src[j]) {
continue;
}
ctx->unsynced_nodes_read.push_back(cur_node->src[j]);
}
}
}
switch (node->op) {
case GGML_OP_REPEAT:
ggml_vk_repeat(ctx, compute_ctx, src0, node, dryrun);
@ -10117,6 +10324,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
case GGML_OP_SUM_ROWS:
ggml_vk_sum_rows(ctx, compute_ctx, src0, node, dryrun);
break;
case GGML_OP_MEAN:
ggml_vk_mean(ctx, compute_ctx, src0, node, dryrun);
break;
case GGML_OP_ARGMAX:
ggml_vk_argmax(ctx, compute_ctx, src0, node, dryrun);
@ -10276,6 +10487,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_cgraph *
case GGML_OP_ARGSORT:
case GGML_OP_SUM:
case GGML_OP_SUM_ROWS:
case GGML_OP_MEAN:
case GGML_OP_ARGMAX:
case GGML_OP_COUNT_EQUAL:
case GGML_OP_IM2COL:
@ -10394,6 +10606,10 @@ static void ggml_vk_graph_cleanup(ggml_backend_vk_context * ctx) {
ctx->gc.temp_buffers.clear();
ctx->prealloc_y_last_pipeline_used = {};
ctx->unsynced_nodes_written.clear();
ctx->unsynced_nodes_read.clear();
ctx->prealloc_x_need_sync = ctx->prealloc_y_need_sync = ctx->prealloc_split_k_need_sync = false;
ggml_vk_command_pool_cleanup(ctx->device, ctx->compute_cmd_pool);
ggml_vk_command_pool_cleanup(ctx->device, ctx->transfer_cmd_pool);
@ -11513,8 +11729,11 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
case GGML_OP_DIAG_MASK_INF:
case GGML_OP_SOFT_MAX:
case GGML_OP_SOFT_MAX_BACK:
return true;
case GGML_OP_SUM:
case GGML_OP_SUM_ROWS:
case GGML_OP_MEAN:
return op->src[0]->type == GGML_TYPE_F32 && ggml_is_contiguous_rows(op->src[0]);
case GGML_OP_ARGMAX:
case GGML_OP_COUNT_EQUAL:
case GGML_OP_IM2COL:
@ -12073,6 +12292,8 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph *
tensor_clone = ggml_sum(ggml_ctx, src_clone[0]);
} else if (tensor->op == GGML_OP_SUM_ROWS) {
tensor_clone = ggml_sum_rows(ggml_ctx, src_clone[0]);
} else if (tensor->op == GGML_OP_MEAN) {
tensor_clone = ggml_mean(ggml_ctx, src_clone[0]);
} else if (tensor->op == GGML_OP_ARGMAX) {
tensor_clone = ggml_argmax(ggml_ctx, src_clone[0]);
} else if (tensor->op == GGML_OP_COUNT_EQUAL) {

View file

@ -103,16 +103,74 @@ layout (constant_id = 10) const uint WARP = 32;
shared FLOAT_TYPE buf_a[BM * SHMEM_STRIDE];
shared FLOAT_TYPE buf_b[BN * SHMEM_STRIDE];
#define NUM_WARPS (BLOCK_SIZE / WARP)
#ifdef MUL_MAT_ID
shared u16vec2 row_ids[4096];
uint _ne1;
#ifdef COOPMAT
shared uint _ne1_sh;
shared uvec4 ballots_sh[NUM_WARPS];
void load_row_ids(uint expert_idx, bool nei0_is_pow2) {
_ne1 = 0;
uint num_elements = p.nei1 * p.nei0;
uint nei0shift = findLSB(p.nei0);
uint ids[16];
uint iter = 0;
for (uint j = 0; j < num_elements; j += BLOCK_SIZE) {
// prefetch up to 16 elements
if (iter == 0) {
[[unroll]] for (uint k = 0; k < 16; ++k) {
uint i = j + gl_LocalInvocationIndex + k*BLOCK_SIZE;
bool in_range = i < num_elements;
uint ii1;
if (nei0_is_pow2) {
ii1 = i >> nei0shift;
} else {
ii1 = i / p.nei0;
}
uint ii0 = i - ii1 * p.nei0;
ids[k] = in_range ? data_ids[ii1*p.nbi1 + ii0] : 0;
}
}
uint i = j + gl_LocalInvocationIndex;
bool in_range = i < num_elements;
uint ii1;
if (nei0_is_pow2) {
ii1 = i >> nei0shift;
} else {
ii1 = i / p.nei0;
}
uint ii0 = i - ii1 * p.nei0;
uint id = ids[iter++];
uvec4 ballot = subgroupBallot(in_range && id == expert_idx);
ballots_sh[gl_SubgroupID] = ballot;
barrier();
uint subgroup_base = 0;
uint total = 0;
for (uint k = 0; k < gl_NumSubgroups; ++k) {
if (k == gl_SubgroupID) {
subgroup_base = total;
}
total += subgroupBallotBitCount(ballots_sh[k]);
}
barrier();
uint idx = subgroup_base + subgroupBallotExclusiveBitCount(ballot);
if (in_range && id == expert_idx) {
row_ids[_ne1 + idx] = u16vec2(ii0, ii1);
}
_ne1 += total;
iter &= 15;
}
barrier();
}
#endif
#endif // MUL_MAT_ID
#define NUM_WARPS (BLOCK_SIZE / WARP)
#ifdef COOPMAT
shared ACC_TYPE coopmat_stage[TM * TN * NUM_WARPS];
#endif
@ -178,44 +236,11 @@ void main() {
#ifdef MUL_MAT_ID
#ifdef COOPMAT
// Spread the search across all elements in the first subgroup
if (gl_SubgroupID == 0) {
_ne1 = 0;
uint num_elements = p.nei1 * p.nei0;
uint ids[16];
uint iter = 0;
for (uint j = 0; j < num_elements; j += gl_SubgroupSize) {
// prefetch up to 16 elements
if (iter == 0) {
[[unroll]] for (uint k = 0; k < 16; ++k) {
uint i = j + gl_SubgroupInvocationID + k*gl_SubgroupSize;
bool in_range = i < num_elements;
uint ii1 = i / p.nei0;
uint ii0 = i % p.nei0;
ids[k] = in_range ? data_ids[ii1*p.nbi1 + ii0] : 0;
}
}
uint i = j + gl_SubgroupInvocationID;
bool in_range = i < num_elements;
uint ii1 = i / p.nei0;
uint ii0 = i % p.nei0;
uint id = ids[iter++];
uvec4 ballot = subgroupBallot(in_range && id == expert_idx);
uint idx = subgroupBallotExclusiveBitCount(ballot);
if (in_range && id == expert_idx) {
row_ids[_ne1 + idx] = u16vec2(ii0, ii1);
}
_ne1 += subgroupBallotBitCount(ballot);
iter &= 15;
}
_ne1_sh = _ne1;
if (bitCount(p.nei0) == 1) {
load_row_ids(expert_idx, true);
} else {
load_row_ids(expert_idx, false);
}
barrier();
_ne1 = _ne1_sh;
#else
_ne1 = 0;
for (uint ii1 = 0; ii1 < p.nei1; ii1++) {

View file

@ -19,6 +19,7 @@
#endif
#include "types.comp"
#include "utils.comp"
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
@ -99,7 +100,8 @@ layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufB {
};
uint _ne1;
shared uint _ne1_sh;
layout (constant_id = 5) const uint subgroup_size = 32;
shared uvec4 ballots_sh[BLOCK_SIZE / subgroup_size];
B_TYPE decodeFuncB(const in decodeBufB bl, const in uint blockCoords[2], const in uint coordInBlock[2])
{
@ -128,6 +130,64 @@ D_TYPE perElemOpD(const in uint32_t r, const in uint32_t c, const in D_TYPE elem
return elem;
}
void load_row_ids(uint expert_idx, bool nei0_is_pow2) {
_ne1 = 0;
uint num_elements = p.nei1 * p.nei0;
uint nei0shift = findLSB(p.nei0);
uint ids[16];
uint iter = 0;
for (uint j = 0; j < num_elements; j += BLOCK_SIZE) {
// prefetch up to 16 elements
if (iter == 0) {
[[unroll]] for (uint k = 0; k < 16; ++k) {
uint i = j + gl_LocalInvocationIndex + k*BLOCK_SIZE;
bool in_range = i < num_elements;
uint ii1;
if (nei0_is_pow2) {
ii1 = i >> nei0shift;
} else {
ii1 = i / p.nei0;
}
uint ii0 = i - ii1 * p.nei0;
ids[k] = in_range ? data_ids[ii1*p.nbi1 + ii0] : 0;
}
}
uint i = j + gl_LocalInvocationIndex;
bool in_range = i < num_elements;
uint ii1;
if (nei0_is_pow2) {
ii1 = i >> nei0shift;
} else {
ii1 = i / p.nei0;
}
uint ii0 = i - ii1 * p.nei0;
uint id = ids[iter++];
uvec4 ballot = subgroupBallot(in_range && id == expert_idx);
ballots_sh[gl_SubgroupID] = ballot;
barrier();
uint subgroup_base = 0;
uint total = 0;
for (uint k = 0; k < gl_NumSubgroups; ++k) {
if (k == gl_SubgroupID) {
subgroup_base = total;
}
total += subgroupBallotBitCount(ballots_sh[k]);
}
barrier();
uint idx = subgroup_base + subgroupBallotExclusiveBitCount(ballot);
if (in_range && id == expert_idx) {
row_ids[_ne1 + idx] = u16vec4(fastmod(ii0, p.ne11), ii1, ii0, 0);
}
_ne1 += total;
iter &= 15;
}
barrier();
}
#endif
void main() {
@ -157,45 +217,12 @@ void main() {
const uint ic = gl_WorkGroupID.y;
#ifdef MUL_MAT_ID
// Spread the search across all elements in the first subgroup
if (gl_SubgroupID == 0) {
_ne1 = 0;
uint num_elements = p.nei1 * p.nei0;
uint ids[16];
uint iter = 0;
for (uint j = 0; j < num_elements; j += gl_SubgroupSize) {
// prefetch up to 16 elements
if (iter == 0) {
[[unroll]] for (uint k = 0; k < 16; ++k) {
uint i = j + gl_SubgroupInvocationID + k*gl_SubgroupSize;
bool in_range = i < num_elements;
uint ii1 = i / p.nei0;
uint ii0 = i % p.nei0;
ids[k] = in_range ? data_ids[ii1*p.nbi1 + ii0] : 0;
}
}
uint i = j + gl_SubgroupInvocationID;
bool in_range = i < num_elements;
uint ii1 = i / p.nei0;
uint ii0 = i % p.nei0;
uint id = ids[iter++];
uvec4 ballot = subgroupBallot(in_range && id == expert_idx);
uint idx = subgroupBallotExclusiveBitCount(ballot);
if (in_range && id == expert_idx) {
row_ids[_ne1 + idx] = u16vec4(ii0 % p.ne11, ii1, ii0, 0);
}
_ne1 += subgroupBallotBitCount(ballot);
iter &= 15;
}
_ne1_sh = _ne1;
if (bitCount(p.nei0) == 1) {
load_row_ids(expert_idx, true);
} else {
load_row_ids(expert_idx, false);
}
barrier();
_ne1 = _ne1_sh;
// Workgroup has no work
if (ic * BN >= _ne1) return;
#endif

View file

@ -1,9 +1,9 @@
#version 450
#include "generic_head.comp"
#include "types.comp"
#extension GL_EXT_control_flow_attributes : enable
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
@ -11,16 +11,49 @@ layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
layout (constant_id = 0) const uint BLOCK_SIZE = 32;
layout (push_constant) uniform parameter
{
uint n_cols;
uint ne01, ne02;
uint nb01, nb02, nb03;
uint nb11, nb12, nb13;
float weight;
uint misalign_offsets;
uint ne0_12mp, ne0_12L;
uint ne0_1mp, ne0_1L;
} p;
uint get_aoffset() { return p.misalign_offsets >> 16; }
uint get_doffset() { return p.misalign_offsets & 0xFFFF; }
// see init_fastdiv_values in ggml-vulkan.cpp
uint fastdiv(uint n, uint mp, uint L) {
uint msbs, lsbs;
// msbs = mulhi(n, mp)
umulExtended(n, mp, msbs, lsbs);
return (msbs + n) >> L;
}
shared FLOAT_TYPE tmp[BLOCK_SIZE];
void main() {
const uint row = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x;
const uint col = gl_LocalInvocationID.x;
const float weight = p.weight;
tmp[col] = FLOAT_TYPE(0.0f);
const uint i03 = fastdiv(row, p.ne0_12mp, p.ne0_12L);
const uint i03_offset = i03 * p.ne01*p.ne02;
const uint i02 = fastdiv(row - i03_offset, p.ne0_1mp, p.ne0_1L);
const uint i01 = row - i03_offset - i02*p.ne01;
for (uint i = col; i < p.KX; i += BLOCK_SIZE) {
tmp[col] += FLOAT_TYPE(data_a[row*p.KX + i]);
const uint src_idx = get_aoffset() + i01 * p.nb01 + i02 * p.nb02 + i03 * p.nb03;
const uint dst_idx = get_doffset() + i01 * p.nb11 + i02 * p.nb12 + i03 * p.nb13;
tmp[col] = FLOAT_TYPE(0.0);
for (uint i = col; i < p.n_cols; i += BLOCK_SIZE) {
tmp[col] += FLOAT_TYPE(data_a[src_idx + i]);
}
barrier();
@ -32,6 +65,6 @@ void main() {
}
if (col == 0) {
data_d[row] = D_TYPE(tmp[0]);
data_d[dst_idx] = D_TYPE(tmp[0] * weight);
}
}

View file

@ -16,10 +16,10 @@
static std::string trim(const std::string & str) {
size_t start = 0;
size_t end = str.size();
while (start < end && isspace(str[start])) {
while (start < end && isspace(static_cast<unsigned char>(str[start]))) {
start += 1;
}
while (end > start && isspace(str[end - 1])) {
while (end > start && isspace(static_cast<unsigned char>(str[end - 1]))) {
end -= 1;
}
return str.substr(start, end - start);