llama-server: fix k-shift when output overlength

This commit is contained in:
Li, Zonghang 2025-07-17 21:03:41 +08:00
parent f032680cab
commit bdf9d8e74b
3 changed files with 21 additions and 14 deletions

View file

@ -2037,6 +2037,13 @@ gpt_params_context gpt_params_parser_init(gpt_params & params, llama_example ex,
params.chat_template = value; params.chat_template = value;
} }
).set_examples({LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_CHAT_TEMPLATE")); ).set_examples({LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_CHAT_TEMPLATE"));
add_opt(llama_arg(
{"-sys", "--system-prompt"}, "PROMPT",
"system prompt to use with model (if applicable, depending on chat template)",
[](gpt_params & params, const std::string & value) {
params.system_prompt = value;
}
).set_examples({LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_SYSTEM_PROMPT"));
add_opt(llama_arg( add_opt(llama_arg(
{"-sps", "--slot-prompt-similarity"}, "SIMILARITY", {"-sps", "--slot-prompt-similarity"}, "SIMILARITY",
format("how much the prompt of a request must match the prompt of a slot in order to use that slot (default: %.2f, 0.0 = disabled)\n", params.slot_prompt_similarity), format("how much the prompt of a request must match the prompt of a slot in order to use that slot (default: %.2f, 0.0 = disabled)\n", params.slot_prompt_similarity),

View file

@ -1013,7 +1013,7 @@ struct server_context {
slot.sparams.mirostat_tau = json_value(data, "mirostat_tau", default_sparams.mirostat_tau); slot.sparams.mirostat_tau = json_value(data, "mirostat_tau", default_sparams.mirostat_tau);
slot.sparams.mirostat_eta = json_value(data, "mirostat_eta", default_sparams.mirostat_eta); slot.sparams.mirostat_eta = json_value(data, "mirostat_eta", default_sparams.mirostat_eta);
slot.sparams.penalize_nl = json_value(data, "penalize_nl", default_sparams.penalize_nl); slot.sparams.penalize_nl = json_value(data, "penalize_nl", default_sparams.penalize_nl);
slot.params.n_keep = json_value(data, "n_keep", slot.params.n_keep); slot.params.n_keep = json_value(data, "n_keep", params.n_keep);
slot.params.n_discard = json_value(data, "n_discard", default_params.n_discard); slot.params.n_discard = json_value(data, "n_discard", default_params.n_discard);
slot.sparams.seed = json_value(data, "seed", default_sparams.seed); slot.sparams.seed = json_value(data, "seed", default_sparams.seed);
slot.sparams.n_probs = json_value(data, "n_probs", default_sparams.n_probs); slot.sparams.n_probs = json_value(data, "n_probs", default_sparams.n_probs);
@ -1215,7 +1215,8 @@ struct server_context {
// assign the system KV cache to all parallel sequences // assign the system KV cache to all parallel sequences
for (int32_t i = 1; i <= params.n_parallel; ++i) { for (int32_t i = 1; i <= params.n_parallel; ++i) {
llama_kv_cache_seq_cp(ctx, 0, i, -1, -1); llama_kv_cache_seq_cp (ctx, 0, i, -1, -1);
llama_send_kv_cache_seq_cp(ctx, 0, i - 1, -1, -1);
} }
} }
@ -2029,7 +2030,6 @@ struct server_context {
} }
// apply context-shift if needed // apply context-shift if needed
// TODO: simplify and improve
for (server_slot & slot : slots) { for (server_slot & slot : slots) {
if (slot.ga_n == 1) { if (slot.ga_n == 1) {
if (slot.is_processing() && (int) system_tokens.size() + slot.n_past >= slot.n_ctx - 1) { if (slot.is_processing() && (int) system_tokens.size() + slot.n_past >= slot.n_ctx - 1) {
@ -2204,14 +2204,14 @@ struct server_context {
} else { } else {
if (!params.ctx_shift) { if (!params.ctx_shift) {
// if context shift is disabled, we make sure prompt size is smaller than KV size // if context shift is disabled, we make sure prompt size is smaller than KV size
if ((int) system_tokens.size() + slot.n_prompt_tokens >= slot.n_ctx) { if ((int)system_tokens.size() + slot.n_prompt_tokens >= slot.n_ctx) {
slot.release(); slot.release();
send_error(slot, "the request exceeds the available context size. try increasing the context size or enable context shift", ERROR_TYPE_INVALID_REQUEST); send_error(slot, "the request exceeds the available context size. try increasing the context size or enable context shift", ERROR_TYPE_INVALID_REQUEST);
continue; continue;
} }
} }
if (slot.params.n_keep < 0) { if (slot.params.n_keep < 0) {
slot.params.n_keep = slot.n_prompt_tokens; slot.params.n_keep = (int)system_tokens.size() + slot.n_prompt_tokens + 3; // +3 for <think> tag
} }
slot.params.n_keep = std::min(slot.n_ctx - 4, slot.params.n_keep); slot.params.n_keep = std::min(slot.n_ctx - 4, slot.params.n_keep);
@ -3590,7 +3590,7 @@ int main(int argc, char ** argv) {
} }
// print sample chat example to make it clear which template is used // print sample chat example to make it clear which template is used
LOG_INF("%s: chat template, built_in: %d, chat_example: '%s'\n", __func__, params.chat_template.empty(), llama_chat_format_example(ctx_server.model, params.chat_template).c_str()); // LOG_INF("%s: chat template, built_in: %d, chat_example: '%s'\n", __func__, params.chat_template.empty(), llama_chat_format_example(ctx_server.model, params.chat_template).c_str());
ctx_server.queue_tasks.on_new_task(std::bind( ctx_server.queue_tasks.on_new_task(std::bind(
&server_context::process_single_task, &ctx_server, std::placeholders::_1)); &server_context::process_single_task, &ctx_server, std::placeholders::_1));

View file

@ -19108,19 +19108,19 @@ static void llama_kv_cache_update_internal(struct llama_context & lctx) {
GGML_ABORT("Deepseek2 does not support K-shift"); GGML_ABORT("Deepseek2 does not support K-shift");
} }
for (size_t i = 0; i < lctx.sched.size(); ++i) { auto * sched = lctx.sched.at(0);
ggml_backend_sched_reset(lctx.sched[i]);
ggml_backend_sched_reset(sched);
ggml_cgraph * gf = llama_build_graph_k_shift(lctx); ggml_cgraph * gf = llama_build_graph_k_shift(lctx);
ggml_backend_sched_alloc_graph(lctx.sched[i], gf); ggml_backend_sched_alloc_graph(sched, gf);
llama_set_k_shift(lctx); llama_set_k_shift(lctx);
llama_graph_compute(lctx, gf, lctx.sched[i], lctx.cparams.n_threads, lctx.threadpool); llama_graph_compute(lctx, gf, sched, lctx.cparams.n_threads, lctx.threadpool);
need_reserve = true; need_reserve = true;
}
{ {
auto & kv_self = lctx.kv_self; auto & kv_self = lctx.kv_self;