warn about RNN models not supporting antislop

This commit is contained in:
Concedo 2026-03-06 14:02:51 +08:00
parent 389773070f
commit e36d7b6464

View file

@ -530,16 +530,16 @@ static std::string toLowerCase(const std::string& str) {
}
void ContextRewind(std::vector<int> &embd, std::vector<int> &current_context_tokens, int &n_past, std::vector<int> &last_n_tokens, const int amount_rewind)
bool ContextRewind(std::vector<int> &embd, std::vector<int> &current_context_tokens, int &n_past, std::vector<int> &last_n_tokens, const int amount_rewind)
{
if(amount_rewind<=0 || current_context_tokens.size()==0)
{
return; //do nothing
return true; //do nothing
}
if(embd.size()>1)
{
printf("\nWARNING: Don't use context rewind when in batch processing phase!\n");
return;
return false;
}
bool is_recurrent = false;
if(file_format==FileFormat::GGUF_GENERIC)
@ -552,12 +552,12 @@ void ContextRewind(std::vector<int> &embd, std::vector<int> &current_context_tok
}
if(file_format == FileFormat::RWKV_1 || file_format==FileFormat::RWKV_2 || is_recurrent)
{
if(!showed_rnn_warning && debugmode==1 && !is_quiet)
if(!showed_rnn_warning)
{
showed_rnn_warning = true;
printf("\nWARNING: RNN models do not support context rewind!\n");
printf("\n!!!\nWARNING: RNN models do not support context rewind! Anti-Slop sampler will not work!\n!!!\n");
}
return;
return false;
}
if (amount_rewind >= last_n_tokens.size())
@ -610,6 +610,7 @@ void ContextRewind(std::vector<int> &embd, std::vector<int> &current_context_tok
{
embd.push_back(current_context_tokens[current_context_tokens.size()-1]);
}
return true;
}
const char * kcpp_print_system_info(void) {
@ -4900,27 +4901,32 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
if (rewind_amt > 0 && (current_context_tokens.size() - rewind_amt) > 0)
{
int last_tok = current_context_tokens[current_context_tokens.size() - rewind_amt];
delayed_generated_tokens.resize(delayed_generated_tokens.size() - rewind_amt);
ContextRewind(embd, current_context_tokens, n_past, last_n_tokens, rewind_amt);
bool rwok = ContextRewind(embd, current_context_tokens, n_past, last_n_tokens, rewind_amt);
//immediately terminate drafting if used
abort_draft = true;
// Check if the key exists
int banindex = n_past+1;
if (antislop_banned_token_ids.find(banindex) == antislop_banned_token_ids.end()) {
antislop_banned_token_ids[banindex] = std::vector<int>();
}
std::vector<int>& current_ids = antislop_banned_token_ids[banindex];
current_ids.push_back(last_tok);
if (allow_regular_prints && debugmode == 1)
if(rwok)
{
auto match_clean = matched;
replace_all(match_clean, "\n", "\\n");
printf("\n(Banned Phrase Detected: %s - Add ID %d to banlist at index %d, and rewinding %d tokens)\n", match_clean.c_str(), last_tok, banindex, rewind_amt);
}
delayed_generated_tokens.resize(delayed_generated_tokens.size() - rewind_amt);
// Check if the key exists
int banindex = n_past+1;
if (antislop_banned_token_ids.find(banindex) == antislop_banned_token_ids.end()) {
antislop_banned_token_ids[banindex] = std::vector<int>();
}
std::vector<int>& current_ids = antislop_banned_token_ids[banindex];
current_ids.push_back(last_tok);
if (allow_regular_prints && debugmode == 1)
{
auto match_clean = matched;
replace_all(match_clean, "\n", "\\n");
printf("\n(Banned Phrase Detected: %s - Add ID %d to banlist at index %d, and rewinding %d tokens)\n", match_clean.c_str(), last_tok, banindex, rewind_amt);
}
}
break;
}
}