diff --git a/class.py b/class.py index 33e32e785..242c9cc63 100644 --- a/class.py +++ b/class.py @@ -306,11 +306,11 @@ class model_backend(InferenceModel): # Store context in memory to use it for comparison with generated content utils.koboldai_vars.lastctx = decoded_prompt - self.input_queue.put({'command': 'generate', 'data': [(decoded_prompt,), {'max_length': max_new, 'max_context_length': utils.koboldai_vars.max_length, + self.input_queue.put({'command': 'generate', 'data': {'prompt':decoded_prompt, 'max_length': max_new, 'max_context_length': utils.koboldai_vars.max_length, 'temperature': gen_settings.temp, 'top_k': int(gen_settings.top_k), 'top_a': gen_settings.top_a, 'top_p': gen_settings.top_p, 'typical_p': gen_settings.typical, 'tfs': gen_settings.tfs, 'rep_pen': gen_settings.rep_pen, 'rep_pen_range': gen_settings.rep_pen_range, "sampler_order": gen_settings.sampler_order, "use_default_badwordsids": utils.koboldai_vars.use_default_badwordsids} - ]}) + }) #genresult = koboldcpp.generate(decoded_prompt,"",max_new,utils.koboldai_vars.max_length, #gen_settings.temp,int(gen_settings.top_k),gen_settings.top_a,gen_settings.top_p, diff --git a/koboldcpp.py b/koboldcpp.py index 2032f7f15..4e77c2d35 100644 --- a/koboldcpp.py +++ b/koboldcpp.py @@ -865,8 +865,51 @@ def load_model(model_filename): ret = handle.load_model(inputs) return ret -def generate(prompt, memory="", images=[], max_length=32, max_context_length=512, temperature=0.7, top_k=100, top_a=0.0, top_p=0.92, min_p=0.0, typical_p=1.0, tfs=1.0, rep_pen=1.0, rep_pen_range=128, rep_pen_slope=1.0, presence_penalty=0.0, mirostat=0, mirostat_tau=5.0, mirostat_eta=0.1, dry_multiplier=0.0, dry_base=1.75, dry_allowed_length=2, dry_penalty_last_n=0, dry_sequence_breakers=[], sampler_order=[6,0,1,3,4,2,5], seed=-1, stop_sequence=[], ban_eos_token=False, stream_sse=False, grammar='', grammar_retain_state=False, genkey='', trimstop=False, quiet=False, dynatemp_range=0.0, dynatemp_exponent=1.0, smoothing_factor=0.0, logit_biases={}, render_special=False, banned_tokens=[], bypass_eos_token=False): +def generate(genparams, is_quiet=False, stream_flag=False): global maxctx, args, currentusergenkey, totalgens, pendingabortkey + + prompt = genparams.get('prompt', "") + memory = genparams.get('memory', "") + images = genparams.get('images', []) + max_context_length = genparams.get('max_context_length', maxctx) + max_length = genparams.get('max_length', 180) + temperature = genparams.get('temperature', 0.7) + top_k = genparams.get('top_k', 100) + top_a = genparams.get('top_a', 0.0) + top_p = genparams.get('top_p', 0.92) + min_p = genparams.get('min_p', 0.0) + typical_p = genparams.get('typical', 1.0) + tfs = genparams.get('tfs', 1.0) + rep_pen = genparams.get('rep_pen', 1.0) + rep_pen_range = genparams.get('rep_pen_range', 256) + rep_pen_slope = genparams.get('rep_pen_slope', 1.0) + presence_penalty = genparams.get('presence_penalty', 0.0) + mirostat = genparams.get('mirostat', 0) + mirostat_tau = genparams.get('mirostat_tau', 5.0) + mirostat_eta = genparams.get('mirostat_eta', 0.1) + dry_multiplier = genparams.get('dry_multiplier', 0.0) + dry_base = genparams.get('dry_base', 1.75) + dry_allowed_length = genparams.get('dry_allowed_length', 2) + dry_penalty_last_n = genparams.get('dry_penalty_last_n', 0) + dry_sequence_breakers = genparams.get('dry_sequence_breakers', []) + sampler_order = genparams.get('sampler_order', [6, 0, 1, 3, 4, 2, 5]) + seed = tryparseint(genparams.get('sampler_seed', -1)) + stop_sequence = genparams.get('stop_sequence', []) + ban_eos_token = genparams.get('ban_eos_token', False) + stream_sse = stream_flag + grammar = genparams.get('grammar', '') + grammar_retain_state = genparams.get('grammar_retain_state', False) + genkey = genparams.get('genkey', '') + trimstop = genparams.get('trim_stop', False) + quiet = is_quiet + dynatemp_range = genparams.get('dynatemp_range', 0.0) + dynatemp_exponent = genparams.get('dynatemp_exponent', 1.0) + smoothing_factor = genparams.get('smoothing_factor', 0.0) + logit_biases = genparams.get('logit_bias', {}) + render_special = genparams.get('render_special', False) + banned_tokens = genparams.get('banned_tokens', []) + bypass_eos_token = genparams.get('bypass_eos', False) + inputs = generation_inputs() inputs.prompt = prompt.encode("UTF-8") inputs.memory = memory.encode("UTF-8") @@ -1362,47 +1405,9 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler): last_non_horde_req_time = time.time() return generate( - prompt=genparams.get('prompt', ""), - memory=genparams.get('memory', ""), - images=genparams.get('images', []), - max_context_length=genparams.get('max_context_length', maxctx), - max_length=genparams.get('max_length', 180), - temperature=genparams.get('temperature', 0.7), - top_k=genparams.get('top_k', 100), - top_a=genparams.get('top_a', 0.0), - top_p=genparams.get('top_p', 0.92), - min_p=genparams.get('min_p', 0.0), - typical_p=genparams.get('typical', 1.0), - tfs=genparams.get('tfs', 1.0), - rep_pen=genparams.get('rep_pen', 1.0), - rep_pen_range=genparams.get('rep_pen_range', 256), - rep_pen_slope=genparams.get('rep_pen_slope', 1.0), - presence_penalty=genparams.get('presence_penalty', 0.0), - mirostat=genparams.get('mirostat', 0), - mirostat_tau=genparams.get('mirostat_tau', 5.0), - mirostat_eta=genparams.get('mirostat_eta', 0.1), - dry_multiplier=genparams.get('dry_multiplier', 0.0), - dry_base=genparams.get('dry_base', 1.75), - dry_allowed_length=genparams.get('dry_allowed_length', 2), - dry_penalty_last_n=genparams.get('dry_penalty_last_n', 0), - dry_sequence_breakers=genparams.get('dry_sequence_breakers', []), - sampler_order=genparams.get('sampler_order', [6,0,1,3,4,2,5]), - seed=tryparseint(genparams.get('sampler_seed', -1)), - stop_sequence=genparams.get('stop_sequence', []), - ban_eos_token=genparams.get('ban_eos_token', False), - stream_sse=stream_flag, - grammar=genparams.get('grammar', ''), - grammar_retain_state = genparams.get('grammar_retain_state', False), - genkey=genparams.get('genkey', ''), - trimstop=genparams.get('trim_stop', False), - quiet=is_quiet, - dynatemp_range=genparams.get('dynatemp_range', 0.0), - dynatemp_exponent=genparams.get('dynatemp_exponent', 1.0), - smoothing_factor=genparams.get('smoothing_factor', 0.0), - logit_biases=genparams.get('logit_bias', {}), - render_special=genparams.get('render_special', False), - banned_tokens=genparams.get('banned_tokens', []), - bypass_eos_token=genparams.get('bypass_eos', False), + genparams=genparams, + is_quiet=is_quiet, + stream_flag=stream_flag ) genout = {"text": "", "status": -1, "stopreason": -1} @@ -4237,7 +4242,16 @@ def main(launch_args,start_server=True): benchprompt = " 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1" for i in range(0,14): #generate massive prompt benchprompt += benchprompt - genout = generate(benchprompt,memory="",images=[],max_length=benchlen,max_context_length=benchmaxctx,temperature=benchtemp,top_k=benchtopk,rep_pen=benchreppen,ban_eos_token=benchbaneos) + genp = { + "prompt":benchprompt, + "max_length":benchlen, + "max_context_length":benchmaxctx, + "temperature":benchtemp, + "top_k":benchtopk, + "rep_pen":benchreppen, + "ban_eos_token":benchbaneos + } + genout = generate(genparams=genp) result = genout['text'] if args.prompt and not args.benchmark: restore_stdout() @@ -4300,10 +4314,10 @@ def run_in_queue(launch_args, input_queue, output_queue): while not input_queue.empty(): data = input_queue.get() if data['command'] == 'generate': - (args, kwargs) = data['data'] - genout = generate(*args, **kwargs) - result = genout['text'] - output_queue.put({'command': 'generated text', 'data': result}) + pl = data['data'] + genout = generate(genparams=pl) + result = genout['text'] + output_queue.put({'command': 'generated text', 'data': result}) time.sleep(0.2) def start_in_seperate_process(launch_args):