From 80ce8a50b33fd90c94decf7fe6e231a8e8a11c88 Mon Sep 17 00:00:00 2001 From: Concedo <39025047+LostRuins@users.noreply.github.com> Date: Sat, 16 May 2026 15:20:46 +0800 Subject: [PATCH] allow token bans and eos handling in --- gpttype_adapter.cpp | 37 +++++++++++++++++++++++++++++-------- koboldcpp.py | 11 ++++++++--- 2 files changed, 37 insertions(+), 11 deletions(-) diff --git a/gpttype_adapter.cpp b/gpttype_adapter.cpp index ce0c24b90..f77b805d0 100644 --- a/gpttype_adapter.cpp +++ b/gpttype_adapter.cpp @@ -3417,6 +3417,7 @@ struct BatchGenerateRequest std::string prompt; std::string prompt_added_memory; std::vector stop_sequences; + std::vector logit_biases; int max_context_length = 0; int max_length = 0; int seed = 0; @@ -3566,7 +3567,7 @@ static bool batch_inputs_eligible(const generation_inputs & inputs) { return false; } - if(inputs.images_len > 0 || inputs.audio_len > 0 || inputs.guidance_scale != 1.0f || !inputs.allow_eos_token) + if(inputs.images_len > 0 || inputs.audio_len > 0 || inputs.guidance_scale != 1.0f) { return false; } @@ -3574,7 +3575,7 @@ static bool batch_inputs_eligible(const generation_inputs & inputs) { return false; } - if(inputs.logit_biases_len > 0 || inputs.banned_tokens_len > 0 || inputs.dry_multiplier > 0.0f) + if(inputs.banned_tokens_len > 0 || inputs.dry_multiplier > 0.0f) { return false; } @@ -3741,6 +3742,11 @@ static llama_sampler * batch_build_sampler(const BatchGenerateRequest & req) req.rep_pen, req.rep_pen_slope, req.presence_penalty)); + if(req.logit_biases.size()>0) + { + int32_t n_vocab = llama_vocab_n_tokens(llama_model_get_vocab(llama_get_model(llama_ctx_v4))); + llama_sampler_chain_add(chain, llama_sampler_init_logit_bias(n_vocab, req.logit_biases.size(), req.logit_biases.data())); + } if(req.top_k > 0) { llama_sampler_chain_add(chain, llama_sampler_init_top_k(req.top_k)); @@ -3974,7 +3980,7 @@ static void batch_worker_loop() } const llama_vocab * vocab = llama_model_get_vocab(llama_get_model(llama_ctx_v4)); - llama_token eos = llama_vocab_eos(vocab); + const std::vector eog_tokens = GetEogIDs(file_format,n_vocab); for(int request_id : decode_ids) { BatchGenerateRequest * req = batch_find_request_locked(request_id); @@ -3983,12 +3989,9 @@ static void batch_worker_loop() continue; } llama_token sampled = llama_sampler_sample(req->sampler, llama_ctx_v4, req->i_batch); - if(!req->allow_eos_token && !req->bypass_eos_token && sampled == eos) - { - sampled = llama_sampler_sample(req->sampler, llama_ctx_v4, req->i_batch); - } req->completion_token_count++; - if(sampled == eos && !req->bypass_eos_token) + bool is_eog = std::find(eog_tokens.begin(), eog_tokens.end(), sampled) != eog_tokens.end(); + if(is_eog && !req->bypass_eos_token) { batch_finish_request_locked(*req, stop_reason::EOS_TOKEN_HIT); continue; @@ -4061,6 +4064,24 @@ int gpttype_batch_generate_submit(const generation_inputs inputs) req->allow_eos_token = inputs.allow_eos_token; req->bypass_eos_token = inputs.bypass_eos_token; req->render_special = inputs.render_special; + req->logit_biases = {}; + for(int i = 0; i < inputs.logit_biases_len; ++i) + { + int32_t t_id = inputs.logit_biases[i].token_id; + float bias = inputs.logit_biases[i].bias; + if(t_id >= 0 && t_id < n_vocab && bias!=0) + { + req->logit_biases.push_back({t_id, bias}); + } + } + if(!req->allow_eos_token && !req->bypass_eos_token) //eos token bans + { + const std::vector eog_tokens = GetEogIDs(file_format,n_vocab); + for(int x = 0; x < eog_tokens.size(); ++x) + { + req->logit_biases.push_back({eog_tokens[x], -999.0f}); + } + } for(int i = 0; i < inputs.stop_sequence_len; ++i) { if(inputs.stop_sequence[i]) diff --git a/koboldcpp.py b/koboldcpp.py index 1694289dd..aed8fdd15 100644 --- a/koboldcpp.py +++ b/koboldcpp.py @@ -2292,20 +2292,25 @@ def continuous_batching_python_eligible(genparams, api_format): if not getattr(args, "noshift", False) or getattr(args, "smartcontext", False) or getattr(args, "draftmodel", "") or getattr(args, "enableguidance", False): return False if genparams.get("negative_prompt") or genparams.get("images") or genparams.get("audio"): + utfprint("Batching disabled due to media",0) return False - if genparams.get("ban_eos_token", False): - return False - if genparams.get("grammar") or genparams.get("grammar_retain_state") or genparams.get("logit_bias") or genparams.get("banned_tokens") or genparams.get("banned_strings"): + if genparams.get("grammar") or genparams.get("grammar_retain_state") or genparams.get("banned_tokens") or genparams.get("banned_strings"): + utfprint("Batching disabled due to grammar or bans",0) return False if tryparsefloat(genparams.get("dry_multiplier", 0), 0) or tryparseint(genparams.get("mirostat", 0), 0) or tryparsefloat(genparams.get("xtc_probability", 0), 0) or tryparsefloat(genparams.get("nsigma", 0), 0): + utfprint("Batching disabled due to samplers set 1",0) return False if tryparsefloat(genparams.get("smoothing_factor", 0), 0) or tryparsefloat(genparams.get("adaptive_target", -1), -1) > 0 or genparams.get("using_openai_tools", False): + utfprint("Batching disabled due to samplers set 2",0) return False if tryparsefloat(genparams.get("top_a", 0), 0) or tryparsefloat(genparams.get("tfs", 1), 1) != 1 or tryparsefloat(genparams.get("dynatemp_range", 0), 0): + utfprint("Batching disabled due to samplers set 3",0) return False if genparams.get("sampler_order") and genparams.get("sampler_order") != [6, 0, 1, 3, 4, 2, 5]: + utfprint("Batching disabled due to sampler order",0) return False if genparams.get("reasoning_effort"): + utfprint("Batching disabled due to reasoning",0) return False return True