mirror of
https://github.com/LostRuins/koboldcpp.git
synced 2026-05-19 08:00:25 +00:00
added preliminary support for reasoning budget
This commit is contained in:
parent
79882d669a
commit
0b37cb9a57
5 changed files with 96 additions and 2 deletions
|
|
@ -5195,12 +5195,14 @@ Current version indicated by LITEVER below.
|
|||
if (before) {
|
||||
input += `${before}`;
|
||||
}
|
||||
if (middle) {
|
||||
if (middle && middle.trim()) {
|
||||
input += `<span class="color_lightgreen">${middle}</span>`;
|
||||
}
|
||||
if (after) {
|
||||
input += `${after}`;
|
||||
}
|
||||
input = input.replaceAll(thinkstartpl, escape_html(localsettings.start_thinking_tag));
|
||||
input = input.replaceAll(thinkendpl, escape_html(localsettings.stop_thinking_tag));
|
||||
}
|
||||
input = replaceAll(input,"\n","<br>",false);
|
||||
return input;
|
||||
|
|
|
|||
1
expose.h
1
expose.h
|
|
@ -142,6 +142,7 @@ struct generation_inputs
|
|||
const logit_bias * logit_biases = nullptr;
|
||||
const int banned_tokens_len = 0;
|
||||
const char ** banned_tokens = nullptr;
|
||||
const int reasoning_budget = 0;
|
||||
};
|
||||
struct generation_outputs
|
||||
{
|
||||
|
|
|
|||
|
|
@ -1870,6 +1870,56 @@ const std::vector<samplers> & sampler_order, llama_grammar * grammar, float dyna
|
|||
return id;
|
||||
}
|
||||
|
||||
static int apply_reasoning_budget(int id, const std::vector<int> & start_think, const std::vector<int> & end_think, int budget)
|
||||
{
|
||||
if(budget<=0 || start_think.size()==0 || end_think.size()!=1) //start_think can be 1-3 tokens long, end_think is always 1 token
|
||||
{
|
||||
return id;
|
||||
}
|
||||
|
||||
int end_think_index = -1;
|
||||
int start_think_index = -1;
|
||||
int ctx_size = (int)current_context_tokens.size();
|
||||
|
||||
for (int i = ctx_size - 1; i >= 0; --i) { // Search backwards for the latest end_think token
|
||||
if (end_think_index == -1 && current_context_tokens[i] == end_think[0]) {
|
||||
end_think_index = i;
|
||||
}
|
||||
if (start_think_index == -1) { // Search backwards for the latest start_think sequence
|
||||
int seq_len = (int) start_think.size();
|
||||
if (i - seq_len + 1 >= 0) {
|
||||
bool match = true;
|
||||
for (int j = 0; j < seq_len; ++j) {
|
||||
if (current_context_tokens[i - seq_len + 1 + j] != start_think[j]) {
|
||||
match = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (match) {
|
||||
start_think_index = i; // index of the last token of the start_think sequence
|
||||
}
|
||||
}
|
||||
}
|
||||
if (start_think_index != -1 && end_think_index != -1) { // Early exit once both are found
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (start_think_index == -1) { // If no start_think found, do nothing
|
||||
return id;
|
||||
}
|
||||
|
||||
if (end_think_index != -1 && end_think_index > start_think_index) { // If end_think comes after start_think, thinking is already closed
|
||||
return id;
|
||||
}
|
||||
|
||||
int tokens_since_start = ctx_size - 1 - start_think_index; // start_think is unclosed, check budget
|
||||
if (tokens_since_start >= budget) {
|
||||
return end_think[0]; // Force-close thinking by returning the end_think token
|
||||
}
|
||||
|
||||
return id;
|
||||
}
|
||||
|
||||
static void grammar_accept_token(FileFormat file_format, int32_t n_vocab, struct llama_grammar * grammar, llama_token token)
|
||||
{
|
||||
|
|
@ -3741,6 +3791,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
|
|||
kcpp_data->smoothing_curve = inputs.smoothing_curve;
|
||||
kcpp_data->adaptive_target = inputs.adaptive_target;
|
||||
kcpp_data->adaptive_decay = inputs.adaptive_decay;
|
||||
kcpp_data->reasoning_budget = inputs.reasoning_budget;
|
||||
|
||||
adaptive_p_weighted_sum = 0;
|
||||
adaptive_p_total_weight = 0;
|
||||
|
|
@ -3862,6 +3913,34 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
|
|||
}
|
||||
}
|
||||
|
||||
//thinking budget handling
|
||||
std::vector<int> thinking_start_sequence;
|
||||
std::vector<int> thinking_end_sequence;
|
||||
std::string chat_template = "";
|
||||
if (file_format == FileFormat::GGUF_GENERIC) {
|
||||
chat_template = gpttype_get_chat_template();
|
||||
if (file_format_meta.model_architecture == llm_arch::LLM_ARCH_GEMMA4) {
|
||||
TokenizeString("<|channel>thought",thinking_start_sequence,file_format,false);
|
||||
TokenizeString("<channel|>",thinking_end_sequence,file_format,false);
|
||||
//sanity check, start is 2 tokens and end is 1
|
||||
if(thinking_start_sequence.size()!=2 || thinking_end_sequence.size()!=1)
|
||||
{
|
||||
thinking_start_sequence.clear();
|
||||
thinking_end_sequence.clear();
|
||||
}
|
||||
} else {
|
||||
TokenizeString("<think>",thinking_start_sequence,file_format,false);
|
||||
TokenizeString("</think>",thinking_end_sequence,file_format,false);
|
||||
//sanity check, start is 1 tokens and end is 1
|
||||
if(thinking_start_sequence.size()!=1 || thinking_end_sequence.size()!=1)
|
||||
{
|
||||
thinking_start_sequence.clear();
|
||||
thinking_end_sequence.clear();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
bool stream_sse = inputs.stream_sse;
|
||||
bool allow_regular_prints = (!is_quiet && debugmode!=-1);
|
||||
|
||||
|
|
@ -4842,6 +4921,13 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
|
|||
adaptive_p_update_history(original_prob, adaptive_p_weighted_sum, adaptive_p_total_weight, adaptive_decay);
|
||||
}
|
||||
|
||||
//apply reasoning budget
|
||||
int newid = apply_reasoning_budget(id, thinking_start_sequence, thinking_end_sequence, kcpp_data->reasoning_budget);
|
||||
if (id != newid) {
|
||||
printf("\n(Reasoning Budget of %d tokens exceeded! Attempting to stop thinking, insert token %d!)\n", kcpp_data->reasoning_budget, newid);
|
||||
id = newid;
|
||||
}
|
||||
|
||||
if(draft_used)
|
||||
{
|
||||
int32_t draftedid = draft_results.draftids[logits_sampled];
|
||||
|
|
|
|||
|
|
@ -325,7 +325,8 @@ class generation_inputs(ctypes.Structure):
|
|||
("logit_biases_len", ctypes.c_int),
|
||||
("logit_biases", ctypes.POINTER(logit_bias)),
|
||||
("banned_tokens_len", ctypes.c_int),
|
||||
("banned_tokens", ctypes.POINTER(ctypes.c_char_p))]
|
||||
("banned_tokens", ctypes.POINTER(ctypes.c_char_p)),
|
||||
("reasoning_budget", ctypes.c_int)]
|
||||
|
||||
class generation_outputs(ctypes.Structure):
|
||||
_fields_ = [("status", ctypes.c_int),
|
||||
|
|
@ -2008,6 +2009,7 @@ def generate(genparams, stream_flag=False):
|
|||
ban_eos_token = genparams.get('ban_eos_token', False)
|
||||
stream_sse = stream_flag
|
||||
grammar = genparams.get('grammar', '')
|
||||
reasoning_budget = tryparseint(genparams.get('reasoning_budget', 0),0)
|
||||
#translate grammar if its json
|
||||
try:
|
||||
grammarjson = json.loads(grammar)
|
||||
|
|
@ -2184,6 +2186,8 @@ def generate(genparams, stream_flag=False):
|
|||
for n, tok in enumerate(banned_tokens):
|
||||
inputs.banned_tokens[n] = tok.encode("UTF-8")
|
||||
|
||||
inputs.reasoning_budget = reasoning_budget
|
||||
|
||||
currentusergenkey = genkey
|
||||
totalgens += 1
|
||||
#early exit if aborted
|
||||
|
|
|
|||
|
|
@ -54,6 +54,7 @@ struct kcpp_params {
|
|||
float dynatemp_exponent = 1.0f;
|
||||
float adaptive_target = -1.0f; // 0.0 - 1.0, <=0.0 is disabled
|
||||
float adaptive_decay = 0.9f;
|
||||
int reasoning_budget = 0; //if > 0, controls thinking budget
|
||||
|
||||
std::string model_filename = ""; // model path
|
||||
std::string prompt = "";
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue