mirror of
https://github.com/LostRuins/koboldcpp.git
synced 2025-09-11 01:24:36 +00:00
added the ability to ban any substring tokens
This commit is contained in:
parent
27a0907cfa
commit
8424a35c62
3 changed files with 62 additions and 2 deletions
|
@ -76,6 +76,8 @@ static size_t mem_per_token = 0;
|
|||
static std::vector<float> logits;
|
||||
static std::vector<int> smartcontext;
|
||||
static std::vector<std::string> stop_sequence;
|
||||
static std::vector<std::string> banned_tokens;
|
||||
static std::vector<int> banned_token_ids;
|
||||
static std::vector<llama_token_data> top_picks;
|
||||
static int remaining_tokens = 0;
|
||||
static int stopper_unused_tokens = 0;
|
||||
|
@ -344,6 +346,17 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in
|
|||
= gpt2_ctx_v1.hparams.n_ctx = gpt2_ctx_v2.hparams.n_ctx = gpt2_ctx_v3.hparams.n_ctx
|
||||
= mpt_ctx_v3.hparams.n_ctx = params.n_ctx;
|
||||
|
||||
//handle custom token bans
|
||||
banned_tokens.clear();
|
||||
for(int x=0;x<ban_token_max;++x)
|
||||
{
|
||||
std::string word = inputs.banned_tokens[x];
|
||||
if(word!="")
|
||||
{
|
||||
banned_tokens.push_back(word);
|
||||
}
|
||||
}
|
||||
|
||||
//this is used for the mem_per_token eval, openblas needs more RAM
|
||||
bool use_scratch = ggml_cpu_has_gpublas();
|
||||
|
||||
|
@ -1064,6 +1077,25 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
|
|||
printf("Bad format!");
|
||||
}
|
||||
|
||||
//prepare banned tokens
|
||||
if(banned_token_ids.size()==0 && banned_tokens.size()>0)
|
||||
{
|
||||
printf("\n[First Run] Banning %d token sequences...",banned_tokens.size());
|
||||
for(int v=0;v<n_vocab;++v)
|
||||
{
|
||||
std::string word = FileFormatTokenizeID(v,file_format);
|
||||
for(int i=0;i<banned_tokens.size();++i)
|
||||
{
|
||||
if (word.find(banned_tokens[i]) != std::string::npos)
|
||||
{
|
||||
banned_token_ids.push_back(v);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
printf("\nBanned a total of %d tokens.\n",banned_token_ids.size());
|
||||
}
|
||||
|
||||
if(debugmode!=-1)
|
||||
{
|
||||
printf("\n");
|
||||
|
@ -1221,6 +1253,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
|
|||
|
||||
unsigned int eosID = 0;
|
||||
float * logitsPtr;
|
||||
int btsize = banned_token_ids.size();
|
||||
if(file_format == FileFormat::GGML || file_format == FileFormat::GGHF || file_format == FileFormat::GGJT || file_format == FileFormat::GGJT_2 || file_format == FileFormat::GGJT_3)
|
||||
{
|
||||
if(file_format == FileFormat::GGJT_3)
|
||||
|
@ -1239,6 +1272,14 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
|
|||
// set the logit of the eos token (2) to zero to avoid sampling it
|
||||
logitsPtr[eosID] = 0;
|
||||
}
|
||||
|
||||
if(btsize>0)
|
||||
{
|
||||
for(int t=0;t<btsize;++t)
|
||||
{
|
||||
logitsPtr[banned_token_ids[t]]=0;
|
||||
}
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
|
@ -1293,6 +1334,14 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
|
|||
}
|
||||
}
|
||||
|
||||
if(btsize>0)
|
||||
{
|
||||
int topid = std::min_element(logits.begin(), logits.end()) - logits.begin();
|
||||
for (int t = 0; t < btsize; ++t)
|
||||
{
|
||||
logits[banned_token_ids[t]] = (logits[topid] < 0 ? logits[topid] : 0);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
id = SampleLogits(logitsPtr, nctx, n_vocab, last_n_size, repeat_penalty,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue