From 6bf0b2d062e566ddf954add90b20973200953e47 Mon Sep 17 00:00:00 2001 From: Concedo <39025047+LostRuins@users.noreply.github.com> Date: Tue, 28 Jan 2025 17:43:28 +0800 Subject: [PATCH] try casting the numeric fields read --- koboldcpp.py | 62 ++++++++++++++++++++++++++-------------------------- 1 file changed, 31 insertions(+), 31 deletions(-) diff --git a/koboldcpp.py b/koboldcpp.py index 1745abc78..366b5b09b 100644 --- a/koboldcpp.py +++ b/koboldcpp.py @@ -1092,29 +1092,29 @@ def generate(genparams, stream_flag=False): prompt = genparams.get('prompt', "") memory = genparams.get('memory', "") images = genparams.get('images', []) - max_context_length = genparams.get('max_context_length', maxctx) - max_length = genparams.get('max_length', 200) - temperature = genparams.get('temperature', 0.75) - top_k = genparams.get('top_k', 100) - top_a = genparams.get('top_a', 0.0) - top_p = genparams.get('top_p', 0.92) - min_p = genparams.get('min_p', 0.0) - typical_p = genparams.get('typical', 1.0) - tfs = genparams.get('tfs', 1.0) - rep_pen = genparams.get('rep_pen', 1.0) - rep_pen_range = genparams.get('rep_pen_range', 320) - rep_pen_slope = genparams.get('rep_pen_slope', 1.0) - presence_penalty = genparams.get('presence_penalty', 0.0) - mirostat = genparams.get('mirostat', 0) - mirostat_tau = genparams.get('mirostat_tau', 5.0) - mirostat_eta = genparams.get('mirostat_eta', 0.1) - dry_multiplier = genparams.get('dry_multiplier', 0.0) - dry_base = genparams.get('dry_base', 1.75) - dry_allowed_length = genparams.get('dry_allowed_length', 2) - dry_penalty_last_n = genparams.get('dry_penalty_last_n', 320) + max_context_length = int(genparams.get('max_context_length', maxctx)) + max_length = int(genparams.get('max_length', 200)) + temperature = float(genparams.get('temperature', 0.75)) + top_k = int(genparams.get('top_k', 100)) + top_a = float(genparams.get('top_a', 0.0)) + top_p = float(genparams.get('top_p', 0.92)) + min_p = float(genparams.get('min_p', 0.0)) + typical_p = float(genparams.get('typical', 1.0)) + tfs = float(genparams.get('tfs', 1.0)) + rep_pen = float(genparams.get('rep_pen', 1.0)) + rep_pen_range = int(genparams.get('rep_pen_range', 320)) + rep_pen_slope = float(genparams.get('rep_pen_slope', 1.0)) + presence_penalty = float(genparams.get('presence_penalty', 0.0)) + mirostat = int(genparams.get('mirostat', 0)) + mirostat_tau = float(genparams.get('mirostat_tau', 5.0)) + mirostat_eta = float(genparams.get('mirostat_eta', 0.1)) + dry_multiplier = float(genparams.get('dry_multiplier', 0.0)) + dry_base = float(genparams.get('dry_base', 1.75)) + dry_allowed_length = int(genparams.get('dry_allowed_length', 2)) + dry_penalty_last_n = int(genparams.get('dry_penalty_last_n', 320)) dry_sequence_breakers = genparams.get('dry_sequence_breakers', []) - xtc_threshold = genparams.get('xtc_threshold', 0.2) - xtc_probability = genparams.get('xtc_probability', 0) + xtc_threshold = float(genparams.get('xtc_threshold', 0.2)) + xtc_probability = float(genparams.get('xtc_probability', 0)) sampler_order = genparams.get('sampler_order', [6, 0, 1, 3, 4, 2, 5]) seed = tryparseint(genparams.get('sampler_seed', -1)) stop_sequence = genparams.get('stop_sequence', []) @@ -1124,9 +1124,9 @@ def generate(genparams, stream_flag=False): grammar_retain_state = genparams.get('grammar_retain_state', False) genkey = genparams.get('genkey', '') trimstop = genparams.get('trim_stop', True) - dynatemp_range = genparams.get('dynatemp_range', 0.0) - dynatemp_exponent = genparams.get('dynatemp_exponent', 1.0) - smoothing_factor = genparams.get('smoothing_factor', 0.0) + dynatemp_range = float(genparams.get('dynatemp_range', 0.0)) + dynatemp_exponent = float(genparams.get('dynatemp_exponent', 1.0)) + smoothing_factor = float(genparams.get('smoothing_factor', 0.0)) logit_biases = genparams.get('logit_bias', {}) render_special = genparams.get('render_special', False) banned_strings = genparams.get('banned_strings', []) # SillyTavern uses that name @@ -1766,9 +1766,9 @@ def transform_genparams(genparams, api_format): global chatcompl_adapter, maxctx #api format 1=basic,2=kai,3=oai,4=oai-chat,5=interrogate,6=ollama,7=ollamachat #alias all nonstandard alternative names for rep pen. - rp1 = genparams.get('repeat_penalty', 1.0) - rp2 = genparams.get('repetition_penalty', 1.0) - rp3 = genparams.get('rep_pen', 1.0) + rp1 = float(genparams.get('repeat_penalty', 1.0)) + rp2 = float(genparams.get('repetition_penalty', 1.0)) + rp3 = float(genparams.get('rep_pen', 1.0)) rp_max = max(rp1,rp2,rp3) genparams["rep_pen"] = rp_max if "use_default_badwordsids" in genparams and "ban_eos_token" not in genparams: @@ -1777,7 +1777,7 @@ def transform_genparams(genparams, api_format): if api_format==1: genparams["prompt"] = genparams.get('text', "") genparams["top_k"] = int(genparams.get('top_k', 120)) - genparams["max_length"] = genparams.get('max', 200) + genparams["max_length"] = int(genparams.get('max', 200)) elif api_format==2: pass @@ -1786,9 +1786,9 @@ def transform_genparams(genparams, api_format): default_adapter = {} if chatcompl_adapter is None else chatcompl_adapter adapter_obj = genparams.get('adapter', default_adapter) default_max_tok = (adapter_obj.get("max_length", 512) if (api_format==4 or api_format==7) else 200) - genparams["max_length"] = genparams.get('max_tokens', genparams.get('max_completion_tokens', default_max_tok)) + genparams["max_length"] = int(genparams.get('max_tokens', genparams.get('max_completion_tokens', default_max_tok))) presence_penalty = genparams.get('presence_penalty', genparams.get('frequency_penalty', 0.0)) - genparams["presence_penalty"] = presence_penalty + genparams["presence_penalty"] = float(presence_penalty) # openai allows either a string or a list as a stop sequence if isinstance(genparams.get('stop',[]), list): genparams["stop_sequence"] = genparams.get('stop', [])