diff --git a/koboldcpp.py b/koboldcpp.py index 4d7493db5..4d983b5d8 100644 --- a/koboldcpp.py +++ b/koboldcpp.py @@ -594,16 +594,20 @@ def end_trim_to_sentence(input_text): return input_text[:last + 1].strip() return input_text.strip() -def tryparseint(value): +def tryparseint(value,fallback): + if value is None: + return fallback try: return int(value) except ValueError: - return value -def tryparsefloat(value): + return fallback +def tryparsefloat(value,fallback): + if value is None: + return fallback try: return float(value) except ValueError: - return value + return fallback def is_incomplete_utf8_sequence(byte_seq): #note, this will only flag INCOMPLETE sequences, corrupted ones will be ignored. try: @@ -1237,32 +1241,32 @@ def generate(genparams, stream_flag=False): prompt = genparams.get('prompt', "") memory = genparams.get('memory', "") images = genparams.get('images', []) - max_context_length = int(genparams.get('max_context_length', maxctx)) - max_length = int(genparams.get('max_length', args.defaultgenamt)) - temperature = float(genparams.get('temperature', 0.75)) - top_k = int(genparams.get('top_k', 100)) - top_a = float(genparams.get('top_a', 0.0)) - top_p = float(genparams.get('top_p', 0.92)) - min_p = float(genparams.get('min_p', 0.0)) - typical_p = float(genparams.get('typical', 1.0)) - tfs = float(genparams.get('tfs', 1.0)) - nsigma = float(genparams.get('nsigma', 0.0)) - rep_pen = float(genparams.get('rep_pen', 1.0)) - rep_pen_range = int(genparams.get('rep_pen_range', 320)) - rep_pen_slope = float(genparams.get('rep_pen_slope', 1.0)) - presence_penalty = float(genparams.get('presence_penalty', 0.0)) - mirostat = int(genparams.get('mirostat', 0)) - mirostat_tau = float(genparams.get('mirostat_tau', 5.0)) - mirostat_eta = float(genparams.get('mirostat_eta', 0.1)) - dry_multiplier = float(genparams.get('dry_multiplier', 0.0)) - dry_base = float(genparams.get('dry_base', 1.75)) - dry_allowed_length = int(genparams.get('dry_allowed_length', 2)) - dry_penalty_last_n = int(genparams.get('dry_penalty_last_n', 320)) + max_context_length = tryparseint(genparams.get('max_context_length', maxctx),maxctx) + max_length = tryparseint(genparams.get('max_length', args.defaultgenamt),args.defaultgenamt) + temperature = tryparsefloat(genparams.get('temperature', 0.75),0.75) + top_k = tryparseint(genparams.get('top_k', 100),100) + top_a = tryparsefloat(genparams.get('top_a', 0.0),0.0) + top_p = tryparsefloat(genparams.get('top_p', 0.92),0.92) + min_p = tryparsefloat(genparams.get('min_p', 0.0),0.0) + typical_p = tryparsefloat(genparams.get('typical', 1.0),1.0) + tfs = tryparsefloat(genparams.get('tfs', 1.0),1.0) + nsigma = tryparsefloat(genparams.get('nsigma', 0.0),0.0) + rep_pen = tryparsefloat(genparams.get('rep_pen', 1.0),1.0) + rep_pen_range = tryparseint(genparams.get('rep_pen_range', 320),320) + rep_pen_slope = tryparsefloat(genparams.get('rep_pen_slope', 1.0),1.0) + presence_penalty = tryparsefloat(genparams.get('presence_penalty', 0.0),0.0) + mirostat = tryparseint(genparams.get('mirostat', 0),0) + mirostat_tau = tryparsefloat(genparams.get('mirostat_tau', 5.0),5.0) + mirostat_eta = tryparsefloat(genparams.get('mirostat_eta', 0.1),0.1) + dry_multiplier = tryparsefloat(genparams.get('dry_multiplier', 0.0),0.0) + dry_base = tryparsefloat(genparams.get('dry_base', 1.75),1.75) + dry_allowed_length = tryparseint(genparams.get('dry_allowed_length', 2),2) + dry_penalty_last_n = tryparseint(genparams.get('dry_penalty_last_n', 320),320) dry_sequence_breakers = genparams.get('dry_sequence_breakers', []) - xtc_threshold = float(genparams.get('xtc_threshold', 0.2)) - xtc_probability = float(genparams.get('xtc_probability', 0)) + xtc_threshold = tryparsefloat(genparams.get('xtc_threshold', 0.2),0.2) + xtc_probability = tryparsefloat(genparams.get('xtc_probability', 0),0) sampler_order = genparams.get('sampler_order', [6, 0, 1, 3, 4, 2, 5]) - seed = tryparseint(genparams.get('sampler_seed', -1)) + seed = tryparseint(genparams.get('sampler_seed', -1),-1) stop_sequence = genparams.get('stop_sequence', []) ban_eos_token = genparams.get('ban_eos_token', False) stream_sse = stream_flag @@ -1278,9 +1282,9 @@ def generate(genparams, stream_flag=False): grammar_retain_state = genparams.get('grammar_retain_state', False) genkey = genparams.get('genkey', '') trimstop = genparams.get('trim_stop', True) - dynatemp_range = float(genparams.get('dynatemp_range', 0.0)) - dynatemp_exponent = float(genparams.get('dynatemp_exponent', 1.0)) - smoothing_factor = float(genparams.get('smoothing_factor', 0.0)) + dynatemp_range = tryparsefloat(genparams.get('dynatemp_range', 0.0),0.0) + dynatemp_exponent = tryparsefloat(genparams.get('dynatemp_exponent', 1.0),1.0) + smoothing_factor = tryparsefloat(genparams.get('smoothing_factor', 0.0),0.0) logit_biases = genparams.get('logit_bias', {}) render_special = genparams.get('render_special', False) banned_strings = genparams.get('banned_strings', []) # SillyTavern uses that name @@ -1534,14 +1538,14 @@ def sd_generate(genparams): init_images = ("" if (not init_images_arr or len(init_images_arr)==0 or not init_images_arr[0]) else init_images_arr[0]) mask = genparams.get("mask", "") flip_mask = genparams.get("inpainting_mask_invert", 0) - denoising_strength = tryparsefloat(genparams.get("denoising_strength", 0.6)) - cfg_scale = tryparsefloat(genparams.get("cfg_scale", 5)) - sample_steps = tryparseint(genparams.get("steps", 20)) - width = tryparseint(genparams.get("width", 512)) - height = tryparseint(genparams.get("height", 512)) - seed = tryparseint(genparams.get("seed", -1)) + denoising_strength = tryparsefloat(genparams.get("denoising_strength", 0.6),0.6) + cfg_scale = tryparsefloat(genparams.get("cfg_scale", 5),5) + sample_steps = tryparseint(genparams.get("steps", 20),20) + width = tryparseint(genparams.get("width", 512),512) + height = tryparseint(genparams.get("height", 512),512) + seed = tryparseint(genparams.get("seed", -1),-1) sample_method = genparams.get("sampler_name", "k_euler_a") - clip_skip = tryparseint(genparams.get("clip_skip", -1)) + clip_skip = tryparseint(genparams.get("clip_skip", -1),-1) #clean vars width = width - (width%64) @@ -2043,7 +2047,7 @@ def transform_genparams(genparams, api_format): else: genparams["stop_sequence"] = [genparams.get('stop')] - genparams["sampler_seed"] = tryparseint(genparams.get('seed', -1)) + genparams["sampler_seed"] = tryparseint(genparams.get('seed', -1),-1) genparams["mirostat"] = genparams.get('mirostat_mode', 0) if api_format==4 or api_format==7: #handle ollama chat here too @@ -2248,7 +2252,7 @@ ws ::= | " " | "\n" [ \t]{0,20} if "top_p" in ollamaopts: genparams["top_p"] = ollamaopts.get('top_p', 0.92) if "seed" in ollamaopts: - genparams["sampler_seed"] = tryparseint(ollamaopts.get('seed', -1)) + genparams["sampler_seed"] = tryparseint(ollamaopts.get('seed', -1),-1) if "stop" in ollamaopts: genparams["stop_sequence"] = ollamaopts.get('stop', []) genparams["stop_sequence"].append(user_message_start.strip()) @@ -3132,7 +3136,7 @@ Change Mode
loadid = -1 try: tempbody = json.loads(body) - loadid = tryparseint(tempbody.get('slot', 0)) + loadid = tryparseint(tempbody.get('slot', 0),0) except Exception: loadid = -1 if loadid < 0 or str(loadid) not in savedata_obj: @@ -3149,7 +3153,7 @@ Change Mode
else: try: incoming_story = json.loads(body) # ensure submitted data is valid json - slotid = tryparseint(incoming_story.get('slot', -1)) + slotid = tryparseint(incoming_story.get('slot', -1),-1) dataformat = incoming_story.get('format', "") title = incoming_story.get('title', "") if not title or title=="":