speculative decoding initial impl completed (+6 squashed commit)

Squashed commit:

[0a6306ca0] draft wip dont use (will be squashed)

[a758a1c9c] wip dont use (will be squashed)

[e1994d3ce] wip dont use

[f59690d68] wip

[77228147d] wip on spec decoding. dont use yet

[2445bca54] wip adding speculative decoding (+1 squashed commits)

Squashed commits:

[50e341bb7] wip adding speculative decoding
This commit is contained in:
Concedo 2024-11-27 00:16:51 +08:00
parent b9e99c69e8
commit f75bbb945f
9 changed files with 539 additions and 280 deletions

View file

@ -35,6 +35,7 @@ extern "C"
lora_filename = inputs.lora_filename; lora_filename = inputs.lora_filename;
lora_base = inputs.lora_base; lora_base = inputs.lora_base;
mmproj_filename = inputs.mmproj_filename; mmproj_filename = inputs.mmproj_filename;
draftmodel_filename = inputs.draftmodel_filename;
int forceversion = inputs.forceversion; int forceversion = inputs.forceversion;

View file

@ -40,6 +40,7 @@ struct load_model_inputs
const char * model_filename = nullptr; const char * model_filename = nullptr;
const char * lora_filename = nullptr; const char * lora_filename = nullptr;
const char * lora_base = nullptr; const char * lora_base = nullptr;
const char * draftmodel_filename = nullptr;
const char * mmproj_filename = nullptr; const char * mmproj_filename = nullptr;
const bool use_mmap = false; const bool use_mmap = false;
const bool use_mlock = false; const bool use_mlock = false;
@ -197,6 +198,7 @@ extern std::string executable_path;
extern std::string lora_filename; extern std::string lora_filename;
extern std::string lora_base; extern std::string lora_base;
extern std::string mmproj_filename; extern std::string mmproj_filename;
extern std::string draftmodel_filename;
extern std::vector<std::string> generated_tokens; extern std::vector<std::string> generated_tokens;
extern bool generation_finished; extern bool generation_finished;
extern float last_eval_time; extern float last_eval_time;

View file

@ -40,8 +40,10 @@
#include "mpt_v3.cpp" #include "mpt_v3.cpp"
#include "examples/llava/clip.h" #include "examples/llava/clip.h"
#include "examples/llava/llava.h" #include "examples/llava/llava.h"
#include "common/common.h"
//const //const
const int speculative_chunk_amt = 16; //do it in chunks of this many tokens
const int extra_context_handle_fragmentation = 120; const int extra_context_handle_fragmentation = 120;
const int LLAVA_TOKEN_IDENTIFIER_A = -998; //alternate between both, changing when image changes const int LLAVA_TOKEN_IDENTIFIER_A = -998; //alternate between both, changing when image changes
const int LLAVA_TOKEN_IDENTIFIER_B = -999; const int LLAVA_TOKEN_IDENTIFIER_B = -999;
@ -51,6 +53,7 @@ std::string executable_path = "";
std::string lora_filename = ""; std::string lora_filename = "";
std::string lora_base = ""; std::string lora_base = "";
std::string mmproj_filename = ""; std::string mmproj_filename = "";
std::string draftmodel_filename = "";
bool generation_finished; bool generation_finished;
float last_process_time = 0; float last_process_time = 0;
float last_eval_time = 0; float last_eval_time = 0;
@ -90,6 +93,7 @@ static rwkv_context * rwkv_ctx_v3;
static llama_v2_context * llama_ctx_v2; static llama_v2_context * llama_ctx_v2;
static llama_v3_context * llama_ctx_v3; static llama_v3_context * llama_ctx_v3;
static llama_context * llama_ctx_v4; static llama_context * llama_ctx_v4;
static llama_context * draft_ctx = nullptr; //will remain null if speculative is unused
static clip_ctx * clp_ctx = nullptr; //for llava static clip_ctx * clp_ctx = nullptr; //for llava
static clip_image_u8 * clp_img_data = nullptr; //most recent image static clip_image_u8 * clp_img_data = nullptr; //most recent image
@ -487,6 +491,10 @@ void ContextRewind(std::vector<int> &embd, std::vector<int> &current_context_tok
if (file_format == FileFormat::GGUF_GENERIC) if (file_format == FileFormat::GGUF_GENERIC)
{ {
llama_kv_cache_seq_rm(llama_ctx_v4, 0, n_past, -1); llama_kv_cache_seq_rm(llama_ctx_v4, 0, n_past, -1);
if(draft_ctx)
{
llama_kv_cache_seq_rm(draft_ctx, 0, n_past, -1);
}
} }
embd.clear(); embd.clear();
@ -527,6 +535,170 @@ const char * kcpp_print_system_info(void) {
return s.c_str(); return s.c_str();
} }
struct kcpp_embd_batch { //duplcated from llava_embd_batch
std::vector<int32_t> pos;
std::vector<int32_t> n_seq_id;
std::vector<int32_t> seq_id_0;
std::vector<int32_t *> seq_ids;
std::vector<int8_t> logits;
llama_batch batch;
kcpp_embd_batch(float * embd, int32_t n_tokens, int32_t npast) {
int32_t seq_id = 0;
pos.resize(n_tokens);
n_seq_id.resize(n_tokens);
seq_ids.resize(n_tokens + 1);
logits.resize(n_tokens);
seq_id_0.resize(1);
seq_id_0[0] = seq_id;
seq_ids [n_tokens] = nullptr;
batch = {
/*n_tokens =*/ n_tokens,
/*tokens =*/ nullptr,
/*embd =*/ embd,
/*pos =*/ pos.data(),
/*n_seq_id =*/ n_seq_id.data(),
/*seq_id =*/ seq_ids.data(),
/*logits =*/ logits.data(),
};
for (int i = 0; i < n_tokens; i++) {
batch.pos [i] = npast + i;
batch.n_seq_id[i] = 1;
batch.seq_id [i] = seq_id_0.data();
batch.logits [i] = false;
}
}
kcpp_embd_batch(std::vector<llama_token> & tokens, int32_t npast, bool return_all_logits) {
int32_t seq_id = 0;
int32_t n_tokens = tokens.size();
pos.resize(n_tokens);
n_seq_id.resize(n_tokens);
seq_ids.resize(n_tokens + 1);
logits.resize(n_tokens);
seq_id_0.resize(1);
seq_id_0[0] = seq_id;
seq_ids [n_tokens] = nullptr;
batch = {
/*n_tokens =*/ n_tokens,
/*tokens =*/ tokens.data(),
/*embd =*/ nullptr,
/*pos =*/ pos.data(),
/*n_seq_id =*/ n_seq_id.data(),
/*seq_id =*/ seq_ids.data(),
/*logits =*/ logits.data(),
};
for (int i = 0; i < n_tokens; i++) {
batch.pos [i] = npast + i;
batch.n_seq_id[i] = 1;
batch.seq_id [i] = seq_id_0.data();
batch.logits [i] = (return_all_logits?true:false);
}
batch.logits[n_tokens - 1] = true;
}
};
//loads a model for speculative decoding.
static void speculative_decoding_setup(std::string spec_model_filename, const llama_model_params & base_model_params, const llama_context_params & base_ctx_params, int base_n_vocab)
{
llama_model_params draft_model_params = llama_model_default_params();
llama_context_params draft_ctx_params = llama_context_default_params();
draft_model_params.use_mmap = base_model_params.use_mmap;
draft_model_params.use_mlock = base_model_params.use_mlock;
draft_model_params.n_gpu_layers = 999; //assume they want to fully offload the speculative model. Otherwise, why even use it?
draft_ctx_params.n_ctx = base_ctx_params.n_ctx;
draft_ctx_params.logits_all = false;
draft_ctx_params.offload_kqv = base_ctx_params.offload_kqv;
draft_model_params.main_gpu = base_model_params.main_gpu;
draft_model_params.split_mode = llama_split_mode::LLAMA_SPLIT_MODE_LAYER;
draft_ctx_params.n_batch = base_ctx_params.n_batch;
draft_ctx_params.n_ubatch = base_ctx_params.n_ubatch;
draft_ctx_params.n_threads = base_ctx_params.n_threads;
draft_ctx_params.n_threads_batch = base_ctx_params.n_threads_batch;
draft_ctx_params.flash_attn = base_ctx_params.flash_attn;
draft_ctx_params.type_k = base_ctx_params.type_k;
draft_ctx_params.type_v = base_ctx_params.type_v;
llama_model * draftmodel = llama_load_model_from_file(spec_model_filename.c_str(), draft_model_params);
draft_ctx = llama_new_context_with_model(draftmodel, draft_ctx_params);
if(draft_ctx == NULL)
{
printf("Error: failed to load speculative decoding draft model '%s'\n", spec_model_filename.c_str());
printf("Speculative Decoding will not be used!\n");
}
else
{
int draftvocab = llama_n_vocab(draftmodel);
if(draftvocab!=base_n_vocab)
{
printf("Error: Draft model vocab of (%d) does not match base vocab of (%d). Speculative decoding cannot be used!\n",draftvocab,base_n_vocab);
llama_free(draft_ctx);
draft_ctx = nullptr;
}else if(llama_model_is_recurrent(draftmodel))
{
printf("Error: Speculative decoding cannot be used with Recurrent draft models!\n");
llama_free(draft_ctx);
draft_ctx = nullptr;
}
}
}
static speculative_draft_result speculative_decoding_eval_chunk(llama_context * draft_ctx, llama_context * main_ctx, const llama_tokens & embd, const int n_vocab, const int & n_past)
{
speculative_draft_result results;
results.draft_success = false;
if(embd.size()==0)
{
printf("\nERROR: Speculate on empty batch!\n");
return results;
}
if(embd.size()>1)
{
printf("\nERROR: Speculative decoding applied on large batch!\n");
return results;
}
int draft_npast = n_past;
int actual_npast = n_past;
std::vector<int> temp_embd;
std::vector<int> drafted_ids;
temp_embd.push_back(embd[0]);
drafted_ids.push_back(embd[0]);
for(int i=0;i<speculative_chunk_amt;++i)
{
kcpp_embd_batch batch1 = kcpp_embd_batch(temp_embd, draft_npast, false);
auto draftok = (llama_decode(draft_ctx, batch1.batch)==0);
if(!draftok)
{
printf("\nERROR: Speculative draft model 1 failed!\n");
return results;
}
float * draftlogits = llama_get_logits(draft_ctx);
//greedy sample the draft model
int topid = std::max_element(draftlogits, draftlogits + n_vocab) - draftlogits;
drafted_ids.push_back(topid);
temp_embd.clear();
temp_embd.push_back(topid);
++draft_npast;
}
//now that we have our drafted tokens, we form a batch and PP it
kcpp_embd_batch batch2 = kcpp_embd_batch(drafted_ids, actual_npast, true);
auto draftok = (llama_decode(main_ctx, batch2.batch)==0); //actual eval for big model
if(!draftok)
{
printf("\nERROR: Speculative draft model 2 failed!\n");
return results;
}
results.drafted_amount = 0;
for(int i=0;i<drafted_ids.size()-1;++i)
{
results.drafted_amount += 1;
float * fulllogits = llama_get_logits_ith(main_ctx,i);
results.draftids.push_back(drafted_ids[i+1]);
results.actual_logits.push_back(fulllogits);
}
results.draft_success = true;
return results;
}
// KCPP SAMPLING FUNCTIONS // KCPP SAMPLING FUNCTIONS
void sample_softmax(llama_token_data_array * cur_p) { void sample_softmax(llama_token_data_array * cur_p) {
GGML_ASSERT(cur_p->size > 0); GGML_ASSERT(cur_p->size > 0);
@ -1554,66 +1726,6 @@ static void load_grammar(const std::string & gammarstr)
} }
} }
struct kcpp_embd_batch { //duplcated from llava_embd_batch
std::vector<int32_t> pos;
std::vector<int32_t> n_seq_id;
std::vector<int32_t> seq_id_0;
std::vector<int32_t *> seq_ids;
std::vector<int8_t> logits;
llama_batch batch;
kcpp_embd_batch(float * embd, int32_t n_tokens, int32_t npast) {
int32_t seq_id = 0;
pos.resize(n_tokens);
n_seq_id.resize(n_tokens);
seq_ids.resize(n_tokens + 1);
logits.resize(n_tokens);
seq_id_0.resize(1);
seq_id_0[0] = seq_id;
seq_ids [n_tokens] = nullptr;
batch = {
/*n_tokens =*/ n_tokens,
/*tokens =*/ nullptr,
/*embd =*/ embd,
/*pos =*/ pos.data(),
/*n_seq_id =*/ n_seq_id.data(),
/*seq_id =*/ seq_ids.data(),
/*logits =*/ logits.data(),
};
for (int i = 0; i < n_tokens; i++) {
batch.pos [i] = npast + i;
batch.n_seq_id[i] = 1;
batch.seq_id [i] = seq_id_0.data();
batch.logits [i] = false;
}
}
kcpp_embd_batch(std::vector<llama_token> & tokens, int32_t npast) {
int32_t seq_id = 0;
int32_t n_tokens = tokens.size();
pos.resize(n_tokens);
n_seq_id.resize(n_tokens);
seq_ids.resize(n_tokens + 1);
logits.resize(n_tokens);
seq_id_0.resize(1);
seq_id_0[0] = seq_id;
seq_ids [n_tokens] = nullptr;
batch = {
/*n_tokens =*/ n_tokens,
/*tokens =*/ tokens.data(),
/*embd =*/ nullptr,
/*pos =*/ pos.data(),
/*n_seq_id =*/ n_seq_id.data(),
/*seq_id =*/ seq_ids.data(),
/*logits =*/ logits.data(),
};
for (int i = 0; i < n_tokens; i++) {
batch.pos [i] = npast + i;
batch.n_seq_id[i] = 1;
batch.seq_id [i] = seq_id_0.data();
batch.logits [i] = false;
}
batch.logits[n_tokens - 1] = true;
}
};
static bool kcpp_eval_image(llama_context * ctx_llama, float * img_embd, int num_img_tokens, int n_batch, int * n_past) { static bool kcpp_eval_image(llama_context * ctx_llama, float * img_embd, int num_img_tokens, int n_batch, int * n_past) {
int n_embd = llama_n_embd(llama_get_model(ctx_llama)); int n_embd = llama_n_embd(llama_get_model(ctx_llama));
@ -1635,7 +1747,7 @@ static bool kcpp_eval_image(llama_context * ctx_llama, float * img_embd, int num
//given an old GGUF context and a new context that has some middle portion removed, //given an old GGUF context and a new context that has some middle portion removed,
//find and remove the middle portion from the old context from the KV. Does not fast forward after this destructive action //find and remove the middle portion from the old context from the KV. Does not fast forward after this destructive action
void PurgeMissingTokens(llama_context * ctx, std::vector<int> &current_context_tokens, std::vector<int> &new_context_tokens, const int genamt, const int nctx) void PurgeMissingTokens(llama_context * ctx, llama_context * draft_ctx, std::vector<int> &current_context_tokens, std::vector<int> &new_context_tokens, const int genamt, const int nctx)
{ {
//scan from start old and new ctx, until first mismatch found, save as p0 //scan from start old and new ctx, until first mismatch found, save as p0
//check remaining old and new ctx for longest common subseq, which needs to be at 256 tokens //check remaining old and new ctx for longest common subseq, which needs to be at 256 tokens
@ -1688,8 +1800,13 @@ void PurgeMissingTokens(llama_context * ctx, std::vector<int> &current_context_t
//extract the unwanted tokens out from context and KV //extract the unwanted tokens out from context and KV
int diff = found - trimstart; int diff = found - trimstart;
llama_kv_cache_seq_rm(llama_ctx_v4, 0, trimstart, trimstart + diff); llama_kv_cache_seq_rm(ctx, 0, trimstart, trimstart + diff);
llama_kv_cache_seq_add(llama_ctx_v4, 0, trimstart + diff, -1, -diff); llama_kv_cache_seq_add(ctx, 0, trimstart + diff, -1, -diff);
if(draft_ctx)
{
llama_kv_cache_seq_rm(draft_ctx, 0, trimstart, trimstart + diff);
llama_kv_cache_seq_add(draft_ctx, 0, trimstart + diff, -1, -diff);
}
for (size_t i = trimstart + diff; i < current_context_tokens.size() - 1; i++) for (size_t i = trimstart + diff; i < current_context_tokens.size() - 1; i++)
{ {
@ -1791,6 +1908,7 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in
kcpp_data->use_contextshift = inputs.use_contextshift; kcpp_data->use_contextshift = inputs.use_contextshift;
kcpp_data->use_fastforward = inputs.use_fastforward; kcpp_data->use_fastforward = inputs.use_fastforward;
debugmode = inputs.debugmode; debugmode = inputs.debugmode;
draft_ctx = nullptr;
auto clamped_max_context_length = inputs.max_context_length; auto clamped_max_context_length = inputs.max_context_length;
@ -2136,6 +2254,23 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in
n_vocab = llama_n_vocab(llamamodel); n_vocab = llama_n_vocab(llamamodel);
if(draftmodel_filename !="" && file_format==FileFormat::GGUF_GENERIC)
{
if(llama_model_is_recurrent(llamamodel))
{
printf("Error: Speculative decoding cannot be used with Recurrent models!\n");
}
else if(clp_ctx!=nullptr)
{
printf("Error: Speculative decoding cannot be used with multimodal vision projectors!\n");
}
else
{
printf("\nAttempting to load draft model for speculative decoding. It will be fully offloaded if possible. Vocab must match the main model.\n");
speculative_decoding_setup(draftmodel_filename, model_params, llama_ctx_params, n_vocab);
}
}
//determine mem per token //determine mem per token
std::vector<int> tmp = {1, 2, 3, 4}; std::vector<int> tmp = {1, 2, 3, 4};
llama_kv_cache_clear(llama_ctx_v4); llama_kv_cache_clear(llama_ctx_v4);
@ -3090,6 +3225,10 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
if(n_past==0) if(n_past==0)
{ {
llama_kv_cache_clear(llama_ctx_v4); llama_kv_cache_clear(llama_ctx_v4);
if(draft_ctx)
{
llama_kv_cache_clear(draft_ctx);
}
} }
else if(embd_inp.size()==0) else if(embd_inp.size()==0)
{ {
@ -3105,7 +3244,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
{ {
if(kcpp_data->use_fastforward && kcpp_data->use_contextshift && (file_format == FileFormat::GGUF_GENERIC)) if(kcpp_data->use_fastforward && kcpp_data->use_contextshift && (file_format == FileFormat::GGUF_GENERIC))
{ {
PurgeMissingTokens(llama_ctx_v4, current_context_tokens, embd_inp, inputs.max_length, nctx); PurgeMissingTokens(llama_ctx_v4, draft_ctx, current_context_tokens, embd_inp, inputs.max_length, nctx);
triggersc = false; triggersc = false;
} }
if(kcpp_data->use_fastforward) if(kcpp_data->use_fastforward)
@ -3116,6 +3255,10 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
if(file_format == FileFormat::GGUF_GENERIC) if(file_format == FileFormat::GGUF_GENERIC)
{ {
llama_kv_cache_seq_rm(llama_ctx_v4, 0, n_past, -1); llama_kv_cache_seq_rm(llama_ctx_v4, 0, n_past, -1);
if(draft_ctx)
{
llama_kv_cache_seq_rm(draft_ctx, 0, n_past, -1);
}
} }
} }
@ -3153,6 +3296,9 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
bool startedsampling = false; bool startedsampling = false;
bool v3_use_scratch = true; //for normal inference always use scratch bool v3_use_scratch = true; //for normal inference always use scratch
speculative_draft_result draft_results; //only use if drafting was used
bool draft_used = false;
time0 = timer_check(); time0 = timer_check();
timer_start(); timer_start();
@ -3223,9 +3369,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
if (embdsize > 0) if (embdsize > 0)
{ {
bool evalres = false; bool evalres = false;
if (file_format == FileFormat::GGML || file_format == FileFormat::GGHF || file_format == FileFormat::GGJT || file_format == FileFormat::GGJT_2) if (file_format == FileFormat::GGML || file_format == FileFormat::GGHF || file_format == FileFormat::GGJT || file_format == FileFormat::GGJT_2)
{ {
evalres = (llama_v2_eval(llama_ctx_v2, embd.data(), embdsize, n_past, GetThreadsToUse(blasmode))==0); evalres = (llama_v2_eval(llama_ctx_v2, embd.data(), embdsize, n_past, GetThreadsToUse(blasmode))==0);
@ -3236,8 +3380,20 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
} }
else if(file_format == FileFormat::GGUF_GENERIC) else if(file_format == FileFormat::GGUF_GENERIC)
{ {
kcpp_embd_batch batch = kcpp_embd_batch(embd, n_past); if(embd.size()!=1 || draft_ctx==nullptr || remaining_tokens<=speculative_chunk_amt || grammar!=nullptr || startedsampling==false) //for large batch, or if no draft model, PP/TG as usual
{
draft_used = false;
kcpp_embd_batch batch = kcpp_embd_batch(embd, n_past, false);
evalres = (llama_decode(llama_ctx_v4, batch.batch)==0); evalres = (llama_decode(llama_ctx_v4, batch.batch)==0);
if(draft_ctx)
{
evalres = (evalres && (llama_decode(draft_ctx, batch.batch)==0));
}
} else { //individual tokens AND speculative is used (generation)
draft_used = true;
draft_results = speculative_decoding_eval_chunk(draft_ctx, llama_ctx_v4, embd, n_vocab, n_past);
evalres = draft_results.draft_success;
}
} }
else if(file_format==FileFormat::RWKV_1 || file_format==FileFormat::RWKV_2) else if(file_format==FileFormat::RWKV_1 || file_format==FileFormat::RWKV_2)
{ {
@ -3350,12 +3506,35 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
float * logitsPtr; float * logitsPtr;
float lowestLogit = 0; float lowestLogit = 0;
int btsize = banned_token_ids.size(); int btsize = banned_token_ids.size();
//sample pending logits. usually only 1, unless speculative decoding
int logits_to_sample = 1;
int logits_sampled = 0;
bool abort_draft = false;
if(draft_used)
{
logits_to_sample = draft_results.drafted_amount;
}
while(logits_sampled<logits_to_sample && remaining_tokens>0 && !abort_draft)
{
if(logits_sampled>0)
{
//this is not the first loop, so we need to increment some things
n_past += 1;
}
if(file_format == FileFormat::GGML || file_format == FileFormat::GGHF || file_format == FileFormat::GGJT || file_format == FileFormat::GGJT_2 || file_format == FileFormat::GGJT_3 || file_format == FileFormat::GGUF_GENERIC) if(file_format == FileFormat::GGML || file_format == FileFormat::GGHF || file_format == FileFormat::GGJT || file_format == FileFormat::GGJT_2 || file_format == FileFormat::GGJT_3 || file_format == FileFormat::GGUF_GENERIC)
{ {
if(file_format == FileFormat::GGUF_GENERIC) if(file_format == FileFormat::GGUF_GENERIC)
{
if(draft_used)
{
logitsPtr = draft_results.actual_logits[logits_sampled];
}
else
{ {
logitsPtr = llama_get_logits(llama_ctx_v4); logitsPtr = llama_get_logits(llama_ctx_v4);
} }
}
else if(file_format == FileFormat::GGJT_3) else if(file_format == FileFormat::GGJT_3)
{ {
logitsPtr = llama_v3_get_logits(llama_ctx_v3); logitsPtr = llama_v3_get_logits(llama_ctx_v3);
@ -3368,7 +3547,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
} }
else else
{ {
logitsPtr = logits.data(); logitsPtr = logits.data(); //legacy rwkv, neox, gptj etc
lowestLogit = LowestLogit(logits); lowestLogit = LowestLogit(logits);
} }
@ -3408,6 +3587,21 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
kcpp_data->dry_allowed_length, kcpp_data->dry_penalty_last_n, kcpp_data->xtc_threshold, kcpp_data->xtc_probability, kcpp_data->dry_allowed_length, kcpp_data->dry_penalty_last_n, kcpp_data->xtc_threshold, kcpp_data->xtc_probability,
sampler_order, grammar, dynatemp_range, dynatemp_exponent, smoothing_factor); sampler_order, grammar, dynatemp_range, dynatemp_exponent, smoothing_factor);
if(draft_used)
{
int32_t draftedid = draft_results.draftids[logits_sampled];
if(debugmode==1)
{
std::string drafttok = FileFormatTokenizeID(draftedid, file_format, true);
std::string realtok = FileFormatTokenizeID(id, file_format, true);
printf("(Draft %d/%d): Predicted=%d (%s), Actual=%d (%s) [%s]\n",(logits_sampled+1),logits_to_sample,draftedid,drafttok.c_str(),id,realtok.c_str(),(draftedid==id?"PASS":"FAIL"));
}
if(draftedid!=id) //draft mismatch, abort
{
abort_draft = true;
}
}
if (grammar != nullptr) { if (grammar != nullptr) {
grammar_accept_token(file_format, n_vocab, grammar, id); grammar_accept_token(file_format, n_vocab, grammar, id);
} }
@ -3420,6 +3614,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
current_context_tokens.push_back(id); current_context_tokens.push_back(id);
// add it to the context // add it to the context
embd.clear();
embd.push_back(id); embd.push_back(id);
// decrement remaining sampling budget // decrement remaining sampling budget
@ -3503,6 +3698,9 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
delayed_generated_tokens.resize(delayed_generated_tokens.size() - rewind_amt); delayed_generated_tokens.resize(delayed_generated_tokens.size() - rewind_amt);
ContextRewind(embd, current_context_tokens, n_past, last_n_tokens, rewind_amt); ContextRewind(embd, current_context_tokens, n_past, last_n_tokens, rewind_amt);
//immediately terminate drafting if used
abort_draft = true;
// Check if the key exists // Check if the key exists
int banindex = n_past+1; int banindex = n_past+1;
if (antislop_banned_token_ids.find(banindex) == antislop_banned_token_ids.end()) { if (antislop_banned_token_ids.find(banindex) == antislop_banned_token_ids.end()) {
@ -3577,6 +3775,18 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
} }
} }
logits_sampled += 1;
}
//if we have somehow skipped ahead (e.g drafting), ensure that all tokens after npast are purged
if (file_format == FileFormat::GGUF_GENERIC && draft_used)
{
llama_kv_cache_seq_rm(llama_ctx_v4, 0, n_past, -1);
if (draft_ctx) {
llama_kv_cache_seq_rm(draft_ctx, 0, n_past, -1);
}
}
fflush(stdout); fflush(stdout);
} }
else else
@ -3611,6 +3821,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
} }
for(int i=0;i<llava_images.size();++i) for(int i=0;i<llava_images.size();++i)
{ {
//note: no handling for draft_ctx as we don't support vision for it
if(i>0 && sepsize>0) if(i>0 && sepsize>0)
{ {
//add a separator between each image //add a separator between each image
@ -3701,7 +3912,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
float pt2 = (time2*1000.0/(realnpredict==0?1:realnpredict)); float pt2 = (time2*1000.0/(realnpredict==0?1:realnpredict));
float ts2 = (1000.0/pt2); float ts2 = (1000.0/pt2);
float tokens_per_second = (realnpredict == 0 ? 0 : realnpredict / (time1 + time2)); float tokens_per_second = (realnpredict == 0 ? 0 : realnpredict / (time1 + time2));
printf("\nCtxLimit:%d/%d, Amt:%d/%d, Init:%.2fs, Process:%.2fs (%.1fms/T = %.2fT/s), Generate:%.2fs (%.1fms/T = %.2fT/s), Total:%.2fs (%.2fT/s)",(int)current_context_tokens.size(),(int)nctx, realnpredict, kcpp_data->n_predict, time0, time1, pt1, ts1, time2, pt2, ts2, (time1 + time2), tokens_per_second); printf("\n[%s] CtxLimit:%d/%d, Amt:%d/%d, Init:%.2fs, Process:%.2fs (%.1fms/T = %.2fT/s), Generate:%.2fs (%.1fms/T = %.2fT/s), Total:%.2fs (%.2fT/s)",get_timestamp_str().c_str(),(int)current_context_tokens.size(),(int)nctx, realnpredict, kcpp_data->n_predict, time0, time1, pt1, ts1, time2, pt2, ts2, (time1 + time2), tokens_per_second);
fflush(stdout); fflush(stdout);
output.status = 1; output.status = 1;
int finaltokcount = (int)current_context_tokens.size()-realnpredict; int finaltokcount = (int)current_context_tokens.size()-realnpredict;

View file

@ -131,6 +131,7 @@ class load_model_inputs(ctypes.Structure):
("model_filename", ctypes.c_char_p), ("model_filename", ctypes.c_char_p),
("lora_filename", ctypes.c_char_p), ("lora_filename", ctypes.c_char_p),
("lora_base", ctypes.c_char_p), ("lora_base", ctypes.c_char_p),
("draftmodel_filename", ctypes.c_char_p),
("mmproj_filename", ctypes.c_char_p), ("mmproj_filename", ctypes.c_char_p),
("use_mmap", ctypes.c_bool), ("use_mmap", ctypes.c_bool),
("use_mlock", ctypes.c_bool), ("use_mlock", ctypes.c_bool),
@ -672,24 +673,27 @@ def read_gguf_metadata(file_path):
except Exception as ex: except Exception as ex:
return None return None
def extract_modelfile_params(filepath,sdfilepath,whisperfilepath,mmprojfilepath): def extract_modelfile_params(filepath,sdfilepath,whisperfilepath,mmprojfilepath,draftmodelpath):
global modelfile_extracted_meta global modelfile_extracted_meta
modelfile_extracted_meta = None modelfile_extracted_meta = None
sdfsize = 0 sdfsize = 0
whisperfsize = 0 whisperfsize = 0
mmprojsize = 0 mmprojsize = 0
draftmodelsize = 0
if sdfilepath and os.path.exists(sdfilepath): if sdfilepath and os.path.exists(sdfilepath):
sdfsize = os.path.getsize(sdfilepath) sdfsize = os.path.getsize(sdfilepath)
if whisperfilepath and os.path.exists(whisperfilepath): if whisperfilepath and os.path.exists(whisperfilepath):
whisperfsize = os.path.getsize(whisperfilepath) whisperfsize = os.path.getsize(whisperfilepath)
if mmprojfilepath and os.path.exists(mmprojfilepath): if mmprojfilepath and os.path.exists(mmprojfilepath):
mmprojsize = os.path.getsize(mmprojfilepath) mmprojsize = os.path.getsize(mmprojfilepath)
if draftmodelpath and os.path.exists(draftmodelpath):
draftmodelsize = os.path.getsize(draftmodelpath)
if filepath and os.path.exists(filepath): if filepath and os.path.exists(filepath):
try: try:
fsize = os.path.getsize(filepath) fsize = os.path.getsize(filepath)
if fsize>10000000: #dont bother with models < 10mb as they are probably bad if fsize>10000000: #dont bother with models < 10mb as they are probably bad
ggufmeta = read_gguf_metadata(filepath) ggufmeta = read_gguf_metadata(filepath)
modelfile_extracted_meta = [ggufmeta,fsize,sdfsize,whisperfsize,mmprojsize] #extract done. note that meta may be null modelfile_extracted_meta = [ggufmeta,fsize,sdfsize,whisperfsize,mmprojsize,draftmodelsize] #extract done. note that meta may be null
except Exception as ex: except Exception as ex:
modelfile_extracted_meta = None modelfile_extracted_meta = None
@ -702,7 +706,7 @@ def autoset_gpu_layers(ctxsize,sdquanted,bbs): #shitty algo to determine how man
if showusedmemwarning and usedmem > (2.5*1024*1024*1024): if showusedmemwarning and usedmem > (2.5*1024*1024*1024):
showusedmemwarning = False showusedmemwarning = False
print(f"Note: KoboldCpp has detected that a significant amount of GPU VRAM ({usedmem/1024/1024} MB) is currently used by another application.\nFor best results, you may wish to close that application and then restart KoboldCpp.\n***") print(f"Note: KoboldCpp has detected that a significant amount of GPU VRAM ({usedmem/1024/1024} MB) is currently used by another application.\nFor best results, you may wish to close that application and then restart KoboldCpp.\n***")
reservedmem = max(1.5*1024*1024*1024,(0.5*1024*1024*1024 + usedmem)) # determine vram overhead reservedmem = max(1.3*1024*1024*1024,(0.5*1024*1024*1024 + usedmem)) # determine vram overhead
try: try:
if not modelfile_extracted_meta: if not modelfile_extracted_meta:
return 0 return 0
@ -719,6 +723,9 @@ def autoset_gpu_layers(ctxsize,sdquanted,bbs): #shitty algo to determine how man
mem -= 350*1024*1024 mem -= 350*1024*1024
if modelfile_extracted_meta[4] > 1024*1024*10: #mmproj tax if modelfile_extracted_meta[4] > 1024*1024*10: #mmproj tax
mem -= 350*1024*1024 mem -= 350*1024*1024
if modelfile_extracted_meta[5] > 1024*1024*10: #draft model tax
mem -= (modelfile_extracted_meta[5] * 1.5)
mem = 0 if mem < 0 else mem
csmul = 1.0 csmul = 1.0
if cs: if cs:
@ -732,8 +739,8 @@ def autoset_gpu_layers(ctxsize,sdquanted,bbs): #shitty algo to determine how man
headcount = ggufmeta[1] headcount = ggufmeta[1]
headkvlen = (ggufmeta[2] if ggufmeta[2] > 0 else 128) headkvlen = (ggufmeta[2] if ggufmeta[2] > 0 else 128)
ratio = (mem-usedmem)/(fsize*csmul*1.6*(1.0 if bbs <= 512 else 1.2)) ratio = (mem-usedmem)/(fsize*csmul*1.6*(1.0 if bbs <= 512 else 1.2))
computemem = layers*(4 if bbs <= 512 else (bbs/128))*headkvlen*cs*4*1.5 # apply blasbatchsize calculations if over 512 computemem = layers*(4 if bbs <= 512 else (bbs/128))*headkvlen*cs*4*1.55 # apply blasbatchsize calculations if over 512
contextmem = layers*headcount*headkvlen*cs*4*1.1 contextmem = layers*headcount*headkvlen*cs*4*1.15
if headcount > 0: if headcount > 0:
ratio = max(ratio, (mem - reservedmem - computemem) / (fsize + contextmem)) ratio = max(ratio, (mem - reservedmem - computemem) / (fsize + contextmem))
layerlimit = min(int(ratio*layers), (layers + 3)) layerlimit = min(int(ratio*layers), (layers + 3))
@ -877,6 +884,7 @@ def load_model(model_filename):
if len(args.lora) > 1: if len(args.lora) > 1:
inputs.lora_base = args.lora[1].encode("UTF-8") inputs.lora_base = args.lora[1].encode("UTF-8")
inputs.draftmodel_filename = args.draftmodel.encode("UTF-8") if args.draftmodel else "".encode("UTF-8")
inputs.mmproj_filename = args.mmproj.encode("UTF-8") if args.mmproj else "".encode("UTF-8") inputs.mmproj_filename = args.mmproj.encode("UTF-8") if args.mmproj else "".encode("UTF-8")
inputs.use_smartcontext = args.smartcontext inputs.use_smartcontext = args.smartcontext
inputs.use_contextshift = (0 if args.noshift else 1) inputs.use_contextshift = (0 if args.noshift else 1)
@ -1510,15 +1518,18 @@ ws ::= | " " | "\n" [ \t]{0,20}
elif api_format==6: elif api_format==6:
detokstr = "" detokstr = ""
tokids = genparams.get('context', []) tokids = genparams.get('context', [])
adapter_obj = {} if chatcompl_adapter is None else chatcompl_adapter
user_message_start = adapter_obj.get("user_start", "\n\n### Instruction:\n")
assistant_message_start = adapter_obj.get("assistant_start", "\n\n### Response:\n")
try: try:
detokstr = detokenize_ids(tokids) detokstr = detokenize_ids(tokids)
except Exception as e: except Exception as e:
utfprint("Ollama Context Error: " + str(e)) utfprint("Ollama Context Error: " + str(e))
ollamasysprompt = genparams.get('system', "") ollamasysprompt = genparams.get('system', "")
ollamabodyprompt = detokstr + "\n\n### Instruction:\n" + genparams.get('prompt', "") + "\n\n### Response:\n" ollamabodyprompt = f"{detokstr}{user_message_start}{genparams.get('prompt', '')}{assistant_message_start}"
genparams["stop_sequence"] = genparams.get('stop', []) genparams["stop_sequence"] = genparams.get('stop', [])
genparams["stop_sequence"].append("\n### Instruction:") genparams["stop_sequence"].append(user_message_start.strip())
genparams["stop_sequence"].append("\n### Response:") genparams["stop_sequence"].append(assistant_message_start.strip())
genparams["trim_stop"] = True genparams["trim_stop"] = True
genparams["ollamasysprompt"] = ollamasysprompt genparams["ollamasysprompt"] = ollamasysprompt
genparams["ollamabodyprompt"] = ollamabodyprompt genparams["ollamabodyprompt"] = ollamabodyprompt
@ -2374,9 +2385,8 @@ Enter Prompt:<br>
return return
is_quiet = args.quiet is_quiet = args.quiet
utfprint(f"\n{datetime.now().strftime('[%H:%M:%S] Input Received')}")
if (args.debugmode != -1 and not is_quiet) or args.debugmode >= 1: if (args.debugmode != -1 and not is_quiet) or args.debugmode >= 1:
utfprint(f"Input: " + json.dumps(genparams)) utfprint(f"\nInput: " + json.dumps(genparams))
if args.foreground: if args.foreground:
bring_terminal_to_foreground() bring_terminal_to_foreground()
@ -2751,6 +2761,7 @@ def show_gui():
lora_base_var = ctk.StringVar() lora_base_var = ctk.StringVar()
preloadstory_var = ctk.StringVar() preloadstory_var = ctk.StringVar()
mmproj_var = ctk.StringVar() mmproj_var = ctk.StringVar()
draftmodel_var = ctk.StringVar()
nomodel = ctk.IntVar(value=0) nomodel = ctk.IntVar(value=0)
port_var = ctk.StringVar(value=defaultport) port_var = ctk.StringVar(value=defaultport)
@ -2929,7 +2940,8 @@ def show_gui():
sdfilepath = sd_model_var.get() sdfilepath = sd_model_var.get()
whisperfilepath = whisper_model_var.get() whisperfilepath = whisper_model_var.get()
mmprojfilepath = mmproj_var.get() mmprojfilepath = mmproj_var.get()
extract_modelfile_params(filepath,sdfilepath,whisperfilepath,mmprojfilepath) draftmodelpath = draftmodel_var.get()
extract_modelfile_params(filepath,sdfilepath,whisperfilepath,mmprojfilepath,draftmodelpath)
changed_gpulayers_estimate() changed_gpulayers_estimate()
pass pass
@ -3234,18 +3246,20 @@ def show_gui():
makefileentry(model_tab, "Text Lora:", "Select Lora File",lora_var, 3,width=280,tooltiptxt="Select an optional GGML LoRA adapter to use.\nLeave blank to skip.") makefileentry(model_tab, "Text Lora:", "Select Lora File",lora_var, 3,width=280,tooltiptxt="Select an optional GGML LoRA adapter to use.\nLeave blank to skip.")
makefileentry(model_tab, "Text Lora Base:", "Select Lora Base File", lora_base_var, 5,width=280,tooltiptxt="Select an optional F16 GGML LoRA base file to use.\nLeave blank to skip.") makefileentry(model_tab, "Text Lora Base:", "Select Lora Base File", lora_base_var, 5,width=280,tooltiptxt="Select an optional F16 GGML LoRA base file to use.\nLeave blank to skip.")
makefileentry(model_tab, "Vision mmproj:", "Select Vision mmproj File", mmproj_var, 7,width=280,tooltiptxt="Select a mmproj file to use for vision models like LLaVA.\nLeave blank to skip.") makefileentry(model_tab, "Vision mmproj:", "Select Vision mmproj File", mmproj_var, 7,width=280,tooltiptxt="Select a mmproj file to use for vision models like LLaVA.\nLeave blank to skip.")
makefileentry(model_tab, "Preloaded Story:", "Select Preloaded Story File", preloadstory_var, 9,width=280,tooltiptxt="Select an optional KoboldAI JSON savefile \nto be served on launch to any client.") makefileentry(model_tab, "Speculative Model:", "Select Draft Text Model File", draftmodel_var, 9,width=280,tooltiptxt="Select a draft text model file to use for speculative decoding.\nLeave blank to skip.")
makefileentry(model_tab, "ChatCompletions Adapter:", "Select ChatCompletions Adapter File", chatcompletionsadapter_var, 12, width=250, filetypes=[("JSON Adapter", "*.json")], tooltiptxt="Select an optional ChatCompletions Adapter JSON file to force custom instruct tags.") makefileentry(model_tab, "Preloaded Story:", "Select Preloaded Story File", preloadstory_var, 11,width=280,tooltiptxt="Select an optional KoboldAI JSON savefile \nto be served on launch to any client.")
makefileentry(model_tab, "ChatCompletions Adapter:", "Select ChatCompletions Adapter File", chatcompletionsadapter_var, 14, width=250, filetypes=[("JSON Adapter", "*.json")], tooltiptxt="Select an optional ChatCompletions Adapter JSON file to force custom instruct tags.")
def pickpremadetemplate(): def pickpremadetemplate():
initialDir = os.path.join(os.path.abspath(os.path.dirname(__file__)), 'kcpp_adapters') initialDir = os.path.join(os.path.abspath(os.path.dirname(__file__)), 'kcpp_adapters')
initialDir = initialDir if os.path.isdir(initialDir) else None initialDir = initialDir if os.path.isdir(initialDir) else None
fnam = askopenfilename(title="Pick Premade ChatCompletions Adapter",filetypes=[("JSON Adapter", "*.json")], initialdir=initialDir) fnam = askopenfilename(title="Pick Premade ChatCompletions Adapter",filetypes=[("JSON Adapter", "*.json")], initialdir=initialDir)
if fnam: if fnam:
chatcompletionsadapter_var.set(fnam) chatcompletionsadapter_var.set(fnam)
ctk.CTkButton(model_tab, 64, text="Pick Premade", command=pickpremadetemplate).grid(row=13, column=0, padx=322, stick="nw") ctk.CTkButton(model_tab, 64, text="Pick Premade", command=pickpremadetemplate).grid(row=15, column=0, padx=322, stick="nw")
mmproj_var.trace("w", gui_changed_modelfile) mmproj_var.trace("w", gui_changed_modelfile)
makecheckbox(model_tab, "Allow Launch Without Models", nomodel, 15, tooltiptxt="Allows running the WebUI with no model loaded.") draftmodel_var.trace("w", gui_changed_modelfile)
makecheckbox(model_tab, "Allow Launch Without Models", nomodel, 17, tooltiptxt="Allows running the WebUI with no model loaded.")
# Network Tab # Network Tab
network_tab = tabcontent["Network"] network_tab = tabcontent["Network"]
@ -3489,6 +3503,7 @@ def show_gui():
except Exception as ex2: except Exception as ex2:
pass pass
args.mmproj = None if mmproj_var.get() == "" else mmproj_var.get() args.mmproj = None if mmproj_var.get() == "" else mmproj_var.get()
args.draftmodel = None if draftmodel_var.get() == "" else draftmodel_var.get()
args.ssl = None if (ssl_cert_var.get() == "" or ssl_key_var.get() == "") else ([ssl_cert_var.get(), ssl_key_var.get()]) args.ssl = None if (ssl_cert_var.get() == "" or ssl_key_var.get() == "") else ([ssl_cert_var.get(), ssl_key_var.get()])
args.password = None if (password_var.get() == "") else (password_var.get()) args.password = None if (password_var.get() == "") else (password_var.get())
@ -3649,6 +3664,7 @@ def show_gui():
lora_var.set(dict["lora"][0]) lora_var.set(dict["lora"][0])
mmproj_var.set(dict["mmproj"] if ("mmproj" in dict and dict["mmproj"]) else "") mmproj_var.set(dict["mmproj"] if ("mmproj" in dict and dict["mmproj"]) else "")
draftmodel_var.set(dict["draftmodel"] if ("draftmodel" in dict and dict["draftmodel"]) else "")
ssl_cert_var.set("") ssl_cert_var.set("")
ssl_key_var.set("") ssl_key_var.set("")
@ -4442,6 +4458,10 @@ def main(launch_args,start_server=True):
dlfile = download_model_from_url(args.whispermodel,[".gguf",".bin"]) dlfile = download_model_from_url(args.whispermodel,[".gguf",".bin"])
if dlfile: if dlfile:
args.whispermodel = dlfile args.whispermodel = dlfile
if args.draftmodel and args.draftmodel!="":
dlfile = download_model_from_url(args.draftmodel,[".gguf"])
if dlfile:
args.draftmodel = dlfile
# sanitize and replace the default vanity name. remember me.... # sanitize and replace the default vanity name. remember me....
if args.model_param and args.model_param!="": if args.model_param and args.model_param!="":
@ -4517,7 +4537,7 @@ def main(launch_args,start_server=True):
pass pass
if args.gpulayers==-1: if args.gpulayers==-1:
if MaxMemory[0] > 0 and (not args.usecpu) and ((args.usecublas is not None) or (args.usevulkan is not None) or (args.useclblast is not None) or sys.platform=="darwin"): if MaxMemory[0] > 0 and (not args.usecpu) and ((args.usecublas is not None) or (args.usevulkan is not None) or (args.useclblast is not None) or sys.platform=="darwin"):
extract_modelfile_params(args.model_param,args.sdmodel,args.whispermodel,args.mmproj) extract_modelfile_params(args.model_param,args.sdmodel,args.whispermodel,args.mmproj,args.draftmodel)
layeramt = autoset_gpu_layers(args.contextsize,args.sdquant,args.blasbatchsize) layeramt = autoset_gpu_layers(args.contextsize,args.sdquant,args.blasbatchsize)
print(f"Auto Recommended GPU Layers: {layeramt}") print(f"Auto Recommended GPU Layers: {layeramt}")
args.gpulayers = layeramt args.gpulayers = layeramt
@ -4923,6 +4943,7 @@ if __name__ == '__main__':
advparser.add_argument("--ssl", help="Allows all content to be served over SSL instead. A valid UNENCRYPTED SSL cert and key .pem files must be provided", metavar=('[cert_pem]', '[key_pem]'), nargs='+') advparser.add_argument("--ssl", help="Allows all content to be served over SSL instead. A valid UNENCRYPTED SSL cert and key .pem files must be provided", metavar=('[cert_pem]', '[key_pem]'), nargs='+')
advparser.add_argument("--nocertify", help="Allows insecure SSL connections. Use this if you have cert errors and need to bypass certificate restrictions.", action='store_true') advparser.add_argument("--nocertify", help="Allows insecure SSL connections. Use this if you have cert errors and need to bypass certificate restrictions.", action='store_true')
advparser.add_argument("--mmproj", help="Select a multimodal projector file for vision models like LLaVA.", default="") advparser.add_argument("--mmproj", help="Select a multimodal projector file for vision models like LLaVA.", default="")
advparser.add_argument("--draftmodel", help="Load a small draft model for speculative decoding. It will be fully offloaded. Vocab must match the main model.", default="")
advparser.add_argument("--password", help="Enter a password required to use this instance. This key will be required for all text endpoints. Image endpoints are not secured.", default=None) advparser.add_argument("--password", help="Enter a password required to use this instance. This key will be required for all text endpoints. Image endpoints are not secured.", default=None)
advparser.add_argument("--ignoremissing", help="Ignores all missing non-essential files, just skipping them instead.", action='store_true') advparser.add_argument("--ignoremissing", help="Ignores all missing non-essential files, just skipping them instead.", action='store_true')
advparser.add_argument("--chatcompletionsadapter", help="Select an optional ChatCompletions Adapter JSON file to force custom instruct tags.", default="") advparser.add_argument("--chatcompletionsadapter", help="Select an optional ChatCompletions Adapter JSON file to force custom instruct tags.", default="")

View file

@ -507,4 +507,12 @@ struct llava_image
float * clp_img_embd = nullptr; //this holds dynamic memory and must be freed each use! float * clp_img_embd = nullptr; //this holds dynamic memory and must be freed each use!
}; };
struct speculative_draft_result
{
std::vector<int32_t> draftids;
std::vector<float *> actual_logits;
bool draft_success = false;
int drafted_amount = 0;
};
const float default_norm_eps = 1e-5f; const float default_norm_eps = 1e-5f;

View file

@ -357,11 +357,12 @@ sd_generation_outputs sdtype_generate(const sd_generation_inputs inputs)
int img2imgC = 3; // Assuming RGB image int img2imgC = 3; // Assuming RGB image
std::vector<uint8_t> resized_image_buf(img2imgW * img2imgH * img2imgC); std::vector<uint8_t> resized_image_buf(img2imgW * img2imgH * img2imgC);
std::string ts = get_timestamp_str();
if(!is_quiet) if(!is_quiet)
{ {
printf("\nGenerating Image (%d steps)\n",inputs.sample_steps); printf("\n[%s] Generating Image (%d steps)\n",ts.c_str(),inputs.sample_steps);
}else{ }else{
printf("\nGenerating (%d st.)\n",inputs.sample_steps); printf("\n[%s] Generating (%d st.)\n",ts.c_str(),inputs.sample_steps);
} }
fflush(stdout); fflush(stdout);

View file

@ -8,6 +8,7 @@
#include <locale> #include <locale>
#include <codecvt> #include <codecvt>
#include <sstream> #include <sstream>
#include <ctime>
void utreplace(std::string & str, const std::string & needle, const std::string & replacement) { void utreplace(std::string & str, const std::string & needle, const std::string & replacement) {
@ -302,3 +303,14 @@ std::vector<uint8_t> kcpp_base64_decode(const std::string & encoded_string)
return ret; return ret;
} }
std::string get_timestamp_str()
{
std::time_t t = std::time(nullptr);
std::tm* now = std::localtime(&t);
char buffer[16]; // Buffer to hold "hh:mm:ss" and null terminator
std::sprintf(buffer, "%02d:%02d:%02d", now->tm_hour, now->tm_min, now->tm_sec);
// Convert the buffer to a std::string
std::string timestamp(buffer);
return timestamp;
}

View file

@ -57,3 +57,5 @@ bool should_transpose_layer(std::string name);
void kcpp_graph_compute_helper(ggml_v3_cgraph * graph, int n_threads); void kcpp_graph_compute_helper(ggml_v3_cgraph * graph, int n_threads);
std::vector<uint8_t> kcpp_base64_decode(const std::string & encoded_string); std::vector<uint8_t> kcpp_base64_decode(const std::string & encoded_string);
std::string get_timestamp_str();

View file

@ -273,11 +273,12 @@ whisper_generation_outputs whispertype_generate(const whisper_generation_inputs
// output text transcription // output text transcription
whisper_output_text = output_txt(whisper_ctx, pcmf32s); whisper_output_text = output_txt(whisper_ctx, pcmf32s);
std::string ts = get_timestamp_str();
if(!inputs.quiet) if(!inputs.quiet)
{ {
printf("\nWhisper Transcribe Output: %s",whisper_output_text.c_str()); printf("\n[%s] Whisper Transcribe Output: %s",ts.c_str(),whisper_output_text.c_str());
} else { } else {
printf("\nWhisper Transcribe Done."); printf("\n[%s] Whisper Transcribe Done.",ts.c_str());
} }
output.text = whisper_output_text.c_str(); output.text = whisper_output_text.c_str();
output.status = 1; output.status = 1;