handle token unbanning over api

This commit is contained in:
Concedo 2023-08-30 10:51:49 +08:00
parent f2c02dd06d
commit 89495c0716
3 changed files with 10 additions and 5 deletions

View file

@ -1458,7 +1458,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
}
float lowestLogit = LowestLogit(logitsPtr,n_vocab);
if (!unbanTokens)
if (!unbanTokens && !inputs.unban_tokens_rt)
{
// set the logit of the eos token (2) to -INF to avoid sampling it
logitsPtr[eosID] = lowestLogit;
@ -1476,7 +1476,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
{
logitsPtr = logits.data();
float lowestLogit = LowestLogit(logits);
if (!unbanTokens)
if (!unbanTokens && !inputs.unban_tokens_rt)
{
//gpt2 uses negative logits, so we cant zero it
// set the logit of the eos token to minimum to avoid sampling it
@ -1580,7 +1580,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
printf("]\n");
}
if(unbanTokens && id==eosID)
if((unbanTokens||inputs.unban_tokens_rt) && id==eosID)
{
stopper_unused_tokens = remaining_tokens;
printf("\n(EOS token triggered!)");