mirror of
https://github.com/Lizonghang/prima.cpp.git
synced 2025-09-09 19:44:40 +00:00
fix context shifting
This commit is contained in:
parent
07c4966a80
commit
c54a6a0132
8 changed files with 397 additions and 73 deletions
|
@ -116,7 +116,7 @@ struct server_task {
|
|||
};
|
||||
|
||||
struct server_task_result {
|
||||
int id = -1;
|
||||
int id = -1;
|
||||
|
||||
json data;
|
||||
|
||||
|
@ -1063,6 +1063,9 @@ struct server_context {
|
|||
|
||||
// clear the entire KV cache
|
||||
llama_kv_cache_clear(ctx);
|
||||
|
||||
llama_send_kv_cache_clear(ctx);
|
||||
|
||||
clean_kv_cache = false;
|
||||
}
|
||||
|
||||
|
@ -1191,7 +1194,7 @@ struct server_context {
|
|||
SLT_DBG(slot, "stopped by limit, n_decoded = %d, n_predict = %d\n", slot.n_decoded, slot.params.n_predict);
|
||||
}
|
||||
|
||||
// if context shift is disabled, we stop when it reaches the context limit
|
||||
// we stop when it reaches the context limit, otherwise it may run forever
|
||||
if (slot.n_decoded >= slot.n_ctx) {
|
||||
slot.truncated = true;
|
||||
slot.stopped_limit = true;
|
||||
|
@ -1917,8 +1920,11 @@ struct server_context {
|
|||
|
||||
SLT_WRN(slot, "slot context shift, n_keep = %d, n_left = %d, n_discard = %d\n", n_keep, n_left, n_discard);
|
||||
|
||||
llama_kv_cache_seq_rm (ctx, slot.id + 1, n_keep , n_keep + n_discard);
|
||||
llama_kv_cache_seq_add(ctx, slot.id + 1, n_keep + n_discard, system_tokens.size() + slot.n_past, -n_discard);
|
||||
llama_kv_cache_seq_rm (ctx, slot.id + 1, n_keep , n_keep + n_discard);
|
||||
llama_kv_cache_seq_add (ctx, slot.id + 1, n_keep + n_discard, system_tokens.size() + slot.n_past, -n_discard);
|
||||
|
||||
llama_send_kv_cache_seq_rm (ctx, slot.id , n_keep , n_keep + n_discard);
|
||||
llama_send_kv_cache_seq_add(ctx, slot.id , n_keep + n_discard, system_tokens.size() + slot.n_past, -n_discard);
|
||||
|
||||
if (slot.params.cache_prompt) {
|
||||
for (size_t i = n_keep + n_discard; i < slot.cache_tokens.size(); i++) {
|
||||
|
@ -2084,7 +2090,6 @@ struct server_context {
|
|||
// if input prompt is too big, truncate it (if group attention self-extend is disabled)
|
||||
if (slot.ga_n == 1 && slot.n_prompt_tokens >= slot.n_ctx) {
|
||||
const int n_left = slot.n_ctx - slot.params.n_keep;
|
||||
|
||||
const int n_block_size = n_left / 2;
|
||||
const int erased_blocks = (slot.n_prompt_tokens - slot.params.n_keep - n_block_size) / n_block_size;
|
||||
|
||||
|
@ -2161,12 +2166,14 @@ struct server_context {
|
|||
int p0 = (int) system_tokens.size() + slot.n_past;
|
||||
if (!llama_kv_cache_seq_rm(ctx, slot.id + 1, p0, -1)) {
|
||||
// could not partially delete (likely using a non-Transformer model)
|
||||
llama_kv_cache_seq_rm(ctx, slot.id + 1, -1, -1);
|
||||
llama_kv_cache_seq_rm (ctx, slot.id + 1, -1, -1);
|
||||
llama_send_kv_cache_seq_rm(ctx, slot.id , -1, -1);
|
||||
|
||||
p0 = (int) system_tokens.size();
|
||||
if (p0 != 0) {
|
||||
// copy over the system prompt when there is one
|
||||
llama_kv_cache_seq_cp(ctx, 0, slot.id + 1, -1, -1);
|
||||
llama_kv_cache_seq_cp (ctx, 0, slot.id + 1, -1, -1);
|
||||
llama_send_kv_cache_seq_cp(ctx, 0, slot.id , -1, -1);
|
||||
}
|
||||
|
||||
// there is no common part left (except for the system prompt)
|
||||
|
@ -2175,6 +2182,8 @@ struct server_context {
|
|||
slot.ga_i = 0;
|
||||
// TODO: is the system prompt ever in the sampling context?
|
||||
gpt_sampler_reset(slot.smpl);
|
||||
} else {
|
||||
llama_send_kv_cache_seq_rm(ctx, slot.id, p0, -1);
|
||||
}
|
||||
|
||||
// remove the non-common part from the cache
|
||||
|
@ -2260,9 +2269,14 @@ struct server_context {
|
|||
SLT_DBG(slot, "div: [%6d, %6d] / %6d -> [%6d, %6d]\n", slot.ga_i + ib * bd, slot.ga_i + ib * bd + slot.ga_w, slot.ga_n, (slot.ga_i + ib * bd) / slot.ga_n, (slot.ga_i + ib * bd + slot.ga_w) / slot.ga_n);
|
||||
SLT_DBG(slot, "shift: [%6d, %6d] + %6d -> [%6d, %6d]\n", slot.ga_i + ib * bd + slot.ga_w, slot.n_past_se + ib * bd, dd, slot.ga_i + ib * bd + slot.ga_w + dd, slot.n_past_se + ib * bd + dd);
|
||||
|
||||
llama_kv_cache_seq_add(ctx, slot.id + 1, slot.ga_i, slot.n_past_se, ib * bd);
|
||||
llama_kv_cache_seq_div(ctx, slot.id + 1, slot.ga_i + ib * bd, slot.ga_i + ib * bd + slot.ga_w, slot.ga_n);
|
||||
llama_kv_cache_seq_add(ctx, slot.id + 1, slot.ga_i + ib * bd + slot.ga_w, slot.n_past_se + ib * bd, dd);
|
||||
llama_kv_cache_seq_add (ctx, slot.id + 1, slot.ga_i, slot.n_past_se, ib * bd);
|
||||
llama_send_kv_cache_seq_add(ctx, slot.id , slot.ga_i, slot.n_past_se, ib * bd);
|
||||
|
||||
llama_kv_cache_seq_div (ctx, slot.id + 1, slot.ga_i + ib * bd, slot.ga_i + ib * bd + slot.ga_w, slot.ga_n);
|
||||
llama_send_kv_cache_seq_div(ctx, slot.id , slot.ga_i + ib * bd, slot.ga_i + ib * bd + slot.ga_w, slot.ga_n);
|
||||
|
||||
llama_kv_cache_seq_add (ctx, slot.id + 1, slot.ga_i + ib * bd + slot.ga_w, slot.n_past_se + ib * bd, dd);
|
||||
llama_send_kv_cache_seq_add(ctx, slot.id , slot.ga_i + ib * bd + slot.ga_w, slot.n_past_se + ib * bd, dd);
|
||||
|
||||
slot.n_past_se -= bd;
|
||||
|
||||
|
@ -3329,10 +3343,6 @@ int main(int argc, char ** argv) {
|
|||
|
||||
// bind HTTP listen port, run the HTTP server in a thread
|
||||
if (!svr->bind_to_port(params.hostname, params.port)) {
|
||||
//LOG_ERROR("couldn't bind HTTP server socket", {
|
||||
// {"hostname", params.hostname},
|
||||
// {"port", params.port},
|
||||
//});
|
||||
LOG_ERR("%s: couldn't bind HTTP server socket, hostname: %s, port: %d\n", __func__, params.hostname.c_str(), params.port);
|
||||
clean_up();
|
||||
return 1;
|
||||
|
@ -3377,10 +3387,6 @@ int main(int argc, char ** argv) {
|
|||
ctx_server.queue_tasks.terminate();
|
||||
};
|
||||
|
||||
LOG_INF("%s: server is listening on %s:%d - starting the main loop\n", __func__, params.hostname.c_str(), params.port);
|
||||
|
||||
ctx_server.queue_tasks.start_loop();
|
||||
|
||||
#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__))
|
||||
struct sigaction sigint_action;
|
||||
sigint_action.sa_handler = signal_handler;
|
||||
|
@ -3395,6 +3401,13 @@ int main(int argc, char ** argv) {
|
|||
SetConsoleCtrlHandler(reinterpret_cast<PHANDLER_ROUTINE>(console_ctrl_handler), true);
|
||||
#endif
|
||||
|
||||
LOG_INF("%s: server is listening on %s:%d - starting the main loop\n", __func__, params.hostname.c_str(), params.port);
|
||||
|
||||
ctx_server.queue_tasks.start_loop();
|
||||
|
||||
char * stop_signal = nullptr;
|
||||
llama_free_sockets(ctx_server.ctx, &stop_signal);
|
||||
|
||||
clean_up();
|
||||
t.join();
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue