From c9977a5cb5cc755aeb53a1d14bb49436ec49aa29 Mon Sep 17 00:00:00 2001 From: Concedo <39025047+LostRuins@users.noreply.github.com> Date: Thu, 7 Nov 2024 14:41:25 +0800 Subject: [PATCH] model downloading for new params --- koboldcpp.py | 75 +++++++++++++++++++++++++++++++--------------------- 1 file changed, 45 insertions(+), 30 deletions(-) diff --git a/koboldcpp.py b/koboldcpp.py index ad93223a6..04ff99354 100644 --- a/koboldcpp.py +++ b/koboldcpp.py @@ -2428,7 +2428,7 @@ def show_gui(): gtooltip_box.withdraw() gtooltip_box.overrideredirect(True) gtooltip_label = ctk.CTkLabel(gtooltip_box, text=tooltip_text, text_color="#000000", fg_color="#ffffe0") - gtooltip_label.pack(expand=True, padx=2, pady=1) + gtooltip_label.pack(expand=True, ipadx=2, ipady=1) else: gtooltip_label.configure(text=tooltip_text) @@ -3989,7 +3989,7 @@ def sanitize_string(input_string): sanitized_string = re.sub( r'[^\w\d\.\-_]', '', input_string) return sanitized_string -def download_model_from_url(url): #returns path to downloaded model when done +def download_model_from_url_internal(url): #returns path to downloaded model when done import subprocess mdlfilename = os.path.basename(url) #check if file already exists @@ -4006,6 +4006,19 @@ def download_model_from_url(url): #returns path to downloaded model when done print(f"Download {mdlfilename} completed.", flush=True) return mdlfilename return None +def download_model_from_url(url,permitted_types=[".gguf",".safetensors"]): + if url and url!="": + if url.endswith("?download=true"): + url = url.replace("?download=true","") + end_ext_ok = False + for t in permitted_types: + if url.endswith(t): + end_ext_ok = True + break + if ((url.startswith("http://") or url.startswith("https://")) and end_ext_ok): + dlfile = download_model_from_url_internal(url) + return dlfile + return None def main(launch_args,start_server=True): global embedded_kailite, embedded_kcpp_docs, embedded_kcpp_sdui @@ -4030,10 +4043,8 @@ def main(launch_args,start_server=True): if args.config and len(args.config)==1: cfgname = args.config[0] - if cfgname.endswith("?download=true"): - cfgname = cfgname.replace("?download=true","") - if isinstance(cfgname, str) and (cfgname.startswith("http://") or cfgname.startswith("https://")) and (cfgname.endswith(".kcpps") or cfgname.endswith(".kcppt")): - dlfile = download_model_from_url(cfgname) + if isinstance(cfgname, str): + dlfile = download_model_from_url(cfgname,[".kcpps",".kcppt"]) if dlfile: cfgname = dlfile if isinstance(cfgname, str) and os.path.exists(cfgname): @@ -4144,33 +4155,37 @@ def main(launch_args,start_server=True): # handle model downloads if needed if args.model_param and args.model_param!="": - if args.model_param.endswith("?download=true"): - args.model_param = args.model_param.replace("?download=true","") - if (args.model_param.startswith("http://") or args.model_param.startswith("https://")) and (args.model_param.endswith(".gguf") or args.model_param.endswith(".bin")): - dlfile = download_model_from_url(args.model_param) - if dlfile: - args.model_param = dlfile + dlfile = download_model_from_url(args.model_param,[".gguf",".bin"]) + if dlfile: + args.model_param = dlfile if args.sdmodel and args.sdmodel!="": - if args.sdmodel.endswith("?download=true"): - args.sdmodel = args.sdmodel.replace("?download=true","") - if (args.sdmodel.startswith("http://") or args.sdmodel.startswith("https://")) and (args.sdmodel.endswith(".gguf") or args.sdmodel.endswith(".safetensors")): - dlfile = download_model_from_url(args.sdmodel) - if dlfile: - args.sdmodel = dlfile + dlfile = download_model_from_url(args.sdmodel,[".gguf",".safetensors"]) + if dlfile: + args.sdmodel = dlfile + if args.sdt5xxl and args.sdt5xxl!="": + dlfile = download_model_from_url(args.sdt5xxl,[".safetensors"]) + if dlfile: + args.sdt5xxl = dlfile + if args.sdclipl and args.sdclipl!="": + dlfile = download_model_from_url(args.sdclipl,[".safetensors"]) + if dlfile: + args.sdclipl = dlfile + if args.sdclipg and args.sdclipg!="": + dlfile = download_model_from_url(args.sdclipg,[".safetensors"]) + if dlfile: + args.sdclipg = dlfile + if args.sdvae and args.sdvae!="": + dlfile = download_model_from_url(args.sdvae,[".safetensors"]) + if dlfile: + args.sdvae = dlfile if args.mmproj and args.mmproj!="": - if args.mmproj.endswith("?download=true"): - args.mmproj = args.mmproj.replace("?download=true","") - if (args.mmproj.startswith("http://") or args.mmproj.startswith("https://")) and (args.mmproj.endswith(".gguf")): - dlfile = download_model_from_url(args.mmproj) - if dlfile: - args.mmproj = dlfile + dlfile = download_model_from_url(args.mmproj,[".gguf"]) + if dlfile: + args.mmproj = dlfile if args.whispermodel and args.whispermodel!="": - if args.whispermodel.endswith("?download=true"): - args.whispermodel = args.whispermodel.replace("?download=true","") - if (args.whispermodel.startswith("http://") or args.whispermodel.startswith("https://")) and (args.whispermodel.endswith(".gguf") or args.whispermodel.endswith(".bin")): - dlfile = download_model_from_url(args.whispermodel) - if dlfile: - args.whispermodel = dlfile + dlfile = download_model_from_url(args.whispermodel,[".gguf",".bin"]) + if dlfile: + args.whispermodel = dlfile # sanitize and replace the default vanity name. remember me.... if args.model_param and args.model_param!="":