mirror of
https://github.com/LostRuins/koboldcpp.git
synced 2025-09-09 16:44:35 +00:00
wip cfg scale
This commit is contained in:
parent
13cee48740
commit
38a8778f24
3 changed files with 90 additions and 10 deletions
2
expose.h
2
expose.h
|
@ -78,6 +78,8 @@ struct generation_inputs
|
|||
const int seed = 0;
|
||||
const char * prompt = nullptr;
|
||||
const char * memory = nullptr;
|
||||
const char * negative_prompt = nullptr;
|
||||
const float guidance_scale = 1;
|
||||
const char * images[images_max] = {};
|
||||
const int max_context_length = 0;
|
||||
const int max_length = 0;
|
||||
|
|
|
@ -1578,6 +1578,22 @@ void sample_grammar(FileFormat file_format, int32_t n_vocab, llama_token_data_ar
|
|||
|
||||
}
|
||||
|
||||
void sample_guidance(struct llama_context * ctx, struct llama_context * guidance_ctx, int n_vocab, float scale)
|
||||
{
|
||||
float * guidanceLogitsPtr = llama_get_logits(guidance_ctx);
|
||||
float * mainLogitsPtr = llama_get_logits(ctx);
|
||||
|
||||
if (scale < 0) {
|
||||
scale = 0;
|
||||
}
|
||||
|
||||
for (int i = 0; i < n_vocab; ++i) {
|
||||
float logit_guidance = guidanceLogitsPtr[i];
|
||||
float logit_main = mainLogitsPtr[i];
|
||||
mainLogitsPtr[i] = scale * (logit_main-logit_guidance) + logit_guidance;
|
||||
}
|
||||
}
|
||||
|
||||
int SampleLogits(const float * logits, int n_ctx, int n_vocab, int rep_pen_range, float rep_pen, float rep_pen_slope, float presence_penalty, float top_k, float top_a, float top_p, float min_p, float typical_p, float tfs, float nsigma, float temp, std::mt19937 & rng,
|
||||
int mirostat, float mirostat_tau, float mirostat_eta, float dry_multiplier, float dry_base, int dry_allowed_length, int dry_penalty_last_n, float xtc_threshold, float xtc_probability,
|
||||
const std::vector<samplers> & sampler_order, llama_grammar * grammar, float dynatemp_range, float dynatemp_exponent, float smoothing_factor)
|
||||
|
@ -2105,7 +2121,7 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in
|
|||
auto er = llama_v3_eval(llama_ctx_v3, tmp.data(), tmp.size(), 0, kcpp_data->n_threads);
|
||||
if(er!=0)
|
||||
{
|
||||
printf("\nLLAMA EVAL returned nonzero!\n");
|
||||
printf("\nModel Warmup Failed! (code:%d)\n",er);
|
||||
}
|
||||
return ModelLoadResult::SUCCESS;
|
||||
}
|
||||
|
@ -2411,11 +2427,15 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in
|
|||
auto er = llama_decode(llama_ctx_v4, llama_batch_get_one(tmp.data(), tmp.size()));
|
||||
if(er!=0)
|
||||
{
|
||||
printf("\nLLAMA EVAL returned nonzero: %d\n",er);
|
||||
printf("\nModel Warmup Failed! (code:%d)\n",er);
|
||||
}
|
||||
tmp = {1};
|
||||
llama_kv_self_clear(llama_ctx_v4);
|
||||
er = llama_decode(llama_ctx_v4, llama_batch_get_one(tmp.data(), tmp.size()));
|
||||
if(er!=0)
|
||||
{
|
||||
printf("\nModel Warmup Failed! (code:%d)\n",er);
|
||||
}
|
||||
return ModelLoadResult::SUCCESS;
|
||||
}
|
||||
else if (file_format == FileFormat::RWKV_1 || file_format==FileFormat::RWKV_2)
|
||||
|
@ -3097,6 +3117,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
|
|||
}
|
||||
|
||||
std::string addedmemory = inputs.memory;
|
||||
std::string negative_prompt = inputs.negative_prompt;
|
||||
|
||||
//clear previous run llava embd memory, just-in-time free
|
||||
for(int i=0;i<llava_images.size();++i)
|
||||
|
@ -3283,6 +3304,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
|
|||
std::vector<int> embd_inp_mem; //for storing added memory
|
||||
std::vector<int> llava_sep; //to separate between different llava images
|
||||
std::vector<int> llava_intro; //to separate between different llava images
|
||||
std::vector<int> guidance_embd; //holds the guidance prompt
|
||||
bool llava_embds_built = false;
|
||||
|
||||
int32_t nctx = kcpp_data->n_ctx;
|
||||
|
@ -3357,18 +3379,31 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
|
|||
}
|
||||
}
|
||||
|
||||
std::vector<int> negprompt_tokens;
|
||||
int guidance_n_past = 0;
|
||||
if(guidance_ctx)
|
||||
{
|
||||
llama_kv_self_clear(guidance_ctx);
|
||||
//prepare negative prompt
|
||||
if(negative_prompt!="" && inputs.guidance_scale!=1.0f)
|
||||
{
|
||||
TokenizeString(negative_prompt+"\n", negprompt_tokens, file_format, add_bos_token);
|
||||
}
|
||||
}
|
||||
|
||||
//added special memory, overwrite if needed
|
||||
if(embd_inp_mem.size()>0)
|
||||
if (embd_inp_mem.size() + negprompt_tokens.size() > 0)
|
||||
{
|
||||
//remove bos token from prompt, it'll be taken from memory
|
||||
std::vector<int> bos;
|
||||
TokenizeString("", bos, file_format, add_bos_token);
|
||||
if (bos.size()>0 && !embd_inp.empty() && bos[0]==embd_inp[0]) {
|
||||
|
||||
if (bos.size()>0 && !embd_inp.empty() && bos[0]==embd_inp[0]) { //strip away bos if exists
|
||||
embd_inp.erase(embd_inp.begin());
|
||||
}
|
||||
|
||||
//shorten memory if needed
|
||||
if (embd_inp_mem.size() + kcpp_data->n_predict + 4 > nctx)
|
||||
if (embd_inp_mem.size() > 0 && embd_inp_mem.size() + kcpp_data->n_predict + 4 > nctx)
|
||||
{
|
||||
int offset = embd_inp_mem.size() - nctx + kcpp_data->n_predict + 4;
|
||||
embd_inp_mem = std::vector<int>(embd_inp_mem.begin() + offset, embd_inp_mem.end());
|
||||
|
@ -3380,7 +3415,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
|
|||
}
|
||||
|
||||
//shorten main prompt by trimming the front if needed
|
||||
int addmemtokens = embd_inp_mem.size();
|
||||
int addmemtokens = embd_inp_mem.size() + negprompt_tokens.size() + 1;
|
||||
int totalsize = (addmemtokens + embd_inp.size() + kcpp_data->n_predict);
|
||||
if(totalsize > nctx)
|
||||
{
|
||||
|
@ -3394,6 +3429,34 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
|
|||
|
||||
//stick memory to front of prompt
|
||||
embd_inp.insert(embd_inp.begin(), embd_inp_mem.begin(), embd_inp_mem.end());
|
||||
if(add_bos_token && embd_inp.size()>0 && bos.size()>0 && bos[0]!=embd_inp[0])
|
||||
{
|
||||
embd_inp.insert(embd_inp.begin(), bos[0]); //insert bos at front, if added
|
||||
}
|
||||
}
|
||||
|
||||
//prepare negative prompt
|
||||
if(guidance_ctx && negprompt_tokens.size()>0 && inputs.guidance_scale!=1.0f)
|
||||
{
|
||||
guidance_embd = embd_inp; //clone main prompt
|
||||
std::vector<int> bos;
|
||||
TokenizeString("", bos, file_format, add_bos_token);
|
||||
if (bos.size()>0 && !guidance_embd.empty() && bos[0]==guidance_embd[0]) {
|
||||
guidance_embd.erase(guidance_embd.begin());
|
||||
}
|
||||
|
||||
// Insert at the beginning of everything. size is already handled
|
||||
guidance_embd.insert(guidance_embd.begin(), negprompt_tokens.begin(), negprompt_tokens.end());
|
||||
|
||||
//eval the guidance prompt
|
||||
printf("Preparing Negative Prompt (%zu tokens)\n", guidance_embd.size());
|
||||
kcpp_embd_batch batch = kcpp_embd_batch(guidance_embd, 0, use_mrope, false);
|
||||
auto er = (llama_decode(guidance_ctx, batch.batch)==0);
|
||||
if(er!=0)
|
||||
{
|
||||
printf("\nProcess Negative Prompt Failed! (code:%d)\n",er);
|
||||
}
|
||||
guidance_n_past += guidance_embd.size();
|
||||
}
|
||||
|
||||
//determine how much npast we have to rewind from the current state
|
||||
|
@ -3467,10 +3530,6 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
|
|||
}
|
||||
}
|
||||
}
|
||||
if(guidance_ctx)
|
||||
{
|
||||
llama_kv_self_clear(guidance_ctx);
|
||||
}
|
||||
|
||||
bool blasmode = (embd_inp.size() >= 32 && kcpp_cpu_has_blas() && kcpp_data->n_batch>=32);
|
||||
|
||||
|
@ -3591,6 +3650,13 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
|
|||
}
|
||||
else if(file_format == FileFormat::GGUF_GENERIC)
|
||||
{
|
||||
if(guidance_ctx && negprompt_tokens.size()>0 && inputs.guidance_scale!=1.0f && embd.size()==1 && startedsampling)
|
||||
{
|
||||
//eval for negative prompt
|
||||
kcpp_embd_batch batch = kcpp_embd_batch(embd, guidance_n_past, use_mrope, false);
|
||||
evalres = (evalres && (llama_decode(guidance_ctx, batch.batch)==0));
|
||||
guidance_n_past += 1;
|
||||
}
|
||||
if(embd.size()!=1 || draft_ctx==nullptr || remaining_tokens<=speculative_chunk_amt || grammar!=nullptr || startedsampling==false) //for large batch, or if no draft model, PP/TG as usual
|
||||
{
|
||||
draft_used = false;
|
||||
|
@ -3769,6 +3835,12 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
|
|||
lowestLogit = LowestLogit(logits);
|
||||
}
|
||||
|
||||
if(file_format == FileFormat::GGUF_GENERIC && guidance_ctx && negprompt_tokens.size()>0 && inputs.guidance_scale!=1.0f)
|
||||
{
|
||||
sample_guidance(llama_ctx_v4, guidance_ctx, n_vocab, inputs.guidance_scale);
|
||||
}
|
||||
|
||||
//handle token bans
|
||||
if (!inputs.allow_eos_token && !inputs.bypass_eos_token)
|
||||
{
|
||||
// set the logit of the eos token to very low to avoid sampling it
|
||||
|
|
|
@ -198,6 +198,8 @@ class generation_inputs(ctypes.Structure):
|
|||
_fields_ = [("seed", ctypes.c_int),
|
||||
("prompt", ctypes.c_char_p),
|
||||
("memory", ctypes.c_char_p),
|
||||
("negative_prompt", ctypes.c_char_p),
|
||||
("guidance_scale", ctypes.c_float),
|
||||
("images", ctypes.c_char_p * images_max),
|
||||
("max_context_length", ctypes.c_int),
|
||||
("max_length", ctypes.c_int),
|
||||
|
@ -1247,6 +1249,8 @@ def generate(genparams, stream_flag=False):
|
|||
|
||||
prompt = genparams.get('prompt', "")
|
||||
memory = genparams.get('memory', "")
|
||||
negative_prompt = genparams.get('negative_prompt', "")
|
||||
guidance_scale = tryparsefloat(genparams.get('guidance_scale', 1.0),1.0)
|
||||
images = genparams.get('images', [])
|
||||
max_context_length = tryparseint(genparams.get('max_context_length', maxctx),maxctx)
|
||||
max_length = tryparseint(genparams.get('max_length', args.defaultgenamt),args.defaultgenamt)
|
||||
|
@ -1327,6 +1331,8 @@ def generate(genparams, stream_flag=False):
|
|||
inputs = generation_inputs()
|
||||
inputs.prompt = prompt.encode("UTF-8")
|
||||
inputs.memory = memory.encode("UTF-8")
|
||||
inputs.negative_prompt = negative_prompt.encode("UTF-8")
|
||||
inputs.guidance_scale = guidance_scale
|
||||
for n in range(images_max):
|
||||
if not images or n >= len(images):
|
||||
inputs.images[n] = "".encode("UTF-8")
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue