handle explicit null

This commit is contained in:
Concedo 2025-04-26 13:06:38 +08:00
parent cb1c182673
commit 4dcd215b27

View file

@ -594,16 +594,20 @@ def end_trim_to_sentence(input_text):
return input_text[:last + 1].strip() return input_text[:last + 1].strip()
return input_text.strip() return input_text.strip()
def tryparseint(value): def tryparseint(value,fallback):
if value is None:
return fallback
try: try:
return int(value) return int(value)
except ValueError: except ValueError:
return value return fallback
def tryparsefloat(value): def tryparsefloat(value,fallback):
if value is None:
return fallback
try: try:
return float(value) return float(value)
except ValueError: except ValueError:
return value return fallback
def is_incomplete_utf8_sequence(byte_seq): #note, this will only flag INCOMPLETE sequences, corrupted ones will be ignored. def is_incomplete_utf8_sequence(byte_seq): #note, this will only flag INCOMPLETE sequences, corrupted ones will be ignored.
try: try:
@ -1237,32 +1241,32 @@ def generate(genparams, stream_flag=False):
prompt = genparams.get('prompt', "") prompt = genparams.get('prompt', "")
memory = genparams.get('memory', "") memory = genparams.get('memory', "")
images = genparams.get('images', []) images = genparams.get('images', [])
max_context_length = int(genparams.get('max_context_length', maxctx)) max_context_length = tryparseint(genparams.get('max_context_length', maxctx),maxctx)
max_length = int(genparams.get('max_length', args.defaultgenamt)) max_length = tryparseint(genparams.get('max_length', args.defaultgenamt),args.defaultgenamt)
temperature = float(genparams.get('temperature', 0.75)) temperature = tryparsefloat(genparams.get('temperature', 0.75),0.75)
top_k = int(genparams.get('top_k', 100)) top_k = tryparseint(genparams.get('top_k', 100),100)
top_a = float(genparams.get('top_a', 0.0)) top_a = tryparsefloat(genparams.get('top_a', 0.0),0.0)
top_p = float(genparams.get('top_p', 0.92)) top_p = tryparsefloat(genparams.get('top_p', 0.92),0.92)
min_p = float(genparams.get('min_p', 0.0)) min_p = tryparsefloat(genparams.get('min_p', 0.0),0.0)
typical_p = float(genparams.get('typical', 1.0)) typical_p = tryparsefloat(genparams.get('typical', 1.0),1.0)
tfs = float(genparams.get('tfs', 1.0)) tfs = tryparsefloat(genparams.get('tfs', 1.0),1.0)
nsigma = float(genparams.get('nsigma', 0.0)) nsigma = tryparsefloat(genparams.get('nsigma', 0.0),0.0)
rep_pen = float(genparams.get('rep_pen', 1.0)) rep_pen = tryparsefloat(genparams.get('rep_pen', 1.0),1.0)
rep_pen_range = int(genparams.get('rep_pen_range', 320)) rep_pen_range = tryparseint(genparams.get('rep_pen_range', 320),320)
rep_pen_slope = float(genparams.get('rep_pen_slope', 1.0)) rep_pen_slope = tryparsefloat(genparams.get('rep_pen_slope', 1.0),1.0)
presence_penalty = float(genparams.get('presence_penalty', 0.0)) presence_penalty = tryparsefloat(genparams.get('presence_penalty', 0.0),0.0)
mirostat = int(genparams.get('mirostat', 0)) mirostat = tryparseint(genparams.get('mirostat', 0),0)
mirostat_tau = float(genparams.get('mirostat_tau', 5.0)) mirostat_tau = tryparsefloat(genparams.get('mirostat_tau', 5.0),5.0)
mirostat_eta = float(genparams.get('mirostat_eta', 0.1)) mirostat_eta = tryparsefloat(genparams.get('mirostat_eta', 0.1),0.1)
dry_multiplier = float(genparams.get('dry_multiplier', 0.0)) dry_multiplier = tryparsefloat(genparams.get('dry_multiplier', 0.0),0.0)
dry_base = float(genparams.get('dry_base', 1.75)) dry_base = tryparsefloat(genparams.get('dry_base', 1.75),1.75)
dry_allowed_length = int(genparams.get('dry_allowed_length', 2)) dry_allowed_length = tryparseint(genparams.get('dry_allowed_length', 2),2)
dry_penalty_last_n = int(genparams.get('dry_penalty_last_n', 320)) dry_penalty_last_n = tryparseint(genparams.get('dry_penalty_last_n', 320),320)
dry_sequence_breakers = genparams.get('dry_sequence_breakers', []) dry_sequence_breakers = genparams.get('dry_sequence_breakers', [])
xtc_threshold = float(genparams.get('xtc_threshold', 0.2)) xtc_threshold = tryparsefloat(genparams.get('xtc_threshold', 0.2),0.2)
xtc_probability = float(genparams.get('xtc_probability', 0)) xtc_probability = tryparsefloat(genparams.get('xtc_probability', 0),0)
sampler_order = genparams.get('sampler_order', [6, 0, 1, 3, 4, 2, 5]) sampler_order = genparams.get('sampler_order', [6, 0, 1, 3, 4, 2, 5])
seed = tryparseint(genparams.get('sampler_seed', -1)) seed = tryparseint(genparams.get('sampler_seed', -1),-1)
stop_sequence = genparams.get('stop_sequence', []) stop_sequence = genparams.get('stop_sequence', [])
ban_eos_token = genparams.get('ban_eos_token', False) ban_eos_token = genparams.get('ban_eos_token', False)
stream_sse = stream_flag stream_sse = stream_flag
@ -1278,9 +1282,9 @@ def generate(genparams, stream_flag=False):
grammar_retain_state = genparams.get('grammar_retain_state', False) grammar_retain_state = genparams.get('grammar_retain_state', False)
genkey = genparams.get('genkey', '') genkey = genparams.get('genkey', '')
trimstop = genparams.get('trim_stop', True) trimstop = genparams.get('trim_stop', True)
dynatemp_range = float(genparams.get('dynatemp_range', 0.0)) dynatemp_range = tryparsefloat(genparams.get('dynatemp_range', 0.0),0.0)
dynatemp_exponent = float(genparams.get('dynatemp_exponent', 1.0)) dynatemp_exponent = tryparsefloat(genparams.get('dynatemp_exponent', 1.0),1.0)
smoothing_factor = float(genparams.get('smoothing_factor', 0.0)) smoothing_factor = tryparsefloat(genparams.get('smoothing_factor', 0.0),0.0)
logit_biases = genparams.get('logit_bias', {}) logit_biases = genparams.get('logit_bias', {})
render_special = genparams.get('render_special', False) render_special = genparams.get('render_special', False)
banned_strings = genparams.get('banned_strings', []) # SillyTavern uses that name banned_strings = genparams.get('banned_strings', []) # SillyTavern uses that name
@ -1534,14 +1538,14 @@ def sd_generate(genparams):
init_images = ("" if (not init_images_arr or len(init_images_arr)==0 or not init_images_arr[0]) else init_images_arr[0]) init_images = ("" if (not init_images_arr or len(init_images_arr)==0 or not init_images_arr[0]) else init_images_arr[0])
mask = genparams.get("mask", "") mask = genparams.get("mask", "")
flip_mask = genparams.get("inpainting_mask_invert", 0) flip_mask = genparams.get("inpainting_mask_invert", 0)
denoising_strength = tryparsefloat(genparams.get("denoising_strength", 0.6)) denoising_strength = tryparsefloat(genparams.get("denoising_strength", 0.6),0.6)
cfg_scale = tryparsefloat(genparams.get("cfg_scale", 5)) cfg_scale = tryparsefloat(genparams.get("cfg_scale", 5),5)
sample_steps = tryparseint(genparams.get("steps", 20)) sample_steps = tryparseint(genparams.get("steps", 20),20)
width = tryparseint(genparams.get("width", 512)) width = tryparseint(genparams.get("width", 512),512)
height = tryparseint(genparams.get("height", 512)) height = tryparseint(genparams.get("height", 512),512)
seed = tryparseint(genparams.get("seed", -1)) seed = tryparseint(genparams.get("seed", -1),-1)
sample_method = genparams.get("sampler_name", "k_euler_a") sample_method = genparams.get("sampler_name", "k_euler_a")
clip_skip = tryparseint(genparams.get("clip_skip", -1)) clip_skip = tryparseint(genparams.get("clip_skip", -1),-1)
#clean vars #clean vars
width = width - (width%64) width = width - (width%64)
@ -2043,7 +2047,7 @@ def transform_genparams(genparams, api_format):
else: else:
genparams["stop_sequence"] = [genparams.get('stop')] genparams["stop_sequence"] = [genparams.get('stop')]
genparams["sampler_seed"] = tryparseint(genparams.get('seed', -1)) genparams["sampler_seed"] = tryparseint(genparams.get('seed', -1),-1)
genparams["mirostat"] = genparams.get('mirostat_mode', 0) genparams["mirostat"] = genparams.get('mirostat_mode', 0)
if api_format==4 or api_format==7: #handle ollama chat here too if api_format==4 or api_format==7: #handle ollama chat here too
@ -2248,7 +2252,7 @@ ws ::= | " " | "\n" [ \t]{0,20}
if "top_p" in ollamaopts: if "top_p" in ollamaopts:
genparams["top_p"] = ollamaopts.get('top_p', 0.92) genparams["top_p"] = ollamaopts.get('top_p', 0.92)
if "seed" in ollamaopts: if "seed" in ollamaopts:
genparams["sampler_seed"] = tryparseint(ollamaopts.get('seed', -1)) genparams["sampler_seed"] = tryparseint(ollamaopts.get('seed', -1),-1)
if "stop" in ollamaopts: if "stop" in ollamaopts:
genparams["stop_sequence"] = ollamaopts.get('stop', []) genparams["stop_sequence"] = ollamaopts.get('stop', [])
genparams["stop_sequence"].append(user_message_start.strip()) genparams["stop_sequence"].append(user_message_start.strip())
@ -3132,7 +3136,7 @@ Change Mode<br>
loadid = -1 loadid = -1
try: try:
tempbody = json.loads(body) tempbody = json.loads(body)
loadid = tryparseint(tempbody.get('slot', 0)) loadid = tryparseint(tempbody.get('slot', 0),0)
except Exception: except Exception:
loadid = -1 loadid = -1
if loadid < 0 or str(loadid) not in savedata_obj: if loadid < 0 or str(loadid) not in savedata_obj:
@ -3149,7 +3153,7 @@ Change Mode<br>
else: else:
try: try:
incoming_story = json.loads(body) # ensure submitted data is valid json incoming_story = json.loads(body) # ensure submitted data is valid json
slotid = tryparseint(incoming_story.get('slot', -1)) slotid = tryparseint(incoming_story.get('slot', -1),-1)
dataformat = incoming_story.get('format', "") dataformat = incoming_story.get('format', "")
title = incoming_story.get('title', "") title = incoming_story.get('title', "")
if not title or title=="": if not title or title=="":