try casting the numeric fields read

This commit is contained in:
Concedo 2025-01-28 17:43:28 +08:00
parent 0e45d3bb7a
commit 6bf0b2d062

View file

@ -1092,29 +1092,29 @@ def generate(genparams, stream_flag=False):
prompt = genparams.get('prompt', "")
memory = genparams.get('memory', "")
images = genparams.get('images', [])
max_context_length = genparams.get('max_context_length', maxctx)
max_length = genparams.get('max_length', 200)
temperature = genparams.get('temperature', 0.75)
top_k = genparams.get('top_k', 100)
top_a = genparams.get('top_a', 0.0)
top_p = genparams.get('top_p', 0.92)
min_p = genparams.get('min_p', 0.0)
typical_p = genparams.get('typical', 1.0)
tfs = genparams.get('tfs', 1.0)
rep_pen = genparams.get('rep_pen', 1.0)
rep_pen_range = genparams.get('rep_pen_range', 320)
rep_pen_slope = genparams.get('rep_pen_slope', 1.0)
presence_penalty = genparams.get('presence_penalty', 0.0)
mirostat = genparams.get('mirostat', 0)
mirostat_tau = genparams.get('mirostat_tau', 5.0)
mirostat_eta = genparams.get('mirostat_eta', 0.1)
dry_multiplier = genparams.get('dry_multiplier', 0.0)
dry_base = genparams.get('dry_base', 1.75)
dry_allowed_length = genparams.get('dry_allowed_length', 2)
dry_penalty_last_n = genparams.get('dry_penalty_last_n', 320)
max_context_length = int(genparams.get('max_context_length', maxctx))
max_length = int(genparams.get('max_length', 200))
temperature = float(genparams.get('temperature', 0.75))
top_k = int(genparams.get('top_k', 100))
top_a = float(genparams.get('top_a', 0.0))
top_p = float(genparams.get('top_p', 0.92))
min_p = float(genparams.get('min_p', 0.0))
typical_p = float(genparams.get('typical', 1.0))
tfs = float(genparams.get('tfs', 1.0))
rep_pen = float(genparams.get('rep_pen', 1.0))
rep_pen_range = int(genparams.get('rep_pen_range', 320))
rep_pen_slope = float(genparams.get('rep_pen_slope', 1.0))
presence_penalty = float(genparams.get('presence_penalty', 0.0))
mirostat = int(genparams.get('mirostat', 0))
mirostat_tau = float(genparams.get('mirostat_tau', 5.0))
mirostat_eta = float(genparams.get('mirostat_eta', 0.1))
dry_multiplier = float(genparams.get('dry_multiplier', 0.0))
dry_base = float(genparams.get('dry_base', 1.75))
dry_allowed_length = int(genparams.get('dry_allowed_length', 2))
dry_penalty_last_n = int(genparams.get('dry_penalty_last_n', 320))
dry_sequence_breakers = genparams.get('dry_sequence_breakers', [])
xtc_threshold = genparams.get('xtc_threshold', 0.2)
xtc_probability = genparams.get('xtc_probability', 0)
xtc_threshold = float(genparams.get('xtc_threshold', 0.2))
xtc_probability = float(genparams.get('xtc_probability', 0))
sampler_order = genparams.get('sampler_order', [6, 0, 1, 3, 4, 2, 5])
seed = tryparseint(genparams.get('sampler_seed', -1))
stop_sequence = genparams.get('stop_sequence', [])
@ -1124,9 +1124,9 @@ def generate(genparams, stream_flag=False):
grammar_retain_state = genparams.get('grammar_retain_state', False)
genkey = genparams.get('genkey', '')
trimstop = genparams.get('trim_stop', True)
dynatemp_range = genparams.get('dynatemp_range', 0.0)
dynatemp_exponent = genparams.get('dynatemp_exponent', 1.0)
smoothing_factor = genparams.get('smoothing_factor', 0.0)
dynatemp_range = float(genparams.get('dynatemp_range', 0.0))
dynatemp_exponent = float(genparams.get('dynatemp_exponent', 1.0))
smoothing_factor = float(genparams.get('smoothing_factor', 0.0))
logit_biases = genparams.get('logit_bias', {})
render_special = genparams.get('render_special', False)
banned_strings = genparams.get('banned_strings', []) # SillyTavern uses that name
@ -1766,9 +1766,9 @@ def transform_genparams(genparams, api_format):
global chatcompl_adapter, maxctx
#api format 1=basic,2=kai,3=oai,4=oai-chat,5=interrogate,6=ollama,7=ollamachat
#alias all nonstandard alternative names for rep pen.
rp1 = genparams.get('repeat_penalty', 1.0)
rp2 = genparams.get('repetition_penalty', 1.0)
rp3 = genparams.get('rep_pen', 1.0)
rp1 = float(genparams.get('repeat_penalty', 1.0))
rp2 = float(genparams.get('repetition_penalty', 1.0))
rp3 = float(genparams.get('rep_pen', 1.0))
rp_max = max(rp1,rp2,rp3)
genparams["rep_pen"] = rp_max
if "use_default_badwordsids" in genparams and "ban_eos_token" not in genparams:
@ -1777,7 +1777,7 @@ def transform_genparams(genparams, api_format):
if api_format==1:
genparams["prompt"] = genparams.get('text', "")
genparams["top_k"] = int(genparams.get('top_k', 120))
genparams["max_length"] = genparams.get('max', 200)
genparams["max_length"] = int(genparams.get('max', 200))
elif api_format==2:
pass
@ -1786,9 +1786,9 @@ def transform_genparams(genparams, api_format):
default_adapter = {} if chatcompl_adapter is None else chatcompl_adapter
adapter_obj = genparams.get('adapter', default_adapter)
default_max_tok = (adapter_obj.get("max_length", 512) if (api_format==4 or api_format==7) else 200)
genparams["max_length"] = genparams.get('max_tokens', genparams.get('max_completion_tokens', default_max_tok))
genparams["max_length"] = int(genparams.get('max_tokens', genparams.get('max_completion_tokens', default_max_tok)))
presence_penalty = genparams.get('presence_penalty', genparams.get('frequency_penalty', 0.0))
genparams["presence_penalty"] = presence_penalty
genparams["presence_penalty"] = float(presence_penalty)
# openai allows either a string or a list as a stop sequence
if isinstance(genparams.get('stop',[]), list):
genparams["stop_sequence"] = genparams.get('stop', [])