model : NvFP4 quantized LM head support (#23046)

* NvFP4 quantized LM head support

Signed-off-by: ynankani <ynankani@nvidia.com>

* Address review commnets

Signed-off-by: ynankani <ynankani@nvidia.com>

* Add assert for NvFp4 lm head and tied embeddings

Signed-off-by: ynankani <ynankani@nvidia.com>

* Address review commnets

Signed-off-by: ynankani <ynankani@nvidia.com>

* Create output_s tensor only when LM head NvFp4

Signed-off-by: ynankani <ynankani@nvidia.com>

---------

Signed-off-by: ynankani <ynankani@nvidia.com>
This commit is contained in:
ynankani 2026-05-16 09:09:27 +00:00 committed by GitHub
parent 59778f0196
commit 42928bc14d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
103 changed files with 121 additions and 101 deletions

View file

@ -393,6 +393,8 @@ void llama_model_saver::add_tensors_from_model() {
add_tensor(model->output);
add_tensor(model->output_b);
add_tensor(model->output_norm_enc);
add_tensor(model->output_s);
add_tensor(model->output_in_s);
add_tensor(model->cls);
add_tensor(model->cls_b);
add_tensor(model->cls_out);

View file

@ -1394,10 +1394,23 @@ bool llama_model_base::load_tensors(llama_model_loader & ml) {
layer.ssm_beta_in_s = create_tensor(tn(LLM_TENSOR_SSM_BETA, "input_scale", i), {1}, TENSOR_NOT_REQUIRED);
}
}
// output scales
if (output && output->type == GGML_TYPE_NVFP4) {
// weight scale
if (!output_s) {
output_s = create_tensor(tn(LLM_TENSOR_OUTPUT, "scale"), {1}, TENSOR_NOT_REQUIRED);
}
// input scale
if (!output_in_s) {
output_in_s = create_tensor(tn(LLM_TENSOR_OUTPUT, "input_scale"), {1}, TENSOR_NOT_REQUIRED);
}
}
}
ml.done_getting_tensors();
GGML_ASSERT(!(output && tok_embd &&
strcmp(output->name, tok_embd->name) == 0 &&
output->type == GGML_TYPE_NVFP4));
// populate tensors_by_name
for (auto & [_, ctx_ptr] : ml.ctx_map) {
for (auto * cur = ggml_get_first_tensor(ctx_ptr.get()); cur != NULL; cur = ggml_get_next_tensor(ctx_ptr.get(), cur)) {

View file

@ -533,6 +533,11 @@ struct llama_model {
struct ggml_tensor * output_b = nullptr;
struct ggml_tensor * output_norm_enc = nullptr;
// NVFP4 per-tensor scale2, input_scale for LM head
struct ggml_tensor * output_s = nullptr;
struct ggml_tensor * output_in_s = nullptr;
// classifier
struct ggml_tensor * cls = nullptr;
struct ggml_tensor * cls_b = nullptr;

View file

@ -277,7 +277,7 @@ llama_model_afmoe::graph::graph(const llama_model & model, const llm_graph_param
res->t_embd = cur;
// lm_head
cur = build_lora_mm(model.output, cur);
cur = build_lora_mm(model.output, cur, model.output_s);
cb(cur, "result_output", -1);
res->t_logits = cur;

View file

@ -160,7 +160,7 @@ llama_model_apertus::graph::graph(const llama_model & model, const llm_graph_par
res->t_embd = cur;
// lm_head
cur = build_lora_mm(model.output, cur);
cur = build_lora_mm(model.output, cur, model.output_s);
cb(cur, "result_output", -1);
res->t_logits = cur;

View file

@ -148,7 +148,7 @@ llama_model_arcee::graph::graph(const llama_model & model, const llm_graph_param
res->t_embd = cur;
// lm_head
cur = build_lora_mm(model.output, cur);
cur = build_lora_mm(model.output, cur, model.output_s);
cb(cur, "result_output", -1);
res->t_logits = cur;

View file

@ -171,7 +171,7 @@ llama_model_arctic::graph::graph(const llama_model & model, const llm_graph_para
res->t_embd = cur;
// lm_head
cur = build_lora_mm(model.output, cur);
cur = build_lora_mm(model.output, cur, model.output_s);
cb(cur, "result_output", -1);
res->t_logits = cur;

View file

@ -193,7 +193,7 @@ llama_model_arwkv7::graph::graph(const llama_model & model, const llm_graph_para
cb(cur, "result_norm", -1);
res->t_embd = cur;
cur = build_lora_mm(model.output, cur);
cur = build_lora_mm(model.output, cur, model.output_s);
cb(cur, "result_output", -1);
res->t_logits = cur;

View file

@ -146,7 +146,7 @@ llama_model_baichuan::graph::graph(const llama_model & model, const llm_graph_pa
res->t_embd = cur;
// lm_head
cur = build_lora_mm(model.output, cur);
cur = build_lora_mm(model.output, cur, model.output_s);
cb(cur, "result_output", -1);
res->t_logits = cur;

View file

@ -171,7 +171,7 @@ llama_model_bailingmoe::graph::graph(const llama_model & model, const llm_graph_
res->t_embd = cur;
// lm_head
cur = build_lora_mm(model.output, cur);
cur = build_lora_mm(model.output, cur, model.output_s);
cb(cur, "result_output", -1);
res->t_logits = cur;

View file

@ -210,7 +210,7 @@ llama_model_bailingmoe2::graph::graph(const llama_model & model, const llm_graph
res->t_embd = cur;
// lm_head
cur = build_lora_mm(model.output, cur);
cur = build_lora_mm(model.output, cur, model.output_s);
cb(cur, "result_output", -1);
res->t_logits = cur;

View file

@ -142,7 +142,7 @@ llama_model_bloom::graph::graph(const llama_model & model, const llm_graph_param
cb(cur, "result_norm", -1);
res->t_embd = cur;
cur = build_lora_mm(model.output, cur);
cur = build_lora_mm(model.output, cur, model.output_s);
cb(cur, "result_output", -1);
res->t_logits = cur;

View file

@ -181,7 +181,7 @@ llama_model_chameleon::graph::graph(const llama_model & model, const llm_graph_p
res->t_embd = cur;
// lm_head
cur = build_lora_mm(model.output, cur);
cur = build_lora_mm(model.output, cur, model.output_s);
cb(cur, "result_output_with_img_logits", -1);
// TODO: this suppresses the output of image tokens, which is required to enable text-only outputs.

View file

@ -151,7 +151,7 @@ llama_model_chatglm::graph::graph(const llama_model & model, const llm_graph_par
cb(cur, "result_norm", -1);
res->t_embd = cur;
cur = build_lora_mm(model.output, cur);
cur = build_lora_mm(model.output, cur, model.output_s);
cb(cur, "result_output", -1);
res->t_logits = cur;

View file

@ -143,7 +143,7 @@ llama_model_codeshell::graph::graph(const llama_model & model, const llm_graph_p
cb(cur, "result_norm", -1);
res->t_embd = cur;
cur = build_lora_mm(model.output, cur);
cur = build_lora_mm(model.output, cur, model.output_s);
cb(cur, "result_output", -1);
res->t_logits = cur;

View file

@ -150,7 +150,7 @@ llama_model_cogvlm::graph::graph(const llama_model & model, const llm_graph_para
cb(cur, "result_norm", -1);
res->t_embd = cur;
cur = build_lora_mm(model.output, cur);
cur = build_lora_mm(model.output, cur, model.output_s);
cb(cur, "result_output", -1);
res->t_logits = cur;
ggml_build_forward_expand(gf, cur);

View file

@ -146,7 +146,7 @@ llama_model_cohere2::graph::graph(const llama_model & model, const llm_graph_par
res->t_embd = cur;
// lm_head
cur = build_lora_mm(model.output, cur);
cur = build_lora_mm(model.output, cur, model.output_s);
if (f_logit_scale) {
cur = ggml_scale(ctx0, cur, f_logit_scale);

View file

@ -131,7 +131,7 @@ llama_model_command_r::graph::graph(const llama_model & model, const llm_graph_p
res->t_embd = cur;
// lm_head
cur = build_lora_mm(model.output, cur);
cur = build_lora_mm(model.output, cur, model.output_s);
if (f_logit_scale) {
cur = ggml_scale(ctx0, cur, f_logit_scale);

View file

@ -145,7 +145,7 @@ llama_model_dbrx::graph::graph(const llama_model & model, const llm_graph_params
res->t_embd = cur;
// lm_head
cur = build_lora_mm(model.output, cur);
cur = build_lora_mm(model.output, cur, model.output_s);
cb(cur, "result_output", -1);
res->t_logits = cur;

View file

@ -181,7 +181,7 @@ llama_model_deci::graph::graph(const llama_model & model, const llm_graph_params
res->t_embd = cur;
// lm_head
cur = build_lora_mm(model.output, cur);
cur = build_lora_mm(model.output, cur, model.output_s);
cb(cur, "result_output", -1);
res->t_logits = cur;

View file

@ -185,7 +185,7 @@ llama_model_deepseek::graph::graph(const llama_model & model, const llm_graph_pa
res->t_embd = cur;
// lm_head
cur = build_lora_mm(model.output, cur);
cur = build_lora_mm(model.output, cur, model.output_s);
cb(cur, "result_output", -1);
res->t_logits = cur;

View file

@ -183,7 +183,7 @@ llama_model_dots1::graph::graph(const llama_model & model, const llm_graph_param
res->t_embd = cur;
// lm_head
cur = build_lora_mm(model.output, cur);
cur = build_lora_mm(model.output, cur, model.output_s);
cb(cur, "result_output", -1);
res->t_logits = cur;

View file

@ -128,7 +128,7 @@ llama_model_dream::graph::graph(const llama_model & model, const llm_graph_param
res->t_embd = cur;
// lm_head
cur = build_lora_mm(model.output, cur);
cur = build_lora_mm(model.output, cur, model.output_s);
cb(cur, "result_output", -1);
res->t_logits = cur;

View file

@ -124,7 +124,7 @@ llama_model_ernie4_5_moe::graph::graph(const llama_model & model, const llm_grap
res->t_embd = cur;
// lm_head
cur = build_lora_mm(model.output, cur);
cur = build_lora_mm(model.output, cur, model.output_s);
cb(cur, "result_output", -1);
res->t_logits = cur;

View file

@ -155,7 +155,7 @@ llama_model_ernie4_5::graph::graph(const llama_model & model, const llm_graph_pa
res->t_embd = cur;
// lm_head
cur = build_lora_mm(model.output, cur);
cur = build_lora_mm(model.output, cur, model.output_s);
cb(cur, "result_output", -1);
res->t_logits = cur;

View file

@ -237,7 +237,7 @@ llama_model_exaone_moe::graph::graph(const llama_model & model, const llm_graph_
res->t_embd = cur;
// lm_head
cur = build_lora_mm(model.output, cur);
cur = build_lora_mm(model.output, cur, model.output_s);
cb(cur, "result_output", -1);
res->t_logits = cur;

View file

@ -127,7 +127,7 @@ llama_model_exaone::graph::graph(const llama_model & model, const llm_graph_para
res->t_embd = cur;
// lm_head
cur = build_lora_mm(model.output, cur);
cur = build_lora_mm(model.output, cur, model.output_s);
cb(cur, "result_output", -1);
res->t_logits = cur;

View file

@ -163,7 +163,7 @@ llama_model_exaone4::graph<iswa>::graph(const llama_model & model, const llm_gra
res->t_embd = cur;
// lm_head
cur = build_lora_mm(model.output, cur);
cur = build_lora_mm(model.output, cur, model.output_s);
cb(cur, "result_output", -1);
res->t_logits = cur;

View file

@ -200,7 +200,7 @@ llama_model_falcon_h1::graph::graph(const llama_model & model, const llm_graph_p
res->t_embd = cur;
// lm_head
cur = build_lora_mm(model.output, cur);
cur = build_lora_mm(model.output, cur, model.output_s);
cb(cur, "result_output", -1);
res->t_logits = cur;

View file

@ -152,7 +152,7 @@ llama_model_falcon::graph::graph(const llama_model & model, const llm_graph_para
cb(cur, "result_norm", -1);
res->t_embd = cur;
cur = build_lora_mm(model.output, cur);
cur = build_lora_mm(model.output, cur, model.output_s);
cb(cur, "result_output", -1);
res->t_logits = cur;

View file

@ -130,7 +130,7 @@ llama_model_gemma::graph::graph(const llama_model & model, const llm_graph_param
res->t_embd = cur;
// lm_head
cur = build_lora_mm(model.output, cur);
cur = build_lora_mm(model.output, cur, model.output_s);
cb(cur, "result_output", -1);
res->t_logits = cur;

View file

@ -163,7 +163,7 @@ llama_model_gemma2::graph::graph(const llama_model & model, const llm_graph_para
res->t_embd = cur;
// lm_head
cur = build_lora_mm(model.output, cur);
cur = build_lora_mm(model.output, cur, model.output_s);
// final logit soft-capping
cur = ggml_scale(ctx0, cur, 1.0f / hparams.f_final_logit_softcapping);

View file

@ -207,7 +207,7 @@ llama_model_gemma3::graph<iswa>::graph(const llama_model & model, const llm_grap
res->t_embd = cur;
// lm_head
cur = build_lora_mm(model.output, cur);
cur = build_lora_mm(model.output, cur, model.output_s);
if (hparams.f_final_logit_softcapping) {
cur = ggml_scale(ctx0, cur, 1.0f / hparams.f_final_logit_softcapping);

View file

@ -296,7 +296,7 @@ llama_model_gemma3n::graph::graph(const llama_model & model, const llm_graph_par
cb(cur, "result_norm", -1);
res->t_embd = cur;
cur = build_lora_mm(model.output, cur);
cur = build_lora_mm(model.output, cur, model.output_s);
{
// final logit soft-capping

View file

@ -380,7 +380,7 @@ llama_model_gemma4::graph::graph(const llama_model & model, const llm_graph_para
res->t_embd = cur;
// lm_head
cur = build_lora_mm(model.output, cur);
cur = build_lora_mm(model.output, cur, model.output_s);
if (hparams.f_final_logit_softcapping) {
cur = ggml_scale(ctx0, cur, 1.0f / hparams.f_final_logit_softcapping);

View file

@ -275,7 +275,7 @@ llama_model_glm4_moe::graph::graph(const llama_model & model, const llm_graph_pa
res->t_embd = cur;
// lm_head
cur = build_lora_mm(model.output, cur);
cur = build_lora_mm(model.output, cur, model.output_s);
cb(cur, "result_output", -1);
res->t_logits = cur;

View file

@ -185,7 +185,7 @@ llama_model_glm4::graph::graph(const llama_model & model, const llm_graph_params
res->t_embd = cur;
// Output projection
cur = build_lora_mm(model.output, cur);
cur = build_lora_mm(model.output, cur, model.output_s);
cb(cur, "result_output", -1);
res->t_logits = cur;

View file

@ -138,7 +138,7 @@ llama_model_gpt2::graph::graph(const llama_model & model, const llm_graph_params
cb(cur, "result_norm", -1);
res->t_embd = cur;
cur = build_lora_mm(model.output, cur);
cur = build_lora_mm(model.output, cur, model.output_s);
cb(cur, "result_output", -1);
res->t_logits = cur;

View file

@ -209,7 +209,7 @@ llama_model_gptneox::graph::graph(const llama_model & model, const llm_graph_par
cb(cur, "result_norm", -1);
res->t_embd = cur;
cur = build_lora_mm(model.output, cur);
cur = build_lora_mm(model.output, cur, model.output_s);
cb(cur, "result_output", -1);
res->t_logits = cur;

View file

@ -186,7 +186,7 @@ llama_model_granite_hybrid::graph::graph(const llama_model & model, const llm_gr
res->t_embd = cur;
// lm_head
cur = build_lora_mm(model.output, cur);
cur = build_lora_mm(model.output, cur, model.output_s);
// For Granite architectures - scale logits
if (hparams.f_logit_scale) {

View file

@ -145,7 +145,7 @@ llama_model_granite::graph::graph(
res->t_embd = cur;
// lm_head
cur = build_lora_mm(model.output, cur);
cur = build_lora_mm(model.output, cur, model.output_s);
// For Granite architectures - scale logits
cur = ggml_scale(ctx0, cur, 1.0f / hparams.f_logit_scale);

View file

@ -206,7 +206,7 @@ llama_model_grok::graph::graph(const llama_model & model, const llm_graph_params
res->t_embd = cur;
// lm_head
cur = build_lora_mm(model.output, cur);
cur = build_lora_mm(model.output, cur, model.output_s);
cur = ggml_scale(ctx0, cur, hparams.f_logit_scale);

View file

@ -184,7 +184,7 @@ llama_model_grovemoe::graph::graph(const llama_model & model, const llm_graph_pa
res->t_embd = cur;
// lm_head
cur = build_lora_mm(model.output, cur);
cur = build_lora_mm(model.output, cur, model.output_s);
cb(cur, "result_output", -1);
res->t_logits = cur;

View file

@ -179,7 +179,7 @@ llama_model_hunyuan_moe::graph::graph(const llama_model & model, const llm_graph
res->t_embd = cur;
// lm_head
cur = build_lora_mm(model.output, cur);
cur = build_lora_mm(model.output, cur, model.output_s);
cb(cur, "result_output", -1);
res->t_logits = cur;

View file

@ -181,7 +181,7 @@ llama_model_hunyuan_vl::graph::graph(const llama_model & model, const llm_graph_
cb(cur, "result_norm", -1);
res->t_embd = cur;
// lm_head
cur = build_lora_mm(model.output, cur);
cur = build_lora_mm(model.output, cur, model.output_s);
cb(cur, "result_output", -1);
res->t_logits = cur;

View file

@ -129,7 +129,7 @@ llama_model_internlm2::graph::graph(const llama_model & model, const llm_graph_p
res->t_embd = cur;
// lm_head
cur = build_lora_mm(model.output, cur);
cur = build_lora_mm(model.output, cur, model.output_s);
cb(cur, "result_output", -1);
res->t_logits = cur;

View file

@ -123,7 +123,7 @@ llama_model_jais::graph::graph(const llama_model & model, const llm_graph_params
cb(cur, "result_norm", -1);
res->t_embd = cur;
cur = build_lora_mm(model.output, cur);
cur = build_lora_mm(model.output, cur, model.output_s);
cb(cur, "result_output", -1);
res->t_logits = cur;

View file

@ -152,7 +152,7 @@ llama_model_jais2::graph::graph(const llama_model & model, const llm_graph_param
res->t_embd = cur;
// Output projection
cur = build_lora_mm(model.output, cur);
cur = build_lora_mm(model.output, cur, model.output_s);
cb(cur, "result_output", -1);
res->t_logits = cur;

View file

@ -189,7 +189,7 @@ llama_model_jamba::graph::graph(const llama_model & model, const llm_graph_param
res->t_embd = cur;
// lm_head
cur = build_lora_mm(model.output, cur);
cur = build_lora_mm(model.output, cur, model.output_s);
cb(cur, "result_output", -1);
res->t_logits = cur;

View file

@ -262,7 +262,7 @@ llama_model_lfm2::graph<iswa>::graph(const llama_model & model, const llm_graph_
cb(cur, "result_norm", -1);
res->t_embd = cur;
cur = build_lora_mm(model.output, cur);
cur = build_lora_mm(model.output, cur, model.output_s);
cb(cur, "result_output", -1);
res->t_logits = cur;

View file

@ -153,7 +153,7 @@ llama_model_llada_moe::graph::graph(const llama_model & model, const llm_graph_p
res->t_embd = cur;
// lm_head
cur = build_lora_mm(model.output, cur);
cur = build_lora_mm(model.output, cur, model.output_s);
cb(cur, "result_output", -1);
res->t_logits = cur;

View file

@ -147,7 +147,7 @@ llama_model_llada::graph::graph(const llama_model & model, const llm_graph_param
res->t_embd = cur;
// lm_head
cur = build_lora_mm(model.output, cur);
cur = build_lora_mm(model.output, cur, model.output_s);
cb(cur, "result_output", -1);
res->t_logits = cur;

View file

@ -235,7 +235,7 @@ llama_model_llama::graph<embed>::graph(const llama_model & model, const llm_grap
if constexpr (!embed) {
// lm_head
cur = build_lora_mm(model.output, cur);
cur = build_lora_mm(model.output, cur, model.output_s);
cb(cur, "result_output", -1);
res->t_logits = cur;

View file

@ -260,7 +260,7 @@ llama_model_llama4::graph<iswa>::graph(const llama_model & model, const llm_grap
res->t_embd = cur;
// lm_head
cur = build_lora_mm(model.output, cur);
cur = build_lora_mm(model.output, cur, model.output_s);
cb(cur, "result_output", -1);
res->t_logits = cur;

View file

@ -141,7 +141,7 @@ llama_model_maincoder::graph::graph(const llama_model & model, const llm_graph_p
res->t_embd = cur;
// lm_head
cur = build_lora_mm(model.output, cur);
cur = build_lora_mm(model.output, cur, model.output_s);
cb(cur, "result_output", -1);
res->t_logits = cur;

View file

@ -128,7 +128,7 @@ llama_model_mamba::graph::graph(const llama_model & model, const llm_graph_param
res->t_embd = cur;
// lm_head
cur = build_lora_mm(model.output, cur);
cur = build_lora_mm(model.output, cur, model.output_s);
cb(cur, "result_output", -1);
res->t_logits = cur;

View file

@ -231,7 +231,7 @@ llama_model_mimo2::graph::graph(const llama_model & model, const llm_graph_param
res->t_embd = cur;
// lm_head
cur = build_lora_mm(model.output, cur);
cur = build_lora_mm(model.output, cur, model.output_s);
cb(cur, "result_output", -1);
res->t_logits = cur;

View file

@ -251,7 +251,7 @@ llama_model_minicpm3::graph::graph(const llama_model & model, const llm_graph_pa
cb(cur, "lmhead_scaling", -1);
// lm_head
cur = build_lora_mm(model.output, cur);
cur = build_lora_mm(model.output, cur, model.output_s);
cb(cur, "result_output", -1);
res->t_logits = cur;

View file

@ -158,7 +158,7 @@ llama_model_minimax_m2::graph::graph(const llama_model & model, const llm_graph_
res->t_embd = cur;
// lm_head
cur = build_lora_mm(model.output, cur);
cur = build_lora_mm(model.output, cur, model.output_s);
cb(cur, "result_output", -1);
res->t_logits = cur;

View file

@ -222,7 +222,7 @@ llama_model_mistral3::graph::graph(const llama_model & model, const llm_graph_pa
res->t_embd = cur;
// lm_head
cur = build_lora_mm(model.output, cur);
cur = build_lora_mm(model.output, cur, model.output_s);
cb(cur, "result_output", -1);
res->t_logits = cur;

View file

@ -161,7 +161,7 @@ llama_model_mpt::graph::graph(const llama_model & model, const llm_graph_params
cb(cur, "result_norm", -1);
res->t_embd = cur;
cur = build_lora_mm(model.output, cur);
cur = build_lora_mm(model.output, cur, model.output_s);
cb(cur, "result_output", -1);
res->t_logits = cur;

View file

@ -174,7 +174,7 @@ llama_model_nemotron_h::graph::graph(const llama_model & model, const llm_graph_
res->t_embd = cur;
// lm_head
cur = build_lora_mm(model.output, cur);
cur = build_lora_mm(model.output, cur, model.output_s);
cb(cur, "result_output", -1);
res->t_logits = cur;

View file

@ -140,7 +140,7 @@ llama_model_nemotron::graph::graph(const llama_model & model, const llm_graph_pa
res->t_embd = cur;
// lm_head
cur = build_lora_mm(model.output, cur);
cur = build_lora_mm(model.output, cur, model.output_s);
cb(cur, "result_output", -1);
res->t_logits = cur;

View file

@ -133,7 +133,7 @@ llama_model_olmo::graph::graph(const llama_model & model, const llm_graph_params
res->t_embd = cur;
// lm_head
cur = build_lora_mm(model.output, cur);
cur = build_lora_mm(model.output, cur, model.output_s);
cb(cur, "result_output", -1);
res->t_logits = cur;

View file

@ -198,7 +198,7 @@ llama_model_olmo2::graph<iswa>::graph(const llama_model & model, const llm_graph
res->t_embd = cur;
// lm_head
cur = build_lora_mm(model.output, cur);
cur = build_lora_mm(model.output, cur, model.output_s);
cb(cur, "result_output", -1);
res->t_logits = cur;

View file

@ -164,7 +164,7 @@ llama_model_olmoe::graph::graph(const llama_model & model, const llm_graph_param
res->t_embd = cur;
// lm_head
cur = build_lora_mm(model.output, cur);
cur = build_lora_mm(model.output, cur, model.output_s);
cb(cur, "result_output", -1);
res->t_logits = cur;

View file

@ -160,7 +160,7 @@ llama_model_openai_moe::graph::graph(const llama_model & model, const llm_graph_
res->t_embd = cur;
// lm_head
cur = build_lora_mm(model.output, cur);
cur = build_lora_mm(model.output, cur, model.output_s);
cb(cur, "result_output", -1);
res->t_logits = cur;

View file

@ -162,7 +162,7 @@ llama_model_openelm::graph::graph(const llama_model & model, const llm_graph_par
cb(cur, "result_norm", -1);
res->t_embd = cur;
cur = build_lora_mm(model.output, cur);
cur = build_lora_mm(model.output, cur, model.output_s);
cb(cur, "result_output", -1);
res->t_logits = cur;

View file

@ -132,7 +132,7 @@ llama_model_orion::graph::graph(const llama_model & model, const llm_graph_param
res->t_embd = cur;
// lm_head
cur = build_lora_mm(model.output, cur);
cur = build_lora_mm(model.output, cur, model.output_s);
cb(cur, "result_output", -1);
res->t_logits = cur;

View file

@ -98,7 +98,7 @@ llama_model_paddleocr::graph::graph(const llama_model & model, const llm_graph_p
res->t_embd = cur;
// lm_head
cur = build_lora_mm(model.output, cur);
cur = build_lora_mm(model.output, cur, model.output_s);
cb(cur, "result_output", -1);
res->t_logits = cur;

View file

@ -148,7 +148,7 @@ llama_model_pangu_embed::graph::graph(const llama_model & model, const llm_graph
res->t_embd = cur;
// lm_head
cur = build_lora_mm(model.output, cur);
cur = build_lora_mm(model.output, cur, model.output_s);
if (model.output_b != nullptr) {
cur = ggml_add(ctx0, cur, model.output_b);

View file

@ -130,7 +130,7 @@ llama_model_phi2::graph::graph(const llama_model & model, const llm_graph_params
cb(cur, "result_norm", -1);
res->t_embd = cur;
cur = build_lora_mm(model.output, cur);
cur = build_lora_mm(model.output, cur, model.output_s);
cb(cur, "result_output_no_bias", -1);
cur = ggml_add(ctx0, cur, model.output_b);

View file

@ -179,7 +179,7 @@ llama_model_phi3::graph<iswa>::graph(const llama_model & model, const llm_graph_
cb(cur, "result_norm", -1);
res->t_embd = cur;
cur = build_lora_mm(model.output, cur);
cur = build_lora_mm(model.output, cur, model.output_s);
if (model.output_b != nullptr) {
cb(cur, "result_output_no_bias", -1);

View file

@ -127,7 +127,7 @@ llama_model_plamo::graph::graph(const llama_model & model, const llm_graph_param
res->t_embd = cur;
// lm_head
cur = build_lora_mm(model.output, cur);
cur = build_lora_mm(model.output, cur, model.output_s);
cb(cur, "result_output", -1);
res->t_logits = cur;

View file

@ -185,7 +185,7 @@ llama_model_plamo2::graph::graph(const llama_model & model, const llm_graph_para
res->t_embd = cur;
// lm_head
cur = build_lora_mm(model.output, cur);
cur = build_lora_mm(model.output, cur, model.output_s);
cb(cur, "result_output", -1);
// Explicitly mark as output tensor to ensure proper backend assignment

View file

@ -186,7 +186,7 @@ llama_model_plamo3::graph<iswa>::graph(const llama_model & model, const llm_grap
cur = build_norm(cur, model.output_norm, NULL, LLM_NORM_RMS, -1);
res->t_embd = cur;
cur = build_lora_mm(model.output, cur);
cur = build_lora_mm(model.output, cur, model.output_s);
res->t_logits = cur;
ggml_build_forward_expand(gf, cur);

View file

@ -204,7 +204,7 @@ llama_model_plm::graph::graph(const llama_model & model, const llm_graph_params
cb(cur, "result_norm", -1);
res->t_embd = cur;
cur = build_lora_mm(model.output, cur);
cur = build_lora_mm(model.output, cur, model.output_s);
cb(cur, "result_output", -1);
res->t_logits = cur;

View file

@ -131,7 +131,7 @@ llama_model_qwen::graph::graph(const llama_model & model, const llm_graph_params
res->t_embd = cur;
// lm_head
cur = build_lora_mm(model.output, cur);
cur = build_lora_mm(model.output, cur, model.output_s);
cb(cur, "result_output", -1);
res->t_logits = cur;

View file

@ -141,7 +141,7 @@ llama_model_qwen2::graph::graph(const llama_model & model, const llm_graph_param
res->t_embd = cur;
// lm_head
cur = build_lora_mm(model.output, cur);
cur = build_lora_mm(model.output, cur, model.output_s);
if (model.output_b != nullptr) {
cur = ggml_add(ctx0, cur, model.output_b);

View file

@ -184,7 +184,7 @@ llama_model_qwen2moe::graph::graph(const llama_model & model, const llm_graph_pa
res->t_embd = cur;
// lm_head
cur = build_lora_mm(model.output, cur);
cur = build_lora_mm(model.output, cur, model.output_s);
cb(cur, "result_output", -1);
res->t_logits = cur;

View file

@ -134,7 +134,7 @@ llama_model_qwen2vl::graph::graph(const llama_model & model, const llm_graph_par
res->t_embd = cur;
// lm_head
cur = build_lora_mm(model.output, cur);
cur = build_lora_mm(model.output, cur, model.output_s);
cb(cur, "result_output", -1);
res->t_logits = cur;

View file

@ -147,7 +147,7 @@ llama_model_qwen3::graph::graph(const llama_model & model, const llm_graph_param
res->t_embd = cur;
// lm_head
cur = build_lora_mm(model.output, cur);
cur = build_lora_mm(model.output, cur, model.output_s);
cb(cur, "result_output", -1);
res->t_logits = cur;

View file

@ -167,7 +167,7 @@ llama_model_qwen35::graph::graph(const llama_model & model, const llm_graph_para
res->t_embd = cur;
// LM head
cur = build_lora_mm(model.output, cur);
cur = build_lora_mm(model.output, cur, model.output_s);
cb(cur, "result_output", -1);
res->t_logits = cur;

View file

@ -180,7 +180,7 @@ llama_model_qwen35moe::graph::graph(const llama_model & model, const llm_graph_p
res->t_embd = cur;
// LM head
cur = build_lora_mm(model.output, cur);
cur = build_lora_mm(model.output, cur, model.output_s);
cb(cur, "result_output", -1);
res->t_logits = cur;

View file

@ -168,7 +168,7 @@ llama_model_qwen3moe::graph::graph(const llama_model & model, const llm_graph_pa
res->t_embd = cur;
// lm_head
cur = build_lora_mm(model.output, cur);
cur = build_lora_mm(model.output, cur, model.output_s);
cb(cur, "result_output", -1);
res->t_logits = cur;

View file

@ -176,7 +176,7 @@ llama_model_qwen3next::graph::graph(const llama_model & model, const llm_graph_p
res->t_embd = cur;
// LM head
cur = build_lora_mm(model.output, cur);
cur = build_lora_mm(model.output, cur, model.output_s);
cb(cur, "result_output", -1);
res->t_logits = cur;

View file

@ -163,7 +163,7 @@ llama_model_qwen3vl::graph::graph(const llama_model & model, const llm_graph_par
res->t_embd = cur;
// lm_head
cur = build_lora_mm(model.output, cur);
cur = build_lora_mm(model.output, cur, model.output_s);
cb(cur, "result_output", -1);
res->t_logits = cur;

View file

@ -180,7 +180,7 @@ llama_model_qwen3vlmoe::graph::graph(const llama_model & model, const llm_graph_
res->t_embd = cur;
// lm_head
cur = build_lora_mm(model.output, cur);
cur = build_lora_mm(model.output, cur, model.output_s);
cb(cur, "result_output", -1);
res->t_logits = cur;

View file

@ -150,7 +150,7 @@ llama_model_refact::graph::graph(const llama_model & model, const llm_graph_para
res->t_embd = cur;
// lm_head
cur = build_lora_mm(model.output, cur);
cur = build_lora_mm(model.output, cur, model.output_s);
cb(cur, "result_output", -1);
res->t_logits = cur;

View file

@ -167,7 +167,7 @@ llama_model_rnd1::graph::graph(const llama_model & model, const llm_graph_params
res->t_embd = cur;
// lm_head
cur = build_lora_mm(model.output, cur);
cur = build_lora_mm(model.output, cur, model.output_s);
cb(cur, "result_output", -1);
res->t_logits = cur;

View file

@ -176,7 +176,7 @@ llama_model_rwkv6::graph::graph(const llama_model & model, const llm_graph_param
cb(cur, "result_norm", -1);
res->t_embd = cur;
cur = build_lora_mm(model.output, cur);
cur = build_lora_mm(model.output, cur, model.output_s);
cb(cur, "result_output", -1);
res->t_logits = cur;

View file

@ -158,7 +158,7 @@ llama_model_rwkv6qwen2::graph::graph(const llama_model & model, const llm_graph_
cb(cur, "result_norm", -1);
res->t_embd = cur;
cur = build_lora_mm(model.output, cur);
cur = build_lora_mm(model.output, cur, model.output_s);
cb(cur, "result_output", -1);
res->t_logits = cur;

View file

@ -202,7 +202,7 @@ llama_model_rwkv7::graph::graph(const llama_model & model, const llm_graph_param
cb(cur, "result_norm", -1);
res->t_embd = cur;
cur = build_lora_mm(model.output, cur);
cur = build_lora_mm(model.output, cur, model.output_s);
cb(cur, "result_output", -1);
res->t_logits = cur;

View file

@ -141,7 +141,7 @@ llama_model_seed_oss::graph::graph(const llama_model & model, const llm_graph_pa
res->t_embd = cur;
// lm_head
cur = build_lora_mm(model.output, cur);
cur = build_lora_mm(model.output, cur, model.output_s);
cb(cur, "result_output", -1);
res->t_logits = cur;

View file

@ -178,7 +178,7 @@ llama_model_smallthinker::graph<iswa>::graph(const llama_model & model, const ll
res->t_embd = cur;
// lm_head
cur = build_lora_mm(model.output, cur);
cur = build_lora_mm(model.output, cur, model.output_s);
cb(cur, "result_output", -1);
res->t_logits = cur;

View file

@ -143,7 +143,7 @@ llama_model_smollm3::graph::graph(const llama_model & model, const llm_graph_par
res->t_embd = cur;
// lm_head
cur = build_lora_mm(model.output, cur);
cur = build_lora_mm(model.output, cur, model.output_s);
cb(cur, "result_output", -1);
res->t_logits = cur;

View file

@ -163,7 +163,7 @@ llama_model_stablelm::graph::graph(const llama_model & model, const llm_graph_pa
res->t_embd = cur;
// lm_head
cur = build_lora_mm(model.output, cur);
cur = build_lora_mm(model.output, cur, model.output_s);
cb(cur, "result_output", -1);
res->t_logits = cur;

View file

@ -135,7 +135,7 @@ llama_model_starcoder::graph::graph(const llama_model & model, const llm_graph_p
cb(cur, "result_norm", -1);
res->t_embd = cur;
cur = build_lora_mm(model.output, cur);
cur = build_lora_mm(model.output, cur, model.output_s);
cb(cur, "result_output", -1);
res->t_logits = cur;

View file

@ -148,7 +148,7 @@ llama_model_starcoder2::graph::graph(const llama_model & model, const llm_graph_
res->t_embd = cur;
// lm_head
cur = build_lora_mm(model.output, cur);
cur = build_lora_mm(model.output, cur, model.output_s);
cb(cur, "result_output", -1);
res->t_logits = cur;

View file

@ -261,7 +261,7 @@ llama_model_step35::graph::graph(const llama_model & model, const llm_graph_para
cb(cur, "result_norm", -1);
res->t_embd = cur;
cur = build_lora_mm(model.output, cur);
cur = build_lora_mm(model.output, cur, model.output_s);
cb(cur, "result_output", -1);
res->t_logits = cur;

Some files were not shown because too many files have changed in this diff Show more