speculative decoding initial impl completed (+6 squashed commit)

Squashed commit:

[0a6306ca0] draft wip dont use (will be squashed)

[a758a1c9c] wip dont use (will be squashed)

[e1994d3ce] wip dont use

[f59690d68] wip

[77228147d] wip on spec decoding. dont use yet

[2445bca54] wip adding speculative decoding (+1 squashed commits)

Squashed commits:

[50e341bb7] wip adding speculative decoding
This commit is contained in:
Concedo 2024-11-27 00:16:51 +08:00
parent b9e99c69e8
commit f75bbb945f
9 changed files with 539 additions and 280 deletions

View file

@ -35,6 +35,7 @@ extern "C"
lora_filename = inputs.lora_filename;
lora_base = inputs.lora_base;
mmproj_filename = inputs.mmproj_filename;
draftmodel_filename = inputs.draftmodel_filename;
int forceversion = inputs.forceversion;

View file

@ -40,6 +40,7 @@ struct load_model_inputs
const char * model_filename = nullptr;
const char * lora_filename = nullptr;
const char * lora_base = nullptr;
const char * draftmodel_filename = nullptr;
const char * mmproj_filename = nullptr;
const bool use_mmap = false;
const bool use_mlock = false;
@ -197,6 +198,7 @@ extern std::string executable_path;
extern std::string lora_filename;
extern std::string lora_base;
extern std::string mmproj_filename;
extern std::string draftmodel_filename;
extern std::vector<std::string> generated_tokens;
extern bool generation_finished;
extern float last_eval_time;

View file

@ -40,8 +40,10 @@
#include "mpt_v3.cpp"
#include "examples/llava/clip.h"
#include "examples/llava/llava.h"
#include "common/common.h"
//const
const int speculative_chunk_amt = 16; //do it in chunks of this many tokens
const int extra_context_handle_fragmentation = 120;
const int LLAVA_TOKEN_IDENTIFIER_A = -998; //alternate between both, changing when image changes
const int LLAVA_TOKEN_IDENTIFIER_B = -999;
@ -51,6 +53,7 @@ std::string executable_path = "";
std::string lora_filename = "";
std::string lora_base = "";
std::string mmproj_filename = "";
std::string draftmodel_filename = "";
bool generation_finished;
float last_process_time = 0;
float last_eval_time = 0;
@ -90,6 +93,7 @@ static rwkv_context * rwkv_ctx_v3;
static llama_v2_context * llama_ctx_v2;
static llama_v3_context * llama_ctx_v3;
static llama_context * llama_ctx_v4;
static llama_context * draft_ctx = nullptr; //will remain null if speculative is unused
static clip_ctx * clp_ctx = nullptr; //for llava
static clip_image_u8 * clp_img_data = nullptr; //most recent image
@ -487,6 +491,10 @@ void ContextRewind(std::vector<int> &embd, std::vector<int> &current_context_tok
if (file_format == FileFormat::GGUF_GENERIC)
{
llama_kv_cache_seq_rm(llama_ctx_v4, 0, n_past, -1);
if(draft_ctx)
{
llama_kv_cache_seq_rm(draft_ctx, 0, n_past, -1);
}
}
embd.clear();
@ -527,6 +535,170 @@ const char * kcpp_print_system_info(void) {
return s.c_str();
}
struct kcpp_embd_batch { //duplcated from llava_embd_batch
std::vector<int32_t> pos;
std::vector<int32_t> n_seq_id;
std::vector<int32_t> seq_id_0;
std::vector<int32_t *> seq_ids;
std::vector<int8_t> logits;
llama_batch batch;
kcpp_embd_batch(float * embd, int32_t n_tokens, int32_t npast) {
int32_t seq_id = 0;
pos.resize(n_tokens);
n_seq_id.resize(n_tokens);
seq_ids.resize(n_tokens + 1);
logits.resize(n_tokens);
seq_id_0.resize(1);
seq_id_0[0] = seq_id;
seq_ids [n_tokens] = nullptr;
batch = {
/*n_tokens =*/ n_tokens,
/*tokens =*/ nullptr,
/*embd =*/ embd,
/*pos =*/ pos.data(),
/*n_seq_id =*/ n_seq_id.data(),
/*seq_id =*/ seq_ids.data(),
/*logits =*/ logits.data(),
};
for (int i = 0; i < n_tokens; i++) {
batch.pos [i] = npast + i;
batch.n_seq_id[i] = 1;
batch.seq_id [i] = seq_id_0.data();
batch.logits [i] = false;
}
}
kcpp_embd_batch(std::vector<llama_token> & tokens, int32_t npast, bool return_all_logits) {
int32_t seq_id = 0;
int32_t n_tokens = tokens.size();
pos.resize(n_tokens);
n_seq_id.resize(n_tokens);
seq_ids.resize(n_tokens + 1);
logits.resize(n_tokens);
seq_id_0.resize(1);
seq_id_0[0] = seq_id;
seq_ids [n_tokens] = nullptr;
batch = {
/*n_tokens =*/ n_tokens,
/*tokens =*/ tokens.data(),
/*embd =*/ nullptr,
/*pos =*/ pos.data(),
/*n_seq_id =*/ n_seq_id.data(),
/*seq_id =*/ seq_ids.data(),
/*logits =*/ logits.data(),
};
for (int i = 0; i < n_tokens; i++) {
batch.pos [i] = npast + i;
batch.n_seq_id[i] = 1;
batch.seq_id [i] = seq_id_0.data();
batch.logits [i] = (return_all_logits?true:false);
}
batch.logits[n_tokens - 1] = true;
}
};
//loads a model for speculative decoding.
static void speculative_decoding_setup(std::string spec_model_filename, const llama_model_params & base_model_params, const llama_context_params & base_ctx_params, int base_n_vocab)
{
llama_model_params draft_model_params = llama_model_default_params();
llama_context_params draft_ctx_params = llama_context_default_params();
draft_model_params.use_mmap = base_model_params.use_mmap;
draft_model_params.use_mlock = base_model_params.use_mlock;
draft_model_params.n_gpu_layers = 999; //assume they want to fully offload the speculative model. Otherwise, why even use it?
draft_ctx_params.n_ctx = base_ctx_params.n_ctx;
draft_ctx_params.logits_all = false;
draft_ctx_params.offload_kqv = base_ctx_params.offload_kqv;
draft_model_params.main_gpu = base_model_params.main_gpu;
draft_model_params.split_mode = llama_split_mode::LLAMA_SPLIT_MODE_LAYER;
draft_ctx_params.n_batch = base_ctx_params.n_batch;
draft_ctx_params.n_ubatch = base_ctx_params.n_ubatch;
draft_ctx_params.n_threads = base_ctx_params.n_threads;
draft_ctx_params.n_threads_batch = base_ctx_params.n_threads_batch;
draft_ctx_params.flash_attn = base_ctx_params.flash_attn;
draft_ctx_params.type_k = base_ctx_params.type_k;
draft_ctx_params.type_v = base_ctx_params.type_v;
llama_model * draftmodel = llama_load_model_from_file(spec_model_filename.c_str(), draft_model_params);
draft_ctx = llama_new_context_with_model(draftmodel, draft_ctx_params);
if(draft_ctx == NULL)
{
printf("Error: failed to load speculative decoding draft model '%s'\n", spec_model_filename.c_str());
printf("Speculative Decoding will not be used!\n");
}
else
{
int draftvocab = llama_n_vocab(draftmodel);
if(draftvocab!=base_n_vocab)
{
printf("Error: Draft model vocab of (%d) does not match base vocab of (%d). Speculative decoding cannot be used!\n",draftvocab,base_n_vocab);
llama_free(draft_ctx);
draft_ctx = nullptr;
}else if(llama_model_is_recurrent(draftmodel))
{
printf("Error: Speculative decoding cannot be used with Recurrent draft models!\n");
llama_free(draft_ctx);
draft_ctx = nullptr;
}
}
}
static speculative_draft_result speculative_decoding_eval_chunk(llama_context * draft_ctx, llama_context * main_ctx, const llama_tokens & embd, const int n_vocab, const int & n_past)
{
speculative_draft_result results;
results.draft_success = false;
if(embd.size()==0)
{
printf("\nERROR: Speculate on empty batch!\n");
return results;
}
if(embd.size()>1)
{
printf("\nERROR: Speculative decoding applied on large batch!\n");
return results;
}
int draft_npast = n_past;
int actual_npast = n_past;
std::vector<int> temp_embd;
std::vector<int> drafted_ids;
temp_embd.push_back(embd[0]);
drafted_ids.push_back(embd[0]);
for(int i=0;i<speculative_chunk_amt;++i)
{
kcpp_embd_batch batch1 = kcpp_embd_batch(temp_embd, draft_npast, false);
auto draftok = (llama_decode(draft_ctx, batch1.batch)==0);
if(!draftok)
{
printf("\nERROR: Speculative draft model 1 failed!\n");
return results;
}
float * draftlogits = llama_get_logits(draft_ctx);
//greedy sample the draft model
int topid = std::max_element(draftlogits, draftlogits + n_vocab) - draftlogits;
drafted_ids.push_back(topid);
temp_embd.clear();
temp_embd.push_back(topid);
++draft_npast;
}
//now that we have our drafted tokens, we form a batch and PP it
kcpp_embd_batch batch2 = kcpp_embd_batch(drafted_ids, actual_npast, true);
auto draftok = (llama_decode(main_ctx, batch2.batch)==0); //actual eval for big model
if(!draftok)
{
printf("\nERROR: Speculative draft model 2 failed!\n");
return results;
}
results.drafted_amount = 0;
for(int i=0;i<drafted_ids.size()-1;++i)
{
results.drafted_amount += 1;
float * fulllogits = llama_get_logits_ith(main_ctx,i);
results.draftids.push_back(drafted_ids[i+1]);
results.actual_logits.push_back(fulllogits);
}
results.draft_success = true;
return results;
}
// KCPP SAMPLING FUNCTIONS
void sample_softmax(llama_token_data_array * cur_p) {
GGML_ASSERT(cur_p->size > 0);
@ -1554,66 +1726,6 @@ static void load_grammar(const std::string & gammarstr)
}
}
struct kcpp_embd_batch { //duplcated from llava_embd_batch
std::vector<int32_t> pos;
std::vector<int32_t> n_seq_id;
std::vector<int32_t> seq_id_0;
std::vector<int32_t *> seq_ids;
std::vector<int8_t> logits;
llama_batch batch;
kcpp_embd_batch(float * embd, int32_t n_tokens, int32_t npast) {
int32_t seq_id = 0;
pos.resize(n_tokens);
n_seq_id.resize(n_tokens);
seq_ids.resize(n_tokens + 1);
logits.resize(n_tokens);
seq_id_0.resize(1);
seq_id_0[0] = seq_id;
seq_ids [n_tokens] = nullptr;
batch = {
/*n_tokens =*/ n_tokens,
/*tokens =*/ nullptr,
/*embd =*/ embd,
/*pos =*/ pos.data(),
/*n_seq_id =*/ n_seq_id.data(),
/*seq_id =*/ seq_ids.data(),
/*logits =*/ logits.data(),
};
for (int i = 0; i < n_tokens; i++) {
batch.pos [i] = npast + i;
batch.n_seq_id[i] = 1;
batch.seq_id [i] = seq_id_0.data();
batch.logits [i] = false;
}
}
kcpp_embd_batch(std::vector<llama_token> & tokens, int32_t npast) {
int32_t seq_id = 0;
int32_t n_tokens = tokens.size();
pos.resize(n_tokens);
n_seq_id.resize(n_tokens);
seq_ids.resize(n_tokens + 1);
logits.resize(n_tokens);
seq_id_0.resize(1);
seq_id_0[0] = seq_id;
seq_ids [n_tokens] = nullptr;
batch = {
/*n_tokens =*/ n_tokens,
/*tokens =*/ tokens.data(),
/*embd =*/ nullptr,
/*pos =*/ pos.data(),
/*n_seq_id =*/ n_seq_id.data(),
/*seq_id =*/ seq_ids.data(),
/*logits =*/ logits.data(),
};
for (int i = 0; i < n_tokens; i++) {
batch.pos [i] = npast + i;
batch.n_seq_id[i] = 1;
batch.seq_id [i] = seq_id_0.data();
batch.logits [i] = false;
}
batch.logits[n_tokens - 1] = true;
}
};
static bool kcpp_eval_image(llama_context * ctx_llama, float * img_embd, int num_img_tokens, int n_batch, int * n_past) {
int n_embd = llama_n_embd(llama_get_model(ctx_llama));
@ -1635,7 +1747,7 @@ static bool kcpp_eval_image(llama_context * ctx_llama, float * img_embd, int num
//given an old GGUF context and a new context that has some middle portion removed,
//find and remove the middle portion from the old context from the KV. Does not fast forward after this destructive action
void PurgeMissingTokens(llama_context * ctx, std::vector<int> &current_context_tokens, std::vector<int> &new_context_tokens, const int genamt, const int nctx)
void PurgeMissingTokens(llama_context * ctx, llama_context * draft_ctx, std::vector<int> &current_context_tokens, std::vector<int> &new_context_tokens, const int genamt, const int nctx)
{
//scan from start old and new ctx, until first mismatch found, save as p0
//check remaining old and new ctx for longest common subseq, which needs to be at 256 tokens
@ -1688,8 +1800,13 @@ void PurgeMissingTokens(llama_context * ctx, std::vector<int> &current_context_t
//extract the unwanted tokens out from context and KV
int diff = found - trimstart;
llama_kv_cache_seq_rm(llama_ctx_v4, 0, trimstart, trimstart + diff);
llama_kv_cache_seq_add(llama_ctx_v4, 0, trimstart + diff, -1, -diff);
llama_kv_cache_seq_rm(ctx, 0, trimstart, trimstart + diff);
llama_kv_cache_seq_add(ctx, 0, trimstart + diff, -1, -diff);
if(draft_ctx)
{
llama_kv_cache_seq_rm(draft_ctx, 0, trimstart, trimstart + diff);
llama_kv_cache_seq_add(draft_ctx, 0, trimstart + diff, -1, -diff);
}
for (size_t i = trimstart + diff; i < current_context_tokens.size() - 1; i++)
{
@ -1791,6 +1908,7 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in
kcpp_data->use_contextshift = inputs.use_contextshift;
kcpp_data->use_fastforward = inputs.use_fastforward;
debugmode = inputs.debugmode;
draft_ctx = nullptr;
auto clamped_max_context_length = inputs.max_context_length;
@ -2136,6 +2254,23 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in
n_vocab = llama_n_vocab(llamamodel);
if(draftmodel_filename !="" && file_format==FileFormat::GGUF_GENERIC)
{
if(llama_model_is_recurrent(llamamodel))
{
printf("Error: Speculative decoding cannot be used with Recurrent models!\n");
}
else if(clp_ctx!=nullptr)
{
printf("Error: Speculative decoding cannot be used with multimodal vision projectors!\n");
}
else
{
printf("\nAttempting to load draft model for speculative decoding. It will be fully offloaded if possible. Vocab must match the main model.\n");
speculative_decoding_setup(draftmodel_filename, model_params, llama_ctx_params, n_vocab);
}
}
//determine mem per token
std::vector<int> tmp = {1, 2, 3, 4};
llama_kv_cache_clear(llama_ctx_v4);
@ -3090,6 +3225,10 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
if(n_past==0)
{
llama_kv_cache_clear(llama_ctx_v4);
if(draft_ctx)
{
llama_kv_cache_clear(draft_ctx);
}
}
else if(embd_inp.size()==0)
{
@ -3105,7 +3244,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
{
if(kcpp_data->use_fastforward && kcpp_data->use_contextshift && (file_format == FileFormat::GGUF_GENERIC))
{
PurgeMissingTokens(llama_ctx_v4, current_context_tokens, embd_inp, inputs.max_length, nctx);
PurgeMissingTokens(llama_ctx_v4, draft_ctx, current_context_tokens, embd_inp, inputs.max_length, nctx);
triggersc = false;
}
if(kcpp_data->use_fastforward)
@ -3116,6 +3255,10 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
if(file_format == FileFormat::GGUF_GENERIC)
{
llama_kv_cache_seq_rm(llama_ctx_v4, 0, n_past, -1);
if(draft_ctx)
{
llama_kv_cache_seq_rm(draft_ctx, 0, n_past, -1);
}
}
}
@ -3153,6 +3296,9 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
bool startedsampling = false;
bool v3_use_scratch = true; //for normal inference always use scratch
speculative_draft_result draft_results; //only use if drafting was used
bool draft_used = false;
time0 = timer_check();
timer_start();
@ -3223,9 +3369,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
if (embdsize > 0)
{
bool evalres = false;
if (file_format == FileFormat::GGML || file_format == FileFormat::GGHF || file_format == FileFormat::GGJT || file_format == FileFormat::GGJT_2)
{
evalres = (llama_v2_eval(llama_ctx_v2, embd.data(), embdsize, n_past, GetThreadsToUse(blasmode))==0);
@ -3236,8 +3380,20 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
}
else if(file_format == FileFormat::GGUF_GENERIC)
{
kcpp_embd_batch batch = kcpp_embd_batch(embd, n_past);
evalres = (llama_decode(llama_ctx_v4, batch.batch)==0);
if(embd.size()!=1 || draft_ctx==nullptr || remaining_tokens<=speculative_chunk_amt || grammar!=nullptr || startedsampling==false) //for large batch, or if no draft model, PP/TG as usual
{
draft_used = false;
kcpp_embd_batch batch = kcpp_embd_batch(embd, n_past, false);
evalres = (llama_decode(llama_ctx_v4, batch.batch)==0);
if(draft_ctx)
{
evalres = (evalres && (llama_decode(draft_ctx, batch.batch)==0));
}
} else { //individual tokens AND speculative is used (generation)
draft_used = true;
draft_results = speculative_decoding_eval_chunk(draft_ctx, llama_ctx_v4, embd, n_vocab, n_past);
evalres = draft_results.draft_success;
}
}
else if(file_format==FileFormat::RWKV_1 || file_format==FileFormat::RWKV_2)
{
@ -3255,8 +3411,8 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
}
else
{
bool ignoreLogits = (!startedsampling && ((int)embd_inp.size() > input_consumed + 2));
evalres = rwkv_eval(rwkv_ctx_v3, GetThreadsToUse(blasmode), embd[0], rwkv_ctx_v3->state_in, rwkv_ctx_v3->state_out, ignoreLogits?nullptr:rwkv_ctx_v3->logits_out);
bool ignoreLogits = (!startedsampling && ((int)embd_inp.size() > input_consumed + 2));
evalres = rwkv_eval(rwkv_ctx_v3, GetThreadsToUse(blasmode), embd[0], rwkv_ctx_v3->state_in, rwkv_ctx_v3->state_out, ignoreLogits?nullptr:rwkv_ctx_v3->logits_out);
}
memcpy(logits.data(), rwkv_ctx_v3->logits_out, sizeof(float) * rwkv_vocab.size());
@ -3350,230 +3506,284 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
float * logitsPtr;
float lowestLogit = 0;
int btsize = banned_token_ids.size();
if(file_format == FileFormat::GGML || file_format == FileFormat::GGHF || file_format == FileFormat::GGJT || file_format == FileFormat::GGJT_2 || file_format == FileFormat::GGJT_3 || file_format == FileFormat::GGUF_GENERIC)
//sample pending logits. usually only 1, unless speculative decoding
int logits_to_sample = 1;
int logits_sampled = 0;
bool abort_draft = false;
if(draft_used)
{
if(file_format == FileFormat::GGUF_GENERIC)
logits_to_sample = draft_results.drafted_amount;
}
while(logits_sampled<logits_to_sample && remaining_tokens>0 && !abort_draft)
{
if(logits_sampled>0)
{
logitsPtr = llama_get_logits(llama_ctx_v4);
//this is not the first loop, so we need to increment some things
n_past += 1;
}
else if(file_format == FileFormat::GGJT_3)
if(file_format == FileFormat::GGML || file_format == FileFormat::GGHF || file_format == FileFormat::GGJT || file_format == FileFormat::GGJT_2 || file_format == FileFormat::GGJT_3 || file_format == FileFormat::GGUF_GENERIC)
{
logitsPtr = llama_v3_get_logits(llama_ctx_v3);
if(file_format == FileFormat::GGUF_GENERIC)
{
if(draft_used)
{
logitsPtr = draft_results.actual_logits[logits_sampled];
}
else
{
logitsPtr = llama_get_logits(llama_ctx_v4);
}
}
else if(file_format == FileFormat::GGJT_3)
{
logitsPtr = llama_v3_get_logits(llama_ctx_v3);
}
else
{
logitsPtr = llama_v2_get_logits(llama_ctx_v2);
}
lowestLogit = LowestLogit(logitsPtr,n_vocab);
}
else
{
logitsPtr = llama_v2_get_logits(llama_ctx_v2);
}
lowestLogit = LowestLogit(logitsPtr,n_vocab);
}
else
{
logitsPtr = logits.data();
lowestLogit = LowestLogit(logits);
}
if (!inputs.allow_eos_token && !inputs.bypass_eos_token)
{
// set the logit of the eos token to very low to avoid sampling it
if(eosID!=LLAMA_TOKEN_NULL)
{
logitsPtr[eosID] = lowestLogit;
}
if(eotID!=-1)
{
logitsPtr[eotID] = lowestLogit;
}
}
if(btsize>0)
{
for(int t=0;t<btsize;++t)
{
logitsPtr[banned_token_ids[t]]=lowestLogit;
}
}
//handle temp bans from antislop
if (antislop_banned_token_ids.find(n_past) != antislop_banned_token_ids.end()) {
std::vector<int>& bans = antislop_banned_token_ids[n_past];
for(int t=0;t<bans.size();++t)
{
logitsPtr[bans[t]]=lowestLogit;
}
}
id = SampleLogits(logitsPtr, nctx, n_vocab, last_n_size, repeat_penalty, kcpp_data->rep_pen_slope, presence_penalty,
top_k, top_a, top_p, min_p, typical_p, tfs_z, temp, rng,
kcpp_data->mirostat, kcpp_data->mirostat_tau, kcpp_data->mirostat_eta,
kcpp_data->dry_multiplier, kcpp_data->dry_base,
kcpp_data->dry_allowed_length, kcpp_data->dry_penalty_last_n, kcpp_data->xtc_threshold, kcpp_data->xtc_probability,
sampler_order, grammar, dynatemp_range, dynatemp_exponent, smoothing_factor);
if (grammar != nullptr) {
grammar_accept_token(file_format, n_vocab, grammar, id);
}
if (!last_n_tokens.empty())
{
last_n_tokens.erase(last_n_tokens.begin());
}
last_n_tokens.push_back(id);
current_context_tokens.push_back(id);
// add it to the context
embd.push_back(id);
// decrement remaining sampling budget
--remaining_tokens;
for (auto eid : embd)
{
std::string tokenizedstr = FileFormatTokenizeID(eid, file_format, inputs.render_special);
if(!inputs.render_special && (eid==eosID || (eid==eotID && eid!=-1) || VecContainsIntVal(special_stop_sequence,id))) //extra filter to avoid unwanted special tokens
{
tokenizedstr = ""; //prevent render
logitsPtr = logits.data(); //legacy rwkv, neox, gptj etc
lowestLogit = LowestLogit(logits);
}
delayed_generated_tokens.push_back(tokenizedstr);
while(delayed_generated_tokens.size() > delayed_generated_tokens_limit && delayed_generated_tokens.size() > 0)
if (!inputs.allow_eos_token && !inputs.bypass_eos_token)
{
generated_tokens.push_back(delayed_generated_tokens[0]);
concat_output_mtx.lock();
concat_output += delayed_generated_tokens[0];
concat_output_mtx.unlock();
delayed_generated_tokens.pop_front();
}
}
if (startedsampling && allow_regular_prints)
{
printf("\rGenerating (%d / %d tokens)", (kcpp_data->n_predict - remaining_tokens), kcpp_data->n_predict);
}
if(debugmode==1 && top_picks_history.size()>0)
{
printf(" [");
bool firstloop = true;
TopPicksData toppick = top_picks_history[top_picks_history.size()-1];
std::string topstr = toppick.selected_token;
::utreplace(topstr, "\n", "\\n");
printf("(%s %.2f%%)", RemoveBell(topstr).c_str(), toppick.selected_probability*100);
int maxtoshow = (toppick.tokenid.size()>4?4:toppick.tokenid.size());
for (int i=0;i<maxtoshow;++i)
{
if(toppick.tokenid[i]==toppick.selected_tokenid)
// set the logit of the eos token to very low to avoid sampling it
if(eosID!=LLAMA_TOKEN_NULL)
{
continue;
logitsPtr[eosID] = lowestLogit;
}
printf(" ");
std::string tokenizedstr = toppick.tokens[i];
::utreplace(tokenizedstr, "\n", "\\n");
printf("(%s %.2f%%)", RemoveBell(tokenizedstr).c_str(), toppick.p[i]*100);
}
printf("]\n");
}
//anti slop detection
if (banned_phrases.size() > 0)
{
std::string scanstr = "";
for (int i = 0; i < delayed_generated_tokens.size(); ++i)
{
scanstr += delayed_generated_tokens[i];
}
scanstr = toLowerCase(scanstr);
for (const auto &matched : banned_phrases)
{
std::string matched_lower = toLowerCase(matched);
if (scanstr.find(matched_lower) != std::string::npos)
if(eotID!=-1)
{
//find the position in the string that contains all necessary tokens
std::string checkstr = "";
int rewind_amt = 0;
for (int i = delayed_generated_tokens.size() - 1; i >= 0; --i)
logitsPtr[eotID] = lowestLogit;
}
}
if(btsize>0)
{
for(int t=0;t<btsize;++t)
{
logitsPtr[banned_token_ids[t]]=lowestLogit;
}
}
//handle temp bans from antislop
if (antislop_banned_token_ids.find(n_past) != antislop_banned_token_ids.end()) {
std::vector<int>& bans = antislop_banned_token_ids[n_past];
for(int t=0;t<bans.size();++t)
{
logitsPtr[bans[t]]=lowestLogit;
}
}
id = SampleLogits(logitsPtr, nctx, n_vocab, last_n_size, repeat_penalty, kcpp_data->rep_pen_slope, presence_penalty,
top_k, top_a, top_p, min_p, typical_p, tfs_z, temp, rng,
kcpp_data->mirostat, kcpp_data->mirostat_tau, kcpp_data->mirostat_eta,
kcpp_data->dry_multiplier, kcpp_data->dry_base,
kcpp_data->dry_allowed_length, kcpp_data->dry_penalty_last_n, kcpp_data->xtc_threshold, kcpp_data->xtc_probability,
sampler_order, grammar, dynatemp_range, dynatemp_exponent, smoothing_factor);
if(draft_used)
{
int32_t draftedid = draft_results.draftids[logits_sampled];
if(debugmode==1)
{
std::string drafttok = FileFormatTokenizeID(draftedid, file_format, true);
std::string realtok = FileFormatTokenizeID(id, file_format, true);
printf("(Draft %d/%d): Predicted=%d (%s), Actual=%d (%s) [%s]\n",(logits_sampled+1),logits_to_sample,draftedid,drafttok.c_str(),id,realtok.c_str(),(draftedid==id?"PASS":"FAIL"));
}
if(draftedid!=id) //draft mismatch, abort
{
abort_draft = true;
}
}
if (grammar != nullptr) {
grammar_accept_token(file_format, n_vocab, grammar, id);
}
if (!last_n_tokens.empty())
{
last_n_tokens.erase(last_n_tokens.begin());
}
last_n_tokens.push_back(id);
current_context_tokens.push_back(id);
// add it to the context
embd.clear();
embd.push_back(id);
// decrement remaining sampling budget
--remaining_tokens;
for (auto eid : embd)
{
std::string tokenizedstr = FileFormatTokenizeID(eid, file_format, inputs.render_special);
if(!inputs.render_special && (eid==eosID || (eid==eotID && eid!=-1) || VecContainsIntVal(special_stop_sequence,id))) //extra filter to avoid unwanted special tokens
{
tokenizedstr = ""; //prevent render
}
delayed_generated_tokens.push_back(tokenizedstr);
while(delayed_generated_tokens.size() > delayed_generated_tokens_limit && delayed_generated_tokens.size() > 0)
{
generated_tokens.push_back(delayed_generated_tokens[0]);
concat_output_mtx.lock();
concat_output += delayed_generated_tokens[0];
concat_output_mtx.unlock();
delayed_generated_tokens.pop_front();
}
}
if (startedsampling && allow_regular_prints)
{
printf("\rGenerating (%d / %d tokens)", (kcpp_data->n_predict - remaining_tokens), kcpp_data->n_predict);
}
if(debugmode==1 && top_picks_history.size()>0)
{
printf(" [");
bool firstloop = true;
TopPicksData toppick = top_picks_history[top_picks_history.size()-1];
std::string topstr = toppick.selected_token;
::utreplace(topstr, "\n", "\\n");
printf("(%s %.2f%%)", RemoveBell(topstr).c_str(), toppick.selected_probability*100);
int maxtoshow = (toppick.tokenid.size()>4?4:toppick.tokenid.size());
for (int i=0;i<maxtoshow;++i)
{
if(toppick.tokenid[i]==toppick.selected_tokenid)
{
checkstr = delayed_generated_tokens[i] + checkstr;
++rewind_amt;
if (toLowerCase(checkstr).find(matched_lower) != std::string::npos)
continue;
}
printf(" ");
std::string tokenizedstr = toppick.tokens[i];
::utreplace(tokenizedstr, "\n", "\\n");
printf("(%s %.2f%%)", RemoveBell(tokenizedstr).c_str(), toppick.p[i]*100);
}
printf("]\n");
}
//anti slop detection
if (banned_phrases.size() > 0)
{
std::string scanstr = "";
for (int i = 0; i < delayed_generated_tokens.size(); ++i)
{
scanstr += delayed_generated_tokens[i];
}
scanstr = toLowerCase(scanstr);
for (const auto &matched : banned_phrases)
{
std::string matched_lower = toLowerCase(matched);
if (scanstr.find(matched_lower) != std::string::npos)
{
//find the position in the string that contains all necessary tokens
std::string checkstr = "";
int rewind_amt = 0;
for (int i = delayed_generated_tokens.size() - 1; i >= 0; --i)
{
checkstr = delayed_generated_tokens[i] + checkstr;
++rewind_amt;
if (toLowerCase(checkstr).find(matched_lower) != std::string::npos)
{
break;
}
}
if (rewind_amt > 0 && (current_context_tokens.size() - rewind_amt) > 0)
{
int last_tok = current_context_tokens[current_context_tokens.size() - rewind_amt];
delayed_generated_tokens.resize(delayed_generated_tokens.size() - rewind_amt);
ContextRewind(embd, current_context_tokens, n_past, last_n_tokens, rewind_amt);
//immediately terminate drafting if used
abort_draft = true;
// Check if the key exists
int banindex = n_past+1;
if (antislop_banned_token_ids.find(banindex) == antislop_banned_token_ids.end()) {
antislop_banned_token_ids[banindex] = std::vector<int>();
}
std::vector<int>& current_ids = antislop_banned_token_ids[banindex];
current_ids.push_back(last_tok);
if (allow_regular_prints && debugmode == 1)
{
auto match_clean = matched;
replace_all(match_clean, "\n", "\\n");
printf("\n(Banned Phrase Detected: %s - Add ID %d to banlist at index %d, and rewinding %d tokens)\n", match_clean.c_str(), last_tok, banindex, rewind_amt);
}
break;
}
}
if (rewind_amt > 0 && (current_context_tokens.size() - rewind_amt) > 0)
}
}
bool earlystopped = false;
if(!inputs.bypass_eos_token && inputs.allow_eos_token && (id==eosID || (id==eotID && id!=-1)))
{
stopper_unused_tokens = remaining_tokens;
if(allow_regular_prints)
{
printf("\n(EOS token triggered! ID:%d)",id);
}
remaining_tokens = 0;
last_stop_reason = stop_reason::EOS_TOKEN_HIT;
earlystopped = true;
}
if(!earlystopped)
{
for (const auto &matched : special_stop_sequence)
{
if(id==matched)
{
int last_tok = current_context_tokens[current_context_tokens.size() - rewind_amt];
delayed_generated_tokens.resize(delayed_generated_tokens.size() - rewind_amt);
ContextRewind(embd, current_context_tokens, n_past, last_n_tokens, rewind_amt);
// Check if the key exists
int banindex = n_past+1;
if (antislop_banned_token_ids.find(banindex) == antislop_banned_token_ids.end()) {
antislop_banned_token_ids[banindex] = std::vector<int>();
}
std::vector<int>& current_ids = antislop_banned_token_ids[banindex];
current_ids.push_back(last_tok);
if (allow_regular_prints && debugmode == 1)
stopper_unused_tokens = remaining_tokens;
if(allow_regular_prints)
{
auto match_clean = matched;
replace_all(match_clean, "\n", "\\n");
printf("\n(Banned Phrase Detected: %s - Add ID %d to banlist at index %d, and rewinding %d tokens)\n", match_clean.c_str(), last_tok, banindex, rewind_amt);
printf("\n(Special Stop Token Triggered! ID:%d)",matched);
}
remaining_tokens = 0;
last_stop_reason = stop_reason::EOS_TOKEN_HIT;
earlystopped = true;
break;
}
}
}
}
bool earlystopped = false;
if(!inputs.bypass_eos_token && inputs.allow_eos_token && (id==eosID || (id==eotID && id!=-1)))
{
stopper_unused_tokens = remaining_tokens;
if(allow_regular_prints)
if(!earlystopped)
{
printf("\n(EOS token triggered! ID:%d)",id);
}
remaining_tokens = 0;
last_stop_reason = stop_reason::EOS_TOKEN_HIT;
earlystopped = true;
}
if(!earlystopped)
{
for (const auto &matched : special_stop_sequence)
{
if(id==matched)
for (const auto &matched : stop_sequence)
{
stopper_unused_tokens = remaining_tokens;
if(allow_regular_prints)
if (concat_output.find(matched) != std::string::npos)
{
printf("\n(Special Stop Token Triggered! ID:%d)",matched);
stopper_unused_tokens = remaining_tokens;
remaining_tokens = 0;
if(allow_regular_prints)
{
auto match_clean = matched;
replace_all(match_clean, "\n", "\\n");
printf("\n(Stop sequence triggered: %s)", match_clean.c_str());
}
last_stop_reason = stop_reason::CUSTOM_STOPPER;
earlystopped = true;
break;
}
remaining_tokens = 0;
last_stop_reason = stop_reason::EOS_TOKEN_HIT;
earlystopped = true;
break;
}
}
logits_sampled += 1;
}
if(!earlystopped)
//if we have somehow skipped ahead (e.g drafting), ensure that all tokens after npast are purged
if (file_format == FileFormat::GGUF_GENERIC && draft_used)
{
for (const auto &matched : stop_sequence)
{
if (concat_output.find(matched) != std::string::npos)
{
stopper_unused_tokens = remaining_tokens;
remaining_tokens = 0;
if(allow_regular_prints)
{
auto match_clean = matched;
replace_all(match_clean, "\n", "\\n");
printf("\n(Stop sequence triggered: %s)", match_clean.c_str());
}
last_stop_reason = stop_reason::CUSTOM_STOPPER;
earlystopped = true;
break;
}
llama_kv_cache_seq_rm(llama_ctx_v4, 0, n_past, -1);
if (draft_ctx) {
llama_kv_cache_seq_rm(draft_ctx, 0, n_past, -1);
}
}
@ -3611,6 +3821,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
}
for(int i=0;i<llava_images.size();++i)
{
//note: no handling for draft_ctx as we don't support vision for it
if(i>0 && sepsize>0)
{
//add a separator between each image
@ -3701,7 +3912,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
float pt2 = (time2*1000.0/(realnpredict==0?1:realnpredict));
float ts2 = (1000.0/pt2);
float tokens_per_second = (realnpredict == 0 ? 0 : realnpredict / (time1 + time2));
printf("\nCtxLimit:%d/%d, Amt:%d/%d, Init:%.2fs, Process:%.2fs (%.1fms/T = %.2fT/s), Generate:%.2fs (%.1fms/T = %.2fT/s), Total:%.2fs (%.2fT/s)",(int)current_context_tokens.size(),(int)nctx, realnpredict, kcpp_data->n_predict, time0, time1, pt1, ts1, time2, pt2, ts2, (time1 + time2), tokens_per_second);
printf("\n[%s] CtxLimit:%d/%d, Amt:%d/%d, Init:%.2fs, Process:%.2fs (%.1fms/T = %.2fT/s), Generate:%.2fs (%.1fms/T = %.2fT/s), Total:%.2fs (%.2fT/s)",get_timestamp_str().c_str(),(int)current_context_tokens.size(),(int)nctx, realnpredict, kcpp_data->n_predict, time0, time1, pt1, ts1, time2, pt2, ts2, (time1 + time2), tokens_per_second);
fflush(stdout);
output.status = 1;
int finaltokcount = (int)current_context_tokens.size()-realnpredict;

View file

@ -131,6 +131,7 @@ class load_model_inputs(ctypes.Structure):
("model_filename", ctypes.c_char_p),
("lora_filename", ctypes.c_char_p),
("lora_base", ctypes.c_char_p),
("draftmodel_filename", ctypes.c_char_p),
("mmproj_filename", ctypes.c_char_p),
("use_mmap", ctypes.c_bool),
("use_mlock", ctypes.c_bool),
@ -672,24 +673,27 @@ def read_gguf_metadata(file_path):
except Exception as ex:
return None
def extract_modelfile_params(filepath,sdfilepath,whisperfilepath,mmprojfilepath):
def extract_modelfile_params(filepath,sdfilepath,whisperfilepath,mmprojfilepath,draftmodelpath):
global modelfile_extracted_meta
modelfile_extracted_meta = None
sdfsize = 0
whisperfsize = 0
mmprojsize = 0
draftmodelsize = 0
if sdfilepath and os.path.exists(sdfilepath):
sdfsize = os.path.getsize(sdfilepath)
if whisperfilepath and os.path.exists(whisperfilepath):
whisperfsize = os.path.getsize(whisperfilepath)
if mmprojfilepath and os.path.exists(mmprojfilepath):
mmprojsize = os.path.getsize(mmprojfilepath)
if draftmodelpath and os.path.exists(draftmodelpath):
draftmodelsize = os.path.getsize(draftmodelpath)
if filepath and os.path.exists(filepath):
try:
fsize = os.path.getsize(filepath)
if fsize>10000000: #dont bother with models < 10mb as they are probably bad
ggufmeta = read_gguf_metadata(filepath)
modelfile_extracted_meta = [ggufmeta,fsize,sdfsize,whisperfsize,mmprojsize] #extract done. note that meta may be null
modelfile_extracted_meta = [ggufmeta,fsize,sdfsize,whisperfsize,mmprojsize,draftmodelsize] #extract done. note that meta may be null
except Exception as ex:
modelfile_extracted_meta = None
@ -702,7 +706,7 @@ def autoset_gpu_layers(ctxsize,sdquanted,bbs): #shitty algo to determine how man
if showusedmemwarning and usedmem > (2.5*1024*1024*1024):
showusedmemwarning = False
print(f"Note: KoboldCpp has detected that a significant amount of GPU VRAM ({usedmem/1024/1024} MB) is currently used by another application.\nFor best results, you may wish to close that application and then restart KoboldCpp.\n***")
reservedmem = max(1.5*1024*1024*1024,(0.5*1024*1024*1024 + usedmem)) # determine vram overhead
reservedmem = max(1.3*1024*1024*1024,(0.5*1024*1024*1024 + usedmem)) # determine vram overhead
try:
if not modelfile_extracted_meta:
return 0
@ -719,6 +723,9 @@ def autoset_gpu_layers(ctxsize,sdquanted,bbs): #shitty algo to determine how man
mem -= 350*1024*1024
if modelfile_extracted_meta[4] > 1024*1024*10: #mmproj tax
mem -= 350*1024*1024
if modelfile_extracted_meta[5] > 1024*1024*10: #draft model tax
mem -= (modelfile_extracted_meta[5] * 1.5)
mem = 0 if mem < 0 else mem
csmul = 1.0
if cs:
@ -732,8 +739,8 @@ def autoset_gpu_layers(ctxsize,sdquanted,bbs): #shitty algo to determine how man
headcount = ggufmeta[1]
headkvlen = (ggufmeta[2] if ggufmeta[2] > 0 else 128)
ratio = (mem-usedmem)/(fsize*csmul*1.6*(1.0 if bbs <= 512 else 1.2))
computemem = layers*(4 if bbs <= 512 else (bbs/128))*headkvlen*cs*4*1.5 # apply blasbatchsize calculations if over 512
contextmem = layers*headcount*headkvlen*cs*4*1.1
computemem = layers*(4 if bbs <= 512 else (bbs/128))*headkvlen*cs*4*1.55 # apply blasbatchsize calculations if over 512
contextmem = layers*headcount*headkvlen*cs*4*1.15
if headcount > 0:
ratio = max(ratio, (mem - reservedmem - computemem) / (fsize + contextmem))
layerlimit = min(int(ratio*layers), (layers + 3))
@ -877,6 +884,7 @@ def load_model(model_filename):
if len(args.lora) > 1:
inputs.lora_base = args.lora[1].encode("UTF-8")
inputs.draftmodel_filename = args.draftmodel.encode("UTF-8") if args.draftmodel else "".encode("UTF-8")
inputs.mmproj_filename = args.mmproj.encode("UTF-8") if args.mmproj else "".encode("UTF-8")
inputs.use_smartcontext = args.smartcontext
inputs.use_contextshift = (0 if args.noshift else 1)
@ -1510,15 +1518,18 @@ ws ::= | " " | "\n" [ \t]{0,20}
elif api_format==6:
detokstr = ""
tokids = genparams.get('context', [])
adapter_obj = {} if chatcompl_adapter is None else chatcompl_adapter
user_message_start = adapter_obj.get("user_start", "\n\n### Instruction:\n")
assistant_message_start = adapter_obj.get("assistant_start", "\n\n### Response:\n")
try:
detokstr = detokenize_ids(tokids)
except Exception as e:
utfprint("Ollama Context Error: " + str(e))
ollamasysprompt = genparams.get('system', "")
ollamabodyprompt = detokstr + "\n\n### Instruction:\n" + genparams.get('prompt', "") + "\n\n### Response:\n"
ollamabodyprompt = f"{detokstr}{user_message_start}{genparams.get('prompt', '')}{assistant_message_start}"
genparams["stop_sequence"] = genparams.get('stop', [])
genparams["stop_sequence"].append("\n### Instruction:")
genparams["stop_sequence"].append("\n### Response:")
genparams["stop_sequence"].append(user_message_start.strip())
genparams["stop_sequence"].append(assistant_message_start.strip())
genparams["trim_stop"] = True
genparams["ollamasysprompt"] = ollamasysprompt
genparams["ollamabodyprompt"] = ollamabodyprompt
@ -2374,9 +2385,8 @@ Enter Prompt:<br>
return
is_quiet = args.quiet
utfprint(f"\n{datetime.now().strftime('[%H:%M:%S] Input Received')}")
if (args.debugmode != -1 and not is_quiet) or args.debugmode >= 1:
utfprint(f"Input: " + json.dumps(genparams))
utfprint(f"\nInput: " + json.dumps(genparams))
if args.foreground:
bring_terminal_to_foreground()
@ -2751,6 +2761,7 @@ def show_gui():
lora_base_var = ctk.StringVar()
preloadstory_var = ctk.StringVar()
mmproj_var = ctk.StringVar()
draftmodel_var = ctk.StringVar()
nomodel = ctk.IntVar(value=0)
port_var = ctk.StringVar(value=defaultport)
@ -2929,7 +2940,8 @@ def show_gui():
sdfilepath = sd_model_var.get()
whisperfilepath = whisper_model_var.get()
mmprojfilepath = mmproj_var.get()
extract_modelfile_params(filepath,sdfilepath,whisperfilepath,mmprojfilepath)
draftmodelpath = draftmodel_var.get()
extract_modelfile_params(filepath,sdfilepath,whisperfilepath,mmprojfilepath,draftmodelpath)
changed_gpulayers_estimate()
pass
@ -3234,18 +3246,20 @@ def show_gui():
makefileentry(model_tab, "Text Lora:", "Select Lora File",lora_var, 3,width=280,tooltiptxt="Select an optional GGML LoRA adapter to use.\nLeave blank to skip.")
makefileentry(model_tab, "Text Lora Base:", "Select Lora Base File", lora_base_var, 5,width=280,tooltiptxt="Select an optional F16 GGML LoRA base file to use.\nLeave blank to skip.")
makefileentry(model_tab, "Vision mmproj:", "Select Vision mmproj File", mmproj_var, 7,width=280,tooltiptxt="Select a mmproj file to use for vision models like LLaVA.\nLeave blank to skip.")
makefileentry(model_tab, "Preloaded Story:", "Select Preloaded Story File", preloadstory_var, 9,width=280,tooltiptxt="Select an optional KoboldAI JSON savefile \nto be served on launch to any client.")
makefileentry(model_tab, "ChatCompletions Adapter:", "Select ChatCompletions Adapter File", chatcompletionsadapter_var, 12, width=250, filetypes=[("JSON Adapter", "*.json")], tooltiptxt="Select an optional ChatCompletions Adapter JSON file to force custom instruct tags.")
makefileentry(model_tab, "Speculative Model:", "Select Draft Text Model File", draftmodel_var, 9,width=280,tooltiptxt="Select a draft text model file to use for speculative decoding.\nLeave blank to skip.")
makefileentry(model_tab, "Preloaded Story:", "Select Preloaded Story File", preloadstory_var, 11,width=280,tooltiptxt="Select an optional KoboldAI JSON savefile \nto be served on launch to any client.")
makefileentry(model_tab, "ChatCompletions Adapter:", "Select ChatCompletions Adapter File", chatcompletionsadapter_var, 14, width=250, filetypes=[("JSON Adapter", "*.json")], tooltiptxt="Select an optional ChatCompletions Adapter JSON file to force custom instruct tags.")
def pickpremadetemplate():
initialDir = os.path.join(os.path.abspath(os.path.dirname(__file__)), 'kcpp_adapters')
initialDir = initialDir if os.path.isdir(initialDir) else None
fnam = askopenfilename(title="Pick Premade ChatCompletions Adapter",filetypes=[("JSON Adapter", "*.json")], initialdir=initialDir)
if fnam:
chatcompletionsadapter_var.set(fnam)
ctk.CTkButton(model_tab, 64, text="Pick Premade", command=pickpremadetemplate).grid(row=13, column=0, padx=322, stick="nw")
ctk.CTkButton(model_tab, 64, text="Pick Premade", command=pickpremadetemplate).grid(row=15, column=0, padx=322, stick="nw")
mmproj_var.trace("w", gui_changed_modelfile)
makecheckbox(model_tab, "Allow Launch Without Models", nomodel, 15, tooltiptxt="Allows running the WebUI with no model loaded.")
draftmodel_var.trace("w", gui_changed_modelfile)
makecheckbox(model_tab, "Allow Launch Without Models", nomodel, 17, tooltiptxt="Allows running the WebUI with no model loaded.")
# Network Tab
network_tab = tabcontent["Network"]
@ -3489,6 +3503,7 @@ def show_gui():
except Exception as ex2:
pass
args.mmproj = None if mmproj_var.get() == "" else mmproj_var.get()
args.draftmodel = None if draftmodel_var.get() == "" else draftmodel_var.get()
args.ssl = None if (ssl_cert_var.get() == "" or ssl_key_var.get() == "") else ([ssl_cert_var.get(), ssl_key_var.get()])
args.password = None if (password_var.get() == "") else (password_var.get())
@ -3649,6 +3664,7 @@ def show_gui():
lora_var.set(dict["lora"][0])
mmproj_var.set(dict["mmproj"] if ("mmproj" in dict and dict["mmproj"]) else "")
draftmodel_var.set(dict["draftmodel"] if ("draftmodel" in dict and dict["draftmodel"]) else "")
ssl_cert_var.set("")
ssl_key_var.set("")
@ -4442,6 +4458,10 @@ def main(launch_args,start_server=True):
dlfile = download_model_from_url(args.whispermodel,[".gguf",".bin"])
if dlfile:
args.whispermodel = dlfile
if args.draftmodel and args.draftmodel!="":
dlfile = download_model_from_url(args.draftmodel,[".gguf"])
if dlfile:
args.draftmodel = dlfile
# sanitize and replace the default vanity name. remember me....
if args.model_param and args.model_param!="":
@ -4517,7 +4537,7 @@ def main(launch_args,start_server=True):
pass
if args.gpulayers==-1:
if MaxMemory[0] > 0 and (not args.usecpu) and ((args.usecublas is not None) or (args.usevulkan is not None) or (args.useclblast is not None) or sys.platform=="darwin"):
extract_modelfile_params(args.model_param,args.sdmodel,args.whispermodel,args.mmproj)
extract_modelfile_params(args.model_param,args.sdmodel,args.whispermodel,args.mmproj,args.draftmodel)
layeramt = autoset_gpu_layers(args.contextsize,args.sdquant,args.blasbatchsize)
print(f"Auto Recommended GPU Layers: {layeramt}")
args.gpulayers = layeramt
@ -4923,6 +4943,7 @@ if __name__ == '__main__':
advparser.add_argument("--ssl", help="Allows all content to be served over SSL instead. A valid UNENCRYPTED SSL cert and key .pem files must be provided", metavar=('[cert_pem]', '[key_pem]'), nargs='+')
advparser.add_argument("--nocertify", help="Allows insecure SSL connections. Use this if you have cert errors and need to bypass certificate restrictions.", action='store_true')
advparser.add_argument("--mmproj", help="Select a multimodal projector file for vision models like LLaVA.", default="")
advparser.add_argument("--draftmodel", help="Load a small draft model for speculative decoding. It will be fully offloaded. Vocab must match the main model.", default="")
advparser.add_argument("--password", help="Enter a password required to use this instance. This key will be required for all text endpoints. Image endpoints are not secured.", default=None)
advparser.add_argument("--ignoremissing", help="Ignores all missing non-essential files, just skipping them instead.", action='store_true')
advparser.add_argument("--chatcompletionsadapter", help="Select an optional ChatCompletions Adapter JSON file to force custom instruct tags.", default="")

View file

@ -507,4 +507,12 @@ struct llava_image
float * clp_img_embd = nullptr; //this holds dynamic memory and must be freed each use!
};
struct speculative_draft_result
{
std::vector<int32_t> draftids;
std::vector<float *> actual_logits;
bool draft_success = false;
int drafted_amount = 0;
};
const float default_norm_eps = 1e-5f;

View file

@ -357,11 +357,12 @@ sd_generation_outputs sdtype_generate(const sd_generation_inputs inputs)
int img2imgC = 3; // Assuming RGB image
std::vector<uint8_t> resized_image_buf(img2imgW * img2imgH * img2imgC);
std::string ts = get_timestamp_str();
if(!is_quiet)
{
printf("\nGenerating Image (%d steps)\n",inputs.sample_steps);
printf("\n[%s] Generating Image (%d steps)\n",ts.c_str(),inputs.sample_steps);
}else{
printf("\nGenerating (%d st.)\n",inputs.sample_steps);
printf("\n[%s] Generating (%d st.)\n",ts.c_str(),inputs.sample_steps);
}
fflush(stdout);

View file

@ -8,6 +8,7 @@
#include <locale>
#include <codecvt>
#include <sstream>
#include <ctime>
void utreplace(std::string & str, const std::string & needle, const std::string & replacement) {
@ -302,3 +303,14 @@ std::vector<uint8_t> kcpp_base64_decode(const std::string & encoded_string)
return ret;
}
std::string get_timestamp_str()
{
std::time_t t = std::time(nullptr);
std::tm* now = std::localtime(&t);
char buffer[16]; // Buffer to hold "hh:mm:ss" and null terminator
std::sprintf(buffer, "%02d:%02d:%02d", now->tm_hour, now->tm_min, now->tm_sec);
// Convert the buffer to a std::string
std::string timestamp(buffer);
return timestamp;
}

View file

@ -57,3 +57,5 @@ bool should_transpose_layer(std::string name);
void kcpp_graph_compute_helper(ggml_v3_cgraph * graph, int n_threads);
std::vector<uint8_t> kcpp_base64_decode(const std::string & encoded_string);
std::string get_timestamp_str();

View file

@ -273,11 +273,12 @@ whisper_generation_outputs whispertype_generate(const whisper_generation_inputs
// output text transcription
whisper_output_text = output_txt(whisper_ctx, pcmf32s);
std::string ts = get_timestamp_str();
if(!inputs.quiet)
{
printf("\nWhisper Transcribe Output: %s",whisper_output_text.c_str());
printf("\n[%s] Whisper Transcribe Output: %s",ts.c_str(),whisper_output_text.c_str());
} else {
printf("\nWhisper Transcribe Done.");
printf("\n[%s] Whisper Transcribe Done.",ts.c_str());
}
output.text = whisper_output_text.c_str();
output.status = 1;