diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index d3fb19048..aa289220a 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -499,6 +499,12 @@ static constexpr std::initializer_list topk_moe_late_softmax { GGM GGML_OP_GET_ROWS, GGML_OP_RESHAPE, GGML_OP_SOFT_MAX, GGML_OP_RESHAPE }; +// Snake activation: y = x + sin(a*x)^2 * inv_b. Used by the optimize_graph reorder +// pass so it keeps the chain contiguous and by the dispatcher to detect the fusion. +static constexpr std::initializer_list snake_pattern { GGML_OP_MUL, GGML_OP_SIN, + GGML_OP_SQR, GGML_OP_MUL, + GGML_OP_ADD }; + //node #978 ( SOFT_MAX): ffn_moe_probs-15 ( 0K) [Vulka ] use=2: ffn_moe_logits-15 ( 0K) [Vulka ] //node #979 ( RESHAPE): ffn_moe_probs-15 (re ( 0K) [Vulka ] use=1: ffn_moe_probs-15 ( 0K) [Vulka ] //node #980 ( ARGSORT): ffn_moe_argsort-15 ( 0K) [Vulka ] use=1: ffn_moe_probs-15 ( 0K) [Vulka ] @@ -846,6 +852,9 @@ struct vk_device_struct { vk_pipeline pipeline_im2col_3d_f32, pipeline_im2col_3d_f32_f16; vk_pipeline pipeline_timestep_embedding_f32; vk_pipeline pipeline_conv_transpose_1d_f32; + vk_pipeline pipeline_snake_f32; + vk_pipeline pipeline_snake_f16; + vk_pipeline pipeline_snake_bf16; vk_pipeline pipeline_pool2d_f32; vk_pipeline pipeline_rwkv_wkv6_f32; vk_pipeline pipeline_rwkv_wkv7_f32; @@ -1475,6 +1484,11 @@ struct vk_op_conv_transpose_1d_push_constants { int32_t s0; }; +struct vk_op_snake_push_constants { + uint32_t ne0; + uint32_t ne1; +}; + struct vk_op_pool2d_push_constants { uint32_t IW; uint32_t IH; uint32_t OW; uint32_t OH; @@ -4845,6 +4859,10 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_conv_transpose_1d_f32, "conv_transpose_1d_f32", conv_transpose_1d_f32_len, conv_transpose_1d_f32_data, "main", 3, sizeof(vk_op_conv_transpose_1d_push_constants), {1, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_snake_f32, "snake_f32", snake_f32_len, snake_f32_data, "main", 4, sizeof(vk_op_snake_push_constants), {256, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_snake_f16, "snake_f16", snake_f16_len, snake_f16_data, "main", 4, sizeof(vk_op_snake_push_constants), {256, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_snake_bf16, "snake_bf16", snake_bf16_len, snake_bf16_data, "main", 4, sizeof(vk_op_snake_push_constants), {256, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_pool2d_f32, "pool2d_f32", pool2d_f32_len, pool2d_f32_data, "main", 2, sizeof(vk_op_pool2d_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_rwkv_wkv6_f32, "rwkv_wkv6_f32", rwkv_wkv6_f32_len, rwkv_wkv6_f32_data, "main", 7, sizeof(vk_op_rwkv_wkv6_push_constants), {1, 1, 1}, {device->subgroup_size}, 1); @@ -12110,6 +12128,45 @@ static void ggml_vk_conv_transpose_1d(ggml_backend_vk_context * ctx, vk_context& ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_CONV_TRANSPOSE_1D, std::move(p)); } +// Dispatch the fused snake activation: y = x + sin^2(a * x) * inv_b. +// Match the naive mul -> sin -> sqr -> mul -> add chain and run the +// dedicated kernel directly. The pattern is validated by +// ggml_vk_can_fuse_snake before this call. +static void ggml_vk_snake_dispatch_fused(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_cgraph * cgraph, int node_idx) { + const ggml_tensor * mul0 = cgraph->nodes[node_idx + 0]; + const ggml_tensor * sqr = cgraph->nodes[node_idx + 2]; + const ggml_tensor * mul1 = cgraph->nodes[node_idx + 3]; + ggml_tensor * add = cgraph->nodes[node_idx + 4]; + + // x carries the full activation shape, a is the broadcast operand + const ggml_tensor * x = ggml_are_same_shape(mul0, mul0->src[0]) ? mul0->src[0] : mul0->src[1]; + const ggml_tensor * a = (x == mul0->src[0]) ? mul0->src[1] : mul0->src[0]; + + // mul1 reads sqr and inv_b in either operand order + const ggml_tensor * inv_b = (mul1->src[0] == sqr) ? mul1->src[1] : mul1->src[0]; + + vk_pipeline pipeline = nullptr; + switch (x->type) { + case GGML_TYPE_F32: pipeline = ctx->device->pipeline_snake_f32; break; + case GGML_TYPE_F16: pipeline = ctx->device->pipeline_snake_f16; break; + case GGML_TYPE_BF16: pipeline = ctx->device->pipeline_snake_bf16; break; + default: GGML_ABORT("unsupported type"); + } + ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1); + + vk_subbuffer x_buf = ggml_vk_tensor_subbuffer(ctx, x); + vk_subbuffer a_buf = ggml_vk_tensor_subbuffer(ctx, a); + vk_subbuffer inv_b_buf = ggml_vk_tensor_subbuffer(ctx, inv_b); + vk_subbuffer dst_buf = ggml_vk_tensor_subbuffer(ctx, add); + + vk_op_snake_push_constants pc{}; + pc.ne0 = static_cast(x->ne[0]); + pc.ne1 = static_cast(x->ne[1]); + + std::array elements = { pc.ne0, pc.ne1, 1 }; + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { x_buf, a_buf, inv_b_buf, dst_buf }, pc, elements); +} + static void ggml_vk_pool_2d(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) { uint32_t op = static_cast(dst->op_params[0]); const int32_t k1 = dst->op_params[1]; @@ -13318,7 +13375,11 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr break; case GGML_OP_MUL: - ggml_vk_mul(ctx, compute_ctx, src0, src1, node); + if (ctx->num_additional_fused_ops) { + ggml_vk_snake_dispatch_fused(ctx, compute_ctx, cgraph, node_idx); + } else { + ggml_vk_mul(ctx, compute_ctx, src0, src1, node); + } break; case GGML_OP_DIV: @@ -14691,6 +14752,65 @@ static bool ggml_vk_can_fuse_rope_set_rows(ggml_backend_vk_context * ctx, const return true; } +// Pattern check for the 5-op Snake fusion: mul -> sin -> sqr -> mul -> add. +// Verifies the chain shape, the closure x_in_add == x_in_mul0, and that +// the broadcast operands a and inv_b share a [1, C] layout. +static bool ggml_vk_can_fuse_snake(ggml_backend_vk_context * ctx, const struct ggml_cgraph * cgraph, int node_idx) { + GGML_UNUSED(ctx); + if (!ggml_can_fuse(cgraph, node_idx, snake_pattern)) { + return false; + } + + const ggml_tensor * mul0 = cgraph->nodes[node_idx + 0]; + const ggml_tensor * sin_node = cgraph->nodes[node_idx + 1]; + const ggml_tensor * sqr = cgraph->nodes[node_idx + 2]; + const ggml_tensor * mul1 = cgraph->nodes[node_idx + 3]; + const ggml_tensor * add = cgraph->nodes[node_idx + 4]; + + const ggml_tensor * x = ggml_are_same_shape(mul0, mul0->src[0]) ? mul0->src[0] : mul0->src[1]; + const ggml_tensor * a = (x == mul0->src[0]) ? mul0->src[1] : mul0->src[0]; + + const ggml_tensor * inv_b = (mul1->src[0] == sqr) ? mul1->src[1] : mul1->src[0]; + const ggml_tensor * x_in_add = (add->src[0] == mul1) ? add->src[1] : add->src[0]; + + if (x_in_add != x) { + return false; + } + if (x->type != GGML_TYPE_F32 && x->type != GGML_TYPE_F16 && x->type != GGML_TYPE_BF16) { + return false; + } + // Shader bindings: data_a is A_TYPE so it follows x's precision, while + // data_b and data_c are hardcoded float, so the broadcast operands must + // be F32 regardless of x's type. + if (a->type != GGML_TYPE_F32) return false; + if (inv_b->type != GGML_TYPE_F32) return false; + // Chain intermediates and output share x's precision (single A_TYPE / D_TYPE pipeline). + if (mul0->type != x->type) return false; + if (sin_node->type != x->type) return false; + if (sqr->type != x->type) return false; + if (mul1->type != x->type) return false; + if (add->type != x->type) return false; + if (!ggml_are_same_shape(a, inv_b)) { + return false; + } + if (a->ne[0] != 1 || a->ne[1] != x->ne[1]) { + return false; + } + // Dispatch is 2D over (ne0, ne1), so x and add must be 2D and a / inv_b + // must collapse to [1, C, 1, 1]. Higher dims are not handled by the shader. + if (x->ne[2] != 1 || x->ne[3] != 1) return false; + if (add->ne[2] != 1 || add->ne[3] != 1) return false; + if (a->ne[2] != 1 || a->ne[3] != 1) return false; + if (inv_b->ne[2] != 1 || inv_b->ne[3] != 1) return false; + // Shader uses idx = i0 + i1 * ne0 and reads data_b[i1] / data_c[i1], + // so every operand must be contiguous. + if (!ggml_is_contiguous(x) || !ggml_is_contiguous(add) || + !ggml_is_contiguous(a) || !ggml_is_contiguous(inv_b)) { + return false; + } + return true; +} + // Check whether the tensors overlap in memory. // Fusions can potentially overwrite src tensors in ways that are not prevented // by ggml-alloc. If the fusion src is being applied in a way that's elementwise @@ -14998,6 +15118,14 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg op_srcs_fused_elementwise[0] = false; op_srcs_fused_elementwise[1] = false; op_srcs_fused_elementwise[2] = false; + } else if (ggml_vk_can_fuse_snake(ctx, cgraph, i)) { + ctx->num_additional_fused_ops = 4; + fusion_string = "SNAKE"; + // elementwise=true: snake.comp is safe under exact aliasing because each + // thread reads data_x[idx] into a register before writing data_d[idx] + // with a data dependency on that register. The overlap check still + // rejects partial overlaps (different base or size). + std::fill_n(op_srcs_fused_elementwise, 5, true); } else if (ggml_can_fuse_subgraph(cgraph, i, topk_moe_early_softmax_norm, { i + 3, i + 9 }) && ggml_check_edges(cgraph, i, topk_moe_early_softmax_norm_edges) && ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, TOPK_MOE_EARLY_SOFTMAX_NORM)) { @@ -15288,6 +15416,9 @@ static void ggml_vk_graph_optimize(ggml_backend_t backend, struct ggml_cgraph * if (keep_pattern(topk_moe_late_softmax)) { continue; } + if (keep_pattern(snake_pattern)) { + continue; + } // First, grab the next unused node. current_set.push_back(first_unused); @@ -15310,7 +15441,8 @@ static void ggml_vk_graph_optimize(ggml_backend_t backend, struct ggml_cgraph * if (match_pattern(topk_moe_early_softmax_norm, j) || match_pattern(topk_moe_sigmoid_norm_bias, j) || match_pattern(topk_moe_early_softmax, j) || - match_pattern(topk_moe_late_softmax, j)) { + match_pattern(topk_moe_late_softmax, j) || + match_pattern(snake_pattern, j)) { continue; } bool ok = true; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/snake.comp b/ggml/src/ggml-vulkan/vulkan-shaders/snake.comp new file mode 100644 index 000000000..8585538cb --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/snake.comp @@ -0,0 +1,49 @@ +#version 450 + +#include "types.glsl" + +// Fused snake activation: y = x + sin(b * x)^2 * c +// data_a [ne0, ne1] per element activation x (A_TYPE) +// data_b [1, ne1] per channel multiplier (float) +// data_c [1, ne1] per channel inverse scale (float, precomputed as 1 / freq) +// data_d [ne0, ne1] output y (D_TYPE) +layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; +layout (binding = 1) readonly buffer B {float data_b[];}; +layout (binding = 2) readonly buffer C {float data_c[];}; +layout (binding = 3) writeonly buffer D {D_TYPE data_d[];}; + +layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in; + +layout (push_constant) uniform parameter { + uint32_t ne0; + uint32_t ne1; +} p; + +// Load A_TYPE to float +float load_val(uint32_t idx) { +#if defined(DATA_A_BF16) + return bf16_to_fp32(uint32_t(data_a[idx])); +#else + return float(data_a[idx]); +#endif +} + +// Store float as D_TYPE +void store_val(uint32_t idx, float v) { +#if defined(DATA_D_BF16) + data_d[idx] = D_TYPE(fp32_to_bf16(v)); +#else + data_d[idx] = D_TYPE(v); +#endif +} + +void main() { + const uint32_t i0 = gl_GlobalInvocationID.x; + const uint32_t i1 = gl_GlobalInvocationID.y; + if (i0 >= p.ne0 || i1 >= p.ne1) return; + + const uint32_t idx = i0 + i1 * p.ne0; + const float xi = load_val(idx); + const float s = sin(data_b[i1] * xi); + store_val(idx, xi + s * s * data_c[i1]); +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp index e3a9d61a5..a1d735150 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp @@ -952,6 +952,10 @@ void process_shaders() { string_to_spv("conv_transpose_1d_f32", "conv_transpose_1d.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}); + string_to_spv("snake_f32", "snake.comp", {{"DATA_A_F32", "1"}, {"A_TYPE", "float"}, {"D_TYPE", "float"}}); + string_to_spv("snake_f16", "snake.comp", {{"DATA_A_F16", "1"}, {"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); + string_to_spv("snake_bf16", "snake.comp", {{"DATA_A_BF16", "1"}, {"DATA_D_BF16", "1"}, {"A_TYPE", "uint16_t"}, {"D_TYPE", "uint16_t"}}); + string_to_spv("pool2d_f32", "pool2d.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}})); string_to_spv("rwkv_wkv6_f32", "wkv6.comp", merge_maps(base_dict, {{"A_TYPE", "float"}}));