llama: avoid copying logits during prompt decode in MTP (#23198)

* llama: avoid copying logits during prompt decode in MTP

* review: update comment

* llama-graph: call set_output for t_h_pre_norm
This commit is contained in:
Aman Gupta 2026-05-17 23:30:25 +08:00 committed by GitHub
parent 39cf5d6191
commit 3e12fbdea5
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
10 changed files with 91 additions and 27 deletions

View file

@ -146,8 +146,11 @@ struct common_speculative_impl {
virtual void accept(llama_seq_id seq_id, uint16_t n_accepted) = 0;
// true if this implementation requires the target context to extract embeddings
// true if this implementation requires the target context to extract post-norm embeddings
virtual bool need_embd() const = 0;
// true if this implementation requires the target context to extract pre-norm embeddings
virtual bool need_embd_pre_norm() const { return false; }
};
struct common_speculative_impl_draft_simple : public common_speculative_impl {
@ -429,8 +432,8 @@ struct common_speculative_state_draft_mtp : public common_speculative_impl {
s.reset(common_sampler_init(llama_get_model(ctx_dft), sparams));
}
llama_set_embeddings_pre_norm(ctx_tgt, true);
llama_set_embeddings_pre_norm(ctx_dft, true);
llama_set_embeddings_pre_norm(ctx_tgt, true, /*masked*/ false);
llama_set_embeddings_pre_norm(ctx_dft, true, /*masked*/ true);
pending_h.assign(n_seq, std::vector<float>(n_embd, 0.0f));
@ -691,6 +694,10 @@ struct common_speculative_state_draft_mtp : public common_speculative_impl {
}
bool need_embd() const override {
return false;
}
bool need_embd_pre_norm() const override {
return true;
}
};
@ -1408,6 +1415,20 @@ bool common_speculative_need_embd(common_speculative * spec) {
return false;
}
bool common_speculative_need_embd_pre_norm(common_speculative * spec) {
if (spec == nullptr) {
return false;
}
for (auto & impl : spec->impls) {
if (impl->need_embd_pre_norm()) {
return true;
}
}
return false;
}
void common_speculative_draft(common_speculative * spec) {
if (spec == nullptr) {
return;

View file

@ -53,9 +53,12 @@ void common_speculative_begin(common_speculative * spec, llama_seq_id seq_id, co
// process the batch and update the internal state of the speculative context
bool common_speculative_process(common_speculative * spec, const llama_batch & batch);
// true if any implementation requires target embeddings to be extracted
// true if any implementation requires target post-norm embeddings to be extracted
bool common_speculative_need_embd(common_speculative * spec);
// true if any implementation requires target pre-norm embeddings to be extracted
bool common_speculative_need_embd_pre_norm(common_speculative * spec);
// generate drafts for the sequences specified with `common_speculative_get_draft_params`
void common_speculative_draft(common_speculative * spec);

View file

@ -895,8 +895,17 @@ float * llama_context::get_embeddings_pre_norm_ith(int32_t i) {
throw std::runtime_error("no pre-norm embeddings");
}
const int64_t j = output_resolve_row(i);
const uint32_t n_embd = model.hparams.n_embd;
if (!cparams.embeddings_pre_norm_masked) {
// unmasked: pre-norm rows are stored densely, indexed by raw token position.
if (i < 0 || (size_t)(i + 1) * n_embd > embd_pre_norm.size) {
throw std::runtime_error(format("out of range [0, %zu)", embd_pre_norm.size / n_embd));
}
return embd_pre_norm.data + (size_t) i * n_embd;
}
const int64_t j = output_resolve_row(i);
return embd_pre_norm.data + j*n_embd;
} catch (const std::exception & err) {
LLAMA_LOG_ERROR("%s: invalid pre-norm embeddings id %d, reason: %s\n", __func__, i, err.what());
@ -1088,10 +1097,11 @@ void llama_context::set_embeddings(bool value) {
//sched_need_reserve = true;
}
void llama_context::set_embeddings_pre_norm(bool value) {
LLAMA_LOG_DEBUG("%s: value = %d\n", __func__, value);
void llama_context::set_embeddings_pre_norm(bool value, bool masked) {
LLAMA_LOG_DEBUG("%s: value = %d, masked = %d\n", __func__, value, masked);
cparams.embeddings_pre_norm = value;
cparams.embeddings_pre_norm = value;
cparams.embeddings_pre_norm_masked = masked;
}
void llama_context::set_causal_attn(bool value) {
@ -1737,6 +1747,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
};
int64_t n_outputs_prev = 0;
int64_t n_tokens_prev = 0;
do {
const auto & ubatch = mctx->get_ubatch();
@ -1882,16 +1893,21 @@ int llama_context::decode(const llama_batch & batch_inp) {
// extract pre-norm embeddings (hidden state before the final output norm)
// only meaningful in LLAMA_POOLING_TYPE_NONE (per-token); other pooling modes are ignored.
if (embd_pre_norm.data && t_h_pre_norm && n_outputs > 0 && cparams.pooling_type == LLAMA_POOLING_TYPE_NONE) {
ggml_backend_t backend_h = ggml_backend_sched_get_tensor_backend(sched.get(), t_h_pre_norm);
GGML_ASSERT(backend_h != nullptr);
{
const bool masked = cparams.embeddings_pre_norm_masked;
const int64_t n_rows = masked ? n_outputs : (int64_t) ubatch.n_tokens;
const int64_t offset = masked ? n_outputs_prev : n_tokens_prev;
const uint32_t n_embd = hparams.n_embd;
float * embd_pre_norm_out = embd_pre_norm.data + n_outputs_prev*n_embd;
if (embd_pre_norm.data && t_h_pre_norm && n_rows > 0 && cparams.pooling_type == LLAMA_POOLING_TYPE_NONE) {
ggml_backend_t backend_h = ggml_backend_sched_get_tensor_backend(sched.get(), t_h_pre_norm);
GGML_ASSERT(backend_h != nullptr);
GGML_ASSERT( n_outputs_prev + n_outputs <= n_outputs_all);
GGML_ASSERT((n_outputs_prev + n_outputs)*n_embd <= (int64_t) embd_pre_norm.size);
ggml_backend_tensor_get_async(backend_h, t_h_pre_norm, embd_pre_norm_out, 0, n_outputs*n_embd*sizeof(float));
const uint32_t n_embd = hparams.n_embd;
float * embd_pre_norm_out = embd_pre_norm.data + offset*n_embd;
GGML_ASSERT((offset + n_rows)*n_embd <= (int64_t) embd_pre_norm.size);
ggml_backend_tensor_get_async(backend_h, t_h_pre_norm, embd_pre_norm_out, 0, n_rows*n_embd*sizeof(float));
}
}
// Copy backend sampling output if this ubatch produced any sampling tensors.
@ -1908,6 +1924,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
}
n_outputs_prev += n_outputs;
n_tokens_prev += ubatch.n_tokens;
} while (mctx->next());
// set to total number of outputs in the batch, for use in llama_get_logits_ith
@ -1999,6 +2016,12 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) {
embd.size = has_embd ? n_embd_out*n_outputs_max : 0;
embd_pre_norm.size = has_embd_pre_norm ? n_embd*n_outputs_max : 0;
if (has_embd_pre_norm && !cparams.embeddings_pre_norm_masked) {
// unmasked: pre-norm row exists for every token in the batch, not just
// those flagged via batch.logits[i] -> size by token count instead.
embd_pre_norm.size = (size_t) n_embd * n_batch;
}
// Allocate backend sampling output buffers if there are backend samplers configured.
const bool has_sampling = !sampling.samplers.empty();
if (has_sampling) {
@ -3547,8 +3570,8 @@ float * llama_get_embeddings_seq(llama_context * ctx, llama_seq_id seq_id) {
return ctx->get_embeddings_seq(seq_id);
}
void llama_set_embeddings_pre_norm(llama_context * ctx, bool value) {
ctx->set_embeddings_pre_norm(value);
void llama_set_embeddings_pre_norm(llama_context * ctx, bool value, bool masked) {
ctx->set_embeddings_pre_norm(value, masked);
}
float * llama_get_embeddings_pre_norm(llama_context * ctx) {

View file

@ -110,7 +110,7 @@ struct llama_context {
void set_abort_callback(bool (*abort_callback)(void * data), void * abort_callback_data);
void set_embeddings (bool value);
void set_embeddings_pre_norm(bool value);
void set_embeddings_pre_norm(bool value, bool masked);
void set_causal_attn(bool value);
void set_warmup(bool value);

View file

@ -28,7 +28,8 @@ struct llama_cparams {
float yarn_beta_slow;
bool embeddings;
bool embeddings_pre_norm; // also extract the hidden state before the final output norm
bool embeddings_pre_norm; // also extract the hidden state before the final output norm
bool embeddings_pre_norm_masked; // extract for only rows where batch.logits != 0
bool causal_attn;
bool offload_kqv;
bool flash_attn;

View file

@ -93,14 +93,14 @@ LLAMA_API llama_memory_breakdown llama_get_memory_breakdown(const struct llama_c
// pre-norm embeddings (hidden state before the final output norm)
//
// mirrors:
// LLAMA_API void llama_set_embeddings(struct llama_context * ctx, bool embeddings);
LLAMA_API void llama_set_embeddings_pre_norm(struct llama_context * ctx, bool value);
// Set whether the context outputs pre-norm embeddings or not
// If masked == true, output the embeddings only for the tokens with batch.logits != 0
// If masked == false, output the embeddings for all tokens in the batch regardless of batch.logits
LLAMA_API void llama_set_embeddings_pre_norm(struct llama_context * ctx, bool value, bool masked);
// mirrors:
// LLAMA_API float * llama_get_embeddings(struct llama_context * ctx);
LLAMA_API float * llama_get_embeddings_pre_norm(struct llama_context * ctx);
LLAMA_API float * llama_get_embeddings_pre_norm (struct llama_context * ctx);
// mirrors:
// LLAMA_API float * llama_get_embeddings_ith(struct llama_context * ctx, int32_t i);
LLAMA_API float * llama_get_embeddings_pre_norm_ith(struct llama_context * ctx, int32_t i);

View file

@ -848,6 +848,9 @@ void llm_graph_result::set_outputs() {
if (t_embd_pooled != nullptr) {
ggml_set_output(t_embd_pooled);
}
if (t_h_pre_norm != nullptr) {
ggml_set_output(t_h_pre_norm);
}
for (auto & [seq_id, t] : t_sampled) {
if (t != nullptr) {
ggml_set_output(t);

View file

@ -176,7 +176,7 @@ llama_model_qwen35::graph::graph(const llama_model & model, const llm_graph_para
cur = build_layer_attn(inp->get_attn(), cur, inp_pos, sections, il);
}
if (il == n_transformer_layers - 1 && inp_out_ids) {
if (il == n_transformer_layers - 1 && inp_out_ids && cparams.embeddings_pre_norm_masked) {
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
}
@ -211,6 +211,10 @@ llama_model_qwen35::graph::graph(const llama_model & model, const llm_graph_para
cb(cur, "h_pre_norm", -1);
res->t_h_pre_norm = cur;
if (!cparams.embeddings_pre_norm_masked && inp_out_ids) {
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
}
// Final norm
cur = build_norm(cur, model.output_norm, nullptr, LLM_NORM_RMS, -1);

View file

@ -199,7 +199,7 @@ llama_model_qwen35moe::graph::graph(const llama_model & model, const llm_graph_p
cur = build_layer_attn(inp->get_attn(), cur, inp_pos, sections, il);
}
if (il == n_transformer_layers - 1 && inp_out_ids) {
if (il == n_transformer_layers - 1 && inp_out_ids && cparams.embeddings_pre_norm_masked) {
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
}
@ -234,6 +234,10 @@ llama_model_qwen35moe::graph::graph(const llama_model & model, const llm_graph_p
cb(cur, "h_pre_norm", -1);
res->t_h_pre_norm = cur;
if (!cparams.embeddings_pre_norm_masked && inp_out_ids) {
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
}
// Final norm
cur = build_norm(cur, model.output_norm, nullptr, LLM_NORM_RMS, -1);

View file

@ -243,6 +243,11 @@ struct server_slot {
return task->need_embd() || (spec && common_speculative_need_embd(spec));
}
bool need_embd_pre_norm() const {
GGML_ASSERT(task);
return spec && common_speculative_need_embd_pre_norm(spec);
}
// if the context does not have a memory module then all embeddings have to be computed within a single ubatch
// also we cannot split if the pooling would require any past tokens
// (MTP supports splitting — uses task->need_embd() not need_embd())