wip cfg scale

This commit is contained in:
Concedo 2025-05-06 23:06:25 +08:00
parent 13cee48740
commit 38a8778f24
3 changed files with 90 additions and 10 deletions

View file

@ -78,6 +78,8 @@ struct generation_inputs
const int seed = 0;
const char * prompt = nullptr;
const char * memory = nullptr;
const char * negative_prompt = nullptr;
const float guidance_scale = 1;
const char * images[images_max] = {};
const int max_context_length = 0;
const int max_length = 0;

View file

@ -1578,6 +1578,22 @@ void sample_grammar(FileFormat file_format, int32_t n_vocab, llama_token_data_ar
}
void sample_guidance(struct llama_context * ctx, struct llama_context * guidance_ctx, int n_vocab, float scale)
{
float * guidanceLogitsPtr = llama_get_logits(guidance_ctx);
float * mainLogitsPtr = llama_get_logits(ctx);
if (scale < 0) {
scale = 0;
}
for (int i = 0; i < n_vocab; ++i) {
float logit_guidance = guidanceLogitsPtr[i];
float logit_main = mainLogitsPtr[i];
mainLogitsPtr[i] = scale * (logit_main-logit_guidance) + logit_guidance;
}
}
int SampleLogits(const float * logits, int n_ctx, int n_vocab, int rep_pen_range, float rep_pen, float rep_pen_slope, float presence_penalty, float top_k, float top_a, float top_p, float min_p, float typical_p, float tfs, float nsigma, float temp, std::mt19937 & rng,
int mirostat, float mirostat_tau, float mirostat_eta, float dry_multiplier, float dry_base, int dry_allowed_length, int dry_penalty_last_n, float xtc_threshold, float xtc_probability,
const std::vector<samplers> & sampler_order, llama_grammar * grammar, float dynatemp_range, float dynatemp_exponent, float smoothing_factor)
@ -2105,7 +2121,7 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in
auto er = llama_v3_eval(llama_ctx_v3, tmp.data(), tmp.size(), 0, kcpp_data->n_threads);
if(er!=0)
{
printf("\nLLAMA EVAL returned nonzero!\n");
printf("\nModel Warmup Failed! (code:%d)\n",er);
}
return ModelLoadResult::SUCCESS;
}
@ -2411,11 +2427,15 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in
auto er = llama_decode(llama_ctx_v4, llama_batch_get_one(tmp.data(), tmp.size()));
if(er!=0)
{
printf("\nLLAMA EVAL returned nonzero: %d\n",er);
printf("\nModel Warmup Failed! (code:%d)\n",er);
}
tmp = {1};
llama_kv_self_clear(llama_ctx_v4);
er = llama_decode(llama_ctx_v4, llama_batch_get_one(tmp.data(), tmp.size()));
if(er!=0)
{
printf("\nModel Warmup Failed! (code:%d)\n",er);
}
return ModelLoadResult::SUCCESS;
}
else if (file_format == FileFormat::RWKV_1 || file_format==FileFormat::RWKV_2)
@ -3097,6 +3117,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
}
std::string addedmemory = inputs.memory;
std::string negative_prompt = inputs.negative_prompt;
//clear previous run llava embd memory, just-in-time free
for(int i=0;i<llava_images.size();++i)
@ -3283,6 +3304,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
std::vector<int> embd_inp_mem; //for storing added memory
std::vector<int> llava_sep; //to separate between different llava images
std::vector<int> llava_intro; //to separate between different llava images
std::vector<int> guidance_embd; //holds the guidance prompt
bool llava_embds_built = false;
int32_t nctx = kcpp_data->n_ctx;
@ -3357,18 +3379,31 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
}
}
std::vector<int> negprompt_tokens;
int guidance_n_past = 0;
if(guidance_ctx)
{
llama_kv_self_clear(guidance_ctx);
//prepare negative prompt
if(negative_prompt!="" && inputs.guidance_scale!=1.0f)
{
TokenizeString(negative_prompt+"\n", negprompt_tokens, file_format, add_bos_token);
}
}
//added special memory, overwrite if needed
if(embd_inp_mem.size()>0)
if (embd_inp_mem.size() + negprompt_tokens.size() > 0)
{
//remove bos token from prompt, it'll be taken from memory
std::vector<int> bos;
TokenizeString("", bos, file_format, add_bos_token);
if (bos.size()>0 && !embd_inp.empty() && bos[0]==embd_inp[0]) {
if (bos.size()>0 && !embd_inp.empty() && bos[0]==embd_inp[0]) { //strip away bos if exists
embd_inp.erase(embd_inp.begin());
}
//shorten memory if needed
if (embd_inp_mem.size() + kcpp_data->n_predict + 4 > nctx)
if (embd_inp_mem.size() > 0 && embd_inp_mem.size() + kcpp_data->n_predict + 4 > nctx)
{
int offset = embd_inp_mem.size() - nctx + kcpp_data->n_predict + 4;
embd_inp_mem = std::vector<int>(embd_inp_mem.begin() + offset, embd_inp_mem.end());
@ -3380,7 +3415,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
}
//shorten main prompt by trimming the front if needed
int addmemtokens = embd_inp_mem.size();
int addmemtokens = embd_inp_mem.size() + negprompt_tokens.size() + 1;
int totalsize = (addmemtokens + embd_inp.size() + kcpp_data->n_predict);
if(totalsize > nctx)
{
@ -3394,6 +3429,34 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
//stick memory to front of prompt
embd_inp.insert(embd_inp.begin(), embd_inp_mem.begin(), embd_inp_mem.end());
if(add_bos_token && embd_inp.size()>0 && bos.size()>0 && bos[0]!=embd_inp[0])
{
embd_inp.insert(embd_inp.begin(), bos[0]); //insert bos at front, if added
}
}
//prepare negative prompt
if(guidance_ctx && negprompt_tokens.size()>0 && inputs.guidance_scale!=1.0f)
{
guidance_embd = embd_inp; //clone main prompt
std::vector<int> bos;
TokenizeString("", bos, file_format, add_bos_token);
if (bos.size()>0 && !guidance_embd.empty() && bos[0]==guidance_embd[0]) {
guidance_embd.erase(guidance_embd.begin());
}
// Insert at the beginning of everything. size is already handled
guidance_embd.insert(guidance_embd.begin(), negprompt_tokens.begin(), negprompt_tokens.end());
//eval the guidance prompt
printf("Preparing Negative Prompt (%zu tokens)\n", guidance_embd.size());
kcpp_embd_batch batch = kcpp_embd_batch(guidance_embd, 0, use_mrope, false);
auto er = (llama_decode(guidance_ctx, batch.batch)==0);
if(er!=0)
{
printf("\nProcess Negative Prompt Failed! (code:%d)\n",er);
}
guidance_n_past += guidance_embd.size();
}
//determine how much npast we have to rewind from the current state
@ -3467,10 +3530,6 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
}
}
}
if(guidance_ctx)
{
llama_kv_self_clear(guidance_ctx);
}
bool blasmode = (embd_inp.size() >= 32 && kcpp_cpu_has_blas() && kcpp_data->n_batch>=32);
@ -3591,6 +3650,13 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
}
else if(file_format == FileFormat::GGUF_GENERIC)
{
if(guidance_ctx && negprompt_tokens.size()>0 && inputs.guidance_scale!=1.0f && embd.size()==1 && startedsampling)
{
//eval for negative prompt
kcpp_embd_batch batch = kcpp_embd_batch(embd, guidance_n_past, use_mrope, false);
evalres = (evalres && (llama_decode(guidance_ctx, batch.batch)==0));
guidance_n_past += 1;
}
if(embd.size()!=1 || draft_ctx==nullptr || remaining_tokens<=speculative_chunk_amt || grammar!=nullptr || startedsampling==false) //for large batch, or if no draft model, PP/TG as usual
{
draft_used = false;
@ -3769,6 +3835,12 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
lowestLogit = LowestLogit(logits);
}
if(file_format == FileFormat::GGUF_GENERIC && guidance_ctx && negprompt_tokens.size()>0 && inputs.guidance_scale!=1.0f)
{
sample_guidance(llama_ctx_v4, guidance_ctx, n_vocab, inputs.guidance_scale);
}
//handle token bans
if (!inputs.allow_eos_token && !inputs.bypass_eos_token)
{
// set the logit of the eos token to very low to avoid sampling it

View file

@ -198,6 +198,8 @@ class generation_inputs(ctypes.Structure):
_fields_ = [("seed", ctypes.c_int),
("prompt", ctypes.c_char_p),
("memory", ctypes.c_char_p),
("negative_prompt", ctypes.c_char_p),
("guidance_scale", ctypes.c_float),
("images", ctypes.c_char_p * images_max),
("max_context_length", ctypes.c_int),
("max_length", ctypes.c_int),
@ -1247,6 +1249,8 @@ def generate(genparams, stream_flag=False):
prompt = genparams.get('prompt', "")
memory = genparams.get('memory', "")
negative_prompt = genparams.get('negative_prompt', "")
guidance_scale = tryparsefloat(genparams.get('guidance_scale', 1.0),1.0)
images = genparams.get('images', [])
max_context_length = tryparseint(genparams.get('max_context_length', maxctx),maxctx)
max_length = tryparseint(genparams.get('max_length', args.defaultgenamt),args.defaultgenamt)
@ -1327,6 +1331,8 @@ def generate(genparams, stream_flag=False):
inputs = generation_inputs()
inputs.prompt = prompt.encode("UTF-8")
inputs.memory = memory.encode("UTF-8")
inputs.negative_prompt = negative_prompt.encode("UTF-8")
inputs.guidance_scale = guidance_scale
for n in range(images_max):
if not images or n >= len(images):
inputs.images[n] = "".encode("UTF-8")