mirror of
https://github.com/LostRuins/koboldcpp.git
synced 2026-05-30 03:43:40 +00:00
vulkan: fast path for walsh-hadamard transform (#23687)
* vulkan: fast path for walsh-hadamard transform * disable for intel due to segfault
This commit is contained in:
parent
bb771cbd2b
commit
48e7078ee0
4 changed files with 162 additions and 0 deletions
|
|
@ -860,6 +860,7 @@ struct vk_device_struct {
|
|||
vk_pipeline pipeline_argsort_large_f32[num_argsort_pipelines];
|
||||
vk_pipeline pipeline_topk_f32[num_topk_pipelines];
|
||||
vk_pipeline pipeline_sum_rows_f32;
|
||||
vk_pipeline pipeline_fwht_f32[4];
|
||||
vk_pipeline pipeline_cumsum_f32;
|
||||
vk_pipeline pipeline_cumsum_small_f32;
|
||||
vk_pipeline pipeline_cumsum_multipass1_f32;
|
||||
|
|
@ -1150,6 +1151,13 @@ struct vk_op_push_constants {
|
|||
float param4;
|
||||
};
|
||||
|
||||
struct vk_op_fwht_push_constants {
|
||||
uint32_t n_rows;
|
||||
uint32_t src_offset;
|
||||
uint32_t dst_offset;
|
||||
float scale;
|
||||
};
|
||||
|
||||
struct vk_op_count_experts_push_constants {
|
||||
uint32_t ne00;
|
||||
uint32_t ne01;
|
||||
|
|
@ -2055,6 +2063,15 @@ template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk
|
|||
GGML_UNUSED(src3);
|
||||
}
|
||||
|
||||
template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk_op_fwht_push_constants &p, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, const ggml_tensor * src3, ggml_tensor * dst) {
|
||||
p.src_offset = get_misalign_bytes(ctx, src0) / ggml_type_size(src0->type);
|
||||
p.dst_offset = get_misalign_bytes(ctx, dst) / ggml_type_size(dst->type);
|
||||
|
||||
GGML_UNUSED(src1);
|
||||
GGML_UNUSED(src2);
|
||||
GGML_UNUSED(src3);
|
||||
}
|
||||
|
||||
struct ggml_backend_vk_buffer_context {
|
||||
vk_device_ref device;
|
||||
vk_buffer dev_buffer;
|
||||
|
|
@ -4982,6 +4999,16 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|||
ggml_vk_create_pipeline(device, device->pipeline_argmax_f32, "argmax_f32", argmax_f32_len, argmax_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
|
||||
|
||||
ggml_vk_create_pipeline(device, device->pipeline_sum_rows_f32, "sum_rows_f32", sum_rows_f32_len, sum_rows_f32_data, "main", 2, sizeof(vk_op_sum_rows_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
|
||||
// Intel Arc B390 was observed segfaulting with this shader.
|
||||
if (device->subgroup_basic && device->subgroup_shuffle && device->vendor_id != VK_VENDOR_ID_INTEL) {
|
||||
int idx = 0;
|
||||
for (uint32_t n : {64, 128, 256, 512}) {
|
||||
if (device->subgroup_size <= n) {
|
||||
ggml_vk_create_pipeline(device, device->pipeline_fwht_f32[idx], "fwht_f32", fwht_f32_len, fwht_f32_data, "main", 2, sizeof(vk_op_fwht_push_constants), {1, 1, 1}, { device->subgroup_size, n }, 1, true, true, device->subgroup_size);
|
||||
}
|
||||
++idx;
|
||||
}
|
||||
}
|
||||
|
||||
const uint32_t cumsum_elem_per_thread = (device->vendor_id == VK_VENDOR_ID_AMD || device->vendor_id == VK_VENDOR_ID_INTEL) ? 2 : 4;
|
||||
ggml_vk_create_pipeline(device, device->pipeline_cumsum_f32, "cumsum_f32", cumsum_f32_len, cumsum_f32_data, "main", 2, sizeof(vk_op_sum_rows_push_constants), {1, 1, 1}, { 256, device->subgroup_size, cumsum_elem_per_thread }, 1, true, true, device->subgroup_size);
|
||||
|
|
@ -8741,6 +8768,68 @@ static void ggml_vk_mul_mat_vec_nc_f16_f32(ggml_backend_vk_context * ctx, vk_con
|
|||
}, pc, { (uint32_t)ne03, (uint32_t)ne01, (uint32_t)ne12 });
|
||||
}
|
||||
|
||||
static int ggml_vk_fwht_pipeline_idx(int64_t n) {
|
||||
switch (n) {
|
||||
case 64: return 0;
|
||||
case 128: return 1;
|
||||
case 256: return 2;
|
||||
case 512: return 3;
|
||||
default: return -1;
|
||||
}
|
||||
}
|
||||
|
||||
static bool ggml_vk_can_use_fwht(const ggml_backend_vk_context * ctx, const ggml_tensor * src1, const ggml_tensor * dst) {
|
||||
if (ctx->num_additional_fused_ops != 0) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (ggml_get_op_params_i32(dst, 1) != GGML_HINT_SRC0_IS_HADAMARD) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const int idx = ggml_vk_fwht_pipeline_idx(src1->ne[0]);
|
||||
if (idx < 0 || ctx->device->pipeline_fwht_f32[idx] == nullptr) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (src1->type != GGML_TYPE_F32 || dst->type != GGML_TYPE_F32) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!ggml_is_contiguous(src1)) {
|
||||
return false;
|
||||
}
|
||||
GGML_ASSERT(ggml_is_contiguous(dst));
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
static void ggml_vk_fwht(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src, ggml_tensor * dst) {
|
||||
const int idx = ggml_vk_fwht_pipeline_idx(src->ne[0]);
|
||||
vk_pipeline pipeline = ctx->device->pipeline_fwht_f32[idx];
|
||||
|
||||
const uint32_t rows_per_workgroup = 4;
|
||||
const uint32_t n_rows = (uint32_t)ggml_nrows(src);
|
||||
const uint32_t max_workgroups_x = ctx->device->properties.limits.maxComputeWorkGroupCount[0];
|
||||
|
||||
const uint32_t total_workgroups = CEIL_DIV(n_rows, rows_per_workgroup);
|
||||
const uint32_t workgroups_x = std::min(total_workgroups, max_workgroups_x);
|
||||
ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1);
|
||||
|
||||
const vk_subbuffer src_buf = ggml_vk_tensor_subbuffer(ctx, src, true);
|
||||
const vk_subbuffer dst_buf = ggml_vk_tensor_subbuffer(ctx, dst, true);
|
||||
|
||||
vk_op_fwht_push_constants pc = {
|
||||
n_rows,
|
||||
0,
|
||||
0,
|
||||
1.0f / std::sqrt((float)src->ne[0]),
|
||||
};
|
||||
init_pushconst_tensor_offsets(ctx, pc, src, nullptr, nullptr, nullptr, dst);
|
||||
|
||||
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { src_buf, dst_buf }, pc, { workgroups_x, 1, 1 });
|
||||
}
|
||||
|
||||
static void ggml_vk_mul_mat(ggml_backend_vk_context * ctx, vk_context& subctx, const struct ggml_cgraph * cgraph, int node_idx) {
|
||||
ggml_tensor * dst = cgraph->nodes[node_idx];
|
||||
ggml_tensor * src0 = dst->src[0];
|
||||
|
|
@ -8774,6 +8863,8 @@ static void ggml_vk_mul_mat(ggml_backend_vk_context * ctx, vk_context& subctx, c
|
|||
|
||||
m_offset += cur_M_size;
|
||||
}
|
||||
} else if (ggml_vk_can_use_fwht(ctx, src1, dst)) {
|
||||
ggml_vk_fwht(ctx, subctx, src1, dst);
|
||||
} else if (src0->type == GGML_TYPE_F16 && ggml_is_permuted(src0) && ggml_is_permuted(src1) && dst->ne[1] == 1 &&
|
||||
// detect 0213 permutation, and batch size of 1
|
||||
src0->nb[0] <= src0->nb[2] &&
|
||||
|
|
|
|||
69
ggml/src/ggml-vulkan/vulkan-shaders/fwht.comp
Normal file
69
ggml/src/ggml-vulkan/vulkan-shaders/fwht.comp
Normal file
|
|
@ -0,0 +1,69 @@
|
|||
#version 450
|
||||
|
||||
#extension GL_EXT_control_flow_attributes : require
|
||||
#extension GL_KHR_shader_subgroup_basic : enable
|
||||
#extension GL_KHR_shader_subgroup_shuffle : enable
|
||||
|
||||
layout(local_size_x_id = 0, local_size_y = 4, local_size_z = 1) in;
|
||||
|
||||
layout(constant_id = 0) const uint WARP_SIZE = 32;
|
||||
layout(constant_id = 1) const uint N = 128;
|
||||
|
||||
layout(push_constant) uniform parameter
|
||||
{
|
||||
uint n_rows;
|
||||
uint src_offset;
|
||||
uint dst_offset;
|
||||
float scale;
|
||||
};
|
||||
|
||||
layout(binding = 0, std430) readonly buffer A { float data_a[]; };
|
||||
layout(binding = 1, std430) writeonly buffer D { float data_d[]; };
|
||||
|
||||
const uint EL_W = N / WARP_SIZE;
|
||||
|
||||
void main() {
|
||||
const uint lane = gl_SubgroupInvocationID;
|
||||
for (uint row = gl_WorkGroupID.x * gl_WorkGroupSize.y + gl_SubgroupID;
|
||||
row < n_rows;
|
||||
row += gl_NumWorkGroups.x * gl_WorkGroupSize.y) {
|
||||
const uint row_offset = row * N;
|
||||
|
||||
float reg[EL_W];
|
||||
|
||||
[[unroll]]
|
||||
for (uint i = 0; i < EL_W; ++i) {
|
||||
reg[i] = data_a[src_offset + row_offset + i * WARP_SIZE + lane] * scale;
|
||||
}
|
||||
|
||||
[[unroll]]
|
||||
for (uint h = 1; h < WARP_SIZE; h <<= 1) {
|
||||
[[unroll]]
|
||||
for (uint j = 0; j < EL_W; ++j) {
|
||||
const float val = reg[j];
|
||||
const float val2 = subgroupShuffleXor(val, h);
|
||||
reg[j] = (lane & h) == 0 ? val + val2 : val2 - val;
|
||||
}
|
||||
}
|
||||
|
||||
[[unroll]]
|
||||
for (uint h = WARP_SIZE; h < N; h <<= 1) {
|
||||
const uint step = h / WARP_SIZE;
|
||||
[[unroll]]
|
||||
for (uint j = 0; j < EL_W; j += 2 * step) {
|
||||
[[unroll]]
|
||||
for (uint k = 0; k < step; ++k) {
|
||||
const float x = reg[j + k];
|
||||
const float y = reg[j + k + step];
|
||||
reg[j + k] = x + y;
|
||||
reg[j + k + step] = x - y;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
[[unroll]]
|
||||
for (uint i = 0; i < EL_W; ++i) {
|
||||
data_d[dst_offset + row_offset + i * WARP_SIZE + lane] = reg[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -934,6 +934,7 @@ void process_shaders() {
|
|||
|
||||
string_to_spv("argmax_f32", "argmax.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "int"}}));
|
||||
string_to_spv("sum_rows_f32", "sum_rows.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
|
||||
string_to_spv("fwht_f32", "fwht.comp", {});
|
||||
string_to_spv("count_equal_i32", "count_equal.comp", merge_maps(base_dict, {{"A_TYPE", "int"}, {"B_TYPE", "int"}, {"D_TYPE", "int"}}));
|
||||
string_to_spv("cumsum_f32", "cumsum.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
|
||||
string_to_spv("cumsum_multipass1_f32", "cumsum_multipass1.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
|
||||
|
|
|
|||
|
|
@ -8318,6 +8318,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
|||
test_cases.emplace_back(new test_mul_mat_hadamard(GGML_TYPE_F32, GGML_TYPE_F32, 128, 1, 128));
|
||||
test_cases.emplace_back(new test_mul_mat_hadamard(GGML_TYPE_F32, GGML_TYPE_F32, 64, 1, 64));
|
||||
test_cases.emplace_back(new test_mul_mat_hadamard(GGML_TYPE_F32, GGML_TYPE_F32, 256, 1, 256));
|
||||
test_cases.emplace_back(new test_mul_mat_hadamard(GGML_TYPE_F32, GGML_TYPE_F32, 512, 1, 512));
|
||||
test_cases.emplace_back(new test_mul_mat_hadamard(GGML_TYPE_F32, GGML_TYPE_F32, 128, 32, 128));
|
||||
test_cases.emplace_back(new test_mul_mat_hadamard(GGML_TYPE_F32, GGML_TYPE_F32, 128, 4, 128, {2, 3}));
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue