generate passes whole object now

This commit is contained in:
Concedo 2024-08-11 00:08:13 +08:00
parent 7fab499b79
commit 139ab3d198
2 changed files with 63 additions and 49 deletions

View file

@ -865,8 +865,51 @@ def load_model(model_filename):
ret = handle.load_model(inputs)
return ret
def generate(prompt, memory="", images=[], max_length=32, max_context_length=512, temperature=0.7, top_k=100, top_a=0.0, top_p=0.92, min_p=0.0, typical_p=1.0, tfs=1.0, rep_pen=1.0, rep_pen_range=128, rep_pen_slope=1.0, presence_penalty=0.0, mirostat=0, mirostat_tau=5.0, mirostat_eta=0.1, dry_multiplier=0.0, dry_base=1.75, dry_allowed_length=2, dry_penalty_last_n=0, dry_sequence_breakers=[], sampler_order=[6,0,1,3,4,2,5], seed=-1, stop_sequence=[], ban_eos_token=False, stream_sse=False, grammar='', grammar_retain_state=False, genkey='', trimstop=False, quiet=False, dynatemp_range=0.0, dynatemp_exponent=1.0, smoothing_factor=0.0, logit_biases={}, render_special=False, banned_tokens=[], bypass_eos_token=False):
def generate(genparams, is_quiet=False, stream_flag=False):
global maxctx, args, currentusergenkey, totalgens, pendingabortkey
prompt = genparams.get('prompt', "")
memory = genparams.get('memory', "")
images = genparams.get('images', [])
max_context_length = genparams.get('max_context_length', maxctx)
max_length = genparams.get('max_length', 180)
temperature = genparams.get('temperature', 0.7)
top_k = genparams.get('top_k', 100)
top_a = genparams.get('top_a', 0.0)
top_p = genparams.get('top_p', 0.92)
min_p = genparams.get('min_p', 0.0)
typical_p = genparams.get('typical', 1.0)
tfs = genparams.get('tfs', 1.0)
rep_pen = genparams.get('rep_pen', 1.0)
rep_pen_range = genparams.get('rep_pen_range', 256)
rep_pen_slope = genparams.get('rep_pen_slope', 1.0)
presence_penalty = genparams.get('presence_penalty', 0.0)
mirostat = genparams.get('mirostat', 0)
mirostat_tau = genparams.get('mirostat_tau', 5.0)
mirostat_eta = genparams.get('mirostat_eta', 0.1)
dry_multiplier = genparams.get('dry_multiplier', 0.0)
dry_base = genparams.get('dry_base', 1.75)
dry_allowed_length = genparams.get('dry_allowed_length', 2)
dry_penalty_last_n = genparams.get('dry_penalty_last_n', 0)
dry_sequence_breakers = genparams.get('dry_sequence_breakers', [])
sampler_order = genparams.get('sampler_order', [6, 0, 1, 3, 4, 2, 5])
seed = tryparseint(genparams.get('sampler_seed', -1))
stop_sequence = genparams.get('stop_sequence', [])
ban_eos_token = genparams.get('ban_eos_token', False)
stream_sse = stream_flag
grammar = genparams.get('grammar', '')
grammar_retain_state = genparams.get('grammar_retain_state', False)
genkey = genparams.get('genkey', '')
trimstop = genparams.get('trim_stop', False)
quiet = is_quiet
dynatemp_range = genparams.get('dynatemp_range', 0.0)
dynatemp_exponent = genparams.get('dynatemp_exponent', 1.0)
smoothing_factor = genparams.get('smoothing_factor', 0.0)
logit_biases = genparams.get('logit_bias', {})
render_special = genparams.get('render_special', False)
banned_tokens = genparams.get('banned_tokens', [])
bypass_eos_token = genparams.get('bypass_eos', False)
inputs = generation_inputs()
inputs.prompt = prompt.encode("UTF-8")
inputs.memory = memory.encode("UTF-8")
@ -1362,47 +1405,9 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler):
last_non_horde_req_time = time.time()
return generate(
prompt=genparams.get('prompt', ""),
memory=genparams.get('memory', ""),
images=genparams.get('images', []),
max_context_length=genparams.get('max_context_length', maxctx),
max_length=genparams.get('max_length', 180),
temperature=genparams.get('temperature', 0.7),
top_k=genparams.get('top_k', 100),
top_a=genparams.get('top_a', 0.0),
top_p=genparams.get('top_p', 0.92),
min_p=genparams.get('min_p', 0.0),
typical_p=genparams.get('typical', 1.0),
tfs=genparams.get('tfs', 1.0),
rep_pen=genparams.get('rep_pen', 1.0),
rep_pen_range=genparams.get('rep_pen_range', 256),
rep_pen_slope=genparams.get('rep_pen_slope', 1.0),
presence_penalty=genparams.get('presence_penalty', 0.0),
mirostat=genparams.get('mirostat', 0),
mirostat_tau=genparams.get('mirostat_tau', 5.0),
mirostat_eta=genparams.get('mirostat_eta', 0.1),
dry_multiplier=genparams.get('dry_multiplier', 0.0),
dry_base=genparams.get('dry_base', 1.75),
dry_allowed_length=genparams.get('dry_allowed_length', 2),
dry_penalty_last_n=genparams.get('dry_penalty_last_n', 0),
dry_sequence_breakers=genparams.get('dry_sequence_breakers', []),
sampler_order=genparams.get('sampler_order', [6,0,1,3,4,2,5]),
seed=tryparseint(genparams.get('sampler_seed', -1)),
stop_sequence=genparams.get('stop_sequence', []),
ban_eos_token=genparams.get('ban_eos_token', False),
stream_sse=stream_flag,
grammar=genparams.get('grammar', ''),
grammar_retain_state = genparams.get('grammar_retain_state', False),
genkey=genparams.get('genkey', ''),
trimstop=genparams.get('trim_stop', False),
quiet=is_quiet,
dynatemp_range=genparams.get('dynatemp_range', 0.0),
dynatemp_exponent=genparams.get('dynatemp_exponent', 1.0),
smoothing_factor=genparams.get('smoothing_factor', 0.0),
logit_biases=genparams.get('logit_bias', {}),
render_special=genparams.get('render_special', False),
banned_tokens=genparams.get('banned_tokens', []),
bypass_eos_token=genparams.get('bypass_eos', False),
genparams=genparams,
is_quiet=is_quiet,
stream_flag=stream_flag
)
genout = {"text": "", "status": -1, "stopreason": -1}
@ -4237,7 +4242,16 @@ def main(launch_args,start_server=True):
benchprompt = " 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1"
for i in range(0,14): #generate massive prompt
benchprompt += benchprompt
genout = generate(benchprompt,memory="",images=[],max_length=benchlen,max_context_length=benchmaxctx,temperature=benchtemp,top_k=benchtopk,rep_pen=benchreppen,ban_eos_token=benchbaneos)
genp = {
"prompt":benchprompt,
"max_length":benchlen,
"max_context_length":benchmaxctx,
"temperature":benchtemp,
"top_k":benchtopk,
"rep_pen":benchreppen,
"ban_eos_token":benchbaneos
}
genout = generate(genparams=genp)
result = genout['text']
if args.prompt and not args.benchmark:
restore_stdout()
@ -4300,10 +4314,10 @@ def run_in_queue(launch_args, input_queue, output_queue):
while not input_queue.empty():
data = input_queue.get()
if data['command'] == 'generate':
(args, kwargs) = data['data']
genout = generate(*args, **kwargs)
result = genout['text']
output_queue.put({'command': 'generated text', 'data': result})
pl = data['data']
genout = generate(genparams=pl)
result = genout['text']
output_queue.put({'command': 'generated text', 'data': result})
time.sleep(0.2)
def start_in_seperate_process(launch_args):