mirror of
https://github.com/LostRuins/koboldcpp.git
synced 2025-09-10 17:14:36 +00:00
try casting the numeric fields read
This commit is contained in:
parent
0e45d3bb7a
commit
6bf0b2d062
1 changed files with 31 additions and 31 deletions
62
koboldcpp.py
62
koboldcpp.py
|
@ -1092,29 +1092,29 @@ def generate(genparams, stream_flag=False):
|
|||
prompt = genparams.get('prompt', "")
|
||||
memory = genparams.get('memory', "")
|
||||
images = genparams.get('images', [])
|
||||
max_context_length = genparams.get('max_context_length', maxctx)
|
||||
max_length = genparams.get('max_length', 200)
|
||||
temperature = genparams.get('temperature', 0.75)
|
||||
top_k = genparams.get('top_k', 100)
|
||||
top_a = genparams.get('top_a', 0.0)
|
||||
top_p = genparams.get('top_p', 0.92)
|
||||
min_p = genparams.get('min_p', 0.0)
|
||||
typical_p = genparams.get('typical', 1.0)
|
||||
tfs = genparams.get('tfs', 1.0)
|
||||
rep_pen = genparams.get('rep_pen', 1.0)
|
||||
rep_pen_range = genparams.get('rep_pen_range', 320)
|
||||
rep_pen_slope = genparams.get('rep_pen_slope', 1.0)
|
||||
presence_penalty = genparams.get('presence_penalty', 0.0)
|
||||
mirostat = genparams.get('mirostat', 0)
|
||||
mirostat_tau = genparams.get('mirostat_tau', 5.0)
|
||||
mirostat_eta = genparams.get('mirostat_eta', 0.1)
|
||||
dry_multiplier = genparams.get('dry_multiplier', 0.0)
|
||||
dry_base = genparams.get('dry_base', 1.75)
|
||||
dry_allowed_length = genparams.get('dry_allowed_length', 2)
|
||||
dry_penalty_last_n = genparams.get('dry_penalty_last_n', 320)
|
||||
max_context_length = int(genparams.get('max_context_length', maxctx))
|
||||
max_length = int(genparams.get('max_length', 200))
|
||||
temperature = float(genparams.get('temperature', 0.75))
|
||||
top_k = int(genparams.get('top_k', 100))
|
||||
top_a = float(genparams.get('top_a', 0.0))
|
||||
top_p = float(genparams.get('top_p', 0.92))
|
||||
min_p = float(genparams.get('min_p', 0.0))
|
||||
typical_p = float(genparams.get('typical', 1.0))
|
||||
tfs = float(genparams.get('tfs', 1.0))
|
||||
rep_pen = float(genparams.get('rep_pen', 1.0))
|
||||
rep_pen_range = int(genparams.get('rep_pen_range', 320))
|
||||
rep_pen_slope = float(genparams.get('rep_pen_slope', 1.0))
|
||||
presence_penalty = float(genparams.get('presence_penalty', 0.0))
|
||||
mirostat = int(genparams.get('mirostat', 0))
|
||||
mirostat_tau = float(genparams.get('mirostat_tau', 5.0))
|
||||
mirostat_eta = float(genparams.get('mirostat_eta', 0.1))
|
||||
dry_multiplier = float(genparams.get('dry_multiplier', 0.0))
|
||||
dry_base = float(genparams.get('dry_base', 1.75))
|
||||
dry_allowed_length = int(genparams.get('dry_allowed_length', 2))
|
||||
dry_penalty_last_n = int(genparams.get('dry_penalty_last_n', 320))
|
||||
dry_sequence_breakers = genparams.get('dry_sequence_breakers', [])
|
||||
xtc_threshold = genparams.get('xtc_threshold', 0.2)
|
||||
xtc_probability = genparams.get('xtc_probability', 0)
|
||||
xtc_threshold = float(genparams.get('xtc_threshold', 0.2))
|
||||
xtc_probability = float(genparams.get('xtc_probability', 0))
|
||||
sampler_order = genparams.get('sampler_order', [6, 0, 1, 3, 4, 2, 5])
|
||||
seed = tryparseint(genparams.get('sampler_seed', -1))
|
||||
stop_sequence = genparams.get('stop_sequence', [])
|
||||
|
@ -1124,9 +1124,9 @@ def generate(genparams, stream_flag=False):
|
|||
grammar_retain_state = genparams.get('grammar_retain_state', False)
|
||||
genkey = genparams.get('genkey', '')
|
||||
trimstop = genparams.get('trim_stop', True)
|
||||
dynatemp_range = genparams.get('dynatemp_range', 0.0)
|
||||
dynatemp_exponent = genparams.get('dynatemp_exponent', 1.0)
|
||||
smoothing_factor = genparams.get('smoothing_factor', 0.0)
|
||||
dynatemp_range = float(genparams.get('dynatemp_range', 0.0))
|
||||
dynatemp_exponent = float(genparams.get('dynatemp_exponent', 1.0))
|
||||
smoothing_factor = float(genparams.get('smoothing_factor', 0.0))
|
||||
logit_biases = genparams.get('logit_bias', {})
|
||||
render_special = genparams.get('render_special', False)
|
||||
banned_strings = genparams.get('banned_strings', []) # SillyTavern uses that name
|
||||
|
@ -1766,9 +1766,9 @@ def transform_genparams(genparams, api_format):
|
|||
global chatcompl_adapter, maxctx
|
||||
#api format 1=basic,2=kai,3=oai,4=oai-chat,5=interrogate,6=ollama,7=ollamachat
|
||||
#alias all nonstandard alternative names for rep pen.
|
||||
rp1 = genparams.get('repeat_penalty', 1.0)
|
||||
rp2 = genparams.get('repetition_penalty', 1.0)
|
||||
rp3 = genparams.get('rep_pen', 1.0)
|
||||
rp1 = float(genparams.get('repeat_penalty', 1.0))
|
||||
rp2 = float(genparams.get('repetition_penalty', 1.0))
|
||||
rp3 = float(genparams.get('rep_pen', 1.0))
|
||||
rp_max = max(rp1,rp2,rp3)
|
||||
genparams["rep_pen"] = rp_max
|
||||
if "use_default_badwordsids" in genparams and "ban_eos_token" not in genparams:
|
||||
|
@ -1777,7 +1777,7 @@ def transform_genparams(genparams, api_format):
|
|||
if api_format==1:
|
||||
genparams["prompt"] = genparams.get('text', "")
|
||||
genparams["top_k"] = int(genparams.get('top_k', 120))
|
||||
genparams["max_length"] = genparams.get('max', 200)
|
||||
genparams["max_length"] = int(genparams.get('max', 200))
|
||||
|
||||
elif api_format==2:
|
||||
pass
|
||||
|
@ -1786,9 +1786,9 @@ def transform_genparams(genparams, api_format):
|
|||
default_adapter = {} if chatcompl_adapter is None else chatcompl_adapter
|
||||
adapter_obj = genparams.get('adapter', default_adapter)
|
||||
default_max_tok = (adapter_obj.get("max_length", 512) if (api_format==4 or api_format==7) else 200)
|
||||
genparams["max_length"] = genparams.get('max_tokens', genparams.get('max_completion_tokens', default_max_tok))
|
||||
genparams["max_length"] = int(genparams.get('max_tokens', genparams.get('max_completion_tokens', default_max_tok)))
|
||||
presence_penalty = genparams.get('presence_penalty', genparams.get('frequency_penalty', 0.0))
|
||||
genparams["presence_penalty"] = presence_penalty
|
||||
genparams["presence_penalty"] = float(presence_penalty)
|
||||
# openai allows either a string or a list as a stop sequence
|
||||
if isinstance(genparams.get('stop',[]), list):
|
||||
genparams["stop_sequence"] = genparams.get('stop', [])
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue