mirror of
https://github.com/Lizonghang/prima.cpp.git
synced 2025-09-04 03:19:17 +00:00
llama-server: fix k-shift when output overlength
This commit is contained in:
parent
f032680cab
commit
bdf9d8e74b
3 changed files with 21 additions and 14 deletions
|
@ -2037,6 +2037,13 @@ gpt_params_context gpt_params_parser_init(gpt_params & params, llama_example ex,
|
|||
params.chat_template = value;
|
||||
}
|
||||
).set_examples({LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_CHAT_TEMPLATE"));
|
||||
add_opt(llama_arg(
|
||||
{"-sys", "--system-prompt"}, "PROMPT",
|
||||
"system prompt to use with model (if applicable, depending on chat template)",
|
||||
[](gpt_params & params, const std::string & value) {
|
||||
params.system_prompt = value;
|
||||
}
|
||||
).set_examples({LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_SYSTEM_PROMPT"));
|
||||
add_opt(llama_arg(
|
||||
{"-sps", "--slot-prompt-similarity"}, "SIMILARITY",
|
||||
format("how much the prompt of a request must match the prompt of a slot in order to use that slot (default: %.2f, 0.0 = disabled)\n", params.slot_prompt_similarity),
|
||||
|
|
|
@ -1013,7 +1013,7 @@ struct server_context {
|
|||
slot.sparams.mirostat_tau = json_value(data, "mirostat_tau", default_sparams.mirostat_tau);
|
||||
slot.sparams.mirostat_eta = json_value(data, "mirostat_eta", default_sparams.mirostat_eta);
|
||||
slot.sparams.penalize_nl = json_value(data, "penalize_nl", default_sparams.penalize_nl);
|
||||
slot.params.n_keep = json_value(data, "n_keep", slot.params.n_keep);
|
||||
slot.params.n_keep = json_value(data, "n_keep", params.n_keep);
|
||||
slot.params.n_discard = json_value(data, "n_discard", default_params.n_discard);
|
||||
slot.sparams.seed = json_value(data, "seed", default_sparams.seed);
|
||||
slot.sparams.n_probs = json_value(data, "n_probs", default_sparams.n_probs);
|
||||
|
@ -1215,7 +1215,8 @@ struct server_context {
|
|||
|
||||
// assign the system KV cache to all parallel sequences
|
||||
for (int32_t i = 1; i <= params.n_parallel; ++i) {
|
||||
llama_kv_cache_seq_cp(ctx, 0, i, -1, -1);
|
||||
llama_kv_cache_seq_cp (ctx, 0, i, -1, -1);
|
||||
llama_send_kv_cache_seq_cp(ctx, 0, i - 1, -1, -1);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -2029,7 +2030,6 @@ struct server_context {
|
|||
}
|
||||
|
||||
// apply context-shift if needed
|
||||
// TODO: simplify and improve
|
||||
for (server_slot & slot : slots) {
|
||||
if (slot.ga_n == 1) {
|
||||
if (slot.is_processing() && (int) system_tokens.size() + slot.n_past >= slot.n_ctx - 1) {
|
||||
|
@ -2204,14 +2204,14 @@ struct server_context {
|
|||
} else {
|
||||
if (!params.ctx_shift) {
|
||||
// if context shift is disabled, we make sure prompt size is smaller than KV size
|
||||
if ((int) system_tokens.size() + slot.n_prompt_tokens >= slot.n_ctx) {
|
||||
if ((int)system_tokens.size() + slot.n_prompt_tokens >= slot.n_ctx) {
|
||||
slot.release();
|
||||
send_error(slot, "the request exceeds the available context size. try increasing the context size or enable context shift", ERROR_TYPE_INVALID_REQUEST);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
if (slot.params.n_keep < 0) {
|
||||
slot.params.n_keep = slot.n_prompt_tokens;
|
||||
slot.params.n_keep = (int)system_tokens.size() + slot.n_prompt_tokens + 3; // +3 for <think> tag
|
||||
}
|
||||
slot.params.n_keep = std::min(slot.n_ctx - 4, slot.params.n_keep);
|
||||
|
||||
|
@ -3590,7 +3590,7 @@ int main(int argc, char ** argv) {
|
|||
}
|
||||
|
||||
// print sample chat example to make it clear which template is used
|
||||
LOG_INF("%s: chat template, built_in: %d, chat_example: '%s'\n", __func__, params.chat_template.empty(), llama_chat_format_example(ctx_server.model, params.chat_template).c_str());
|
||||
// LOG_INF("%s: chat template, built_in: %d, chat_example: '%s'\n", __func__, params.chat_template.empty(), llama_chat_format_example(ctx_server.model, params.chat_template).c_str());
|
||||
|
||||
ctx_server.queue_tasks.on_new_task(std::bind(
|
||||
&server_context::process_single_task, &ctx_server, std::placeholders::_1));
|
||||
|
|
|
@ -19108,19 +19108,19 @@ static void llama_kv_cache_update_internal(struct llama_context & lctx) {
|
|||
GGML_ABORT("Deepseek2 does not support K-shift");
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < lctx.sched.size(); ++i) {
|
||||
ggml_backend_sched_reset(lctx.sched[i]);
|
||||
auto * sched = lctx.sched.at(0);
|
||||
|
||||
ggml_cgraph * gf = llama_build_graph_k_shift(lctx);
|
||||
ggml_backend_sched_reset(sched);
|
||||
|
||||
ggml_backend_sched_alloc_graph(lctx.sched[i], gf);
|
||||
ggml_cgraph * gf = llama_build_graph_k_shift(lctx);
|
||||
|
||||
llama_set_k_shift(lctx);
|
||||
ggml_backend_sched_alloc_graph(sched, gf);
|
||||
|
||||
llama_graph_compute(lctx, gf, lctx.sched[i], lctx.cparams.n_threads, lctx.threadpool);
|
||||
llama_set_k_shift(lctx);
|
||||
|
||||
need_reserve = true;
|
||||
}
|
||||
llama_graph_compute(lctx, gf, sched, lctx.cparams.n_threads, lctx.threadpool);
|
||||
|
||||
need_reserve = true;
|
||||
|
||||
{
|
||||
auto & kv_self = lctx.kv_self;
|
||||
|
|
Loading…
Add table
Reference in a new issue