mirror of
https://github.com/LostRuins/koboldcpp.git
synced 2026-05-22 11:16:08 +00:00
vulkan: fuse snake activation (mul, sin, sqr, mul, add) (#22855)
* vulkan: fuse snake activation (mul, sin, sqr, mul, add) Add snake.comp shader with F32 / F16 / BF16 pipelines and ggml_vk_snake_dispatch_fused. The matcher recognizes the naive 5 op decomposition emitted by audio decoders (BigVGAN, Vocos) for snake activation y = x + sin(a*x)^2 * inv_b and rewrites it to a single elementwise kernel. test_snake_fuse from the CUDA PR now also compares CPU naive vs Vulkan fused across F32 / F16 / BF16. * vulkan: address jeffbolznv review for fused snake activation Rename T / C to ne0 / ne1 in the shader and push constants to match the standard naming convention used across the Vulkan backend. Tighten ggml_vk_can_fuse_snake: require x and dst to be contiguous (the shader uses idx = i0 + i1 * ne0) and require a / inv_b to be tightly packed on the broadcast dim (the shader reads data_a[i1]). * vulkan: tighten snake fusion type checks for all operands (address jeffbolznv review) * vulkan: reject snake fusion when ne[2] or ne[3] > 1 (address jeffbolznv review) * vulkan: address 0cc4m review for fused snake activation snake.comp is renamed to follow the ggml DATA_A_* / A_TYPE convention. A_TYPE now applies to the activation tensor data_a instead of the broadcast multiplier, and the bindings become data_a (A_TYPE), data_b (float), data_c (float) and data_d (D_TYPE). A header at the top of the shader maps each buffer to its role in y = x + sin(b * x)^2 * c. On the C++ side, ggml_vk_can_fuse_snake reuses the existing snake_pattern constant instead of duplicating the op list, sin_node is extracted as a named local alongside the other chain nodes, and the broadcast operands a and inv_b are now required to be GGML_TYPE_F32 to match the hardcoded float bindings on data_b and data_c (the previous a->type == x->type would silently reject any future BF16 or F16 chain once the supports_op gate for SIN / SQR is lifted). ggml_vk_snake_dispatch_fused gets an explicit GGML_TYPE_F32 case and GGML_ABORT on default in place of the silent f32 fallback, and a stale comment about data_a[i1] / data_inv_b[i1] is refreshed to match the new binding names.
This commit is contained in:
parent
5306f4b3b5
commit
47c0eda9d4
3 changed files with 187 additions and 2 deletions
|
|
@ -499,6 +499,12 @@ static constexpr std::initializer_list<ggml_op> topk_moe_late_softmax { GGM
|
|||
GGML_OP_GET_ROWS, GGML_OP_RESHAPE,
|
||||
GGML_OP_SOFT_MAX, GGML_OP_RESHAPE };
|
||||
|
||||
// Snake activation: y = x + sin(a*x)^2 * inv_b. Used by the optimize_graph reorder
|
||||
// pass so it keeps the chain contiguous and by the dispatcher to detect the fusion.
|
||||
static constexpr std::initializer_list<ggml_op> snake_pattern { GGML_OP_MUL, GGML_OP_SIN,
|
||||
GGML_OP_SQR, GGML_OP_MUL,
|
||||
GGML_OP_ADD };
|
||||
|
||||
//node #978 ( SOFT_MAX): ffn_moe_probs-15 ( 0K) [Vulka ] use=2: ffn_moe_logits-15 ( 0K) [Vulka ]
|
||||
//node #979 ( RESHAPE): ffn_moe_probs-15 (re ( 0K) [Vulka ] use=1: ffn_moe_probs-15 ( 0K) [Vulka ]
|
||||
//node #980 ( ARGSORT): ffn_moe_argsort-15 ( 0K) [Vulka ] use=1: ffn_moe_probs-15 ( 0K) [Vulka ]
|
||||
|
|
@ -846,6 +852,9 @@ struct vk_device_struct {
|
|||
vk_pipeline pipeline_im2col_3d_f32, pipeline_im2col_3d_f32_f16;
|
||||
vk_pipeline pipeline_timestep_embedding_f32;
|
||||
vk_pipeline pipeline_conv_transpose_1d_f32;
|
||||
vk_pipeline pipeline_snake_f32;
|
||||
vk_pipeline pipeline_snake_f16;
|
||||
vk_pipeline pipeline_snake_bf16;
|
||||
vk_pipeline pipeline_pool2d_f32;
|
||||
vk_pipeline pipeline_rwkv_wkv6_f32;
|
||||
vk_pipeline pipeline_rwkv_wkv7_f32;
|
||||
|
|
@ -1475,6 +1484,11 @@ struct vk_op_conv_transpose_1d_push_constants {
|
|||
int32_t s0;
|
||||
};
|
||||
|
||||
struct vk_op_snake_push_constants {
|
||||
uint32_t ne0;
|
||||
uint32_t ne1;
|
||||
};
|
||||
|
||||
struct vk_op_pool2d_push_constants {
|
||||
uint32_t IW; uint32_t IH;
|
||||
uint32_t OW; uint32_t OH;
|
||||
|
|
@ -4845,6 +4859,10 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|||
|
||||
ggml_vk_create_pipeline(device, device->pipeline_conv_transpose_1d_f32, "conv_transpose_1d_f32", conv_transpose_1d_f32_len, conv_transpose_1d_f32_data, "main", 3, sizeof(vk_op_conv_transpose_1d_push_constants), {1, 1, 1}, {}, 1);
|
||||
|
||||
ggml_vk_create_pipeline(device, device->pipeline_snake_f32, "snake_f32", snake_f32_len, snake_f32_data, "main", 4, sizeof(vk_op_snake_push_constants), {256, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_snake_f16, "snake_f16", snake_f16_len, snake_f16_data, "main", 4, sizeof(vk_op_snake_push_constants), {256, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_snake_bf16, "snake_bf16", snake_bf16_len, snake_bf16_data, "main", 4, sizeof(vk_op_snake_push_constants), {256, 1, 1}, {}, 1);
|
||||
|
||||
ggml_vk_create_pipeline(device, device->pipeline_pool2d_f32, "pool2d_f32", pool2d_f32_len, pool2d_f32_data, "main", 2, sizeof(vk_op_pool2d_push_constants), {512, 1, 1}, {}, 1);
|
||||
|
||||
ggml_vk_create_pipeline(device, device->pipeline_rwkv_wkv6_f32, "rwkv_wkv6_f32", rwkv_wkv6_f32_len, rwkv_wkv6_f32_data, "main", 7, sizeof(vk_op_rwkv_wkv6_push_constants), {1, 1, 1}, {device->subgroup_size}, 1);
|
||||
|
|
@ -12110,6 +12128,45 @@ static void ggml_vk_conv_transpose_1d(ggml_backend_vk_context * ctx, vk_context&
|
|||
ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_CONV_TRANSPOSE_1D, std::move(p));
|
||||
}
|
||||
|
||||
// Dispatch the fused snake activation: y = x + sin^2(a * x) * inv_b.
|
||||
// Match the naive mul -> sin -> sqr -> mul -> add chain and run the
|
||||
// dedicated kernel directly. The pattern is validated by
|
||||
// ggml_vk_can_fuse_snake before this call.
|
||||
static void ggml_vk_snake_dispatch_fused(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_cgraph * cgraph, int node_idx) {
|
||||
const ggml_tensor * mul0 = cgraph->nodes[node_idx + 0];
|
||||
const ggml_tensor * sqr = cgraph->nodes[node_idx + 2];
|
||||
const ggml_tensor * mul1 = cgraph->nodes[node_idx + 3];
|
||||
ggml_tensor * add = cgraph->nodes[node_idx + 4];
|
||||
|
||||
// x carries the full activation shape, a is the broadcast operand
|
||||
const ggml_tensor * x = ggml_are_same_shape(mul0, mul0->src[0]) ? mul0->src[0] : mul0->src[1];
|
||||
const ggml_tensor * a = (x == mul0->src[0]) ? mul0->src[1] : mul0->src[0];
|
||||
|
||||
// mul1 reads sqr and inv_b in either operand order
|
||||
const ggml_tensor * inv_b = (mul1->src[0] == sqr) ? mul1->src[1] : mul1->src[0];
|
||||
|
||||
vk_pipeline pipeline = nullptr;
|
||||
switch (x->type) {
|
||||
case GGML_TYPE_F32: pipeline = ctx->device->pipeline_snake_f32; break;
|
||||
case GGML_TYPE_F16: pipeline = ctx->device->pipeline_snake_f16; break;
|
||||
case GGML_TYPE_BF16: pipeline = ctx->device->pipeline_snake_bf16; break;
|
||||
default: GGML_ABORT("unsupported type");
|
||||
}
|
||||
ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1);
|
||||
|
||||
vk_subbuffer x_buf = ggml_vk_tensor_subbuffer(ctx, x);
|
||||
vk_subbuffer a_buf = ggml_vk_tensor_subbuffer(ctx, a);
|
||||
vk_subbuffer inv_b_buf = ggml_vk_tensor_subbuffer(ctx, inv_b);
|
||||
vk_subbuffer dst_buf = ggml_vk_tensor_subbuffer(ctx, add);
|
||||
|
||||
vk_op_snake_push_constants pc{};
|
||||
pc.ne0 = static_cast<uint32_t>(x->ne[0]);
|
||||
pc.ne1 = static_cast<uint32_t>(x->ne[1]);
|
||||
|
||||
std::array<uint32_t, 3> elements = { pc.ne0, pc.ne1, 1 };
|
||||
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { x_buf, a_buf, inv_b_buf, dst_buf }, pc, elements);
|
||||
}
|
||||
|
||||
static void ggml_vk_pool_2d(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
|
||||
uint32_t op = static_cast<uint32_t>(dst->op_params[0]);
|
||||
const int32_t k1 = dst->op_params[1];
|
||||
|
|
@ -13318,7 +13375,11 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
|
|||
|
||||
break;
|
||||
case GGML_OP_MUL:
|
||||
ggml_vk_mul(ctx, compute_ctx, src0, src1, node);
|
||||
if (ctx->num_additional_fused_ops) {
|
||||
ggml_vk_snake_dispatch_fused(ctx, compute_ctx, cgraph, node_idx);
|
||||
} else {
|
||||
ggml_vk_mul(ctx, compute_ctx, src0, src1, node);
|
||||
}
|
||||
|
||||
break;
|
||||
case GGML_OP_DIV:
|
||||
|
|
@ -14691,6 +14752,65 @@ static bool ggml_vk_can_fuse_rope_set_rows(ggml_backend_vk_context * ctx, const
|
|||
return true;
|
||||
}
|
||||
|
||||
// Pattern check for the 5-op Snake fusion: mul -> sin -> sqr -> mul -> add.
|
||||
// Verifies the chain shape, the closure x_in_add == x_in_mul0, and that
|
||||
// the broadcast operands a and inv_b share a [1, C] layout.
|
||||
static bool ggml_vk_can_fuse_snake(ggml_backend_vk_context * ctx, const struct ggml_cgraph * cgraph, int node_idx) {
|
||||
GGML_UNUSED(ctx);
|
||||
if (!ggml_can_fuse(cgraph, node_idx, snake_pattern)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const ggml_tensor * mul0 = cgraph->nodes[node_idx + 0];
|
||||
const ggml_tensor * sin_node = cgraph->nodes[node_idx + 1];
|
||||
const ggml_tensor * sqr = cgraph->nodes[node_idx + 2];
|
||||
const ggml_tensor * mul1 = cgraph->nodes[node_idx + 3];
|
||||
const ggml_tensor * add = cgraph->nodes[node_idx + 4];
|
||||
|
||||
const ggml_tensor * x = ggml_are_same_shape(mul0, mul0->src[0]) ? mul0->src[0] : mul0->src[1];
|
||||
const ggml_tensor * a = (x == mul0->src[0]) ? mul0->src[1] : mul0->src[0];
|
||||
|
||||
const ggml_tensor * inv_b = (mul1->src[0] == sqr) ? mul1->src[1] : mul1->src[0];
|
||||
const ggml_tensor * x_in_add = (add->src[0] == mul1) ? add->src[1] : add->src[0];
|
||||
|
||||
if (x_in_add != x) {
|
||||
return false;
|
||||
}
|
||||
if (x->type != GGML_TYPE_F32 && x->type != GGML_TYPE_F16 && x->type != GGML_TYPE_BF16) {
|
||||
return false;
|
||||
}
|
||||
// Shader bindings: data_a is A_TYPE so it follows x's precision, while
|
||||
// data_b and data_c are hardcoded float, so the broadcast operands must
|
||||
// be F32 regardless of x's type.
|
||||
if (a->type != GGML_TYPE_F32) return false;
|
||||
if (inv_b->type != GGML_TYPE_F32) return false;
|
||||
// Chain intermediates and output share x's precision (single A_TYPE / D_TYPE pipeline).
|
||||
if (mul0->type != x->type) return false;
|
||||
if (sin_node->type != x->type) return false;
|
||||
if (sqr->type != x->type) return false;
|
||||
if (mul1->type != x->type) return false;
|
||||
if (add->type != x->type) return false;
|
||||
if (!ggml_are_same_shape(a, inv_b)) {
|
||||
return false;
|
||||
}
|
||||
if (a->ne[0] != 1 || a->ne[1] != x->ne[1]) {
|
||||
return false;
|
||||
}
|
||||
// Dispatch is 2D over (ne0, ne1), so x and add must be 2D and a / inv_b
|
||||
// must collapse to [1, C, 1, 1]. Higher dims are not handled by the shader.
|
||||
if (x->ne[2] != 1 || x->ne[3] != 1) return false;
|
||||
if (add->ne[2] != 1 || add->ne[3] != 1) return false;
|
||||
if (a->ne[2] != 1 || a->ne[3] != 1) return false;
|
||||
if (inv_b->ne[2] != 1 || inv_b->ne[3] != 1) return false;
|
||||
// Shader uses idx = i0 + i1 * ne0 and reads data_b[i1] / data_c[i1],
|
||||
// so every operand must be contiguous.
|
||||
if (!ggml_is_contiguous(x) || !ggml_is_contiguous(add) ||
|
||||
!ggml_is_contiguous(a) || !ggml_is_contiguous(inv_b)) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
// Check whether the tensors overlap in memory.
|
||||
// Fusions can potentially overwrite src tensors in ways that are not prevented
|
||||
// by ggml-alloc. If the fusion src is being applied in a way that's elementwise
|
||||
|
|
@ -14998,6 +15118,14 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
|
|||
op_srcs_fused_elementwise[0] = false;
|
||||
op_srcs_fused_elementwise[1] = false;
|
||||
op_srcs_fused_elementwise[2] = false;
|
||||
} else if (ggml_vk_can_fuse_snake(ctx, cgraph, i)) {
|
||||
ctx->num_additional_fused_ops = 4;
|
||||
fusion_string = "SNAKE";
|
||||
// elementwise=true: snake.comp is safe under exact aliasing because each
|
||||
// thread reads data_x[idx] into a register before writing data_d[idx]
|
||||
// with a data dependency on that register. The overlap check still
|
||||
// rejects partial overlaps (different base or size).
|
||||
std::fill_n(op_srcs_fused_elementwise, 5, true);
|
||||
} else if (ggml_can_fuse_subgraph(cgraph, i, topk_moe_early_softmax_norm, { i + 3, i + 9 }) &&
|
||||
ggml_check_edges(cgraph, i, topk_moe_early_softmax_norm_edges) &&
|
||||
ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, TOPK_MOE_EARLY_SOFTMAX_NORM)) {
|
||||
|
|
@ -15288,6 +15416,9 @@ static void ggml_vk_graph_optimize(ggml_backend_t backend, struct ggml_cgraph *
|
|||
if (keep_pattern(topk_moe_late_softmax)) {
|
||||
continue;
|
||||
}
|
||||
if (keep_pattern(snake_pattern)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// First, grab the next unused node.
|
||||
current_set.push_back(first_unused);
|
||||
|
|
@ -15310,7 +15441,8 @@ static void ggml_vk_graph_optimize(ggml_backend_t backend, struct ggml_cgraph *
|
|||
if (match_pattern(topk_moe_early_softmax_norm, j) ||
|
||||
match_pattern(topk_moe_sigmoid_norm_bias, j) ||
|
||||
match_pattern(topk_moe_early_softmax, j) ||
|
||||
match_pattern(topk_moe_late_softmax, j)) {
|
||||
match_pattern(topk_moe_late_softmax, j) ||
|
||||
match_pattern(snake_pattern, j)) {
|
||||
continue;
|
||||
}
|
||||
bool ok = true;
|
||||
|
|
|
|||
49
ggml/src/ggml-vulkan/vulkan-shaders/snake.comp
Normal file
49
ggml/src/ggml-vulkan/vulkan-shaders/snake.comp
Normal file
|
|
@ -0,0 +1,49 @@
|
|||
#version 450
|
||||
|
||||
#include "types.glsl"
|
||||
|
||||
// Fused snake activation: y = x + sin(b * x)^2 * c
|
||||
// data_a [ne0, ne1] per element activation x (A_TYPE)
|
||||
// data_b [1, ne1] per channel multiplier (float)
|
||||
// data_c [1, ne1] per channel inverse scale (float, precomputed as 1 / freq)
|
||||
// data_d [ne0, ne1] output y (D_TYPE)
|
||||
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
|
||||
layout (binding = 1) readonly buffer B {float data_b[];};
|
||||
layout (binding = 2) readonly buffer C {float data_c[];};
|
||||
layout (binding = 3) writeonly buffer D {D_TYPE data_d[];};
|
||||
|
||||
layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in;
|
||||
|
||||
layout (push_constant) uniform parameter {
|
||||
uint32_t ne0;
|
||||
uint32_t ne1;
|
||||
} p;
|
||||
|
||||
// Load A_TYPE to float
|
||||
float load_val(uint32_t idx) {
|
||||
#if defined(DATA_A_BF16)
|
||||
return bf16_to_fp32(uint32_t(data_a[idx]));
|
||||
#else
|
||||
return float(data_a[idx]);
|
||||
#endif
|
||||
}
|
||||
|
||||
// Store float as D_TYPE
|
||||
void store_val(uint32_t idx, float v) {
|
||||
#if defined(DATA_D_BF16)
|
||||
data_d[idx] = D_TYPE(fp32_to_bf16(v));
|
||||
#else
|
||||
data_d[idx] = D_TYPE(v);
|
||||
#endif
|
||||
}
|
||||
|
||||
void main() {
|
||||
const uint32_t i0 = gl_GlobalInvocationID.x;
|
||||
const uint32_t i1 = gl_GlobalInvocationID.y;
|
||||
if (i0 >= p.ne0 || i1 >= p.ne1) return;
|
||||
|
||||
const uint32_t idx = i0 + i1 * p.ne0;
|
||||
const float xi = load_val(idx);
|
||||
const float s = sin(data_b[i1] * xi);
|
||||
store_val(idx, xi + s * s * data_c[i1]);
|
||||
}
|
||||
|
|
@ -952,6 +952,10 @@ void process_shaders() {
|
|||
|
||||
string_to_spv("conv_transpose_1d_f32", "conv_transpose_1d.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}});
|
||||
|
||||
string_to_spv("snake_f32", "snake.comp", {{"DATA_A_F32", "1"}, {"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
||||
string_to_spv("snake_f16", "snake.comp", {{"DATA_A_F16", "1"}, {"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
|
||||
string_to_spv("snake_bf16", "snake.comp", {{"DATA_A_BF16", "1"}, {"DATA_D_BF16", "1"}, {"A_TYPE", "uint16_t"}, {"D_TYPE", "uint16_t"}});
|
||||
|
||||
string_to_spv("pool2d_f32", "pool2d.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
|
||||
|
||||
string_to_spv("rwkv_wkv6_f32", "wkv6.comp", merge_maps(base_dict, {{"A_TYPE", "float"}}));
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue