diff --git a/koboldcpp.py b/koboldcpp.py
index dc4d5f785..8ac20c262 100644
--- a/koboldcpp.py
+++ b/koboldcpp.py
@@ -775,7 +775,7 @@ def load_model(model_filename):
ret = handle.load_model(inputs)
return ret
-def generate(prompt, memory="", images=[], max_length=32, max_context_length=512, temperature=0.7, top_k=100, top_a=0.0, top_p=0.92, min_p=0.0, typical_p=1.0, tfs=1.0, rep_pen=1.0, rep_pen_range=128, rep_pen_slope=1.0, presence_penalty=0.0, mirostat=0, mirostat_tau=5.0, mirostat_eta=0.1, dry_multiplier=0.0, dry_base=1.75, dry_allowed_length=2, dry_penalty_last_n=0, dry_sequence_breakers=[], sampler_order=[6,0,1,3,4,2,5], seed=-1, stop_sequence=[], use_default_badwordsids=False, stream_sse=False, grammar='', grammar_retain_state=False, genkey='', trimstop=False, quiet=False, dynatemp_range=0.0, dynatemp_exponent=1.0, smoothing_factor=0.0, logit_biases={}, render_special=False, banned_tokens=[], bypass_eos_token=False):
+def generate(prompt, memory="", images=[], max_length=32, max_context_length=512, temperature=0.7, top_k=100, top_a=0.0, top_p=0.92, min_p=0.0, typical_p=1.0, tfs=1.0, rep_pen=1.0, rep_pen_range=128, rep_pen_slope=1.0, presence_penalty=0.0, mirostat=0, mirostat_tau=5.0, mirostat_eta=0.1, dry_multiplier=0.0, dry_base=1.75, dry_allowed_length=2, dry_penalty_last_n=0, dry_sequence_breakers=[], sampler_order=[6,0,1,3,4,2,5], seed=-1, stop_sequence=[], ban_eos_token=False, stream_sse=False, grammar='', grammar_retain_state=False, genkey='', trimstop=False, quiet=False, dynatemp_range=0.0, dynatemp_exponent=1.0, smoothing_factor=0.0, logit_biases={}, render_special=False, banned_tokens=[], bypass_eos_token=False):
global maxctx, args, currentusergenkey, totalgens, pendingabortkey
inputs = generation_inputs()
inputs.prompt = prompt.encode("UTF-8")
@@ -814,7 +814,7 @@ def generate(prompt, memory="", images=[], max_length=32, max_context_length=512
inputs.smoothing_factor = smoothing_factor
inputs.grammar = grammar.encode("UTF-8")
inputs.grammar_retain_state = grammar_retain_state
- inputs.allow_eos_token = not use_default_badwordsids
+ inputs.allow_eos_token = not ban_eos_token
inputs.bypass_eos_token = bypass_eos_token
inputs.render_special = render_special
if mirostat in (1, 2):
@@ -1078,6 +1078,8 @@ def transform_genparams(genparams, api_format):
rp3 = genparams.get('rep_pen', 1.0)
rp_max = max(rp1,rp2,rp3)
genparams["rep_pen"] = rp_max
+ if "use_default_badwordsids" in genparams and not ("ban_eos_token" in genparams):
+ genparams["ban_eos_token"] = genparams.get('use_default_badwordsids', False)
if api_format==1:
genparams["prompt"] = genparams.get('text', "")
@@ -1085,8 +1087,7 @@ def transform_genparams(genparams, api_format):
genparams["max_length"] = genparams.get('max', 150)
elif api_format==2:
- if "ignore_eos" in genparams and not ("use_default_badwordsids" in genparams):
- genparams["use_default_badwordsids"] = genparams.get('ignore_eos', False)
+ pass
elif api_format==3 or api_format==4:
genparams["max_length"] = genparams.get('max_tokens', (350 if api_format==4 else 150))
@@ -1099,7 +1100,6 @@ def transform_genparams(genparams, api_format):
genparams["stop_sequence"] = [genparams.get('stop')]
genparams["sampler_seed"] = tryparseint(genparams.get('seed', -1))
- genparams["use_default_badwordsids"] = genparams.get('ignore_eos', False)
genparams["mirostat"] = genparams.get('mirostat_mode', 0)
if api_format==4:
@@ -1297,7 +1297,7 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler):
sampler_order=genparams.get('sampler_order', [6,0,1,3,4,2,5]),
seed=tryparseint(genparams.get('sampler_seed', -1)),
stop_sequence=genparams.get('stop_sequence', []),
- use_default_badwordsids=genparams.get('use_default_badwordsids', False),
+ ban_eos_token=genparams.get('ban_eos_token', False),
stream_sse=stream_flag,
grammar=genparams.get('grammar', ''),
grammar_retain_state = genparams.get('grammar_retain_state', False),
@@ -1517,7 +1517,7 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler):
top_k = int(parsed_dict['top_k'][0]) if 'top_k' in parsed_dict else 100
top_p = float(parsed_dict['top_p'][0]) if 'top_p' in parsed_dict else 0.9
rep_pen = float(parsed_dict['rep_pen'][0]) if 'rep_pen' in parsed_dict else 1.0
- use_default_badwordsids = int(parsed_dict['use_default_badwordsids'][0]) if 'use_default_badwordsids' in parsed_dict else 0
+ ban_eos_token = int(parsed_dict['ban_eos_token'][0]) if 'ban_eos_token' in parsed_dict else 0
gencommand = (parsed_dict['generate'][0] if 'generate' in parsed_dict else "")=="Generate"
if modelbusy.locked():
@@ -1531,7 +1531,7 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler):
epurl = f"http://localhost:{args.port}"
if args.host!="":
epurl = f"http://{args.host}:{args.port}"
- gen_payload = {"prompt": prompt,"max_length": max_length,"temperature": temperature,"prompt": prompt,"top_k": top_k,"top_p": top_p,"rep_pen": rep_pen,"use_default_badwordsids":use_default_badwordsids}
+ gen_payload = {"prompt": prompt,"max_length": max_length,"temperature": temperature,"prompt": prompt,"top_k": top_k,"top_p": top_p,"rep_pen": rep_pen,"ban_eos_token":ban_eos_token}
respjson = make_url_request(f'{epurl}/api/v1/generate', gen_payload)
reply = html.escape(respjson["results"][0]["text"])
status = "Generation Completed"
@@ -1568,7 +1568,7 @@ Enter Prompt:
-
+
(Please be patient)