llama: add llm_graph_input_mtp (#23643)

* llama: add llm_graph_input_mtp

* rename input_mtp -> input_token_embd

* add TODO about mtmd embedding

* cont : clean-up

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
This commit is contained in:
Aman Gupta 2026-05-29 14:17:32 +08:00 committed by GitHub
parent 98e480a32e
commit eef59a7642
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 91 additions and 16 deletions

View file

@ -102,6 +102,39 @@ bool llm_graph_input_embd::can_reuse(const llm_graph_params & params) {
return res;
}
void llm_graph_input_embd_h::set_input(const llama_ubatch * ubatch) {
const int64_t n_tokens = ubatch->n_tokens;
if (ubatch->token) {
ggml_backend_tensor_set(tokens, ubatch->token, 0, n_tokens*ggml_element_size(tokens));
} else {
// note: mtmd embedding input goes through here
GGML_ASSERT(ubatch->embd);
GGML_ASSERT(n_embd == embd->ne[0]);
ggml_backend_tensor_set(embd, ubatch->embd, 0, n_tokens*n_embd*ggml_element_size(h));
}
// TODO: extend llama_ubatch to differentiate between token embeddings and hidden states
// for now, we assume that the hidden state is always provided as an embedding
// ref: https://github.com/ggml-org/llama.cpp/pull/23643
if (ubatch->embd) {
GGML_ASSERT(n_embd == h->ne[0]);
ggml_backend_tensor_set(h, ubatch->embd, 0, n_tokens*n_embd*ggml_element_size(h));
}
}
bool llm_graph_input_embd_h::can_reuse(const llm_graph_params & params) {
bool res = true;
res &= (!params.ubatch.token) || (tokens && tokens->ne[0] == params.ubatch.n_tokens);
res &= (!params.ubatch.embd) || (embd && embd->ne[1] == params.ubatch.n_tokens);
res &= (!params.ubatch.embd) || (h && h->ne[1] == params.ubatch.n_tokens);
return res;
}
void llm_graph_input_pos::set_input(const llama_ubatch * ubatch) {
if (ubatch->pos && pos) {
const int64_t n_tokens = ubatch->n_tokens;

View file

@ -121,6 +121,23 @@ public:
const int64_t n_embd = 0;
};
// similar to llm_graph_input_embd but with an additional hidden state input
class llm_graph_input_embd_h : public llm_graph_input_i {
public:
llm_graph_input_embd_h(int64_t n_embd) : n_embd(n_embd) {}
virtual ~llm_graph_input_embd_h() = default;
void set_input(const llama_ubatch * ubatch) override;
bool can_reuse(const llm_graph_params & params) override;
ggml_tensor * tokens = nullptr; // I32 [n_batch]
ggml_tensor * embd = nullptr; // F32 [n_embd, n_batch]
ggml_tensor * h = nullptr; // F32 [n_embd, n_batch]
const int64_t n_embd = 0;
};
class llm_graph_input_pos : public llm_graph_input_i {
public:
llm_graph_input_pos(uint32_t n_pos_per_embd) : n_pos_per_embd(n_pos_per_embd) {}

View file

@ -508,28 +508,41 @@ llama_model_qwen35::graph_mtp::graph_mtp(const llama_model & model, const llm_gr
int sections[4];
std::copy(std::begin(hparams.rope_sections), std::begin(hparams.rope_sections) + 4, sections);
auto inp = std::make_unique<llm_graph_input_embd>(hparams.n_embd);
// TODO: extract in a common llm_graph_context::build_inp_embd_h()
auto inp = std::make_unique<llm_graph_input_embd_h>(hparams.n_embd);
inp->tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens);
ggml_set_input(inp->tokens);
inp->embd = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, hparams.n_embd, n_tokens);
inp->embd = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, hparams.n_embd_inp(), n_tokens);
ggml_set_input(inp->embd);
ggml_set_name(inp->embd, "mtp_h_input");
ggml_tensor * tok_embd_w = layer.nextn.embed_tokens ? layer.nextn.embed_tokens : model.tok_embd;
// TODO: make static using `ggml_build_forward_select()`
// see llm_graph_context::build_inp_embd() for reference
ggml_tensor * tok_embd;
if (ubatch.token) {
ggml_tensor * tok_embd_w = layer.nextn.embed_tokens ? layer.nextn.embed_tokens : model.tok_embd;
ggml_tensor * h_input = inp->embd;
ggml_tensor * tok_embd = ggml_get_rows(ctx0, tok_embd_w, inp->tokens);
tok_embd = ggml_get_rows(ctx0, tok_embd_w, inp->tokens);
} else {
tok_embd = inp->embd;
}
cb(tok_embd, "mtp_tok_embd", il);
inp->h = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, hparams.n_embd, n_tokens);
ggml_set_input(inp->h);
ggml_set_name(inp->h, "mtp_h_input");
ggml_tensor * h_embd = inp->h;
res->add_input(std::move(inp));
ggml_tensor * inp_pos = build_inp_pos();
ggml_tensor * inp_out_ids = build_inp_out_ids();
auto * inp_attn = build_attn_inp_kv();
ggml_tensor * h_norm = build_norm(h_input, layer.nextn.hnorm, nullptr, LLM_NORM_RMS, il);
auto * inp_attn = build_attn_inp_kv();
ggml_tensor * h_norm = build_norm(h_embd, layer.nextn.hnorm, nullptr, LLM_NORM_RMS, il);
cb(h_norm, "mtp_hnorm", il);
ggml_tensor * e_norm = build_norm(tok_embd, layer.nextn.enorm, nullptr, LLM_NORM_RMS, il);

View file

@ -571,29 +571,41 @@ llama_model_qwen35moe::graph_mtp::graph_mtp(const llama_model & model, const llm
int sections[4];
std::copy(std::begin(hparams.rope_sections), std::begin(hparams.rope_sections) + 4, sections);
auto inp = std::make_unique<llm_graph_input_embd>(hparams.n_embd);
// TODO: extract in a common llm_graph_context::build_inp_embd_h()
auto inp = std::make_unique<llm_graph_input_embd_h>(hparams.n_embd);
inp->tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens);
ggml_set_input(inp->tokens);
inp->embd = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, hparams.n_embd, n_tokens);
inp->embd = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, hparams.n_embd_inp(), n_tokens);
ggml_set_input(inp->embd);
ggml_set_name(inp->embd, "mtp_h_input");
ggml_tensor * tok_embd_w = layer.nextn.embed_tokens ? layer.nextn.embed_tokens : model.tok_embd;
// TODO: make static using `ggml_build_forward_select()`
// see llm_graph_context::build_inp_embd() for reference
ggml_tensor * tok_embd;
if (ubatch.token) {
ggml_tensor * tok_embd_w = layer.nextn.embed_tokens ? layer.nextn.embed_tokens : model.tok_embd;
ggml_tensor * h_input = inp->embd;
ggml_tensor * tok_embd = ggml_get_rows(ctx0, tok_embd_w, inp->tokens);
tok_embd = ggml_get_rows(ctx0, tok_embd_w, inp->tokens);
} else {
tok_embd = inp->embd;
}
cb(tok_embd, "mtp_tok_embd", il);
inp->h = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, hparams.n_embd, n_tokens);
ggml_set_input(inp->h);
ggml_set_name(inp->h, "mtp_h_input");
ggml_tensor * h_embd = inp->h;
res->add_input(std::move(inp));
ggml_tensor * inp_pos = build_inp_pos();
ggml_tensor * inp_out_ids = build_inp_out_ids();
auto * inp_attn = build_attn_inp_kv();
auto * inp_attn = build_attn_inp_kv();
ggml_tensor * h_norm = build_norm(h_input, layer.nextn.hnorm, nullptr, LLM_NORM_RMS, il);
ggml_tensor * h_norm = build_norm(h_embd, layer.nextn.hnorm, nullptr, LLM_NORM_RMS, il);
cb(h_norm, "mtp_hnorm", il);
ggml_tensor * e_norm = build_norm(tok_embd, layer.nextn.enorm, nullptr, LLM_NORM_RMS, il);