mirror of
https://github.com/Lizonghang/prima.cpp.git
synced 2025-09-05 16:59:03 +00:00
Add support for speculative decoding in llama-server
Add support for speculative decoding in llama-server
This commit is contained in:
commit
f032680cab
12 changed files with 640 additions and 46 deletions
8
Makefile
8
Makefile
|
@ -963,6 +963,7 @@ OBJ_COMMON = \
|
||||||
common/console.o \
|
common/console.o \
|
||||||
common/ngram-cache.o \
|
common/ngram-cache.o \
|
||||||
common/sampling.o \
|
common/sampling.o \
|
||||||
|
common/speculative.o \
|
||||||
common/train.o \
|
common/train.o \
|
||||||
common/build-info.o \
|
common/build-info.o \
|
||||||
common/json-schema-to-grammar.o
|
common/json-schema-to-grammar.o
|
||||||
|
@ -1239,6 +1240,13 @@ common/json-schema-to-grammar.o: \
|
||||||
common/json-schema-to-grammar.h
|
common/json-schema-to-grammar.h
|
||||||
$(CXX) $(CXXFLAGS) -c $< -o $@
|
$(CXX) $(CXXFLAGS) -c $< -o $@
|
||||||
|
|
||||||
|
# speculative
|
||||||
|
common/speculative.o: \
|
||||||
|
common/speculative.cpp \
|
||||||
|
common/speculative.h \
|
||||||
|
include/llama.h
|
||||||
|
$(CXX) $(CXXFLAGS) -c $< -o $@
|
||||||
|
|
||||||
common/train.o: \
|
common/train.o: \
|
||||||
common/train.cpp \
|
common/train.cpp \
|
||||||
common/train.h
|
common/train.h
|
||||||
|
|
|
@ -381,6 +381,7 @@ curl -X POST http://localhost:8080/v1/cancel \
|
||||||
```
|
```
|
||||||
|
|
||||||
**9. How to use speculative decoding?**
|
**9. How to use speculative decoding?**
|
||||||
|
|
||||||
Please see "[Power prima.cpp with speculative decoding: Further speeds up by up to 80%](https://github.com/Lizonghang/prima.cpp/discussions/29)".
|
Please see "[Power prima.cpp with speculative decoding: Further speeds up by up to 80%](https://github.com/Lizonghang/prima.cpp/discussions/29)".
|
||||||
|
|
||||||
## ❤️ Acknowledgment
|
## ❤️ Acknowledgment
|
||||||
|
|
|
@ -627,12 +627,19 @@ gpt_params_context gpt_params_parser_init(gpt_params & params, llama_example ex,
|
||||||
}
|
}
|
||||||
).set_examples({LLAMA_EXAMPLE_SPECULATIVE}));
|
).set_examples({LLAMA_EXAMPLE_SPECULATIVE}));
|
||||||
add_opt(llama_arg(
|
add_opt(llama_arg(
|
||||||
{"--draft"}, "N",
|
{"--draft-max", "--draft", "--draft-n"}, "N",
|
||||||
format("number of tokens to draft for speculative decoding (default: %d)", params.n_draft),
|
format("number of tokens to draft for speculative decoding (default: %d)", params.speculative.n_max),
|
||||||
[](gpt_params & params, int value) {
|
[](gpt_params & params, int value) {
|
||||||
params.n_draft = value;
|
params.speculative.n_max = value;
|
||||||
}
|
}
|
||||||
).set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_LOOKUP}));
|
).set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_LOOKUP, LLAMA_EXAMPLE_SERVER}));
|
||||||
|
add_opt(llama_arg(
|
||||||
|
{"--draft-min", "--draft-n-min"}, "N",
|
||||||
|
format("minimum number of draft tokens to use for speculative decoding (default: %d)", params.speculative.n_min),
|
||||||
|
[](gpt_params & params, int value) {
|
||||||
|
params.speculative.n_min = value;
|
||||||
|
}
|
||||||
|
).set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_LOOKUP, LLAMA_EXAMPLE_SERVER}));
|
||||||
add_opt(llama_arg(
|
add_opt(llama_arg(
|
||||||
{"-ps", "--p-split"}, "N",
|
{"-ps", "--p-split"}, "N",
|
||||||
format("speculative decoding split probability (default: %.1f)", (double)params.p_split),
|
format("speculative decoding split probability (default: %.1f)", (double)params.p_split),
|
||||||
|
@ -640,6 +647,13 @@ gpt_params_context gpt_params_parser_init(gpt_params & params, llama_example ex,
|
||||||
params.p_split = std::stof(value);
|
params.p_split = std::stof(value);
|
||||||
}
|
}
|
||||||
).set_examples({LLAMA_EXAMPLE_SPECULATIVE}));
|
).set_examples({LLAMA_EXAMPLE_SPECULATIVE}));
|
||||||
|
add_opt(llama_arg(
|
||||||
|
{"--draft-p-min"}, "P",
|
||||||
|
format("minimum speculative decoding probability (greedy) (default: %.1f)", (double)params.speculative.p_min),
|
||||||
|
[](gpt_params & params, const std::string & value) {
|
||||||
|
params.speculative.p_min = std::stof(value);
|
||||||
|
}
|
||||||
|
).set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SERVER}));
|
||||||
add_opt(llama_arg(
|
add_opt(llama_arg(
|
||||||
{"-lcs", "--lookup-cache-static"}, "FNAME",
|
{"-lcs", "--lookup-cache-static"}, "FNAME",
|
||||||
"path to static lookup cache to use for lookup decoding (not updated by generation)",
|
"path to static lookup cache to use for lookup decoding (not updated by generation)",
|
||||||
|
@ -659,6 +673,7 @@ gpt_params_context gpt_params_parser_init(gpt_params & params, llama_example ex,
|
||||||
format("size of the prompt context (default: %d, 0 = loaded from model)", params.n_ctx),
|
format("size of the prompt context (default: %d, 0 = loaded from model)", params.n_ctx),
|
||||||
[](gpt_params & params, int value) {
|
[](gpt_params & params, int value) {
|
||||||
params.n_ctx = value;
|
params.n_ctx = value;
|
||||||
|
params.speculative.n_ctx = value;
|
||||||
}
|
}
|
||||||
).set_env("LLAMA_ARG_CTX_SIZE"));
|
).set_env("LLAMA_ARG_CTX_SIZE"));
|
||||||
add_opt(llama_arg(
|
add_opt(llama_arg(
|
||||||
|
@ -1555,13 +1570,14 @@ gpt_params_context gpt_params_parser_init(gpt_params & params, llama_example ex,
|
||||||
{"-ngld", "--gpu-layers-draft", "--n-gpu-layers-draft"}, "N",
|
{"-ngld", "--gpu-layers-draft", "--n-gpu-layers-draft"}, "N",
|
||||||
"number of layers to store in VRAM for the draft model",
|
"number of layers to store in VRAM for the draft model",
|
||||||
[](gpt_params & params, int value) {
|
[](gpt_params & params, int value) {
|
||||||
params.n_gpu_layers_draft = value;
|
params.n_gpu_layers_draft = value; // TODO: remove
|
||||||
|
params.speculative.n_gpu_layers = value;
|
||||||
if (!llama_supports_gpu_offload()) {
|
if (!llama_supports_gpu_offload()) {
|
||||||
fprintf(stderr, "warning: not compiled with GPU offload support, --gpu-layers-draft option will be ignored\n");
|
fprintf(stderr, "warning: not compiled with GPU offload support, --gpu-layers-draft option will be ignored\n");
|
||||||
fprintf(stderr, "warning: see main README.md for information on enabling GPU BLAS support\n");
|
fprintf(stderr, "warning: see main README.md for information on enabling GPU BLAS support\n");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
).set_examples({LLAMA_EXAMPLE_SPECULATIVE}));
|
).set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SERVER}));
|
||||||
add_opt(llama_arg(
|
add_opt(llama_arg(
|
||||||
{"-sm", "--split-mode"}, "{none,layer,row}",
|
{"-sm", "--split-mode"}, "{none,layer,row}",
|
||||||
"how to split the model across multiple GPUs, one of:\n"
|
"how to split the model across multiple GPUs, one of:\n"
|
||||||
|
@ -1704,9 +1720,9 @@ gpt_params_context gpt_params_parser_init(gpt_params & params, llama_example ex,
|
||||||
{"-md", "--model-draft"}, "FNAME",
|
{"-md", "--model-draft"}, "FNAME",
|
||||||
"draft model for speculative decoding (default: unused)",
|
"draft model for speculative decoding (default: unused)",
|
||||||
[](gpt_params & params, const std::string & value) {
|
[](gpt_params & params, const std::string & value) {
|
||||||
params.model_draft = value;
|
params.speculative.model = value;
|
||||||
}
|
}
|
||||||
).set_examples({LLAMA_EXAMPLE_SPECULATIVE}));
|
).set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SERVER}));
|
||||||
add_opt(llama_arg(
|
add_opt(llama_arg(
|
||||||
{"-mu", "--model-url"}, "MODEL_URL",
|
{"-mu", "--model-url"}, "MODEL_URL",
|
||||||
"model download url (default: unused)",
|
"model download url (default: unused)",
|
||||||
|
|
|
@ -3111,7 +3111,7 @@ void yaml_dump_non_result_info(FILE * stream, const gpt_params & params, const l
|
||||||
fprintf(stream, "mirostat_lr: %f # default: 0.1\n", sparams.mirostat_eta);
|
fprintf(stream, "mirostat_lr: %f # default: 0.1\n", sparams.mirostat_eta);
|
||||||
fprintf(stream, "mlock: %s # default: false\n", params.use_mlock ? "true" : "false");
|
fprintf(stream, "mlock: %s # default: false\n", params.use_mlock ? "true" : "false");
|
||||||
fprintf(stream, "model: %s # default: %s\n", params.model.c_str(), DEFAULT_MODEL_PATH);
|
fprintf(stream, "model: %s # default: %s\n", params.model.c_str(), DEFAULT_MODEL_PATH);
|
||||||
fprintf(stream, "model_draft: %s # default:\n", params.model_draft.c_str());
|
fprintf(stream, "model_draft: %s # default:\n", params.speculative.model.c_str());
|
||||||
fprintf(stream, "multiline_input: %s # default: false\n", params.multiline_input ? "true" : "false");
|
fprintf(stream, "multiline_input: %s # default: false\n", params.multiline_input ? "true" : "false");
|
||||||
fprintf(stream, "n_gpu_layers: %d # default: -1\n", params.n_gpu_layers);
|
fprintf(stream, "n_gpu_layers: %d # default: -1\n", params.n_gpu_layers);
|
||||||
fprintf(stream, "n_predict: %d # default: -1 (unlimited)\n", params.n_predict);
|
fprintf(stream, "n_predict: %d # default: -1 (unlimited)\n", params.n_predict);
|
||||||
|
|
|
@ -33,6 +33,8 @@ struct llama_lora_adapter_container : llama_lora_adapter_info {
|
||||||
struct llama_lora_adapter * adapter;
|
struct llama_lora_adapter * adapter;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
using llama_tokens = std::vector<llama_token>;
|
||||||
|
|
||||||
// build info
|
// build info
|
||||||
extern int LLAMA_BUILD_NUMBER;
|
extern int LLAMA_BUILD_NUMBER;
|
||||||
extern char const * LLAMA_COMMIT;
|
extern char const * LLAMA_COMMIT;
|
||||||
|
@ -141,6 +143,20 @@ struct gpt_sampler_params {
|
||||||
std::string print() const;
|
std::string print() const;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
struct common_params_speculative {
|
||||||
|
int32_t n_ctx = 0; // draft context size
|
||||||
|
int32_t n_max = 16; // maximum number of tokens to draft during speculative decoding
|
||||||
|
int32_t n_min = 5; // minimum number of draft tokens to use for speculative decoding
|
||||||
|
int32_t n_gpu_layers = -1; // number of layers to store in VRAM for the draft model (-1 - use default)
|
||||||
|
float p_split = 0.1f; // speculative decoding split probability
|
||||||
|
float p_min = 0.9f; // minimum speculative decoding probability (greedy)
|
||||||
|
|
||||||
|
struct cpu_params cpuparams;
|
||||||
|
struct cpu_params cpuparams_batch;
|
||||||
|
|
||||||
|
std::string model = ""; // draft model for speculative decoding // NOLINT
|
||||||
|
};
|
||||||
|
|
||||||
struct gpt_params {
|
struct gpt_params {
|
||||||
int32_t n_world = 1; // number of devices to use
|
int32_t n_world = 1; // number of devices to use
|
||||||
int32_t rank = 0; // my rank for distributed inference
|
int32_t rank = 0; // my rank for distributed inference
|
||||||
|
@ -161,7 +177,6 @@ struct gpt_params {
|
||||||
int32_t n_batch = 2048; // logical batch size for prompt processing (must be >=32 to use BLAS)
|
int32_t n_batch = 2048; // logical batch size for prompt processing (must be >=32 to use BLAS)
|
||||||
int32_t n_ubatch = 512; // physical batch size for prompt processing (must be >=32 to use BLAS)
|
int32_t n_ubatch = 512; // physical batch size for prompt processing (must be >=32 to use BLAS)
|
||||||
int32_t n_keep = 0; // number of tokens to keep from initial prompt
|
int32_t n_keep = 0; // number of tokens to keep from initial prompt
|
||||||
int32_t n_draft = 5; // number of tokens to draft during speculative decoding
|
|
||||||
int32_t n_chunks = -1; // max number of chunks to process (-1 = unlimited)
|
int32_t n_chunks = -1; // max number of chunks to process (-1 = unlimited)
|
||||||
int32_t n_parallel = 1; // number of parallel sequences to decode
|
int32_t n_parallel = 1; // number of parallel sequences to decode
|
||||||
int32_t n_sequences = 1; // number of sequences to decode
|
int32_t n_sequences = 1; // number of sequences to decode
|
||||||
|
@ -198,9 +213,9 @@ struct gpt_params {
|
||||||
enum llama_attention_type attention_type = LLAMA_ATTENTION_TYPE_UNSPECIFIED; // attention type for embeddings
|
enum llama_attention_type attention_type = LLAMA_ATTENTION_TYPE_UNSPECIFIED; // attention type for embeddings
|
||||||
|
|
||||||
struct gpt_sampler_params sparams;
|
struct gpt_sampler_params sparams;
|
||||||
|
struct common_params_speculative speculative;
|
||||||
|
|
||||||
std::string model = ""; // model path // NOLINT
|
std::string model = ""; // model path // NOLINT
|
||||||
std::string model_draft = ""; // draft model for speculative decoding // NOLINT
|
|
||||||
std::string model_alias = "unknown"; // model alias // NOLINT
|
std::string model_alias = "unknown"; // model alias // NOLINT
|
||||||
std::string model_url = ""; // model url to download // NOLINT
|
std::string model_url = ""; // model url to download // NOLINT
|
||||||
std::string hf_token = ""; // HF token // NOLINT
|
std::string hf_token = ""; // HF token // NOLINT
|
||||||
|
|
|
@ -318,6 +318,45 @@ llama_token gpt_sampler_sample(struct gpt_sampler * gsmpl, struct llama_context
|
||||||
return cur_p.data[cur_p.selected].id;
|
return cur_p.data[cur_p.selected].id;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::vector<llama_token> gpt_sampler_sample_and_accept_n(struct gpt_sampler * gsmpl, struct llama_context * ctx, const std::vector<int> & idxs, const llama_tokens & draft, bool grammar_first) {
|
||||||
|
GGML_ASSERT(idxs.size() == draft.size() + 1 && "idxs.size() must be draft.size() + 1");
|
||||||
|
|
||||||
|
std::vector<llama_token> result;
|
||||||
|
result.reserve(idxs.size());
|
||||||
|
|
||||||
|
size_t i = 0;
|
||||||
|
for (; i < draft.size(); i++) {
|
||||||
|
const llama_token id = gpt_sampler_sample(gsmpl, ctx, idxs[i], grammar_first);
|
||||||
|
|
||||||
|
gpt_sampler_accept(gsmpl, id, true);
|
||||||
|
|
||||||
|
result.push_back(id);
|
||||||
|
|
||||||
|
if (draft[i] != id) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (i == draft.size()) {
|
||||||
|
const llama_token id = gpt_sampler_sample(gsmpl, ctx, idxs[i], grammar_first);
|
||||||
|
|
||||||
|
gpt_sampler_accept(gsmpl, id, true);
|
||||||
|
|
||||||
|
result.push_back(id);
|
||||||
|
}
|
||||||
|
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<llama_token> gpt_sampler_sample_and_accept_n(struct gpt_sampler * gsmpl, struct llama_context * ctx, const llama_tokens & draft, bool grammar_first) {
|
||||||
|
std::vector<int> idxs(draft.size() + 1);
|
||||||
|
for (size_t i = 0; i < idxs.size(); ++i) {
|
||||||
|
idxs[i] = i;
|
||||||
|
}
|
||||||
|
|
||||||
|
return gpt_sampler_sample_and_accept_n(gsmpl, ctx, idxs, draft, grammar_first);
|
||||||
|
}
|
||||||
|
|
||||||
uint32_t gpt_sampler_get_seed(const struct gpt_sampler * gsmpl) {
|
uint32_t gpt_sampler_get_seed(const struct gpt_sampler * gsmpl) {
|
||||||
return llama_sampler_get_seed(gsmpl->chain);
|
return llama_sampler_get_seed(gsmpl->chain);
|
||||||
}
|
}
|
||||||
|
|
|
@ -60,6 +60,27 @@ void gpt_perf_print(const struct llama_context * ctx, const struct gpt_sampler *
|
||||||
//
|
//
|
||||||
llama_token gpt_sampler_sample(struct gpt_sampler * gsmpl, struct llama_context * ctx, int idx, bool grammar_first = false);
|
llama_token gpt_sampler_sample(struct gpt_sampler * gsmpl, struct llama_context * ctx, int idx, bool grammar_first = false);
|
||||||
|
|
||||||
|
// generalized version of gpt_sampler_sample
|
||||||
|
//
|
||||||
|
// will cross-reference the sampled tokens with a batch of draft tokens and accept those that match
|
||||||
|
// if the sampler disagrees at some point, we stop and return the accepted tokens up to now
|
||||||
|
//
|
||||||
|
// gpt_sampler_sample_n(gsmpl, ctx, { idx }, {});
|
||||||
|
//
|
||||||
|
// is equivalent to
|
||||||
|
//
|
||||||
|
// gpt_sampler_sample(gsmpl, ctx, idx);
|
||||||
|
// gpt_sampler_accept(gsmpl, token, true);
|
||||||
|
//
|
||||||
|
// requires: idxs.size() == draft.size() + 1
|
||||||
|
//
|
||||||
|
// returns at least 1 token, up to idxs.size()
|
||||||
|
//
|
||||||
|
std::vector<llama_token> gpt_sampler_sample_and_accept_n(struct gpt_sampler * gsmpl, struct llama_context * ctx, const std::vector<int> & idxs, const llama_tokens & draft, bool grammar_first = false);
|
||||||
|
|
||||||
|
// assume idxs == [ 0, 1, 2, ..., draft.size() ]
|
||||||
|
std::vector<llama_token> gpt_sampler_sample_and_accept_n(struct gpt_sampler * gsmpl, struct llama_context * ctx, const llama_tokens & draft, bool grammar_first = false);
|
||||||
|
|
||||||
uint32_t gpt_sampler_get_seed(const struct gpt_sampler * gsmpl);
|
uint32_t gpt_sampler_get_seed(const struct gpt_sampler * gsmpl);
|
||||||
|
|
||||||
// helpers
|
// helpers
|
||||||
|
|
271
common/speculative.cpp
Normal file
271
common/speculative.cpp
Normal file
|
@ -0,0 +1,271 @@
|
||||||
|
#include "speculative.h"
|
||||||
|
|
||||||
|
#include "log.h"
|
||||||
|
#include "common.h"
|
||||||
|
#include "sampling.h"
|
||||||
|
|
||||||
|
#include <cstring>
|
||||||
|
|
||||||
|
#define SPEC_VOCAB_MAX_SIZE_DIFFERENCE 128
|
||||||
|
#define SPEC_VOCAB_CHECK_START_TOKEN_ID 5
|
||||||
|
|
||||||
|
struct common_speculative {
|
||||||
|
struct llama_context * ctx;
|
||||||
|
struct gpt_sampler * smpl;
|
||||||
|
|
||||||
|
llama_batch batch;
|
||||||
|
llama_tokens prompt;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct common_speculative * common_speculative_init(
|
||||||
|
struct llama_context * ctx_dft) {
|
||||||
|
auto * result = new common_speculative {
|
||||||
|
/* .ctx = */ ctx_dft,
|
||||||
|
/* .smpl = */ nullptr,
|
||||||
|
/* .batch = */ llama_batch_init(llama_n_batch(ctx_dft), 0, 1),
|
||||||
|
/* .prompt = */ {},
|
||||||
|
};
|
||||||
|
|
||||||
|
// TODO: optimize or pass from outside?
|
||||||
|
#if 0
|
||||||
|
{
|
||||||
|
common_params_sampling params;
|
||||||
|
params.no_perf = false;
|
||||||
|
|
||||||
|
params.top_k = 40;
|
||||||
|
params.top_p = 0.9;
|
||||||
|
|
||||||
|
params.samplers = {
|
||||||
|
COMMON_SAMPLER_TYPE_TOP_K,
|
||||||
|
COMMON_SAMPLER_TYPE_TOP_P,
|
||||||
|
COMMON_SAMPLER_TYPE_INFILL,
|
||||||
|
};
|
||||||
|
|
||||||
|
result->smpl = gpt_sampler_init(llama_get_model(ctx_dft), params);
|
||||||
|
}
|
||||||
|
#else
|
||||||
|
{
|
||||||
|
gpt_sampler_params params;
|
||||||
|
params.no_perf = false;
|
||||||
|
|
||||||
|
params.top_k = 10;
|
||||||
|
|
||||||
|
params.samplers = {
|
||||||
|
GPT_SAMPLER_TYPE_TOP_K,
|
||||||
|
};
|
||||||
|
|
||||||
|
result->smpl = gpt_sampler_init(llama_get_model(ctx_dft), params);
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
llama_update_context_with_rankworld(result->ctx, 0, 1, 0, 1);
|
||||||
|
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
void common_speculative_free(struct common_speculative * spec) {
|
||||||
|
gpt_sampler_free(spec->smpl);
|
||||||
|
|
||||||
|
llama_batch_free(spec->batch);
|
||||||
|
|
||||||
|
delete spec;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool common_speculative_are_compatible(
|
||||||
|
const struct llama_context * ctx_tgt,
|
||||||
|
const struct llama_context * ctx_dft) {
|
||||||
|
const struct llama_model * model_tgt = llama_get_model(ctx_tgt);
|
||||||
|
const struct llama_model * model_dft = llama_get_model(ctx_dft);
|
||||||
|
|
||||||
|
const bool vocab_type_tgt = llama_vocab_type(model_tgt);
|
||||||
|
LOG_DBG("%s: vocab_type tgt: %d\n", __func__, vocab_type_tgt);
|
||||||
|
|
||||||
|
const bool vocab_type_dft = llama_vocab_type(model_dft);
|
||||||
|
LOG_DBG("%s: vocab_type dft: %d\n", __func__, vocab_type_dft);
|
||||||
|
|
||||||
|
if (vocab_type_tgt != vocab_type_dft) {
|
||||||
|
LOG_ERR("%s: draft model vocab type must match target model to use speculation but "
|
||||||
|
"vocab_type_dft = %d while vocab_type_tgt = %d\n", __func__, vocab_type_dft, vocab_type_tgt);
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (llama_add_bos_token(model_tgt) != llama_add_bos_token(model_dft) ||
|
||||||
|
llama_add_eos_token(model_tgt) != llama_add_eos_token(model_dft) ||
|
||||||
|
llama_token_bos(model_tgt) != llama_token_bos(model_dft) ||
|
||||||
|
llama_token_eos(model_tgt) != llama_token_eos(model_dft)
|
||||||
|
) {
|
||||||
|
LOG_ERR("%s: draft model special tokens must match target model to use speculation\n", __func__);
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
const int n_vocab_tgt = llama_n_vocab(model_tgt);
|
||||||
|
const int n_vocab_dft = llama_n_vocab(model_dft);
|
||||||
|
|
||||||
|
const int vocab_diff = std::abs(n_vocab_tgt - n_vocab_dft);
|
||||||
|
|
||||||
|
if (vocab_diff > SPEC_VOCAB_MAX_SIZE_DIFFERENCE) {
|
||||||
|
LOG_ERR("%s: draft model vocab must closely match target model to use speculation but "
|
||||||
|
"target vocab size %d does not match draft vocab size %d - difference %d, max allowed %d\n",
|
||||||
|
__func__, n_vocab_tgt, llama_n_vocab(model_dft), vocab_diff, SPEC_VOCAB_MAX_SIZE_DIFFERENCE);
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int i = SPEC_VOCAB_CHECK_START_TOKEN_ID; i < std::min(n_vocab_tgt, n_vocab_dft); ++i) {
|
||||||
|
const char * token_text_tgt = llama_token_get_text(model_tgt, i);
|
||||||
|
const char * token_text_dft = llama_token_get_text(model_dft, i);
|
||||||
|
if (std::strcmp(token_text_tgt, token_text_dft) != 0) {
|
||||||
|
LOG_ERR("%s: draft model vocab must match target model to use speculation but "
|
||||||
|
"token %d content differs - target '%s', draft '%s'\n", __func__, i,
|
||||||
|
llama_token_to_piece(ctx_tgt, i).c_str(),
|
||||||
|
llama_token_to_piece(ctx_dft, i).c_str());
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
llama_tokens common_speculative_gen_draft(
|
||||||
|
struct common_speculative * spec,
|
||||||
|
struct common_speculative_params params,
|
||||||
|
const llama_tokens & prompt_tgt,
|
||||||
|
llama_token id_last) {
|
||||||
|
auto & batch = spec->batch;
|
||||||
|
auto & ctx = spec->ctx;
|
||||||
|
auto & smpl = spec->smpl;
|
||||||
|
auto & prompt = spec->prompt;
|
||||||
|
|
||||||
|
int reuse_i = 0;
|
||||||
|
int reuse_n = 0;
|
||||||
|
|
||||||
|
const int n_ctx = llama_n_ctx(ctx) - params.n_draft;
|
||||||
|
|
||||||
|
const int i_start = std::max<int>(0, (int) prompt_tgt.size() - n_ctx);
|
||||||
|
|
||||||
|
// reuse as much as possible from the old draft context
|
||||||
|
// ideally, the draft context should be as big as the target context and we will always reuse the entire prompt
|
||||||
|
for (int i = 0; i < (int) prompt.size(); ++i) {
|
||||||
|
int cur = 0;
|
||||||
|
while (i_start + cur < (int) prompt_tgt.size() &&
|
||||||
|
i + cur < (int) prompt.size() &&
|
||||||
|
prompt_tgt[i_start + cur] == prompt[i + cur]) {
|
||||||
|
cur++;
|
||||||
|
}
|
||||||
|
|
||||||
|
if ((cur >= params.n_reuse || n_ctx >= (int) prompt_tgt.size()) && cur > reuse_n) {
|
||||||
|
reuse_i = i;
|
||||||
|
reuse_n = cur;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
LOG_DBG("%s: reuse_i = %d, reuse_n = %d, prompt = %d\n", __func__, reuse_i, reuse_n, (int) prompt.size());
|
||||||
|
|
||||||
|
llama_tokens result;
|
||||||
|
result.reserve(params.n_draft);
|
||||||
|
|
||||||
|
if (reuse_n == 0) {
|
||||||
|
llama_kv_cache_clear(ctx);
|
||||||
|
|
||||||
|
prompt.clear();
|
||||||
|
} else {
|
||||||
|
// this happens when a previous draft has been discarded (for example, due to being too small), but the
|
||||||
|
// target model agreed with it. in this case, we simply pass back the previous results to save compute
|
||||||
|
if (reuse_i + reuse_n < (int) prompt.size() && prompt[reuse_i + reuse_n] == id_last) {
|
||||||
|
for (int i = reuse_i + reuse_n + 1; i < (int) prompt.size(); ++i) {
|
||||||
|
result.push_back(prompt[i]);
|
||||||
|
|
||||||
|
if (params.n_draft <= (int) result.size()) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (reuse_i > 0) {
|
||||||
|
llama_kv_cache_seq_rm (ctx, 0, 0, reuse_i);
|
||||||
|
llama_kv_cache_seq_add(ctx, 0, reuse_i, -1, -reuse_i);
|
||||||
|
|
||||||
|
prompt.erase(prompt.begin(), prompt.begin() + reuse_i);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (reuse_n < (int) prompt.size()) {
|
||||||
|
llama_kv_cache_seq_rm (ctx, 0, reuse_n, -1);
|
||||||
|
|
||||||
|
prompt.erase(prompt.begin() + reuse_n, prompt.end());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// prepare a batch to evaluate any new tokens in the prompt
|
||||||
|
llama_batch_clear(batch);
|
||||||
|
|
||||||
|
for (size_t i = i_start + reuse_n; i < prompt_tgt.size(); ++i) {
|
||||||
|
//LOG_DBG("i = %d, i_start = %d, reuse_n = %d, i - i_start = %d, id = %6d\n", i, i_start, reuse_n, i - i_start, prompt_tgt[i]);
|
||||||
|
llama_batch_add(batch, prompt_tgt[i], i - i_start, { 0 }, false);
|
||||||
|
|
||||||
|
prompt.push_back(prompt_tgt[i]);
|
||||||
|
}
|
||||||
|
|
||||||
|
// we should rarely end-up here during normal decoding
|
||||||
|
if (batch.n_tokens > 0) {
|
||||||
|
//LOG_DBG("%s: draft prompt batch: %s\n", __func__, string_from(ctx, batch).c_str());
|
||||||
|
|
||||||
|
llama_decode(ctx, batch);
|
||||||
|
}
|
||||||
|
|
||||||
|
const llama_pos n_past = prompt.size();
|
||||||
|
|
||||||
|
LOG_DBG("%s: n_past = %d\n", __func__, n_past);
|
||||||
|
|
||||||
|
llama_batch_clear(batch);
|
||||||
|
llama_batch_add (batch, id_last, n_past, { 0 }, true);
|
||||||
|
|
||||||
|
prompt.push_back(id_last);
|
||||||
|
|
||||||
|
//LOG_DBG("%s: draft prompt: %s\n", __func__, string_from(ctx, prompt).c_str());
|
||||||
|
|
||||||
|
llama_decode(ctx, batch);
|
||||||
|
|
||||||
|
gpt_sampler_reset(smpl);
|
||||||
|
|
||||||
|
// sample n_draft tokens from the draft model
|
||||||
|
for (int i = 0; i < params.n_draft; ++i) {
|
||||||
|
llama_batch_clear(batch);
|
||||||
|
|
||||||
|
gpt_sampler_sample(smpl, ctx, 0, true);
|
||||||
|
|
||||||
|
const auto * cur_p = gpt_sampler_get_candidates(smpl);
|
||||||
|
|
||||||
|
for (int k = 0; k < std::min(3, (int) cur_p->size); ++k) {
|
||||||
|
LOG_DBG(" - draft candidate %3d, pos %3d: %6d (%8.3f) '%s'\n",
|
||||||
|
k, i, cur_p->data[k].id, cur_p->data[k].p, llama_token_to_piece(ctx, cur_p->data[k].id).c_str());
|
||||||
|
}
|
||||||
|
|
||||||
|
// add drafted token for each sequence
|
||||||
|
const llama_token id = cur_p->data[0].id;
|
||||||
|
|
||||||
|
// only collect very high-confidence draft tokens
|
||||||
|
if (cur_p->data[0].p < params.p_min) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
gpt_sampler_accept(smpl, id, true);
|
||||||
|
|
||||||
|
result.push_back(id);
|
||||||
|
|
||||||
|
if (params.n_draft <= (int) result.size()) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
llama_batch_add(batch, id, n_past + i + 1, { 0 }, true);
|
||||||
|
|
||||||
|
// evaluate the drafted tokens on the draft model
|
||||||
|
llama_decode(ctx, batch);
|
||||||
|
|
||||||
|
prompt.push_back(id);
|
||||||
|
}
|
||||||
|
|
||||||
|
return result;
|
||||||
|
}
|
28
common/speculative.h
Normal file
28
common/speculative.h
Normal file
|
@ -0,0 +1,28 @@
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "llama.h"
|
||||||
|
#include "common.h"
|
||||||
|
|
||||||
|
struct common_speculative;
|
||||||
|
|
||||||
|
struct common_speculative_params {
|
||||||
|
int n_draft = 16; // max drafted tokens
|
||||||
|
int n_reuse = 256;
|
||||||
|
|
||||||
|
float p_min = 0.9f; // min probabiliy required to accept a token in the draft
|
||||||
|
};
|
||||||
|
|
||||||
|
struct common_speculative * common_speculative_init(struct llama_context * ctx_dft);
|
||||||
|
|
||||||
|
void common_speculative_free(struct common_speculative * spec);
|
||||||
|
|
||||||
|
bool common_speculative_are_compatible(
|
||||||
|
const struct llama_context * ctx_tgt,
|
||||||
|
const struct llama_context * ctx_dft);
|
||||||
|
|
||||||
|
// sample up to n_draft tokens and add them to the batch using the draft model
|
||||||
|
llama_tokens common_speculative_gen_draft(
|
||||||
|
struct common_speculative * spec,
|
||||||
|
struct common_speculative_params params,
|
||||||
|
const llama_tokens & prompt,
|
||||||
|
llama_token id_last);
|
|
@ -6,6 +6,7 @@
|
||||||
#include "sampling.h"
|
#include "sampling.h"
|
||||||
#include "json-schema-to-grammar.h"
|
#include "json-schema-to-grammar.h"
|
||||||
#include "llama.h"
|
#include "llama.h"
|
||||||
|
#include "speculative.h"
|
||||||
|
|
||||||
// Change JSON_ASSERT from assert() to GGML_ASSERT:
|
// Change JSON_ASSERT from assert() to GGML_ASSERT:
|
||||||
#define JSON_ASSERT GGML_ASSERT
|
#define JSON_ASSERT GGML_ASSERT
|
||||||
|
@ -126,13 +127,16 @@ struct server_task_result {
|
||||||
|
|
||||||
struct slot_params {
|
struct slot_params {
|
||||||
bool stream = true;
|
bool stream = true;
|
||||||
bool cache_prompt = false; // remember the prompt to avoid reprocessing all prompt
|
bool cache_prompt = true; // remember the prompt to avoid reprocessing all prompt
|
||||||
|
|
||||||
int32_t n_keep = 0; // number of tokens to keep from initial prompt
|
int32_t n_keep = 0; // number of tokens to keep from initial prompt
|
||||||
int32_t n_discard = 0; // number of tokens after n_keep that may be discarded when shifting context, 0 defaults to half
|
int32_t n_discard = 0; // number of tokens after n_keep that may be discarded when shifting context, 0 defaults to half
|
||||||
int32_t n_predict = -1; // new tokens to predict
|
int32_t n_predict = -1; // new tokens to predict
|
||||||
|
|
||||||
std::vector<std::string> antiprompt;
|
std::vector<std::string> antiprompt;
|
||||||
|
|
||||||
|
struct gpt_sampler_params sampling;
|
||||||
|
struct common_params_speculative speculative;
|
||||||
|
|
||||||
json input_prefix;
|
json input_prefix;
|
||||||
json input_suffix;
|
json input_suffix;
|
||||||
|
@ -142,6 +146,12 @@ struct server_slot {
|
||||||
int id;
|
int id;
|
||||||
int id_task = -1;
|
int id_task = -1;
|
||||||
|
|
||||||
|
llama_batch batch_spec;
|
||||||
|
|
||||||
|
llama_context * ctx_dft = nullptr;
|
||||||
|
|
||||||
|
common_speculative * spec = nullptr;
|
||||||
|
|
||||||
// the index relative to completion multi-task request
|
// the index relative to completion multi-task request
|
||||||
size_t index = 0;
|
size_t index = 0;
|
||||||
|
|
||||||
|
@ -231,7 +241,7 @@ struct server_slot {
|
||||||
generated_token_probs.clear();
|
generated_token_probs.clear();
|
||||||
}
|
}
|
||||||
|
|
||||||
bool has_budget(gpt_params &global_params) {
|
bool has_budget(const gpt_params &global_params) {
|
||||||
if (params.n_predict == -1 && global_params.n_predict == -1) {
|
if (params.n_predict == -1 && global_params.n_predict == -1) {
|
||||||
return true; // limitless
|
return true; // limitless
|
||||||
}
|
}
|
||||||
|
@ -251,6 +261,10 @@ struct server_slot {
|
||||||
return state != SLOT_STATE_IDLE;
|
return state != SLOT_STATE_IDLE;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool can_speculate() const {
|
||||||
|
return ctx_dft && params.speculative.n_max > 0 && params.cache_prompt;
|
||||||
|
}
|
||||||
|
|
||||||
void add_token(const completion_token_output & token) {
|
void add_token(const completion_token_output & token) {
|
||||||
if (!is_processing()) {
|
if (!is_processing()) {
|
||||||
SLT_WRN(*this, "%s", "slot is not processing\n");
|
SLT_WRN(*this, "%s", "slot is not processing\n");
|
||||||
|
@ -615,6 +629,9 @@ struct server_context {
|
||||||
|
|
||||||
gpt_params params;
|
gpt_params params;
|
||||||
|
|
||||||
|
llama_model * model_dft = nullptr;
|
||||||
|
llama_context_params cparams_dft;
|
||||||
|
|
||||||
llama_batch batch = {};
|
llama_batch batch = {};
|
||||||
|
|
||||||
bool clean_kv_cache = true;
|
bool clean_kv_cache = true;
|
||||||
|
@ -652,21 +669,71 @@ struct server_context {
|
||||||
model = nullptr;
|
model = nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (model_dft) {
|
||||||
|
llama_free_model(model_dft);
|
||||||
|
model_dft = nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
// Clear any sampling context
|
// Clear any sampling context
|
||||||
for (server_slot & slot : slots) {
|
for (server_slot & slot : slots) {
|
||||||
if (slot.smpl != nullptr) {
|
if (slot.smpl != nullptr) {
|
||||||
gpt_sampler_free(slot.smpl);
|
gpt_sampler_free(slot.smpl);
|
||||||
}
|
}
|
||||||
|
slot.smpl = nullptr;
|
||||||
|
|
||||||
|
llama_free(slot.ctx_dft);
|
||||||
|
slot.ctx_dft = nullptr;
|
||||||
|
|
||||||
|
common_speculative_free(slot.spec);
|
||||||
|
slot.spec = nullptr;
|
||||||
|
|
||||||
|
llama_batch_free(slot.batch_spec);
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_batch_free(batch);
|
llama_batch_free(batch);
|
||||||
}
|
}
|
||||||
|
|
||||||
bool load_model(const gpt_params & params_) {
|
bool load_model(const gpt_params & params_) {
|
||||||
|
SRV_INF("loading model '%s'\n", params.model.c_str());
|
||||||
|
|
||||||
params = params_;
|
params = params_;
|
||||||
|
|
||||||
// dedicate one sequence to the system prompt
|
// dedicate one sequence to the system prompt
|
||||||
params.n_parallel += 1;
|
params.n_parallel += 1;
|
||||||
|
|
||||||
|
// load draft model first
|
||||||
|
llama_init_result llama_init_dft;
|
||||||
|
if (!params.speculative.model.empty()) {
|
||||||
|
SRV_INF("loading draft model '%s'\n", params.speculative.model.c_str());
|
||||||
|
|
||||||
|
auto params_dft = params;
|
||||||
|
|
||||||
|
params_dft.model = params.speculative.model;
|
||||||
|
params_dft.n_ctx = params.speculative.n_ctx;
|
||||||
|
params_dft.n_gpu_layers = params.speculative.n_gpu_layers;
|
||||||
|
params_dft.use_mlock = true;
|
||||||
|
params_dft.n_world = 1; // do not split the draft model across devicesAdd commentMore actions
|
||||||
|
params_dft.rank = 0; // always load the draft model on the head device
|
||||||
|
|
||||||
|
std::fill_n(params_dft.n_layer_window, params.n_world, 0);
|
||||||
|
|
||||||
|
llama_init_dft = llama_init_from_gpt_params(params_dft);
|
||||||
|
|
||||||
|
model_dft = llama_init_dft.model;
|
||||||
|
|
||||||
|
if (model_dft == nullptr) {
|
||||||
|
SRV_ERR("failed to load draft model, '%s'\n", params.speculative.model.c_str());
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
cparams_dft = llama_context_params_from_gpt_params(params);
|
||||||
|
cparams_dft.n_batch = llama_n_ctx(llama_init_dft.context);
|
||||||
|
cparams_dft.n_world = 1;
|
||||||
|
cparams_dft.rank = 0;
|
||||||
|
std::fill_n(cparams_dft.n_layer_window, 32, 0);
|
||||||
|
cparams_dft.n_layer_window[0] = llama_n_layer(model_dft);
|
||||||
|
cparams_dft.n_gpu_layers = params.speculative.n_gpu_layers;
|
||||||
|
}
|
||||||
|
|
||||||
llama_init_result llama_init = llama_init_from_gpt_params(params);
|
llama_init_result llama_init = llama_init_from_gpt_params(params);
|
||||||
|
|
||||||
|
@ -685,6 +752,22 @@ struct server_context {
|
||||||
|
|
||||||
add_bos_token = llama_add_bos_token(model);
|
add_bos_token = llama_add_bos_token(model);
|
||||||
has_eos_token = !llama_add_eos_token(model);
|
has_eos_token = !llama_add_eos_token(model);
|
||||||
|
|
||||||
|
if (!params.speculative.model.empty()){
|
||||||
|
if (!common_speculative_are_compatible(ctx, llama_init_dft.context)) {
|
||||||
|
SRV_ERR("the draft model '%s' is not compatible with the target model '%s'\n", params.speculative.model.c_str(), params.model.c_str());
|
||||||
|
|
||||||
|
llama_free (llama_init_dft.context);
|
||||||
|
llama_free_model(llama_init_dft.model);
|
||||||
|
|
||||||
|
model_dft = nullptr;
|
||||||
|
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
// the context is not needed - we will create one for each slot
|
||||||
|
llama_free(llama_init_dft.context);
|
||||||
|
}
|
||||||
|
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
@ -708,6 +791,30 @@ struct server_context {
|
||||||
slot.id = i;
|
slot.id = i;
|
||||||
slot.n_ctx = n_ctx_slot;
|
slot.n_ctx = n_ctx_slot;
|
||||||
slot.n_predict = params.n_predict;
|
slot.n_predict = params.n_predict;
|
||||||
|
|
||||||
|
if (model_dft) {
|
||||||
|
slot.batch_spec = llama_batch_init(params.speculative.n_max + 1, 0, 1);
|
||||||
|
|
||||||
|
slot.ctx_dft = llama_new_context_with_model(model_dft, cparams_dft);
|
||||||
|
|
||||||
|
if (llama_context_setup_backend(model_dft, cparams_dft, slot.ctx_dft) == nullptr) {
|
||||||
|
SRV_ERR("%s: failed to setup context with model '%s'\n", __func__, params.speculative.model.c_str());
|
||||||
|
llama_free(slot.ctx_dft);
|
||||||
|
llama_free_model(model_dft);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (slot.ctx_dft == nullptr) {
|
||||||
|
SRV_ERR("%s", "failed to create draft context\n");
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
slot.spec = common_speculative_init(slot.ctx_dft);
|
||||||
|
if (slot.spec == nullptr) {
|
||||||
|
SRV_ERR("%s", "failed to create speculator\n");
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
SLT_INF(slot, "new slot n_ctx_slot = %d\n", slot.n_ctx);
|
SLT_INF(slot, "new slot n_ctx_slot = %d\n", slot.n_ctx);
|
||||||
|
|
||||||
|
@ -875,6 +982,8 @@ struct server_context {
|
||||||
slot_params default_params;
|
slot_params default_params;
|
||||||
// Sampling parameter defaults are loaded from the global server context (but individual requests can still override them)
|
// Sampling parameter defaults are loaded from the global server context (but individual requests can still override them)
|
||||||
auto default_sparams = params.sparams;
|
auto default_sparams = params.sparams;
|
||||||
|
default_params.speculative = params.speculative;
|
||||||
|
|
||||||
const auto & data = task.data;
|
const auto & data = task.data;
|
||||||
|
|
||||||
if (data.count("__oaicompat") != 0) {
|
if (data.count("__oaicompat") != 0) {
|
||||||
|
@ -886,7 +995,7 @@ struct server_context {
|
||||||
}
|
}
|
||||||
|
|
||||||
slot.params.stream = json_value(data, "stream", false);
|
slot.params.stream = json_value(data, "stream", false);
|
||||||
slot.params.cache_prompt = json_value(data, "cache_prompt", false);
|
slot.params.cache_prompt = json_value(data, "cache_prompt", true);
|
||||||
slot.params.n_predict = json_value(data, "n_predict", json_value(data, "max_tokens", default_params.n_predict));
|
slot.params.n_predict = json_value(data, "n_predict", json_value(data, "max_tokens", default_params.n_predict));
|
||||||
slot.sparams.top_k = json_value(data, "top_k", default_sparams.top_k);
|
slot.sparams.top_k = json_value(data, "top_k", default_sparams.top_k);
|
||||||
slot.sparams.top_p = json_value(data, "top_p", default_sparams.top_p);
|
slot.sparams.top_p = json_value(data, "top_p", default_sparams.top_p);
|
||||||
|
@ -909,6 +1018,12 @@ struct server_context {
|
||||||
slot.sparams.seed = json_value(data, "seed", default_sparams.seed);
|
slot.sparams.seed = json_value(data, "seed", default_sparams.seed);
|
||||||
slot.sparams.n_probs = json_value(data, "n_probs", default_sparams.n_probs);
|
slot.sparams.n_probs = json_value(data, "n_probs", default_sparams.n_probs);
|
||||||
slot.sparams.min_keep = json_value(data, "min_keep", default_sparams.min_keep);
|
slot.sparams.min_keep = json_value(data, "min_keep", default_sparams.min_keep);
|
||||||
|
|
||||||
|
slot.params.speculative.n_min = json_value(data, "speculative.n_min", default_params.speculative.n_min);
|
||||||
|
slot.params.speculative.n_max = json_value(data, "speculative.n_max", default_params.speculative.n_max);
|
||||||
|
slot.params.speculative.p_min = json_value(data, "speculative.p_min", default_params.speculative.p_min);
|
||||||
|
|
||||||
|
slot.params.speculative.n_min = std::min(slot.params.speculative.n_max, slot.params.speculative.n_min);
|
||||||
|
|
||||||
// process "json_schema" and "grammar"
|
// process "json_schema" and "grammar"
|
||||||
if (data.contains("json_schema") && !data.at("json_schema").is_null() && data.contains("grammar") && !data.at("grammar").is_null()) {
|
if (data.contains("json_schema") && !data.at("json_schema").is_null() && data.contains("grammar") && !data.at("grammar").is_null()) {
|
||||||
|
@ -1049,6 +1164,12 @@ struct server_context {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (slot.ctx_dft) {
|
||||||
|
llama_batch_free(slot.batch_spec);
|
||||||
|
|
||||||
|
slot.batch_spec = llama_batch_init(slot.params.speculative.n_max + 1, 0, 1);
|
||||||
|
}
|
||||||
|
|
||||||
slot.state = SLOT_STATE_PROCESSING_PROMPT;
|
slot.state = SLOT_STATE_PROCESSING_PROMPT;
|
||||||
slot.prompt_tokens.clear();
|
slot.prompt_tokens.clear();
|
||||||
|
@ -2357,38 +2478,101 @@ struct server_context {
|
||||||
continue; // continue loop of slots
|
continue; // continue loop of slots
|
||||||
}
|
}
|
||||||
|
|
||||||
completion_token_output result;
|
llama_token id;
|
||||||
const llama_token id = gpt_sampler_sample(slot.smpl, ctx, slot.i_batch - i);
|
|
||||||
|
|
||||||
gpt_sampler_accept(slot.smpl, id, true);
|
{
|
||||||
|
completion_token_output result;
|
||||||
|
|
||||||
slot.n_decoded += 1;
|
id = gpt_sampler_sample(slot.smpl, ctx, slot.i_batch - i);
|
||||||
if (slot.n_decoded == 1) {
|
|
||||||
slot.t_start_generation = ggml_time_us();
|
slot.i_batch = -1;
|
||||||
slot.t_prompt_processing = (slot.t_start_generation - slot.t_start_process_prompt) / 1e3;
|
|
||||||
metrics.on_prompt_eval(slot);
|
gpt_sampler_accept(slot.smpl, id, true);
|
||||||
|
|
||||||
|
slot.n_decoded += 1;
|
||||||
|
if (slot.n_decoded == 1) {
|
||||||
|
slot.t_start_generation = ggml_time_us();
|
||||||
|
slot.t_prompt_processing = (slot.t_start_generation - slot.t_start_process_prompt) / 1e3;
|
||||||
|
metrics.on_prompt_eval(slot);
|
||||||
|
}
|
||||||
|
|
||||||
|
result.tok = id;
|
||||||
|
|
||||||
|
const auto * cur_p = gpt_sampler_get_candidates(slot.smpl);
|
||||||
|
|
||||||
|
for (size_t i = 0; i < (size_t) slot.params.sampling.n_probs; ++i) {
|
||||||
|
result.probs.push_back({
|
||||||
|
cur_p->data[i].id,
|
||||||
|
i >= cur_p->size ? 0.0f : cur_p->data[i].p,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!process_token(result, slot)) {
|
||||||
|
// release slot because of stop condition
|
||||||
|
slot.release();
|
||||||
|
slot.print_timings();
|
||||||
|
send_final_response(slot);
|
||||||
|
metrics.on_prediction(slot);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
result.tok = id;
|
// check if the slot supports speculative decoding
|
||||||
|
if (!slot.can_speculate()) {
|
||||||
const auto * cur_p = gpt_sampler_get_candidates(slot.smpl);
|
continue;
|
||||||
|
|
||||||
for (size_t i = 0; i < (size_t) slot.sparams.n_probs; ++i) {
|
|
||||||
result.probs.push_back({
|
|
||||||
cur_p->data[i].id,
|
|
||||||
i >= cur_p->size ? 0.0f : cur_p->data[i].p,
|
|
||||||
});
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!process_token(result, slot)) {
|
struct common_speculative_params params_spec;
|
||||||
// release slot because of stop condition
|
params_spec.n_draft = slot.params.speculative.n_max;
|
||||||
slot.release();
|
params_spec.n_reuse = llama_n_ctx(slot.ctx_dft) - slot.params.speculative.n_max;
|
||||||
slot.print_timings();
|
params_spec.p_min = slot.params.speculative.p_min;
|
||||||
send_final_response(slot);
|
|
||||||
metrics.on_prediction(slot);
|
llama_tokens draft = common_speculative_gen_draft(slot.spec, params_spec, slot.cache_tokens, id);
|
||||||
|
|
||||||
|
// ignore small drafts
|
||||||
|
if (slot.params.speculative.n_min > (int) draft.size()) {
|
||||||
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
slot.i_batch = -1;
|
// construct the speculation batch
|
||||||
|
llama_batch_clear(slot.batch_spec);
|
||||||
|
llama_batch_add (slot.batch_spec, id, slot.n_past, { slot.id + 1 }, true);
|
||||||
|
|
||||||
|
for (size_t i = 0; i < draft.size(); ++i) {
|
||||||
|
llama_batch_add(slot.batch_spec, draft[i], slot.n_past + 1 + i, { slot.id + 1 }, true);
|
||||||
|
}
|
||||||
|
|
||||||
|
llama_decode(ctx, slot.batch_spec, true);
|
||||||
|
|
||||||
|
// the accepted tokens from the speculation
|
||||||
|
const auto ids = gpt_sampler_sample_and_accept_n(slot.smpl, ctx, draft);
|
||||||
|
|
||||||
|
slot.n_past += ids.size();
|
||||||
|
slot.n_decoded += ids.size();
|
||||||
|
|
||||||
|
slot.cache_tokens.push_back(id);
|
||||||
|
slot.cache_tokens.insert(slot.cache_tokens.end(), ids.begin(), ids.end() - 1);
|
||||||
|
|
||||||
|
llama_kv_cache_seq_rm (ctx, slot.id + 1, slot.n_past, -1);
|
||||||
|
llama_send_kv_cache_seq_rm(ctx, slot.id , slot.n_past, -1);
|
||||||
|
|
||||||
|
for (size_t i = 0; i < ids.size(); ++i) {
|
||||||
|
completion_token_output result;
|
||||||
|
|
||||||
|
result.tok = ids[i];
|
||||||
|
|
||||||
|
if (!process_token(result, slot)) {
|
||||||
|
// release slot because of stop condition
|
||||||
|
slot.release();
|
||||||
|
slot.print_timings();
|
||||||
|
send_final_response(slot);
|
||||||
|
metrics.on_prediction(slot);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
SRV_DBG("accepted %d/%d draft tokens\n", (int) ids.size() - 1, (int) draft.size());
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -3384,6 +3568,8 @@ int main(int argc, char ** argv) {
|
||||||
LOG_INF("%s: loading model\n", __func__);
|
LOG_INF("%s: loading model\n", __func__);
|
||||||
|
|
||||||
if (!ctx_server.load_model(params)) {
|
if (!ctx_server.load_model(params)) {
|
||||||
|
char * stop_signal = nullptr;
|
||||||
|
llama_free_sockets(ctx_server.ctx, &stop_signal);
|
||||||
clean_up();
|
clean_up();
|
||||||
t.join();
|
t.join();
|
||||||
LOG_ERR("%s: exiting due to model loading error\n", __func__);
|
LOG_ERR("%s: exiting due to model loading error\n", __func__);
|
||||||
|
|
|
@ -41,7 +41,7 @@ int main(int argc, char ** argv) {
|
||||||
|
|
||||||
gpt_init();
|
gpt_init();
|
||||||
|
|
||||||
if (params.model_draft.empty()) {
|
if (params.speculative.model.empty()) {
|
||||||
LOG_ERR("%s: --model-draft is required\n", __func__);
|
LOG_ERR("%s: --model-draft is required\n", __func__);
|
||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
|
@ -68,7 +68,7 @@ int main(int argc, char ** argv) {
|
||||||
// load the draft model
|
// load the draft model
|
||||||
// make a hard copy of params to use for the draft model
|
// make a hard copy of params to use for the draft model
|
||||||
gpt_params params_draft = params;
|
gpt_params params_draft = params;
|
||||||
params_draft.model = params_draft.model_draft;
|
params_draft.model = params_draft.speculative.model;
|
||||||
params_draft.n_gpu_layers = params_draft.n_gpu_layers_draft;
|
params_draft.n_gpu_layers = params_draft.n_gpu_layers_draft;
|
||||||
params_draft.n_world = 1; // do not split the draft model across devices
|
params_draft.n_world = 1; // do not split the draft model across devices
|
||||||
params_draft.rank = 0; // always load the draft model on the head device
|
params_draft.rank = 0; // always load the draft model on the head device
|
||||||
|
@ -169,7 +169,7 @@ int main(int argc, char ** argv) {
|
||||||
const auto t_enc_end = ggml_time_us();
|
const auto t_enc_end = ggml_time_us();
|
||||||
|
|
||||||
// how many tokens to draft each time
|
// how many tokens to draft each time
|
||||||
int n_draft = params.n_draft;
|
int n_draft = params.speculative.n_max;
|
||||||
|
|
||||||
int n_predict = 0;
|
int n_predict = 0;
|
||||||
int n_drafted = 0;
|
int n_drafted = 0;
|
||||||
|
|
|
@ -17878,7 +17878,7 @@ static void llama_send_meta(zmq::socket_t & socket, struct sync_meta * meta, boo
|
||||||
|
|
||||||
if (meta->pos != nullptr) {
|
if (meta->pos != nullptr) {
|
||||||
send_msgs.emplace_back("pos", strlen("pos"));
|
send_msgs.emplace_back("pos", strlen("pos"));
|
||||||
send_msgs.emplace_back(meta->pos, meta->n_ctx * sizeof(llama_pos));
|
send_msgs.emplace_back(meta->pos, meta->n_tokens * sizeof(llama_pos));
|
||||||
}
|
}
|
||||||
|
|
||||||
if (meta->n_seq_id != nullptr) {
|
if (meta->n_seq_id != nullptr) {
|
||||||
|
@ -17986,8 +17986,8 @@ static int llama_recv_meta(zmq::socket_t & socket, struct sync_meta * meta) {
|
||||||
}
|
}
|
||||||
|
|
||||||
if (key == "pos") {
|
if (key == "pos") {
|
||||||
meta->pos = (llama_pos *) malloc(meta->n_ctx * sizeof(llama_pos));
|
meta->pos = (llama_pos *) malloc(meta->n_tokens * sizeof(llama_pos));
|
||||||
std::memcpy(meta->pos, data_msg.data(), meta->n_ctx * sizeof(llama_pos));
|
std::memcpy(meta->pos, data_msg.data(), meta->n_tokens * sizeof(llama_pos));
|
||||||
}
|
}
|
||||||
|
|
||||||
if (key == "n_seq_id") {
|
if (key == "n_seq_id") {
|
||||||
|
@ -18304,8 +18304,8 @@ static int llama_decode_internal(
|
||||||
if (meta.n_tokens > 0) {
|
if (meta.n_tokens > 0) {
|
||||||
batch_all.n_tokens = meta.n_tokens;
|
batch_all.n_tokens = meta.n_tokens;
|
||||||
if (meta.pos != nullptr) {
|
if (meta.pos != nullptr) {
|
||||||
batch_all.pos = (llama_pos *) malloc(meta.n_ctx * sizeof(llama_pos));
|
batch_all.pos = (llama_pos *) malloc(meta.n_tokens * sizeof(llama_pos));
|
||||||
std::memcpy(batch_all.pos, meta.pos, meta.n_ctx * sizeof(llama_pos));
|
std::memcpy(batch_all.pos, meta.pos, meta.n_tokens * sizeof(llama_pos));
|
||||||
}
|
}
|
||||||
if (meta.n_seq_id != nullptr) {
|
if (meta.n_seq_id != nullptr) {
|
||||||
batch_all.n_seq_id = (int32_t *) malloc(meta.n_tokens * sizeof(int32_t));
|
batch_all.n_seq_id = (int32_t *) malloc(meta.n_tokens * sizeof(int32_t));
|
||||||
|
@ -20944,6 +20944,12 @@ struct llama_context * llama_new_context_with_model(
|
||||||
ctx->cparams.rank = params.rank;
|
ctx->cparams.rank = params.rank;
|
||||||
ctx->cparams.force = params.force;
|
ctx->cparams.force = params.force;
|
||||||
ctx->cparams.original_next_rank = (params.rank + 1) % params.n_world;
|
ctx->cparams.original_next_rank = (params.rank + 1) % params.n_world;
|
||||||
|
|
||||||
|
auto &hparams = model->hparams;
|
||||||
|
auto &cparams = ctx->cparams;
|
||||||
|
cparams.n_ctx = params.n_ctx == 0 ? hparams.n_ctx_train : params.n_ctx;
|
||||||
|
ctx->logits_all = params.logits_all;
|
||||||
|
|
||||||
return ctx;
|
return ctx;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -22083,6 +22089,9 @@ void llama_model_compute_buf_size(
|
||||||
// this value may vary by GPU and CUDA version, but it's lower than 400 MiB in most cases,
|
// this value may vary by GPU and CUDA version, but it's lower than 400 MiB in most cases,
|
||||||
// another 300 MiB is used to prevent accidental OOM.
|
// another 300 MiB is used to prevent accidental OOM.
|
||||||
*gpu_buf += 700 * 1024 * 1024;
|
*gpu_buf += 700 * 1024 * 1024;
|
||||||
|
} else if (backend == BACKEND_METAL) {
|
||||||
|
// 300 MiB is used to prevent accidental OOM, e.g., automatic quantization conversion.
|
||||||
|
*gpu_buf += 300 * 1024 * 1024;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Add table
Reference in a new issue