From 3210b378e818f8172252560023c23f7585c8e73b Mon Sep 17 00:00:00 2001 From: Concedo <39025047+LostRuins@users.noreply.github.com> Date: Wed, 20 Aug 2025 22:11:31 +0800 Subject: [PATCH] better tool calls --- expose.h | 1 + gpttype_adapter.cpp | 23 +++++++++++++++++++++++ koboldcpp.py | 3 +++ 3 files changed, 27 insertions(+) diff --git a/expose.h b/expose.h index d34f06bfd..9967bfdca 100644 --- a/expose.h +++ b/expose.h @@ -110,6 +110,7 @@ struct generation_inputs const int sampler_len = 0; const bool allow_eos_token = false; const bool bypass_eos_token = false; + const bool tool_call_fix = false; //this prevents close square bracket ] from being generated early. const bool render_special = false; const bool stream_sse = false; const char * grammar = nullptr; diff --git a/gpttype_adapter.cpp b/gpttype_adapter.cpp index d2b8c832c..22654b8fa 100644 --- a/gpttype_adapter.cpp +++ b/gpttype_adapter.cpp @@ -128,6 +128,7 @@ static std::vector stop_sequence; static std::vector special_stop_sequence; //for stop sequences that don't have a string representation static std::vector banned_tokens; static std::vector banned_token_ids; +static std::vector toolcall_prevented_ids; //temp ban these id for the first 3 tokens generated, to prevent empty replies static std::vector banned_phrases; static std::unordered_multimap> dry_sequence_breakers; // Multi-mapping from first token of sequence to tail of sequence (tail is empty for a single token) static std::vector dry_repeat_count; // Indexed as last_n_tokens @@ -3266,6 +3267,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs) } banned_token_ids.clear(); + toolcall_prevented_ids.clear(); if(banned_tokens.size()>0) { if(debugmode==1 && !is_quiet) @@ -3290,6 +3292,18 @@ generation_outputs gpttype_generate(const generation_inputs inputs) printf("\nBanned a total of %zu individual tokens.\n",banned_token_ids.size()); } } + if(inputs.tool_call_fix) + { + for(int v=0;v0) { @@ -4078,6 +4092,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs) float * logitsPtr; float lowestLogit = 0; int btsize = banned_token_ids.size(); + int tcpreventsize = toolcall_prevented_ids.size(); //sample pending logits. usually only 1, unless speculative decoding int logits_to_sample = 1; @@ -4144,6 +4159,14 @@ generation_outputs gpttype_generate(const generation_inputs inputs) logitsPtr[banned_token_ids[t]]=lowestLogit; } } + bool tcpreventtoks = ((kcpp_data->n_predict - remaining_tokens)<3); + if(tcpreventsize>0 && tcpreventtoks && std::count(concat_output.begin(), concat_output.end(), '[')<=1) + { + for(int t=0;t