diff --git a/class.py b/class.py index 0478bef5b..531c1ddaf 100644 --- a/class.py +++ b/class.py @@ -270,7 +270,7 @@ class model_backend(InferenceModel): port=5001, port_param=5001, host='', launch=False, lora=None, threads=self.kcpp_threads, blasthreads=self.kcpp_threads, psutil_set_threads=False, highpriority=False, contextsize=self.kcpp_ctxsize, blasbatchsize=self.kcpp_blasbatchsize, ropeconfig=[self.kcpp_ropescale, self.kcpp_ropebase], stream=False, smartcontext=self.kcpp_smartcontext, - unbantokens=False, bantokens=None, usemirostat=None, forceversion=0, nommap=self.kcpp_nommap, + usemirostat=None, forceversion=0, nommap=self.kcpp_nommap, usemlock=False, noavx2=self.kcpp_noavx2, debugmode=self.kcpp_debugmode, skiplauncher=True, hordeconfig=None, noblas=self.kcpp_noblas, useclblast=self.kcpp_useclblast, usecublas=self.kcpp_usecublas, usevulkan=self.kcpp_usevulkan, gpulayers=self.kcpp_gpulayers, tensor_split=self.kcpp_tensor_split, config=None, onready='', multiuser=False, foreground=False, preloadstory=None, noshift=False, remotetunnel=False, ssl=False, benchmark=None, nocertify=False, sdconfig=None, mmproj=None, diff --git a/expose.h b/expose.h index fa064b96e..1ef865e73 100644 --- a/expose.h +++ b/expose.h @@ -55,7 +55,6 @@ struct load_model_inputs const int gpulayers = 0; const float rope_freq_scale = 1.0f; const float rope_freq_base = 10000.0f; - const char * banned_tokens[ban_token_max]; const float tensor_split[tensor_split_max]; }; struct generation_inputs @@ -92,7 +91,7 @@ struct generation_inputs const float dynatemp_exponent = 1.0f; const float smoothing_factor = 0.0f; const logit_bias logit_biases[logit_bias_max]; - + const char * banned_tokens[ban_token_max]; }; struct generation_outputs { diff --git a/gpttype_adapter.cpp b/gpttype_adapter.cpp index 1adf603a1..3eeeec5dc 100644 --- a/gpttype_adapter.cpp +++ b/gpttype_adapter.cpp @@ -837,17 +837,6 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in gptj_ctx_v3.hparams.rope_freq_scale = neox_ctx_v3.hparams.rope_freq_scale = rope_freq_scale; gptj_ctx_v3.hparams.rope_freq_base = neox_ctx_v3.hparams.rope_freq_base = rope_freq_base; - //handle custom token bans - banned_tokens.clear(); - for(int x=0;x0) + { + if(debugmode==1) + { + printf("\nBanning %zu token sequences...",banned_tokens.size()); + } + for(int v=0;v0) - { - printf("\n[First Run] Banning %zu token sequences...",banned_tokens.size()); - for(int v=0;v @@ -299,7 +299,7 @@ Current version: 135 padding-right: 10px; } - #extrastopseq, #anotetemplate { + .inlineinput { background-color: #404040; color: #ffffff; resize: none; @@ -3603,6 +3603,7 @@ Current version: 135 var current_anote = ""; //stored author note var current_anotetemplate = "[Author\'s note: <|>]"; var extrastopseq = ""; + var tokenbans = ""; var anote_strength = 320; //distance from end var newlineaftermemory = true; var current_wi = []; //each item stores a wi object. @@ -3738,8 +3739,8 @@ Current version: 135 passed_ai_warning: false, //used to store AI safety panel acknowledgement state entersubmit: true, //enter sends the prompt - max_context_length: 1600, - max_length: 120, + max_context_length: 1800, + max_length: 140, auto_ctxlen: true, auto_genamt: true, rep_pen: 1.1, @@ -5095,6 +5096,7 @@ Current version: 135 //extra unofficial fields for the story new_save_storyobj.extrastopseq = extrastopseq; + new_save_storyobj.tokenbans = tokenbans; new_save_storyobj.anotestr = anote_strength; new_save_storyobj.wisearchdepth = wi_searchdepth; new_save_storyobj.wiinsertlocation = wi_insertlocation; @@ -5271,6 +5273,7 @@ Current version: 135 let old_current_memory = current_memory; let old_current_wi = current_wi; let old_extrastopseq = extrastopseq; + let old_tokenbans = tokenbans; let old_notes = personal_notes; let old_regexreplace_data = regexreplace_data; @@ -5325,6 +5328,10 @@ Current version: 135 if (storyobj.extrastopseq) { extrastopseq = storyobj.extrastopseq; } + if(storyobj.tokenbans) + { + tokenbans = storyobj.tokenbans; + } if (storyobj.anotestr) { anote_strength = storyobj.anotestr; } @@ -5416,6 +5423,7 @@ Current version: 135 { extrastopseq = old_extrastopseq; regexreplace_data = old_regexreplace_data; + tokenbans = old_tokenbans; } if (storyobj.savedsettings && storyobj.savedsettings != "") @@ -6746,25 +6754,53 @@ Current version: 135 },false,true); } - var pendinglogitbias = {}; - function set_logit_bias() + function expand_tokens_section(targetid) { - inputBox("Enter OpenAI-formatted logit bias dictionary. Each key is the integer token IDs and their values are the biases (-100.0 to 100.0)
Input is a JSON object, reference here.
Leave blank to disable.
","Set Logit Biases",JSON.stringify(pendinglogitbias),"Enter JSON Object",()=>{ - let userinput = getInputBoxValue().trim(); - if(userinput=="") + let tablist = ["expandregexreplace","expandtokenbans","expandlogitbias"]; + + for(let i=0;i{ + let userinput = getInputBoxValue(); + if(userinput.trim()!="") + { + let ov = document.getElementById("tokenbans").value; + if(ov!="") + { + ov += "||$||"; + } + ov += userinput.trim(); + document.getElementById("tokenbans").value = ov; + } + },false); + } + var msgboxOnDone = hide_msgbox; function hide_msgbox() { //hide msgbox ONLY @@ -8919,8 +8972,20 @@ Current version: 135 current_anotetemplate = document.getElementById("anotetemplate").value; anote_strength = document.getElementById("anote_strength").value; extrastopseq = document.getElementById("extrastopseq").value; + tokenbans = document.getElementById("tokenbans").value; newlineaftermemory = (document.getElementById("newlineaftermemory").checked?true:false); - logitbiasdict = pendinglogitbias; + try + { + let lb = document.getElementById("logitbiastxtarea").value; + let dict = {}; + if(lb!="") + { + dict = JSON.parse(lb); + } + logitbiasdict = dict; + } catch (e) { + console.log("Your logit bias JSON dictionary was not correctly formatted!"); + } regexreplace_data = []; for(let i=0;i 0 && !seqs) { + seqs = []; + } + for (let i = 0; i < srep.length; ++i) { + if (srep[i] && srep[i] != "") { + seqs.push(srep[i]); + } + } + } + return seqs; + } + function dispatch_submit_generation(submit_payload, input_was_empty) //if input is not empty, always unban eos { console.log(submit_payload); @@ -10367,6 +10451,7 @@ Current version: 135 submit_payload.params.dynatemp_range = localsettings.dynatemp_range; submit_payload.params.dynatemp_exponent = localsettings.dynatemp_exponent; submit_payload.params.smoothing_factor = localsettings.smoothing_factor; + submit_payload.params.banned_tokens = get_token_bans(); } //presence pen and logit bias for OAI and newer kcpp if((custom_kobold_endpoint != "" && is_using_kcpp_with_mirostat()) || custom_oai_endpoint!="") @@ -13165,8 +13250,10 @@ Current version: 135 document.getElementById("anotetemplate").value = current_anotetemplate; document.getElementById("anote_strength").value = anote_strength; document.getElementById("extrastopseq").value = extrastopseq; + document.getElementById("tokenbans").value = tokenbans; document.getElementById("newlineaftermemory").checked = (newlineaftermemory?true:false); - pendinglogitbias = logitbiasdict; + document.getElementById("logitbiastxtarea").value = JSON.stringify(logitbiasdict,null,2); + if(custom_kobold_endpoint!="" || !is_using_custom_ep() ) { document.getElementById("noextrastopseq").classList.add("hidden"); @@ -13183,7 +13270,16 @@ Current version: 135 //setup regex replacers populate_regex_replacers(); - document.getElementById("btnlogitbias").disabled = !is_using_custom_ep(); + if(is_using_custom_ep()) + { + document.getElementById("nologitbias").classList.add("hidden"); + document.getElementById("notokenbans").classList.add("hidden"); + } + else + { + document.getElementById("nologitbias").classList.remove("hidden"); + document.getElementById("notokenbans").classList.remove("hidden"); + } } @@ -15291,7 +15387,7 @@ Current version: 135
- +
-
-

Logit Biases ?Specify a dictionary of token IDs to modify the probability of occuring.
-
-
-
-

Custom Regex Replace ?Allows transforming incoming text with regex patterns, modifying all matches. Replacements will be applied in sequence.
-
- -
+
Logit Biases ?Specify a dictionary of token IDs to modify the probability of occuring. + +
+ + +
Token Bans ?Outright removal for ANY tokens containing a specific substring from model vocab. If you want multiple sequences, separate them with the following delimiter: ||$|| + +
+ + +
Regex Replace ?Allows transforming incoming text with regex patterns, modifying all matches. Replacements will be applied in sequence. + +
+ + +
diff --git a/koboldcpp.py b/koboldcpp.py index aadf9a8a8..7557a3fd8 100644 --- a/koboldcpp.py +++ b/koboldcpp.py @@ -56,7 +56,6 @@ class load_model_inputs(ctypes.Structure): ("gpulayers", ctypes.c_int), ("rope_freq_scale", ctypes.c_float), ("rope_freq_base", ctypes.c_float), - ("banned_tokens", ctypes.c_char_p * ban_token_max), ("tensor_split", ctypes.c_float * tensor_split_max)] class generation_inputs(ctypes.Structure): @@ -91,7 +90,8 @@ class generation_inputs(ctypes.Structure): ("dynatemp_range", ctypes.c_float), ("dynatemp_exponent", ctypes.c_float), ("smoothing_factor", ctypes.c_float), - ("logit_biases", logit_bias * logit_bias_max)] + ("logit_biases", logit_bias * logit_bias_max), + ("banned_tokens", ctypes.c_char_p * ban_token_max)] class generation_outputs(ctypes.Structure): _fields_ = [("status", ctypes.c_int), @@ -391,16 +391,10 @@ def load_model(model_filename): inputs.executable_path = (getdirpath()+"/").encode("UTF-8") inputs.debugmode = args.debugmode - banned_tokens = args.bantokens - for n in range(ban_token_max): - if not banned_tokens or n >= len(banned_tokens): - inputs.banned_tokens[n] = "".encode("UTF-8") - else: - inputs.banned_tokens[n] = banned_tokens[n].encode("UTF-8") ret = handle.load_model(inputs) return ret -def generate(prompt, memory="", images=[], max_length=32, max_context_length=512, temperature=0.7, top_k=100, top_a=0.0, top_p=0.92, min_p=0.0, typical_p=1.0, tfs=1.0, rep_pen=1.0, rep_pen_range=128, presence_penalty=0.0, mirostat=0, mirostat_tau=5.0, mirostat_eta=0.1, sampler_order=[6,0,1,3,4,2,5], seed=-1, stop_sequence=[], use_default_badwordsids=False, stream_sse=False, grammar='', grammar_retain_state=False, genkey='', trimstop=False, quiet=False, dynatemp_range=0.0, dynatemp_exponent=1.0, smoothing_factor=0.0, logit_biases={}, render_special=False): +def generate(prompt, memory="", images=[], max_length=32, max_context_length=512, temperature=0.7, top_k=100, top_a=0.0, top_p=0.92, min_p=0.0, typical_p=1.0, tfs=1.0, rep_pen=1.0, rep_pen_range=128, presence_penalty=0.0, mirostat=0, mirostat_tau=5.0, mirostat_eta=0.1, sampler_order=[6,0,1,3,4,2,5], seed=-1, stop_sequence=[], use_default_badwordsids=False, stream_sse=False, grammar='', grammar_retain_state=False, genkey='', trimstop=False, quiet=False, dynatemp_range=0.0, dynatemp_exponent=1.0, smoothing_factor=0.0, logit_biases={}, render_special=False, banned_tokens=[]): global maxctx, args, currentusergenkey, totalgens, pendingabortkey inputs = generation_inputs() inputs.prompt = prompt.encode("UTF-8") @@ -487,6 +481,12 @@ def generate(prompt, memory="", images=[], max_length=32, max_context_length=512 inputs.logit_biases[n] = logit_bias(-1, 0.0) print(f"Skipped unparsable logit bias:{ex}") + for n in range(ban_token_max): + if not banned_tokens or n >= len(banned_tokens): + inputs.banned_tokens[n] = "".encode("UTF-8") + else: + inputs.banned_tokens[n] = banned_tokens[n].encode("UTF-8") + currentusergenkey = genkey totalgens += 1 #early exit if aborted @@ -672,6 +672,10 @@ def transform_genparams(genparams, api_format): genparams["top_k"] = int(genparams.get('top_k', 120)) genparams["max_length"] = genparams.get('max', 100) + elif api_format==2: + if "ignore_eos" in genparams and not ("use_default_badwordsids" in genparams): + genparams["use_default_badwordsids"] = genparams.get('ignore_eos', False) + elif api_format==3 or api_format==4: genparams["max_length"] = genparams.get('max_tokens', 100) presence_penalty = genparams.get('presence_penalty', genparams.get('frequency_penalty', 0.0)) @@ -813,6 +817,7 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler): smoothing_factor=genparams.get('smoothing_factor', 0.0), logit_biases=genparams.get('logit_bias', {}), render_special=genparams.get('render_special', False), + banned_tokens=genparams.get('banned_tokens', []), ) genout = {"text":"","status":-1,"stopreason":-1} @@ -3281,7 +3286,6 @@ if __name__ == '__main__': parser.add_argument("--lora", help="LLAMA models only, applies a lora file on top of model. Experimental.", metavar=('[lora_filename]', '[lora_base]'), nargs='+') parser.add_argument("--smartcontext", help="Reserving a portion of context to try processing less frequently.", action='store_true') parser.add_argument("--noshift", help="If set, do not attempt to Trim and Shift the GGUF context.", action='store_true') - parser.add_argument("--bantokens", help="You can manually specify a list of token SUBSTRINGS that the AI cannot use. This bans ALL instances of that substring.", metavar=('[token_substrings]'), nargs='+') parser.add_argument("--forceversion", help="If the model file format detection fails (e.g. rogue modified model) you can set this to override the detected format (enter desired version, e.g. 401 for GPTNeoX-Type2).",metavar=('[version]'), type=int, default=0) parser.add_argument("--nommap", help="If set, do not use mmap to load newer models", action='store_true') parser.add_argument("--usemlock", help="For Apple Systems. Force system to keep model in RAM rather than swapping or compressing", action='store_true')