model downloading for new params

This commit is contained in:
Concedo 2024-11-07 14:41:25 +08:00
parent 628dcd640e
commit c9977a5cb5

View file

@ -2428,7 +2428,7 @@ def show_gui():
gtooltip_box.withdraw()
gtooltip_box.overrideredirect(True)
gtooltip_label = ctk.CTkLabel(gtooltip_box, text=tooltip_text, text_color="#000000", fg_color="#ffffe0")
gtooltip_label.pack(expand=True, padx=2, pady=1)
gtooltip_label.pack(expand=True, ipadx=2, ipady=1)
else:
gtooltip_label.configure(text=tooltip_text)
@ -3989,7 +3989,7 @@ def sanitize_string(input_string):
sanitized_string = re.sub( r'[^\w\d\.\-_]', '', input_string)
return sanitized_string
def download_model_from_url(url): #returns path to downloaded model when done
def download_model_from_url_internal(url): #returns path to downloaded model when done
import subprocess
mdlfilename = os.path.basename(url)
#check if file already exists
@ -4006,6 +4006,19 @@ def download_model_from_url(url): #returns path to downloaded model when done
print(f"Download {mdlfilename} completed.", flush=True)
return mdlfilename
return None
def download_model_from_url(url,permitted_types=[".gguf",".safetensors"]):
if url and url!="":
if url.endswith("?download=true"):
url = url.replace("?download=true","")
end_ext_ok = False
for t in permitted_types:
if url.endswith(t):
end_ext_ok = True
break
if ((url.startswith("http://") or url.startswith("https://")) and end_ext_ok):
dlfile = download_model_from_url_internal(url)
return dlfile
return None
def main(launch_args,start_server=True):
global embedded_kailite, embedded_kcpp_docs, embedded_kcpp_sdui
@ -4030,10 +4043,8 @@ def main(launch_args,start_server=True):
if args.config and len(args.config)==1:
cfgname = args.config[0]
if cfgname.endswith("?download=true"):
cfgname = cfgname.replace("?download=true","")
if isinstance(cfgname, str) and (cfgname.startswith("http://") or cfgname.startswith("https://")) and (cfgname.endswith(".kcpps") or cfgname.endswith(".kcppt")):
dlfile = download_model_from_url(cfgname)
if isinstance(cfgname, str):
dlfile = download_model_from_url(cfgname,[".kcpps",".kcppt"])
if dlfile:
cfgname = dlfile
if isinstance(cfgname, str) and os.path.exists(cfgname):
@ -4144,33 +4155,37 @@ def main(launch_args,start_server=True):
# handle model downloads if needed
if args.model_param and args.model_param!="":
if args.model_param.endswith("?download=true"):
args.model_param = args.model_param.replace("?download=true","")
if (args.model_param.startswith("http://") or args.model_param.startswith("https://")) and (args.model_param.endswith(".gguf") or args.model_param.endswith(".bin")):
dlfile = download_model_from_url(args.model_param)
if dlfile:
args.model_param = dlfile
dlfile = download_model_from_url(args.model_param,[".gguf",".bin"])
if dlfile:
args.model_param = dlfile
if args.sdmodel and args.sdmodel!="":
if args.sdmodel.endswith("?download=true"):
args.sdmodel = args.sdmodel.replace("?download=true","")
if (args.sdmodel.startswith("http://") or args.sdmodel.startswith("https://")) and (args.sdmodel.endswith(".gguf") or args.sdmodel.endswith(".safetensors")):
dlfile = download_model_from_url(args.sdmodel)
if dlfile:
args.sdmodel = dlfile
dlfile = download_model_from_url(args.sdmodel,[".gguf",".safetensors"])
if dlfile:
args.sdmodel = dlfile
if args.sdt5xxl and args.sdt5xxl!="":
dlfile = download_model_from_url(args.sdt5xxl,[".safetensors"])
if dlfile:
args.sdt5xxl = dlfile
if args.sdclipl and args.sdclipl!="":
dlfile = download_model_from_url(args.sdclipl,[".safetensors"])
if dlfile:
args.sdclipl = dlfile
if args.sdclipg and args.sdclipg!="":
dlfile = download_model_from_url(args.sdclipg,[".safetensors"])
if dlfile:
args.sdclipg = dlfile
if args.sdvae and args.sdvae!="":
dlfile = download_model_from_url(args.sdvae,[".safetensors"])
if dlfile:
args.sdvae = dlfile
if args.mmproj and args.mmproj!="":
if args.mmproj.endswith("?download=true"):
args.mmproj = args.mmproj.replace("?download=true","")
if (args.mmproj.startswith("http://") or args.mmproj.startswith("https://")) and (args.mmproj.endswith(".gguf")):
dlfile = download_model_from_url(args.mmproj)
if dlfile:
args.mmproj = dlfile
dlfile = download_model_from_url(args.mmproj,[".gguf"])
if dlfile:
args.mmproj = dlfile
if args.whispermodel and args.whispermodel!="":
if args.whispermodel.endswith("?download=true"):
args.whispermodel = args.whispermodel.replace("?download=true","")
if (args.whispermodel.startswith("http://") or args.whispermodel.startswith("https://")) and (args.whispermodel.endswith(".gguf") or args.whispermodel.endswith(".bin")):
dlfile = download_model_from_url(args.whispermodel)
if dlfile:
args.whispermodel = dlfile
dlfile = download_model_from_url(args.whispermodel,[".gguf",".bin"])
if dlfile:
args.whispermodel = dlfile
# sanitize and replace the default vanity name. remember me....
if args.model_param and args.model_param!="":