From 173c7272d583fa20f7df1f2c266f6b8597abd540 Mon Sep 17 00:00:00 2001 From: Concedo <39025047+LostRuins@users.noreply.github.com> Date: Mon, 6 May 2024 18:01:49 +0800 Subject: [PATCH] EOS bypass mode added --- expose.h | 1 + gpttype_adapter.cpp | 4 +- klite.embd | 151 ++++++++++++++++++++++++++++++-------------- koboldcpp.py | 5 +- 4 files changed, 112 insertions(+), 49 deletions(-) diff --git a/expose.h b/expose.h index 04565a357..2242f58a0 100644 --- a/expose.h +++ b/expose.h @@ -82,6 +82,7 @@ struct generation_inputs const samplers sampler_order[KCPP_SAMPLER_MAX]; const int sampler_len; const bool allow_eos_token; + const bool bypass_eos_token = false; const bool render_special; const char * stop_sequence[stop_token_max]; const bool stream_sse; diff --git a/gpttype_adapter.cpp b/gpttype_adapter.cpp index fe6dfc1c4..91bf28aa0 100644 --- a/gpttype_adapter.cpp +++ b/gpttype_adapter.cpp @@ -2205,7 +2205,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs) lowestLogit = LowestLogit(logits); } - if (!inputs.allow_eos_token) + if (!inputs.allow_eos_token && !inputs.bypass_eos_token) { // set the logit of the eos token to very low to avoid sampling it logitsPtr[eosID] = lowestLogit; @@ -2274,7 +2274,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs) printf("]\n"); } - if(inputs.allow_eos_token && (id==eosID || (id==eotID && id!=-1))) + if(!inputs.bypass_eos_token && inputs.allow_eos_token && (id==eosID || (id==eotID && id!=-1))) { stopper_unused_tokens = remaining_tokens; if(allow_regular_prints) diff --git a/klite.embd b/klite.embd index 4f8c5cd4a..0c0efdb8f 100644 --- a/klite.embd +++ b/klite.embd @@ -7,7 +7,7 @@ Just copy this single static HTML file anywhere and open it in a browser, or fro Please go to https://github.com/LostRuins/lite.koboldai.net for updates on Kobold Lite. If you are submitting a pull request for Lite, PLEASE use the above repo, not the KoboldCpp one. Kobold Lite is under the AGPL v3.0 License unless otherwise exempted. Please do not remove this line. -Current version: 137 +Current version: 138 -Concedo --> @@ -2081,6 +2081,7 @@ Current version: 137 const instructmodels1 = ["gpt4all","supercot","hermes","airoboros","chrono","wizard","mantis","vicuna","manticore","alpaca","myth","xwin","spicyboros","mlewd","mxlewd","mistral","maid","mixtral","estopia","fighter","fimbul"]; const instructmodels2 = ["erebus","nerys","nerybus","janeway","opt","llama"]; const defaultmodels = ["gpt4all","supercot","hermes","airoboros","chrono","wizard","mantis","vicuna","manticore","alpaca","myth","xwin","spicyboros","mlewd","mxlewd","llama","mistral","maid","mixtral","estopia","fighter","fimbul"]; + const ignoredmodels = ["tinyllama"]; const instructstartplaceholder = "\n{{[INPUT]}}\n"; const instructendplaceholder = "\n{{[OUTPUT]}}\n"; @@ -3705,7 +3706,7 @@ Current version: 137 trimsentences: true, //trim to last punctuation trimwhitespace: false, //trim trailing whitespace compressnewlines: false, //compress multiple newlines - eos_ban_mode: 0, //allow the EOS token when using locally 0=auto,1=unban,2=ban + eos_ban_mode: 0, //allow the EOS token when using locally 0=auto,1=unban,2=ban,3=bypass opmode: 4, //what mode are we in? 1=story, 2=adventure, 3=chat, 4=instruct adventure_is_action: false, //in adventure mode, determine story or action adventure_context_mod: true, //extra injection for adventure mode @@ -3995,10 +3996,21 @@ Current version: 137 if (mdls.length > 0) { for (var i = 0; i < mdls.length; ++i) { - for (var j = 0; j < defaultmodels.length; ++j) { - if (mdls[i].name.trim().toLowerCase().includes(defaultmodels[j].trim().toLowerCase()) || - defaultmodels[j].trim().toLowerCase().includes(mdls[i].name.trim().toLowerCase())) { - selected_models.push(mdls[i]); + let skipignored = false; + for(let k=0;k= 0) { submit_payload.use_default_badwordsids = determine_if_ban_eos(input_was_empty); + if(is_using_kcpp_with_added_memory()) + { + submit_payload.bypass_eos = (localsettings.eos_ban_mode == 3?true:false); + } } let pseudostreaming = (determine_streaming_type()==1); @@ -15447,6 +15505,7 @@ Current version: 137 +
diff --git a/koboldcpp.py b/koboldcpp.py index 3f72bf71f..40e081c7a 100644 --- a/koboldcpp.py +++ b/koboldcpp.py @@ -82,6 +82,7 @@ class generation_inputs(ctypes.Structure): ("sampler_order", ctypes.c_int * sampler_order_max), ("sampler_len", ctypes.c_int), ("allow_eos_token", ctypes.c_bool), + ("bypass_eos_token", ctypes.c_bool), ("render_special", ctypes.c_bool), ("stop_sequence", ctypes.c_char_p * stop_token_max), ("stream_sse", ctypes.c_bool), @@ -396,7 +397,7 @@ def load_model(model_filename): ret = handle.load_model(inputs) return ret -def generate(prompt, memory="", images=[], max_length=32, max_context_length=512, temperature=0.7, top_k=100, top_a=0.0, top_p=0.92, min_p=0.0, typical_p=1.0, tfs=1.0, rep_pen=1.0, rep_pen_range=128, presence_penalty=0.0, mirostat=0, mirostat_tau=5.0, mirostat_eta=0.1, sampler_order=[6,0,1,3,4,2,5], seed=-1, stop_sequence=[], use_default_badwordsids=False, stream_sse=False, grammar='', grammar_retain_state=False, genkey='', trimstop=False, quiet=False, dynatemp_range=0.0, dynatemp_exponent=1.0, smoothing_factor=0.0, logit_biases={}, render_special=False, banned_tokens=[]): +def generate(prompt, memory="", images=[], max_length=32, max_context_length=512, temperature=0.7, top_k=100, top_a=0.0, top_p=0.92, min_p=0.0, typical_p=1.0, tfs=1.0, rep_pen=1.0, rep_pen_range=128, presence_penalty=0.0, mirostat=0, mirostat_tau=5.0, mirostat_eta=0.1, sampler_order=[6,0,1,3,4,2,5], seed=-1, stop_sequence=[], use_default_badwordsids=False, stream_sse=False, grammar='', grammar_retain_state=False, genkey='', trimstop=False, quiet=False, dynatemp_range=0.0, dynatemp_exponent=1.0, smoothing_factor=0.0, logit_biases={}, render_special=False, banned_tokens=[], bypass_eos_token=False): global maxctx, args, currentusergenkey, totalgens, pendingabortkey inputs = generation_inputs() inputs.prompt = prompt.encode("UTF-8") @@ -435,6 +436,7 @@ def generate(prompt, memory="", images=[], max_length=32, max_context_length=512 inputs.grammar = grammar.encode("UTF-8") inputs.grammar_retain_state = grammar_retain_state inputs.allow_eos_token = not use_default_badwordsids + inputs.bypass_eos_token = bypass_eos_token inputs.render_special = render_special if mirostat in (1, 2): inputs.mirostat = mirostat @@ -823,6 +825,7 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler): logit_biases=genparams.get('logit_bias', {}), render_special=genparams.get('render_special', False), banned_tokens=genparams.get('banned_tokens', []), + bypass_eos_token=genparams.get('bypass_eos', False), ) genout = {"text":"","status":-1,"stopreason":-1}