mirror of
https://github.com/LostRuins/koboldcpp.git
synced 2025-09-10 00:54:41 +00:00
speculative decoding initial impl completed (+6 squashed commit)
Squashed commit: [0a6306ca0] draft wip dont use (will be squashed) [a758a1c9c] wip dont use (will be squashed) [e1994d3ce] wip dont use [f59690d68] wip [77228147d] wip on spec decoding. dont use yet [2445bca54] wip adding speculative decoding (+1 squashed commits) Squashed commits: [50e341bb7] wip adding speculative decoding
This commit is contained in:
parent
b9e99c69e8
commit
f75bbb945f
9 changed files with 539 additions and 280 deletions
|
@ -40,8 +40,10 @@
|
|||
#include "mpt_v3.cpp"
|
||||
#include "examples/llava/clip.h"
|
||||
#include "examples/llava/llava.h"
|
||||
#include "common/common.h"
|
||||
|
||||
//const
|
||||
const int speculative_chunk_amt = 16; //do it in chunks of this many tokens
|
||||
const int extra_context_handle_fragmentation = 120;
|
||||
const int LLAVA_TOKEN_IDENTIFIER_A = -998; //alternate between both, changing when image changes
|
||||
const int LLAVA_TOKEN_IDENTIFIER_B = -999;
|
||||
|
@ -51,6 +53,7 @@ std::string executable_path = "";
|
|||
std::string lora_filename = "";
|
||||
std::string lora_base = "";
|
||||
std::string mmproj_filename = "";
|
||||
std::string draftmodel_filename = "";
|
||||
bool generation_finished;
|
||||
float last_process_time = 0;
|
||||
float last_eval_time = 0;
|
||||
|
@ -90,6 +93,7 @@ static rwkv_context * rwkv_ctx_v3;
|
|||
static llama_v2_context * llama_ctx_v2;
|
||||
static llama_v3_context * llama_ctx_v3;
|
||||
static llama_context * llama_ctx_v4;
|
||||
static llama_context * draft_ctx = nullptr; //will remain null if speculative is unused
|
||||
|
||||
static clip_ctx * clp_ctx = nullptr; //for llava
|
||||
static clip_image_u8 * clp_img_data = nullptr; //most recent image
|
||||
|
@ -487,6 +491,10 @@ void ContextRewind(std::vector<int> &embd, std::vector<int> ¤t_context_tok
|
|||
if (file_format == FileFormat::GGUF_GENERIC)
|
||||
{
|
||||
llama_kv_cache_seq_rm(llama_ctx_v4, 0, n_past, -1);
|
||||
if(draft_ctx)
|
||||
{
|
||||
llama_kv_cache_seq_rm(draft_ctx, 0, n_past, -1);
|
||||
}
|
||||
}
|
||||
|
||||
embd.clear();
|
||||
|
@ -527,6 +535,170 @@ const char * kcpp_print_system_info(void) {
|
|||
return s.c_str();
|
||||
}
|
||||
|
||||
struct kcpp_embd_batch { //duplcated from llava_embd_batch
|
||||
std::vector<int32_t> pos;
|
||||
std::vector<int32_t> n_seq_id;
|
||||
std::vector<int32_t> seq_id_0;
|
||||
std::vector<int32_t *> seq_ids;
|
||||
std::vector<int8_t> logits;
|
||||
llama_batch batch;
|
||||
kcpp_embd_batch(float * embd, int32_t n_tokens, int32_t npast) {
|
||||
int32_t seq_id = 0;
|
||||
pos.resize(n_tokens);
|
||||
n_seq_id.resize(n_tokens);
|
||||
seq_ids.resize(n_tokens + 1);
|
||||
logits.resize(n_tokens);
|
||||
seq_id_0.resize(1);
|
||||
seq_id_0[0] = seq_id;
|
||||
seq_ids [n_tokens] = nullptr;
|
||||
batch = {
|
||||
/*n_tokens =*/ n_tokens,
|
||||
/*tokens =*/ nullptr,
|
||||
/*embd =*/ embd,
|
||||
/*pos =*/ pos.data(),
|
||||
/*n_seq_id =*/ n_seq_id.data(),
|
||||
/*seq_id =*/ seq_ids.data(),
|
||||
/*logits =*/ logits.data(),
|
||||
};
|
||||
for (int i = 0; i < n_tokens; i++) {
|
||||
batch.pos [i] = npast + i;
|
||||
batch.n_seq_id[i] = 1;
|
||||
batch.seq_id [i] = seq_id_0.data();
|
||||
batch.logits [i] = false;
|
||||
}
|
||||
}
|
||||
kcpp_embd_batch(std::vector<llama_token> & tokens, int32_t npast, bool return_all_logits) {
|
||||
int32_t seq_id = 0;
|
||||
int32_t n_tokens = tokens.size();
|
||||
pos.resize(n_tokens);
|
||||
n_seq_id.resize(n_tokens);
|
||||
seq_ids.resize(n_tokens + 1);
|
||||
logits.resize(n_tokens);
|
||||
seq_id_0.resize(1);
|
||||
seq_id_0[0] = seq_id;
|
||||
seq_ids [n_tokens] = nullptr;
|
||||
batch = {
|
||||
/*n_tokens =*/ n_tokens,
|
||||
/*tokens =*/ tokens.data(),
|
||||
/*embd =*/ nullptr,
|
||||
/*pos =*/ pos.data(),
|
||||
/*n_seq_id =*/ n_seq_id.data(),
|
||||
/*seq_id =*/ seq_ids.data(),
|
||||
/*logits =*/ logits.data(),
|
||||
};
|
||||
for (int i = 0; i < n_tokens; i++) {
|
||||
batch.pos [i] = npast + i;
|
||||
batch.n_seq_id[i] = 1;
|
||||
batch.seq_id [i] = seq_id_0.data();
|
||||
batch.logits [i] = (return_all_logits?true:false);
|
||||
}
|
||||
batch.logits[n_tokens - 1] = true;
|
||||
}
|
||||
};
|
||||
|
||||
//loads a model for speculative decoding.
|
||||
static void speculative_decoding_setup(std::string spec_model_filename, const llama_model_params & base_model_params, const llama_context_params & base_ctx_params, int base_n_vocab)
|
||||
{
|
||||
llama_model_params draft_model_params = llama_model_default_params();
|
||||
llama_context_params draft_ctx_params = llama_context_default_params();
|
||||
|
||||
draft_model_params.use_mmap = base_model_params.use_mmap;
|
||||
draft_model_params.use_mlock = base_model_params.use_mlock;
|
||||
draft_model_params.n_gpu_layers = 999; //assume they want to fully offload the speculative model. Otherwise, why even use it?
|
||||
draft_ctx_params.n_ctx = base_ctx_params.n_ctx;
|
||||
draft_ctx_params.logits_all = false;
|
||||
draft_ctx_params.offload_kqv = base_ctx_params.offload_kqv;
|
||||
draft_model_params.main_gpu = base_model_params.main_gpu;
|
||||
draft_model_params.split_mode = llama_split_mode::LLAMA_SPLIT_MODE_LAYER;
|
||||
draft_ctx_params.n_batch = base_ctx_params.n_batch;
|
||||
draft_ctx_params.n_ubatch = base_ctx_params.n_ubatch;
|
||||
draft_ctx_params.n_threads = base_ctx_params.n_threads;
|
||||
draft_ctx_params.n_threads_batch = base_ctx_params.n_threads_batch;
|
||||
draft_ctx_params.flash_attn = base_ctx_params.flash_attn;
|
||||
draft_ctx_params.type_k = base_ctx_params.type_k;
|
||||
draft_ctx_params.type_v = base_ctx_params.type_v;
|
||||
|
||||
llama_model * draftmodel = llama_load_model_from_file(spec_model_filename.c_str(), draft_model_params);
|
||||
draft_ctx = llama_new_context_with_model(draftmodel, draft_ctx_params);
|
||||
if(draft_ctx == NULL)
|
||||
{
|
||||
printf("Error: failed to load speculative decoding draft model '%s'\n", spec_model_filename.c_str());
|
||||
printf("Speculative Decoding will not be used!\n");
|
||||
}
|
||||
else
|
||||
{
|
||||
int draftvocab = llama_n_vocab(draftmodel);
|
||||
if(draftvocab!=base_n_vocab)
|
||||
{
|
||||
printf("Error: Draft model vocab of (%d) does not match base vocab of (%d). Speculative decoding cannot be used!\n",draftvocab,base_n_vocab);
|
||||
llama_free(draft_ctx);
|
||||
draft_ctx = nullptr;
|
||||
}else if(llama_model_is_recurrent(draftmodel))
|
||||
{
|
||||
printf("Error: Speculative decoding cannot be used with Recurrent draft models!\n");
|
||||
llama_free(draft_ctx);
|
||||
draft_ctx = nullptr;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static speculative_draft_result speculative_decoding_eval_chunk(llama_context * draft_ctx, llama_context * main_ctx, const llama_tokens & embd, const int n_vocab, const int & n_past)
|
||||
{
|
||||
speculative_draft_result results;
|
||||
results.draft_success = false;
|
||||
if(embd.size()==0)
|
||||
{
|
||||
printf("\nERROR: Speculate on empty batch!\n");
|
||||
return results;
|
||||
}
|
||||
if(embd.size()>1)
|
||||
{
|
||||
printf("\nERROR: Speculative decoding applied on large batch!\n");
|
||||
return results;
|
||||
}
|
||||
int draft_npast = n_past;
|
||||
int actual_npast = n_past;
|
||||
std::vector<int> temp_embd;
|
||||
std::vector<int> drafted_ids;
|
||||
temp_embd.push_back(embd[0]);
|
||||
drafted_ids.push_back(embd[0]);
|
||||
for(int i=0;i<speculative_chunk_amt;++i)
|
||||
{
|
||||
kcpp_embd_batch batch1 = kcpp_embd_batch(temp_embd, draft_npast, false);
|
||||
auto draftok = (llama_decode(draft_ctx, batch1.batch)==0);
|
||||
if(!draftok)
|
||||
{
|
||||
printf("\nERROR: Speculative draft model 1 failed!\n");
|
||||
return results;
|
||||
}
|
||||
float * draftlogits = llama_get_logits(draft_ctx);
|
||||
//greedy sample the draft model
|
||||
int topid = std::max_element(draftlogits, draftlogits + n_vocab) - draftlogits;
|
||||
drafted_ids.push_back(topid);
|
||||
temp_embd.clear();
|
||||
temp_embd.push_back(topid);
|
||||
++draft_npast;
|
||||
}
|
||||
//now that we have our drafted tokens, we form a batch and PP it
|
||||
kcpp_embd_batch batch2 = kcpp_embd_batch(drafted_ids, actual_npast, true);
|
||||
auto draftok = (llama_decode(main_ctx, batch2.batch)==0); //actual eval for big model
|
||||
if(!draftok)
|
||||
{
|
||||
printf("\nERROR: Speculative draft model 2 failed!\n");
|
||||
return results;
|
||||
}
|
||||
results.drafted_amount = 0;
|
||||
for(int i=0;i<drafted_ids.size()-1;++i)
|
||||
{
|
||||
results.drafted_amount += 1;
|
||||
float * fulllogits = llama_get_logits_ith(main_ctx,i);
|
||||
results.draftids.push_back(drafted_ids[i+1]);
|
||||
results.actual_logits.push_back(fulllogits);
|
||||
}
|
||||
results.draft_success = true;
|
||||
return results;
|
||||
}
|
||||
|
||||
// KCPP SAMPLING FUNCTIONS
|
||||
void sample_softmax(llama_token_data_array * cur_p) {
|
||||
GGML_ASSERT(cur_p->size > 0);
|
||||
|
@ -1554,66 +1726,6 @@ static void load_grammar(const std::string & gammarstr)
|
|||
}
|
||||
}
|
||||
|
||||
struct kcpp_embd_batch { //duplcated from llava_embd_batch
|
||||
std::vector<int32_t> pos;
|
||||
std::vector<int32_t> n_seq_id;
|
||||
std::vector<int32_t> seq_id_0;
|
||||
std::vector<int32_t *> seq_ids;
|
||||
std::vector<int8_t> logits;
|
||||
llama_batch batch;
|
||||
kcpp_embd_batch(float * embd, int32_t n_tokens, int32_t npast) {
|
||||
int32_t seq_id = 0;
|
||||
pos.resize(n_tokens);
|
||||
n_seq_id.resize(n_tokens);
|
||||
seq_ids.resize(n_tokens + 1);
|
||||
logits.resize(n_tokens);
|
||||
seq_id_0.resize(1);
|
||||
seq_id_0[0] = seq_id;
|
||||
seq_ids [n_tokens] = nullptr;
|
||||
batch = {
|
||||
/*n_tokens =*/ n_tokens,
|
||||
/*tokens =*/ nullptr,
|
||||
/*embd =*/ embd,
|
||||
/*pos =*/ pos.data(),
|
||||
/*n_seq_id =*/ n_seq_id.data(),
|
||||
/*seq_id =*/ seq_ids.data(),
|
||||
/*logits =*/ logits.data(),
|
||||
};
|
||||
for (int i = 0; i < n_tokens; i++) {
|
||||
batch.pos [i] = npast + i;
|
||||
batch.n_seq_id[i] = 1;
|
||||
batch.seq_id [i] = seq_id_0.data();
|
||||
batch.logits [i] = false;
|
||||
}
|
||||
}
|
||||
kcpp_embd_batch(std::vector<llama_token> & tokens, int32_t npast) {
|
||||
int32_t seq_id = 0;
|
||||
int32_t n_tokens = tokens.size();
|
||||
pos.resize(n_tokens);
|
||||
n_seq_id.resize(n_tokens);
|
||||
seq_ids.resize(n_tokens + 1);
|
||||
logits.resize(n_tokens);
|
||||
seq_id_0.resize(1);
|
||||
seq_id_0[0] = seq_id;
|
||||
seq_ids [n_tokens] = nullptr;
|
||||
batch = {
|
||||
/*n_tokens =*/ n_tokens,
|
||||
/*tokens =*/ tokens.data(),
|
||||
/*embd =*/ nullptr,
|
||||
/*pos =*/ pos.data(),
|
||||
/*n_seq_id =*/ n_seq_id.data(),
|
||||
/*seq_id =*/ seq_ids.data(),
|
||||
/*logits =*/ logits.data(),
|
||||
};
|
||||
for (int i = 0; i < n_tokens; i++) {
|
||||
batch.pos [i] = npast + i;
|
||||
batch.n_seq_id[i] = 1;
|
||||
batch.seq_id [i] = seq_id_0.data();
|
||||
batch.logits [i] = false;
|
||||
}
|
||||
batch.logits[n_tokens - 1] = true;
|
||||
}
|
||||
};
|
||||
static bool kcpp_eval_image(llama_context * ctx_llama, float * img_embd, int num_img_tokens, int n_batch, int * n_past) {
|
||||
int n_embd = llama_n_embd(llama_get_model(ctx_llama));
|
||||
|
||||
|
@ -1635,7 +1747,7 @@ static bool kcpp_eval_image(llama_context * ctx_llama, float * img_embd, int num
|
|||
|
||||
//given an old GGUF context and a new context that has some middle portion removed,
|
||||
//find and remove the middle portion from the old context from the KV. Does not fast forward after this destructive action
|
||||
void PurgeMissingTokens(llama_context * ctx, std::vector<int> ¤t_context_tokens, std::vector<int> &new_context_tokens, const int genamt, const int nctx)
|
||||
void PurgeMissingTokens(llama_context * ctx, llama_context * draft_ctx, std::vector<int> ¤t_context_tokens, std::vector<int> &new_context_tokens, const int genamt, const int nctx)
|
||||
{
|
||||
//scan from start old and new ctx, until first mismatch found, save as p0
|
||||
//check remaining old and new ctx for longest common subseq, which needs to be at 256 tokens
|
||||
|
@ -1688,8 +1800,13 @@ void PurgeMissingTokens(llama_context * ctx, std::vector<int> ¤t_context_t
|
|||
|
||||
//extract the unwanted tokens out from context and KV
|
||||
int diff = found - trimstart;
|
||||
llama_kv_cache_seq_rm(llama_ctx_v4, 0, trimstart, trimstart + diff);
|
||||
llama_kv_cache_seq_add(llama_ctx_v4, 0, trimstart + diff, -1, -diff);
|
||||
llama_kv_cache_seq_rm(ctx, 0, trimstart, trimstart + diff);
|
||||
llama_kv_cache_seq_add(ctx, 0, trimstart + diff, -1, -diff);
|
||||
if(draft_ctx)
|
||||
{
|
||||
llama_kv_cache_seq_rm(draft_ctx, 0, trimstart, trimstart + diff);
|
||||
llama_kv_cache_seq_add(draft_ctx, 0, trimstart + diff, -1, -diff);
|
||||
}
|
||||
|
||||
for (size_t i = trimstart + diff; i < current_context_tokens.size() - 1; i++)
|
||||
{
|
||||
|
@ -1791,6 +1908,7 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in
|
|||
kcpp_data->use_contextshift = inputs.use_contextshift;
|
||||
kcpp_data->use_fastforward = inputs.use_fastforward;
|
||||
debugmode = inputs.debugmode;
|
||||
draft_ctx = nullptr;
|
||||
|
||||
auto clamped_max_context_length = inputs.max_context_length;
|
||||
|
||||
|
@ -2136,6 +2254,23 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in
|
|||
|
||||
n_vocab = llama_n_vocab(llamamodel);
|
||||
|
||||
if(draftmodel_filename !="" && file_format==FileFormat::GGUF_GENERIC)
|
||||
{
|
||||
if(llama_model_is_recurrent(llamamodel))
|
||||
{
|
||||
printf("Error: Speculative decoding cannot be used with Recurrent models!\n");
|
||||
}
|
||||
else if(clp_ctx!=nullptr)
|
||||
{
|
||||
printf("Error: Speculative decoding cannot be used with multimodal vision projectors!\n");
|
||||
}
|
||||
else
|
||||
{
|
||||
printf("\nAttempting to load draft model for speculative decoding. It will be fully offloaded if possible. Vocab must match the main model.\n");
|
||||
speculative_decoding_setup(draftmodel_filename, model_params, llama_ctx_params, n_vocab);
|
||||
}
|
||||
}
|
||||
|
||||
//determine mem per token
|
||||
std::vector<int> tmp = {1, 2, 3, 4};
|
||||
llama_kv_cache_clear(llama_ctx_v4);
|
||||
|
@ -3090,6 +3225,10 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
|
|||
if(n_past==0)
|
||||
{
|
||||
llama_kv_cache_clear(llama_ctx_v4);
|
||||
if(draft_ctx)
|
||||
{
|
||||
llama_kv_cache_clear(draft_ctx);
|
||||
}
|
||||
}
|
||||
else if(embd_inp.size()==0)
|
||||
{
|
||||
|
@ -3105,7 +3244,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
|
|||
{
|
||||
if(kcpp_data->use_fastforward && kcpp_data->use_contextshift && (file_format == FileFormat::GGUF_GENERIC))
|
||||
{
|
||||
PurgeMissingTokens(llama_ctx_v4, current_context_tokens, embd_inp, inputs.max_length, nctx);
|
||||
PurgeMissingTokens(llama_ctx_v4, draft_ctx, current_context_tokens, embd_inp, inputs.max_length, nctx);
|
||||
triggersc = false;
|
||||
}
|
||||
if(kcpp_data->use_fastforward)
|
||||
|
@ -3116,6 +3255,10 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
|
|||
if(file_format == FileFormat::GGUF_GENERIC)
|
||||
{
|
||||
llama_kv_cache_seq_rm(llama_ctx_v4, 0, n_past, -1);
|
||||
if(draft_ctx)
|
||||
{
|
||||
llama_kv_cache_seq_rm(draft_ctx, 0, n_past, -1);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -3153,6 +3296,9 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
|
|||
bool startedsampling = false;
|
||||
bool v3_use_scratch = true; //for normal inference always use scratch
|
||||
|
||||
speculative_draft_result draft_results; //only use if drafting was used
|
||||
bool draft_used = false;
|
||||
|
||||
time0 = timer_check();
|
||||
timer_start();
|
||||
|
||||
|
@ -3223,9 +3369,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
|
|||
|
||||
if (embdsize > 0)
|
||||
{
|
||||
|
||||
bool evalres = false;
|
||||
|
||||
if (file_format == FileFormat::GGML || file_format == FileFormat::GGHF || file_format == FileFormat::GGJT || file_format == FileFormat::GGJT_2)
|
||||
{
|
||||
evalres = (llama_v2_eval(llama_ctx_v2, embd.data(), embdsize, n_past, GetThreadsToUse(blasmode))==0);
|
||||
|
@ -3236,8 +3380,20 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
|
|||
}
|
||||
else if(file_format == FileFormat::GGUF_GENERIC)
|
||||
{
|
||||
kcpp_embd_batch batch = kcpp_embd_batch(embd, n_past);
|
||||
evalres = (llama_decode(llama_ctx_v4, batch.batch)==0);
|
||||
if(embd.size()!=1 || draft_ctx==nullptr || remaining_tokens<=speculative_chunk_amt || grammar!=nullptr || startedsampling==false) //for large batch, or if no draft model, PP/TG as usual
|
||||
{
|
||||
draft_used = false;
|
||||
kcpp_embd_batch batch = kcpp_embd_batch(embd, n_past, false);
|
||||
evalres = (llama_decode(llama_ctx_v4, batch.batch)==0);
|
||||
if(draft_ctx)
|
||||
{
|
||||
evalres = (evalres && (llama_decode(draft_ctx, batch.batch)==0));
|
||||
}
|
||||
} else { //individual tokens AND speculative is used (generation)
|
||||
draft_used = true;
|
||||
draft_results = speculative_decoding_eval_chunk(draft_ctx, llama_ctx_v4, embd, n_vocab, n_past);
|
||||
evalres = draft_results.draft_success;
|
||||
}
|
||||
}
|
||||
else if(file_format==FileFormat::RWKV_1 || file_format==FileFormat::RWKV_2)
|
||||
{
|
||||
|
@ -3255,8 +3411,8 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
|
|||
}
|
||||
else
|
||||
{
|
||||
bool ignoreLogits = (!startedsampling && ((int)embd_inp.size() > input_consumed + 2));
|
||||
evalres = rwkv_eval(rwkv_ctx_v3, GetThreadsToUse(blasmode), embd[0], rwkv_ctx_v3->state_in, rwkv_ctx_v3->state_out, ignoreLogits?nullptr:rwkv_ctx_v3->logits_out);
|
||||
bool ignoreLogits = (!startedsampling && ((int)embd_inp.size() > input_consumed + 2));
|
||||
evalres = rwkv_eval(rwkv_ctx_v3, GetThreadsToUse(blasmode), embd[0], rwkv_ctx_v3->state_in, rwkv_ctx_v3->state_out, ignoreLogits?nullptr:rwkv_ctx_v3->logits_out);
|
||||
}
|
||||
|
||||
memcpy(logits.data(), rwkv_ctx_v3->logits_out, sizeof(float) * rwkv_vocab.size());
|
||||
|
@ -3350,230 +3506,284 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
|
|||
float * logitsPtr;
|
||||
float lowestLogit = 0;
|
||||
int btsize = banned_token_ids.size();
|
||||
if(file_format == FileFormat::GGML || file_format == FileFormat::GGHF || file_format == FileFormat::GGJT || file_format == FileFormat::GGJT_2 || file_format == FileFormat::GGJT_3 || file_format == FileFormat::GGUF_GENERIC)
|
||||
|
||||
//sample pending logits. usually only 1, unless speculative decoding
|
||||
int logits_to_sample = 1;
|
||||
int logits_sampled = 0;
|
||||
bool abort_draft = false;
|
||||
if(draft_used)
|
||||
{
|
||||
if(file_format == FileFormat::GGUF_GENERIC)
|
||||
logits_to_sample = draft_results.drafted_amount;
|
||||
}
|
||||
while(logits_sampled<logits_to_sample && remaining_tokens>0 && !abort_draft)
|
||||
{
|
||||
if(logits_sampled>0)
|
||||
{
|
||||
logitsPtr = llama_get_logits(llama_ctx_v4);
|
||||
//this is not the first loop, so we need to increment some things
|
||||
n_past += 1;
|
||||
}
|
||||
else if(file_format == FileFormat::GGJT_3)
|
||||
if(file_format == FileFormat::GGML || file_format == FileFormat::GGHF || file_format == FileFormat::GGJT || file_format == FileFormat::GGJT_2 || file_format == FileFormat::GGJT_3 || file_format == FileFormat::GGUF_GENERIC)
|
||||
{
|
||||
logitsPtr = llama_v3_get_logits(llama_ctx_v3);
|
||||
if(file_format == FileFormat::GGUF_GENERIC)
|
||||
{
|
||||
if(draft_used)
|
||||
{
|
||||
logitsPtr = draft_results.actual_logits[logits_sampled];
|
||||
}
|
||||
else
|
||||
{
|
||||
logitsPtr = llama_get_logits(llama_ctx_v4);
|
||||
}
|
||||
}
|
||||
else if(file_format == FileFormat::GGJT_3)
|
||||
{
|
||||
logitsPtr = llama_v3_get_logits(llama_ctx_v3);
|
||||
}
|
||||
else
|
||||
{
|
||||
logitsPtr = llama_v2_get_logits(llama_ctx_v2);
|
||||
}
|
||||
lowestLogit = LowestLogit(logitsPtr,n_vocab);
|
||||
}
|
||||
else
|
||||
{
|
||||
logitsPtr = llama_v2_get_logits(llama_ctx_v2);
|
||||
}
|
||||
lowestLogit = LowestLogit(logitsPtr,n_vocab);
|
||||
}
|
||||
else
|
||||
{
|
||||
logitsPtr = logits.data();
|
||||
lowestLogit = LowestLogit(logits);
|
||||
}
|
||||
|
||||
if (!inputs.allow_eos_token && !inputs.bypass_eos_token)
|
||||
{
|
||||
// set the logit of the eos token to very low to avoid sampling it
|
||||
if(eosID!=LLAMA_TOKEN_NULL)
|
||||
{
|
||||
logitsPtr[eosID] = lowestLogit;
|
||||
}
|
||||
if(eotID!=-1)
|
||||
{
|
||||
logitsPtr[eotID] = lowestLogit;
|
||||
}
|
||||
}
|
||||
if(btsize>0)
|
||||
{
|
||||
for(int t=0;t<btsize;++t)
|
||||
{
|
||||
logitsPtr[banned_token_ids[t]]=lowestLogit;
|
||||
}
|
||||
}
|
||||
|
||||
//handle temp bans from antislop
|
||||
if (antislop_banned_token_ids.find(n_past) != antislop_banned_token_ids.end()) {
|
||||
std::vector<int>& bans = antislop_banned_token_ids[n_past];
|
||||
for(int t=0;t<bans.size();++t)
|
||||
{
|
||||
logitsPtr[bans[t]]=lowestLogit;
|
||||
}
|
||||
}
|
||||
|
||||
id = SampleLogits(logitsPtr, nctx, n_vocab, last_n_size, repeat_penalty, kcpp_data->rep_pen_slope, presence_penalty,
|
||||
top_k, top_a, top_p, min_p, typical_p, tfs_z, temp, rng,
|
||||
kcpp_data->mirostat, kcpp_data->mirostat_tau, kcpp_data->mirostat_eta,
|
||||
kcpp_data->dry_multiplier, kcpp_data->dry_base,
|
||||
kcpp_data->dry_allowed_length, kcpp_data->dry_penalty_last_n, kcpp_data->xtc_threshold, kcpp_data->xtc_probability,
|
||||
sampler_order, grammar, dynatemp_range, dynatemp_exponent, smoothing_factor);
|
||||
|
||||
if (grammar != nullptr) {
|
||||
grammar_accept_token(file_format, n_vocab, grammar, id);
|
||||
}
|
||||
|
||||
if (!last_n_tokens.empty())
|
||||
{
|
||||
last_n_tokens.erase(last_n_tokens.begin());
|
||||
}
|
||||
last_n_tokens.push_back(id);
|
||||
current_context_tokens.push_back(id);
|
||||
|
||||
// add it to the context
|
||||
embd.push_back(id);
|
||||
|
||||
// decrement remaining sampling budget
|
||||
--remaining_tokens;
|
||||
|
||||
for (auto eid : embd)
|
||||
{
|
||||
std::string tokenizedstr = FileFormatTokenizeID(eid, file_format, inputs.render_special);
|
||||
if(!inputs.render_special && (eid==eosID || (eid==eotID && eid!=-1) || VecContainsIntVal(special_stop_sequence,id))) //extra filter to avoid unwanted special tokens
|
||||
{
|
||||
tokenizedstr = ""; //prevent render
|
||||
logitsPtr = logits.data(); //legacy rwkv, neox, gptj etc
|
||||
lowestLogit = LowestLogit(logits);
|
||||
}
|
||||
|
||||
delayed_generated_tokens.push_back(tokenizedstr);
|
||||
while(delayed_generated_tokens.size() > delayed_generated_tokens_limit && delayed_generated_tokens.size() > 0)
|
||||
if (!inputs.allow_eos_token && !inputs.bypass_eos_token)
|
||||
{
|
||||
generated_tokens.push_back(delayed_generated_tokens[0]);
|
||||
concat_output_mtx.lock();
|
||||
concat_output += delayed_generated_tokens[0];
|
||||
concat_output_mtx.unlock();
|
||||
delayed_generated_tokens.pop_front();
|
||||
}
|
||||
}
|
||||
|
||||
if (startedsampling && allow_regular_prints)
|
||||
{
|
||||
printf("\rGenerating (%d / %d tokens)", (kcpp_data->n_predict - remaining_tokens), kcpp_data->n_predict);
|
||||
}
|
||||
if(debugmode==1 && top_picks_history.size()>0)
|
||||
{
|
||||
printf(" [");
|
||||
bool firstloop = true;
|
||||
TopPicksData toppick = top_picks_history[top_picks_history.size()-1];
|
||||
std::string topstr = toppick.selected_token;
|
||||
::utreplace(topstr, "\n", "\\n");
|
||||
printf("(%s %.2f%%)", RemoveBell(topstr).c_str(), toppick.selected_probability*100);
|
||||
int maxtoshow = (toppick.tokenid.size()>4?4:toppick.tokenid.size());
|
||||
for (int i=0;i<maxtoshow;++i)
|
||||
{
|
||||
if(toppick.tokenid[i]==toppick.selected_tokenid)
|
||||
// set the logit of the eos token to very low to avoid sampling it
|
||||
if(eosID!=LLAMA_TOKEN_NULL)
|
||||
{
|
||||
continue;
|
||||
logitsPtr[eosID] = lowestLogit;
|
||||
}
|
||||
printf(" ");
|
||||
std::string tokenizedstr = toppick.tokens[i];
|
||||
::utreplace(tokenizedstr, "\n", "\\n");
|
||||
printf("(%s %.2f%%)", RemoveBell(tokenizedstr).c_str(), toppick.p[i]*100);
|
||||
}
|
||||
printf("]\n");
|
||||
}
|
||||
|
||||
//anti slop detection
|
||||
if (banned_phrases.size() > 0)
|
||||
{
|
||||
std::string scanstr = "";
|
||||
for (int i = 0; i < delayed_generated_tokens.size(); ++i)
|
||||
{
|
||||
scanstr += delayed_generated_tokens[i];
|
||||
}
|
||||
scanstr = toLowerCase(scanstr);
|
||||
for (const auto &matched : banned_phrases)
|
||||
{
|
||||
std::string matched_lower = toLowerCase(matched);
|
||||
if (scanstr.find(matched_lower) != std::string::npos)
|
||||
if(eotID!=-1)
|
||||
{
|
||||
//find the position in the string that contains all necessary tokens
|
||||
std::string checkstr = "";
|
||||
int rewind_amt = 0;
|
||||
for (int i = delayed_generated_tokens.size() - 1; i >= 0; --i)
|
||||
logitsPtr[eotID] = lowestLogit;
|
||||
}
|
||||
}
|
||||
if(btsize>0)
|
||||
{
|
||||
for(int t=0;t<btsize;++t)
|
||||
{
|
||||
logitsPtr[banned_token_ids[t]]=lowestLogit;
|
||||
}
|
||||
}
|
||||
|
||||
//handle temp bans from antislop
|
||||
if (antislop_banned_token_ids.find(n_past) != antislop_banned_token_ids.end()) {
|
||||
std::vector<int>& bans = antislop_banned_token_ids[n_past];
|
||||
for(int t=0;t<bans.size();++t)
|
||||
{
|
||||
logitsPtr[bans[t]]=lowestLogit;
|
||||
}
|
||||
}
|
||||
|
||||
id = SampleLogits(logitsPtr, nctx, n_vocab, last_n_size, repeat_penalty, kcpp_data->rep_pen_slope, presence_penalty,
|
||||
top_k, top_a, top_p, min_p, typical_p, tfs_z, temp, rng,
|
||||
kcpp_data->mirostat, kcpp_data->mirostat_tau, kcpp_data->mirostat_eta,
|
||||
kcpp_data->dry_multiplier, kcpp_data->dry_base,
|
||||
kcpp_data->dry_allowed_length, kcpp_data->dry_penalty_last_n, kcpp_data->xtc_threshold, kcpp_data->xtc_probability,
|
||||
sampler_order, grammar, dynatemp_range, dynatemp_exponent, smoothing_factor);
|
||||
|
||||
if(draft_used)
|
||||
{
|
||||
int32_t draftedid = draft_results.draftids[logits_sampled];
|
||||
if(debugmode==1)
|
||||
{
|
||||
std::string drafttok = FileFormatTokenizeID(draftedid, file_format, true);
|
||||
std::string realtok = FileFormatTokenizeID(id, file_format, true);
|
||||
printf("(Draft %d/%d): Predicted=%d (%s), Actual=%d (%s) [%s]\n",(logits_sampled+1),logits_to_sample,draftedid,drafttok.c_str(),id,realtok.c_str(),(draftedid==id?"PASS":"FAIL"));
|
||||
}
|
||||
if(draftedid!=id) //draft mismatch, abort
|
||||
{
|
||||
abort_draft = true;
|
||||
}
|
||||
}
|
||||
|
||||
if (grammar != nullptr) {
|
||||
grammar_accept_token(file_format, n_vocab, grammar, id);
|
||||
}
|
||||
|
||||
if (!last_n_tokens.empty())
|
||||
{
|
||||
last_n_tokens.erase(last_n_tokens.begin());
|
||||
}
|
||||
last_n_tokens.push_back(id);
|
||||
current_context_tokens.push_back(id);
|
||||
|
||||
// add it to the context
|
||||
embd.clear();
|
||||
embd.push_back(id);
|
||||
|
||||
// decrement remaining sampling budget
|
||||
--remaining_tokens;
|
||||
|
||||
for (auto eid : embd)
|
||||
{
|
||||
std::string tokenizedstr = FileFormatTokenizeID(eid, file_format, inputs.render_special);
|
||||
if(!inputs.render_special && (eid==eosID || (eid==eotID && eid!=-1) || VecContainsIntVal(special_stop_sequence,id))) //extra filter to avoid unwanted special tokens
|
||||
{
|
||||
tokenizedstr = ""; //prevent render
|
||||
}
|
||||
|
||||
delayed_generated_tokens.push_back(tokenizedstr);
|
||||
while(delayed_generated_tokens.size() > delayed_generated_tokens_limit && delayed_generated_tokens.size() > 0)
|
||||
{
|
||||
generated_tokens.push_back(delayed_generated_tokens[0]);
|
||||
concat_output_mtx.lock();
|
||||
concat_output += delayed_generated_tokens[0];
|
||||
concat_output_mtx.unlock();
|
||||
delayed_generated_tokens.pop_front();
|
||||
}
|
||||
}
|
||||
|
||||
if (startedsampling && allow_regular_prints)
|
||||
{
|
||||
printf("\rGenerating (%d / %d tokens)", (kcpp_data->n_predict - remaining_tokens), kcpp_data->n_predict);
|
||||
}
|
||||
if(debugmode==1 && top_picks_history.size()>0)
|
||||
{
|
||||
printf(" [");
|
||||
bool firstloop = true;
|
||||
TopPicksData toppick = top_picks_history[top_picks_history.size()-1];
|
||||
std::string topstr = toppick.selected_token;
|
||||
::utreplace(topstr, "\n", "\\n");
|
||||
printf("(%s %.2f%%)", RemoveBell(topstr).c_str(), toppick.selected_probability*100);
|
||||
int maxtoshow = (toppick.tokenid.size()>4?4:toppick.tokenid.size());
|
||||
for (int i=0;i<maxtoshow;++i)
|
||||
{
|
||||
if(toppick.tokenid[i]==toppick.selected_tokenid)
|
||||
{
|
||||
checkstr = delayed_generated_tokens[i] + checkstr;
|
||||
++rewind_amt;
|
||||
if (toLowerCase(checkstr).find(matched_lower) != std::string::npos)
|
||||
continue;
|
||||
}
|
||||
printf(" ");
|
||||
std::string tokenizedstr = toppick.tokens[i];
|
||||
::utreplace(tokenizedstr, "\n", "\\n");
|
||||
printf("(%s %.2f%%)", RemoveBell(tokenizedstr).c_str(), toppick.p[i]*100);
|
||||
}
|
||||
printf("]\n");
|
||||
}
|
||||
|
||||
//anti slop detection
|
||||
if (banned_phrases.size() > 0)
|
||||
{
|
||||
std::string scanstr = "";
|
||||
for (int i = 0; i < delayed_generated_tokens.size(); ++i)
|
||||
{
|
||||
scanstr += delayed_generated_tokens[i];
|
||||
}
|
||||
scanstr = toLowerCase(scanstr);
|
||||
for (const auto &matched : banned_phrases)
|
||||
{
|
||||
std::string matched_lower = toLowerCase(matched);
|
||||
if (scanstr.find(matched_lower) != std::string::npos)
|
||||
{
|
||||
//find the position in the string that contains all necessary tokens
|
||||
std::string checkstr = "";
|
||||
int rewind_amt = 0;
|
||||
for (int i = delayed_generated_tokens.size() - 1; i >= 0; --i)
|
||||
{
|
||||
checkstr = delayed_generated_tokens[i] + checkstr;
|
||||
++rewind_amt;
|
||||
if (toLowerCase(checkstr).find(matched_lower) != std::string::npos)
|
||||
{
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (rewind_amt > 0 && (current_context_tokens.size() - rewind_amt) > 0)
|
||||
{
|
||||
int last_tok = current_context_tokens[current_context_tokens.size() - rewind_amt];
|
||||
delayed_generated_tokens.resize(delayed_generated_tokens.size() - rewind_amt);
|
||||
ContextRewind(embd, current_context_tokens, n_past, last_n_tokens, rewind_amt);
|
||||
|
||||
//immediately terminate drafting if used
|
||||
abort_draft = true;
|
||||
|
||||
// Check if the key exists
|
||||
int banindex = n_past+1;
|
||||
if (antislop_banned_token_ids.find(banindex) == antislop_banned_token_ids.end()) {
|
||||
antislop_banned_token_ids[banindex] = std::vector<int>();
|
||||
}
|
||||
std::vector<int>& current_ids = antislop_banned_token_ids[banindex];
|
||||
current_ids.push_back(last_tok);
|
||||
|
||||
if (allow_regular_prints && debugmode == 1)
|
||||
{
|
||||
auto match_clean = matched;
|
||||
replace_all(match_clean, "\n", "\\n");
|
||||
printf("\n(Banned Phrase Detected: %s - Add ID %d to banlist at index %d, and rewinding %d tokens)\n", match_clean.c_str(), last_tok, banindex, rewind_amt);
|
||||
}
|
||||
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (rewind_amt > 0 && (current_context_tokens.size() - rewind_amt) > 0)
|
||||
}
|
||||
}
|
||||
|
||||
bool earlystopped = false;
|
||||
if(!inputs.bypass_eos_token && inputs.allow_eos_token && (id==eosID || (id==eotID && id!=-1)))
|
||||
{
|
||||
stopper_unused_tokens = remaining_tokens;
|
||||
if(allow_regular_prints)
|
||||
{
|
||||
printf("\n(EOS token triggered! ID:%d)",id);
|
||||
}
|
||||
remaining_tokens = 0;
|
||||
last_stop_reason = stop_reason::EOS_TOKEN_HIT;
|
||||
earlystopped = true;
|
||||
}
|
||||
|
||||
if(!earlystopped)
|
||||
{
|
||||
for (const auto &matched : special_stop_sequence)
|
||||
{
|
||||
if(id==matched)
|
||||
{
|
||||
int last_tok = current_context_tokens[current_context_tokens.size() - rewind_amt];
|
||||
delayed_generated_tokens.resize(delayed_generated_tokens.size() - rewind_amt);
|
||||
ContextRewind(embd, current_context_tokens, n_past, last_n_tokens, rewind_amt);
|
||||
|
||||
// Check if the key exists
|
||||
int banindex = n_past+1;
|
||||
if (antislop_banned_token_ids.find(banindex) == antislop_banned_token_ids.end()) {
|
||||
antislop_banned_token_ids[banindex] = std::vector<int>();
|
||||
}
|
||||
std::vector<int>& current_ids = antislop_banned_token_ids[banindex];
|
||||
current_ids.push_back(last_tok);
|
||||
|
||||
if (allow_regular_prints && debugmode == 1)
|
||||
stopper_unused_tokens = remaining_tokens;
|
||||
if(allow_regular_prints)
|
||||
{
|
||||
auto match_clean = matched;
|
||||
replace_all(match_clean, "\n", "\\n");
|
||||
printf("\n(Banned Phrase Detected: %s - Add ID %d to banlist at index %d, and rewinding %d tokens)\n", match_clean.c_str(), last_tok, banindex, rewind_amt);
|
||||
printf("\n(Special Stop Token Triggered! ID:%d)",matched);
|
||||
}
|
||||
|
||||
remaining_tokens = 0;
|
||||
last_stop_reason = stop_reason::EOS_TOKEN_HIT;
|
||||
earlystopped = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
bool earlystopped = false;
|
||||
if(!inputs.bypass_eos_token && inputs.allow_eos_token && (id==eosID || (id==eotID && id!=-1)))
|
||||
{
|
||||
stopper_unused_tokens = remaining_tokens;
|
||||
if(allow_regular_prints)
|
||||
if(!earlystopped)
|
||||
{
|
||||
printf("\n(EOS token triggered! ID:%d)",id);
|
||||
}
|
||||
remaining_tokens = 0;
|
||||
last_stop_reason = stop_reason::EOS_TOKEN_HIT;
|
||||
earlystopped = true;
|
||||
}
|
||||
|
||||
if(!earlystopped)
|
||||
{
|
||||
for (const auto &matched : special_stop_sequence)
|
||||
{
|
||||
if(id==matched)
|
||||
for (const auto &matched : stop_sequence)
|
||||
{
|
||||
stopper_unused_tokens = remaining_tokens;
|
||||
if(allow_regular_prints)
|
||||
if (concat_output.find(matched) != std::string::npos)
|
||||
{
|
||||
printf("\n(Special Stop Token Triggered! ID:%d)",matched);
|
||||
stopper_unused_tokens = remaining_tokens;
|
||||
remaining_tokens = 0;
|
||||
if(allow_regular_prints)
|
||||
{
|
||||
auto match_clean = matched;
|
||||
replace_all(match_clean, "\n", "\\n");
|
||||
printf("\n(Stop sequence triggered: %s)", match_clean.c_str());
|
||||
}
|
||||
last_stop_reason = stop_reason::CUSTOM_STOPPER;
|
||||
earlystopped = true;
|
||||
break;
|
||||
}
|
||||
remaining_tokens = 0;
|
||||
last_stop_reason = stop_reason::EOS_TOKEN_HIT;
|
||||
earlystopped = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
logits_sampled += 1;
|
||||
}
|
||||
|
||||
if(!earlystopped)
|
||||
//if we have somehow skipped ahead (e.g drafting), ensure that all tokens after npast are purged
|
||||
if (file_format == FileFormat::GGUF_GENERIC && draft_used)
|
||||
{
|
||||
for (const auto &matched : stop_sequence)
|
||||
{
|
||||
if (concat_output.find(matched) != std::string::npos)
|
||||
{
|
||||
stopper_unused_tokens = remaining_tokens;
|
||||
remaining_tokens = 0;
|
||||
if(allow_regular_prints)
|
||||
{
|
||||
auto match_clean = matched;
|
||||
replace_all(match_clean, "\n", "\\n");
|
||||
printf("\n(Stop sequence triggered: %s)", match_clean.c_str());
|
||||
}
|
||||
last_stop_reason = stop_reason::CUSTOM_STOPPER;
|
||||
earlystopped = true;
|
||||
break;
|
||||
}
|
||||
llama_kv_cache_seq_rm(llama_ctx_v4, 0, n_past, -1);
|
||||
if (draft_ctx) {
|
||||
llama_kv_cache_seq_rm(draft_ctx, 0, n_past, -1);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -3611,6 +3821,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
|
|||
}
|
||||
for(int i=0;i<llava_images.size();++i)
|
||||
{
|
||||
//note: no handling for draft_ctx as we don't support vision for it
|
||||
if(i>0 && sepsize>0)
|
||||
{
|
||||
//add a separator between each image
|
||||
|
@ -3701,7 +3912,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
|
|||
float pt2 = (time2*1000.0/(realnpredict==0?1:realnpredict));
|
||||
float ts2 = (1000.0/pt2);
|
||||
float tokens_per_second = (realnpredict == 0 ? 0 : realnpredict / (time1 + time2));
|
||||
printf("\nCtxLimit:%d/%d, Amt:%d/%d, Init:%.2fs, Process:%.2fs (%.1fms/T = %.2fT/s), Generate:%.2fs (%.1fms/T = %.2fT/s), Total:%.2fs (%.2fT/s)",(int)current_context_tokens.size(),(int)nctx, realnpredict, kcpp_data->n_predict, time0, time1, pt1, ts1, time2, pt2, ts2, (time1 + time2), tokens_per_second);
|
||||
printf("\n[%s] CtxLimit:%d/%d, Amt:%d/%d, Init:%.2fs, Process:%.2fs (%.1fms/T = %.2fT/s), Generate:%.2fs (%.1fms/T = %.2fT/s), Total:%.2fs (%.2fT/s)",get_timestamp_str().c_str(),(int)current_context_tokens.size(),(int)nctx, realnpredict, kcpp_data->n_predict, time0, time1, pt1, ts1, time2, pt2, ts2, (time1 + time2), tokens_per_second);
|
||||
fflush(stdout);
|
||||
output.status = 1;
|
||||
int finaltokcount = (int)current_context_tokens.size()-realnpredict;
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue