tests: add long-sequence cases and fix inputs for gated_delta_net (#22794)

* tests : add long-seq + tail cases for gated_delta_net

* tests : realistic input ranges for gated_delta_net
This commit is contained in:
HaoJun ZHANG 2026-05-08 00:23:36 +08:00 committed by GitHub
parent ad09224658
commit deab41ec68
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -3763,13 +3763,37 @@ struct test_gated_delta_net : public test_case {
k = ggml_new_tensor_4d(ctx, type, head_size, head_count, n_seq_tokens, n_seqs);
v = ggml_new_tensor_4d(ctx, type, head_size, head_count * v_repeat, n_seq_tokens, n_seqs);
}
ggml_set_name(q, "q");
ggml_set_name(k, "k");
ggml_set_name(v, "v");
const int64_t g_ne0 = kda ? head_size : 1;
ggml_tensor * g = ggml_new_tensor_4d(ctx, type, g_ne0, head_count * v_repeat, n_seq_tokens, n_seqs);
ggml_tensor * beta = ggml_new_tensor_4d(ctx, type, 1, head_count * v_repeat, n_seq_tokens, n_seqs);
ggml_tensor * state = ggml_new_tensor_2d(ctx, type, head_size * v_repeat * head_size * head_count, n_seqs);
ggml_set_name(g, "g");
ggml_set_name(beta, "beta");
ggml_set_name(state, "state");
// q/k are L2-normalised in qwen35/kimi-linear before delta_net
q = ggml_l2_norm(ctx, q, 1e-6f);
k = ggml_l2_norm(ctx, k, 1e-6f);
ggml_tensor * out = ggml_gated_delta_net(ctx, q, k, v, g, beta, state);
return out;
}
void initialize_tensors(ggml_context * ctx) override {
for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != nullptr; t = ggml_get_next_tensor(ctx, t)) {
if (ggml_is_view_op(t->op)) { continue; }
if (strcmp(t->name, "g") == 0) {
init_tensor_uniform(t, -20.0f, -1e-4f);
} else if (strcmp(t->name, "beta") == 0) {
init_tensor_uniform(t, 0.0f, 1.0f);
} else if (strcmp(t->name, "v") == 0) {
init_tensor_uniform(t, -0.3f, 5.0f);
} else {
init_tensor_uniform(t);
}
}
}
};
// GGML_OP_GATED_LINEAR_ATTN
@ -8871,6 +8895,17 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 8, 32, 4, 2, 2, false, true));
test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 4, 64, 4, 2, 1, true, true));
test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 4, 16, 4, 2, 1, true, true));
// chunked path: multi-chunk and non-multiple-of-chunk-size (chunk_size=64 GDN, 16 KDA)
test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 4, 64, 64, 1));
test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 4, 64, 127, 1));
test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 4, 64, 256, 1));
test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 4, 64, 65, 1));
test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 4, 64, 100, 1));
test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 4, 64, 200, 1));
test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 4, 64, 127, 2));
test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 4, 64, 64, 1, 1, false, true));
test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 4, 64, 33, 1, 1, false, true));
test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 4, 64, 100, 1, 1, false, true));
#if 0
// these tests are disabled to save execution time, sbut they can be handy for debugging