better tool calls

This commit is contained in:
Concedo 2025-08-20 22:11:31 +08:00
parent 2853baf1e8
commit 3210b378e8
3 changed files with 27 additions and 0 deletions

View file

@ -110,6 +110,7 @@ struct generation_inputs
const int sampler_len = 0; const int sampler_len = 0;
const bool allow_eos_token = false; const bool allow_eos_token = false;
const bool bypass_eos_token = false; const bool bypass_eos_token = false;
const bool tool_call_fix = false; //this prevents close square bracket ] from being generated early.
const bool render_special = false; const bool render_special = false;
const bool stream_sse = false; const bool stream_sse = false;
const char * grammar = nullptr; const char * grammar = nullptr;

View file

@ -128,6 +128,7 @@ static std::vector<std::string> stop_sequence;
static std::vector<int> special_stop_sequence; //for stop sequences that don't have a string representation static std::vector<int> special_stop_sequence; //for stop sequences that don't have a string representation
static std::vector<std::string> banned_tokens; static std::vector<std::string> banned_tokens;
static std::vector<int> banned_token_ids; static std::vector<int> banned_token_ids;
static std::vector<int> toolcall_prevented_ids; //temp ban these id for the first 3 tokens generated, to prevent empty replies
static std::vector<std::string> banned_phrases; static std::vector<std::string> banned_phrases;
static std::unordered_multimap<gpt_vocab::id, std::vector<gpt_vocab::id>> dry_sequence_breakers; // Multi-mapping from first token of sequence to tail of sequence (tail is empty for a single token) static std::unordered_multimap<gpt_vocab::id, std::vector<gpt_vocab::id>> dry_sequence_breakers; // Multi-mapping from first token of sequence to tail of sequence (tail is empty for a single token)
static std::vector<int> dry_repeat_count; // Indexed as last_n_tokens static std::vector<int> dry_repeat_count; // Indexed as last_n_tokens
@ -3266,6 +3267,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
} }
banned_token_ids.clear(); banned_token_ids.clear();
toolcall_prevented_ids.clear();
if(banned_tokens.size()>0) if(banned_tokens.size()>0)
{ {
if(debugmode==1 && !is_quiet) if(debugmode==1 && !is_quiet)
@ -3290,6 +3292,18 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
printf("\nBanned a total of %zu individual tokens.\n",banned_token_ids.size()); printf("\nBanned a total of %zu individual tokens.\n",banned_token_ids.size());
} }
} }
if(inputs.tool_call_fix)
{
for(int v=0;v<n_vocab;++v)
{
std::string word = FileFormatTokenizeID(v,file_format, true);
word = toLowerCase(word);
if (word.find(']') != std::string::npos)
{
toolcall_prevented_ids.push_back(v);
}
}
}
if(debugmode==1 && !is_quiet && banned_phrases.size()>0) if(debugmode==1 && !is_quiet && banned_phrases.size()>0)
{ {
@ -4078,6 +4092,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
float * logitsPtr; float * logitsPtr;
float lowestLogit = 0; float lowestLogit = 0;
int btsize = banned_token_ids.size(); int btsize = banned_token_ids.size();
int tcpreventsize = toolcall_prevented_ids.size();
//sample pending logits. usually only 1, unless speculative decoding //sample pending logits. usually only 1, unless speculative decoding
int logits_to_sample = 1; int logits_to_sample = 1;
@ -4144,6 +4159,14 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
logitsPtr[banned_token_ids[t]]=lowestLogit; logitsPtr[banned_token_ids[t]]=lowestLogit;
} }
} }
bool tcpreventtoks = ((kcpp_data->n_predict - remaining_tokens)<3);
if(tcpreventsize>0 && tcpreventtoks && std::count(concat_output.begin(), concat_output.end(), '[')<=1)
{
for(int t=0;t<tcpreventsize;++t)
{
logitsPtr[toolcall_prevented_ids[t]]=lowestLogit;
}
}
//handle temp bans from antislop //handle temp bans from antislop
if (antislop_banned_token_ids.find(n_past) != antislop_banned_token_ids.end()) { if (antislop_banned_token_ids.find(n_past) != antislop_banned_token_ids.end()) {

View file

@ -244,6 +244,7 @@ class generation_inputs(ctypes.Structure):
("sampler_len", ctypes.c_int), ("sampler_len", ctypes.c_int),
("allow_eos_token", ctypes.c_bool), ("allow_eos_token", ctypes.c_bool),
("bypass_eos_token", ctypes.c_bool), ("bypass_eos_token", ctypes.c_bool),
("tool_call_fix", ctypes.c_bool),
("render_special", ctypes.c_bool), ("render_special", ctypes.c_bool),
("stream_sse", ctypes.c_bool), ("stream_sse", ctypes.c_bool),
("grammar", ctypes.c_char_p), ("grammar", ctypes.c_char_p),
@ -1477,6 +1478,7 @@ def generate(genparams, stream_flag=False):
banned_strings = genparams.get('banned_strings', []) # SillyTavern uses that name banned_strings = genparams.get('banned_strings', []) # SillyTavern uses that name
banned_tokens = genparams.get('banned_tokens', banned_strings) banned_tokens = genparams.get('banned_tokens', banned_strings)
bypass_eos_token = genparams.get('bypass_eos', False) bypass_eos_token = genparams.get('bypass_eos', False)
tool_call_fix = genparams.get('using_openai_tools', False)
custom_token_bans = genparams.get('custom_token_bans', '') custom_token_bans = genparams.get('custom_token_bans', '')
for tok in custom_token_bans.split(','): for tok in custom_token_bans.split(','):
@ -1535,6 +1537,7 @@ def generate(genparams, stream_flag=False):
inputs.grammar_retain_state = grammar_retain_state inputs.grammar_retain_state = grammar_retain_state
inputs.allow_eos_token = not ban_eos_token inputs.allow_eos_token = not ban_eos_token
inputs.bypass_eos_token = bypass_eos_token inputs.bypass_eos_token = bypass_eos_token
inputs.tool_call_fix = tool_call_fix
inputs.render_special = render_special inputs.render_special = render_special
if mirostat in (1, 2): if mirostat in (1, 2):
inputs.mirostat = mirostat inputs.mirostat = mirostat