mirror of
https://github.com/LostRuins/koboldcpp.git
synced 2025-09-10 17:14:36 +00:00
refactored a lot of code, remove bantokens, move it to api
This commit is contained in:
parent
4ec8a9c57b
commit
c230b78906
6 changed files with 214 additions and 76 deletions
|
@ -837,17 +837,6 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in
|
|||
gptj_ctx_v3.hparams.rope_freq_scale = neox_ctx_v3.hparams.rope_freq_scale = rope_freq_scale;
|
||||
gptj_ctx_v3.hparams.rope_freq_base = neox_ctx_v3.hparams.rope_freq_base = rope_freq_base;
|
||||
|
||||
//handle custom token bans
|
||||
banned_tokens.clear();
|
||||
for(int x=0;x<ban_token_max;++x)
|
||||
{
|
||||
std::string word = inputs.banned_tokens[x];
|
||||
if(word!="")
|
||||
{
|
||||
banned_tokens.push_back(word);
|
||||
}
|
||||
}
|
||||
|
||||
//this is used for the mem_per_token eval, openblas needs more RAM
|
||||
bool v3_use_scratch = ggml_v3_cpu_has_gpublas();
|
||||
|
||||
|
@ -1624,6 +1613,41 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
|
|||
}
|
||||
}
|
||||
|
||||
//handle custom token bans
|
||||
banned_tokens.clear();
|
||||
for(int x=0;x<ban_token_max;++x)
|
||||
{
|
||||
std::string word = inputs.banned_tokens[x];
|
||||
if(word!="")
|
||||
{
|
||||
banned_tokens.push_back(word);
|
||||
}
|
||||
}
|
||||
banned_token_ids.clear();
|
||||
if(banned_tokens.size()>0)
|
||||
{
|
||||
if(debugmode==1)
|
||||
{
|
||||
printf("\nBanning %zu token sequences...",banned_tokens.size());
|
||||
}
|
||||
for(int v=0;v<n_vocab;++v)
|
||||
{
|
||||
std::string word = FileFormatTokenizeID(v,file_format, true);
|
||||
for(int i=0;i<banned_tokens.size();++i)
|
||||
{
|
||||
if (word.find(banned_tokens[i]) != std::string::npos)
|
||||
{
|
||||
banned_token_ids.push_back(v);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
if(debugmode==1)
|
||||
{
|
||||
printf("\nBanned a total of %zu tokens.\n",banned_token_ids.size());
|
||||
}
|
||||
}
|
||||
|
||||
logit_biases.clear();
|
||||
for(int x=0;x<logit_bias_max;++x)
|
||||
{
|
||||
|
@ -1993,25 +2017,6 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
|
|||
printf("\nWarning! n_vocab is invalid, maybe bad format!");
|
||||
}
|
||||
|
||||
//prepare banned tokens
|
||||
if(banned_token_ids.size()==0 && banned_tokens.size()>0)
|
||||
{
|
||||
printf("\n[First Run] Banning %zu token sequences...",banned_tokens.size());
|
||||
for(int v=0;v<n_vocab;++v)
|
||||
{
|
||||
std::string word = FileFormatTokenizeID(v,file_format, true);
|
||||
for(int i=0;i<banned_tokens.size();++i)
|
||||
{
|
||||
if (word.find(banned_tokens[i]) != std::string::npos)
|
||||
{
|
||||
banned_token_ids.push_back(v);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
printf("\nBanned a total of %zu tokens.\n",banned_token_ids.size());
|
||||
}
|
||||
|
||||
if(allow_regular_prints)
|
||||
{
|
||||
printf("\n");
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue