diff --git a/gpttype_adapter.cpp b/gpttype_adapter.cpp index a7e134b8e..15f80dba3 100644 --- a/gpttype_adapter.cpp +++ b/gpttype_adapter.cpp @@ -505,7 +505,7 @@ void sample_dry(int n_ctx, int penalty_range, float penalty_multiplier, float pe if (penalty_multiplier <= 0.0f || penalty_base <= 0.0f) { return; } - if (penalty_range <= 0) { + if (penalty_range <= 0 || penalty_range>n_ctx) { penalty_range = n_ctx; } auto last_n_repeat = std::min(std::min((int)current_context_tokens.size(), penalty_range), n_ctx); @@ -843,6 +843,8 @@ int mirostat, float mirostat_tau, float mirostat_eta, float dry_multiplier, floa sample_grammar(file_format, n_vocab, &candidates_p, grammar); } + sample_dry(n_ctx, dry_penalty_last_n, dry_multiplier, dry_base, dry_allowed_length, dry_sequence_breakers, &candidates_p); + //prefilter to top 5k tokens for improved speed llama_sample_top_k(nullptr, &candidates_p, 5000, 1); @@ -901,7 +903,6 @@ int mirostat, float mirostat_tau, float mirostat_eta, float dry_multiplier, floa break; case KCPP_SAMPLER_REP_PEN: sample_rep_pen(n_ctx, rep_pen_range, rep_pen, rep_pen_slope, presence_penalty, &candidates_p); - sample_dry(n_ctx, dry_penalty_last_n, dry_multiplier, dry_base, dry_allowed_length, dry_sequence_breakers, &candidates_p); break; default: printf("\nSampleLogits: Unknown Sampler : %d",sampler_order[i]); @@ -1956,6 +1957,10 @@ generation_outputs gpttype_generate(const generation_inputs inputs) last_stop_reason = stop_reason::OUT_OF_TOKENS; stop_sequence.clear(); special_stop_sequence.clear(); + dry_repeat_count.clear(); + dry_sequence_breakers.clear(); + dry_max_token_repeat.clear(); + for(int x=0;x