improvements to model downloader and chat completions adapter loader

This commit is contained in:
Concedo 2024-07-04 15:34:08 +08:00
parent 3fdbe3351d
commit 6b0756506b
2 changed files with 29 additions and 25 deletions

View file

@ -3302,16 +3302,27 @@ def main(launch_args,start_server=True):
# try to read chat completions adapter
if args.chatcompletionsadapter:
ccadapter_path = None
adapt_dir = os.path.join(os.path.abspath(os.path.dirname(__file__)), 'kcpp_adapters')
adapt_dir = adapt_dir if os.path.isdir(adapt_dir) else None
if isinstance(args.chatcompletionsadapter, str) and os.path.exists(args.chatcompletionsadapter):
print(f"Loading Chat Completions Adapter...")
with open(args.chatcompletionsadapter, 'r') as f:
ccadapter_path = os.path.abspath(args.chatcompletionsadapter)
elif isinstance(args.chatcompletionsadapter, str) and adapt_dir:
filename = args.chatcompletionsadapter
if not filename.endswith(".json"):
filename += ".json"
premade_adapt_path = os.path.join(adapt_dir,filename)
if os.path.exists(premade_adapt_path):
ccadapter_path = os.path.abspath(premade_adapt_path)
if ccadapter_path:
print(f"Loading Chat Completions Adapter: {ccadapter_path}")
with open(ccadapter_path, 'r') as f:
global chatcompl_adapter
chatcompl_adapter = json.load(f)
print(f"Chat Completions Adapter Loaded")
else:
print(f"Warning: Chat Completions Adapter {args.chatcompletionsadapter} invalid or not found.")
if args.model_param and args.model_param!="":
if args.model_param.endswith("?download=true"):
args.model_param = args.model_param.replace("?download=true","")
@ -3320,13 +3331,16 @@ def main(launch_args,start_server=True):
mdlfilename = os.path.basename(args.model_param)
#check if file already exists
if mdlfilename:
if not os.path.exists(mdlfilename):
print(f"Downloading model from external URL at {args.model_param}")
subprocess.run(f"curl -fL {args.model_param} -o {mdlfilename}", shell=True, capture_output=True, text=True, check=True, encoding='utf-8')
print(f"Download {mdlfilename} completed...", flush=True)
if os.path.exists(mdlfilename) and os.path.getsize(mdlfilename) > 10000000: #10MB trigger
print(f"Model file {mdlfilename} already exists, not redownloading.")
args.model_param = mdlfilename
else:
print(f"Model file {mdlfilename} already exists, not redownloading.")
dl_url = args.model_param
if "https://huggingface.co/" in dl_url and "/blob/main/" in dl_url:
dl_url = dl_url.replace("/blob/main/", "/resolve/main/")
print(f"Downloading model from external URL at {dl_url}")
subprocess.run(f"curl -fL {dl_url} -o {mdlfilename}", shell=True, capture_output=True, text=True, check=True, encoding='utf-8')
print(f"Download {mdlfilename} completed...", flush=True)
args.model_param = mdlfilename
# sanitize and replace the default vanity name. remember me....