Add speculative decoding support to the server and command-line interfaces

This commit is contained in:
DeEMO 2025-06-23 20:36:32 +08:00 committed by DeEMO
parent 1ea2d61a97
commit 2e8e42a5ad
11 changed files with 591 additions and 31 deletions

View file

@ -6,6 +6,7 @@
#include "sampling.h"
#include "json-schema-to-grammar.h"
#include "llama.h"
#include "speculative.h"
// Change JSON_ASSERT from assert() to GGML_ASSERT:
#define JSON_ASSERT GGML_ASSERT
@ -133,6 +134,9 @@ struct slot_params {
int32_t n_predict = -1; // new tokens to predict
std::vector<std::string> antiprompt;
struct gpt_sampler_params sampling;
struct common_params_speculative speculative;
json input_prefix;
json input_suffix;
@ -142,6 +146,12 @@ struct server_slot {
int id;
int id_task = -1;
llama_batch batch_spec;
llama_context * ctx_dft = nullptr;
common_speculative * spec = nullptr;
// the index relative to completion multi-task request
size_t index = 0;
@ -231,7 +241,7 @@ struct server_slot {
generated_token_probs.clear();
}
bool has_budget(gpt_params &global_params) {
bool has_budget(const gpt_params &global_params) {
if (params.n_predict == -1 && global_params.n_predict == -1) {
return true; // limitless
}
@ -251,6 +261,10 @@ struct server_slot {
return state != SLOT_STATE_IDLE;
}
bool can_speculate() const {
return ctx_dft && params.speculative.n_max > 0 && params.cache_prompt;
}
void add_token(const completion_token_output & token) {
if (!is_processing()) {
SLT_WRN(*this, "%s", "slot is not processing\n");
@ -615,6 +629,9 @@ struct server_context {
gpt_params params;
llama_model * model_dft = nullptr;
llama_context_params cparams_dft;
llama_batch batch = {};
bool clean_kv_cache = true;
@ -652,17 +669,33 @@ struct server_context {
model = nullptr;
}
if (model_dft) {
llama_free_model(model_dft);
model_dft = nullptr;
}
// Clear any sampling context
for (server_slot & slot : slots) {
if (slot.smpl != nullptr) {
gpt_sampler_free(slot.smpl);
}
slot.smpl = nullptr;
llama_free(slot.ctx_dft);
slot.ctx_dft = nullptr;
common_speculative_free(slot.spec);
slot.spec = nullptr;
llama_batch_free(slot.batch_spec);
}
llama_batch_free(batch);
}
bool load_model(const gpt_params & params_) {
SRV_INF("loading model '%s'\n", params.model.c_str());
params = params_;
// dedicate one sequence to the system prompt
@ -685,6 +718,44 @@ struct server_context {
add_bos_token = llama_add_bos_token(model);
has_eos_token = !llama_add_eos_token(model);
if (!params.speculative.model.empty()) {
SRV_INF("loading draft model '%s'\n", params.speculative.model.c_str());
auto params_dft = params;
params_dft.model = params.speculative.model;
params_dft.n_ctx = params.speculative.n_ctx;
params_dft.n_gpu_layers = params.speculative.n_gpu_layers;
params_dft.n_world = 1; // do not split the draft model across devicesAdd commentMore actions
params_dft.rank = 0; // always load the draft model on the head device
std::fill_n(params_dft.n_layer_window, params.n_world, 0);
llama_init_result llama_init_dft = llama_init_from_gpt_params(params_dft);
model_dft = llama_init_dft.model;
if (model_dft == nullptr) {
SRV_ERR("failed to load draft model, '%s'\n", params.speculative.model.c_str());
return false;
}
if (!common_speculative_are_compatible(ctx, llama_init_dft.context)) {
SRV_ERR("the draft model '%s' is not compatible with the target model '%s'\n", params.speculative.model.c_str(), params.model.c_str());
llama_free (llama_init_dft.context);
llama_free_model(llama_init_dft.model);
return false;
}
cparams_dft = llama_context_params_from_gpt_params(params);
cparams_dft.n_batch = llama_n_ctx(llama_init_dft.context);
// the context is not needed - we will create one for each slot
llama_free(llama_init_dft.context);
}
return true;
}
@ -708,6 +779,30 @@ struct server_context {
slot.id = i;
slot.n_ctx = n_ctx_slot;
slot.n_predict = params.n_predict;
if (model_dft) {
slot.batch_spec = llama_batch_init(params.speculative.n_max + 1, 0, 1);
slot.ctx_dft = llama_new_context_with_model(model_dft, cparams_dft);
if (llama_context_setup_backend(model, cparams_dft, slot.ctx_dft) == nullptr) {
SRV_ERR("%s: failed to setup context with model '%s'\n", __func__, params.model.c_str());
llama_free(slot.ctx_dft);
llama_free_model(model);
return;
}
if (slot.ctx_dft == nullptr) {
SRV_ERR("%s", "failed to create draft context\n");
return;
}
slot.spec = common_speculative_init(slot.ctx_dft);
if (slot.spec == nullptr) {
SRV_ERR("%s", "failed to create speculator\n");
return;
}
}
SLT_INF(slot, "new slot n_ctx_slot = %d\n", slot.n_ctx);
@ -875,6 +970,8 @@ struct server_context {
slot_params default_params;
// Sampling parameter defaults are loaded from the global server context (but individual requests can still override them)
auto default_sparams = params.sparams;
default_params.speculative = params.speculative;
const auto & data = task.data;
if (data.count("__oaicompat") != 0) {
@ -909,6 +1006,12 @@ struct server_context {
slot.sparams.seed = json_value(data, "seed", default_sparams.seed);
slot.sparams.n_probs = json_value(data, "n_probs", default_sparams.n_probs);
slot.sparams.min_keep = json_value(data, "min_keep", default_sparams.min_keep);
slot.params.speculative.n_min = json_value(data, "speculative.n_min", default_params.speculative.n_min);
slot.params.speculative.n_max = json_value(data, "speculative.n_max", default_params.speculative.n_max);
slot.params.speculative.p_min = json_value(data, "speculative.p_min", default_params.speculative.p_min);
slot.params.speculative.n_min = std::min(slot.params.speculative.n_max, slot.params.speculative.n_min);
// process "json_schema" and "grammar"
if (data.contains("json_schema") && !data.at("json_schema").is_null() && data.contains("grammar") && !data.at("grammar").is_null()) {
@ -1049,6 +1152,12 @@ struct server_context {
return false;
}
}
if (slot.ctx_dft) {
llama_batch_free(slot.batch_spec);
slot.batch_spec = llama_batch_init(slot.params.speculative.n_max + 1, 0, 1);
}
slot.state = SLOT_STATE_PROCESSING_PROMPT;
slot.prompt_tokens.clear();
@ -2357,38 +2466,100 @@ struct server_context {
continue; // continue loop of slots
}
completion_token_output result;
const llama_token id = gpt_sampler_sample(slot.smpl, ctx, slot.i_batch - i);
llama_token id;
gpt_sampler_accept(slot.smpl, id, true);
{
completion_token_output result;
slot.n_decoded += 1;
if (slot.n_decoded == 1) {
slot.t_start_generation = ggml_time_us();
slot.t_prompt_processing = (slot.t_start_generation - slot.t_start_process_prompt) / 1e3;
metrics.on_prompt_eval(slot);
id = gpt_sampler_sample(slot.smpl, ctx, slot.i_batch - i);
slot.i_batch = -1;
gpt_sampler_accept(slot.smpl, id, true);
slot.n_decoded += 1;
if (slot.n_decoded == 1) {
slot.t_start_generation = ggml_time_us();
slot.t_prompt_processing = (slot.t_start_generation - slot.t_start_process_prompt) / 1e3;
metrics.on_prompt_eval(slot);
}
result.tok = id;
const auto * cur_p = gpt_sampler_get_candidates(slot.smpl);
for (size_t i = 0; i < (size_t) slot.params.sampling.n_probs; ++i) {
result.probs.push_back({
cur_p->data[i].id,
i >= cur_p->size ? 0.0f : cur_p->data[i].p,
});
}
if (!process_token(result, slot)) {
// release slot because of stop condition
slot.release();
slot.print_timings();
send_final_response(slot);
metrics.on_prediction(slot);
continue;
}
}
result.tok = id;
const auto * cur_p = gpt_sampler_get_candidates(slot.smpl);
for (size_t i = 0; i < (size_t) slot.sparams.n_probs; ++i) {
result.probs.push_back({
cur_p->data[i].id,
i >= cur_p->size ? 0.0f : cur_p->data[i].p,
});
// check if the slot supports speculative decoding
if (!slot.can_speculate()) {
continue;
}
if (!process_token(result, slot)) {
// release slot because of stop condition
slot.release();
slot.print_timings();
send_final_response(slot);
metrics.on_prediction(slot);
struct common_speculative_params params_spec;
params_spec.n_draft = slot.params.speculative.n_max;
params_spec.n_reuse = llama_n_ctx(slot.ctx_dft) - slot.params.speculative.n_max;
params_spec.p_min = slot.params.speculative.p_min;
llama_tokens draft = common_speculative_gen_draft(slot.spec, params_spec, slot.cache_tokens, id);
// ignore small drafts
if (slot.params.speculative.n_min > (int) draft.size()) {
continue;
}
slot.i_batch = -1;
// construct the speculation batch
llama_batch_clear(slot.batch_spec);
llama_batch_add (slot.batch_spec, id, slot.n_past, { slot.id }, true);
for (size_t i = 0; i < draft.size(); ++i) {
llama_batch_add(slot.batch_spec, draft[i], slot.n_past + 1 + i, { slot.id }, true);
}
llama_decode(ctx, slot.batch_spec);
// the accepted tokens from the speculation
const auto ids = gpt_sampler_sample_and_accept_n(slot.smpl, ctx, draft);
slot.n_past += ids.size();
slot.n_decoded += ids.size();
slot.cache_tokens.push_back(id);
slot.cache_tokens.insert(slot.cache_tokens.end(), ids.begin(), ids.end() - 1);
llama_kv_cache_seq_rm(ctx, slot.id, slot.n_past, -1);
for (size_t i = 0; i < ids.size(); ++i) {
completion_token_output result;
result.tok = ids[i];
if (!process_token(result, slot)) {
// release slot because of stop condition
slot.release();
slot.print_timings();
send_final_response(slot);
metrics.on_prediction(slot);
break;
}
}
SRV_DBG("accepted %d/%d draft tokens\n", (int) ids.size() - 1, (int) draft.size());
}
}