mirror of
https://github.com/LostRuins/koboldcpp.git
synced 2025-09-10 09:04:36 +00:00
handle explicit null
This commit is contained in:
parent
cb1c182673
commit
4dcd215b27
1 changed files with 46 additions and 42 deletions
88
koboldcpp.py
88
koboldcpp.py
|
@ -594,16 +594,20 @@ def end_trim_to_sentence(input_text):
|
|||
return input_text[:last + 1].strip()
|
||||
return input_text.strip()
|
||||
|
||||
def tryparseint(value):
|
||||
def tryparseint(value,fallback):
|
||||
if value is None:
|
||||
return fallback
|
||||
try:
|
||||
return int(value)
|
||||
except ValueError:
|
||||
return value
|
||||
def tryparsefloat(value):
|
||||
return fallback
|
||||
def tryparsefloat(value,fallback):
|
||||
if value is None:
|
||||
return fallback
|
||||
try:
|
||||
return float(value)
|
||||
except ValueError:
|
||||
return value
|
||||
return fallback
|
||||
|
||||
def is_incomplete_utf8_sequence(byte_seq): #note, this will only flag INCOMPLETE sequences, corrupted ones will be ignored.
|
||||
try:
|
||||
|
@ -1237,32 +1241,32 @@ def generate(genparams, stream_flag=False):
|
|||
prompt = genparams.get('prompt', "")
|
||||
memory = genparams.get('memory', "")
|
||||
images = genparams.get('images', [])
|
||||
max_context_length = int(genparams.get('max_context_length', maxctx))
|
||||
max_length = int(genparams.get('max_length', args.defaultgenamt))
|
||||
temperature = float(genparams.get('temperature', 0.75))
|
||||
top_k = int(genparams.get('top_k', 100))
|
||||
top_a = float(genparams.get('top_a', 0.0))
|
||||
top_p = float(genparams.get('top_p', 0.92))
|
||||
min_p = float(genparams.get('min_p', 0.0))
|
||||
typical_p = float(genparams.get('typical', 1.0))
|
||||
tfs = float(genparams.get('tfs', 1.0))
|
||||
nsigma = float(genparams.get('nsigma', 0.0))
|
||||
rep_pen = float(genparams.get('rep_pen', 1.0))
|
||||
rep_pen_range = int(genparams.get('rep_pen_range', 320))
|
||||
rep_pen_slope = float(genparams.get('rep_pen_slope', 1.0))
|
||||
presence_penalty = float(genparams.get('presence_penalty', 0.0))
|
||||
mirostat = int(genparams.get('mirostat', 0))
|
||||
mirostat_tau = float(genparams.get('mirostat_tau', 5.0))
|
||||
mirostat_eta = float(genparams.get('mirostat_eta', 0.1))
|
||||
dry_multiplier = float(genparams.get('dry_multiplier', 0.0))
|
||||
dry_base = float(genparams.get('dry_base', 1.75))
|
||||
dry_allowed_length = int(genparams.get('dry_allowed_length', 2))
|
||||
dry_penalty_last_n = int(genparams.get('dry_penalty_last_n', 320))
|
||||
max_context_length = tryparseint(genparams.get('max_context_length', maxctx),maxctx)
|
||||
max_length = tryparseint(genparams.get('max_length', args.defaultgenamt),args.defaultgenamt)
|
||||
temperature = tryparsefloat(genparams.get('temperature', 0.75),0.75)
|
||||
top_k = tryparseint(genparams.get('top_k', 100),100)
|
||||
top_a = tryparsefloat(genparams.get('top_a', 0.0),0.0)
|
||||
top_p = tryparsefloat(genparams.get('top_p', 0.92),0.92)
|
||||
min_p = tryparsefloat(genparams.get('min_p', 0.0),0.0)
|
||||
typical_p = tryparsefloat(genparams.get('typical', 1.0),1.0)
|
||||
tfs = tryparsefloat(genparams.get('tfs', 1.0),1.0)
|
||||
nsigma = tryparsefloat(genparams.get('nsigma', 0.0),0.0)
|
||||
rep_pen = tryparsefloat(genparams.get('rep_pen', 1.0),1.0)
|
||||
rep_pen_range = tryparseint(genparams.get('rep_pen_range', 320),320)
|
||||
rep_pen_slope = tryparsefloat(genparams.get('rep_pen_slope', 1.0),1.0)
|
||||
presence_penalty = tryparsefloat(genparams.get('presence_penalty', 0.0),0.0)
|
||||
mirostat = tryparseint(genparams.get('mirostat', 0),0)
|
||||
mirostat_tau = tryparsefloat(genparams.get('mirostat_tau', 5.0),5.0)
|
||||
mirostat_eta = tryparsefloat(genparams.get('mirostat_eta', 0.1),0.1)
|
||||
dry_multiplier = tryparsefloat(genparams.get('dry_multiplier', 0.0),0.0)
|
||||
dry_base = tryparsefloat(genparams.get('dry_base', 1.75),1.75)
|
||||
dry_allowed_length = tryparseint(genparams.get('dry_allowed_length', 2),2)
|
||||
dry_penalty_last_n = tryparseint(genparams.get('dry_penalty_last_n', 320),320)
|
||||
dry_sequence_breakers = genparams.get('dry_sequence_breakers', [])
|
||||
xtc_threshold = float(genparams.get('xtc_threshold', 0.2))
|
||||
xtc_probability = float(genparams.get('xtc_probability', 0))
|
||||
xtc_threshold = tryparsefloat(genparams.get('xtc_threshold', 0.2),0.2)
|
||||
xtc_probability = tryparsefloat(genparams.get('xtc_probability', 0),0)
|
||||
sampler_order = genparams.get('sampler_order', [6, 0, 1, 3, 4, 2, 5])
|
||||
seed = tryparseint(genparams.get('sampler_seed', -1))
|
||||
seed = tryparseint(genparams.get('sampler_seed', -1),-1)
|
||||
stop_sequence = genparams.get('stop_sequence', [])
|
||||
ban_eos_token = genparams.get('ban_eos_token', False)
|
||||
stream_sse = stream_flag
|
||||
|
@ -1278,9 +1282,9 @@ def generate(genparams, stream_flag=False):
|
|||
grammar_retain_state = genparams.get('grammar_retain_state', False)
|
||||
genkey = genparams.get('genkey', '')
|
||||
trimstop = genparams.get('trim_stop', True)
|
||||
dynatemp_range = float(genparams.get('dynatemp_range', 0.0))
|
||||
dynatemp_exponent = float(genparams.get('dynatemp_exponent', 1.0))
|
||||
smoothing_factor = float(genparams.get('smoothing_factor', 0.0))
|
||||
dynatemp_range = tryparsefloat(genparams.get('dynatemp_range', 0.0),0.0)
|
||||
dynatemp_exponent = tryparsefloat(genparams.get('dynatemp_exponent', 1.0),1.0)
|
||||
smoothing_factor = tryparsefloat(genparams.get('smoothing_factor', 0.0),0.0)
|
||||
logit_biases = genparams.get('logit_bias', {})
|
||||
render_special = genparams.get('render_special', False)
|
||||
banned_strings = genparams.get('banned_strings', []) # SillyTavern uses that name
|
||||
|
@ -1534,14 +1538,14 @@ def sd_generate(genparams):
|
|||
init_images = ("" if (not init_images_arr or len(init_images_arr)==0 or not init_images_arr[0]) else init_images_arr[0])
|
||||
mask = genparams.get("mask", "")
|
||||
flip_mask = genparams.get("inpainting_mask_invert", 0)
|
||||
denoising_strength = tryparsefloat(genparams.get("denoising_strength", 0.6))
|
||||
cfg_scale = tryparsefloat(genparams.get("cfg_scale", 5))
|
||||
sample_steps = tryparseint(genparams.get("steps", 20))
|
||||
width = tryparseint(genparams.get("width", 512))
|
||||
height = tryparseint(genparams.get("height", 512))
|
||||
seed = tryparseint(genparams.get("seed", -1))
|
||||
denoising_strength = tryparsefloat(genparams.get("denoising_strength", 0.6),0.6)
|
||||
cfg_scale = tryparsefloat(genparams.get("cfg_scale", 5),5)
|
||||
sample_steps = tryparseint(genparams.get("steps", 20),20)
|
||||
width = tryparseint(genparams.get("width", 512),512)
|
||||
height = tryparseint(genparams.get("height", 512),512)
|
||||
seed = tryparseint(genparams.get("seed", -1),-1)
|
||||
sample_method = genparams.get("sampler_name", "k_euler_a")
|
||||
clip_skip = tryparseint(genparams.get("clip_skip", -1))
|
||||
clip_skip = tryparseint(genparams.get("clip_skip", -1),-1)
|
||||
|
||||
#clean vars
|
||||
width = width - (width%64)
|
||||
|
@ -2043,7 +2047,7 @@ def transform_genparams(genparams, api_format):
|
|||
else:
|
||||
genparams["stop_sequence"] = [genparams.get('stop')]
|
||||
|
||||
genparams["sampler_seed"] = tryparseint(genparams.get('seed', -1))
|
||||
genparams["sampler_seed"] = tryparseint(genparams.get('seed', -1),-1)
|
||||
genparams["mirostat"] = genparams.get('mirostat_mode', 0)
|
||||
|
||||
if api_format==4 or api_format==7: #handle ollama chat here too
|
||||
|
@ -2248,7 +2252,7 @@ ws ::= | " " | "\n" [ \t]{0,20}
|
|||
if "top_p" in ollamaopts:
|
||||
genparams["top_p"] = ollamaopts.get('top_p', 0.92)
|
||||
if "seed" in ollamaopts:
|
||||
genparams["sampler_seed"] = tryparseint(ollamaopts.get('seed', -1))
|
||||
genparams["sampler_seed"] = tryparseint(ollamaopts.get('seed', -1),-1)
|
||||
if "stop" in ollamaopts:
|
||||
genparams["stop_sequence"] = ollamaopts.get('stop', [])
|
||||
genparams["stop_sequence"].append(user_message_start.strip())
|
||||
|
@ -3132,7 +3136,7 @@ Change Mode<br>
|
|||
loadid = -1
|
||||
try:
|
||||
tempbody = json.loads(body)
|
||||
loadid = tryparseint(tempbody.get('slot', 0))
|
||||
loadid = tryparseint(tempbody.get('slot', 0),0)
|
||||
except Exception:
|
||||
loadid = -1
|
||||
if loadid < 0 or str(loadid) not in savedata_obj:
|
||||
|
@ -3149,7 +3153,7 @@ Change Mode<br>
|
|||
else:
|
||||
try:
|
||||
incoming_story = json.loads(body) # ensure submitted data is valid json
|
||||
slotid = tryparseint(incoming_story.get('slot', -1))
|
||||
slotid = tryparseint(incoming_story.get('slot', -1),-1)
|
||||
dataformat = incoming_story.get('format', "")
|
||||
title = incoming_story.get('title', "")
|
||||
if not title or title=="":
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue