llama: add llm_graph_input_mtp (#23643)

* llama: add llm_graph_input_mtp

* rename input_mtp -> input_token_embd

* add TODO about mtmd embedding

* cont : clean-up

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
This commit is contained in:
Aman Gupta 2026-05-29 14:17:32 +08:00 committed by GitHub
parent 98e480a32e
commit eef59a7642
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 91 additions and 16 deletions

View file

@ -102,6 +102,39 @@ bool llm_graph_input_embd::can_reuse(const llm_graph_params & params) {
return res;
}
void llm_graph_input_embd_h::set_input(const llama_ubatch * ubatch) {
const int64_t n_tokens = ubatch->n_tokens;
if (ubatch->token) {
ggml_backend_tensor_set(tokens, ubatch->token, 0, n_tokens*ggml_element_size(tokens));
} else {
// note: mtmd embedding input goes through here
GGML_ASSERT(ubatch->embd);
GGML_ASSERT(n_embd == embd->ne[0]);
ggml_backend_tensor_set(embd, ubatch->embd, 0, n_tokens*n_embd*ggml_element_size(h));
}
// TODO: extend llama_ubatch to differentiate between token embeddings and hidden states
// for now, we assume that the hidden state is always provided as an embedding
// ref: https://github.com/ggml-org/llama.cpp/pull/23643
if (ubatch->embd) {
GGML_ASSERT(n_embd == h->ne[0]);
ggml_backend_tensor_set(h, ubatch->embd, 0, n_tokens*n_embd*ggml_element_size(h));
}
}
bool llm_graph_input_embd_h::can_reuse(const llm_graph_params & params) {
bool res = true;
res &= (!params.ubatch.token) || (tokens && tokens->ne[0] == params.ubatch.n_tokens);
res &= (!params.ubatch.embd) || (embd && embd->ne[1] == params.ubatch.n_tokens);
res &= (!params.ubatch.embd) || (h && h->ne[1] == params.ubatch.n_tokens);
return res;
}
void llm_graph_input_pos::set_input(const llama_ubatch * ubatch) {
if (ubatch->pos && pos) {
const int64_t n_tokens = ubatch->n_tokens;