cuda: fuse snake activation (mul, sin, sqr, mul, add) (#22667)

* cuda: fuse snake activation (mul, sin, sqr, mul, add)

Add ggml_cuda_op_snake_fused with F32 / F16 / BF16 templates. The
matcher recognizes the naive 5 op decomposition emitted by audio
decoders (BigVGAN, Vocos) for snake activation
y = x + sin(a*x)^2 * inv_b and rewrites it to a single elementwise
kernel.

Add test_snake_fuse comparing CPU naive vs CUDA fused across
F32 / F16 / BF16.

* cuda: address review feedback from @am17an

Use ggml_cuda_cast for F32/F16/BF16 conversions and rename
kernel_snake to snake_kernel to match upstream conventions.

* cuda: snake fusion fastdiv on T_len, Suggested-by: @am17an

* Update tests/test-backend-ops.cpp

Co-authored-by: Aman Gupta <amangupta052@gmail.com>

* cuda: snake fusion check add->type matches x->type

Address review feedback from @am17an

* cuda: snake fusion check add->type matches x->type

Moved for readability (equivalent)
Address review feedback from @am17an

---------

Co-authored-by: Aman Gupta <amangupta052@gmail.com>
This commit is contained in:
Pascal 2026-05-08 11:44:09 +02:00 committed by GitHub
parent 9b2925e1e0
commit 58e68df0f9
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 191 additions and 0 deletions

View file

@ -3556,6 +3556,73 @@ struct test_relu_sqr : public test_case {
}
};
// SNAKE activation fusion: y = x + sin(a*x)^2 * inv_b
// CUDA backend matches the naive 5-op chain (mul, sin, sqr, mul, add)
// and dispatches a single fused kernel.
struct test_snake_fuse : public test_case {
const ggml_type type;
const std::array<int64_t, 2> ne; // [T, C]
std::string op_desc(ggml_tensor * t) override {
GGML_UNUSED(t);
return "SNAKE_FUSE";
}
bool run_whole_graph() override { return true; }
double max_nmse_err() override {
// BF16 epsilon ~ 7.8e-3, F16 epsilon ~ 9.7e-4: relax tolerance to match
// the natural roundoff drift between the naive CPU chain and the fused
// CUDA kernel. F32 keeps the default tight bound.
switch (type) {
case GGML_TYPE_BF16: return 5e-3;
case GGML_TYPE_F16: return 5e-5;
default: return 1e-7;
}
}
std::string vars() override {
return VARS_TO_STR2(type, ne);
}
test_snake_fuse(ggml_type type = GGML_TYPE_F32,
std::array<int64_t, 2> ne = {256, 192})
: type(type), ne(ne) {}
ggml_tensor * build_graph(ggml_context * ctx) override {
ggml_tensor * x = ggml_new_tensor_2d(ctx, type, ne[0], ne[1]);
ggml_set_name(x, "x");
ggml_tensor * a = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, 1, ne[1]);
ggml_set_name(a, "a");
ggml_tensor * inv_b = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, 1, ne[1]);
ggml_set_name(inv_b, "inv_b");
// exact 5-op chain that BigVGAN / Vocos frontends emit
ggml_tensor * ax = ggml_mul(ctx, x, a);
ggml_tensor * sin_ax = ggml_sin(ctx, ax);
ggml_tensor * sin_sq = ggml_sqr(ctx, sin_ax);
ggml_tensor * scaled = ggml_mul(ctx, sin_sq, inv_b);
ggml_tensor * out = ggml_add(ctx, x, scaled);
ggml_set_name(out, "out");
return out;
}
void initialize_tensors(ggml_context * ctx) override {
// x in [-pi, pi] to exercise sin periodicity, params in default [-1, 1]
for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != nullptr; t = ggml_get_next_tensor(ctx, t)) {
const std::string name = ggml_get_name(t);
if (name == "x") {
init_tensor_uniform(t, -3.14159f, 3.14159f);
} else {
init_tensor_uniform(t);
}
}
}
};
// GGML_OP_SSM_CONV
struct test_ssm_conv : public test_case {
const ggml_type type;
@ -7489,6 +7556,15 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
test_cases.emplace_back(new test_relu_sqr(type, { 5, 7, 11, 13 }));
}
// SNAKE activation fusion: x + sin(a*x)^2 * inv_b
for (ggml_type type : { GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_BF16 }) {
test_cases.emplace_back(new test_snake_fuse(type, { 5, 7})); // primes sub-block
test_cases.emplace_back(new test_snake_fuse(type, { 33, 32})); // boundary
test_cases.emplace_back(new test_snake_fuse(type, {1025, 13})); // large prime, grid-stride
test_cases.emplace_back(new test_snake_fuse(type, { 128, 16})); // power-of-two
test_cases.emplace_back(new test_snake_fuse(type, { 256, 192})); // BigVGAN-ish
}
// glu ops
for (ggml_type type : {GGML_TYPE_F16, GGML_TYPE_F32}) {
for (int v : {0, 1}) {
@ -9014,6 +9090,11 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_perf() {
test_cases.emplace_back(new test_pad_reflect_1d(GGML_TYPE_F32, {3000, 384, 1, 1}));
test_cases.emplace_back(new test_pad_reflect_1d(GGML_TYPE_F32, {3000, 384, 4, 1}));
// SNAKE activation fusion at BigVGAN scale (T=7680 = 24 kHz x 320 ms, C=192)
test_cases.emplace_back(new test_snake_fuse(GGML_TYPE_F32, {7680, 192}));
test_cases.emplace_back(new test_snake_fuse(GGML_TYPE_F16, {7680, 192}));
test_cases.emplace_back(new test_snake_fuse(GGML_TYPE_BF16, {7680, 192}));
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 16416, 1, 128, {8, 1}, {4, 1}, {0, 2, 1, 3}));
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 128, 1, 16416, {8, 1}, {4, 1}, {0, 1, 2, 3}, 2*16416));