vulkan: fast path for walsh-hadamard transform (#23687)

* vulkan: fast path for walsh-hadamard transform

* disable for intel due to segfault
This commit is contained in:
Jeff Bolz 2026-05-28 06:18:43 -05:00 committed by GitHub
parent bb771cbd2b
commit 48e7078ee0
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 162 additions and 0 deletions

View file

@ -860,6 +860,7 @@ struct vk_device_struct {
vk_pipeline pipeline_argsort_large_f32[num_argsort_pipelines];
vk_pipeline pipeline_topk_f32[num_topk_pipelines];
vk_pipeline pipeline_sum_rows_f32;
vk_pipeline pipeline_fwht_f32[4];
vk_pipeline pipeline_cumsum_f32;
vk_pipeline pipeline_cumsum_small_f32;
vk_pipeline pipeline_cumsum_multipass1_f32;
@ -1150,6 +1151,13 @@ struct vk_op_push_constants {
float param4;
};
struct vk_op_fwht_push_constants {
uint32_t n_rows;
uint32_t src_offset;
uint32_t dst_offset;
float scale;
};
struct vk_op_count_experts_push_constants {
uint32_t ne00;
uint32_t ne01;
@ -2055,6 +2063,15 @@ template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk
GGML_UNUSED(src3);
}
template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk_op_fwht_push_constants &p, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, const ggml_tensor * src3, ggml_tensor * dst) {
p.src_offset = get_misalign_bytes(ctx, src0) / ggml_type_size(src0->type);
p.dst_offset = get_misalign_bytes(ctx, dst) / ggml_type_size(dst->type);
GGML_UNUSED(src1);
GGML_UNUSED(src2);
GGML_UNUSED(src3);
}
struct ggml_backend_vk_buffer_context {
vk_device_ref device;
vk_buffer dev_buffer;
@ -4982,6 +4999,16 @@ static void ggml_vk_load_shaders(vk_device& device) {
ggml_vk_create_pipeline(device, device->pipeline_argmax_f32, "argmax_f32", argmax_f32_len, argmax_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
ggml_vk_create_pipeline(device, device->pipeline_sum_rows_f32, "sum_rows_f32", sum_rows_f32_len, sum_rows_f32_data, "main", 2, sizeof(vk_op_sum_rows_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
// Intel Arc B390 was observed segfaulting with this shader.
if (device->subgroup_basic && device->subgroup_shuffle && device->vendor_id != VK_VENDOR_ID_INTEL) {
int idx = 0;
for (uint32_t n : {64, 128, 256, 512}) {
if (device->subgroup_size <= n) {
ggml_vk_create_pipeline(device, device->pipeline_fwht_f32[idx], "fwht_f32", fwht_f32_len, fwht_f32_data, "main", 2, sizeof(vk_op_fwht_push_constants), {1, 1, 1}, { device->subgroup_size, n }, 1, true, true, device->subgroup_size);
}
++idx;
}
}
const uint32_t cumsum_elem_per_thread = (device->vendor_id == VK_VENDOR_ID_AMD || device->vendor_id == VK_VENDOR_ID_INTEL) ? 2 : 4;
ggml_vk_create_pipeline(device, device->pipeline_cumsum_f32, "cumsum_f32", cumsum_f32_len, cumsum_f32_data, "main", 2, sizeof(vk_op_sum_rows_push_constants), {1, 1, 1}, { 256, device->subgroup_size, cumsum_elem_per_thread }, 1, true, true, device->subgroup_size);
@ -8741,6 +8768,68 @@ static void ggml_vk_mul_mat_vec_nc_f16_f32(ggml_backend_vk_context * ctx, vk_con
}, pc, { (uint32_t)ne03, (uint32_t)ne01, (uint32_t)ne12 });
}
static int ggml_vk_fwht_pipeline_idx(int64_t n) {
switch (n) {
case 64: return 0;
case 128: return 1;
case 256: return 2;
case 512: return 3;
default: return -1;
}
}
static bool ggml_vk_can_use_fwht(const ggml_backend_vk_context * ctx, const ggml_tensor * src1, const ggml_tensor * dst) {
if (ctx->num_additional_fused_ops != 0) {
return false;
}
if (ggml_get_op_params_i32(dst, 1) != GGML_HINT_SRC0_IS_HADAMARD) {
return false;
}
const int idx = ggml_vk_fwht_pipeline_idx(src1->ne[0]);
if (idx < 0 || ctx->device->pipeline_fwht_f32[idx] == nullptr) {
return false;
}
if (src1->type != GGML_TYPE_F32 || dst->type != GGML_TYPE_F32) {
return false;
}
if (!ggml_is_contiguous(src1)) {
return false;
}
GGML_ASSERT(ggml_is_contiguous(dst));
return true;
}
static void ggml_vk_fwht(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src, ggml_tensor * dst) {
const int idx = ggml_vk_fwht_pipeline_idx(src->ne[0]);
vk_pipeline pipeline = ctx->device->pipeline_fwht_f32[idx];
const uint32_t rows_per_workgroup = 4;
const uint32_t n_rows = (uint32_t)ggml_nrows(src);
const uint32_t max_workgroups_x = ctx->device->properties.limits.maxComputeWorkGroupCount[0];
const uint32_t total_workgroups = CEIL_DIV(n_rows, rows_per_workgroup);
const uint32_t workgroups_x = std::min(total_workgroups, max_workgroups_x);
ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1);
const vk_subbuffer src_buf = ggml_vk_tensor_subbuffer(ctx, src, true);
const vk_subbuffer dst_buf = ggml_vk_tensor_subbuffer(ctx, dst, true);
vk_op_fwht_push_constants pc = {
n_rows,
0,
0,
1.0f / std::sqrt((float)src->ne[0]),
};
init_pushconst_tensor_offsets(ctx, pc, src, nullptr, nullptr, nullptr, dst);
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { src_buf, dst_buf }, pc, { workgroups_x, 1, 1 });
}
static void ggml_vk_mul_mat(ggml_backend_vk_context * ctx, vk_context& subctx, const struct ggml_cgraph * cgraph, int node_idx) {
ggml_tensor * dst = cgraph->nodes[node_idx];
ggml_tensor * src0 = dst->src[0];
@ -8774,6 +8863,8 @@ static void ggml_vk_mul_mat(ggml_backend_vk_context * ctx, vk_context& subctx, c
m_offset += cur_M_size;
}
} else if (ggml_vk_can_use_fwht(ctx, src1, dst)) {
ggml_vk_fwht(ctx, subctx, src1, dst);
} else if (src0->type == GGML_TYPE_F16 && ggml_is_permuted(src0) && ggml_is_permuted(src1) && dst->ne[1] == 1 &&
// detect 0213 permutation, and batch size of 1
src0->nb[0] <= src0->nb[2] &&

View file

@ -0,0 +1,69 @@
#version 450
#extension GL_EXT_control_flow_attributes : require
#extension GL_KHR_shader_subgroup_basic : enable
#extension GL_KHR_shader_subgroup_shuffle : enable
layout(local_size_x_id = 0, local_size_y = 4, local_size_z = 1) in;
layout(constant_id = 0) const uint WARP_SIZE = 32;
layout(constant_id = 1) const uint N = 128;
layout(push_constant) uniform parameter
{
uint n_rows;
uint src_offset;
uint dst_offset;
float scale;
};
layout(binding = 0, std430) readonly buffer A { float data_a[]; };
layout(binding = 1, std430) writeonly buffer D { float data_d[]; };
const uint EL_W = N / WARP_SIZE;
void main() {
const uint lane = gl_SubgroupInvocationID;
for (uint row = gl_WorkGroupID.x * gl_WorkGroupSize.y + gl_SubgroupID;
row < n_rows;
row += gl_NumWorkGroups.x * gl_WorkGroupSize.y) {
const uint row_offset = row * N;
float reg[EL_W];
[[unroll]]
for (uint i = 0; i < EL_W; ++i) {
reg[i] = data_a[src_offset + row_offset + i * WARP_SIZE + lane] * scale;
}
[[unroll]]
for (uint h = 1; h < WARP_SIZE; h <<= 1) {
[[unroll]]
for (uint j = 0; j < EL_W; ++j) {
const float val = reg[j];
const float val2 = subgroupShuffleXor(val, h);
reg[j] = (lane & h) == 0 ? val + val2 : val2 - val;
}
}
[[unroll]]
for (uint h = WARP_SIZE; h < N; h <<= 1) {
const uint step = h / WARP_SIZE;
[[unroll]]
for (uint j = 0; j < EL_W; j += 2 * step) {
[[unroll]]
for (uint k = 0; k < step; ++k) {
const float x = reg[j + k];
const float y = reg[j + k + step];
reg[j + k] = x + y;
reg[j + k + step] = x - y;
}
}
}
[[unroll]]
for (uint i = 0; i < EL_W; ++i) {
data_d[dst_offset + row_offset + i * WARP_SIZE + lane] = reg[i];
}
}
}

View file

@ -934,6 +934,7 @@ void process_shaders() {
string_to_spv("argmax_f32", "argmax.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "int"}}));
string_to_spv("sum_rows_f32", "sum_rows.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
string_to_spv("fwht_f32", "fwht.comp", {});
string_to_spv("count_equal_i32", "count_equal.comp", merge_maps(base_dict, {{"A_TYPE", "int"}, {"B_TYPE", "int"}, {"D_TYPE", "int"}}));
string_to_spv("cumsum_f32", "cumsum.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
string_to_spv("cumsum_multipass1_f32", "cumsum_multipass1.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));

View file

@ -8318,6 +8318,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
test_cases.emplace_back(new test_mul_mat_hadamard(GGML_TYPE_F32, GGML_TYPE_F32, 128, 1, 128));
test_cases.emplace_back(new test_mul_mat_hadamard(GGML_TYPE_F32, GGML_TYPE_F32, 64, 1, 64));
test_cases.emplace_back(new test_mul_mat_hadamard(GGML_TYPE_F32, GGML_TYPE_F32, 256, 1, 256));
test_cases.emplace_back(new test_mul_mat_hadamard(GGML_TYPE_F32, GGML_TYPE_F32, 512, 1, 512));
test_cases.emplace_back(new test_mul_mat_hadamard(GGML_TYPE_F32, GGML_TYPE_F32, 128, 32, 128));
test_cases.emplace_back(new test_mul_mat_hadamard(GGML_TYPE_F32, GGML_TYPE_F32, 128, 4, 128, {2, 3}));