mirror of
https://github.com/LostRuins/koboldcpp.git
synced 2025-09-10 17:14:36 +00:00
better tool calls
This commit is contained in:
parent
2853baf1e8
commit
3210b378e8
3 changed files with 27 additions and 0 deletions
1
expose.h
1
expose.h
|
@ -110,6 +110,7 @@ struct generation_inputs
|
||||||
const int sampler_len = 0;
|
const int sampler_len = 0;
|
||||||
const bool allow_eos_token = false;
|
const bool allow_eos_token = false;
|
||||||
const bool bypass_eos_token = false;
|
const bool bypass_eos_token = false;
|
||||||
|
const bool tool_call_fix = false; //this prevents close square bracket ] from being generated early.
|
||||||
const bool render_special = false;
|
const bool render_special = false;
|
||||||
const bool stream_sse = false;
|
const bool stream_sse = false;
|
||||||
const char * grammar = nullptr;
|
const char * grammar = nullptr;
|
||||||
|
|
|
@ -128,6 +128,7 @@ static std::vector<std::string> stop_sequence;
|
||||||
static std::vector<int> special_stop_sequence; //for stop sequences that don't have a string representation
|
static std::vector<int> special_stop_sequence; //for stop sequences that don't have a string representation
|
||||||
static std::vector<std::string> banned_tokens;
|
static std::vector<std::string> banned_tokens;
|
||||||
static std::vector<int> banned_token_ids;
|
static std::vector<int> banned_token_ids;
|
||||||
|
static std::vector<int> toolcall_prevented_ids; //temp ban these id for the first 3 tokens generated, to prevent empty replies
|
||||||
static std::vector<std::string> banned_phrases;
|
static std::vector<std::string> banned_phrases;
|
||||||
static std::unordered_multimap<gpt_vocab::id, std::vector<gpt_vocab::id>> dry_sequence_breakers; // Multi-mapping from first token of sequence to tail of sequence (tail is empty for a single token)
|
static std::unordered_multimap<gpt_vocab::id, std::vector<gpt_vocab::id>> dry_sequence_breakers; // Multi-mapping from first token of sequence to tail of sequence (tail is empty for a single token)
|
||||||
static std::vector<int> dry_repeat_count; // Indexed as last_n_tokens
|
static std::vector<int> dry_repeat_count; // Indexed as last_n_tokens
|
||||||
|
@ -3266,6 +3267,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
|
||||||
}
|
}
|
||||||
|
|
||||||
banned_token_ids.clear();
|
banned_token_ids.clear();
|
||||||
|
toolcall_prevented_ids.clear();
|
||||||
if(banned_tokens.size()>0)
|
if(banned_tokens.size()>0)
|
||||||
{
|
{
|
||||||
if(debugmode==1 && !is_quiet)
|
if(debugmode==1 && !is_quiet)
|
||||||
|
@ -3290,6 +3292,18 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
|
||||||
printf("\nBanned a total of %zu individual tokens.\n",banned_token_ids.size());
|
printf("\nBanned a total of %zu individual tokens.\n",banned_token_ids.size());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if(inputs.tool_call_fix)
|
||||||
|
{
|
||||||
|
for(int v=0;v<n_vocab;++v)
|
||||||
|
{
|
||||||
|
std::string word = FileFormatTokenizeID(v,file_format, true);
|
||||||
|
word = toLowerCase(word);
|
||||||
|
if (word.find(']') != std::string::npos)
|
||||||
|
{
|
||||||
|
toolcall_prevented_ids.push_back(v);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if(debugmode==1 && !is_quiet && banned_phrases.size()>0)
|
if(debugmode==1 && !is_quiet && banned_phrases.size()>0)
|
||||||
{
|
{
|
||||||
|
@ -4078,6 +4092,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
|
||||||
float * logitsPtr;
|
float * logitsPtr;
|
||||||
float lowestLogit = 0;
|
float lowestLogit = 0;
|
||||||
int btsize = banned_token_ids.size();
|
int btsize = banned_token_ids.size();
|
||||||
|
int tcpreventsize = toolcall_prevented_ids.size();
|
||||||
|
|
||||||
//sample pending logits. usually only 1, unless speculative decoding
|
//sample pending logits. usually only 1, unless speculative decoding
|
||||||
int logits_to_sample = 1;
|
int logits_to_sample = 1;
|
||||||
|
@ -4144,6 +4159,14 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
|
||||||
logitsPtr[banned_token_ids[t]]=lowestLogit;
|
logitsPtr[banned_token_ids[t]]=lowestLogit;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
bool tcpreventtoks = ((kcpp_data->n_predict - remaining_tokens)<3);
|
||||||
|
if(tcpreventsize>0 && tcpreventtoks && std::count(concat_output.begin(), concat_output.end(), '[')<=1)
|
||||||
|
{
|
||||||
|
for(int t=0;t<tcpreventsize;++t)
|
||||||
|
{
|
||||||
|
logitsPtr[toolcall_prevented_ids[t]]=lowestLogit;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
//handle temp bans from antislop
|
//handle temp bans from antislop
|
||||||
if (antislop_banned_token_ids.find(n_past) != antislop_banned_token_ids.end()) {
|
if (antislop_banned_token_ids.find(n_past) != antislop_banned_token_ids.end()) {
|
||||||
|
|
|
@ -244,6 +244,7 @@ class generation_inputs(ctypes.Structure):
|
||||||
("sampler_len", ctypes.c_int),
|
("sampler_len", ctypes.c_int),
|
||||||
("allow_eos_token", ctypes.c_bool),
|
("allow_eos_token", ctypes.c_bool),
|
||||||
("bypass_eos_token", ctypes.c_bool),
|
("bypass_eos_token", ctypes.c_bool),
|
||||||
|
("tool_call_fix", ctypes.c_bool),
|
||||||
("render_special", ctypes.c_bool),
|
("render_special", ctypes.c_bool),
|
||||||
("stream_sse", ctypes.c_bool),
|
("stream_sse", ctypes.c_bool),
|
||||||
("grammar", ctypes.c_char_p),
|
("grammar", ctypes.c_char_p),
|
||||||
|
@ -1477,6 +1478,7 @@ def generate(genparams, stream_flag=False):
|
||||||
banned_strings = genparams.get('banned_strings', []) # SillyTavern uses that name
|
banned_strings = genparams.get('banned_strings', []) # SillyTavern uses that name
|
||||||
banned_tokens = genparams.get('banned_tokens', banned_strings)
|
banned_tokens = genparams.get('banned_tokens', banned_strings)
|
||||||
bypass_eos_token = genparams.get('bypass_eos', False)
|
bypass_eos_token = genparams.get('bypass_eos', False)
|
||||||
|
tool_call_fix = genparams.get('using_openai_tools', False)
|
||||||
custom_token_bans = genparams.get('custom_token_bans', '')
|
custom_token_bans = genparams.get('custom_token_bans', '')
|
||||||
|
|
||||||
for tok in custom_token_bans.split(','):
|
for tok in custom_token_bans.split(','):
|
||||||
|
@ -1535,6 +1537,7 @@ def generate(genparams, stream_flag=False):
|
||||||
inputs.grammar_retain_state = grammar_retain_state
|
inputs.grammar_retain_state = grammar_retain_state
|
||||||
inputs.allow_eos_token = not ban_eos_token
|
inputs.allow_eos_token = not ban_eos_token
|
||||||
inputs.bypass_eos_token = bypass_eos_token
|
inputs.bypass_eos_token = bypass_eos_token
|
||||||
|
inputs.tool_call_fix = tool_call_fix
|
||||||
inputs.render_special = render_special
|
inputs.render_special = render_special
|
||||||
if mirostat in (1, 2):
|
if mirostat in (1, 2):
|
||||||
inputs.mirostat = mirostat
|
inputs.mirostat = mirostat
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue