ggml: add GATED_DELTA_NET op (#19504)

* ggml: add GATED_DELTA_NET op

* remove the transpose

* add KDA

* add qwen35 dense

* llama : check for fused gated delta net backend support

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
This commit is contained in:
Aman Gupta 2026-03-07 15:41:10 +08:00 committed by GitHub
parent 6fce5c6a7d
commit c5a778891b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
15 changed files with 627 additions and 10 deletions

View file

@ -150,6 +150,9 @@ llama_context::llama_context(
cparams.flash_attn = params.flash_attn_type != LLAMA_FLASH_ATTN_TYPE_DISABLED;
cparams.auto_fa = params.flash_attn_type == LLAMA_FLASH_ATTN_TYPE_AUTO;
cparams.fused_gdn_ar = true;
cparams.fused_gdn_ch = false; // TODO: implement
// with causal attention, the batch size is limited by the context size
cparams.n_batch = cparams.causal_attn ? std::min(cparams.n_ctx, params.n_batch) : params.n_batch;
@ -422,7 +425,7 @@ void llama_context::sched_reserve() {
if (cparams.auto_fa) {
auto * gf = graph_reserve(1, n_seqs, n_outputs, mctx.get(), true);
if (!gf) {
throw std::runtime_error("failed to split graph for Flash Attention check");
throw std::runtime_error("failed to reserve graph for Flash Attention check");
}
const size_t prefix_len = strlen(LLAMA_TENSOR_NAME_FATTN) + 1;
@ -432,8 +435,7 @@ void llama_context::sched_reserve() {
if (n->op != GGML_OP_FLASH_ATTN_EXT) {
continue;
}
ggml_backend_dev_t device_fa = ggml_backend_get_device(
ggml_backend_sched_get_tensor_backend(sched.get(), n));
ggml_backend_dev_t device_fa = ggml_backend_get_device(ggml_backend_sched_get_tensor_backend(sched.get(), n));
// TODO: instead of the tensor names, use a map to keep track of which (FA) tensors belong to which layer
GGML_ASSERT(strncmp(n->name, LLAMA_TENSOR_NAME_FATTN "-", prefix_len) == 0);
@ -448,6 +450,7 @@ void llama_context::sched_reserve() {
break;
}
}
if (fa_device_mismatch) {
cparams.flash_attn = false;
LLAMA_LOG_WARN("%s: Flash Attention was auto, set to disabled\n", __func__);
@ -459,6 +462,39 @@ void llama_context::sched_reserve() {
cparams.auto_fa = false;
}
if (cparams.fused_gdn_ar) {
auto * gf = graph_reserve(1, n_seqs, n_outputs, mctx.get(), true);
if (!gf) {
throw std::runtime_error("failed to reserve graph for fused Gated Delta Net check");
}
const size_t prefix_len = strlen(LLAMA_TENSOR_NAME_FGDNAR) + 1;
bool gdn_device_mismatch = false;
for (int i = 0; i < ggml_graph_n_nodes(gf); i++) {
ggml_tensor * n = ggml_graph_node(gf, i);
if (n->op != GGML_OP_GATED_DELTA_NET) {
continue;
}
ggml_backend_dev_t device_gdn = ggml_backend_get_device(ggml_backend_sched_get_tensor_backend(sched.get(), n));
GGML_ASSERT(strncmp(n->name, LLAMA_TENSOR_NAME_FGDNAR "-", prefix_len) == 0);
const int il = std::stoi(n->name + prefix_len);
ggml_backend_dev_t device_kv = model.dev_layer(il);
if (device_gdn != device_kv) {
LLAMA_LOG_WARN("%s: layer %d is assigned to device %s but the fused Gated Delta Net tensor "
"is assigned to device %s (usually due to missing support)\n",
__func__, il, ggml_backend_dev_name(device_kv), ggml_backend_dev_name(device_gdn));
gdn_device_mismatch = true;
break;
}
}
if (gdn_device_mismatch) {
cparams.fused_gdn_ar = false;
LLAMA_LOG_WARN("%s: fused Gated Delta Net not supported, set to disabled\n", __func__);
}
}
// reserve worst-case graph
int n_splits_pp = -1;
int n_nodes_pp = -1;

View file

@ -31,6 +31,8 @@ struct llama_cparams {
bool offload_kqv;
bool flash_attn;
bool auto_fa;
bool fused_gdn_ar; // use fused gated delta net (autoregressive)
bool fused_gdn_ch; // use fused gated delta net (chunked)
bool no_perf;
bool warmup;
bool op_offload;

View file

@ -70,4 +70,6 @@ std::string llama_format_tensor_shape(const struct ggml_tensor * t);
std::string gguf_kv_to_str(const struct gguf_context * ctx_gguf, int i);
#define LLAMA_TENSOR_NAME_FATTN "__fattn__"
#define LLAMA_TENSOR_NAME_FATTN "__fattn__"
#define LLAMA_TENSOR_NAME_FGDNAR "__fgdnar__"
#define LLAMA_TENSOR_NAME_FGDNCH "__fgdnch__"

View file

@ -1,5 +1,7 @@
#include "models.h"
#include "llama-impl.h"
// utility to get one slice from the third dimension
// input dim: [x, y, c, b]
// output dim: [x, y, 1, b]
@ -39,6 +41,13 @@ std::pair<ggml_tensor *, ggml_tensor *> llm_build_delta_net_base::build_delta_ne
GGML_ASSERT(b->ne[0] == 1 && b->ne[1] == H_v && b->ne[2] == n_tokens && b->ne[3] == n_seqs);
GGML_ASSERT(s->ne[0] == S_v && s->ne[1] == S_v && s->ne[2] == H_v && s->ne[3] == n_seqs);
if (cparams.fused_gdn_ch) {
//ggml_tensor * result = ggml_gated_delta_net(ctx0, q, k, v, g, b, s);
//cb(result, LLAMA_TENSOR_NAME_FGDNCH, il);
GGML_ABORT("not implemented yet");
}
const float scale = 1.0f / sqrtf(S_k);
q = ggml_scale(ctx0, q, scale);
@ -316,6 +325,26 @@ std::pair<ggml_tensor *, ggml_tensor *> llm_build_delta_net_base::build_delta_ne
GGML_ASSERT(b->ne[0] == 1 && b->ne[1] == H_v && b->ne[2] == n_tokens && b->ne[3] == n_seqs);
GGML_ASSERT(s->ne[0] == S_v && s->ne[1] == S_v && s->ne[2] == H_v && s->ne[3] == n_seqs);
if (cparams.fused_gdn_ar) {
ggml_tensor * result = ggml_gated_delta_net(ctx0, q, k, v, g, b, s);
cb(result, LLAMA_TENSOR_NAME_FGDNAR, il);
ggml_tensor * output = ggml_view_4d(ctx0, result,
S_v, H_v, n_tokens, n_seqs,
ggml_row_size(result->type, S_v),
ggml_row_size(result->type, S_v * H_v),
ggml_row_size(result->type, S_v * H_v * n_tokens), 0);
ggml_tensor * new_state = ggml_view_4d(ctx0, result,
S_v, S_v, H_v, n_seqs,
ggml_row_size(result->type, S_v),
ggml_row_size(result->type, S_v * S_v),
ggml_row_size(result->type, S_v * S_v * H_v),
ggml_row_size(result->type, S_v * H_v * n_tokens * n_seqs));
return {output, new_state};
}
const float scale = 1.0f / sqrtf(S_k);
q = ggml_scale(ctx0, q, scale);

View file

@ -332,8 +332,7 @@ ggml_tensor * llm_build_qwen35::build_layer_attn_linear(
cb(k_conv, "k_conv_predelta", il);
cb(v_conv, "v_conv_predelta", il);
// Choose between build_delta_net_chunking, build_delta_net_recurrent, and build_delta_net_autoregressive based on n_tokens
std::pair<ggml_tensor *, ggml_tensor *> attn_out; // pair of (output, new_state)
std::pair<ggml_tensor *, ggml_tensor *> attn_out;
if (n_seq_tokens == 1) {
attn_out = build_delta_net_autoregressive(q_conv, k_conv, v_conv, gate, beta, state, il);
} else {

View file

@ -332,8 +332,7 @@ ggml_tensor * llm_build_qwen35moe ::build_layer_attn_linear(
cb(k_conv, "k_conv_predelta", il);
cb(v_conv, "v_conv_predelta", il);
// Choose between build_delta_net_chunking, build_delta_net_recurrent, and build_delta_net_autoregressive based on n_tokens
std::pair<ggml_tensor *, ggml_tensor *> attn_out; // pair of (output, new_state)
std::pair<ggml_tensor *, ggml_tensor *> attn_out;
if (n_seq_tokens == 1) {
attn_out = build_delta_net_autoregressive(q_conv, k_conv, v_conv, gate, beta, state, il);
} else {