mirror of
https://github.com/Lizonghang/prima.cpp.git
synced 2025-09-10 00:24:34 +00:00
fix speculative decoding
This commit is contained in:
parent
e50b3aa473
commit
dc875bbef9
4 changed files with 75 additions and 28 deletions
|
@ -12,7 +12,7 @@
|
|||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#define SPEC_VOCAB_MAX_SIZE_DIFFERENCE 100
|
||||
#define SPEC_VOCAB_MAX_SIZE_DIFFERENCE 128
|
||||
#define SPEC_VOCAB_CHECK_START_TOKEN_ID 5
|
||||
|
||||
struct seq_draft {
|
||||
|
@ -65,23 +65,29 @@ int main(int argc, char ** argv) {
|
|||
llama_context * ctx_tgt = NULL;
|
||||
llama_context * ctx_dft = NULL;
|
||||
|
||||
// load the draft model
|
||||
// make a hard copy of params to use for the draft model
|
||||
gpt_params params_draft = params;
|
||||
params_draft.model = params_draft.model_draft;
|
||||
params_draft.n_gpu_layers = params_draft.n_gpu_layers_draft;
|
||||
params_draft.n_world = 1; // do not split the draft model across devices
|
||||
params_draft.rank = 0; // always load the draft model on the head device
|
||||
std::fill_n(params_draft.n_layer_window, params.n_world, 0);
|
||||
|
||||
if (params_draft.draft_cpuparams.n_threads > 0) {
|
||||
params_draft.cpuparams.n_threads = params_draft.draft_cpuparams.n_threads;
|
||||
}
|
||||
|
||||
params_draft.cpuparams_batch.n_threads = params_draft.draft_cpuparams_batch.n_threads;
|
||||
llama_init_result llama_init_dft = llama_init_from_gpt_params(params_draft);
|
||||
model_dft = llama_init_dft.model;
|
||||
ctx_dft = llama_init_dft.context;
|
||||
|
||||
// load the target model
|
||||
llama_init_result llama_init_tgt = llama_init_from_gpt_params(params);
|
||||
model_tgt = llama_init_tgt.model;
|
||||
ctx_tgt = llama_init_tgt.context;
|
||||
|
||||
// load the draft model
|
||||
params.model = params.model_draft;
|
||||
params.n_gpu_layers = params.n_gpu_layers_draft;
|
||||
if (params.draft_cpuparams.n_threads > 0) {
|
||||
params.cpuparams.n_threads = params.draft_cpuparams.n_threads;
|
||||
}
|
||||
|
||||
params.cpuparams_batch.n_threads = params.draft_cpuparams_batch.n_threads;
|
||||
llama_init_result llama_init_dft = llama_init_from_gpt_params(params);
|
||||
model_dft = llama_init_dft.model;
|
||||
ctx_dft = llama_init_dft.context;
|
||||
|
||||
const bool vocab_type_tgt = llama_vocab_type(model_tgt);
|
||||
LOG_DBG("vocab_type tgt: %d\n", vocab_type_tgt);
|
||||
|
||||
|
@ -161,9 +167,6 @@ int main(int argc, char ** argv) {
|
|||
|
||||
const auto t_enc_end = ggml_time_us();
|
||||
|
||||
// the 2 models should have the same vocab
|
||||
//GGML_ASSERT(n_vocab == llama_n_vocab(model_dft));
|
||||
|
||||
// how many tokens to draft each time
|
||||
int n_draft = params.n_draft;
|
||||
|
||||
|
@ -180,8 +183,6 @@ int main(int argc, char ** argv) {
|
|||
// target model sampling context (reuse the llama_context's sampling instance)
|
||||
struct gpt_sampler * smpl = gpt_sampler_init(model_tgt, params.sparams);
|
||||
|
||||
struct llama_sampler * softmax = llama_sampler_init_softmax();
|
||||
|
||||
// draft sequence data
|
||||
std::vector<seq_draft> drafts(n_seq_dft);
|
||||
|
||||
|
@ -258,10 +259,13 @@ int main(int argc, char ** argv) {
|
|||
float r = u_dist(rng);
|
||||
llama_token_data_array dist_dft = { drafts[s].dists[i_dft].data() , drafts[s].dists[i_dft].size(), LLAMA_TOKEN_NULL, true };
|
||||
|
||||
//GGML_ASSERT(dist_tgt.size <= dist_dft.size);
|
||||
// if (dist_tgt.size > dist_dft.size) {
|
||||
// LOG_ERR("dist_tgt.size (%zu) must be less than or equal to dist_dft.size (%zu)\n", dist_tgt.size, dist_dft.size);
|
||||
// GGML_ASSERT(dist_tgt.size <= dist_dft.size);
|
||||
// }
|
||||
|
||||
// acquire the token probabilities assigned by the draft and target models
|
||||
for (size_t i = 0; i < dist_tgt.size; i++) {
|
||||
for (size_t i = 0; i < dist_tgt.size && i < dist_dft.size; i++) {
|
||||
if (dist_tgt.data[i].id == drafts[s].tokens[i_dft]) {
|
||||
p_tgt = dist_tgt.data[i].p;
|
||||
}
|
||||
|
@ -406,7 +410,6 @@ int main(int argc, char ** argv) {
|
|||
{
|
||||
LOG_DBG("the sampled target token (%d, '%s') did not match, or we ran out of drafted tokens\n", token_id, token_str.c_str());
|
||||
|
||||
// TODO: simplify
|
||||
{
|
||||
LOG_DBG("keeping sequence %d, n_past_tgt = %d, n_past_dft = %d\n", s_keep, n_past_tgt, n_past_dft);
|
||||
|
||||
|
@ -418,6 +421,12 @@ int main(int argc, char ** argv) {
|
|||
llama_kv_cache_seq_keep(ctx_tgt, s_keep);
|
||||
llama_kv_cache_seq_cp (ctx_tgt, s_keep, 0, -1, -1);
|
||||
llama_kv_cache_seq_keep(ctx_tgt, 0);
|
||||
|
||||
// notify other devices to manage the KV cache in the same way
|
||||
llama_send_kv_cache_seq_rm (ctx_tgt, s_keep, n_past_tgt, -1);
|
||||
llama_send_kv_cache_seq_keep(ctx_tgt, s_keep);
|
||||
llama_send_kv_cache_seq_cp (ctx_tgt, s_keep, 0, -1, -1);
|
||||
llama_send_kv_cache_seq_keep(ctx_tgt, 0);
|
||||
}
|
||||
|
||||
for (int s = 0; s < n_seq_dft; ++s) {
|
||||
|
@ -435,7 +444,6 @@ int main(int argc, char ** argv) {
|
|||
llama_batch_add (batch_dft, token_id, n_past_dft, { 0 }, true);
|
||||
|
||||
llama_kv_cache_seq_rm(ctx_dft, 0, n_past_dft, -1);
|
||||
// LOG_DBG("dft batch: %s\n", LOG_BATCH_TOSTR_PRETTY(ctx_dft, batch_dft).c_str());
|
||||
llama_decode(ctx_dft, batch_dft);
|
||||
|
||||
++n_past_dft;
|
||||
|
@ -575,12 +583,13 @@ int main(int argc, char ** argv) {
|
|||
|
||||
// evaluate the target model on the drafted tokens
|
||||
{
|
||||
llama_kv_cache_seq_keep(ctx_tgt, 0);
|
||||
llama_kv_cache_seq_keep (ctx_tgt, 0);
|
||||
llama_send_kv_cache_seq_keep(ctx_tgt, 0);
|
||||
for (int s = 1; s < n_seq_dft; ++s) {
|
||||
llama_kv_cache_seq_cp(ctx_tgt, 0, s, -1, -1);
|
||||
llama_kv_cache_seq_cp (ctx_tgt, 0, s, -1, -1);
|
||||
llama_send_kv_cache_seq_cp(ctx_tgt, 0, s, -1, -1);
|
||||
}
|
||||
|
||||
// LOG_DBG("target batch: %s\n", LOG_BATCH_TOSTR_PRETTY(ctx_tgt, batch_tgt).c_str());
|
||||
llama_decode(ctx_tgt, batch_tgt);
|
||||
++n_past_tgt;
|
||||
}
|
||||
|
@ -612,7 +621,7 @@ int main(int argc, char ** argv) {
|
|||
|
||||
LOG_INF("\n");
|
||||
LOG_INF("draft:\n\n");
|
||||
// TODO: print sampling/grammar timings for all drafts
|
||||
|
||||
llama_perf_context_print(ctx_dft);
|
||||
|
||||
LOG_INF("\n");
|
||||
|
@ -624,7 +633,6 @@ int main(int argc, char ** argv) {
|
|||
gpt_sampler_free(drafts[s].smpl);
|
||||
}
|
||||
|
||||
llama_sampler_free(softmax);
|
||||
llama_batch_free(batch_dft);
|
||||
|
||||
llama_free(ctx_tgt);
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue