tidy up and refactor code to support old flags

This commit is contained in:
Concedo 2024-05-10 16:50:53 +08:00
parent eccc2ddca2
commit dbe72b959e
4 changed files with 88 additions and 173 deletions

View file

@ -45,7 +45,6 @@ class load_model_inputs(ctypes.Structure):
("mmproj_filename", ctypes.c_char_p),
("use_mmap", ctypes.c_bool),
("use_mlock", ctypes.c_bool),
("use_smartcontext", ctypes.c_bool),
("use_contextshift", ctypes.c_bool),
("clblast_info", ctypes.c_int),
("cublas_info", ctypes.c_int),
@ -372,7 +371,6 @@ def load_model(model_filename):
inputs.lora_base = args.lora[1].encode("UTF-8")
inputs.mmproj_filename = args.mmproj.encode("UTF-8") if args.mmproj else "".encode("UTF-8")
inputs.use_smartcontext = args.smartcontext
inputs.use_contextshift = (0 if args.noshift else 1)
inputs.flash_attention = args.flashattention
inputs.blasbatchsize = args.blasbatchsize
@ -519,14 +517,6 @@ def sd_load_model(model_filename):
inputs.model_filename = model_filename.encode("UTF-8")
thds = args.threads
quant = 0
#todo: remove this
if args.sdconfig and len(args.sdconfig) > 2:
show_deprecated_warning()
sdt = int(args.sdconfig[2])
if sdt > 0:
thds = sdt
if args.sdconfig and len(args.sdconfig) > 3:
quant = (1 if args.sdconfig[3]=="quant" else 0)
if args.sdthreads and args.sdthreads > 0:
sdt = int(args.sdthreads)
@ -565,25 +555,10 @@ def sd_generate(genparams):
width = (64 if width < 64 else width)
height = (64 if height < 64 else height)
#todo: remove this
if args.sdconfig and len(args.sdconfig)>1:
show_deprecated_warning()
if args.sdconfig[1]=="quick":
cfg_scale = 1
sample_steps = 7
sample_method = "dpm++ 2m karras"
reslimit = 512
print("\nSDConfig: Quick Mode (Low Quality). Step counts, resolution, sampler, and cfg scale are fixed.")
elif args.sdconfig[1]=="clamped":
sample_steps = (40 if sample_steps > 40 else sample_steps)
reslimit = 512
print("\nSDConfig: Clamped Mode (For Shared Use). Step counts and resolution are clamped.")
if args.sdclamped:
sample_steps = (40 if sample_steps > 40 else sample_steps)
reslimit = 512
print("\nSDConfig: Clamped Mode (For Shared Use). Step counts and resolution are clamped.")
print("\nImgGen: Clamped Mode (For Shared Use). Step counts and resolution are clamped.")
biggest = max(width,height)
if biggest > reslimit:
@ -682,6 +657,7 @@ last_req_time = time.time()
last_non_horde_req_time = time.time()
currfinishreason = "null"
using_gui_launcher = False
using_outdated_flags = False
def transform_genparams(genparams, api_format):
#api format 1=basic,2=kai,3=oai,4=oai-chat,5=interrogate
@ -1542,7 +1518,7 @@ def show_new_gui():
root.quit()
if args.model_param and args.model_param!="" and args.model_param.lower().endswith('.kcpps'):
loadconfigfile(args.model_param)
if not args.model_param and not args.sdconfig and not args.sdmodel:
if not args.model_param and not args.sdmodel:
global exitcounter
exitcounter = 999
print("\nNo ggml model or kcpps file was selected. Exiting.")
@ -1695,7 +1671,6 @@ def show_new_gui():
contextshift = ctk.IntVar(value=1)
remotetunnel = ctk.IntVar(value=0)
smartcontext = ctk.IntVar()
flashattention = ctk.IntVar(value=0)
context_var = ctk.IntVar()
customrope_var = ctk.IntVar()
@ -1979,10 +1954,7 @@ def show_new_gui():
gpulayers_var.trace("w", changed_gpulayers)
def togglectxshift(a,b,c):
if contextshift.get()==0:
smartcontextbox.grid(row=1, column=0, padx=8, pady=1, stick="nw")
else:
smartcontextbox.grid_forget()
pass
def guibench():
args.benchmark = "stdout"
@ -2136,7 +2108,6 @@ def show_new_gui():
# Tokens Tab
tokens_tab = tabcontent["Tokens"]
# tokens checkboxes
smartcontextbox = makecheckbox(tokens_tab, "Use SmartContext", smartcontext, 1,tooltiptxt="Uses SmartContext. Now considered outdated and not recommended.\nCheck the wiki for more info.")
makecheckbox(tokens_tab, "Use ContextShift", contextshift, 2,tooltiptxt="Uses Context Shifting to reduce reprocessing.\nRecommended. Check the wiki for more info.", command=togglectxshift)
togglectxshift(1,1,1)
@ -2233,7 +2204,6 @@ def show_new_gui():
args.launch = launchbrowser.get()==1
args.highpriority = highpriority.get()==1
args.nommap = disablemmap.get()==1
args.smartcontext = smartcontext.get()==1
args.flashattention = flashattention.get()==1
args.noshift = contextshift.get()==0
args.remotetunnel = remotetunnel.get()==1
@ -2319,6 +2289,8 @@ def show_new_gui():
args.sdquant = True
def import_vars(dict):
dict = convert_outdated_args(dict)
if "threads" in dict:
threads_var.set(dict["threads"])
usemlock.set(1 if "usemlock" in dict and dict["usemlock"] else 0)
@ -2327,7 +2299,6 @@ def show_new_gui():
launchbrowser.set(1 if "launch" in dict and dict["launch"] else 0)
highpriority.set(1 if "highpriority" in dict and dict["highpriority"] else 0)
disablemmap.set(1 if "nommap" in dict and dict["nommap"] else 0)
smartcontext.set(1 if "smartcontext" in dict and dict["smartcontext"] else 0)
flashattention.set(1 if "flashattention" in dict and dict["flashattention"] else 0)
contextshift.set(0 if "noshift" in dict and dict["noshift"] else 1)
remotetunnel.set(1 if "remotetunnel" in dict and dict["remotetunnel"] else 0)
@ -2409,12 +2380,12 @@ def show_new_gui():
if "blasbatchsize" in dict and dict["blasbatchsize"]:
blas_size_var.set(blasbatchsize_values.index(str(dict["blasbatchsize"])))
if "forceversion" in dict and dict["forceversion"]:
version_var.set(str(dict["forceversion"]))
if "model_param" in dict and dict["model_param"]:
model_var.set(dict["model_param"])
version_var.set(str(dict["forceversion"]) if ("forceversion" in dict and dict["forceversion"]) else "0")
model_var.set(dict["model_param"] if ("model_param" in dict and dict["model_param"]) else "")
lora_var.set("")
lora_base_var.set("")
if "lora" in dict and dict["lora"]:
if len(dict["lora"]) > 1:
lora_var.set(dict["lora"][0])
@ -2422,73 +2393,33 @@ def show_new_gui():
else:
lora_var.set(dict["lora"][0])
if "mmproj" in dict and dict["mmproj"]:
mmproj_var.set(dict["mmproj"])
mmproj_var.set(dict["mmproj"] if ("mmproj" in dict and dict["mmproj"]) else "")
ssl_cert_var.set("")
ssl_key_var.set("")
if "ssl" in dict and dict["ssl"]:
if len(dict["ssl"]) == 2:
ssl_cert_var.set(dict["ssl"][0])
ssl_key_var.set(dict["ssl"][1])
if "password" in dict and dict["password"]:
password_var.set(dict["password"])
password_var.set(dict["password"] if ("password" in dict and dict["password"]) else "")
preloadstory_var.set(dict["preloadstory"] if ("preloadstory" in dict and dict["preloadstory"]) else "")
chatcompletionsadapter_var.set(dict["chatcompletionsadapter"] if ("chatcompletionsadapter" in dict and dict["chatcompletionsadapter"]) else "")
port_var.set(dict["port_param"] if ("port_param" in dict and dict["port_param"]) else defaultport)
host_var.set(dict["host"] if ("host" in dict and dict["host"]) else "")
multiuser_var.set(dict["multiuser"] if ("multiuser" in dict) else 1)
if "preloadstory" in dict and dict["preloadstory"]:
preloadstory_var.set(dict["preloadstory"])
if "chatcompletionsadapter" in dict and dict["chatcompletionsadapter"]:
chatcompletionsadapter_var.set(dict["chatcompletionsadapter"])
if "port_param" in dict and dict["port_param"]:
port_var.set(dict["port_param"])
if "host" in dict and dict["host"]:
host_var.set(dict["host"])
if "multiuser" in dict:
multiuser_var.set(dict["multiuser"])
# todo: remove these
if "hordeconfig" in dict and dict["hordeconfig"] and len(dict["hordeconfig"]) > 1:
horde_name_var.set(dict["hordeconfig"][0])
horde_gen_var.set(dict["hordeconfig"][1])
horde_context_var.set(dict["hordeconfig"][2])
if len(dict["hordeconfig"]) > 4:
horde_apikey_var.set(dict["hordeconfig"][3])
horde_workername_var.set(dict["hordeconfig"][4])
usehorde_var.set("1")
show_deprecated_warning()
if "sdconfig" in dict and dict["sdconfig"] and len(dict["sdconfig"]) > 0:
sd_model_var.set(dict["sdconfig"][0])
if len(dict["sdconfig"]) > 1:
sd_clamped_var.set(1 if dict["sdconfig"][1]=="quick" else 0)
if len(dict["sdconfig"]) > 2:
sd_threads_var.set(str(dict["sdconfig"][2]))
if len(dict["sdconfig"]) > 3:
sd_quant_var.set(str(dict["sdconfig"][3])=="quant")
show_deprecated_warning()
if "hordemodelname" in dict and dict["hordemodelname"]:
horde_name_var.set(dict["hordemodelname"])
if "hordemaxctx" in dict and dict["hordemaxctx"]:
horde_context_var.set(dict["hordemaxctx"])
if "hordegenlen" in dict and dict["hordegenlen"]:
horde_gen_var.set(dict["hordegenlen"])
if "hordekey" in dict and dict["hordekey"]:
horde_apikey_var.set(dict["hordekey"])
usehorde_var.set(1)
if "hordeworkername" in dict and dict["hordeworkername"]:
horde_workername_var.set(dict["hordeworkername"])
if "sdmodel" in dict and dict["sdmodel"]:
sd_model_var.set(dict["sdmodel"])
if "sdclamped" in dict and dict["sdclamped"]:
sd_clamped_var.set(1)
if "sdthreads" in dict and dict["sdthreads"]:
sd_threads_var.set(str(dict["sdthreads"]))
if "sdquant" in dict and dict["sdquant"]:
sd_quant_var.set(1)
horde_name_var.set(dict["hordemodelname"] if ("hordemodelname" in dict and dict["hordemodelname"]) else "koboldcpp")
horde_context_var.set(dict["hordemaxctx"] if ("hordemaxctx" in dict and dict["hordemaxctx"]) else maxhordectx)
horde_gen_var.set(dict["hordegenlen"] if ("hordegenlen" in dict and dict["hordegenlen"]) else maxhordelen)
horde_apikey_var.set(dict["hordekey"] if ("hordekey" in dict and dict["hordekey"]) else "")
horde_workername_var.set(dict["hordeworkername"] if ("hordeworkername" in dict and dict["hordeworkername"]) else "")
usehorde_var.set(1 if ("hordekey" in dict and dict["hordekey"]) else 0)
sd_model_var.set(dict["sdmodel"] if ("sdmodel" in dict and dict["sdmodel"]) else "")
sd_clamped_var.set(1 if ("sdclamped" in dict and dict["sdclamped"]) else 0)
sd_threads_var.set(str(dict["sdthreads"]) if ("sdthreads" in dict and dict["sdthreads"]) else str(default_threads))
sd_quant_var.set(1 if ("sdquant" in dict and dict["sdquant"]) else 0)
def save_config():
file_type = [("KoboldCpp Settings", "*.kcpps")]
@ -2548,7 +2479,7 @@ def show_new_gui():
# processing vars
export_vars()
if not args.model_param and not args.sdconfig and not args.sdmodel:
if not args.model_param and not args.sdmodel:
exitcounter = 999
print("\nNo text or image model file was selected. Exiting.")
time.sleep(3)
@ -2760,15 +2691,51 @@ def run_horde_worker(args, api_key, worker_name):
time.sleep(3)
sys.exit(2)
# todo: remove this
def show_deprecated_warning():
print("\n=== !!! IMPORTANT WARNING !!! ===")
print("The flags --smartcontext, --hordeconfig and --sdconfig have been DEPRECATED and will be REMOVED soon.")
print("Please use the single-parameter flags instead, e.g. --hordekey, --hordemodelname, --hordemaxctx, --sdmodel, --sdquant, etc")
print("SmartContext will only be applied when contextshift is selected on a model that does not support it.")
print("For more information on these flags, please check --help")
print("If you are using the GUI launcher, simply re-saving your config again will solve this warning.")
print("=== !!! IMPORTANT WARNING !!! ===")
def convert_outdated_args(args):
dict = args
if isinstance(args, argparse.Namespace):
dict = vars(args)
global using_outdated_flags
using_outdated_flags = False
if "sdconfig" in dict and dict["sdconfig"] and len(dict["sdconfig"])>0:
using_outdated_flags = True
dict["sdmodel"] = dict["sdconfig"][0]
if dict["sdconfig"] and len(dict["sdconfig"]) > 1:
dict["sdclamped"] = True
if dict["sdconfig"] and len(dict["sdconfig"]) > 2:
dict["sdthreads"] = int(dict["sdconfig"][2])
if dict["sdconfig"] and len(dict["sdconfig"]) > 3:
dict["sdquant"] = (True if dict["sdconfig"][3]=="quant" else False)
if "hordeconfig" in dict and dict["hordeconfig"] and dict["hordeconfig"][0]!="":
using_outdated_flags = True
dict["hordemodelname"] = dict["hordeconfig"][0]
if len(dict["hordeconfig"]) > 1:
dict["hordegenlen"] = int(dict["hordeconfig"][1])
if len(dict["hordeconfig"]) > 2:
dict["hordemaxctx"] = int(dict["hordeconfig"][2])
if len(dict["hordeconfig"]) > 4:
dict["hordekey"] = dict["hordeconfig"][3]
dict["hordeworkername"] = dict["hordeconfig"][4]
check_deprecation_warning()
return args
def check_deprecation_warning():
# slightly naggy warning to encourage people to switch to new flags
# if you want you can remove this at your own risk,
# but i am not going to troubleshoot or provide support for deprecated flags.
global using_outdated_flags
if using_outdated_flags:
print(f"\n=== !!! IMPORTANT WARNING !!! ===")
print("You are using one or more OUTDATED config files or launch flags!")
print("--smartcontext, --hordeconfig and --sdconfig have been DEPRECATED and MAY be REMOVED in future.")
print("They will still work for now, but you SHOULD switch to the updated flags instead, to avoid future issues.")
print("New flags are: --hordemodelname --hordeworkername --hordekey --hordemaxctx --hordegenlen --sdmodel --sdthreads --sdquant --sdclamped")
print("For more information on these flags, please check --help")
print("> If you are using the GUI launcher, simply re-saving your config again will get rid of this warning.")
print("=== !!! IMPORTANT WARNING !!! ===\n")
def setuptunnel():
# This script will help setup a cloudflared tunnel for accessing KoboldCpp over the internet
@ -2960,7 +2927,6 @@ def sanitize_string(input_string):
def main(launch_args,start_server=True):
global embedded_kailite, embedded_kcpp_docs, embedded_kcpp_sdui
global libname, args, friendlymodelname, friendlysdmodelname, fullsdmodelpath, mmprojpath, password
args = launch_args
#perform some basic cleanup of old temporary directories
try:
@ -2968,7 +2934,7 @@ def main(launch_args,start_server=True):
except Exception as e:
print(f"Error cleaning up orphaned pyinstaller dirs: {e}")
args = launch_args
if args.config and len(args.config)==1:
if isinstance(args.config[0], str) and os.path.exists(args.config[0]):
loadconfigfile(args.config[0])
@ -2980,6 +2946,7 @@ def main(launch_args,start_server=True):
print("Specified kcpp config file invalid or not found.")
time.sleep(3)
sys.exit(2)
args = convert_outdated_args(args)
#positional handling for kcpps files (drag and drop)
if args.model_param and args.model_param!="" and args.model_param.lower().endswith('.kcpps'):
@ -2988,7 +2955,7 @@ def main(launch_args,start_server=True):
if not args.model_param:
args.model_param = args.model
if not args.model_param and not args.sdconfig and not args.sdmodel:
if not args.model_param and not args.sdmodel:
#give them a chance to pick a file
print("For command line arguments, please refer to --help")
print("***")
@ -3032,22 +2999,7 @@ def main(launch_args,start_server=True):
newmdldisplayname = os.path.splitext(newmdldisplayname)[0]
friendlymodelname = "koboldcpp/" + sanitize_string(newmdldisplayname)
# todo: remove these
global maxhordelen, maxhordectx, showdebug
if args.hordeconfig and args.hordeconfig[0]!="":
show_deprecated_warning()
friendlymodelname = args.hordeconfig[0]
if args.debugmode == 1:
friendlymodelname = "debug-" + friendlymodelname
if not friendlymodelname.startswith("koboldcpp/"):
friendlymodelname = "koboldcpp/" + friendlymodelname
if len(args.hordeconfig) > 1:
maxhordelen = int(args.hordeconfig[1])
if len(args.hordeconfig) > 2:
maxhordectx = int(args.hordeconfig[2])
if args.debugmode == 0:
args.debugmode = -1
if args.hordemodelname and args.hordemodelname!="":
friendlymodelname = args.hordemodelname
if args.debugmode == 1:
@ -3175,33 +3127,6 @@ def main(launch_args,start_server=True):
sys.exit(3)
#handle loading image model
#todo: remove this
if args.sdconfig:
show_deprecated_warning()
imgmodel = args.sdconfig[0]
if not imgmodel or not os.path.exists(imgmodel):
print(f"Cannot find image model file: {imgmodel}")
if args.ignoremissing:
print(f"Ignoring missing sdconfig img model file...")
args.sdconfig = None
else:
exitcounter = 999
time.sleep(3)
sys.exit(2)
else:
imgmodel = os.path.abspath(imgmodel)
fullsdmodelpath = imgmodel
friendlysdmodelname = os.path.basename(imgmodel)
friendlysdmodelname = os.path.splitext(friendlysdmodelname)[0]
friendlysdmodelname = sanitize_string(friendlysdmodelname)
loadok = sd_load_model(imgmodel)
print("Load Image Model OK: " + str(loadok))
if not loadok:
exitcounter = 999
print("Could not load image model: " + imgmodel)
time.sleep(3)
sys.exit(3)
if args.sdmodel and args.sdmodel!="":
imgmodel = args.sdmodel
if not imgmodel or not os.path.exists(imgmodel):
@ -3277,7 +3202,7 @@ def main(launch_args,start_server=True):
if not args.remotetunnel:
print(f"Starting Kobold API on port {args.port} at {epurl}/api/")
print(f"Starting OpenAI Compatible API on port {args.port} at {epurl}/v1/")
if args.sdconfig or args.sdmodel:
if args.sdmodel:
print(f"StableUI is available at {epurl}/sdui/")
if args.launch:
@ -3287,13 +3212,7 @@ def main(launch_args,start_server=True):
except:
print("--launch was set, but could not launch web browser automatically.")
#todo: remove this
if args.hordeconfig and len(args.hordeconfig)>4:
show_deprecated_warning()
horde_thread = threading.Thread(target=run_horde_worker,args=(args,args.hordeconfig[3],args.hordeconfig[4]))
horde_thread.daemon = True
horde_thread.start()
elif args.hordekey and args.hordekey!="" and args.hordemodelname and args.hordemodelname!="" and args.hordeworkername and args.hordeworkername!="":
if args.hordekey and args.hordekey!="" and args.hordemodelname and args.hordemodelname!="" and args.hordeworkername and args.hordeworkername!="":
horde_thread = threading.Thread(target=run_horde_worker,args=(args,args.hordekey,args.hordeworkername))
horde_thread.daemon = True
horde_thread.start()
@ -3363,11 +3282,8 @@ def main(launch_args,start_server=True):
print("Press ENTER key to exit.", flush=True)
input()
check_deprecation_warning()
if start_server:
#todo: remove in next version
if args.hordeconfig or args.sdconfig or args.smartcontext:
show_deprecated_warning()
if args.remotetunnel:
setuptunnel()
else:
@ -3439,6 +3355,7 @@ if __name__ == '__main__':
parser.add_argument("--tensor_split", help="For CUDA and Vulkan only, ratio to split tensors across multiple GPUs, space-separated list of proportions, e.g. 7 3", metavar=('[Ratios]'), type=float, nargs='+')
parser.add_argument("--contextsize", help="Controls the memory allocated for maximum context size, only change if you need more RAM for big contexts. (default 2048). Supported values are [256,512,1024,2048,3072,4096,6144,8192,12288,16384,24576,32768,49152,65536,98304,131072]. IF YOU USE ANYTHING ELSE YOU ARE ON YOUR OWN.",metavar=('[256,512,1024,2048,3072,4096,6144,8192,12288,16384,24576,32768,49152,65536,98304,131072]'), type=check_range(int,256,262144), default=2048)
parser.add_argument("--ropeconfig", help="If set, uses customized RoPE scaling from configured frequency scale and frequency base (e.g. --ropeconfig 0.25 10000). Otherwise, uses NTK-Aware scaling set automatically based on context size. For linear rope, simply set the freq-scale and ignore the freq-base",metavar=('[rope-freq-scale]', '[rope-freq-base]'), default=[0.0, 10000.0], type=float, nargs='+')
#more advanced params
parser.add_argument("--blasbatchsize", help="Sets the batch size used in BLAS processing (default 512). Setting it to -1 disables BLAS mode, but keeps other benefits like GPU offload.", type=int,choices=[-1,32,64,128,256,512,1024,2048], default=512)
parser.add_argument("--blasthreads", help="Use a different number of threads during BLAS if specified. Otherwise, has the same value as --threads",metavar=('[threads]'), type=int, default=0)
@ -3477,8 +3394,8 @@ if __name__ == '__main__':
parser.add_argument("--sdquant", help="If specified, loads the model quantized to save memory.", action='store_true')
parser.add_argument("--sdclamped", help="If specified, limit generation steps and resolution settings for shared use.", action='store_true')
parser.add_argument("--smartcontext", help="!!! THIS COMMAND IS DEPRECATED AND WILL BE REMOVED !!!", action='store_true')
parser.add_argument("--hordeconfig", help="!!! THIS COMMAND IS DEPRECATED AND WILL BE REMOVED !!!", nargs='+')
parser.add_argument("--sdconfig", help="!!! THIS COMMAND IS DEPRECATED AND WILL BE REMOVED !!!", nargs='+')
parser.add_argument("--smartcontext", help="!!! THIS COMMAND IS DEPRECATED AND SHOULD NOT BE USED !!!", action='store_true')
parser.add_argument("--hordeconfig", help="!!! THIS COMMAND IS DEPRECATED AND SHOULD NOT BE USED !!!", nargs='+')
parser.add_argument("--sdconfig", help="!!! THIS COMMAND IS DEPRECATED AND SHOULD NOT BE USED !!!", nargs='+')
main(parser.parse_args(),start_server=True)