diff --git a/embd_res/klite.embd b/embd_res/klite.embd
index 38190d7aa..03b407bad 100644
--- a/embd_res/klite.embd
+++ b/embd_res/klite.embd
@@ -5195,12 +5195,14 @@ Current version indicated by LITEVER below.
if (before) {
input += `${before}`;
}
- if (middle) {
+ if (middle && middle.trim()) {
input += `${middle}`;
}
if (after) {
input += `${after}`;
}
+ input = input.replaceAll(thinkstartpl, escape_html(localsettings.start_thinking_tag));
+ input = input.replaceAll(thinkendpl, escape_html(localsettings.stop_thinking_tag));
}
input = replaceAll(input,"\n","
",false);
return input;
diff --git a/expose.h b/expose.h
index ee4a3ce99..090daed2f 100644
--- a/expose.h
+++ b/expose.h
@@ -142,6 +142,7 @@ struct generation_inputs
const logit_bias * logit_biases = nullptr;
const int banned_tokens_len = 0;
const char ** banned_tokens = nullptr;
+ const int reasoning_budget = 0;
};
struct generation_outputs
{
diff --git a/gpttype_adapter.cpp b/gpttype_adapter.cpp
index 4a29de1fa..5b1390399 100644
--- a/gpttype_adapter.cpp
+++ b/gpttype_adapter.cpp
@@ -1870,6 +1870,56 @@ const std::vector & sampler_order, llama_grammar * grammar, float dyna
return id;
}
+static int apply_reasoning_budget(int id, const std::vector & start_think, const std::vector & end_think, int budget)
+{
+ if(budget<=0 || start_think.size()==0 || end_think.size()!=1) //start_think can be 1-3 tokens long, end_think is always 1 token
+ {
+ return id;
+ }
+
+ int end_think_index = -1;
+ int start_think_index = -1;
+ int ctx_size = (int)current_context_tokens.size();
+
+ for (int i = ctx_size - 1; i >= 0; --i) { // Search backwards for the latest end_think token
+ if (end_think_index == -1 && current_context_tokens[i] == end_think[0]) {
+ end_think_index = i;
+ }
+ if (start_think_index == -1) { // Search backwards for the latest start_think sequence
+ int seq_len = (int) start_think.size();
+ if (i - seq_len + 1 >= 0) {
+ bool match = true;
+ for (int j = 0; j < seq_len; ++j) {
+ if (current_context_tokens[i - seq_len + 1 + j] != start_think[j]) {
+ match = false;
+ break;
+ }
+ }
+ if (match) {
+ start_think_index = i; // index of the last token of the start_think sequence
+ }
+ }
+ }
+ if (start_think_index != -1 && end_think_index != -1) { // Early exit once both are found
+ break;
+ }
+ }
+
+ if (start_think_index == -1) { // If no start_think found, do nothing
+ return id;
+ }
+
+ if (end_think_index != -1 && end_think_index > start_think_index) { // If end_think comes after start_think, thinking is already closed
+ return id;
+ }
+
+ int tokens_since_start = ctx_size - 1 - start_think_index; // start_think is unclosed, check budget
+ if (tokens_since_start >= budget) {
+ return end_think[0]; // Force-close thinking by returning the end_think token
+ }
+
+ return id;
+}
static void grammar_accept_token(FileFormat file_format, int32_t n_vocab, struct llama_grammar * grammar, llama_token token)
{
@@ -3741,6 +3791,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
kcpp_data->smoothing_curve = inputs.smoothing_curve;
kcpp_data->adaptive_target = inputs.adaptive_target;
kcpp_data->adaptive_decay = inputs.adaptive_decay;
+ kcpp_data->reasoning_budget = inputs.reasoning_budget;
adaptive_p_weighted_sum = 0;
adaptive_p_total_weight = 0;
@@ -3862,6 +3913,34 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
}
}
+ //thinking budget handling
+ std::vector thinking_start_sequence;
+ std::vector thinking_end_sequence;
+ std::string chat_template = "";
+ if (file_format == FileFormat::GGUF_GENERIC) {
+ chat_template = gpttype_get_chat_template();
+ if (file_format_meta.model_architecture == llm_arch::LLM_ARCH_GEMMA4) {
+ TokenizeString("<|channel>thought",thinking_start_sequence,file_format,false);
+ TokenizeString("",thinking_end_sequence,file_format,false);
+ //sanity check, start is 2 tokens and end is 1
+ if(thinking_start_sequence.size()!=2 || thinking_end_sequence.size()!=1)
+ {
+ thinking_start_sequence.clear();
+ thinking_end_sequence.clear();
+ }
+ } else {
+ TokenizeString("",thinking_start_sequence,file_format,false);
+ TokenizeString("",thinking_end_sequence,file_format,false);
+ //sanity check, start is 1 tokens and end is 1
+ if(thinking_start_sequence.size()!=1 || thinking_end_sequence.size()!=1)
+ {
+ thinking_start_sequence.clear();
+ thinking_end_sequence.clear();
+ }
+ }
+ }
+
+
bool stream_sse = inputs.stream_sse;
bool allow_regular_prints = (!is_quiet && debugmode!=-1);
@@ -4842,6 +4921,13 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
adaptive_p_update_history(original_prob, adaptive_p_weighted_sum, adaptive_p_total_weight, adaptive_decay);
}
+ //apply reasoning budget
+ int newid = apply_reasoning_budget(id, thinking_start_sequence, thinking_end_sequence, kcpp_data->reasoning_budget);
+ if (id != newid) {
+ printf("\n(Reasoning Budget of %d tokens exceeded! Attempting to stop thinking, insert token %d!)\n", kcpp_data->reasoning_budget, newid);
+ id = newid;
+ }
+
if(draft_used)
{
int32_t draftedid = draft_results.draftids[logits_sampled];
diff --git a/koboldcpp.py b/koboldcpp.py
index 68f2c79ca..6fb20bdc1 100755
--- a/koboldcpp.py
+++ b/koboldcpp.py
@@ -325,7 +325,8 @@ class generation_inputs(ctypes.Structure):
("logit_biases_len", ctypes.c_int),
("logit_biases", ctypes.POINTER(logit_bias)),
("banned_tokens_len", ctypes.c_int),
- ("banned_tokens", ctypes.POINTER(ctypes.c_char_p))]
+ ("banned_tokens", ctypes.POINTER(ctypes.c_char_p)),
+ ("reasoning_budget", ctypes.c_int)]
class generation_outputs(ctypes.Structure):
_fields_ = [("status", ctypes.c_int),
@@ -2008,6 +2009,7 @@ def generate(genparams, stream_flag=False):
ban_eos_token = genparams.get('ban_eos_token', False)
stream_sse = stream_flag
grammar = genparams.get('grammar', '')
+ reasoning_budget = tryparseint(genparams.get('reasoning_budget', 0),0)
#translate grammar if its json
try:
grammarjson = json.loads(grammar)
@@ -2184,6 +2186,8 @@ def generate(genparams, stream_flag=False):
for n, tok in enumerate(banned_tokens):
inputs.banned_tokens[n] = tok.encode("UTF-8")
+ inputs.reasoning_budget = reasoning_budget
+
currentusergenkey = genkey
totalgens += 1
#early exit if aborted
diff --git a/otherarch/otherarch.h b/otherarch/otherarch.h
index 1e7b5bb2a..ec28c5ab8 100644
--- a/otherarch/otherarch.h
+++ b/otherarch/otherarch.h
@@ -54,6 +54,7 @@ struct kcpp_params {
float dynatemp_exponent = 1.0f;
float adaptive_target = -1.0f; // 0.0 - 1.0, <=0.0 is disabled
float adaptive_decay = 0.9f;
+ int reasoning_budget = 0; //if > 0, controls thinking budget
std::string model_filename = ""; // model path
std::string prompt = "";