added preliminary support for reasoning budget

This commit is contained in:
Concedo 2026-04-18 11:56:33 +08:00
parent 79882d669a
commit 0b37cb9a57
5 changed files with 96 additions and 2 deletions

View file

@ -5195,12 +5195,14 @@ Current version indicated by LITEVER below.
if (before) {
input += `${before}`;
}
if (middle) {
if (middle && middle.trim()) {
input += `<span class="color_lightgreen">${middle}</span>`;
}
if (after) {
input += `${after}`;
}
input = input.replaceAll(thinkstartpl, escape_html(localsettings.start_thinking_tag));
input = input.replaceAll(thinkendpl, escape_html(localsettings.stop_thinking_tag));
}
input = replaceAll(input,"\n","<br>",false);
return input;

View file

@ -142,6 +142,7 @@ struct generation_inputs
const logit_bias * logit_biases = nullptr;
const int banned_tokens_len = 0;
const char ** banned_tokens = nullptr;
const int reasoning_budget = 0;
};
struct generation_outputs
{

View file

@ -1870,6 +1870,56 @@ const std::vector<samplers> & sampler_order, llama_grammar * grammar, float dyna
return id;
}
static int apply_reasoning_budget(int id, const std::vector<int> & start_think, const std::vector<int> & end_think, int budget)
{
if(budget<=0 || start_think.size()==0 || end_think.size()!=1) //start_think can be 1-3 tokens long, end_think is always 1 token
{
return id;
}
int end_think_index = -1;
int start_think_index = -1;
int ctx_size = (int)current_context_tokens.size();
for (int i = ctx_size - 1; i >= 0; --i) { // Search backwards for the latest end_think token
if (end_think_index == -1 && current_context_tokens[i] == end_think[0]) {
end_think_index = i;
}
if (start_think_index == -1) { // Search backwards for the latest start_think sequence
int seq_len = (int) start_think.size();
if (i - seq_len + 1 >= 0) {
bool match = true;
for (int j = 0; j < seq_len; ++j) {
if (current_context_tokens[i - seq_len + 1 + j] != start_think[j]) {
match = false;
break;
}
}
if (match) {
start_think_index = i; // index of the last token of the start_think sequence
}
}
}
if (start_think_index != -1 && end_think_index != -1) { // Early exit once both are found
break;
}
}
if (start_think_index == -1) { // If no start_think found, do nothing
return id;
}
if (end_think_index != -1 && end_think_index > start_think_index) { // If end_think comes after start_think, thinking is already closed
return id;
}
int tokens_since_start = ctx_size - 1 - start_think_index; // start_think is unclosed, check budget
if (tokens_since_start >= budget) {
return end_think[0]; // Force-close thinking by returning the end_think token
}
return id;
}
static void grammar_accept_token(FileFormat file_format, int32_t n_vocab, struct llama_grammar * grammar, llama_token token)
{
@ -3741,6 +3791,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
kcpp_data->smoothing_curve = inputs.smoothing_curve;
kcpp_data->adaptive_target = inputs.adaptive_target;
kcpp_data->adaptive_decay = inputs.adaptive_decay;
kcpp_data->reasoning_budget = inputs.reasoning_budget;
adaptive_p_weighted_sum = 0;
adaptive_p_total_weight = 0;
@ -3862,6 +3913,34 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
}
}
//thinking budget handling
std::vector<int> thinking_start_sequence;
std::vector<int> thinking_end_sequence;
std::string chat_template = "";
if (file_format == FileFormat::GGUF_GENERIC) {
chat_template = gpttype_get_chat_template();
if (file_format_meta.model_architecture == llm_arch::LLM_ARCH_GEMMA4) {
TokenizeString("<|channel>thought",thinking_start_sequence,file_format,false);
TokenizeString("<channel|>",thinking_end_sequence,file_format,false);
//sanity check, start is 2 tokens and end is 1
if(thinking_start_sequence.size()!=2 || thinking_end_sequence.size()!=1)
{
thinking_start_sequence.clear();
thinking_end_sequence.clear();
}
} else {
TokenizeString("<think>",thinking_start_sequence,file_format,false);
TokenizeString("</think>",thinking_end_sequence,file_format,false);
//sanity check, start is 1 tokens and end is 1
if(thinking_start_sequence.size()!=1 || thinking_end_sequence.size()!=1)
{
thinking_start_sequence.clear();
thinking_end_sequence.clear();
}
}
}
bool stream_sse = inputs.stream_sse;
bool allow_regular_prints = (!is_quiet && debugmode!=-1);
@ -4842,6 +4921,13 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
adaptive_p_update_history(original_prob, adaptive_p_weighted_sum, adaptive_p_total_weight, adaptive_decay);
}
//apply reasoning budget
int newid = apply_reasoning_budget(id, thinking_start_sequence, thinking_end_sequence, kcpp_data->reasoning_budget);
if (id != newid) {
printf("\n(Reasoning Budget of %d tokens exceeded! Attempting to stop thinking, insert token %d!)\n", kcpp_data->reasoning_budget, newid);
id = newid;
}
if(draft_used)
{
int32_t draftedid = draft_results.draftids[logits_sampled];

View file

@ -325,7 +325,8 @@ class generation_inputs(ctypes.Structure):
("logit_biases_len", ctypes.c_int),
("logit_biases", ctypes.POINTER(logit_bias)),
("banned_tokens_len", ctypes.c_int),
("banned_tokens", ctypes.POINTER(ctypes.c_char_p))]
("banned_tokens", ctypes.POINTER(ctypes.c_char_p)),
("reasoning_budget", ctypes.c_int)]
class generation_outputs(ctypes.Structure):
_fields_ = [("status", ctypes.c_int),
@ -2008,6 +2009,7 @@ def generate(genparams, stream_flag=False):
ban_eos_token = genparams.get('ban_eos_token', False)
stream_sse = stream_flag
grammar = genparams.get('grammar', '')
reasoning_budget = tryparseint(genparams.get('reasoning_budget', 0),0)
#translate grammar if its json
try:
grammarjson = json.loads(grammar)
@ -2184,6 +2186,8 @@ def generate(genparams, stream_flag=False):
for n, tok in enumerate(banned_tokens):
inputs.banned_tokens[n] = tok.encode("UTF-8")
inputs.reasoning_budget = reasoning_budget
currentusergenkey = genkey
totalgens += 1
#early exit if aborted

View file

@ -54,6 +54,7 @@ struct kcpp_params {
float dynatemp_exponent = 1.0f;
float adaptive_target = -1.0f; // 0.0 - 1.0, <=0.0 is disabled
float adaptive_decay = 0.9f;
int reasoning_budget = 0; //if > 0, controls thinking budget
std::string model_filename = ""; // model path
std::string prompt = "";