llama-server: fix k-shift when output overlength

This commit is contained in:
Li, Zonghang 2025-07-17 21:03:41 +08:00
parent f032680cab
commit bdf9d8e74b
3 changed files with 21 additions and 14 deletions

View file

@ -2037,6 +2037,13 @@ gpt_params_context gpt_params_parser_init(gpt_params & params, llama_example ex,
params.chat_template = value;
}
).set_examples({LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_CHAT_TEMPLATE"));
add_opt(llama_arg(
{"-sys", "--system-prompt"}, "PROMPT",
"system prompt to use with model (if applicable, depending on chat template)",
[](gpt_params & params, const std::string & value) {
params.system_prompt = value;
}
).set_examples({LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_SYSTEM_PROMPT"));
add_opt(llama_arg(
{"-sps", "--slot-prompt-similarity"}, "SIMILARITY",
format("how much the prompt of a request must match the prompt of a slot in order to use that slot (default: %.2f, 0.0 = disabled)\n", params.slot_prompt_similarity),

View file

@ -1013,7 +1013,7 @@ struct server_context {
slot.sparams.mirostat_tau = json_value(data, "mirostat_tau", default_sparams.mirostat_tau);
slot.sparams.mirostat_eta = json_value(data, "mirostat_eta", default_sparams.mirostat_eta);
slot.sparams.penalize_nl = json_value(data, "penalize_nl", default_sparams.penalize_nl);
slot.params.n_keep = json_value(data, "n_keep", slot.params.n_keep);
slot.params.n_keep = json_value(data, "n_keep", params.n_keep);
slot.params.n_discard = json_value(data, "n_discard", default_params.n_discard);
slot.sparams.seed = json_value(data, "seed", default_sparams.seed);
slot.sparams.n_probs = json_value(data, "n_probs", default_sparams.n_probs);
@ -1215,7 +1215,8 @@ struct server_context {
// assign the system KV cache to all parallel sequences
for (int32_t i = 1; i <= params.n_parallel; ++i) {
llama_kv_cache_seq_cp(ctx, 0, i, -1, -1);
llama_kv_cache_seq_cp (ctx, 0, i, -1, -1);
llama_send_kv_cache_seq_cp(ctx, 0, i - 1, -1, -1);
}
}
@ -2029,7 +2030,6 @@ struct server_context {
}
// apply context-shift if needed
// TODO: simplify and improve
for (server_slot & slot : slots) {
if (slot.ga_n == 1) {
if (slot.is_processing() && (int) system_tokens.size() + slot.n_past >= slot.n_ctx - 1) {
@ -2204,14 +2204,14 @@ struct server_context {
} else {
if (!params.ctx_shift) {
// if context shift is disabled, we make sure prompt size is smaller than KV size
if ((int) system_tokens.size() + slot.n_prompt_tokens >= slot.n_ctx) {
if ((int)system_tokens.size() + slot.n_prompt_tokens >= slot.n_ctx) {
slot.release();
send_error(slot, "the request exceeds the available context size. try increasing the context size or enable context shift", ERROR_TYPE_INVALID_REQUEST);
continue;
}
}
if (slot.params.n_keep < 0) {
slot.params.n_keep = slot.n_prompt_tokens;
slot.params.n_keep = (int)system_tokens.size() + slot.n_prompt_tokens + 3; // +3 for <think> tag
}
slot.params.n_keep = std::min(slot.n_ctx - 4, slot.params.n_keep);
@ -3590,7 +3590,7 @@ int main(int argc, char ** argv) {
}
// print sample chat example to make it clear which template is used
LOG_INF("%s: chat template, built_in: %d, chat_example: '%s'\n", __func__, params.chat_template.empty(), llama_chat_format_example(ctx_server.model, params.chat_template).c_str());
// LOG_INF("%s: chat template, built_in: %d, chat_example: '%s'\n", __func__, params.chat_template.empty(), llama_chat_format_example(ctx_server.model, params.chat_template).c_str());
ctx_server.queue_tasks.on_new_task(std::bind(
&server_context::process_single_task, &ctx_server, std::placeholders::_1));

View file

@ -19108,19 +19108,19 @@ static void llama_kv_cache_update_internal(struct llama_context & lctx) {
GGML_ABORT("Deepseek2 does not support K-shift");
}
for (size_t i = 0; i < lctx.sched.size(); ++i) {
ggml_backend_sched_reset(lctx.sched[i]);
auto * sched = lctx.sched.at(0);
ggml_cgraph * gf = llama_build_graph_k_shift(lctx);
ggml_backend_sched_reset(sched);
ggml_backend_sched_alloc_graph(lctx.sched[i], gf);
ggml_cgraph * gf = llama_build_graph_k_shift(lctx);
llama_set_k_shift(lctx);
ggml_backend_sched_alloc_graph(sched, gf);
llama_graph_compute(lctx, gf, lctx.sched[i], lctx.cparams.n_threads, lctx.threadpool);
llama_set_k_shift(lctx);
need_reserve = true;
}
llama_graph_compute(lctx, gf, sched, lctx.cparams.n_threads, lctx.threadpool);
need_reserve = true;
{
auto & kv_self = lctx.kv_self;