Add speculative decoding support to the server and command-line interfaces

This commit is contained in:
DeEMO 2025-06-23 20:36:32 +08:00 committed by DeEMO
parent 1ea2d61a97
commit 2e8e42a5ad
11 changed files with 591 additions and 31 deletions

View file

@ -963,6 +963,7 @@ OBJ_COMMON = \
common/console.o \
common/ngram-cache.o \
common/sampling.o \
common/speculative.o \
common/train.o \
common/build-info.o \
common/json-schema-to-grammar.o
@ -1239,6 +1240,13 @@ common/json-schema-to-grammar.o: \
common/json-schema-to-grammar.h
$(CXX) $(CXXFLAGS) -c $< -o $@
# speculative
common/speculative.o: \
common/speculative.cpp \
common/speculative.h \
include/llama.h
$(CXX) $(CXXFLAGS) -c $< -o $@
common/train.o: \
common/train.cpp \
common/train.h

View file

@ -1704,9 +1704,9 @@ gpt_params_context gpt_params_parser_init(gpt_params & params, llama_example ex,
{"-md", "--model-draft"}, "FNAME",
"draft model for speculative decoding (default: unused)",
[](gpt_params & params, const std::string & value) {
params.model_draft = value;
params.speculative.model = value;
}
).set_examples({LLAMA_EXAMPLE_SPECULATIVE}));
).set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SERVER}));
add_opt(llama_arg(
{"-mu", "--model-url"}, "MODEL_URL",
"model download url (default: unused)",

View file

@ -3111,7 +3111,7 @@ void yaml_dump_non_result_info(FILE * stream, const gpt_params & params, const l
fprintf(stream, "mirostat_lr: %f # default: 0.1\n", sparams.mirostat_eta);
fprintf(stream, "mlock: %s # default: false\n", params.use_mlock ? "true" : "false");
fprintf(stream, "model: %s # default: %s\n", params.model.c_str(), DEFAULT_MODEL_PATH);
fprintf(stream, "model_draft: %s # default:\n", params.model_draft.c_str());
fprintf(stream, "model_draft: %s # default:\n", params.speculative.model.c_str());
fprintf(stream, "multiline_input: %s # default: false\n", params.multiline_input ? "true" : "false");
fprintf(stream, "n_gpu_layers: %d # default: -1\n", params.n_gpu_layers);
fprintf(stream, "n_predict: %d # default: -1 (unlimited)\n", params.n_predict);

View file

@ -33,6 +33,8 @@ struct llama_lora_adapter_container : llama_lora_adapter_info {
struct llama_lora_adapter * adapter;
};
using llama_tokens = std::vector<llama_token>;
// build info
extern int LLAMA_BUILD_NUMBER;
extern char const * LLAMA_COMMIT;
@ -141,6 +143,20 @@ struct gpt_sampler_params {
std::string print() const;
};
struct common_params_speculative {
int32_t n_ctx = 0; // draft context size
int32_t n_max = 16; // maximum number of tokens to draft during speculative decoding
int32_t n_min = 5; // minimum number of draft tokens to use for speculative decoding
int32_t n_gpu_layers = -1; // number of layers to store in VRAM for the draft model (-1 - use default)
float p_split = 0.1f; // speculative decoding split probability
float p_min = 0.9f; // minimum speculative decoding probability (greedy)
struct cpu_params cpuparams;
struct cpu_params cpuparams_batch;
std::string model = ""; // draft model for speculative decoding // NOLINT
};
struct gpt_params {
int32_t n_world = 1; // number of devices to use
int32_t rank = 0; // my rank for distributed inference
@ -198,9 +214,9 @@ struct gpt_params {
enum llama_attention_type attention_type = LLAMA_ATTENTION_TYPE_UNSPECIFIED; // attention type for embeddings
struct gpt_sampler_params sparams;
struct common_params_speculative speculative;
std::string model = ""; // model path // NOLINT
std::string model_draft = ""; // draft model for speculative decoding // NOLINT
std::string model_alias = "unknown"; // model alias // NOLINT
std::string model_url = ""; // model url to download // NOLINT
std::string hf_token = ""; // HF token // NOLINT

View file

@ -318,6 +318,45 @@ llama_token gpt_sampler_sample(struct gpt_sampler * gsmpl, struct llama_context
return cur_p.data[cur_p.selected].id;
}
std::vector<llama_token> gpt_sampler_sample_and_accept_n(struct gpt_sampler * gsmpl, struct llama_context * ctx, const std::vector<int> & idxs, const llama_tokens & draft, bool grammar_first) {
GGML_ASSERT(idxs.size() == draft.size() + 1 && "idxs.size() must be draft.size() + 1");
std::vector<llama_token> result;
result.reserve(idxs.size());
size_t i = 0;
for (; i < draft.size(); i++) {
const llama_token id = gpt_sampler_sample(gsmpl, ctx, idxs[i], grammar_first);
gpt_sampler_accept(gsmpl, id, true);
result.push_back(id);
if (draft[i] != id) {
break;
}
}
if (i == draft.size()) {
const llama_token id = gpt_sampler_sample(gsmpl, ctx, idxs[i], grammar_first);
gpt_sampler_accept(gsmpl, id, true);
result.push_back(id);
}
return result;
}
std::vector<llama_token> gpt_sampler_sample_and_accept_n(struct gpt_sampler * gsmpl, struct llama_context * ctx, const llama_tokens & draft, bool grammar_first) {
std::vector<int> idxs(draft.size() + 1);
for (size_t i = 0; i < idxs.size(); ++i) {
idxs[i] = i;
}
return gpt_sampler_sample_and_accept_n(gsmpl, ctx, idxs, draft, grammar_first);
}
uint32_t gpt_sampler_get_seed(const struct gpt_sampler * gsmpl) {
return llama_sampler_get_seed(gsmpl->chain);
}

View file

@ -60,6 +60,27 @@ void gpt_perf_print(const struct llama_context * ctx, const struct gpt_sampler *
//
llama_token gpt_sampler_sample(struct gpt_sampler * gsmpl, struct llama_context * ctx, int idx, bool grammar_first = false);
// generalized version of gpt_sampler_sample
//
// will cross-reference the sampled tokens with a batch of draft tokens and accept those that match
// if the sampler disagrees at some point, we stop and return the accepted tokens up to now
//
// gpt_sampler_sample_n(gsmpl, ctx, { idx }, {});
//
// is equivalent to
//
// gpt_sampler_sample(gsmpl, ctx, idx);
// gpt_sampler_accept(gsmpl, token, true);
//
// requires: idxs.size() == draft.size() + 1
//
// returns at least 1 token, up to idxs.size()
//
std::vector<llama_token> gpt_sampler_sample_and_accept_n(struct gpt_sampler * gsmpl, struct llama_context * ctx, const std::vector<int> & idxs, const llama_tokens & draft, bool grammar_first = false);
// assume idxs == [ 0, 1, 2, ..., draft.size() ]
std::vector<llama_token> gpt_sampler_sample_and_accept_n(struct gpt_sampler * gsmpl, struct llama_context * ctx, const llama_tokens & draft, bool grammar_first = false);
uint32_t gpt_sampler_get_seed(const struct gpt_sampler * gsmpl);
// helpers

271
common/speculative.cpp Normal file
View file

@ -0,0 +1,271 @@
#include "speculative.h"
#include "log.h"
#include "common.h"
#include "sampling.h"
#include <cstring>
#define SPEC_VOCAB_MAX_SIZE_DIFFERENCE 128
#define SPEC_VOCAB_CHECK_START_TOKEN_ID 5
struct common_speculative {
struct llama_context * ctx;
struct gpt_sampler * smpl;
llama_batch batch;
llama_tokens prompt;
};
struct common_speculative * common_speculative_init(
struct llama_context * ctx_dft) {
auto * result = new common_speculative {
/* .ctx = */ ctx_dft,
/* .smpl = */ nullptr,
/* .batch = */ llama_batch_init(llama_n_batch(ctx_dft), 0, 1),
/* .prompt = */ {},
};
// TODO: optimize or pass from outside?
#if 0
{
common_params_sampling params;
params.no_perf = false;
params.top_k = 40;
params.top_p = 0.9;
params.samplers = {
COMMON_SAMPLER_TYPE_TOP_K,
COMMON_SAMPLER_TYPE_TOP_P,
COMMON_SAMPLER_TYPE_INFILL,
};
result->smpl = gpt_sampler_init(llama_get_model(ctx_dft), params);
}
#else
{
gpt_sampler_params params;
params.no_perf = false;
params.top_k = 10;
params.samplers = {
GPT_SAMPLER_TYPE_TOP_K,
};
result->smpl = gpt_sampler_init(llama_get_model(ctx_dft), params);
}
#endif
llama_update_context_with_rankworld(result->ctx, 0, 1, 0, 1);
return result;
}
void common_speculative_free(struct common_speculative * spec) {
gpt_sampler_free(spec->smpl);
llama_batch_free(spec->batch);
delete spec;
}
bool common_speculative_are_compatible(
const struct llama_context * ctx_tgt,
const struct llama_context * ctx_dft) {
const struct llama_model * model_tgt = llama_get_model(ctx_tgt);
const struct llama_model * model_dft = llama_get_model(ctx_dft);
const bool vocab_type_tgt = llama_vocab_type(model_tgt);
LOG_DBG("%s: vocab_type tgt: %d\n", __func__, vocab_type_tgt);
const bool vocab_type_dft = llama_vocab_type(model_dft);
LOG_DBG("%s: vocab_type dft: %d\n", __func__, vocab_type_dft);
if (vocab_type_tgt != vocab_type_dft) {
LOG_ERR("%s: draft model vocab type must match target model to use speculation but "
"vocab_type_dft = %d while vocab_type_tgt = %d\n", __func__, vocab_type_dft, vocab_type_tgt);
return false;
}
if (llama_add_bos_token(model_tgt) != llama_add_bos_token(model_dft) ||
llama_add_eos_token(model_tgt) != llama_add_eos_token(model_dft) ||
llama_token_bos(model_tgt) != llama_token_bos(model_dft) ||
llama_token_eos(model_tgt) != llama_token_eos(model_dft)
) {
LOG_ERR("%s: draft model special tokens must match target model to use speculation\n", __func__);
return false;
}
{
const int n_vocab_tgt = llama_n_vocab(model_tgt);
const int n_vocab_dft = llama_n_vocab(model_dft);
const int vocab_diff = std::abs(n_vocab_tgt - n_vocab_dft);
if (vocab_diff > SPEC_VOCAB_MAX_SIZE_DIFFERENCE) {
LOG_ERR("%s: draft model vocab must closely match target model to use speculation but "
"target vocab size %d does not match draft vocab size %d - difference %d, max allowed %d\n",
__func__, n_vocab_tgt, llama_n_vocab(model_dft), vocab_diff, SPEC_VOCAB_MAX_SIZE_DIFFERENCE);
return false;
}
for (int i = SPEC_VOCAB_CHECK_START_TOKEN_ID; i < std::min(n_vocab_tgt, n_vocab_dft); ++i) {
const char * token_text_tgt = llama_token_get_text(model_tgt, i);
const char * token_text_dft = llama_token_get_text(model_dft, i);
if (std::strcmp(token_text_tgt, token_text_dft) != 0) {
LOG_ERR("%s: draft model vocab must match target model to use speculation but "
"token %d content differs - target '%s', draft '%s'\n", __func__, i,
llama_token_to_piece(ctx_tgt, i).c_str(),
llama_token_to_piece(ctx_dft, i).c_str());
return false;
}
}
}
return true;
}
llama_tokens common_speculative_gen_draft(
struct common_speculative * spec,
struct common_speculative_params params,
const llama_tokens & prompt_tgt,
llama_token id_last) {
auto & batch = spec->batch;
auto & ctx = spec->ctx;
auto & smpl = spec->smpl;
auto & prompt = spec->prompt;
int reuse_i = 0;
int reuse_n = 0;
const int n_ctx = llama_n_ctx(ctx) - params.n_draft;
const int i_start = std::max<int>(0, (int) prompt_tgt.size() - n_ctx);
// reuse as much as possible from the old draft context
// ideally, the draft context should be as big as the target context and we will always reuse the entire prompt
for (int i = 0; i < (int) prompt.size(); ++i) {
int cur = 0;
while (i_start + cur < (int) prompt_tgt.size() &&
i + cur < (int) prompt.size() &&
prompt_tgt[i_start + cur] == prompt[i + cur]) {
cur++;
}
if ((cur >= params.n_reuse || n_ctx >= (int) prompt_tgt.size()) && cur > reuse_n) {
reuse_i = i;
reuse_n = cur;
}
}
LOG_DBG("%s: reuse_i = %d, reuse_n = %d, prompt = %d\n", __func__, reuse_i, reuse_n, (int) prompt.size());
llama_tokens result;
result.reserve(params.n_draft);
if (reuse_n == 0) {
llama_kv_cache_clear(ctx);
prompt.clear();
} else {
// this happens when a previous draft has been discarded (for example, due to being too small), but the
// target model agreed with it. in this case, we simply pass back the previous results to save compute
if (reuse_i + reuse_n < (int) prompt.size() && prompt[reuse_i + reuse_n] == id_last) {
for (int i = reuse_i + reuse_n + 1; i < (int) prompt.size(); ++i) {
result.push_back(prompt[i]);
if (params.n_draft <= (int) result.size()) {
break;
}
}
return result;
}
if (reuse_i > 0) {
llama_kv_cache_seq_rm (ctx, 0, 0, reuse_i);
llama_kv_cache_seq_add(ctx, 0, reuse_i, -1, -reuse_i);
prompt.erase(prompt.begin(), prompt.begin() + reuse_i);
}
if (reuse_n < (int) prompt.size()) {
llama_kv_cache_seq_rm (ctx, 0, reuse_n, -1);
prompt.erase(prompt.begin() + reuse_n, prompt.end());
}
}
// prepare a batch to evaluate any new tokens in the prompt
llama_batch_clear(batch);
for (size_t i = i_start + reuse_n; i < prompt_tgt.size(); ++i) {
//LOG_DBG("i = %d, i_start = %d, reuse_n = %d, i - i_start = %d, id = %6d\n", i, i_start, reuse_n, i - i_start, prompt_tgt[i]);
llama_batch_add(batch, prompt_tgt[i], i - i_start, { 0 }, false);
prompt.push_back(prompt_tgt[i]);
}
// we should rarely end-up here during normal decoding
if (batch.n_tokens > 0) {
//LOG_DBG("%s: draft prompt batch: %s\n", __func__, string_from(ctx, batch).c_str());
llama_decode(ctx, batch);
}
const llama_pos n_past = prompt.size();
LOG_DBG("%s: n_past = %d\n", __func__, n_past);
llama_batch_clear(batch);
llama_batch_add (batch, id_last, n_past, { 0 }, true);
prompt.push_back(id_last);
//LOG_DBG("%s: draft prompt: %s\n", __func__, string_from(ctx, prompt).c_str());
llama_decode(ctx, batch);
gpt_sampler_reset(smpl);
// sample n_draft tokens from the draft model
for (int i = 0; i < params.n_draft; ++i) {
llama_batch_clear(batch);
gpt_sampler_sample(smpl, ctx, 0, true);
const auto * cur_p = gpt_sampler_get_candidates(smpl);
for (int k = 0; k < std::min(3, (int) cur_p->size); ++k) {
LOG_DBG(" - draft candidate %3d, pos %3d: %6d (%8.3f) '%s'\n",
k, i, cur_p->data[k].id, cur_p->data[k].p, llama_token_to_piece(ctx, cur_p->data[k].id).c_str());
}
// add drafted token for each sequence
const llama_token id = cur_p->data[0].id;
// only collect very high-confidence draft tokens
if (cur_p->data[0].p < params.p_min) {
break;
}
gpt_sampler_accept(smpl, id, true);
result.push_back(id);
if (params.n_draft <= (int) result.size()) {
break;
}
llama_batch_add(batch, id, n_past + i + 1, { 0 }, true);
// evaluate the drafted tokens on the draft model
llama_decode(ctx, batch);
prompt.push_back(id);
}
return result;
}

28
common/speculative.h Normal file
View file

@ -0,0 +1,28 @@
#pragma once
#include "llama.h"
#include "common.h"
struct common_speculative;
struct common_speculative_params {
int n_draft = 16; // max drafted tokens
int n_reuse = 256;
float p_min = 0.9f; // min probabiliy required to accept a token in the draft
};
struct common_speculative * common_speculative_init(struct llama_context * ctx_dft);
void common_speculative_free(struct common_speculative * spec);
bool common_speculative_are_compatible(
const struct llama_context * ctx_tgt,
const struct llama_context * ctx_dft);
// sample up to n_draft tokens and add them to the batch using the draft model
llama_tokens common_speculative_gen_draft(
struct common_speculative * spec,
struct common_speculative_params params,
const llama_tokens & prompt,
llama_token id_last);

View file

@ -6,6 +6,7 @@
#include "sampling.h"
#include "json-schema-to-grammar.h"
#include "llama.h"
#include "speculative.h"
// Change JSON_ASSERT from assert() to GGML_ASSERT:
#define JSON_ASSERT GGML_ASSERT
@ -133,6 +134,9 @@ struct slot_params {
int32_t n_predict = -1; // new tokens to predict
std::vector<std::string> antiprompt;
struct gpt_sampler_params sampling;
struct common_params_speculative speculative;
json input_prefix;
json input_suffix;
@ -142,6 +146,12 @@ struct server_slot {
int id;
int id_task = -1;
llama_batch batch_spec;
llama_context * ctx_dft = nullptr;
common_speculative * spec = nullptr;
// the index relative to completion multi-task request
size_t index = 0;
@ -231,7 +241,7 @@ struct server_slot {
generated_token_probs.clear();
}
bool has_budget(gpt_params &global_params) {
bool has_budget(const gpt_params &global_params) {
if (params.n_predict == -1 && global_params.n_predict == -1) {
return true; // limitless
}
@ -251,6 +261,10 @@ struct server_slot {
return state != SLOT_STATE_IDLE;
}
bool can_speculate() const {
return ctx_dft && params.speculative.n_max > 0 && params.cache_prompt;
}
void add_token(const completion_token_output & token) {
if (!is_processing()) {
SLT_WRN(*this, "%s", "slot is not processing\n");
@ -615,6 +629,9 @@ struct server_context {
gpt_params params;
llama_model * model_dft = nullptr;
llama_context_params cparams_dft;
llama_batch batch = {};
bool clean_kv_cache = true;
@ -652,17 +669,33 @@ struct server_context {
model = nullptr;
}
if (model_dft) {
llama_free_model(model_dft);
model_dft = nullptr;
}
// Clear any sampling context
for (server_slot & slot : slots) {
if (slot.smpl != nullptr) {
gpt_sampler_free(slot.smpl);
}
slot.smpl = nullptr;
llama_free(slot.ctx_dft);
slot.ctx_dft = nullptr;
common_speculative_free(slot.spec);
slot.spec = nullptr;
llama_batch_free(slot.batch_spec);
}
llama_batch_free(batch);
}
bool load_model(const gpt_params & params_) {
SRV_INF("loading model '%s'\n", params.model.c_str());
params = params_;
// dedicate one sequence to the system prompt
@ -685,6 +718,44 @@ struct server_context {
add_bos_token = llama_add_bos_token(model);
has_eos_token = !llama_add_eos_token(model);
if (!params.speculative.model.empty()) {
SRV_INF("loading draft model '%s'\n", params.speculative.model.c_str());
auto params_dft = params;
params_dft.model = params.speculative.model;
params_dft.n_ctx = params.speculative.n_ctx;
params_dft.n_gpu_layers = params.speculative.n_gpu_layers;
params_dft.n_world = 1; // do not split the draft model across devicesAdd commentMore actions
params_dft.rank = 0; // always load the draft model on the head device
std::fill_n(params_dft.n_layer_window, params.n_world, 0);
llama_init_result llama_init_dft = llama_init_from_gpt_params(params_dft);
model_dft = llama_init_dft.model;
if (model_dft == nullptr) {
SRV_ERR("failed to load draft model, '%s'\n", params.speculative.model.c_str());
return false;
}
if (!common_speculative_are_compatible(ctx, llama_init_dft.context)) {
SRV_ERR("the draft model '%s' is not compatible with the target model '%s'\n", params.speculative.model.c_str(), params.model.c_str());
llama_free (llama_init_dft.context);
llama_free_model(llama_init_dft.model);
return false;
}
cparams_dft = llama_context_params_from_gpt_params(params);
cparams_dft.n_batch = llama_n_ctx(llama_init_dft.context);
// the context is not needed - we will create one for each slot
llama_free(llama_init_dft.context);
}
return true;
}
@ -708,6 +779,30 @@ struct server_context {
slot.id = i;
slot.n_ctx = n_ctx_slot;
slot.n_predict = params.n_predict;
if (model_dft) {
slot.batch_spec = llama_batch_init(params.speculative.n_max + 1, 0, 1);
slot.ctx_dft = llama_new_context_with_model(model_dft, cparams_dft);
if (llama_context_setup_backend(model, cparams_dft, slot.ctx_dft) == nullptr) {
SRV_ERR("%s: failed to setup context with model '%s'\n", __func__, params.model.c_str());
llama_free(slot.ctx_dft);
llama_free_model(model);
return;
}
if (slot.ctx_dft == nullptr) {
SRV_ERR("%s", "failed to create draft context\n");
return;
}
slot.spec = common_speculative_init(slot.ctx_dft);
if (slot.spec == nullptr) {
SRV_ERR("%s", "failed to create speculator\n");
return;
}
}
SLT_INF(slot, "new slot n_ctx_slot = %d\n", slot.n_ctx);
@ -875,6 +970,8 @@ struct server_context {
slot_params default_params;
// Sampling parameter defaults are loaded from the global server context (but individual requests can still override them)
auto default_sparams = params.sparams;
default_params.speculative = params.speculative;
const auto & data = task.data;
if (data.count("__oaicompat") != 0) {
@ -909,6 +1006,12 @@ struct server_context {
slot.sparams.seed = json_value(data, "seed", default_sparams.seed);
slot.sparams.n_probs = json_value(data, "n_probs", default_sparams.n_probs);
slot.sparams.min_keep = json_value(data, "min_keep", default_sparams.min_keep);
slot.params.speculative.n_min = json_value(data, "speculative.n_min", default_params.speculative.n_min);
slot.params.speculative.n_max = json_value(data, "speculative.n_max", default_params.speculative.n_max);
slot.params.speculative.p_min = json_value(data, "speculative.p_min", default_params.speculative.p_min);
slot.params.speculative.n_min = std::min(slot.params.speculative.n_max, slot.params.speculative.n_min);
// process "json_schema" and "grammar"
if (data.contains("json_schema") && !data.at("json_schema").is_null() && data.contains("grammar") && !data.at("grammar").is_null()) {
@ -1049,6 +1152,12 @@ struct server_context {
return false;
}
}
if (slot.ctx_dft) {
llama_batch_free(slot.batch_spec);
slot.batch_spec = llama_batch_init(slot.params.speculative.n_max + 1, 0, 1);
}
slot.state = SLOT_STATE_PROCESSING_PROMPT;
slot.prompt_tokens.clear();
@ -2357,38 +2466,100 @@ struct server_context {
continue; // continue loop of slots
}
completion_token_output result;
const llama_token id = gpt_sampler_sample(slot.smpl, ctx, slot.i_batch - i);
llama_token id;
gpt_sampler_accept(slot.smpl, id, true);
{
completion_token_output result;
slot.n_decoded += 1;
if (slot.n_decoded == 1) {
slot.t_start_generation = ggml_time_us();
slot.t_prompt_processing = (slot.t_start_generation - slot.t_start_process_prompt) / 1e3;
metrics.on_prompt_eval(slot);
id = gpt_sampler_sample(slot.smpl, ctx, slot.i_batch - i);
slot.i_batch = -1;
gpt_sampler_accept(slot.smpl, id, true);
slot.n_decoded += 1;
if (slot.n_decoded == 1) {
slot.t_start_generation = ggml_time_us();
slot.t_prompt_processing = (slot.t_start_generation - slot.t_start_process_prompt) / 1e3;
metrics.on_prompt_eval(slot);
}
result.tok = id;
const auto * cur_p = gpt_sampler_get_candidates(slot.smpl);
for (size_t i = 0; i < (size_t) slot.params.sampling.n_probs; ++i) {
result.probs.push_back({
cur_p->data[i].id,
i >= cur_p->size ? 0.0f : cur_p->data[i].p,
});
}
if (!process_token(result, slot)) {
// release slot because of stop condition
slot.release();
slot.print_timings();
send_final_response(slot);
metrics.on_prediction(slot);
continue;
}
}
result.tok = id;
const auto * cur_p = gpt_sampler_get_candidates(slot.smpl);
for (size_t i = 0; i < (size_t) slot.sparams.n_probs; ++i) {
result.probs.push_back({
cur_p->data[i].id,
i >= cur_p->size ? 0.0f : cur_p->data[i].p,
});
// check if the slot supports speculative decoding
if (!slot.can_speculate()) {
continue;
}
if (!process_token(result, slot)) {
// release slot because of stop condition
slot.release();
slot.print_timings();
send_final_response(slot);
metrics.on_prediction(slot);
struct common_speculative_params params_spec;
params_spec.n_draft = slot.params.speculative.n_max;
params_spec.n_reuse = llama_n_ctx(slot.ctx_dft) - slot.params.speculative.n_max;
params_spec.p_min = slot.params.speculative.p_min;
llama_tokens draft = common_speculative_gen_draft(slot.spec, params_spec, slot.cache_tokens, id);
// ignore small drafts
if (slot.params.speculative.n_min > (int) draft.size()) {
continue;
}
slot.i_batch = -1;
// construct the speculation batch
llama_batch_clear(slot.batch_spec);
llama_batch_add (slot.batch_spec, id, slot.n_past, { slot.id }, true);
for (size_t i = 0; i < draft.size(); ++i) {
llama_batch_add(slot.batch_spec, draft[i], slot.n_past + 1 + i, { slot.id }, true);
}
llama_decode(ctx, slot.batch_spec);
// the accepted tokens from the speculation
const auto ids = gpt_sampler_sample_and_accept_n(slot.smpl, ctx, draft);
slot.n_past += ids.size();
slot.n_decoded += ids.size();
slot.cache_tokens.push_back(id);
slot.cache_tokens.insert(slot.cache_tokens.end(), ids.begin(), ids.end() - 1);
llama_kv_cache_seq_rm(ctx, slot.id, slot.n_past, -1);
for (size_t i = 0; i < ids.size(); ++i) {
completion_token_output result;
result.tok = ids[i];
if (!process_token(result, slot)) {
// release slot because of stop condition
slot.release();
slot.print_timings();
send_final_response(slot);
metrics.on_prediction(slot);
break;
}
}
SRV_DBG("accepted %d/%d draft tokens\n", (int) ids.size() - 1, (int) draft.size());
}
}

View file

@ -41,7 +41,7 @@ int main(int argc, char ** argv) {
gpt_init();
if (params.model_draft.empty()) {
if (params.speculative.model.empty()) {
LOG_ERR("%s: --model-draft is required\n", __func__);
return 1;
}
@ -68,7 +68,7 @@ int main(int argc, char ** argv) {
// load the draft model
// make a hard copy of params to use for the draft model
gpt_params params_draft = params;
params_draft.model = params_draft.model_draft;
params_draft.model = params_draft.speculative.model;
params_draft.n_gpu_layers = params_draft.n_gpu_layers_draft;
params_draft.n_world = 1; // do not split the draft model across devices
params_draft.rank = 0; // always load the draft model on the head device

View file

@ -20944,6 +20944,12 @@ struct llama_context * llama_new_context_with_model(
ctx->cparams.rank = params.rank;
ctx->cparams.force = params.force;
ctx->cparams.original_next_rank = (params.rank + 1) % params.n_world;
auto &hparams = model->hparams;
auto &cparams = ctx->cparams;
cparams.n_ctx = params.n_ctx == 0 ? hparams.n_ctx_train : params.n_ctx;
ctx->logits_all = params.logits_all;
return ctx;
}