server : speculative checkpointing (#19493)

* server : speculative decoding using checkpoints

* server : fix draft check with checkpoints

* server : rename spec vars

* server : log levels

* server : refactored spec logic to speculative.cpp

* server : renamed spec checkpoints option

* server : fix spec checkpoints, logging

* speculative : checkpoints with draft model, logging

* server : n_tokens_cur and create_checkpoint in draft

* server : fix server_speculative_callback (slot.id)

* spec : fix ngram-map/begin idx_last_check

* spec : init ckpt (begin() wasn't called)

* chore: update webui build output

* server : restore sampler in spec checkpoint and clear mem

* cont : avoid --spec-use-checkpoints argument

* cont : remove server_prompt_checkpoint_with_size

* spec : rename (leave_draft_state)

* cont : clean-up

* cont : do not ignore partial drafts even if the are short

* cont : spec callback owned by session

* cont : simplify

* cont : avoid empty speculative session

* cont : simplify

* cont : simplify

* cont : enable mtmd speculative decoding

* cont : keep the spec sampler alive

* cont : simplify

* cont : fix nullptr deref + draft checkpoints

* cont : remove common_speculative_accept_response

* cont : remove callback

* cont : simplify

* cont : minor

* cont : simplify

* cont : fix accepted number

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
This commit is contained in:
Sascha Rogmann 2026-04-19 09:24:06 +02:00 committed by GitHub
parent 91fef95362
commit 455d8e4be8
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
10 changed files with 421 additions and 180 deletions

View file

@ -1,3 +1,4 @@
#include "server-context.h"
#include "server-common.h"
#include "server-http.h"
@ -19,6 +20,7 @@
#include <exception>
#include <memory>
#include <filesystem>
#include <utility>
// fix problem with std::min and std::max
#if defined(_WIN32)
@ -33,6 +35,31 @@ using json = nlohmann::ordered_json;
constexpr int HTTP_POLLING_SECONDS = 1;
static server_prompt_checkpoint server_get_checkpoint(llama_context * ctx, int id, int64_t n_tokens, llama_pos pos_min = -1, llama_pos pos_max = -1) {
if (pos_min == -1) {
pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx), id);
}
if (pos_max == -1) {
pos_max = llama_memory_seq_pos_max(llama_get_memory(ctx), id);
}
const size_t checkpoint_size = llama_state_seq_get_size_ext(ctx, id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
auto cur = server_prompt_checkpoint {
/*.pos_min = */ pos_min,
/*.pos_max = */ pos_max,
/*.n_tokens = */ n_tokens,
/*.data = */ std::vector<uint8_t>(checkpoint_size),
};
const size_t n = llama_state_seq_get_data_ext(ctx, cur.data.data(), checkpoint_size, id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
if (n != checkpoint_size) {
GGML_ABORT("checkpoint size mismatch: expected %zu, got %zu\n", checkpoint_size, n);
}
return cur;
}
// state diagram: https://github.com/ggml-org/llama.cpp/pull/9283
enum slot_state {
SLOT_STATE_IDLE,
@ -57,7 +84,12 @@ struct server_slot {
// multimodal
mtmd_context * mctx = nullptr;
common_speculative * spec = nullptr;
// speculative decoding
llama_tokens spec_draft;
std::vector<int32_t> spec_i_batch;
server_prompt_checkpoint spec_ckpt;
common_speculative_ptr spec;
// TODO: move members that belong to the task (such as `generated_text`, `has_new_line`) to task_results_state
// see https://github.com/ggml-org/llama.cpp/pull/18283#issuecomment-3710175837
@ -83,11 +115,6 @@ struct server_slot {
std::string debug_generated_text;
llama_tokens generated_tokens;
// idx of draft tokens in the main batch
// non-empty if we went to evaluate draft tokens
// ref: https://github.com/ggml-org/llama.cpp/pull/17808
std::vector<int32_t> i_batch_dft;
std::vector<completion_token_output> generated_token_probs;
bool has_next_token = true;
@ -147,8 +174,7 @@ struct server_slot {
common_sampler_ptr smpl;
llama_token sampled; // in speculative mode, this is the last accepted token
llama_tokens drafted;
llama_token sampled; // in speculative mode, this is the last accepted token
// stats
size_t n_sent_text = 0; // number of sent text character
@ -178,8 +204,11 @@ struct server_slot {
stopping_word = "";
n_sent_text = 0;
drafted.clear();
i_batch_dft.clear();
if (can_speculate()) {
spec_draft.clear();
spec_i_batch.clear();
spec_ckpt.clear();
}
generated_tokens.clear();
generated_token_probs.clear();
json_schema = json();
@ -300,6 +329,85 @@ struct server_slot {
return n_draft_max;
}
void update_batch(llama_batch & batch) {
const int n_draft_max = get_n_draft_max();
if (n_draft_max > 0) {
GGML_ASSERT(can_speculate());
// generate draft tokens in speculative decoding mode
// TODO: rework to have a single draft llama_context shared across all slots [TAG_SERVER_SPEC_REWORK]
// perform the speculative drafting for all sequences at the same time in a single batch
const llama_tokens & tokens = prompt.tokens.get_text_tokens();
const auto & params_spec = task->params.speculative;
if (!spec_draft.empty()) {
// we have a previous (partial) draft to reuse
if (task->params.speculative.use_checkpoints) {
GGML_ASSERT(!spec_ckpt.empty());
}
} else {
GGML_ASSERT(spec_i_batch.empty());
// generate a new draft
spec_draft = common_speculative_draft(spec.get(), params_spec, tokens, sampled);
if (spec_draft.size() > (size_t) n_draft_max) {
SLT_WRN(*this, "draft size %d exceeds max %d, truncating\n", (int) spec_draft.size(), n_draft_max);
spec_draft.resize(n_draft_max);
}
if (spec_draft.size() < (size_t) params_spec.n_min) {
SLT_DBG(*this, "ignoring small draft: %d < %d\n", (int) spec_draft.size(), params_spec.n_min);
spec_draft.clear();
}
if (!spec_draft.empty() && params_spec.use_checkpoints) {
const auto n_tokens = prompt.tokens.size();
auto & ckpt = spec_ckpt;
ckpt = server_get_checkpoint(ctx, this->id, n_tokens);
SLT_DBG(*this, "created speculative checkpoint (pos_min = %d, pos_max = %d, n_tokens = %zu, size = %.3f MiB)\n",
ckpt.pos_min, ckpt.pos_max, n_tokens, (float) ckpt.data.size() / 1024 / 1024);
}
}
GGML_ASSERT(spec_draft.size() <= (size_t) n_draft_max);
}
if (spec_draft.empty()) {
// no speculative decoding
i_batch = batch.n_tokens;
common_batch_add(batch, sampled, prompt.tokens.pos_next(), { this->id }, true);
SLT_DBG(*this, "slot decode token, id=%d, n_ctx = %d, n_tokens = %d, truncated = %d\n",
sampled, n_ctx, prompt.n_tokens(), truncated);
} else {
SLT_DBG(*this, "generate_draft: id=%d, #tokens=%zu, #draft=%zu, pos_next=%d\n",
sampled, prompt.tokens.size(), spec_draft.size(), prompt.tokens.pos_next());
GGML_ASSERT(spec_i_batch.empty());
spec_i_batch.push_back(batch.n_tokens);
for (size_t i = 0; i < spec_draft.size(); i++) {
spec_i_batch.push_back(batch.n_tokens + i + 1);
}
auto pos0 = prompt.tokens.pos_next();
common_batch_add(batch, sampled, pos0++, { this->id }, true);
for (auto token : spec_draft) {
common_batch_add(batch, token, pos0++, { this->id }, true);
}
}
prompt.tokens.push_back(sampled);
prompt.tokens.insert(spec_draft);
}
void release() {
if (is_processing()) {
GGML_ASSERT(task);
@ -400,7 +508,7 @@ struct server_slot {
);
}
common_speculative_print_stats(spec);
common_speculative_print_stats(spec.get());
}
json to_json(bool only_metrics = false) const {
@ -591,16 +699,17 @@ private:
void destroy() {
llama_init.reset();
ctx = nullptr;
model = nullptr;
mtmd_free(mctx);
mctx = nullptr;
// Clear any sampling context
for (server_slot & slot : slots) {
common_speculative_free(slot.spec);
slot.spec = nullptr;
if (slot.can_speculate()) {
slot.spec.reset();
}
}
llama_batch_free(batch);
@ -642,9 +751,6 @@ private:
llama_init = common_init_from_params(params_base);
// propagate model-metadata sampling defaults back to caller
params.sampling = params_base.sampling;
model = llama_init->model();
ctx = llama_init->context();
@ -660,6 +766,7 @@ private:
add_bos_token = llama_vocab_get_add_bos(vocab);
if (params_base.speculative.has_dft()) {
// TODO speculative: move to common/speculative.cpp?
SRV_INF("loading draft model '%s'\n", params_base.speculative.mparams_dft.path.c_str());
const auto & params_spec = params_base.speculative;
@ -727,11 +834,6 @@ private:
params_base.n_cache_reuse = 0;
SRV_WRN("%s\n", "cache_reuse is not supported by multimodal, it will be disabled");
}
if (params_base.speculative.type != COMMON_SPECULATIVE_TYPE_NONE) {
params_base.speculative.type = COMMON_SPECULATIVE_TYPE_NONE;
SRV_WRN("%s\n", "speculative decoding is not supported by multimodal, it will be disabled");
}
}
if (!llama_memory_can_shift(llama_get_memory(ctx))) {
@ -769,14 +871,23 @@ private:
slots.clear();
const bool can_spec = common_speculative_is_compat(ctx);
if (!can_spec) {
const auto spec_type = common_speculative_is_compat(ctx);
if (spec_type == COMMON_SPECULATIVE_COMPAT_TYPE_NO) {
SRV_WRN("%s", "speculative decoding not supported by this context\n");
}
if (spec_type == COMMON_SPECULATIVE_COMPAT_TYPE_CKPT) {
SRV_WRN("%s", "speculative decoding will use checkpoints\n");
params_base.speculative.use_checkpoints = true;
}
// initialize slots
for (int i = 0; i < params_base.n_parallel; i++) {
server_slot slot;
slots.emplace_back();
}
for (int i = 0; i < params_base.n_parallel; i++) {
server_slot & slot = slots[i];
slot.id = i;
slot.ctx = ctx;
@ -786,16 +897,11 @@ private:
slot.prompt.tokens.has_mtmd = mctx != nullptr;
// try speculative decoding
if (can_spec) {
slot.spec = common_speculative_init(params_base.speculative, slot.ctx);
if (spec_type != COMMON_SPECULATIVE_COMPAT_TYPE_NO) {
slot.spec.reset(common_speculative_init(params_base.speculative, slot.ctx));
if (slot.spec) {
if (mctx) {
SRV_ERR("%s\n", "speculative decoding is not supported with multimodal");
return false;
}
SLT_INF(slot, "%s", "speculative decoding context initialized\n");
} else {
SLT_INF(slot, "%s", "speculative decoding context not initialized\n");
}
}
@ -806,8 +912,6 @@ private:
};
slot.reset();
slots.push_back(std::move(slot));
}
{
@ -854,6 +958,9 @@ private:
model_aliases = params_base.model_alias;
model_tags = params_base.model_tags;
// propagate new defaults back to caller
params = params_base;
if (!is_resume) {
return init();
}
@ -1197,7 +1304,7 @@ private:
backend_sampling &= task.params.sampling.backend_sampling;
// TODO: speculative decoding requires multiple samples per batch - not supported yet
backend_sampling &= !(slot.spec && task.params.speculative.n_max > 0);
backend_sampling &= !(slot.can_speculate() && task.params.speculative.n_max > 0);
// TODO: getting post/pre sampling logits is not yet supported with backend sampling
backend_sampling &= !need_logits;
@ -1703,6 +1810,26 @@ private:
return true;
}
// n_tokens_cur: the number of tokens added to the batch for the current slot
void create_checkpoint(server_slot & slot, const int64_t n_tokens_cur, llama_pos pos_min, llama_pos pos_max) {
while (slot.prompt.checkpoints.size() >= (size_t) params_base.n_ctx_checkpoints) {
// make room for the new checkpoint, if needed
const auto & cur = slot.prompt.checkpoints.front();
SLT_WRN(slot, "erasing old context checkpoint (pos_min = %d, pos_max = %d, n_tokens = %" PRId64 ", size = %.3f MiB)\n",
cur.pos_min, cur.pos_max, cur.n_tokens, (float) cur.data.size() / 1024 / 1024);
slot.prompt.checkpoints.erase(slot.prompt.checkpoints.begin());
}
const auto & cur = slot.prompt.checkpoints.emplace_back(server_get_checkpoint(ctx, slot.id, slot.prompt.n_tokens() - n_tokens_cur, pos_min, pos_max));
SLT_WRN(slot,
"created context checkpoint %d of %d (pos_min = %d, pos_max = %d, n_tokens = %" PRId64 ", size = %.3f MiB)\n",
(int) slot.prompt.checkpoints.size(), params_base.n_ctx_checkpoints, cur.pos_min,
cur.pos_max, cur.n_tokens, (float) cur.data.size() / 1024 / 1024);
}
void process_single_task(server_task && task) {
switch (task.type) {
case SERVER_TASK_TYPE_COMPLETION:
@ -1854,7 +1981,7 @@ private:
std::string filename = task.slot_action.filename;
std::string filepath = task.slot_action.filepath;
const llama_tokens & tokens = slot->prompt.tokens.get_text_tokens();
const llama_tokens & tokens = slot->prompt.tokens.get_tokens();
const size_t nwrite = llama_state_seq_save_file(ctx, filepath.c_str(), slot->id, tokens.data(), token_count);
const int64_t t_end = ggml_time_us();
@ -2061,7 +2188,7 @@ private:
{
GGML_ASSERT(!slot.prompt.tokens.has_mtmd);
llama_tokens new_tokens = slot.prompt.tokens.get_text_tokens(); // copy
llama_tokens new_tokens = slot.prompt.tokens.get_tokens(); // copy
for (size_t i = n_keep + n_discard; i < new_tokens.size(); i++) {
new_tokens[i - n_discard] = new_tokens[i];
}
@ -2100,61 +2227,7 @@ private:
continue;
}
// generate draft tokens in speculative decoding mode
// TODO: rework to have a single draft llama_context shared across all slots [TAG_SERVER_SPEC_REWORK]
// perform the speculative drafting for all sequences at the same time in a single batch
const int n_draft_max = slot.get_n_draft_max();
if (n_draft_max > 0) {
if (mctx) {
// we should never reach this, as speculative is automatically disabled if mmproj is loaded
GGML_ABORT("not supported by multimodal");
}
const llama_tokens & cached_text_tokens = slot.prompt.tokens.get_text_tokens();
const auto & params_spec = slot.task->params.speculative;
llama_tokens draft = common_speculative_draft(slot.spec, params_spec, cached_text_tokens, slot.sampled);
if (draft.size() > (size_t) n_draft_max) {
SLT_WRN(slot, "draft size %d exceeds max %d, truncating\n", (int) draft.size(), n_draft_max);
draft.resize(n_draft_max);
}
// add the sampled token to the batch
slot.i_batch_dft.push_back(batch.n_tokens);
common_batch_add(batch, slot.sampled, slot.prompt.tokens.pos_next(), { slot.id }, true);
slot.prompt.tokens.push_back(slot.sampled);
if (slot.task->params.speculative.n_min > (int) draft.size()) {
SLT_DBG(slot, "ignoring small draft: %d < %d\n", (int) draft.size(), slot.task->params.speculative.n_min);
// fallback to normal decoding
slot.i_batch = slot.i_batch_dft[0];
slot.drafted.clear();
slot.i_batch_dft.clear();
} else {
// keep track of total number of drafted tokens tested
slot.n_draft_total += draft.size();
// add all drafted tokens to the batch
for (size_t i = 0; i < draft.size(); i++) {
slot.i_batch_dft.push_back(batch.n_tokens);
common_batch_add(batch, draft[i], slot.prompt.tokens.pos_next(), { slot.id }, true);
slot.prompt.tokens.push_back(draft[i]);
}
slot.drafted = std::move(draft);
}
} else {
// no speculative decoding
slot.i_batch = batch.n_tokens;
common_batch_add(batch, slot.sampled, slot.prompt.tokens.pos_next(), { slot.id }, true);
slot.prompt.tokens.push_back(slot.sampled);
SLT_DBG(slot, "slot decode token, n_ctx = %d, n_tokens = %d, truncated = %d\n",
slot.n_ctx, slot.prompt.n_tokens(), slot.truncated);
}
slot.update_batch(batch);
}
// process in chunks of params.n_batch
@ -2651,40 +2724,12 @@ private:
// no need to create checkpoints that are too close together
do_checkpoint = do_checkpoint && (slot.prompt.checkpoints.empty() || slot.prompt.n_tokens() - n_tokens_cur > slot.prompt.checkpoints.back().n_tokens + 64);
SLT_DBG(slot, "main/do_checkpoint = %s, pos_min = %d, pos_max = %d\n", do_checkpoint ? "yes" : "no", pos_min, pos_max);
// note: we create the checkpoint before calling llama_decode(), so the current batch is not
// yet processed and therefore it is not part of the checkpoint.
if (do_checkpoint) {
while (slot.prompt.checkpoints.size() >= (size_t) params_base.n_ctx_checkpoints) {
// make room for the new checkpoint, if needed
const auto & cur = slot.prompt.checkpoints.front();
SLT_WRN(slot,
"erasing old context checkpoint (pos_min = %d, pos_max = %d, n_tokens = %" PRId64
", size = %.3f MiB)\n",
cur.pos_min, cur.pos_max, cur.n_tokens, (float) cur.data.size() / 1024 / 1024);
slot.prompt.checkpoints.erase(slot.prompt.checkpoints.begin());
}
const size_t checkpoint_size =
llama_state_seq_get_size_ext(ctx, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
auto & cur = slot.prompt.checkpoints.emplace_back(server_prompt_checkpoint{
/*.pos_min = */ pos_min,
/*.pos_max = */ pos_max,
/*.n_tokens = */ slot.prompt.n_tokens() - n_tokens_cur,
/*.data = */ std::vector<uint8_t>(checkpoint_size),
});
llama_state_seq_get_data_ext(ctx, cur.data.data(), checkpoint_size, slot.id,
LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
SLT_WRN(slot,
"created context checkpoint %d of %d (pos_min = %d, pos_max = %d, n_tokens = %" PRId64
", size = %.3f MiB)\n",
(int) slot.prompt.checkpoints.size(), params_base.n_ctx_checkpoints, cur.pos_min,
cur.pos_max, cur.n_tokens, (float) cur.data.size() / 1024 / 1024);
create_checkpoint(slot, n_tokens_cur, pos_min, pos_max);
}
}
@ -2856,19 +2901,19 @@ private:
slot.state = SLOT_STATE_GENERATING;
if (slot.can_speculate()) {
common_speculative_begin(slot.spec, slot.prompt.tokens.get_text_tokens());
common_speculative_begin(slot.spec.get(), slot.prompt.tokens.get_text_tokens());
}
} else if (slot.state != SLOT_STATE_GENERATING) {
continue; // continue loop of slots
}
if (slot.i_batch_dft.size() > 0) {
if (slot.can_speculate() && !slot.spec_draft.empty()) {
continue; // sample using speculative decoding
}
const int tok_idx = slot.i_batch - i;
llama_token id = common_sampler_sample(slot.smpl.get(), ctx, tok_idx);
llama_token id = common_sampler_sample(slot.smpl.get(), slot.ctx, tok_idx);
slot.i_batch = -1;
@ -2889,7 +2934,7 @@ private:
completion_token_output result;
result.tok = id;
result.text_to_send = common_token_to_piece(ctx, result.tok, accept_special_token(slot, result.tok));
result.text_to_send = common_token_to_piece(slot.ctx, result.tok, accept_special_token(slot, result.tok));
result.prob = 1.0f; // TODO: set it here instead of doing inside populate_token_probs
if (slot.task->params.sampling.n_probs > 0) {
@ -2909,43 +2954,86 @@ private:
// speculative decoding - main model sample and accept
for (auto & slot : slots) {
if (slot.state != SLOT_STATE_GENERATING || slot.i_batch_dft.empty()) {
if (slot.state != SLOT_STATE_GENERATING || !slot.can_speculate() || slot.spec_draft.empty()) {
continue;
}
const size_t n_draft = slot.drafted.size();
// save the original draft size
const size_t n_draft = slot.spec_draft.size();
// the accepted tokens from the speculation
const auto ids = common_sampler_sample_and_accept_n(slot.smpl.get(), ctx, slot.i_batch_dft, slot.drafted);
slot.i_batch_dft.clear();
slot.drafted.clear();
GGML_ASSERT(n_draft > 0);
// verify and try to accept the draft
{
const auto & params_spec = slot.task->params.speculative;
common_sampler_ptr smpl_save(common_sampler_clone(slot.smpl.get()));
GGML_ASSERT(slot.spec_i_batch.size() == n_draft + 1);
auto accepted = common_sampler_sample_and_accept_n(slot.smpl.get(), slot.ctx, slot.spec_i_batch, slot.spec_draft);
slot.spec_i_batch.clear();
SLT_DBG(slot, "%s: n_draft=%zu, accepted=%zu\n", __func__, slot.spec_draft.size(), accepted.size());
GGML_ASSERT(accepted.size() >= 1);
// check for partial draft acceptance
if (accepted.size() < slot.spec_draft.size() + 1) {
if (params_spec.use_checkpoints) {
// partial acceptance is not supported by the context -> truncate the draft and restore the state
slot.spec_draft = std::move(accepted);
auto & ckpt = slot.spec_ckpt;
SLT_DBG(slot, "restoring speculative checkpoint (pos_min = %d, pos_max = %d, size = %zu)\n", ckpt.pos_min, ckpt.pos_max, ckpt.size());
const size_t n = llama_state_seq_set_data_ext(slot.ctx, ckpt.data.data(), ckpt.size(), slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
if (n != ckpt.size()) {
GGML_ABORT("%s: failed to restore context checkpoint (pos_min=%d, pos_max=%d, size=%zu, get_data_ext->%zu, set_data_ext->%zu",
__func__, ckpt.pos_min, ckpt.pos_max, ckpt.size(), ckpt.size(), n);
}
llama_memory_seq_rm(llama_get_memory(slot.ctx), slot.id, ckpt.pos_max + 1, -1);
slot.prompt.tokens.keep_first(ckpt.n_tokens);
slot.smpl = std::move(smpl_save);
continue;
}
LOG_DBG("%s: partial acceptance: %zu < %zu\n", __func__, accepted.size(), slot.spec_draft.size());
}
common_speculative_accept(slot.spec.get(), accepted.size() - 1);
slot.spec_draft = std::move(accepted);
}
const int64_t t_current = ggml_time_us();
slot.n_decoded += ids.size();
const auto ids = std::move(slot.spec_draft);
slot.n_decoded += ids.size();
slot.t_token_generation = std::max<int64_t>(1, t_current - slot.t_start_generation) / 1e3;
// update how many tokens out of those tested were accepted
slot.n_draft_accepted += ids.size() - 1;
// inform the speculative decoding about the number of accepted tokens
common_speculative_accept(slot.spec, ids.size() - 1);
// rollback to the state before sampling the draft tokens
slot.prompt.tokens.keep_first(slot.prompt.n_tokens() - n_draft);
slot.n_draft_total += n_draft;
// add accepted tokens to the prompt
slot.prompt.tokens.keep_first(slot.prompt.n_tokens() - n_draft);
slot.prompt.tokens.insert({ids.begin(), ids.end() - 1});
slot.sampled = ids.back(); // last accepted token
llama_memory_seq_rm(llama_get_memory(ctx), slot.id, slot.prompt.n_tokens(), -1);
slot.sampled = ids.back(); // last accepted token
SLT_DBG(slot, "add accepted tokens: sampled=%d, ids.size=%zu, n_draft=%zu\n", slot.sampled, ids.size(), n_draft);
llama_memory_seq_rm(llama_get_memory(slot.ctx), slot.id, slot.prompt.n_tokens(), -1);
for (size_t i = 0; i < ids.size(); ++i) {
completion_token_output result;
result.tok = ids[i];
result.text_to_send = common_token_to_piece(ctx, result.tok, accept_special_token(slot, result.tok));
result.text_to_send = common_token_to_piece(slot.ctx, result.tok, accept_special_token(slot, result.tok));
result.prob = 1.0f; // set later
// TODO: set result.probs
@ -3665,7 +3753,7 @@ void server_routes::init_routes() {
params.n_predict,
meta->slot_n_ctx,
params.spm_infill,
tokenized_prompts[0].get_text_tokens() // TODO: this could maybe be multimodal.
tokenized_prompts[0].get_tokens() // TODO: this could maybe be multimodal.
);
std::vector<raw_buffer> files; // dummy