diff --git a/koboldcpp.py b/koboldcpp.py index dc4d5f785..8ac20c262 100644 --- a/koboldcpp.py +++ b/koboldcpp.py @@ -775,7 +775,7 @@ def load_model(model_filename): ret = handle.load_model(inputs) return ret -def generate(prompt, memory="", images=[], max_length=32, max_context_length=512, temperature=0.7, top_k=100, top_a=0.0, top_p=0.92, min_p=0.0, typical_p=1.0, tfs=1.0, rep_pen=1.0, rep_pen_range=128, rep_pen_slope=1.0, presence_penalty=0.0, mirostat=0, mirostat_tau=5.0, mirostat_eta=0.1, dry_multiplier=0.0, dry_base=1.75, dry_allowed_length=2, dry_penalty_last_n=0, dry_sequence_breakers=[], sampler_order=[6,0,1,3,4,2,5], seed=-1, stop_sequence=[], use_default_badwordsids=False, stream_sse=False, grammar='', grammar_retain_state=False, genkey='', trimstop=False, quiet=False, dynatemp_range=0.0, dynatemp_exponent=1.0, smoothing_factor=0.0, logit_biases={}, render_special=False, banned_tokens=[], bypass_eos_token=False): +def generate(prompt, memory="", images=[], max_length=32, max_context_length=512, temperature=0.7, top_k=100, top_a=0.0, top_p=0.92, min_p=0.0, typical_p=1.0, tfs=1.0, rep_pen=1.0, rep_pen_range=128, rep_pen_slope=1.0, presence_penalty=0.0, mirostat=0, mirostat_tau=5.0, mirostat_eta=0.1, dry_multiplier=0.0, dry_base=1.75, dry_allowed_length=2, dry_penalty_last_n=0, dry_sequence_breakers=[], sampler_order=[6,0,1,3,4,2,5], seed=-1, stop_sequence=[], ban_eos_token=False, stream_sse=False, grammar='', grammar_retain_state=False, genkey='', trimstop=False, quiet=False, dynatemp_range=0.0, dynatemp_exponent=1.0, smoothing_factor=0.0, logit_biases={}, render_special=False, banned_tokens=[], bypass_eos_token=False): global maxctx, args, currentusergenkey, totalgens, pendingabortkey inputs = generation_inputs() inputs.prompt = prompt.encode("UTF-8") @@ -814,7 +814,7 @@ def generate(prompt, memory="", images=[], max_length=32, max_context_length=512 inputs.smoothing_factor = smoothing_factor inputs.grammar = grammar.encode("UTF-8") inputs.grammar_retain_state = grammar_retain_state - inputs.allow_eos_token = not use_default_badwordsids + inputs.allow_eos_token = not ban_eos_token inputs.bypass_eos_token = bypass_eos_token inputs.render_special = render_special if mirostat in (1, 2): @@ -1078,6 +1078,8 @@ def transform_genparams(genparams, api_format): rp3 = genparams.get('rep_pen', 1.0) rp_max = max(rp1,rp2,rp3) genparams["rep_pen"] = rp_max + if "use_default_badwordsids" in genparams and not ("ban_eos_token" in genparams): + genparams["ban_eos_token"] = genparams.get('use_default_badwordsids', False) if api_format==1: genparams["prompt"] = genparams.get('text', "") @@ -1085,8 +1087,7 @@ def transform_genparams(genparams, api_format): genparams["max_length"] = genparams.get('max', 150) elif api_format==2: - if "ignore_eos" in genparams and not ("use_default_badwordsids" in genparams): - genparams["use_default_badwordsids"] = genparams.get('ignore_eos', False) + pass elif api_format==3 or api_format==4: genparams["max_length"] = genparams.get('max_tokens', (350 if api_format==4 else 150)) @@ -1099,7 +1100,6 @@ def transform_genparams(genparams, api_format): genparams["stop_sequence"] = [genparams.get('stop')] genparams["sampler_seed"] = tryparseint(genparams.get('seed', -1)) - genparams["use_default_badwordsids"] = genparams.get('ignore_eos', False) genparams["mirostat"] = genparams.get('mirostat_mode', 0) if api_format==4: @@ -1297,7 +1297,7 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler): sampler_order=genparams.get('sampler_order', [6,0,1,3,4,2,5]), seed=tryparseint(genparams.get('sampler_seed', -1)), stop_sequence=genparams.get('stop_sequence', []), - use_default_badwordsids=genparams.get('use_default_badwordsids', False), + ban_eos_token=genparams.get('ban_eos_token', False), stream_sse=stream_flag, grammar=genparams.get('grammar', ''), grammar_retain_state = genparams.get('grammar_retain_state', False), @@ -1517,7 +1517,7 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler): top_k = int(parsed_dict['top_k'][0]) if 'top_k' in parsed_dict else 100 top_p = float(parsed_dict['top_p'][0]) if 'top_p' in parsed_dict else 0.9 rep_pen = float(parsed_dict['rep_pen'][0]) if 'rep_pen' in parsed_dict else 1.0 - use_default_badwordsids = int(parsed_dict['use_default_badwordsids'][0]) if 'use_default_badwordsids' in parsed_dict else 0 + ban_eos_token = int(parsed_dict['ban_eos_token'][0]) if 'ban_eos_token' in parsed_dict else 0 gencommand = (parsed_dict['generate'][0] if 'generate' in parsed_dict else "")=="Generate" if modelbusy.locked(): @@ -1531,7 +1531,7 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler): epurl = f"http://localhost:{args.port}" if args.host!="": epurl = f"http://{args.host}:{args.port}" - gen_payload = {"prompt": prompt,"max_length": max_length,"temperature": temperature,"prompt": prompt,"top_k": top_k,"top_p": top_p,"rep_pen": rep_pen,"use_default_badwordsids":use_default_badwordsids} + gen_payload = {"prompt": prompt,"max_length": max_length,"temperature": temperature,"prompt": prompt,"top_k": top_k,"top_p": top_p,"rep_pen": rep_pen,"ban_eos_token":ban_eos_token} respjson = make_url_request(f'{epurl}/api/v1/generate', gen_payload) reply = html.escape(respjson["results"][0]["text"]) status = "Generation Completed" @@ -1568,7 +1568,7 @@ Enter Prompt:



-
+
(Please be patient)
@@ -4040,7 +4040,7 @@ def main(launch_args,start_server=True): benchprompt = "1111111111111111" for i in range(0,14): #generate massive prompt benchprompt += benchprompt - genout = generate(benchprompt,memory="",images=[],max_length=benchlen,max_context_length=benchmaxctx,temperature=0.1,top_k=1,rep_pen=1,use_default_badwordsids=True) + genout = generate(benchprompt,memory="",images=[],max_length=benchlen,max_context_length=benchmaxctx,temperature=0.1,top_k=1,rep_pen=1,ban_eos_token=True) result = genout['text'] result = (result[:5] if len(result)>5 else "") t_pp = float(handle.get_last_process_time())*float(benchmaxctx-benchlen)*0.001