mirror of
https://github.com/LostRuins/koboldcpp.git
synced 2026-05-23 04:19:08 +00:00
llama: avoid copying logits during prompt decode in MTP (#23198)
* llama: avoid copying logits during prompt decode in MTP * review: update comment * llama-graph: call set_output for t_h_pre_norm
This commit is contained in:
parent
39cf5d6191
commit
3e12fbdea5
10 changed files with 91 additions and 27 deletions
|
|
@ -146,8 +146,11 @@ struct common_speculative_impl {
|
|||
|
||||
virtual void accept(llama_seq_id seq_id, uint16_t n_accepted) = 0;
|
||||
|
||||
// true if this implementation requires the target context to extract embeddings
|
||||
// true if this implementation requires the target context to extract post-norm embeddings
|
||||
virtual bool need_embd() const = 0;
|
||||
|
||||
// true if this implementation requires the target context to extract pre-norm embeddings
|
||||
virtual bool need_embd_pre_norm() const { return false; }
|
||||
};
|
||||
|
||||
struct common_speculative_impl_draft_simple : public common_speculative_impl {
|
||||
|
|
@ -429,8 +432,8 @@ struct common_speculative_state_draft_mtp : public common_speculative_impl {
|
|||
s.reset(common_sampler_init(llama_get_model(ctx_dft), sparams));
|
||||
}
|
||||
|
||||
llama_set_embeddings_pre_norm(ctx_tgt, true);
|
||||
llama_set_embeddings_pre_norm(ctx_dft, true);
|
||||
llama_set_embeddings_pre_norm(ctx_tgt, true, /*masked*/ false);
|
||||
llama_set_embeddings_pre_norm(ctx_dft, true, /*masked*/ true);
|
||||
|
||||
pending_h.assign(n_seq, std::vector<float>(n_embd, 0.0f));
|
||||
|
||||
|
|
@ -691,6 +694,10 @@ struct common_speculative_state_draft_mtp : public common_speculative_impl {
|
|||
}
|
||||
|
||||
bool need_embd() const override {
|
||||
return false;
|
||||
}
|
||||
|
||||
bool need_embd_pre_norm() const override {
|
||||
return true;
|
||||
}
|
||||
};
|
||||
|
|
@ -1408,6 +1415,20 @@ bool common_speculative_need_embd(common_speculative * spec) {
|
|||
return false;
|
||||
}
|
||||
|
||||
bool common_speculative_need_embd_pre_norm(common_speculative * spec) {
|
||||
if (spec == nullptr) {
|
||||
return false;
|
||||
}
|
||||
|
||||
for (auto & impl : spec->impls) {
|
||||
if (impl->need_embd_pre_norm()) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
void common_speculative_draft(common_speculative * spec) {
|
||||
if (spec == nullptr) {
|
||||
return;
|
||||
|
|
|
|||
|
|
@ -53,9 +53,12 @@ void common_speculative_begin(common_speculative * spec, llama_seq_id seq_id, co
|
|||
// process the batch and update the internal state of the speculative context
|
||||
bool common_speculative_process(common_speculative * spec, const llama_batch & batch);
|
||||
|
||||
// true if any implementation requires target embeddings to be extracted
|
||||
// true if any implementation requires target post-norm embeddings to be extracted
|
||||
bool common_speculative_need_embd(common_speculative * spec);
|
||||
|
||||
// true if any implementation requires target pre-norm embeddings to be extracted
|
||||
bool common_speculative_need_embd_pre_norm(common_speculative * spec);
|
||||
|
||||
// generate drafts for the sequences specified with `common_speculative_get_draft_params`
|
||||
void common_speculative_draft(common_speculative * spec);
|
||||
|
||||
|
|
|
|||
|
|
@ -895,8 +895,17 @@ float * llama_context::get_embeddings_pre_norm_ith(int32_t i) {
|
|||
throw std::runtime_error("no pre-norm embeddings");
|
||||
}
|
||||
|
||||
const int64_t j = output_resolve_row(i);
|
||||
const uint32_t n_embd = model.hparams.n_embd;
|
||||
|
||||
if (!cparams.embeddings_pre_norm_masked) {
|
||||
// unmasked: pre-norm rows are stored densely, indexed by raw token position.
|
||||
if (i < 0 || (size_t)(i + 1) * n_embd > embd_pre_norm.size) {
|
||||
throw std::runtime_error(format("out of range [0, %zu)", embd_pre_norm.size / n_embd));
|
||||
}
|
||||
return embd_pre_norm.data + (size_t) i * n_embd;
|
||||
}
|
||||
|
||||
const int64_t j = output_resolve_row(i);
|
||||
return embd_pre_norm.data + j*n_embd;
|
||||
} catch (const std::exception & err) {
|
||||
LLAMA_LOG_ERROR("%s: invalid pre-norm embeddings id %d, reason: %s\n", __func__, i, err.what());
|
||||
|
|
@ -1088,10 +1097,11 @@ void llama_context::set_embeddings(bool value) {
|
|||
//sched_need_reserve = true;
|
||||
}
|
||||
|
||||
void llama_context::set_embeddings_pre_norm(bool value) {
|
||||
LLAMA_LOG_DEBUG("%s: value = %d\n", __func__, value);
|
||||
void llama_context::set_embeddings_pre_norm(bool value, bool masked) {
|
||||
LLAMA_LOG_DEBUG("%s: value = %d, masked = %d\n", __func__, value, masked);
|
||||
|
||||
cparams.embeddings_pre_norm = value;
|
||||
cparams.embeddings_pre_norm = value;
|
||||
cparams.embeddings_pre_norm_masked = masked;
|
||||
}
|
||||
|
||||
void llama_context::set_causal_attn(bool value) {
|
||||
|
|
@ -1737,6 +1747,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
|
|||
};
|
||||
|
||||
int64_t n_outputs_prev = 0;
|
||||
int64_t n_tokens_prev = 0;
|
||||
|
||||
do {
|
||||
const auto & ubatch = mctx->get_ubatch();
|
||||
|
|
@ -1882,16 +1893,21 @@ int llama_context::decode(const llama_batch & batch_inp) {
|
|||
|
||||
// extract pre-norm embeddings (hidden state before the final output norm)
|
||||
// only meaningful in LLAMA_POOLING_TYPE_NONE (per-token); other pooling modes are ignored.
|
||||
if (embd_pre_norm.data && t_h_pre_norm && n_outputs > 0 && cparams.pooling_type == LLAMA_POOLING_TYPE_NONE) {
|
||||
ggml_backend_t backend_h = ggml_backend_sched_get_tensor_backend(sched.get(), t_h_pre_norm);
|
||||
GGML_ASSERT(backend_h != nullptr);
|
||||
{
|
||||
const bool masked = cparams.embeddings_pre_norm_masked;
|
||||
const int64_t n_rows = masked ? n_outputs : (int64_t) ubatch.n_tokens;
|
||||
const int64_t offset = masked ? n_outputs_prev : n_tokens_prev;
|
||||
|
||||
const uint32_t n_embd = hparams.n_embd;
|
||||
float * embd_pre_norm_out = embd_pre_norm.data + n_outputs_prev*n_embd;
|
||||
if (embd_pre_norm.data && t_h_pre_norm && n_rows > 0 && cparams.pooling_type == LLAMA_POOLING_TYPE_NONE) {
|
||||
ggml_backend_t backend_h = ggml_backend_sched_get_tensor_backend(sched.get(), t_h_pre_norm);
|
||||
GGML_ASSERT(backend_h != nullptr);
|
||||
|
||||
GGML_ASSERT( n_outputs_prev + n_outputs <= n_outputs_all);
|
||||
GGML_ASSERT((n_outputs_prev + n_outputs)*n_embd <= (int64_t) embd_pre_norm.size);
|
||||
ggml_backend_tensor_get_async(backend_h, t_h_pre_norm, embd_pre_norm_out, 0, n_outputs*n_embd*sizeof(float));
|
||||
const uint32_t n_embd = hparams.n_embd;
|
||||
float * embd_pre_norm_out = embd_pre_norm.data + offset*n_embd;
|
||||
|
||||
GGML_ASSERT((offset + n_rows)*n_embd <= (int64_t) embd_pre_norm.size);
|
||||
ggml_backend_tensor_get_async(backend_h, t_h_pre_norm, embd_pre_norm_out, 0, n_rows*n_embd*sizeof(float));
|
||||
}
|
||||
}
|
||||
|
||||
// Copy backend sampling output if this ubatch produced any sampling tensors.
|
||||
|
|
@ -1908,6 +1924,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
|
|||
}
|
||||
|
||||
n_outputs_prev += n_outputs;
|
||||
n_tokens_prev += ubatch.n_tokens;
|
||||
} while (mctx->next());
|
||||
|
||||
// set to total number of outputs in the batch, for use in llama_get_logits_ith
|
||||
|
|
@ -1999,6 +2016,12 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) {
|
|||
embd.size = has_embd ? n_embd_out*n_outputs_max : 0;
|
||||
embd_pre_norm.size = has_embd_pre_norm ? n_embd*n_outputs_max : 0;
|
||||
|
||||
if (has_embd_pre_norm && !cparams.embeddings_pre_norm_masked) {
|
||||
// unmasked: pre-norm row exists for every token in the batch, not just
|
||||
// those flagged via batch.logits[i] -> size by token count instead.
|
||||
embd_pre_norm.size = (size_t) n_embd * n_batch;
|
||||
}
|
||||
|
||||
// Allocate backend sampling output buffers if there are backend samplers configured.
|
||||
const bool has_sampling = !sampling.samplers.empty();
|
||||
if (has_sampling) {
|
||||
|
|
@ -3547,8 +3570,8 @@ float * llama_get_embeddings_seq(llama_context * ctx, llama_seq_id seq_id) {
|
|||
return ctx->get_embeddings_seq(seq_id);
|
||||
}
|
||||
|
||||
void llama_set_embeddings_pre_norm(llama_context * ctx, bool value) {
|
||||
ctx->set_embeddings_pre_norm(value);
|
||||
void llama_set_embeddings_pre_norm(llama_context * ctx, bool value, bool masked) {
|
||||
ctx->set_embeddings_pre_norm(value, masked);
|
||||
}
|
||||
|
||||
float * llama_get_embeddings_pre_norm(llama_context * ctx) {
|
||||
|
|
|
|||
|
|
@ -110,7 +110,7 @@ struct llama_context {
|
|||
void set_abort_callback(bool (*abort_callback)(void * data), void * abort_callback_data);
|
||||
|
||||
void set_embeddings (bool value);
|
||||
void set_embeddings_pre_norm(bool value);
|
||||
void set_embeddings_pre_norm(bool value, bool masked);
|
||||
void set_causal_attn(bool value);
|
||||
void set_warmup(bool value);
|
||||
|
||||
|
|
|
|||
|
|
@ -28,7 +28,8 @@ struct llama_cparams {
|
|||
float yarn_beta_slow;
|
||||
|
||||
bool embeddings;
|
||||
bool embeddings_pre_norm; // also extract the hidden state before the final output norm
|
||||
bool embeddings_pre_norm; // also extract the hidden state before the final output norm
|
||||
bool embeddings_pre_norm_masked; // extract for only rows where batch.logits != 0
|
||||
bool causal_attn;
|
||||
bool offload_kqv;
|
||||
bool flash_attn;
|
||||
|
|
|
|||
|
|
@ -93,14 +93,14 @@ LLAMA_API llama_memory_breakdown llama_get_memory_breakdown(const struct llama_c
|
|||
// pre-norm embeddings (hidden state before the final output norm)
|
||||
//
|
||||
|
||||
// mirrors:
|
||||
// LLAMA_API void llama_set_embeddings(struct llama_context * ctx, bool embeddings);
|
||||
LLAMA_API void llama_set_embeddings_pre_norm(struct llama_context * ctx, bool value);
|
||||
// Set whether the context outputs pre-norm embeddings or not
|
||||
// If masked == true, output the embeddings only for the tokens with batch.logits != 0
|
||||
// If masked == false, output the embeddings for all tokens in the batch regardless of batch.logits
|
||||
LLAMA_API void llama_set_embeddings_pre_norm(struct llama_context * ctx, bool value, bool masked);
|
||||
|
||||
// mirrors:
|
||||
// LLAMA_API float * llama_get_embeddings(struct llama_context * ctx);
|
||||
LLAMA_API float * llama_get_embeddings_pre_norm(struct llama_context * ctx);
|
||||
LLAMA_API float * llama_get_embeddings_pre_norm (struct llama_context * ctx);
|
||||
|
||||
// mirrors:
|
||||
// LLAMA_API float * llama_get_embeddings_ith(struct llama_context * ctx, int32_t i);
|
||||
LLAMA_API float * llama_get_embeddings_pre_norm_ith(struct llama_context * ctx, int32_t i);
|
||||
|
|
|
|||
|
|
@ -848,6 +848,9 @@ void llm_graph_result::set_outputs() {
|
|||
if (t_embd_pooled != nullptr) {
|
||||
ggml_set_output(t_embd_pooled);
|
||||
}
|
||||
if (t_h_pre_norm != nullptr) {
|
||||
ggml_set_output(t_h_pre_norm);
|
||||
}
|
||||
for (auto & [seq_id, t] : t_sampled) {
|
||||
if (t != nullptr) {
|
||||
ggml_set_output(t);
|
||||
|
|
|
|||
|
|
@ -176,7 +176,7 @@ llama_model_qwen35::graph::graph(const llama_model & model, const llm_graph_para
|
|||
cur = build_layer_attn(inp->get_attn(), cur, inp_pos, sections, il);
|
||||
}
|
||||
|
||||
if (il == n_transformer_layers - 1 && inp_out_ids) {
|
||||
if (il == n_transformer_layers - 1 && inp_out_ids && cparams.embeddings_pre_norm_masked) {
|
||||
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
||||
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
|
||||
}
|
||||
|
|
@ -211,6 +211,10 @@ llama_model_qwen35::graph::graph(const llama_model & model, const llm_graph_para
|
|||
cb(cur, "h_pre_norm", -1);
|
||||
res->t_h_pre_norm = cur;
|
||||
|
||||
if (!cparams.embeddings_pre_norm_masked && inp_out_ids) {
|
||||
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
||||
}
|
||||
|
||||
// Final norm
|
||||
cur = build_norm(cur, model.output_norm, nullptr, LLM_NORM_RMS, -1);
|
||||
|
||||
|
|
|
|||
|
|
@ -199,7 +199,7 @@ llama_model_qwen35moe::graph::graph(const llama_model & model, const llm_graph_p
|
|||
cur = build_layer_attn(inp->get_attn(), cur, inp_pos, sections, il);
|
||||
}
|
||||
|
||||
if (il == n_transformer_layers - 1 && inp_out_ids) {
|
||||
if (il == n_transformer_layers - 1 && inp_out_ids && cparams.embeddings_pre_norm_masked) {
|
||||
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
||||
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
|
||||
}
|
||||
|
|
@ -234,6 +234,10 @@ llama_model_qwen35moe::graph::graph(const llama_model & model, const llm_graph_p
|
|||
cb(cur, "h_pre_norm", -1);
|
||||
res->t_h_pre_norm = cur;
|
||||
|
||||
if (!cparams.embeddings_pre_norm_masked && inp_out_ids) {
|
||||
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
||||
}
|
||||
|
||||
// Final norm
|
||||
cur = build_norm(cur, model.output_norm, nullptr, LLM_NORM_RMS, -1);
|
||||
|
||||
|
|
|
|||
|
|
@ -243,6 +243,11 @@ struct server_slot {
|
|||
return task->need_embd() || (spec && common_speculative_need_embd(spec));
|
||||
}
|
||||
|
||||
bool need_embd_pre_norm() const {
|
||||
GGML_ASSERT(task);
|
||||
return spec && common_speculative_need_embd_pre_norm(spec);
|
||||
}
|
||||
|
||||
// if the context does not have a memory module then all embeddings have to be computed within a single ubatch
|
||||
// also we cannot split if the pooling would require any past tokens
|
||||
// (MTP supports splitting — uses task->need_embd() not need_embd())
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue