From 48e7078ee08d56d22b81fc9aaefd8cb9fac4c3e2 Mon Sep 17 00:00:00 2001 From: Jeff Bolz Date: Thu, 28 May 2026 06:18:43 -0500 Subject: [PATCH] vulkan: fast path for walsh-hadamard transform (#23687) * vulkan: fast path for walsh-hadamard transform * disable for intel due to segfault --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 91 +++++++++++++++++++ ggml/src/ggml-vulkan/vulkan-shaders/fwht.comp | 69 ++++++++++++++ .../vulkan-shaders/vulkan-shaders-gen.cpp | 1 + tests/test-backend-ops.cpp | 1 + 4 files changed, 162 insertions(+) create mode 100644 ggml/src/ggml-vulkan/vulkan-shaders/fwht.comp diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 238ee8223..c9f906d79 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -860,6 +860,7 @@ struct vk_device_struct { vk_pipeline pipeline_argsort_large_f32[num_argsort_pipelines]; vk_pipeline pipeline_topk_f32[num_topk_pipelines]; vk_pipeline pipeline_sum_rows_f32; + vk_pipeline pipeline_fwht_f32[4]; vk_pipeline pipeline_cumsum_f32; vk_pipeline pipeline_cumsum_small_f32; vk_pipeline pipeline_cumsum_multipass1_f32; @@ -1150,6 +1151,13 @@ struct vk_op_push_constants { float param4; }; +struct vk_op_fwht_push_constants { + uint32_t n_rows; + uint32_t src_offset; + uint32_t dst_offset; + float scale; +}; + struct vk_op_count_experts_push_constants { uint32_t ne00; uint32_t ne01; @@ -2055,6 +2063,15 @@ template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk GGML_UNUSED(src3); } +template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk_op_fwht_push_constants &p, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, const ggml_tensor * src3, ggml_tensor * dst) { + p.src_offset = get_misalign_bytes(ctx, src0) / ggml_type_size(src0->type); + p.dst_offset = get_misalign_bytes(ctx, dst) / ggml_type_size(dst->type); + + GGML_UNUSED(src1); + GGML_UNUSED(src2); + GGML_UNUSED(src3); +} + struct ggml_backend_vk_buffer_context { vk_device_ref device; vk_buffer dev_buffer; @@ -4982,6 +4999,16 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_argmax_f32, "argmax_f32", argmax_f32_len, argmax_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, { device->subgroup_size }, 1); ggml_vk_create_pipeline(device, device->pipeline_sum_rows_f32, "sum_rows_f32", sum_rows_f32_len, sum_rows_f32_data, "main", 2, sizeof(vk_op_sum_rows_push_constants), {1, 1, 1}, { device->subgroup_size }, 1); + // Intel Arc B390 was observed segfaulting with this shader. + if (device->subgroup_basic && device->subgroup_shuffle && device->vendor_id != VK_VENDOR_ID_INTEL) { + int idx = 0; + for (uint32_t n : {64, 128, 256, 512}) { + if (device->subgroup_size <= n) { + ggml_vk_create_pipeline(device, device->pipeline_fwht_f32[idx], "fwht_f32", fwht_f32_len, fwht_f32_data, "main", 2, sizeof(vk_op_fwht_push_constants), {1, 1, 1}, { device->subgroup_size, n }, 1, true, true, device->subgroup_size); + } + ++idx; + } + } const uint32_t cumsum_elem_per_thread = (device->vendor_id == VK_VENDOR_ID_AMD || device->vendor_id == VK_VENDOR_ID_INTEL) ? 2 : 4; ggml_vk_create_pipeline(device, device->pipeline_cumsum_f32, "cumsum_f32", cumsum_f32_len, cumsum_f32_data, "main", 2, sizeof(vk_op_sum_rows_push_constants), {1, 1, 1}, { 256, device->subgroup_size, cumsum_elem_per_thread }, 1, true, true, device->subgroup_size); @@ -8741,6 +8768,68 @@ static void ggml_vk_mul_mat_vec_nc_f16_f32(ggml_backend_vk_context * ctx, vk_con }, pc, { (uint32_t)ne03, (uint32_t)ne01, (uint32_t)ne12 }); } +static int ggml_vk_fwht_pipeline_idx(int64_t n) { + switch (n) { + case 64: return 0; + case 128: return 1; + case 256: return 2; + case 512: return 3; + default: return -1; + } +} + +static bool ggml_vk_can_use_fwht(const ggml_backend_vk_context * ctx, const ggml_tensor * src1, const ggml_tensor * dst) { + if (ctx->num_additional_fused_ops != 0) { + return false; + } + + if (ggml_get_op_params_i32(dst, 1) != GGML_HINT_SRC0_IS_HADAMARD) { + return false; + } + + const int idx = ggml_vk_fwht_pipeline_idx(src1->ne[0]); + if (idx < 0 || ctx->device->pipeline_fwht_f32[idx] == nullptr) { + return false; + } + + if (src1->type != GGML_TYPE_F32 || dst->type != GGML_TYPE_F32) { + return false; + } + + if (!ggml_is_contiguous(src1)) { + return false; + } + GGML_ASSERT(ggml_is_contiguous(dst)); + + return true; +} + +static void ggml_vk_fwht(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src, ggml_tensor * dst) { + const int idx = ggml_vk_fwht_pipeline_idx(src->ne[0]); + vk_pipeline pipeline = ctx->device->pipeline_fwht_f32[idx]; + + const uint32_t rows_per_workgroup = 4; + const uint32_t n_rows = (uint32_t)ggml_nrows(src); + const uint32_t max_workgroups_x = ctx->device->properties.limits.maxComputeWorkGroupCount[0]; + + const uint32_t total_workgroups = CEIL_DIV(n_rows, rows_per_workgroup); + const uint32_t workgroups_x = std::min(total_workgroups, max_workgroups_x); + ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1); + + const vk_subbuffer src_buf = ggml_vk_tensor_subbuffer(ctx, src, true); + const vk_subbuffer dst_buf = ggml_vk_tensor_subbuffer(ctx, dst, true); + + vk_op_fwht_push_constants pc = { + n_rows, + 0, + 0, + 1.0f / std::sqrt((float)src->ne[0]), + }; + init_pushconst_tensor_offsets(ctx, pc, src, nullptr, nullptr, nullptr, dst); + + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { src_buf, dst_buf }, pc, { workgroups_x, 1, 1 }); +} + static void ggml_vk_mul_mat(ggml_backend_vk_context * ctx, vk_context& subctx, const struct ggml_cgraph * cgraph, int node_idx) { ggml_tensor * dst = cgraph->nodes[node_idx]; ggml_tensor * src0 = dst->src[0]; @@ -8774,6 +8863,8 @@ static void ggml_vk_mul_mat(ggml_backend_vk_context * ctx, vk_context& subctx, c m_offset += cur_M_size; } + } else if (ggml_vk_can_use_fwht(ctx, src1, dst)) { + ggml_vk_fwht(ctx, subctx, src1, dst); } else if (src0->type == GGML_TYPE_F16 && ggml_is_permuted(src0) && ggml_is_permuted(src1) && dst->ne[1] == 1 && // detect 0213 permutation, and batch size of 1 src0->nb[0] <= src0->nb[2] && diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/fwht.comp b/ggml/src/ggml-vulkan/vulkan-shaders/fwht.comp new file mode 100644 index 000000000..72059d4af --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/fwht.comp @@ -0,0 +1,69 @@ +#version 450 + +#extension GL_EXT_control_flow_attributes : require +#extension GL_KHR_shader_subgroup_basic : enable +#extension GL_KHR_shader_subgroup_shuffle : enable + +layout(local_size_x_id = 0, local_size_y = 4, local_size_z = 1) in; + +layout(constant_id = 0) const uint WARP_SIZE = 32; +layout(constant_id = 1) const uint N = 128; + +layout(push_constant) uniform parameter +{ + uint n_rows; + uint src_offset; + uint dst_offset; + float scale; +}; + +layout(binding = 0, std430) readonly buffer A { float data_a[]; }; +layout(binding = 1, std430) writeonly buffer D { float data_d[]; }; + +const uint EL_W = N / WARP_SIZE; + +void main() { + const uint lane = gl_SubgroupInvocationID; + for (uint row = gl_WorkGroupID.x * gl_WorkGroupSize.y + gl_SubgroupID; + row < n_rows; + row += gl_NumWorkGroups.x * gl_WorkGroupSize.y) { + const uint row_offset = row * N; + + float reg[EL_W]; + + [[unroll]] + for (uint i = 0; i < EL_W; ++i) { + reg[i] = data_a[src_offset + row_offset + i * WARP_SIZE + lane] * scale; + } + + [[unroll]] + for (uint h = 1; h < WARP_SIZE; h <<= 1) { + [[unroll]] + for (uint j = 0; j < EL_W; ++j) { + const float val = reg[j]; + const float val2 = subgroupShuffleXor(val, h); + reg[j] = (lane & h) == 0 ? val + val2 : val2 - val; + } + } + + [[unroll]] + for (uint h = WARP_SIZE; h < N; h <<= 1) { + const uint step = h / WARP_SIZE; + [[unroll]] + for (uint j = 0; j < EL_W; j += 2 * step) { + [[unroll]] + for (uint k = 0; k < step; ++k) { + const float x = reg[j + k]; + const float y = reg[j + k + step]; + reg[j + k] = x + y; + reg[j + k + step] = x - y; + } + } + } + + [[unroll]] + for (uint i = 0; i < EL_W; ++i) { + data_d[dst_offset + row_offset + i * WARP_SIZE + lane] = reg[i]; + } + } +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp index 24b9d25f7..fa9b938e4 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp @@ -934,6 +934,7 @@ void process_shaders() { string_to_spv("argmax_f32", "argmax.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "int"}})); string_to_spv("sum_rows_f32", "sum_rows.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}})); + string_to_spv("fwht_f32", "fwht.comp", {}); string_to_spv("count_equal_i32", "count_equal.comp", merge_maps(base_dict, {{"A_TYPE", "int"}, {"B_TYPE", "int"}, {"D_TYPE", "int"}})); string_to_spv("cumsum_f32", "cumsum.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}})); string_to_spv("cumsum_multipass1_f32", "cumsum_multipass1.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}})); diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 3853f0329..19f8558d8 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -8318,6 +8318,7 @@ static std::vector> make_test_cases_eval() { test_cases.emplace_back(new test_mul_mat_hadamard(GGML_TYPE_F32, GGML_TYPE_F32, 128, 1, 128)); test_cases.emplace_back(new test_mul_mat_hadamard(GGML_TYPE_F32, GGML_TYPE_F32, 64, 1, 64)); test_cases.emplace_back(new test_mul_mat_hadamard(GGML_TYPE_F32, GGML_TYPE_F32, 256, 1, 256)); + test_cases.emplace_back(new test_mul_mat_hadamard(GGML_TYPE_F32, GGML_TYPE_F32, 512, 1, 512)); test_cases.emplace_back(new test_mul_mat_hadamard(GGML_TYPE_F32, GGML_TYPE_F32, 128, 32, 128)); test_cases.emplace_back(new test_mul_mat_hadamard(GGML_TYPE_F32, GGML_TYPE_F32, 128, 4, 128, {2, 3}));