mirror of
https://github.com/LostRuins/koboldcpp.git
synced 2026-04-28 03:30:20 +00:00
Checkpoint every n tokens: squash (#20087)
This commit is contained in:
parent
f6235a41ef
commit
f5ddcd1696
3 changed files with 77 additions and 48 deletions
|
|
@ -12,6 +12,7 @@
|
|||
#include "mtmd.h"
|
||||
#include "mtmd-helper.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <cstddef>
|
||||
#include <cinttypes>
|
||||
#include <memory>
|
||||
|
|
@ -2348,8 +2349,10 @@ private:
|
|||
const auto it = std::find_if(
|
||||
slot.prompt.checkpoints.rbegin(),
|
||||
slot.prompt.checkpoints.rend(),
|
||||
[&](const auto & cur) {
|
||||
[&, func_name = __func__](const auto & cur) {
|
||||
// guarantee that a checkpoint will result in at least one token being processed [TAG_PROMPT_LOGITS]
|
||||
LOG_INF("slot %12.*s: id %2d | task %d | Checking checkpoint with [%d, %d] against %d...\n", 12,
|
||||
func_name, (slot).id, ((slot).task ? (slot).task->id : -1), cur.pos_min, cur.pos_max, pos_min_thold);
|
||||
return cur.pos_min < pos_min_thold;
|
||||
}
|
||||
);
|
||||
|
|
@ -2533,48 +2536,66 @@ private:
|
|||
slot.i_batch = batch.n_tokens - 1;
|
||||
|
||||
slot.init_sampler();
|
||||
|
||||
const auto pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx), slot.id);
|
||||
const auto pos_max = llama_memory_seq_pos_max(llama_get_memory(ctx), slot.id);
|
||||
|
||||
// no need for empty or small checkpoints
|
||||
do_checkpoint = do_checkpoint && (pos_min >= 0 && pos_max >= 64);
|
||||
|
||||
// no need to create checkpoints that are too close together
|
||||
do_checkpoint = do_checkpoint && (slot.prompt.checkpoints.empty() || pos_max > slot.prompt.checkpoints.back().pos_max + 64);
|
||||
|
||||
// note: we create the checkpoint before calling llama_decode(), so the current batch is not
|
||||
// yet processed and therefore it is not part of the checkpoint.
|
||||
if (do_checkpoint) {
|
||||
while (slot.prompt.checkpoints.size() >= (size_t) params_base.n_ctx_checkpoints) {
|
||||
// make room for the new checkpoint, if needed
|
||||
const auto & cur = slot.prompt.checkpoints.front();
|
||||
|
||||
SLT_WRN(slot, "erasing old context checkpoint (pos_min = %d, pos_max = %d, n_tokens = %" PRId64 ", size = %.3f MiB)\n",
|
||||
cur.pos_min, cur.pos_max, cur.n_tokens, (float) cur.data.size() / 1024 / 1024);
|
||||
|
||||
slot.prompt.checkpoints.erase(slot.prompt.checkpoints.begin());
|
||||
}
|
||||
|
||||
const size_t checkpoint_size = llama_state_seq_get_size_ext(ctx, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
|
||||
|
||||
auto & cur = slot.prompt.checkpoints.emplace_back(server_prompt_checkpoint{
|
||||
/*.pos_min = */ pos_min,
|
||||
/*.pos_max = */ pos_max,
|
||||
/*.n_tokens = */ slot.prompt.n_tokens() - batch.n_tokens,
|
||||
/*.data = */ std::vector<uint8_t>(checkpoint_size),
|
||||
});
|
||||
|
||||
llama_state_seq_get_data_ext(ctx, cur.data.data(), checkpoint_size, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
|
||||
|
||||
SLT_WRN(slot, "created context checkpoint %d of %d (pos_min = %d, pos_max = %d, n_tokens = %" PRId64 ", size = %.3f MiB)\n",
|
||||
(int) slot.prompt.checkpoints.size(), params_base.n_ctx_checkpoints, cur.pos_min, cur.pos_max, cur.n_tokens, (float) cur.data.size() / 1024 / 1024);
|
||||
}
|
||||
|
||||
SLT_INF(slot, "prompt processing done, n_tokens = %d, batch.n_tokens = %d\n", slot.prompt.n_tokens(), batch.n_tokens);
|
||||
} else {
|
||||
// only do non-end checkpoints if the "checkpoint every n tokens" option is set
|
||||
do_checkpoint = do_checkpoint && params_base.checkpoint_every_nt > 0;
|
||||
if (do_checkpoint) {
|
||||
llama_pos last_checkpoint = 0;
|
||||
if (!slot.prompt.checkpoints.empty()) {
|
||||
last_checkpoint = slot.prompt.checkpoints.back().n_tokens;
|
||||
}
|
||||
do_checkpoint = do_checkpoint && slot.prompt.n_tokens() - batch.n_tokens - last_checkpoint >= params_base.checkpoint_every_nt;
|
||||
if (do_checkpoint) {
|
||||
SLT_INF(slot, "%d tokens since last checkpoint at %d, creating new checkpoint during processing at position %d\n", params_base.checkpoint_every_nt, last_checkpoint, slot.prompt.n_tokens());
|
||||
}
|
||||
}
|
||||
SLT_INF(slot, "prompt processing progress, n_tokens = %d, batch.n_tokens = %d, progress = %f\n", slot.prompt.n_tokens(), batch.n_tokens, (float) slot.prompt.n_tokens() / slot.task->n_tokens());
|
||||
}
|
||||
|
||||
const auto pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx), slot.id);
|
||||
const auto pos_max = llama_memory_seq_pos_max(llama_get_memory(ctx), slot.id);
|
||||
|
||||
// no need for empty or small checkpoints
|
||||
do_checkpoint = do_checkpoint && (pos_min >= 0 && pos_max >= 64);
|
||||
|
||||
// no need to create checkpoints that are too close together
|
||||
do_checkpoint = do_checkpoint && (slot.prompt.checkpoints.empty() || pos_max > slot.prompt.checkpoints.back().pos_max + 64);
|
||||
|
||||
// note: we create the checkpoint before calling llama_decode(), so the current batch is not
|
||||
// yet processed and therefore it is not part of the checkpoint.
|
||||
if (do_checkpoint) {
|
||||
while (slot.prompt.checkpoints.size() >= (size_t) params_base.n_ctx_checkpoints) {
|
||||
// make room for the new checkpoint, if needed
|
||||
const auto & cur = slot.prompt.checkpoints.front();
|
||||
|
||||
SLT_WRN(slot,
|
||||
"erasing old context checkpoint (pos_min = %d, pos_max = %d, n_tokens = %" PRId64
|
||||
", size = %.3f MiB)\n",
|
||||
cur.pos_min, cur.pos_max, cur.n_tokens, (float) cur.data.size() / 1024 / 1024);
|
||||
|
||||
slot.prompt.checkpoints.erase(slot.prompt.checkpoints.begin());
|
||||
}
|
||||
|
||||
const size_t checkpoint_size =
|
||||
llama_state_seq_get_size_ext(ctx, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
|
||||
|
||||
auto & cur = slot.prompt.checkpoints.emplace_back(server_prompt_checkpoint{
|
||||
/*.pos_min = */ pos_min,
|
||||
/*.pos_max = */ pos_max,
|
||||
/*.n_tokens = */ slot.prompt.n_tokens() - batch.n_tokens,
|
||||
/*.data = */ std::vector<uint8_t>(checkpoint_size),
|
||||
});
|
||||
|
||||
llama_state_seq_get_data_ext(ctx, cur.data.data(), checkpoint_size, slot.id,
|
||||
LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
|
||||
|
||||
SLT_WRN(slot,
|
||||
"created context checkpoint %d of %d (pos_min = %d, pos_max = %d, n_tokens = %" PRId64
|
||||
", size = %.3f MiB)\n",
|
||||
(int) slot.prompt.checkpoints.size(), params_base.n_ctx_checkpoints, cur.pos_min,
|
||||
cur.pos_max, cur.n_tokens, (float) cur.data.size() / 1024 / 1024);
|
||||
}
|
||||
}
|
||||
|
||||
if (!slot_batched) {
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue