mirror of
https://github.com/LostRuins/koboldcpp.git
synced 2026-05-22 19:47:49 +00:00
vulkan: fuse SSM_CONV + BIAS + SILU (#22653)
This commit is contained in:
parent
1a68ec9378
commit
3fbadb06dc
2 changed files with 129 additions and 9 deletions
|
|
@ -854,6 +854,8 @@ struct vk_device_struct {
|
|||
vk_pipeline pipeline_ssm_scan_f32_d128;
|
||||
vk_pipeline pipeline_ssm_scan_f32_d256;
|
||||
vk_pipeline pipeline_ssm_conv_f32;
|
||||
vk_pipeline pipeline_ssm_conv_silu_f32;
|
||||
vk_pipeline pipeline_ssm_conv_bias_silu_f32;
|
||||
vk_pipeline pipeline_opt_step_adamw_f32;
|
||||
vk_pipeline pipeline_opt_step_sgd_f32;
|
||||
std::map<vk_conv2d_pipeline_state, vk_pipeline> pipeline_conv2d_f32[CONV_SHAPE_COUNT];
|
||||
|
|
@ -4900,7 +4902,9 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|||
ggml_vk_create_pipeline(device, device->pipeline_ssm_scan_f32_d256, "ssm_scan_256_f32", ssm_scan_f32_len, ssm_scan_f32_data, "main", 8, sizeof(vk_op_ssm_scan_push_constants), {1, 1, 1}, {256, device->subgroup_size, 16}, 1, true, true);
|
||||
}
|
||||
|
||||
ggml_vk_create_pipeline(device, device->pipeline_ssm_conv_f32, "ssm_conv_f32", ssm_conv_f32_len, ssm_conv_f32_data, "main", 3, sizeof(vk_op_ssm_conv_push_constants), {32, 16, 1}, {32, 16}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_ssm_conv_f32, "ssm_conv_f32", ssm_conv_f32_len, ssm_conv_f32_data, "main", 4, sizeof(vk_op_ssm_conv_push_constants), {32, 16, 1}, {32, 16, 0, 0}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_ssm_conv_silu_f32, "ssm_conv_silu_f32", ssm_conv_f32_len, ssm_conv_f32_data, "main", 4, sizeof(vk_op_ssm_conv_push_constants), {32, 16, 1}, {32, 16, 0, 1}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_ssm_conv_bias_silu_f32, "ssm_conv_bias_silu_f32", ssm_conv_f32_len, ssm_conv_f32_data, "main", 4, sizeof(vk_op_ssm_conv_push_constants), {32, 16, 1}, {32, 16, 1, 1}, 1);
|
||||
|
||||
ggml_vk_create_pipeline(device, device->pipeline_opt_step_adamw_f32, "opt_step_adamw_f32", opt_step_adamw_f32_len, opt_step_adamw_f32_data, "main", 5, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
|
||||
|
||||
|
|
@ -9936,7 +9940,12 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
|
|||
return nullptr;
|
||||
case GGML_OP_SSM_CONV:
|
||||
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
|
||||
return ctx->device->pipeline_ssm_conv_f32;
|
||||
switch (ctx->num_additional_fused_ops) {
|
||||
case 0: return ctx->device->pipeline_ssm_conv_f32;
|
||||
case 1: return ctx->device->pipeline_ssm_conv_silu_f32;
|
||||
case 2: return ctx->device->pipeline_ssm_conv_bias_silu_f32;
|
||||
default: return nullptr;
|
||||
}
|
||||
}
|
||||
return nullptr;
|
||||
case GGML_OP_OPT_STEP_ADAMW:
|
||||
|
|
@ -10877,11 +10886,28 @@ static void ggml_vk_ssm_scan(ggml_backend_vk_context * ctx, vk_context& subctx,
|
|||
pc, elements);
|
||||
}
|
||||
|
||||
static void ggml_vk_ssm_conv(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst) {
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
const ggml_tensor * src1 = dst->src[1];
|
||||
static void ggml_vk_ssm_conv(ggml_backend_vk_context * ctx, vk_context& subctx, const struct ggml_cgraph * cgraph, int node_idx) {
|
||||
ggml_tensor * conv = cgraph->nodes[node_idx];
|
||||
const ggml_tensor * src0 = conv->src[0];
|
||||
const ggml_tensor * src1 = conv->src[1];
|
||||
|
||||
ggml_vk_op_f32<vk_op_ssm_conv_push_constants>(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_SSM_CONV, {
|
||||
// Pick the destination tensor (last node in the fused chain) and the optional bias.
|
||||
// Fusion modes: 0 = ssm_conv, 1 = ssm_conv+silu, 2 = ssm_conv+add(bias)+silu.
|
||||
ggml_tensor * dst = conv;
|
||||
const ggml_tensor * bias = nullptr;
|
||||
|
||||
if (ctx->num_additional_fused_ops == 1) {
|
||||
dst = cgraph->nodes[node_idx + 1]; // silu
|
||||
} else if (ctx->num_additional_fused_ops == 2) {
|
||||
ggml_tensor * add = cgraph->nodes[node_idx + 1];
|
||||
bias = (add->src[0] == conv) ? add->src[1] : add->src[0];
|
||||
dst = cgraph->nodes[node_idx + 2]; // silu
|
||||
}
|
||||
|
||||
// The shader always declares 4 bindings; bind src0 as a dummy when bias isn't fused.
|
||||
const ggml_tensor * src2 = bias ? bias : src0;
|
||||
|
||||
ggml_vk_op_f32<vk_op_ssm_conv_push_constants>(ctx, subctx, src0, src1, src2, nullptr, dst, GGML_OP_SSM_CONV, {
|
||||
(uint32_t)src0->nb[1], (uint32_t)src0->nb[2],
|
||||
(uint32_t)src1->nb[1],
|
||||
(uint32_t)dst->nb[0], (uint32_t)dst->nb[1], (uint32_t)dst->nb[2],
|
||||
|
|
@ -13556,7 +13582,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
|
|||
break;
|
||||
|
||||
case GGML_OP_SSM_CONV:
|
||||
ggml_vk_ssm_conv(ctx, compute_ctx, node);
|
||||
ggml_vk_ssm_conv(ctx, compute_ctx, cgraph, node_idx);
|
||||
|
||||
break;
|
||||
|
||||
|
|
@ -14453,6 +14479,62 @@ static bool ggml_vk_can_fuse(const ggml_backend_vk_context * ctx, const struct g
|
|||
return true;
|
||||
}
|
||||
|
||||
// Match SSM_CONV + UNARY(SILU) or SSM_CONV + ADD + UNARY(SILU). num_extra is 1 or 2.
|
||||
static bool ggml_vk_can_fuse_ssm_conv(const ggml_backend_vk_context * ctx, const struct ggml_cgraph * cgraph,
|
||||
int node_idx, int num_extra) {
|
||||
const ggml_tensor * conv = cgraph->nodes[node_idx];
|
||||
if (conv->op != GGML_OP_SSM_CONV) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const ggml_tensor * silu = nullptr;
|
||||
const ggml_tensor * bias = nullptr;
|
||||
|
||||
if (num_extra == 1) {
|
||||
if (!ggml_can_fuse(cgraph, node_idx, { GGML_OP_SSM_CONV, GGML_OP_UNARY })) {
|
||||
return false;
|
||||
}
|
||||
silu = cgraph->nodes[node_idx + 1];
|
||||
} else if (num_extra == 2) {
|
||||
if (!ggml_can_fuse(cgraph, node_idx, { GGML_OP_SSM_CONV, GGML_OP_ADD, GGML_OP_UNARY })) {
|
||||
return false;
|
||||
}
|
||||
const ggml_tensor * add = cgraph->nodes[node_idx + 1];
|
||||
silu = cgraph->nodes[node_idx + 2];
|
||||
bias = (add->src[0] == conv) ? add->src[1] : add->src[0];
|
||||
|
||||
if (bias->type != GGML_TYPE_F32 || !ggml_is_contiguous(bias)) {
|
||||
return false;
|
||||
}
|
||||
// bias must be channel-wise (one element per channel of the conv output)
|
||||
if (ggml_nelements(bias) != conv->ne[0] || bias->ne[0] != conv->ne[0]) {
|
||||
return false;
|
||||
}
|
||||
if (add->type != GGML_TYPE_F32) {
|
||||
return false;
|
||||
}
|
||||
// The shader doesn't apply per-tensor offsets, so reject misaligned bias.
|
||||
if (get_misalign_bytes(ctx, bias) != 0) {
|
||||
return false;
|
||||
}
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (ggml_get_unary_op(silu) != GGML_UNARY_OP_SILU) {
|
||||
return false;
|
||||
}
|
||||
if (conv->type != GGML_TYPE_F32 || silu->type != GGML_TYPE_F32) {
|
||||
return false;
|
||||
}
|
||||
// The shader writes to the fused dst using its own strides, but the push constants don't
|
||||
// carry a per-tensor offset, so the binding must be naturally aligned.
|
||||
if (get_misalign_bytes(ctx, silu) != 0) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool ggml_vk_can_fuse_topk_moe(ggml_backend_vk_context * ctx, const struct ggml_cgraph * cgraph,
|
||||
int node_idx, topk_moe_mode mode) {
|
||||
|
||||
|
|
@ -14869,6 +14951,19 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
|
|||
// they are overwritten, and one workgroup per row. So close enough.
|
||||
op_srcs_fused_elementwise[0] = true;
|
||||
op_srcs_fused_elementwise[1] = true;
|
||||
} else if (ggml_vk_can_fuse_ssm_conv(ctx, cgraph, i, 2)) {
|
||||
ctx->num_additional_fused_ops = 2;
|
||||
fusion_string = "SSM_CONV_BIAS_SILU";
|
||||
// ssm_conv reads multiple input tokens per output, so it's not elementwise w.r.t. its srcs.
|
||||
// The downstream add and silu are elementwise on the conv output.
|
||||
op_srcs_fused_elementwise[0] = false;
|
||||
op_srcs_fused_elementwise[1] = true;
|
||||
op_srcs_fused_elementwise[2] = true;
|
||||
} else if (ggml_vk_can_fuse_ssm_conv(ctx, cgraph, i, 1)) {
|
||||
ctx->num_additional_fused_ops = 1;
|
||||
fusion_string = "SSM_CONV_SILU";
|
||||
op_srcs_fused_elementwise[0] = false;
|
||||
op_srcs_fused_elementwise[1] = true;
|
||||
} else if (ggml_can_fuse_subgraph(cgraph, i, { GGML_OP_ROPE, GGML_OP_VIEW, GGML_OP_SET_ROWS }, { i + 2 }) &&
|
||||
ggml_check_edges(cgraph, i, rope_view_set_rows_edges) &&
|
||||
ggml_vk_can_fuse_rope_set_rows(ctx, cgraph, i)) {
|
||||
|
|
@ -15200,7 +15295,9 @@ static void ggml_vk_graph_optimize(ggml_backend_t backend, struct ggml_cgraph *
|
|||
!(j == c+1 && c == current_set.back() && graph->nodes[c]->op == GGML_OP_MUL_MAT && graph->nodes[j]->op == GGML_OP_ADD) &&
|
||||
!(j == c+1 && c == current_set.back() && graph->nodes[c]->op == GGML_OP_MUL_MAT_ID && graph->nodes[j]->op == GGML_OP_ADD_ID) &&
|
||||
!(j == c+1 && c == current_set.back() && graph->nodes[c]->op == GGML_OP_MUL_MAT_ID && graph->nodes[j]->op == GGML_OP_MUL) &&
|
||||
!(j == c+1 && c == current_set.back() && graph->nodes[c]->op == GGML_OP_ADD && graph->nodes[j]->op == GGML_OP_ADD)) {
|
||||
!(j == c+1 && c == current_set.back() && graph->nodes[c]->op == GGML_OP_ADD && graph->nodes[j]->op == GGML_OP_ADD) &&
|
||||
!(j == c+1 && c == current_set.back() && graph->nodes[c]->op == GGML_OP_SSM_CONV && graph->nodes[j]->op == GGML_OP_ADD) &&
|
||||
!(j == c+1 && c == current_set.back() && graph->nodes[c]->op == GGML_OP_SSM_CONV && graph->nodes[j]->op == GGML_OP_UNARY)) {
|
||||
ok = false;
|
||||
break;
|
||||
}
|
||||
|
|
@ -15283,6 +15380,19 @@ static void ggml_vk_graph_optimize(ggml_backend_t backend, struct ggml_cgraph *
|
|||
}
|
||||
}
|
||||
}
|
||||
// SSM_CONV + ADD + UNARY: pull the consuming UNARY forward
|
||||
if (j > 0 &&
|
||||
graph->nodes[j]->op == GGML_OP_ADD &&
|
||||
graph->nodes[j-1]->op == GGML_OP_SSM_CONV) {
|
||||
for (int k = j + 1; k < std::min(j + 15, graph->n_nodes); ++k) {
|
||||
if (graph->nodes[k]->op == GGML_OP_UNARY &&
|
||||
graph->nodes[k]->src[0] == graph->nodes[j]) {
|
||||
current_set.push_back(k);
|
||||
used[k] = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
// Second pass grabs view nodes.
|
||||
|
|
|
|||
|
|
@ -6,12 +6,15 @@
|
|||
|
||||
layout(constant_id = 0) const uint BLOCK_SIZE = 32;
|
||||
layout(constant_id = 1) const uint TOKENS_PER_WG = 16;
|
||||
layout(constant_id = 2) const bool APPLY_BIAS = false;
|
||||
layout(constant_id = 3) const bool APPLY_SILU = false;
|
||||
|
||||
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z = 1) in;
|
||||
|
||||
layout(binding = 0) readonly buffer Src0 { float src0[]; };
|
||||
layout(binding = 1) readonly buffer Src1 { float src1[]; };
|
||||
layout(binding = 2) buffer Dst { float dst[]; };
|
||||
layout(binding = 2) readonly buffer Bias { float bias[]; };
|
||||
layout(binding = 3) buffer Dst { float dst[]; };
|
||||
|
||||
layout(push_constant) uniform PushConstants {
|
||||
uint nb01; uint nb02;
|
||||
|
|
@ -45,6 +48,13 @@ void main() {
|
|||
}
|
||||
}
|
||||
|
||||
if (APPLY_BIAS) {
|
||||
sum += bias[i1];
|
||||
}
|
||||
if (APPLY_SILU) {
|
||||
sum = sum / (1.0f + exp(-sum));
|
||||
}
|
||||
|
||||
const uint dst_idx = i3 * (dst_nb2 / 4) + i2 * (dst_nb1 / 4) + i1;
|
||||
dst[dst_idx] = sum;
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue