mirror of
https://github.com/LostRuins/koboldcpp.git
synced 2025-09-10 17:14:36 +00:00
refactor some fields
This commit is contained in:
parent
9f2076b4b3
commit
4531ab5465
1 changed files with 10 additions and 10 deletions
20
koboldcpp.py
20
koboldcpp.py
|
@ -775,7 +775,7 @@ def load_model(model_filename):
|
||||||
ret = handle.load_model(inputs)
|
ret = handle.load_model(inputs)
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
def generate(prompt, memory="", images=[], max_length=32, max_context_length=512, temperature=0.7, top_k=100, top_a=0.0, top_p=0.92, min_p=0.0, typical_p=1.0, tfs=1.0, rep_pen=1.0, rep_pen_range=128, rep_pen_slope=1.0, presence_penalty=0.0, mirostat=0, mirostat_tau=5.0, mirostat_eta=0.1, dry_multiplier=0.0, dry_base=1.75, dry_allowed_length=2, dry_penalty_last_n=0, dry_sequence_breakers=[], sampler_order=[6,0,1,3,4,2,5], seed=-1, stop_sequence=[], use_default_badwordsids=False, stream_sse=False, grammar='', grammar_retain_state=False, genkey='', trimstop=False, quiet=False, dynatemp_range=0.0, dynatemp_exponent=1.0, smoothing_factor=0.0, logit_biases={}, render_special=False, banned_tokens=[], bypass_eos_token=False):
|
def generate(prompt, memory="", images=[], max_length=32, max_context_length=512, temperature=0.7, top_k=100, top_a=0.0, top_p=0.92, min_p=0.0, typical_p=1.0, tfs=1.0, rep_pen=1.0, rep_pen_range=128, rep_pen_slope=1.0, presence_penalty=0.0, mirostat=0, mirostat_tau=5.0, mirostat_eta=0.1, dry_multiplier=0.0, dry_base=1.75, dry_allowed_length=2, dry_penalty_last_n=0, dry_sequence_breakers=[], sampler_order=[6,0,1,3,4,2,5], seed=-1, stop_sequence=[], ban_eos_token=False, stream_sse=False, grammar='', grammar_retain_state=False, genkey='', trimstop=False, quiet=False, dynatemp_range=0.0, dynatemp_exponent=1.0, smoothing_factor=0.0, logit_biases={}, render_special=False, banned_tokens=[], bypass_eos_token=False):
|
||||||
global maxctx, args, currentusergenkey, totalgens, pendingabortkey
|
global maxctx, args, currentusergenkey, totalgens, pendingabortkey
|
||||||
inputs = generation_inputs()
|
inputs = generation_inputs()
|
||||||
inputs.prompt = prompt.encode("UTF-8")
|
inputs.prompt = prompt.encode("UTF-8")
|
||||||
|
@ -814,7 +814,7 @@ def generate(prompt, memory="", images=[], max_length=32, max_context_length=512
|
||||||
inputs.smoothing_factor = smoothing_factor
|
inputs.smoothing_factor = smoothing_factor
|
||||||
inputs.grammar = grammar.encode("UTF-8")
|
inputs.grammar = grammar.encode("UTF-8")
|
||||||
inputs.grammar_retain_state = grammar_retain_state
|
inputs.grammar_retain_state = grammar_retain_state
|
||||||
inputs.allow_eos_token = not use_default_badwordsids
|
inputs.allow_eos_token = not ban_eos_token
|
||||||
inputs.bypass_eos_token = bypass_eos_token
|
inputs.bypass_eos_token = bypass_eos_token
|
||||||
inputs.render_special = render_special
|
inputs.render_special = render_special
|
||||||
if mirostat in (1, 2):
|
if mirostat in (1, 2):
|
||||||
|
@ -1078,6 +1078,8 @@ def transform_genparams(genparams, api_format):
|
||||||
rp3 = genparams.get('rep_pen', 1.0)
|
rp3 = genparams.get('rep_pen', 1.0)
|
||||||
rp_max = max(rp1,rp2,rp3)
|
rp_max = max(rp1,rp2,rp3)
|
||||||
genparams["rep_pen"] = rp_max
|
genparams["rep_pen"] = rp_max
|
||||||
|
if "use_default_badwordsids" in genparams and not ("ban_eos_token" in genparams):
|
||||||
|
genparams["ban_eos_token"] = genparams.get('use_default_badwordsids', False)
|
||||||
|
|
||||||
if api_format==1:
|
if api_format==1:
|
||||||
genparams["prompt"] = genparams.get('text', "")
|
genparams["prompt"] = genparams.get('text', "")
|
||||||
|
@ -1085,8 +1087,7 @@ def transform_genparams(genparams, api_format):
|
||||||
genparams["max_length"] = genparams.get('max', 150)
|
genparams["max_length"] = genparams.get('max', 150)
|
||||||
|
|
||||||
elif api_format==2:
|
elif api_format==2:
|
||||||
if "ignore_eos" in genparams and not ("use_default_badwordsids" in genparams):
|
pass
|
||||||
genparams["use_default_badwordsids"] = genparams.get('ignore_eos', False)
|
|
||||||
|
|
||||||
elif api_format==3 or api_format==4:
|
elif api_format==3 or api_format==4:
|
||||||
genparams["max_length"] = genparams.get('max_tokens', (350 if api_format==4 else 150))
|
genparams["max_length"] = genparams.get('max_tokens', (350 if api_format==4 else 150))
|
||||||
|
@ -1099,7 +1100,6 @@ def transform_genparams(genparams, api_format):
|
||||||
genparams["stop_sequence"] = [genparams.get('stop')]
|
genparams["stop_sequence"] = [genparams.get('stop')]
|
||||||
|
|
||||||
genparams["sampler_seed"] = tryparseint(genparams.get('seed', -1))
|
genparams["sampler_seed"] = tryparseint(genparams.get('seed', -1))
|
||||||
genparams["use_default_badwordsids"] = genparams.get('ignore_eos', False)
|
|
||||||
genparams["mirostat"] = genparams.get('mirostat_mode', 0)
|
genparams["mirostat"] = genparams.get('mirostat_mode', 0)
|
||||||
|
|
||||||
if api_format==4:
|
if api_format==4:
|
||||||
|
@ -1297,7 +1297,7 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler):
|
||||||
sampler_order=genparams.get('sampler_order', [6,0,1,3,4,2,5]),
|
sampler_order=genparams.get('sampler_order', [6,0,1,3,4,2,5]),
|
||||||
seed=tryparseint(genparams.get('sampler_seed', -1)),
|
seed=tryparseint(genparams.get('sampler_seed', -1)),
|
||||||
stop_sequence=genparams.get('stop_sequence', []),
|
stop_sequence=genparams.get('stop_sequence', []),
|
||||||
use_default_badwordsids=genparams.get('use_default_badwordsids', False),
|
ban_eos_token=genparams.get('ban_eos_token', False),
|
||||||
stream_sse=stream_flag,
|
stream_sse=stream_flag,
|
||||||
grammar=genparams.get('grammar', ''),
|
grammar=genparams.get('grammar', ''),
|
||||||
grammar_retain_state = genparams.get('grammar_retain_state', False),
|
grammar_retain_state = genparams.get('grammar_retain_state', False),
|
||||||
|
@ -1517,7 +1517,7 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler):
|
||||||
top_k = int(parsed_dict['top_k'][0]) if 'top_k' in parsed_dict else 100
|
top_k = int(parsed_dict['top_k'][0]) if 'top_k' in parsed_dict else 100
|
||||||
top_p = float(parsed_dict['top_p'][0]) if 'top_p' in parsed_dict else 0.9
|
top_p = float(parsed_dict['top_p'][0]) if 'top_p' in parsed_dict else 0.9
|
||||||
rep_pen = float(parsed_dict['rep_pen'][0]) if 'rep_pen' in parsed_dict else 1.0
|
rep_pen = float(parsed_dict['rep_pen'][0]) if 'rep_pen' in parsed_dict else 1.0
|
||||||
use_default_badwordsids = int(parsed_dict['use_default_badwordsids'][0]) if 'use_default_badwordsids' in parsed_dict else 0
|
ban_eos_token = int(parsed_dict['ban_eos_token'][0]) if 'ban_eos_token' in parsed_dict else 0
|
||||||
gencommand = (parsed_dict['generate'][0] if 'generate' in parsed_dict else "")=="Generate"
|
gencommand = (parsed_dict['generate'][0] if 'generate' in parsed_dict else "")=="Generate"
|
||||||
|
|
||||||
if modelbusy.locked():
|
if modelbusy.locked():
|
||||||
|
@ -1531,7 +1531,7 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler):
|
||||||
epurl = f"http://localhost:{args.port}"
|
epurl = f"http://localhost:{args.port}"
|
||||||
if args.host!="":
|
if args.host!="":
|
||||||
epurl = f"http://{args.host}:{args.port}"
|
epurl = f"http://{args.host}:{args.port}"
|
||||||
gen_payload = {"prompt": prompt,"max_length": max_length,"temperature": temperature,"prompt": prompt,"top_k": top_k,"top_p": top_p,"rep_pen": rep_pen,"use_default_badwordsids":use_default_badwordsids}
|
gen_payload = {"prompt": prompt,"max_length": max_length,"temperature": temperature,"prompt": prompt,"top_k": top_k,"top_p": top_p,"rep_pen": rep_pen,"ban_eos_token":ban_eos_token}
|
||||||
respjson = make_url_request(f'{epurl}/api/v1/generate', gen_payload)
|
respjson = make_url_request(f'{epurl}/api/v1/generate', gen_payload)
|
||||||
reply = html.escape(respjson["results"][0]["text"])
|
reply = html.escape(respjson["results"][0]["text"])
|
||||||
status = "Generation Completed"
|
status = "Generation Completed"
|
||||||
|
@ -1568,7 +1568,7 @@ Enter Prompt:<br>
|
||||||
<label>Top-K</label> <input type="text" size="4" value="{top_k}" name="top_k"><br>
|
<label>Top-K</label> <input type="text" size="4" value="{top_k}" name="top_k"><br>
|
||||||
<label>Top-P</label> <input type="text" size="4" value="{top_p}" name="top_p"><br>
|
<label>Top-P</label> <input type="text" size="4" value="{top_p}" name="top_p"><br>
|
||||||
<label>Rep. Pen</label> <input type="text" size="4" value="{rep_pen}" name="rep_pen"><br>
|
<label>Rep. Pen</label> <input type="text" size="4" value="{rep_pen}" name="rep_pen"><br>
|
||||||
<label>Ignore EOS</label> <input type="checkbox" name="use_default_badwordsids" value="1" {"checked" if use_default_badwordsids else ""}><br>
|
<label>Prevent EOS</label> <input type="checkbox" name="ban_eos_token" value="1" {"checked" if ban_eos_token else ""}><br>
|
||||||
<input type="submit" name="generate" value="Generate"> (Please be patient)
|
<input type="submit" name="generate" value="Generate"> (Please be patient)
|
||||||
</form>
|
</form>
|
||||||
<form action="/noscript">
|
<form action="/noscript">
|
||||||
|
@ -4040,7 +4040,7 @@ def main(launch_args,start_server=True):
|
||||||
benchprompt = "1111111111111111"
|
benchprompt = "1111111111111111"
|
||||||
for i in range(0,14): #generate massive prompt
|
for i in range(0,14): #generate massive prompt
|
||||||
benchprompt += benchprompt
|
benchprompt += benchprompt
|
||||||
genout = generate(benchprompt,memory="",images=[],max_length=benchlen,max_context_length=benchmaxctx,temperature=0.1,top_k=1,rep_pen=1,use_default_badwordsids=True)
|
genout = generate(benchprompt,memory="",images=[],max_length=benchlen,max_context_length=benchmaxctx,temperature=0.1,top_k=1,rep_pen=1,ban_eos_token=True)
|
||||||
result = genout['text']
|
result = genout['text']
|
||||||
result = (result[:5] if len(result)>5 else "")
|
result = (result[:5] if len(result)>5 else "")
|
||||||
t_pp = float(handle.get_last_process_time())*float(benchmaxctx-benchlen)*0.001
|
t_pp = float(handle.get_last_process_time())*float(benchmaxctx-benchlen)*0.001
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue