From bdf9d8e74bc6e3b9433f655ecb60f1e537bd41e1 Mon Sep 17 00:00:00 2001 From: "Li, Zonghang" <870644199@qq.com> Date: Thu, 17 Jul 2025 21:03:41 +0800 Subject: [PATCH] llama-server: fix k-shift when output overlength --- common/arg.cpp | 7 +++++++ examples/server/server.cpp | 12 ++++++------ src/llama.cpp | 16 ++++++++-------- 3 files changed, 21 insertions(+), 14 deletions(-) diff --git a/common/arg.cpp b/common/arg.cpp index 45954b52..5e4bb56c 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -2037,6 +2037,13 @@ gpt_params_context gpt_params_parser_init(gpt_params & params, llama_example ex, params.chat_template = value; } ).set_examples({LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_CHAT_TEMPLATE")); + add_opt(llama_arg( + {"-sys", "--system-prompt"}, "PROMPT", + "system prompt to use with model (if applicable, depending on chat template)", + [](gpt_params & params, const std::string & value) { + params.system_prompt = value; + } + ).set_examples({LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_SYSTEM_PROMPT")); add_opt(llama_arg( {"-sps", "--slot-prompt-similarity"}, "SIMILARITY", format("how much the prompt of a request must match the prompt of a slot in order to use that slot (default: %.2f, 0.0 = disabled)\n", params.slot_prompt_similarity), diff --git a/examples/server/server.cpp b/examples/server/server.cpp index a1cfa90c..2a899ed5 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -1013,7 +1013,7 @@ struct server_context { slot.sparams.mirostat_tau = json_value(data, "mirostat_tau", default_sparams.mirostat_tau); slot.sparams.mirostat_eta = json_value(data, "mirostat_eta", default_sparams.mirostat_eta); slot.sparams.penalize_nl = json_value(data, "penalize_nl", default_sparams.penalize_nl); - slot.params.n_keep = json_value(data, "n_keep", slot.params.n_keep); + slot.params.n_keep = json_value(data, "n_keep", params.n_keep); slot.params.n_discard = json_value(data, "n_discard", default_params.n_discard); slot.sparams.seed = json_value(data, "seed", default_sparams.seed); slot.sparams.n_probs = json_value(data, "n_probs", default_sparams.n_probs); @@ -1215,7 +1215,8 @@ struct server_context { // assign the system KV cache to all parallel sequences for (int32_t i = 1; i <= params.n_parallel; ++i) { - llama_kv_cache_seq_cp(ctx, 0, i, -1, -1); + llama_kv_cache_seq_cp (ctx, 0, i, -1, -1); + llama_send_kv_cache_seq_cp(ctx, 0, i - 1, -1, -1); } } @@ -2029,7 +2030,6 @@ struct server_context { } // apply context-shift if needed - // TODO: simplify and improve for (server_slot & slot : slots) { if (slot.ga_n == 1) { if (slot.is_processing() && (int) system_tokens.size() + slot.n_past >= slot.n_ctx - 1) { @@ -2204,14 +2204,14 @@ struct server_context { } else { if (!params.ctx_shift) { // if context shift is disabled, we make sure prompt size is smaller than KV size - if ((int) system_tokens.size() + slot.n_prompt_tokens >= slot.n_ctx) { + if ((int)system_tokens.size() + slot.n_prompt_tokens >= slot.n_ctx) { slot.release(); send_error(slot, "the request exceeds the available context size. try increasing the context size or enable context shift", ERROR_TYPE_INVALID_REQUEST); continue; } } if (slot.params.n_keep < 0) { - slot.params.n_keep = slot.n_prompt_tokens; + slot.params.n_keep = (int)system_tokens.size() + slot.n_prompt_tokens + 3; // +3 for tag } slot.params.n_keep = std::min(slot.n_ctx - 4, slot.params.n_keep); @@ -3590,7 +3590,7 @@ int main(int argc, char ** argv) { } // print sample chat example to make it clear which template is used - LOG_INF("%s: chat template, built_in: %d, chat_example: '%s'\n", __func__, params.chat_template.empty(), llama_chat_format_example(ctx_server.model, params.chat_template).c_str()); + // LOG_INF("%s: chat template, built_in: %d, chat_example: '%s'\n", __func__, params.chat_template.empty(), llama_chat_format_example(ctx_server.model, params.chat_template).c_str()); ctx_server.queue_tasks.on_new_task(std::bind( &server_context::process_single_task, &ctx_server, std::placeholders::_1)); diff --git a/src/llama.cpp b/src/llama.cpp index 9aa9cd82..9015c550 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -19108,19 +19108,19 @@ static void llama_kv_cache_update_internal(struct llama_context & lctx) { GGML_ABORT("Deepseek2 does not support K-shift"); } - for (size_t i = 0; i < lctx.sched.size(); ++i) { - ggml_backend_sched_reset(lctx.sched[i]); + auto * sched = lctx.sched.at(0); - ggml_cgraph * gf = llama_build_graph_k_shift(lctx); + ggml_backend_sched_reset(sched); - ggml_backend_sched_alloc_graph(lctx.sched[i], gf); + ggml_cgraph * gf = llama_build_graph_k_shift(lctx); - llama_set_k_shift(lctx); + ggml_backend_sched_alloc_graph(sched, gf); - llama_graph_compute(lctx, gf, lctx.sched[i], lctx.cparams.n_threads, lctx.threadpool); + llama_set_k_shift(lctx); - need_reserve = true; - } + llama_graph_compute(lctx, gf, sched, lctx.cparams.n_threads, lctx.threadpool); + + need_reserve = true; { auto & kv_self = lctx.kv_self;