added token unbanning

This commit is contained in:
Concedo 2023-04-24 21:50:20 +08:00
parent 1b9b9068b1
commit 3962eb39c7
3 changed files with 32 additions and 16 deletions

View file

@ -41,6 +41,7 @@ static int n_past = 0;
static int n_threads = 4;
static int n_batch = 8;
static bool useSmartContext = false;
static bool unbanTokens = false;
static int blasbatchsize = 512;
static std::string modelname;
static std::vector<gpt_vocab::id> last_n_tokens;
@ -65,6 +66,7 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in
n_batch = params.n_batch = inputs.batch_size;
modelname = params.model = inputs.model_filename;
useSmartContext = inputs.use_smartcontext;
unbanTokens = inputs.unban_tokens;
blasbatchsize = inputs.blasbatchsize;
params.memory_f16 = inputs.f16_kv;
params.n_ctx = inputs.max_context_length;
@ -366,7 +368,10 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
}
params.n_batch = bbs; //received reports of 1024 and above crashing on some models
//params.n_threads = 1; //do not limit here anymore.
if(!ggml_cpu_has_cublas())
{
params.n_threads = 1; //do not limit here anymore.
}
}
current_context_tokens.resize(n_past);
@ -512,28 +517,35 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
if(file_format == FileFormat::GGML || file_format == FileFormat::GGHF || file_format == FileFormat::GGJT)
{
auto logits = llama_get_logits(llama_ctx_v1);
// set the logit of the eos token (2) to zero to avoid sampling it
logits[llama_token_eos()] = 0;
//set logits of opening square bracket to zero.
logits[518] = 0;
logits[29961] = 0;
if (!unbanTokens)
{
// set the logit of the eos token (2) to zero to avoid sampling it
logits[llama_token_eos()] = 0;
//set logits of opening square bracket to zero.
logits[518] = 0;
logits[29961] = 0;
}
id = llama_sample_top_p_top_k(llama_ctx_v1, last_n_tokens.data(), last_n_tokens.size(), top_k, top_p, temp, repeat_penalty);
}
else
{
// set the logit of the eos token (2) to zero to avoid sampling it
if((file_format == FileFormat::GPT2_1 ||
file_format == FileFormat::GPT2_2 ||
file_format == FileFormat::GPTJ_1 ||
file_format == FileFormat::GPTJ_2 ||
file_format == FileFormat::GPTJ_3)
&& logits.size()>50256)
if (!unbanTokens)
{
logits[50256] = (logits[50256] < 0 ? logits[50256] : 0);
// set the logit of the eos token (2) to zero to avoid sampling it
if ((file_format == FileFormat::GPT2_1 ||
file_format == FileFormat::GPT2_2 ||
file_format == FileFormat::GPTJ_1 ||
file_format == FileFormat::GPTJ_2 ||
file_format == FileFormat::GPTJ_3) &&
logits.size() > 50256)
{
logits[50256] = (logits[50256] < 0 ? logits[50256] : 0);
}
//gpt2 uses negative logits, so we cant zero it
}
//gpt2 uses negative logits, so we cant zero it
id = gptj_sample_top_p_top_k(vocab, logits.data() + (logits.size() - n_vocab), last_n_tokens, repeat_penalty, top_k, top_p, temp, rng);
}