model downloading for new params

This commit is contained in:
Concedo 2024-11-07 14:41:25 +08:00
parent 628dcd640e
commit c9977a5cb5

View file

@ -2428,7 +2428,7 @@ def show_gui():
gtooltip_box.withdraw() gtooltip_box.withdraw()
gtooltip_box.overrideredirect(True) gtooltip_box.overrideredirect(True)
gtooltip_label = ctk.CTkLabel(gtooltip_box, text=tooltip_text, text_color="#000000", fg_color="#ffffe0") gtooltip_label = ctk.CTkLabel(gtooltip_box, text=tooltip_text, text_color="#000000", fg_color="#ffffe0")
gtooltip_label.pack(expand=True, padx=2, pady=1) gtooltip_label.pack(expand=True, ipadx=2, ipady=1)
else: else:
gtooltip_label.configure(text=tooltip_text) gtooltip_label.configure(text=tooltip_text)
@ -3989,7 +3989,7 @@ def sanitize_string(input_string):
sanitized_string = re.sub( r'[^\w\d\.\-_]', '', input_string) sanitized_string = re.sub( r'[^\w\d\.\-_]', '', input_string)
return sanitized_string return sanitized_string
def download_model_from_url(url): #returns path to downloaded model when done def download_model_from_url_internal(url): #returns path to downloaded model when done
import subprocess import subprocess
mdlfilename = os.path.basename(url) mdlfilename = os.path.basename(url)
#check if file already exists #check if file already exists
@ -4006,6 +4006,19 @@ def download_model_from_url(url): #returns path to downloaded model when done
print(f"Download {mdlfilename} completed.", flush=True) print(f"Download {mdlfilename} completed.", flush=True)
return mdlfilename return mdlfilename
return None return None
def download_model_from_url(url,permitted_types=[".gguf",".safetensors"]):
if url and url!="":
if url.endswith("?download=true"):
url = url.replace("?download=true","")
end_ext_ok = False
for t in permitted_types:
if url.endswith(t):
end_ext_ok = True
break
if ((url.startswith("http://") or url.startswith("https://")) and end_ext_ok):
dlfile = download_model_from_url_internal(url)
return dlfile
return None
def main(launch_args,start_server=True): def main(launch_args,start_server=True):
global embedded_kailite, embedded_kcpp_docs, embedded_kcpp_sdui global embedded_kailite, embedded_kcpp_docs, embedded_kcpp_sdui
@ -4030,10 +4043,8 @@ def main(launch_args,start_server=True):
if args.config and len(args.config)==1: if args.config and len(args.config)==1:
cfgname = args.config[0] cfgname = args.config[0]
if cfgname.endswith("?download=true"): if isinstance(cfgname, str):
cfgname = cfgname.replace("?download=true","") dlfile = download_model_from_url(cfgname,[".kcpps",".kcppt"])
if isinstance(cfgname, str) and (cfgname.startswith("http://") or cfgname.startswith("https://")) and (cfgname.endswith(".kcpps") or cfgname.endswith(".kcppt")):
dlfile = download_model_from_url(cfgname)
if dlfile: if dlfile:
cfgname = dlfile cfgname = dlfile
if isinstance(cfgname, str) and os.path.exists(cfgname): if isinstance(cfgname, str) and os.path.exists(cfgname):
@ -4144,33 +4155,37 @@ def main(launch_args,start_server=True):
# handle model downloads if needed # handle model downloads if needed
if args.model_param and args.model_param!="": if args.model_param and args.model_param!="":
if args.model_param.endswith("?download=true"): dlfile = download_model_from_url(args.model_param,[".gguf",".bin"])
args.model_param = args.model_param.replace("?download=true","") if dlfile:
if (args.model_param.startswith("http://") or args.model_param.startswith("https://")) and (args.model_param.endswith(".gguf") or args.model_param.endswith(".bin")): args.model_param = dlfile
dlfile = download_model_from_url(args.model_param)
if dlfile:
args.model_param = dlfile
if args.sdmodel and args.sdmodel!="": if args.sdmodel and args.sdmodel!="":
if args.sdmodel.endswith("?download=true"): dlfile = download_model_from_url(args.sdmodel,[".gguf",".safetensors"])
args.sdmodel = args.sdmodel.replace("?download=true","") if dlfile:
if (args.sdmodel.startswith("http://") or args.sdmodel.startswith("https://")) and (args.sdmodel.endswith(".gguf") or args.sdmodel.endswith(".safetensors")): args.sdmodel = dlfile
dlfile = download_model_from_url(args.sdmodel) if args.sdt5xxl and args.sdt5xxl!="":
if dlfile: dlfile = download_model_from_url(args.sdt5xxl,[".safetensors"])
args.sdmodel = dlfile if dlfile:
args.sdt5xxl = dlfile
if args.sdclipl and args.sdclipl!="":
dlfile = download_model_from_url(args.sdclipl,[".safetensors"])
if dlfile:
args.sdclipl = dlfile
if args.sdclipg and args.sdclipg!="":
dlfile = download_model_from_url(args.sdclipg,[".safetensors"])
if dlfile:
args.sdclipg = dlfile
if args.sdvae and args.sdvae!="":
dlfile = download_model_from_url(args.sdvae,[".safetensors"])
if dlfile:
args.sdvae = dlfile
if args.mmproj and args.mmproj!="": if args.mmproj and args.mmproj!="":
if args.mmproj.endswith("?download=true"): dlfile = download_model_from_url(args.mmproj,[".gguf"])
args.mmproj = args.mmproj.replace("?download=true","") if dlfile:
if (args.mmproj.startswith("http://") or args.mmproj.startswith("https://")) and (args.mmproj.endswith(".gguf")): args.mmproj = dlfile
dlfile = download_model_from_url(args.mmproj)
if dlfile:
args.mmproj = dlfile
if args.whispermodel and args.whispermodel!="": if args.whispermodel and args.whispermodel!="":
if args.whispermodel.endswith("?download=true"): dlfile = download_model_from_url(args.whispermodel,[".gguf",".bin"])
args.whispermodel = args.whispermodel.replace("?download=true","") if dlfile:
if (args.whispermodel.startswith("http://") or args.whispermodel.startswith("https://")) and (args.whispermodel.endswith(".gguf") or args.whispermodel.endswith(".bin")): args.whispermodel = dlfile
dlfile = download_model_from_url(args.whispermodel)
if dlfile:
args.whispermodel = dlfile
# sanitize and replace the default vanity name. remember me.... # sanitize and replace the default vanity name. remember me....
if args.model_param and args.model_param!="": if args.model_param and args.model_param!="":