fix context shifting

This commit is contained in:
Lizonghang 2025-05-19 16:58:35 +04:00
parent 07c4966a80
commit c54a6a0132
8 changed files with 397 additions and 73 deletions

View file

@ -116,7 +116,7 @@ struct server_task {
};
struct server_task_result {
int id = -1;
int id = -1;
json data;
@ -1063,6 +1063,9 @@ struct server_context {
// clear the entire KV cache
llama_kv_cache_clear(ctx);
llama_send_kv_cache_clear(ctx);
clean_kv_cache = false;
}
@ -1191,7 +1194,7 @@ struct server_context {
SLT_DBG(slot, "stopped by limit, n_decoded = %d, n_predict = %d\n", slot.n_decoded, slot.params.n_predict);
}
// if context shift is disabled, we stop when it reaches the context limit
// we stop when it reaches the context limit, otherwise it may run forever
if (slot.n_decoded >= slot.n_ctx) {
slot.truncated = true;
slot.stopped_limit = true;
@ -1917,8 +1920,11 @@ struct server_context {
SLT_WRN(slot, "slot context shift, n_keep = %d, n_left = %d, n_discard = %d\n", n_keep, n_left, n_discard);
llama_kv_cache_seq_rm (ctx, slot.id + 1, n_keep , n_keep + n_discard);
llama_kv_cache_seq_add(ctx, slot.id + 1, n_keep + n_discard, system_tokens.size() + slot.n_past, -n_discard);
llama_kv_cache_seq_rm (ctx, slot.id + 1, n_keep , n_keep + n_discard);
llama_kv_cache_seq_add (ctx, slot.id + 1, n_keep + n_discard, system_tokens.size() + slot.n_past, -n_discard);
llama_send_kv_cache_seq_rm (ctx, slot.id , n_keep , n_keep + n_discard);
llama_send_kv_cache_seq_add(ctx, slot.id , n_keep + n_discard, system_tokens.size() + slot.n_past, -n_discard);
if (slot.params.cache_prompt) {
for (size_t i = n_keep + n_discard; i < slot.cache_tokens.size(); i++) {
@ -2084,7 +2090,6 @@ struct server_context {
// if input prompt is too big, truncate it (if group attention self-extend is disabled)
if (slot.ga_n == 1 && slot.n_prompt_tokens >= slot.n_ctx) {
const int n_left = slot.n_ctx - slot.params.n_keep;
const int n_block_size = n_left / 2;
const int erased_blocks = (slot.n_prompt_tokens - slot.params.n_keep - n_block_size) / n_block_size;
@ -2161,12 +2166,14 @@ struct server_context {
int p0 = (int) system_tokens.size() + slot.n_past;
if (!llama_kv_cache_seq_rm(ctx, slot.id + 1, p0, -1)) {
// could not partially delete (likely using a non-Transformer model)
llama_kv_cache_seq_rm(ctx, slot.id + 1, -1, -1);
llama_kv_cache_seq_rm (ctx, slot.id + 1, -1, -1);
llama_send_kv_cache_seq_rm(ctx, slot.id , -1, -1);
p0 = (int) system_tokens.size();
if (p0 != 0) {
// copy over the system prompt when there is one
llama_kv_cache_seq_cp(ctx, 0, slot.id + 1, -1, -1);
llama_kv_cache_seq_cp (ctx, 0, slot.id + 1, -1, -1);
llama_send_kv_cache_seq_cp(ctx, 0, slot.id , -1, -1);
}
// there is no common part left (except for the system prompt)
@ -2175,6 +2182,8 @@ struct server_context {
slot.ga_i = 0;
// TODO: is the system prompt ever in the sampling context?
gpt_sampler_reset(slot.smpl);
} else {
llama_send_kv_cache_seq_rm(ctx, slot.id, p0, -1);
}
// remove the non-common part from the cache
@ -2260,9 +2269,14 @@ struct server_context {
SLT_DBG(slot, "div: [%6d, %6d] / %6d -> [%6d, %6d]\n", slot.ga_i + ib * bd, slot.ga_i + ib * bd + slot.ga_w, slot.ga_n, (slot.ga_i + ib * bd) / slot.ga_n, (slot.ga_i + ib * bd + slot.ga_w) / slot.ga_n);
SLT_DBG(slot, "shift: [%6d, %6d] + %6d -> [%6d, %6d]\n", slot.ga_i + ib * bd + slot.ga_w, slot.n_past_se + ib * bd, dd, slot.ga_i + ib * bd + slot.ga_w + dd, slot.n_past_se + ib * bd + dd);
llama_kv_cache_seq_add(ctx, slot.id + 1, slot.ga_i, slot.n_past_se, ib * bd);
llama_kv_cache_seq_div(ctx, slot.id + 1, slot.ga_i + ib * bd, slot.ga_i + ib * bd + slot.ga_w, slot.ga_n);
llama_kv_cache_seq_add(ctx, slot.id + 1, slot.ga_i + ib * bd + slot.ga_w, slot.n_past_se + ib * bd, dd);
llama_kv_cache_seq_add (ctx, slot.id + 1, slot.ga_i, slot.n_past_se, ib * bd);
llama_send_kv_cache_seq_add(ctx, slot.id , slot.ga_i, slot.n_past_se, ib * bd);
llama_kv_cache_seq_div (ctx, slot.id + 1, slot.ga_i + ib * bd, slot.ga_i + ib * bd + slot.ga_w, slot.ga_n);
llama_send_kv_cache_seq_div(ctx, slot.id , slot.ga_i + ib * bd, slot.ga_i + ib * bd + slot.ga_w, slot.ga_n);
llama_kv_cache_seq_add (ctx, slot.id + 1, slot.ga_i + ib * bd + slot.ga_w, slot.n_past_se + ib * bd, dd);
llama_send_kv_cache_seq_add(ctx, slot.id , slot.ga_i + ib * bd + slot.ga_w, slot.n_past_se + ib * bd, dd);
slot.n_past_se -= bd;
@ -3329,10 +3343,6 @@ int main(int argc, char ** argv) {
// bind HTTP listen port, run the HTTP server in a thread
if (!svr->bind_to_port(params.hostname, params.port)) {
//LOG_ERROR("couldn't bind HTTP server socket", {
// {"hostname", params.hostname},
// {"port", params.port},
//});
LOG_ERR("%s: couldn't bind HTTP server socket, hostname: %s, port: %d\n", __func__, params.hostname.c_str(), params.port);
clean_up();
return 1;
@ -3377,10 +3387,6 @@ int main(int argc, char ** argv) {
ctx_server.queue_tasks.terminate();
};
LOG_INF("%s: server is listening on %s:%d - starting the main loop\n", __func__, params.hostname.c_str(), params.port);
ctx_server.queue_tasks.start_loop();
#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__))
struct sigaction sigint_action;
sigint_action.sa_handler = signal_handler;
@ -3395,6 +3401,13 @@ int main(int argc, char ** argv) {
SetConsoleCtrlHandler(reinterpret_cast<PHANDLER_ROUTINE>(console_ctrl_handler), true);
#endif
LOG_INF("%s: server is listening on %s:%d - starting the main loop\n", __func__, params.hostname.c_str(), params.port);
ctx_server.queue_tasks.start_loop();
char * stop_signal = nullptr;
llama_free_sockets(ctx_server.ctx, &stop_signal);
clean_up();
t.join();