refactored a lot of code, remove bantokens, move it to api

This commit is contained in:
Concedo 2024-04-27 17:57:13 +08:00
parent 4ec8a9c57b
commit c230b78906
6 changed files with 214 additions and 76 deletions

View file

@ -837,17 +837,6 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in
gptj_ctx_v3.hparams.rope_freq_scale = neox_ctx_v3.hparams.rope_freq_scale = rope_freq_scale;
gptj_ctx_v3.hparams.rope_freq_base = neox_ctx_v3.hparams.rope_freq_base = rope_freq_base;
//handle custom token bans
banned_tokens.clear();
for(int x=0;x<ban_token_max;++x)
{
std::string word = inputs.banned_tokens[x];
if(word!="")
{
banned_tokens.push_back(word);
}
}
//this is used for the mem_per_token eval, openblas needs more RAM
bool v3_use_scratch = ggml_v3_cpu_has_gpublas();
@ -1624,6 +1613,41 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
}
}
//handle custom token bans
banned_tokens.clear();
for(int x=0;x<ban_token_max;++x)
{
std::string word = inputs.banned_tokens[x];
if(word!="")
{
banned_tokens.push_back(word);
}
}
banned_token_ids.clear();
if(banned_tokens.size()>0)
{
if(debugmode==1)
{
printf("\nBanning %zu token sequences...",banned_tokens.size());
}
for(int v=0;v<n_vocab;++v)
{
std::string word = FileFormatTokenizeID(v,file_format, true);
for(int i=0;i<banned_tokens.size();++i)
{
if (word.find(banned_tokens[i]) != std::string::npos)
{
banned_token_ids.push_back(v);
break;
}
}
}
if(debugmode==1)
{
printf("\nBanned a total of %zu tokens.\n",banned_token_ids.size());
}
}
logit_biases.clear();
for(int x=0;x<logit_bias_max;++x)
{
@ -1993,25 +2017,6 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
printf("\nWarning! n_vocab is invalid, maybe bad format!");
}
//prepare banned tokens
if(banned_token_ids.size()==0 && banned_tokens.size()>0)
{
printf("\n[First Run] Banning %zu token sequences...",banned_tokens.size());
for(int v=0;v<n_vocab;++v)
{
std::string word = FileFormatTokenizeID(v,file_format, true);
for(int i=0;i<banned_tokens.size();++i)
{
if (word.find(banned_tokens[i]) != std::string::npos)
{
banned_token_ids.push_back(v);
break;
}
}
}
printf("\nBanned a total of %zu tokens.\n",banned_token_ids.size());
}
if(allow_regular_prints)
{
printf("\n");