mirror of
https://github.com/LostRuins/koboldcpp.git
synced 2026-05-31 21:39:42 +00:00
llama: add llm_graph_input_mtp (#23643)
* llama: add llm_graph_input_mtp * rename input_mtp -> input_token_embd * add TODO about mtmd embedding * cont : clean-up --------- Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
This commit is contained in:
parent
98e480a32e
commit
eef59a7642
4 changed files with 91 additions and 16 deletions
|
|
@ -102,6 +102,39 @@ bool llm_graph_input_embd::can_reuse(const llm_graph_params & params) {
|
|||
return res;
|
||||
}
|
||||
|
||||
void llm_graph_input_embd_h::set_input(const llama_ubatch * ubatch) {
|
||||
const int64_t n_tokens = ubatch->n_tokens;
|
||||
|
||||
if (ubatch->token) {
|
||||
ggml_backend_tensor_set(tokens, ubatch->token, 0, n_tokens*ggml_element_size(tokens));
|
||||
} else {
|
||||
// note: mtmd embedding input goes through here
|
||||
GGML_ASSERT(ubatch->embd);
|
||||
GGML_ASSERT(n_embd == embd->ne[0]);
|
||||
|
||||
ggml_backend_tensor_set(embd, ubatch->embd, 0, n_tokens*n_embd*ggml_element_size(h));
|
||||
}
|
||||
|
||||
// TODO: extend llama_ubatch to differentiate between token embeddings and hidden states
|
||||
// for now, we assume that the hidden state is always provided as an embedding
|
||||
// ref: https://github.com/ggml-org/llama.cpp/pull/23643
|
||||
if (ubatch->embd) {
|
||||
GGML_ASSERT(n_embd == h->ne[0]);
|
||||
|
||||
ggml_backend_tensor_set(h, ubatch->embd, 0, n_tokens*n_embd*ggml_element_size(h));
|
||||
}
|
||||
}
|
||||
|
||||
bool llm_graph_input_embd_h::can_reuse(const llm_graph_params & params) {
|
||||
bool res = true;
|
||||
|
||||
res &= (!params.ubatch.token) || (tokens && tokens->ne[0] == params.ubatch.n_tokens);
|
||||
res &= (!params.ubatch.embd) || (embd && embd->ne[1] == params.ubatch.n_tokens);
|
||||
res &= (!params.ubatch.embd) || (h && h->ne[1] == params.ubatch.n_tokens);
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
void llm_graph_input_pos::set_input(const llama_ubatch * ubatch) {
|
||||
if (ubatch->pos && pos) {
|
||||
const int64_t n_tokens = ubatch->n_tokens;
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue