refactor some fields

This commit is contained in:
Concedo 2024-07-27 00:04:29 +08:00
parent 9f2076b4b3
commit 4531ab5465

View file

@ -775,7 +775,7 @@ def load_model(model_filename):
ret = handle.load_model(inputs)
return ret
def generate(prompt, memory="", images=[], max_length=32, max_context_length=512, temperature=0.7, top_k=100, top_a=0.0, top_p=0.92, min_p=0.0, typical_p=1.0, tfs=1.0, rep_pen=1.0, rep_pen_range=128, rep_pen_slope=1.0, presence_penalty=0.0, mirostat=0, mirostat_tau=5.0, mirostat_eta=0.1, dry_multiplier=0.0, dry_base=1.75, dry_allowed_length=2, dry_penalty_last_n=0, dry_sequence_breakers=[], sampler_order=[6,0,1,3,4,2,5], seed=-1, stop_sequence=[], use_default_badwordsids=False, stream_sse=False, grammar='', grammar_retain_state=False, genkey='', trimstop=False, quiet=False, dynatemp_range=0.0, dynatemp_exponent=1.0, smoothing_factor=0.0, logit_biases={}, render_special=False, banned_tokens=[], bypass_eos_token=False):
def generate(prompt, memory="", images=[], max_length=32, max_context_length=512, temperature=0.7, top_k=100, top_a=0.0, top_p=0.92, min_p=0.0, typical_p=1.0, tfs=1.0, rep_pen=1.0, rep_pen_range=128, rep_pen_slope=1.0, presence_penalty=0.0, mirostat=0, mirostat_tau=5.0, mirostat_eta=0.1, dry_multiplier=0.0, dry_base=1.75, dry_allowed_length=2, dry_penalty_last_n=0, dry_sequence_breakers=[], sampler_order=[6,0,1,3,4,2,5], seed=-1, stop_sequence=[], ban_eos_token=False, stream_sse=False, grammar='', grammar_retain_state=False, genkey='', trimstop=False, quiet=False, dynatemp_range=0.0, dynatemp_exponent=1.0, smoothing_factor=0.0, logit_biases={}, render_special=False, banned_tokens=[], bypass_eos_token=False):
global maxctx, args, currentusergenkey, totalgens, pendingabortkey
inputs = generation_inputs()
inputs.prompt = prompt.encode("UTF-8")
@ -814,7 +814,7 @@ def generate(prompt, memory="", images=[], max_length=32, max_context_length=512
inputs.smoothing_factor = smoothing_factor
inputs.grammar = grammar.encode("UTF-8")
inputs.grammar_retain_state = grammar_retain_state
inputs.allow_eos_token = not use_default_badwordsids
inputs.allow_eos_token = not ban_eos_token
inputs.bypass_eos_token = bypass_eos_token
inputs.render_special = render_special
if mirostat in (1, 2):
@ -1078,6 +1078,8 @@ def transform_genparams(genparams, api_format):
rp3 = genparams.get('rep_pen', 1.0)
rp_max = max(rp1,rp2,rp3)
genparams["rep_pen"] = rp_max
if "use_default_badwordsids" in genparams and not ("ban_eos_token" in genparams):
genparams["ban_eos_token"] = genparams.get('use_default_badwordsids', False)
if api_format==1:
genparams["prompt"] = genparams.get('text', "")
@ -1085,8 +1087,7 @@ def transform_genparams(genparams, api_format):
genparams["max_length"] = genparams.get('max', 150)
elif api_format==2:
if "ignore_eos" in genparams and not ("use_default_badwordsids" in genparams):
genparams["use_default_badwordsids"] = genparams.get('ignore_eos', False)
pass
elif api_format==3 or api_format==4:
genparams["max_length"] = genparams.get('max_tokens', (350 if api_format==4 else 150))
@ -1099,7 +1100,6 @@ def transform_genparams(genparams, api_format):
genparams["stop_sequence"] = [genparams.get('stop')]
genparams["sampler_seed"] = tryparseint(genparams.get('seed', -1))
genparams["use_default_badwordsids"] = genparams.get('ignore_eos', False)
genparams["mirostat"] = genparams.get('mirostat_mode', 0)
if api_format==4:
@ -1297,7 +1297,7 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler):
sampler_order=genparams.get('sampler_order', [6,0,1,3,4,2,5]),
seed=tryparseint(genparams.get('sampler_seed', -1)),
stop_sequence=genparams.get('stop_sequence', []),
use_default_badwordsids=genparams.get('use_default_badwordsids', False),
ban_eos_token=genparams.get('ban_eos_token', False),
stream_sse=stream_flag,
grammar=genparams.get('grammar', ''),
grammar_retain_state = genparams.get('grammar_retain_state', False),
@ -1517,7 +1517,7 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler):
top_k = int(parsed_dict['top_k'][0]) if 'top_k' in parsed_dict else 100
top_p = float(parsed_dict['top_p'][0]) if 'top_p' in parsed_dict else 0.9
rep_pen = float(parsed_dict['rep_pen'][0]) if 'rep_pen' in parsed_dict else 1.0
use_default_badwordsids = int(parsed_dict['use_default_badwordsids'][0]) if 'use_default_badwordsids' in parsed_dict else 0
ban_eos_token = int(parsed_dict['ban_eos_token'][0]) if 'ban_eos_token' in parsed_dict else 0
gencommand = (parsed_dict['generate'][0] if 'generate' in parsed_dict else "")=="Generate"
if modelbusy.locked():
@ -1531,7 +1531,7 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler):
epurl = f"http://localhost:{args.port}"
if args.host!="":
epurl = f"http://{args.host}:{args.port}"
gen_payload = {"prompt": prompt,"max_length": max_length,"temperature": temperature,"prompt": prompt,"top_k": top_k,"top_p": top_p,"rep_pen": rep_pen,"use_default_badwordsids":use_default_badwordsids}
gen_payload = {"prompt": prompt,"max_length": max_length,"temperature": temperature,"prompt": prompt,"top_k": top_k,"top_p": top_p,"rep_pen": rep_pen,"ban_eos_token":ban_eos_token}
respjson = make_url_request(f'{epurl}/api/v1/generate', gen_payload)
reply = html.escape(respjson["results"][0]["text"])
status = "Generation Completed"
@ -1568,7 +1568,7 @@ Enter Prompt:<br>
<label>Top-K</label> <input type="text" size="4" value="{top_k}" name="top_k"><br>
<label>Top-P</label> <input type="text" size="4" value="{top_p}" name="top_p"><br>
<label>Rep. Pen</label> <input type="text" size="4" value="{rep_pen}" name="rep_pen"><br>
<label>Ignore EOS</label> <input type="checkbox" name="use_default_badwordsids" value="1" {"checked" if use_default_badwordsids else ""}><br>
<label>Prevent EOS</label> <input type="checkbox" name="ban_eos_token" value="1" {"checked" if ban_eos_token else ""}><br>
<input type="submit" name="generate" value="Generate"> (Please be patient)
</form>
<form action="/noscript">
@ -4040,7 +4040,7 @@ def main(launch_args,start_server=True):
benchprompt = "1111111111111111"
for i in range(0,14): #generate massive prompt
benchprompt += benchprompt
genout = generate(benchprompt,memory="",images=[],max_length=benchlen,max_context_length=benchmaxctx,temperature=0.1,top_k=1,rep_pen=1,use_default_badwordsids=True)
genout = generate(benchprompt,memory="",images=[],max_length=benchlen,max_context_length=benchmaxctx,temperature=0.1,top_k=1,rep_pen=1,ban_eos_token=True)
result = genout['text']
result = (result[:5] if len(result)>5 else "")
t_pp = float(handle.get_last_process_time())*float(benchmaxctx-benchlen)*0.001