mirror of
https://github.com/LostRuins/koboldcpp.git
synced 2026-05-31 05:03:44 +00:00
server: fix checkpoints creation (#22929)
* common : add common_chat_split_by_role * cont : fix spans to reach end of message * server: fix checkpoints creation - extract message_spans from chat templates - find the prompt token position before the latest user message - split prompt batching at that position - create a context checkpoint before the latest user input - avoid periodic mid-prompt checkpoints when that position is known - handle multimodal prompts when mapping text/template positions to server prompt tokens - add --checkpoint-min-step to control minimum spacing between checkpoints * cont : clean-up * Support autoparser detection for message barriers * server: fix message span delimiter and update docs --------- Co-authored-by: Alde Rojas <hello@alde.dev> Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> Co-authored-by: Piotr Wilkin <piotr.wilkin@syndatis.com>
This commit is contained in:
parent
6d57c26ef8
commit
e2ef8fe42c
15 changed files with 586 additions and 37 deletions
|
|
@ -1103,6 +1103,13 @@ private:
|
|||
}
|
||||
SRV_INF("%s", "for more info see https://github.com/ggml-org/llama.cpp/pull/16391\n");
|
||||
|
||||
if (params_base.n_ctx_checkpoints > 0) {
|
||||
SRV_INF("context checkpoints enabled, max = %d, min spacing = %d\n",
|
||||
params_base.n_ctx_checkpoints, params_base.checkpoint_min_step);
|
||||
} else {
|
||||
SRV_INF("%s", "context checkpoints disabled\n");
|
||||
}
|
||||
|
||||
if (!params_base.model_alias.empty()) {
|
||||
// backward compat: use first alias as model name
|
||||
model_name = *params_base.model_alias.begin();
|
||||
|
|
@ -2758,8 +2765,6 @@ private:
|
|||
}
|
||||
|
||||
if (pos_min >= pos_min_thold) {
|
||||
SLT_WRN(slot, "n_past = %d, slot.prompt.tokens.size() = %d, seq_id = %d, pos_min = %d, n_swa = %d\n", n_past, (int) slot.prompt.tokens.size(), slot.id, pos_min, n_swa);
|
||||
|
||||
// search for a context checkpoint
|
||||
const auto it = std::find_if(
|
||||
slot.prompt.checkpoints.rbegin(),
|
||||
|
|
@ -2776,7 +2781,6 @@ private:
|
|||
|
||||
if (!do_reset) {
|
||||
// restore the context checkpoint
|
||||
|
||||
it->load_tgt(ctx_tgt, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
|
||||
it->load_dft(ctx_dft.get(), slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
|
||||
|
||||
|
|
@ -2912,6 +2916,9 @@ private:
|
|||
has_mtmd = true;
|
||||
}
|
||||
|
||||
const int32_t n_before_user = slot.task->params.n_before_user;
|
||||
const bool n_before_user_known = n_before_user > 0;
|
||||
|
||||
// add prompt tokens for processing in the current batch
|
||||
while (slot.prompt.n_tokens() < slot.task->n_tokens() && batch.n_tokens < n_batch) {
|
||||
// get next token to process
|
||||
|
|
@ -2940,6 +2947,13 @@ private:
|
|||
|
||||
slot.n_prompt_tokens_processed++;
|
||||
|
||||
// stop the prompt batch exactly before the latest user input, so a checkpoint
|
||||
// can be created after the previous messages
|
||||
if (n_before_user_known &&
|
||||
slot.prompt.n_tokens() == n_before_user) {
|
||||
break;
|
||||
}
|
||||
|
||||
// process the last few tokens of the prompt separately in order to allow for a checkpoint to be created.
|
||||
// create checkpoints that many tokens before the end of the prompt:
|
||||
// - 4 + n_ubatch
|
||||
|
|
@ -2965,6 +2979,8 @@ private:
|
|||
// the number of tokens added to the batch for the current slot
|
||||
const auto n_tokens_cur = batch.n_tokens - n_tokens_prev;
|
||||
|
||||
const bool near_prompt_end = slot.task->n_tokens() < slot.prompt.n_tokens() + n_ubatch;
|
||||
|
||||
// entire prompt has been processed
|
||||
if (slot.prompt.n_tokens() == slot.task->n_tokens()) {
|
||||
slot.state = SLOT_STATE_DONE_PROMPT;
|
||||
|
|
@ -2979,39 +2995,49 @@ private:
|
|||
|
||||
slot.init_sampler();
|
||||
} else {
|
||||
if (slot.task->n_tokens() < slot.prompt.n_tokens() + n_ubatch) {
|
||||
// near the end of the prompt
|
||||
do_checkpoint = do_checkpoint && true;
|
||||
} else {
|
||||
// only do non-end checkpoints if the "checkpoint every n tokens" option is set
|
||||
do_checkpoint = do_checkpoint && params_base.checkpoint_every_nt > 0;
|
||||
|
||||
if (do_checkpoint) {
|
||||
llama_pos last_checkpoint = 0;
|
||||
if (!slot.prompt.checkpoints.empty()) {
|
||||
last_checkpoint = slot.prompt.checkpoints.back().n_tokens;
|
||||
}
|
||||
|
||||
do_checkpoint = do_checkpoint && slot.prompt.n_tokens() - batch.n_tokens - last_checkpoint >= params_base.checkpoint_every_nt;
|
||||
|
||||
if (do_checkpoint) {
|
||||
SLT_INF(slot, "%d tokens since last checkpoint at %d, creating new checkpoint during processing at position %d\n", params_base.checkpoint_every_nt, last_checkpoint, slot.prompt.n_tokens());
|
||||
}
|
||||
}
|
||||
// skip ordinary mid-prompt checkpoints
|
||||
if (!n_before_user_known && !near_prompt_end) {
|
||||
do_checkpoint = false;
|
||||
}
|
||||
}
|
||||
|
||||
const auto pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx_tgt), slot.id);
|
||||
const auto pos_max = llama_memory_seq_pos_max(llama_get_memory(ctx_tgt), slot.id);
|
||||
|
||||
// no need for empty or small checkpoints
|
||||
do_checkpoint = do_checkpoint && (pos_min >= 0 && slot.prompt.n_tokens() >= 64);
|
||||
// checkpoints are created before the current batch is decoded, so
|
||||
// their token position is the batch start rather than the prompt end
|
||||
const int32_t n_tokens_start = slot.prompt.n_tokens() - n_tokens_cur;
|
||||
|
||||
{
|
||||
const bool is_on_user =
|
||||
n_before_user_known &&
|
||||
n_tokens_start == n_before_user;
|
||||
|
||||
const bool is_after_user =
|
||||
n_before_user_known &&
|
||||
n_tokens_start > n_before_user;
|
||||
|
||||
const bool is_allowed =
|
||||
!n_before_user_known ||
|
||||
is_on_user ||
|
||||
(is_after_user && near_prompt_end);
|
||||
|
||||
if (do_checkpoint && !is_allowed) {
|
||||
do_checkpoint = false;
|
||||
}
|
||||
}
|
||||
|
||||
// nothing to checkpoint yet
|
||||
// TODO: is this check needed?
|
||||
if (do_checkpoint && pos_min < 0) {
|
||||
do_checkpoint = false;
|
||||
}
|
||||
|
||||
// do not checkpoint after mtmd chunks
|
||||
do_checkpoint = do_checkpoint && !has_mtmd;
|
||||
|
||||
// no need to create checkpoints that are too close together
|
||||
do_checkpoint = do_checkpoint && (slot.prompt.checkpoints.empty() || slot.prompt.n_tokens() - n_tokens_cur > slot.prompt.checkpoints.back().n_tokens + 64);
|
||||
do_checkpoint = do_checkpoint && (slot.prompt.checkpoints.empty() || n_tokens_start > slot.prompt.checkpoints.back().n_tokens + params_base.checkpoint_min_step);
|
||||
SLT_DBG(slot, "main/do_checkpoint = %s, pos_min = %d, pos_max = %d\n", do_checkpoint ? "yes" : "no", pos_min, pos_max);
|
||||
|
||||
// note: we create the checkpoint before calling llama_decode(), so the current batch is not
|
||||
|
|
@ -3528,6 +3554,53 @@ void server_context::on_sleeping_changed(std::function<void(bool)> callback) {
|
|||
impl->queue_tasks.on_sleeping_state(std::move(callback));
|
||||
}
|
||||
|
||||
// compute the number of tokens before the last user message in the prompt
|
||||
static int32_t prompt_get_n_before_user(
|
||||
const json & message_spans,
|
||||
const std::string & prompt,
|
||||
const std::vector<raw_buffer> & files,
|
||||
const llama_vocab * vocab,
|
||||
mtmd_context * mctx) {
|
||||
int32_t result = -1;
|
||||
int32_t byte_pos = -1;
|
||||
|
||||
for (const auto & span : message_spans) {
|
||||
const std::string role = json_value(span, "role", std::string());
|
||||
|
||||
if (role == "user") {
|
||||
byte_pos = json_value(span, "pos", -1);
|
||||
}
|
||||
}
|
||||
|
||||
if (byte_pos >= 0) {
|
||||
GGML_ASSERT((size_t) byte_pos <= prompt.size());
|
||||
|
||||
const std::string prefix = prompt.substr(0, (size_t) byte_pos);
|
||||
|
||||
const std::string marker = get_media_marker();
|
||||
size_t n_prefix_media = 0;
|
||||
for (size_t pos = 0; (pos = prefix.find(marker, pos)) != std::string::npos; pos += marker.size()) {
|
||||
n_prefix_media++;
|
||||
}
|
||||
|
||||
GGML_ASSERT(n_prefix_media <= files.size());
|
||||
|
||||
if (mctx != nullptr && n_prefix_media > 0) {
|
||||
// TODO: this makes a copy - avoid it
|
||||
std::vector<raw_buffer> prefix_files(files.begin(), files.begin() + n_prefix_media);
|
||||
|
||||
result = (int32_t) process_mtmd_prompt(mctx, prefix, prefix_files).size();
|
||||
} else {
|
||||
result = (int32_t) tokenize_input_prompts(vocab, nullptr, prefix, true, true)[0].size();
|
||||
}
|
||||
|
||||
SRV_TRC("message_spans: last user message: byte_pos=%d, media=%zu, n_before_user=%d\n",
|
||||
byte_pos, n_prefix_media, result);
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
|
||||
//
|
||||
// server_routes
|
||||
|
|
@ -3577,6 +3650,18 @@ std::unique_ptr<server_res_generator> server_routes::handle_completions_impl(
|
|||
meta->slot_n_ctx,
|
||||
meta->logit_bias_eog,
|
||||
data);
|
||||
|
||||
const auto message_spans = json_value(data, "message_spans", json::array());
|
||||
if (prompt.is_string() && message_spans.is_array()) {
|
||||
task.params.n_before_user =
|
||||
prompt_get_n_before_user(
|
||||
message_spans,
|
||||
prompt.get<std::string>(),
|
||||
files,
|
||||
ctx_server.vocab,
|
||||
ctx_server.mctx);
|
||||
}
|
||||
|
||||
task.id_slot = json_value(data, "id_slot", -1);
|
||||
|
||||
// OAI-compat
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue