From 2e8e42a5ad2949fa8b7f4171fddb0d88a684750b Mon Sep 17 00:00:00 2001 From: DeEMO Date: Mon, 23 Jun 2025 20:36:32 +0800 Subject: [PATCH] Add speculative decoding support to the server and command-line interfaces --- Makefile | 8 + common/arg.cpp | 4 +- common/common.cpp | 2 +- common/common.h | 18 +- common/sampling.cpp | 39 ++++ common/sampling.h | 21 +++ common/speculative.cpp | 271 +++++++++++++++++++++++++++ common/speculative.h | 28 +++ examples/server/server.cpp | 221 +++++++++++++++++++--- examples/speculative/speculative.cpp | 4 +- src/llama.cpp | 6 + 11 files changed, 591 insertions(+), 31 deletions(-) create mode 100644 common/speculative.cpp create mode 100644 common/speculative.h diff --git a/Makefile b/Makefile index 39ae0b9f..b1bbc0ed 100644 --- a/Makefile +++ b/Makefile @@ -963,6 +963,7 @@ OBJ_COMMON = \ common/console.o \ common/ngram-cache.o \ common/sampling.o \ + common/speculative.o \ common/train.o \ common/build-info.o \ common/json-schema-to-grammar.o @@ -1239,6 +1240,13 @@ common/json-schema-to-grammar.o: \ common/json-schema-to-grammar.h $(CXX) $(CXXFLAGS) -c $< -o $@ +# speculative +common/speculative.o: \ + common/speculative.cpp \ + common/speculative.h \ + include/llama.h + $(CXX) $(CXXFLAGS) -c $< -o $@ + common/train.o: \ common/train.cpp \ common/train.h diff --git a/common/arg.cpp b/common/arg.cpp index c971bb49..813f87e8 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -1704,9 +1704,9 @@ gpt_params_context gpt_params_parser_init(gpt_params & params, llama_example ex, {"-md", "--model-draft"}, "FNAME", "draft model for speculative decoding (default: unused)", [](gpt_params & params, const std::string & value) { - params.model_draft = value; + params.speculative.model = value; } - ).set_examples({LLAMA_EXAMPLE_SPECULATIVE})); + ).set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SERVER})); add_opt(llama_arg( {"-mu", "--model-url"}, "MODEL_URL", "model download url (default: unused)", diff --git a/common/common.cpp b/common/common.cpp index 92c95b3b..b6182b22 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -3111,7 +3111,7 @@ void yaml_dump_non_result_info(FILE * stream, const gpt_params & params, const l fprintf(stream, "mirostat_lr: %f # default: 0.1\n", sparams.mirostat_eta); fprintf(stream, "mlock: %s # default: false\n", params.use_mlock ? "true" : "false"); fprintf(stream, "model: %s # default: %s\n", params.model.c_str(), DEFAULT_MODEL_PATH); - fprintf(stream, "model_draft: %s # default:\n", params.model_draft.c_str()); + fprintf(stream, "model_draft: %s # default:\n", params.speculative.model.c_str()); fprintf(stream, "multiline_input: %s # default: false\n", params.multiline_input ? "true" : "false"); fprintf(stream, "n_gpu_layers: %d # default: -1\n", params.n_gpu_layers); fprintf(stream, "n_predict: %d # default: -1 (unlimited)\n", params.n_predict); diff --git a/common/common.h b/common/common.h index 54eb7e8b..044dfdf5 100644 --- a/common/common.h +++ b/common/common.h @@ -33,6 +33,8 @@ struct llama_lora_adapter_container : llama_lora_adapter_info { struct llama_lora_adapter * adapter; }; +using llama_tokens = std::vector; + // build info extern int LLAMA_BUILD_NUMBER; extern char const * LLAMA_COMMIT; @@ -141,6 +143,20 @@ struct gpt_sampler_params { std::string print() const; }; +struct common_params_speculative { + int32_t n_ctx = 0; // draft context size + int32_t n_max = 16; // maximum number of tokens to draft during speculative decoding + int32_t n_min = 5; // minimum number of draft tokens to use for speculative decoding + int32_t n_gpu_layers = -1; // number of layers to store in VRAM for the draft model (-1 - use default) + float p_split = 0.1f; // speculative decoding split probability + float p_min = 0.9f; // minimum speculative decoding probability (greedy) + + struct cpu_params cpuparams; + struct cpu_params cpuparams_batch; + + std::string model = ""; // draft model for speculative decoding // NOLINT +}; + struct gpt_params { int32_t n_world = 1; // number of devices to use int32_t rank = 0; // my rank for distributed inference @@ -198,9 +214,9 @@ struct gpt_params { enum llama_attention_type attention_type = LLAMA_ATTENTION_TYPE_UNSPECIFIED; // attention type for embeddings struct gpt_sampler_params sparams; + struct common_params_speculative speculative; std::string model = ""; // model path // NOLINT - std::string model_draft = ""; // draft model for speculative decoding // NOLINT std::string model_alias = "unknown"; // model alias // NOLINT std::string model_url = ""; // model url to download // NOLINT std::string hf_token = ""; // HF token // NOLINT diff --git a/common/sampling.cpp b/common/sampling.cpp index 3dc7f112..b8db920a 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -318,6 +318,45 @@ llama_token gpt_sampler_sample(struct gpt_sampler * gsmpl, struct llama_context return cur_p.data[cur_p.selected].id; } +std::vector gpt_sampler_sample_and_accept_n(struct gpt_sampler * gsmpl, struct llama_context * ctx, const std::vector & idxs, const llama_tokens & draft, bool grammar_first) { + GGML_ASSERT(idxs.size() == draft.size() + 1 && "idxs.size() must be draft.size() + 1"); + + std::vector result; + result.reserve(idxs.size()); + + size_t i = 0; + for (; i < draft.size(); i++) { + const llama_token id = gpt_sampler_sample(gsmpl, ctx, idxs[i], grammar_first); + + gpt_sampler_accept(gsmpl, id, true); + + result.push_back(id); + + if (draft[i] != id) { + break; + } + } + + if (i == draft.size()) { + const llama_token id = gpt_sampler_sample(gsmpl, ctx, idxs[i], grammar_first); + + gpt_sampler_accept(gsmpl, id, true); + + result.push_back(id); + } + + return result; +} + +std::vector gpt_sampler_sample_and_accept_n(struct gpt_sampler * gsmpl, struct llama_context * ctx, const llama_tokens & draft, bool grammar_first) { + std::vector idxs(draft.size() + 1); + for (size_t i = 0; i < idxs.size(); ++i) { + idxs[i] = i; + } + + return gpt_sampler_sample_and_accept_n(gsmpl, ctx, idxs, draft, grammar_first); +} + uint32_t gpt_sampler_get_seed(const struct gpt_sampler * gsmpl) { return llama_sampler_get_seed(gsmpl->chain); } diff --git a/common/sampling.h b/common/sampling.h index d0e1a920..82d6b023 100644 --- a/common/sampling.h +++ b/common/sampling.h @@ -60,6 +60,27 @@ void gpt_perf_print(const struct llama_context * ctx, const struct gpt_sampler * // llama_token gpt_sampler_sample(struct gpt_sampler * gsmpl, struct llama_context * ctx, int idx, bool grammar_first = false); +// generalized version of gpt_sampler_sample +// +// will cross-reference the sampled tokens with a batch of draft tokens and accept those that match +// if the sampler disagrees at some point, we stop and return the accepted tokens up to now +// +// gpt_sampler_sample_n(gsmpl, ctx, { idx }, {}); +// +// is equivalent to +// +// gpt_sampler_sample(gsmpl, ctx, idx); +// gpt_sampler_accept(gsmpl, token, true); +// +// requires: idxs.size() == draft.size() + 1 +// +// returns at least 1 token, up to idxs.size() +// +std::vector gpt_sampler_sample_and_accept_n(struct gpt_sampler * gsmpl, struct llama_context * ctx, const std::vector & idxs, const llama_tokens & draft, bool grammar_first = false); + +// assume idxs == [ 0, 1, 2, ..., draft.size() ] +std::vector gpt_sampler_sample_and_accept_n(struct gpt_sampler * gsmpl, struct llama_context * ctx, const llama_tokens & draft, bool grammar_first = false); + uint32_t gpt_sampler_get_seed(const struct gpt_sampler * gsmpl); // helpers diff --git a/common/speculative.cpp b/common/speculative.cpp new file mode 100644 index 00000000..8a03d08c --- /dev/null +++ b/common/speculative.cpp @@ -0,0 +1,271 @@ +#include "speculative.h" + +#include "log.h" +#include "common.h" +#include "sampling.h" + +#include + +#define SPEC_VOCAB_MAX_SIZE_DIFFERENCE 128 +#define SPEC_VOCAB_CHECK_START_TOKEN_ID 5 + +struct common_speculative { + struct llama_context * ctx; + struct gpt_sampler * smpl; + + llama_batch batch; + llama_tokens prompt; +}; + +struct common_speculative * common_speculative_init( + struct llama_context * ctx_dft) { + auto * result = new common_speculative { + /* .ctx = */ ctx_dft, + /* .smpl = */ nullptr, + /* .batch = */ llama_batch_init(llama_n_batch(ctx_dft), 0, 1), + /* .prompt = */ {}, + }; + + // TODO: optimize or pass from outside? +#if 0 + { + common_params_sampling params; + params.no_perf = false; + + params.top_k = 40; + params.top_p = 0.9; + + params.samplers = { + COMMON_SAMPLER_TYPE_TOP_K, + COMMON_SAMPLER_TYPE_TOP_P, + COMMON_SAMPLER_TYPE_INFILL, + }; + + result->smpl = gpt_sampler_init(llama_get_model(ctx_dft), params); + } +#else + { + gpt_sampler_params params; + params.no_perf = false; + + params.top_k = 10; + + params.samplers = { + GPT_SAMPLER_TYPE_TOP_K, + }; + + result->smpl = gpt_sampler_init(llama_get_model(ctx_dft), params); + } +#endif + + llama_update_context_with_rankworld(result->ctx, 0, 1, 0, 1); + + return result; +} + +void common_speculative_free(struct common_speculative * spec) { + gpt_sampler_free(spec->smpl); + + llama_batch_free(spec->batch); + + delete spec; +} + +bool common_speculative_are_compatible( + const struct llama_context * ctx_tgt, + const struct llama_context * ctx_dft) { + const struct llama_model * model_tgt = llama_get_model(ctx_tgt); + const struct llama_model * model_dft = llama_get_model(ctx_dft); + + const bool vocab_type_tgt = llama_vocab_type(model_tgt); + LOG_DBG("%s: vocab_type tgt: %d\n", __func__, vocab_type_tgt); + + const bool vocab_type_dft = llama_vocab_type(model_dft); + LOG_DBG("%s: vocab_type dft: %d\n", __func__, vocab_type_dft); + + if (vocab_type_tgt != vocab_type_dft) { + LOG_ERR("%s: draft model vocab type must match target model to use speculation but " + "vocab_type_dft = %d while vocab_type_tgt = %d\n", __func__, vocab_type_dft, vocab_type_tgt); + return false; + } + + if (llama_add_bos_token(model_tgt) != llama_add_bos_token(model_dft) || + llama_add_eos_token(model_tgt) != llama_add_eos_token(model_dft) || + llama_token_bos(model_tgt) != llama_token_bos(model_dft) || + llama_token_eos(model_tgt) != llama_token_eos(model_dft) + ) { + LOG_ERR("%s: draft model special tokens must match target model to use speculation\n", __func__); + return false; + } + + { + const int n_vocab_tgt = llama_n_vocab(model_tgt); + const int n_vocab_dft = llama_n_vocab(model_dft); + + const int vocab_diff = std::abs(n_vocab_tgt - n_vocab_dft); + + if (vocab_diff > SPEC_VOCAB_MAX_SIZE_DIFFERENCE) { + LOG_ERR("%s: draft model vocab must closely match target model to use speculation but " + "target vocab size %d does not match draft vocab size %d - difference %d, max allowed %d\n", + __func__, n_vocab_tgt, llama_n_vocab(model_dft), vocab_diff, SPEC_VOCAB_MAX_SIZE_DIFFERENCE); + return false; + } + + for (int i = SPEC_VOCAB_CHECK_START_TOKEN_ID; i < std::min(n_vocab_tgt, n_vocab_dft); ++i) { + const char * token_text_tgt = llama_token_get_text(model_tgt, i); + const char * token_text_dft = llama_token_get_text(model_dft, i); + if (std::strcmp(token_text_tgt, token_text_dft) != 0) { + LOG_ERR("%s: draft model vocab must match target model to use speculation but " + "token %d content differs - target '%s', draft '%s'\n", __func__, i, + llama_token_to_piece(ctx_tgt, i).c_str(), + llama_token_to_piece(ctx_dft, i).c_str()); + return false; + } + } + } + + return true; +} + +llama_tokens common_speculative_gen_draft( + struct common_speculative * spec, + struct common_speculative_params params, + const llama_tokens & prompt_tgt, + llama_token id_last) { + auto & batch = spec->batch; + auto & ctx = spec->ctx; + auto & smpl = spec->smpl; + auto & prompt = spec->prompt; + + int reuse_i = 0; + int reuse_n = 0; + + const int n_ctx = llama_n_ctx(ctx) - params.n_draft; + + const int i_start = std::max(0, (int) prompt_tgt.size() - n_ctx); + + // reuse as much as possible from the old draft context + // ideally, the draft context should be as big as the target context and we will always reuse the entire prompt + for (int i = 0; i < (int) prompt.size(); ++i) { + int cur = 0; + while (i_start + cur < (int) prompt_tgt.size() && + i + cur < (int) prompt.size() && + prompt_tgt[i_start + cur] == prompt[i + cur]) { + cur++; + } + + if ((cur >= params.n_reuse || n_ctx >= (int) prompt_tgt.size()) && cur > reuse_n) { + reuse_i = i; + reuse_n = cur; + } + } + + LOG_DBG("%s: reuse_i = %d, reuse_n = %d, prompt = %d\n", __func__, reuse_i, reuse_n, (int) prompt.size()); + + llama_tokens result; + result.reserve(params.n_draft); + + if (reuse_n == 0) { + llama_kv_cache_clear(ctx); + + prompt.clear(); + } else { + // this happens when a previous draft has been discarded (for example, due to being too small), but the + // target model agreed with it. in this case, we simply pass back the previous results to save compute + if (reuse_i + reuse_n < (int) prompt.size() && prompt[reuse_i + reuse_n] == id_last) { + for (int i = reuse_i + reuse_n + 1; i < (int) prompt.size(); ++i) { + result.push_back(prompt[i]); + + if (params.n_draft <= (int) result.size()) { + break; + } + } + + return result; + } + + if (reuse_i > 0) { + llama_kv_cache_seq_rm (ctx, 0, 0, reuse_i); + llama_kv_cache_seq_add(ctx, 0, reuse_i, -1, -reuse_i); + + prompt.erase(prompt.begin(), prompt.begin() + reuse_i); + } + + if (reuse_n < (int) prompt.size()) { + llama_kv_cache_seq_rm (ctx, 0, reuse_n, -1); + + prompt.erase(prompt.begin() + reuse_n, prompt.end()); + } + } + + // prepare a batch to evaluate any new tokens in the prompt + llama_batch_clear(batch); + + for (size_t i = i_start + reuse_n; i < prompt_tgt.size(); ++i) { + //LOG_DBG("i = %d, i_start = %d, reuse_n = %d, i - i_start = %d, id = %6d\n", i, i_start, reuse_n, i - i_start, prompt_tgt[i]); + llama_batch_add(batch, prompt_tgt[i], i - i_start, { 0 }, false); + + prompt.push_back(prompt_tgt[i]); + } + + // we should rarely end-up here during normal decoding + if (batch.n_tokens > 0) { + //LOG_DBG("%s: draft prompt batch: %s\n", __func__, string_from(ctx, batch).c_str()); + + llama_decode(ctx, batch); + } + + const llama_pos n_past = prompt.size(); + + LOG_DBG("%s: n_past = %d\n", __func__, n_past); + + llama_batch_clear(batch); + llama_batch_add (batch, id_last, n_past, { 0 }, true); + + prompt.push_back(id_last); + + //LOG_DBG("%s: draft prompt: %s\n", __func__, string_from(ctx, prompt).c_str()); + + llama_decode(ctx, batch); + + gpt_sampler_reset(smpl); + + // sample n_draft tokens from the draft model + for (int i = 0; i < params.n_draft; ++i) { + llama_batch_clear(batch); + + gpt_sampler_sample(smpl, ctx, 0, true); + + const auto * cur_p = gpt_sampler_get_candidates(smpl); + + for (int k = 0; k < std::min(3, (int) cur_p->size); ++k) { + LOG_DBG(" - draft candidate %3d, pos %3d: %6d (%8.3f) '%s'\n", + k, i, cur_p->data[k].id, cur_p->data[k].p, llama_token_to_piece(ctx, cur_p->data[k].id).c_str()); + } + + // add drafted token for each sequence + const llama_token id = cur_p->data[0].id; + + // only collect very high-confidence draft tokens + if (cur_p->data[0].p < params.p_min) { + break; + } + + gpt_sampler_accept(smpl, id, true); + + result.push_back(id); + + if (params.n_draft <= (int) result.size()) { + break; + } + + llama_batch_add(batch, id, n_past + i + 1, { 0 }, true); + + // evaluate the drafted tokens on the draft model + llama_decode(ctx, batch); + + prompt.push_back(id); + } + + return result; +} \ No newline at end of file diff --git a/common/speculative.h b/common/speculative.h new file mode 100644 index 00000000..0af10ea4 --- /dev/null +++ b/common/speculative.h @@ -0,0 +1,28 @@ +#pragma once + +#include "llama.h" +#include "common.h" + +struct common_speculative; + +struct common_speculative_params { + int n_draft = 16; // max drafted tokens + int n_reuse = 256; + + float p_min = 0.9f; // min probabiliy required to accept a token in the draft +}; + +struct common_speculative * common_speculative_init(struct llama_context * ctx_dft); + +void common_speculative_free(struct common_speculative * spec); + +bool common_speculative_are_compatible( + const struct llama_context * ctx_tgt, + const struct llama_context * ctx_dft); + +// sample up to n_draft tokens and add them to the batch using the draft model +llama_tokens common_speculative_gen_draft( + struct common_speculative * spec, + struct common_speculative_params params, + const llama_tokens & prompt, + llama_token id_last); \ No newline at end of file diff --git a/examples/server/server.cpp b/examples/server/server.cpp index af39f1ac..d086eaf3 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -6,6 +6,7 @@ #include "sampling.h" #include "json-schema-to-grammar.h" #include "llama.h" +#include "speculative.h" // Change JSON_ASSERT from assert() to GGML_ASSERT: #define JSON_ASSERT GGML_ASSERT @@ -133,6 +134,9 @@ struct slot_params { int32_t n_predict = -1; // new tokens to predict std::vector antiprompt; + + struct gpt_sampler_params sampling; + struct common_params_speculative speculative; json input_prefix; json input_suffix; @@ -142,6 +146,12 @@ struct server_slot { int id; int id_task = -1; + llama_batch batch_spec; + + llama_context * ctx_dft = nullptr; + + common_speculative * spec = nullptr; + // the index relative to completion multi-task request size_t index = 0; @@ -231,7 +241,7 @@ struct server_slot { generated_token_probs.clear(); } - bool has_budget(gpt_params &global_params) { + bool has_budget(const gpt_params &global_params) { if (params.n_predict == -1 && global_params.n_predict == -1) { return true; // limitless } @@ -251,6 +261,10 @@ struct server_slot { return state != SLOT_STATE_IDLE; } + bool can_speculate() const { + return ctx_dft && params.speculative.n_max > 0 && params.cache_prompt; + } + void add_token(const completion_token_output & token) { if (!is_processing()) { SLT_WRN(*this, "%s", "slot is not processing\n"); @@ -615,6 +629,9 @@ struct server_context { gpt_params params; + llama_model * model_dft = nullptr; + llama_context_params cparams_dft; + llama_batch batch = {}; bool clean_kv_cache = true; @@ -652,17 +669,33 @@ struct server_context { model = nullptr; } + if (model_dft) { + llama_free_model(model_dft); + model_dft = nullptr; + } + // Clear any sampling context for (server_slot & slot : slots) { if (slot.smpl != nullptr) { gpt_sampler_free(slot.smpl); } + slot.smpl = nullptr; + + llama_free(slot.ctx_dft); + slot.ctx_dft = nullptr; + + common_speculative_free(slot.spec); + slot.spec = nullptr; + + llama_batch_free(slot.batch_spec); } llama_batch_free(batch); } bool load_model(const gpt_params & params_) { + SRV_INF("loading model '%s'\n", params.model.c_str()); + params = params_; // dedicate one sequence to the system prompt @@ -685,6 +718,44 @@ struct server_context { add_bos_token = llama_add_bos_token(model); has_eos_token = !llama_add_eos_token(model); + + if (!params.speculative.model.empty()) { + SRV_INF("loading draft model '%s'\n", params.speculative.model.c_str()); + + auto params_dft = params; + + params_dft.model = params.speculative.model; + params_dft.n_ctx = params.speculative.n_ctx; + params_dft.n_gpu_layers = params.speculative.n_gpu_layers; + params_dft.n_world = 1; // do not split the draft model across devicesAdd commentMore actions + params_dft.rank = 0; // always load the draft model on the head device + + std::fill_n(params_dft.n_layer_window, params.n_world, 0); + + llama_init_result llama_init_dft = llama_init_from_gpt_params(params_dft); + + model_dft = llama_init_dft.model; + + if (model_dft == nullptr) { + SRV_ERR("failed to load draft model, '%s'\n", params.speculative.model.c_str()); + return false; + } + + if (!common_speculative_are_compatible(ctx, llama_init_dft.context)) { + SRV_ERR("the draft model '%s' is not compatible with the target model '%s'\n", params.speculative.model.c_str(), params.model.c_str()); + + llama_free (llama_init_dft.context); + llama_free_model(llama_init_dft.model); + + return false; + } + + cparams_dft = llama_context_params_from_gpt_params(params); + cparams_dft.n_batch = llama_n_ctx(llama_init_dft.context); + + // the context is not needed - we will create one for each slot + llama_free(llama_init_dft.context); + } return true; } @@ -708,6 +779,30 @@ struct server_context { slot.id = i; slot.n_ctx = n_ctx_slot; slot.n_predict = params.n_predict; + + if (model_dft) { + slot.batch_spec = llama_batch_init(params.speculative.n_max + 1, 0, 1); + + slot.ctx_dft = llama_new_context_with_model(model_dft, cparams_dft); + + if (llama_context_setup_backend(model, cparams_dft, slot.ctx_dft) == nullptr) { + SRV_ERR("%s: failed to setup context with model '%s'\n", __func__, params.model.c_str()); + llama_free(slot.ctx_dft); + llama_free_model(model); + return; + } + + if (slot.ctx_dft == nullptr) { + SRV_ERR("%s", "failed to create draft context\n"); + return; + } + + slot.spec = common_speculative_init(slot.ctx_dft); + if (slot.spec == nullptr) { + SRV_ERR("%s", "failed to create speculator\n"); + return; + } + } SLT_INF(slot, "new slot n_ctx_slot = %d\n", slot.n_ctx); @@ -875,6 +970,8 @@ struct server_context { slot_params default_params; // Sampling parameter defaults are loaded from the global server context (but individual requests can still override them) auto default_sparams = params.sparams; + default_params.speculative = params.speculative; + const auto & data = task.data; if (data.count("__oaicompat") != 0) { @@ -909,6 +1006,12 @@ struct server_context { slot.sparams.seed = json_value(data, "seed", default_sparams.seed); slot.sparams.n_probs = json_value(data, "n_probs", default_sparams.n_probs); slot.sparams.min_keep = json_value(data, "min_keep", default_sparams.min_keep); + + slot.params.speculative.n_min = json_value(data, "speculative.n_min", default_params.speculative.n_min); + slot.params.speculative.n_max = json_value(data, "speculative.n_max", default_params.speculative.n_max); + slot.params.speculative.p_min = json_value(data, "speculative.p_min", default_params.speculative.p_min); + + slot.params.speculative.n_min = std::min(slot.params.speculative.n_max, slot.params.speculative.n_min); // process "json_schema" and "grammar" if (data.contains("json_schema") && !data.at("json_schema").is_null() && data.contains("grammar") && !data.at("grammar").is_null()) { @@ -1049,6 +1152,12 @@ struct server_context { return false; } } + + if (slot.ctx_dft) { + llama_batch_free(slot.batch_spec); + + slot.batch_spec = llama_batch_init(slot.params.speculative.n_max + 1, 0, 1); + } slot.state = SLOT_STATE_PROCESSING_PROMPT; slot.prompt_tokens.clear(); @@ -2357,38 +2466,100 @@ struct server_context { continue; // continue loop of slots } - completion_token_output result; - const llama_token id = gpt_sampler_sample(slot.smpl, ctx, slot.i_batch - i); + llama_token id; - gpt_sampler_accept(slot.smpl, id, true); + { + completion_token_output result; - slot.n_decoded += 1; - if (slot.n_decoded == 1) { - slot.t_start_generation = ggml_time_us(); - slot.t_prompt_processing = (slot.t_start_generation - slot.t_start_process_prompt) / 1e3; - metrics.on_prompt_eval(slot); + id = gpt_sampler_sample(slot.smpl, ctx, slot.i_batch - i); + + slot.i_batch = -1; + + gpt_sampler_accept(slot.smpl, id, true); + + slot.n_decoded += 1; + if (slot.n_decoded == 1) { + slot.t_start_generation = ggml_time_us(); + slot.t_prompt_processing = (slot.t_start_generation - slot.t_start_process_prompt) / 1e3; + metrics.on_prompt_eval(slot); + } + + result.tok = id; + + const auto * cur_p = gpt_sampler_get_candidates(slot.smpl); + + for (size_t i = 0; i < (size_t) slot.params.sampling.n_probs; ++i) { + result.probs.push_back({ + cur_p->data[i].id, + i >= cur_p->size ? 0.0f : cur_p->data[i].p, + }); + } + + if (!process_token(result, slot)) { + // release slot because of stop condition + slot.release(); + slot.print_timings(); + send_final_response(slot); + metrics.on_prediction(slot); + continue; + } } - result.tok = id; - - const auto * cur_p = gpt_sampler_get_candidates(slot.smpl); - - for (size_t i = 0; i < (size_t) slot.sparams.n_probs; ++i) { - result.probs.push_back({ - cur_p->data[i].id, - i >= cur_p->size ? 0.0f : cur_p->data[i].p, - }); + // check if the slot supports speculative decoding + if (!slot.can_speculate()) { + continue; } - if (!process_token(result, slot)) { - // release slot because of stop condition - slot.release(); - slot.print_timings(); - send_final_response(slot); - metrics.on_prediction(slot); + struct common_speculative_params params_spec; + params_spec.n_draft = slot.params.speculative.n_max; + params_spec.n_reuse = llama_n_ctx(slot.ctx_dft) - slot.params.speculative.n_max; + params_spec.p_min = slot.params.speculative.p_min; + + llama_tokens draft = common_speculative_gen_draft(slot.spec, params_spec, slot.cache_tokens, id); + + // ignore small drafts + if (slot.params.speculative.n_min > (int) draft.size()) { + continue; } - slot.i_batch = -1; + // construct the speculation batch + llama_batch_clear(slot.batch_spec); + llama_batch_add (slot.batch_spec, id, slot.n_past, { slot.id }, true); + + for (size_t i = 0; i < draft.size(); ++i) { + llama_batch_add(slot.batch_spec, draft[i], slot.n_past + 1 + i, { slot.id }, true); + } + + llama_decode(ctx, slot.batch_spec); + + // the accepted tokens from the speculation + const auto ids = gpt_sampler_sample_and_accept_n(slot.smpl, ctx, draft); + + slot.n_past += ids.size(); + slot.n_decoded += ids.size(); + + slot.cache_tokens.push_back(id); + slot.cache_tokens.insert(slot.cache_tokens.end(), ids.begin(), ids.end() - 1); + + llama_kv_cache_seq_rm(ctx, slot.id, slot.n_past, -1); + + for (size_t i = 0; i < ids.size(); ++i) { + completion_token_output result; + + result.tok = ids[i]; + + if (!process_token(result, slot)) { + // release slot because of stop condition + slot.release(); + slot.print_timings(); + send_final_response(slot); + metrics.on_prediction(slot); + break; + } + } + + SRV_DBG("accepted %d/%d draft tokens\n", (int) ids.size() - 1, (int) draft.size()); + } } diff --git a/examples/speculative/speculative.cpp b/examples/speculative/speculative.cpp index 8527f8b6..0a6c4701 100644 --- a/examples/speculative/speculative.cpp +++ b/examples/speculative/speculative.cpp @@ -41,7 +41,7 @@ int main(int argc, char ** argv) { gpt_init(); - if (params.model_draft.empty()) { + if (params.speculative.model.empty()) { LOG_ERR("%s: --model-draft is required\n", __func__); return 1; } @@ -68,7 +68,7 @@ int main(int argc, char ** argv) { // load the draft model // make a hard copy of params to use for the draft model gpt_params params_draft = params; - params_draft.model = params_draft.model_draft; + params_draft.model = params_draft.speculative.model; params_draft.n_gpu_layers = params_draft.n_gpu_layers_draft; params_draft.n_world = 1; // do not split the draft model across devices params_draft.rank = 0; // always load the draft model on the head device diff --git a/src/llama.cpp b/src/llama.cpp index e5965fd8..a57715e3 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -20944,6 +20944,12 @@ struct llama_context * llama_new_context_with_model( ctx->cparams.rank = params.rank; ctx->cparams.force = params.force; ctx->cparams.original_next_rank = (params.rank + 1) % params.n_world; + + auto &hparams = model->hparams; + auto &cparams = ctx->cparams; + cparams.n_ctx = params.n_ctx == 0 ? hparams.n_ctx_train : params.n_ctx; + ctx->logits_all = params.logits_all; + return ctx; }