From 0b37cb9a57b50aed2c74a7112b05c978a16a315e Mon Sep 17 00:00:00 2001 From: Concedo <39025047+LostRuins@users.noreply.github.com> Date: Sat, 18 Apr 2026 11:56:33 +0800 Subject: [PATCH] added preliminary support for reasoning budget --- embd_res/klite.embd | 4 +- expose.h | 1 + gpttype_adapter.cpp | 86 +++++++++++++++++++++++++++++++++++++++++++ koboldcpp.py | 6 ++- otherarch/otherarch.h | 1 + 5 files changed, 96 insertions(+), 2 deletions(-) diff --git a/embd_res/klite.embd b/embd_res/klite.embd index 38190d7aa..03b407bad 100644 --- a/embd_res/klite.embd +++ b/embd_res/klite.embd @@ -5195,12 +5195,14 @@ Current version indicated by LITEVER below. if (before) { input += `${before}`; } - if (middle) { + if (middle && middle.trim()) { input += `${middle}`; } if (after) { input += `${after}`; } + input = input.replaceAll(thinkstartpl, escape_html(localsettings.start_thinking_tag)); + input = input.replaceAll(thinkendpl, escape_html(localsettings.stop_thinking_tag)); } input = replaceAll(input,"\n","
",false); return input; diff --git a/expose.h b/expose.h index ee4a3ce99..090daed2f 100644 --- a/expose.h +++ b/expose.h @@ -142,6 +142,7 @@ struct generation_inputs const logit_bias * logit_biases = nullptr; const int banned_tokens_len = 0; const char ** banned_tokens = nullptr; + const int reasoning_budget = 0; }; struct generation_outputs { diff --git a/gpttype_adapter.cpp b/gpttype_adapter.cpp index 4a29de1fa..5b1390399 100644 --- a/gpttype_adapter.cpp +++ b/gpttype_adapter.cpp @@ -1870,6 +1870,56 @@ const std::vector & sampler_order, llama_grammar * grammar, float dyna return id; } +static int apply_reasoning_budget(int id, const std::vector & start_think, const std::vector & end_think, int budget) +{ + if(budget<=0 || start_think.size()==0 || end_think.size()!=1) //start_think can be 1-3 tokens long, end_think is always 1 token + { + return id; + } + + int end_think_index = -1; + int start_think_index = -1; + int ctx_size = (int)current_context_tokens.size(); + + for (int i = ctx_size - 1; i >= 0; --i) { // Search backwards for the latest end_think token + if (end_think_index == -1 && current_context_tokens[i] == end_think[0]) { + end_think_index = i; + } + if (start_think_index == -1) { // Search backwards for the latest start_think sequence + int seq_len = (int) start_think.size(); + if (i - seq_len + 1 >= 0) { + bool match = true; + for (int j = 0; j < seq_len; ++j) { + if (current_context_tokens[i - seq_len + 1 + j] != start_think[j]) { + match = false; + break; + } + } + if (match) { + start_think_index = i; // index of the last token of the start_think sequence + } + } + } + if (start_think_index != -1 && end_think_index != -1) { // Early exit once both are found + break; + } + } + + if (start_think_index == -1) { // If no start_think found, do nothing + return id; + } + + if (end_think_index != -1 && end_think_index > start_think_index) { // If end_think comes after start_think, thinking is already closed + return id; + } + + int tokens_since_start = ctx_size - 1 - start_think_index; // start_think is unclosed, check budget + if (tokens_since_start >= budget) { + return end_think[0]; // Force-close thinking by returning the end_think token + } + + return id; +} static void grammar_accept_token(FileFormat file_format, int32_t n_vocab, struct llama_grammar * grammar, llama_token token) { @@ -3741,6 +3791,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs) kcpp_data->smoothing_curve = inputs.smoothing_curve; kcpp_data->adaptive_target = inputs.adaptive_target; kcpp_data->adaptive_decay = inputs.adaptive_decay; + kcpp_data->reasoning_budget = inputs.reasoning_budget; adaptive_p_weighted_sum = 0; adaptive_p_total_weight = 0; @@ -3862,6 +3913,34 @@ generation_outputs gpttype_generate(const generation_inputs inputs) } } + //thinking budget handling + std::vector thinking_start_sequence; + std::vector thinking_end_sequence; + std::string chat_template = ""; + if (file_format == FileFormat::GGUF_GENERIC) { + chat_template = gpttype_get_chat_template(); + if (file_format_meta.model_architecture == llm_arch::LLM_ARCH_GEMMA4) { + TokenizeString("<|channel>thought",thinking_start_sequence,file_format,false); + TokenizeString("",thinking_end_sequence,file_format,false); + //sanity check, start is 2 tokens and end is 1 + if(thinking_start_sequence.size()!=2 || thinking_end_sequence.size()!=1) + { + thinking_start_sequence.clear(); + thinking_end_sequence.clear(); + } + } else { + TokenizeString("",thinking_start_sequence,file_format,false); + TokenizeString("",thinking_end_sequence,file_format,false); + //sanity check, start is 1 tokens and end is 1 + if(thinking_start_sequence.size()!=1 || thinking_end_sequence.size()!=1) + { + thinking_start_sequence.clear(); + thinking_end_sequence.clear(); + } + } + } + + bool stream_sse = inputs.stream_sse; bool allow_regular_prints = (!is_quiet && debugmode!=-1); @@ -4842,6 +4921,13 @@ generation_outputs gpttype_generate(const generation_inputs inputs) adaptive_p_update_history(original_prob, adaptive_p_weighted_sum, adaptive_p_total_weight, adaptive_decay); } + //apply reasoning budget + int newid = apply_reasoning_budget(id, thinking_start_sequence, thinking_end_sequence, kcpp_data->reasoning_budget); + if (id != newid) { + printf("\n(Reasoning Budget of %d tokens exceeded! Attempting to stop thinking, insert token %d!)\n", kcpp_data->reasoning_budget, newid); + id = newid; + } + if(draft_used) { int32_t draftedid = draft_results.draftids[logits_sampled]; diff --git a/koboldcpp.py b/koboldcpp.py index 68f2c79ca..6fb20bdc1 100755 --- a/koboldcpp.py +++ b/koboldcpp.py @@ -325,7 +325,8 @@ class generation_inputs(ctypes.Structure): ("logit_biases_len", ctypes.c_int), ("logit_biases", ctypes.POINTER(logit_bias)), ("banned_tokens_len", ctypes.c_int), - ("banned_tokens", ctypes.POINTER(ctypes.c_char_p))] + ("banned_tokens", ctypes.POINTER(ctypes.c_char_p)), + ("reasoning_budget", ctypes.c_int)] class generation_outputs(ctypes.Structure): _fields_ = [("status", ctypes.c_int), @@ -2008,6 +2009,7 @@ def generate(genparams, stream_flag=False): ban_eos_token = genparams.get('ban_eos_token', False) stream_sse = stream_flag grammar = genparams.get('grammar', '') + reasoning_budget = tryparseint(genparams.get('reasoning_budget', 0),0) #translate grammar if its json try: grammarjson = json.loads(grammar) @@ -2184,6 +2186,8 @@ def generate(genparams, stream_flag=False): for n, tok in enumerate(banned_tokens): inputs.banned_tokens[n] = tok.encode("UTF-8") + inputs.reasoning_budget = reasoning_budget + currentusergenkey = genkey totalgens += 1 #early exit if aborted diff --git a/otherarch/otherarch.h b/otherarch/otherarch.h index 1e7b5bb2a..ec28c5ab8 100644 --- a/otherarch/otherarch.h +++ b/otherarch/otherarch.h @@ -54,6 +54,7 @@ struct kcpp_params { float dynatemp_exponent = 1.0f; float adaptive_target = -1.0f; // 0.0 - 1.0, <=0.0 is disabled float adaptive_decay = 0.9f; + int reasoning_budget = 0; //if > 0, controls thinking budget std::string model_filename = ""; // model path std::string prompt = "";