better tool calls

This commit is contained in:
Concedo 2025-08-20 22:11:31 +08:00
parent 2853baf1e8
commit 3210b378e8
3 changed files with 27 additions and 0 deletions

View file

@ -128,6 +128,7 @@ static std::vector<std::string> stop_sequence;
static std::vector<int> special_stop_sequence; //for stop sequences that don't have a string representation
static std::vector<std::string> banned_tokens;
static std::vector<int> banned_token_ids;
static std::vector<int> toolcall_prevented_ids; //temp ban these id for the first 3 tokens generated, to prevent empty replies
static std::vector<std::string> banned_phrases;
static std::unordered_multimap<gpt_vocab::id, std::vector<gpt_vocab::id>> dry_sequence_breakers; // Multi-mapping from first token of sequence to tail of sequence (tail is empty for a single token)
static std::vector<int> dry_repeat_count; // Indexed as last_n_tokens
@ -3266,6 +3267,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
}
banned_token_ids.clear();
toolcall_prevented_ids.clear();
if(banned_tokens.size()>0)
{
if(debugmode==1 && !is_quiet)
@ -3290,6 +3292,18 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
printf("\nBanned a total of %zu individual tokens.\n",banned_token_ids.size());
}
}
if(inputs.tool_call_fix)
{
for(int v=0;v<n_vocab;++v)
{
std::string word = FileFormatTokenizeID(v,file_format, true);
word = toLowerCase(word);
if (word.find(']') != std::string::npos)
{
toolcall_prevented_ids.push_back(v);
}
}
}
if(debugmode==1 && !is_quiet && banned_phrases.size()>0)
{
@ -4078,6 +4092,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
float * logitsPtr;
float lowestLogit = 0;
int btsize = banned_token_ids.size();
int tcpreventsize = toolcall_prevented_ids.size();
//sample pending logits. usually only 1, unless speculative decoding
int logits_to_sample = 1;
@ -4144,6 +4159,14 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
logitsPtr[banned_token_ids[t]]=lowestLogit;
}
}
bool tcpreventtoks = ((kcpp_data->n_predict - remaining_tokens)<3);
if(tcpreventsize>0 && tcpreventtoks && std::count(concat_output.begin(), concat_output.end(), '[')<=1)
{
for(int t=0;t<tcpreventsize;++t)
{
logitsPtr[toolcall_prevented_ids[t]]=lowestLogit;
}
}
//handle temp bans from antislop
if (antislop_banned_token_ids.find(n_past) != antislop_banned_token_ids.end()) {