mirror of
https://github.com/LostRuins/koboldcpp.git
synced 2025-09-10 17:14:36 +00:00
better tool calls
This commit is contained in:
parent
2853baf1e8
commit
3210b378e8
3 changed files with 27 additions and 0 deletions
|
@ -128,6 +128,7 @@ static std::vector<std::string> stop_sequence;
|
|||
static std::vector<int> special_stop_sequence; //for stop sequences that don't have a string representation
|
||||
static std::vector<std::string> banned_tokens;
|
||||
static std::vector<int> banned_token_ids;
|
||||
static std::vector<int> toolcall_prevented_ids; //temp ban these id for the first 3 tokens generated, to prevent empty replies
|
||||
static std::vector<std::string> banned_phrases;
|
||||
static std::unordered_multimap<gpt_vocab::id, std::vector<gpt_vocab::id>> dry_sequence_breakers; // Multi-mapping from first token of sequence to tail of sequence (tail is empty for a single token)
|
||||
static std::vector<int> dry_repeat_count; // Indexed as last_n_tokens
|
||||
|
@ -3266,6 +3267,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
|
|||
}
|
||||
|
||||
banned_token_ids.clear();
|
||||
toolcall_prevented_ids.clear();
|
||||
if(banned_tokens.size()>0)
|
||||
{
|
||||
if(debugmode==1 && !is_quiet)
|
||||
|
@ -3290,6 +3292,18 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
|
|||
printf("\nBanned a total of %zu individual tokens.\n",banned_token_ids.size());
|
||||
}
|
||||
}
|
||||
if(inputs.tool_call_fix)
|
||||
{
|
||||
for(int v=0;v<n_vocab;++v)
|
||||
{
|
||||
std::string word = FileFormatTokenizeID(v,file_format, true);
|
||||
word = toLowerCase(word);
|
||||
if (word.find(']') != std::string::npos)
|
||||
{
|
||||
toolcall_prevented_ids.push_back(v);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if(debugmode==1 && !is_quiet && banned_phrases.size()>0)
|
||||
{
|
||||
|
@ -4078,6 +4092,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
|
|||
float * logitsPtr;
|
||||
float lowestLogit = 0;
|
||||
int btsize = banned_token_ids.size();
|
||||
int tcpreventsize = toolcall_prevented_ids.size();
|
||||
|
||||
//sample pending logits. usually only 1, unless speculative decoding
|
||||
int logits_to_sample = 1;
|
||||
|
@ -4144,6 +4159,14 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
|
|||
logitsPtr[banned_token_ids[t]]=lowestLogit;
|
||||
}
|
||||
}
|
||||
bool tcpreventtoks = ((kcpp_data->n_predict - remaining_tokens)<3);
|
||||
if(tcpreventsize>0 && tcpreventtoks && std::count(concat_output.begin(), concat_output.end(), '[')<=1)
|
||||
{
|
||||
for(int t=0;t<tcpreventsize;++t)
|
||||
{
|
||||
logitsPtr[toolcall_prevented_ids[t]]=lowestLogit;
|
||||
}
|
||||
}
|
||||
|
||||
//handle temp bans from antislop
|
||||
if (antislop_banned_token_ids.find(n_past) != antislop_banned_token_ids.end()) {
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue