vulkan: fuse SSM_CONV + BIAS + SILU (#22653)

This commit is contained in:
Jeff Bolz 2026-05-17 03:25:50 -05:00 committed by GitHub
parent 1a68ec9378
commit 3fbadb06dc
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 129 additions and 9 deletions

View file

@ -854,6 +854,8 @@ struct vk_device_struct {
vk_pipeline pipeline_ssm_scan_f32_d128;
vk_pipeline pipeline_ssm_scan_f32_d256;
vk_pipeline pipeline_ssm_conv_f32;
vk_pipeline pipeline_ssm_conv_silu_f32;
vk_pipeline pipeline_ssm_conv_bias_silu_f32;
vk_pipeline pipeline_opt_step_adamw_f32;
vk_pipeline pipeline_opt_step_sgd_f32;
std::map<vk_conv2d_pipeline_state, vk_pipeline> pipeline_conv2d_f32[CONV_SHAPE_COUNT];
@ -4900,7 +4902,9 @@ static void ggml_vk_load_shaders(vk_device& device) {
ggml_vk_create_pipeline(device, device->pipeline_ssm_scan_f32_d256, "ssm_scan_256_f32", ssm_scan_f32_len, ssm_scan_f32_data, "main", 8, sizeof(vk_op_ssm_scan_push_constants), {1, 1, 1}, {256, device->subgroup_size, 16}, 1, true, true);
}
ggml_vk_create_pipeline(device, device->pipeline_ssm_conv_f32, "ssm_conv_f32", ssm_conv_f32_len, ssm_conv_f32_data, "main", 3, sizeof(vk_op_ssm_conv_push_constants), {32, 16, 1}, {32, 16}, 1);
ggml_vk_create_pipeline(device, device->pipeline_ssm_conv_f32, "ssm_conv_f32", ssm_conv_f32_len, ssm_conv_f32_data, "main", 4, sizeof(vk_op_ssm_conv_push_constants), {32, 16, 1}, {32, 16, 0, 0}, 1);
ggml_vk_create_pipeline(device, device->pipeline_ssm_conv_silu_f32, "ssm_conv_silu_f32", ssm_conv_f32_len, ssm_conv_f32_data, "main", 4, sizeof(vk_op_ssm_conv_push_constants), {32, 16, 1}, {32, 16, 0, 1}, 1);
ggml_vk_create_pipeline(device, device->pipeline_ssm_conv_bias_silu_f32, "ssm_conv_bias_silu_f32", ssm_conv_f32_len, ssm_conv_f32_data, "main", 4, sizeof(vk_op_ssm_conv_push_constants), {32, 16, 1}, {32, 16, 1, 1}, 1);
ggml_vk_create_pipeline(device, device->pipeline_opt_step_adamw_f32, "opt_step_adamw_f32", opt_step_adamw_f32_len, opt_step_adamw_f32_data, "main", 5, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
@ -9936,7 +9940,12 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
return nullptr;
case GGML_OP_SSM_CONV:
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
return ctx->device->pipeline_ssm_conv_f32;
switch (ctx->num_additional_fused_ops) {
case 0: return ctx->device->pipeline_ssm_conv_f32;
case 1: return ctx->device->pipeline_ssm_conv_silu_f32;
case 2: return ctx->device->pipeline_ssm_conv_bias_silu_f32;
default: return nullptr;
}
}
return nullptr;
case GGML_OP_OPT_STEP_ADAMW:
@ -10877,11 +10886,28 @@ static void ggml_vk_ssm_scan(ggml_backend_vk_context * ctx, vk_context& subctx,
pc, elements);
}
static void ggml_vk_ssm_conv(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const ggml_tensor * src1 = dst->src[1];
static void ggml_vk_ssm_conv(ggml_backend_vk_context * ctx, vk_context& subctx, const struct ggml_cgraph * cgraph, int node_idx) {
ggml_tensor * conv = cgraph->nodes[node_idx];
const ggml_tensor * src0 = conv->src[0];
const ggml_tensor * src1 = conv->src[1];
ggml_vk_op_f32<vk_op_ssm_conv_push_constants>(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_SSM_CONV, {
// Pick the destination tensor (last node in the fused chain) and the optional bias.
// Fusion modes: 0 = ssm_conv, 1 = ssm_conv+silu, 2 = ssm_conv+add(bias)+silu.
ggml_tensor * dst = conv;
const ggml_tensor * bias = nullptr;
if (ctx->num_additional_fused_ops == 1) {
dst = cgraph->nodes[node_idx + 1]; // silu
} else if (ctx->num_additional_fused_ops == 2) {
ggml_tensor * add = cgraph->nodes[node_idx + 1];
bias = (add->src[0] == conv) ? add->src[1] : add->src[0];
dst = cgraph->nodes[node_idx + 2]; // silu
}
// The shader always declares 4 bindings; bind src0 as a dummy when bias isn't fused.
const ggml_tensor * src2 = bias ? bias : src0;
ggml_vk_op_f32<vk_op_ssm_conv_push_constants>(ctx, subctx, src0, src1, src2, nullptr, dst, GGML_OP_SSM_CONV, {
(uint32_t)src0->nb[1], (uint32_t)src0->nb[2],
(uint32_t)src1->nb[1],
(uint32_t)dst->nb[0], (uint32_t)dst->nb[1], (uint32_t)dst->nb[2],
@ -13556,7 +13582,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
break;
case GGML_OP_SSM_CONV:
ggml_vk_ssm_conv(ctx, compute_ctx, node);
ggml_vk_ssm_conv(ctx, compute_ctx, cgraph, node_idx);
break;
@ -14453,6 +14479,62 @@ static bool ggml_vk_can_fuse(const ggml_backend_vk_context * ctx, const struct g
return true;
}
// Match SSM_CONV + UNARY(SILU) or SSM_CONV + ADD + UNARY(SILU). num_extra is 1 or 2.
static bool ggml_vk_can_fuse_ssm_conv(const ggml_backend_vk_context * ctx, const struct ggml_cgraph * cgraph,
int node_idx, int num_extra) {
const ggml_tensor * conv = cgraph->nodes[node_idx];
if (conv->op != GGML_OP_SSM_CONV) {
return false;
}
const ggml_tensor * silu = nullptr;
const ggml_tensor * bias = nullptr;
if (num_extra == 1) {
if (!ggml_can_fuse(cgraph, node_idx, { GGML_OP_SSM_CONV, GGML_OP_UNARY })) {
return false;
}
silu = cgraph->nodes[node_idx + 1];
} else if (num_extra == 2) {
if (!ggml_can_fuse(cgraph, node_idx, { GGML_OP_SSM_CONV, GGML_OP_ADD, GGML_OP_UNARY })) {
return false;
}
const ggml_tensor * add = cgraph->nodes[node_idx + 1];
silu = cgraph->nodes[node_idx + 2];
bias = (add->src[0] == conv) ? add->src[1] : add->src[0];
if (bias->type != GGML_TYPE_F32 || !ggml_is_contiguous(bias)) {
return false;
}
// bias must be channel-wise (one element per channel of the conv output)
if (ggml_nelements(bias) != conv->ne[0] || bias->ne[0] != conv->ne[0]) {
return false;
}
if (add->type != GGML_TYPE_F32) {
return false;
}
// The shader doesn't apply per-tensor offsets, so reject misaligned bias.
if (get_misalign_bytes(ctx, bias) != 0) {
return false;
}
} else {
return false;
}
if (ggml_get_unary_op(silu) != GGML_UNARY_OP_SILU) {
return false;
}
if (conv->type != GGML_TYPE_F32 || silu->type != GGML_TYPE_F32) {
return false;
}
// The shader writes to the fused dst using its own strides, but the push constants don't
// carry a per-tensor offset, so the binding must be naturally aligned.
if (get_misalign_bytes(ctx, silu) != 0) {
return false;
}
return true;
}
static bool ggml_vk_can_fuse_topk_moe(ggml_backend_vk_context * ctx, const struct ggml_cgraph * cgraph,
int node_idx, topk_moe_mode mode) {
@ -14869,6 +14951,19 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
// they are overwritten, and one workgroup per row. So close enough.
op_srcs_fused_elementwise[0] = true;
op_srcs_fused_elementwise[1] = true;
} else if (ggml_vk_can_fuse_ssm_conv(ctx, cgraph, i, 2)) {
ctx->num_additional_fused_ops = 2;
fusion_string = "SSM_CONV_BIAS_SILU";
// ssm_conv reads multiple input tokens per output, so it's not elementwise w.r.t. its srcs.
// The downstream add and silu are elementwise on the conv output.
op_srcs_fused_elementwise[0] = false;
op_srcs_fused_elementwise[1] = true;
op_srcs_fused_elementwise[2] = true;
} else if (ggml_vk_can_fuse_ssm_conv(ctx, cgraph, i, 1)) {
ctx->num_additional_fused_ops = 1;
fusion_string = "SSM_CONV_SILU";
op_srcs_fused_elementwise[0] = false;
op_srcs_fused_elementwise[1] = true;
} else if (ggml_can_fuse_subgraph(cgraph, i, { GGML_OP_ROPE, GGML_OP_VIEW, GGML_OP_SET_ROWS }, { i + 2 }) &&
ggml_check_edges(cgraph, i, rope_view_set_rows_edges) &&
ggml_vk_can_fuse_rope_set_rows(ctx, cgraph, i)) {
@ -15200,7 +15295,9 @@ static void ggml_vk_graph_optimize(ggml_backend_t backend, struct ggml_cgraph *
!(j == c+1 && c == current_set.back() && graph->nodes[c]->op == GGML_OP_MUL_MAT && graph->nodes[j]->op == GGML_OP_ADD) &&
!(j == c+1 && c == current_set.back() && graph->nodes[c]->op == GGML_OP_MUL_MAT_ID && graph->nodes[j]->op == GGML_OP_ADD_ID) &&
!(j == c+1 && c == current_set.back() && graph->nodes[c]->op == GGML_OP_MUL_MAT_ID && graph->nodes[j]->op == GGML_OP_MUL) &&
!(j == c+1 && c == current_set.back() && graph->nodes[c]->op == GGML_OP_ADD && graph->nodes[j]->op == GGML_OP_ADD)) {
!(j == c+1 && c == current_set.back() && graph->nodes[c]->op == GGML_OP_ADD && graph->nodes[j]->op == GGML_OP_ADD) &&
!(j == c+1 && c == current_set.back() && graph->nodes[c]->op == GGML_OP_SSM_CONV && graph->nodes[j]->op == GGML_OP_ADD) &&
!(j == c+1 && c == current_set.back() && graph->nodes[c]->op == GGML_OP_SSM_CONV && graph->nodes[j]->op == GGML_OP_UNARY)) {
ok = false;
break;
}
@ -15283,6 +15380,19 @@ static void ggml_vk_graph_optimize(ggml_backend_t backend, struct ggml_cgraph *
}
}
}
// SSM_CONV + ADD + UNARY: pull the consuming UNARY forward
if (j > 0 &&
graph->nodes[j]->op == GGML_OP_ADD &&
graph->nodes[j-1]->op == GGML_OP_SSM_CONV) {
for (int k = j + 1; k < std::min(j + 15, graph->n_nodes); ++k) {
if (graph->nodes[k]->op == GGML_OP_UNARY &&
graph->nodes[k]->src[0] == graph->nodes[j]) {
current_set.push_back(k);
used[k] = true;
break;
}
}
}
}
}
// Second pass grabs view nodes.

View file

@ -6,12 +6,15 @@
layout(constant_id = 0) const uint BLOCK_SIZE = 32;
layout(constant_id = 1) const uint TOKENS_PER_WG = 16;
layout(constant_id = 2) const bool APPLY_BIAS = false;
layout(constant_id = 3) const bool APPLY_SILU = false;
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z = 1) in;
layout(binding = 0) readonly buffer Src0 { float src0[]; };
layout(binding = 1) readonly buffer Src1 { float src1[]; };
layout(binding = 2) buffer Dst { float dst[]; };
layout(binding = 2) readonly buffer Bias { float bias[]; };
layout(binding = 3) buffer Dst { float dst[]; };
layout(push_constant) uniform PushConstants {
uint nb01; uint nb02;
@ -45,6 +48,13 @@ void main() {
}
}
if (APPLY_BIAS) {
sum += bias[i1];
}
if (APPLY_SILU) {
sum = sum / (1.0f + exp(-sum));
}
const uint dst_idx = i3 * (dst_nb2 / 4) + i2 * (dst_nb1 / 4) + i1;
dst[dst_idx] = sum;
}