diff --git a/koboldcpp.py b/koboldcpp.py index cce3f1fbd..03f670d4b 100755 --- a/koboldcpp.py +++ b/koboldcpp.py @@ -87,6 +87,7 @@ password = "" #if empty, no auth key required fullwhispermodelpath = "" #if empty, it's not initialized ttsmodelpath = "" #if empty, not initialized embeddingsmodelpath = "" #if empty, not initialized +musicllmmodelpath = "" #if empty, not initialized musicdiffusionmodelpath = "" #if empty, not initialized maxctx = 8192 maxhordectx = 0 #set to whatever maxctx is if 0 @@ -1149,7 +1150,7 @@ def convert_json_to_gbnf(json_obj): return "" def get_capabilities(): - global savedata_obj, has_multiplayer, KcppVersion, friendlymodelname, friendlysdmodelname, fullsdmodelpath, password, fullwhispermodelpath, ttsmodelpath, embeddingsmodelpath, musicdiffusionmodelpath, has_audio_support, has_vision_support, mcp_connections + global savedata_obj, has_multiplayer, KcppVersion, friendlymodelname, friendlysdmodelname, fullsdmodelpath, password, fullwhispermodelpath, ttsmodelpath, embeddingsmodelpath, musicdiffusionmodelpath, musicllmmodelpath, has_audio_support, has_vision_support, mcp_connections has_llm = not (friendlymodelname=="inactive") has_txt2img = not (friendlysdmodelname=="inactive" or fullsdmodelpath=="") has_password = (password!="") @@ -1157,7 +1158,7 @@ def get_capabilities(): has_search = True if args.websearch else False has_tts = (ttsmodelpath!="") has_embeddings = (embeddingsmodelpath!="") - has_music = (musicdiffusionmodelpath!="") + has_music = (musicdiffusionmodelpath!="" or musicllmmodelpath!="") has_guidance = True if args.enableguidance else False has_jinja = True if args.jinja else False has_mcp = True if (args.mcpfile and mcp_connections and len(mcp_connections) > 0) else False @@ -5535,7 +5536,7 @@ def show_gui(): if dlfile: args.model_param = dlfile load_config_cli(args.model_param) - if not args.model_param and not args.sdmodel and not args.whispermodel and not args.ttsmodel and not args.embeddingsmodel and not args.musicdiffusion and not args.mcpfile and not args.nomodel: + if not args.model_param and not args.sdmodel and not args.whispermodel and not args.ttsmodel and not args.embeddingsmodel and not args.musicdiffusion and not args.musicllm and not args.mcpfile and not args.nomodel: global exitcounter exitcounter = 999 exit_with_error(2,"No gguf model or kcpps file was selected. Exiting.") @@ -6704,7 +6705,7 @@ def show_gui(): # launch def guilaunch(): - if model_var.get() == "" and sd_model_var.get() == "" and whisper_model_var.get() == "" and tts_model_var.get() == "" and embeddings_model_var.get() == "" and musicdiffusion_var.get() == "" and nomodel.get()!=1: + if model_var.get() == "" and sd_model_var.get() == "" and whisper_model_var.get() == "" and tts_model_var.get() == "" and embeddings_model_var.get() == "" and musicdiffusion_var.get() == "" and musicllm_var.get() == "" and nomodel.get()!=1: tmp = zentk_askopenfilename(title="Select ggml model .bin or .gguf file") model_var.set(tmp) nonlocal nextstate @@ -7316,7 +7317,7 @@ def show_gui(): kcpp_exporting_template = False export_vars() - if not args.model_param and not args.sdmodel and not args.whispermodel and not args.ttsmodel and not args.embeddingsmodel and not args.musicdiffusion and not args.mcpfile and not args.nomodel: + if not args.model_param and not args.sdmodel and not args.whispermodel and not args.ttsmodel and not args.embeddingsmodel and not args.musicdiffusion and not args.musicllm and not args.mcpfile and not args.nomodel: exitcounter = 999 print("") time.sleep(0.5) @@ -8166,7 +8167,7 @@ def main(launch_args, default_args): return # show the GUI launcher if a model was not provided - if args.showgui or (not args.model_param and not args.sdmodel and not args.whispermodel and not args.ttsmodel and not args.embeddingsmodel and not args.musicdiffusion and not args.mcpfile and not args.nomodel): + if args.showgui or (not args.model_param and not args.sdmodel and not args.whispermodel and not args.ttsmodel and not args.embeddingsmodel and not args.musicdiffusion and not args.musicllm and not args.mcpfile and not args.nomodel): #give them a chance to pick a file print("For command line arguments, please refer to --help") print("***") @@ -8288,7 +8289,7 @@ def main(launch_args, default_args): def kcpp_main_process(launch_args, g_memory=None, gui_launcher=False): global embedded_kailite, embedded_kcpp_docs, embedded_kcpp_sdui, embedded_kailite_gz, embedded_kcpp_docs_gz, embedded_kcpp_sdui_gz, embedded_lcpp_ui_gz, embedded_musicui, embedded_musicui_gz, start_time, exitcounter, global_memory, using_gui_launcher - global libname, args, friendlymodelname, friendlysdmodelname, fullsdmodelpath, password, fullwhispermodelpath, ttsmodelpath, embeddingsmodelpath, musicdiffusionmodelpath, friendlyembeddingsmodelname, has_audio_support, has_vision_support, cached_chat_template + global libname, args, friendlymodelname, friendlysdmodelname, fullsdmodelpath, password, fullwhispermodelpath, ttsmodelpath, embeddingsmodelpath, musicdiffusionmodelpath, musicllmmodelpath, friendlyembeddingsmodelname, has_audio_support, has_vision_support, cached_chat_template start_server = True @@ -8836,27 +8837,55 @@ def kcpp_main_process(launch_args, g_memory=None, gui_launcher=False): exit_with_error(3,"Could not load Embeddings model!") #handle music model - if (args.musicdiffusion and args.musicdiffusion!="") or (args.musicllm and args.musicllm!="") or (args.musicembeddings and args.musicembeddings!="") or (args.musicvae and args.musicvae!=""): - if not os.path.exists(args.musicllm) or not os.path.exists(args.musicembeddings) or not os.path.exists(args.musicdiffusion) or not os.path.exists(args.musicvae): - if args.ignoremissing: - print("Ignoring missing Music model files!") - args.musicllm = None - args.musicembeddings = None - args.musicdiffusion = None - args.musicvae = None + mu_has_llm = True if (args.musicllm and args.musicllm!="") else False + mu_has_embed = True if (args.musicembeddings and args.musicembeddings!="") else False + mu_has_diff = True if (args.musicdiffusion and args.musicdiffusion!="") else False + mu_has_vae = True if (args.musicvae and args.musicvae!="") else False + if mu_has_llm or mu_has_embed or mu_has_diff or mu_has_vae: + if mu_has_llm and not any([mu_has_embed, mu_has_diff, mu_has_vae]): + if not os.path.exists(args.musicllm): + if args.ignoremissing: + print("Ignoring missing Music LLM model file!") + args.musicllm = None + else: + exitcounter = 999 + exit_with_error(2, "Cannot find Music LLM model file!") else: + musicllmpath = os.path.abspath(args.musicllm) + loadok = music_load_model(musicllmpath, "", "", "") + print("Load Music LLM Only OK: " + str(loadok)) + if not loadok: + exitcounter = 999 + exit_with_error(3, "Could not load Music LLM model!") + elif mu_has_diff: + if not (mu_has_embed and mu_has_vae): exitcounter = 999 - exit_with_error(2,"Cannot find music model files or missing a music model. Make sure ALL 4 music models (llm,embed,diffusion and vae) are loaded!") - else: - musicdiffusionmodelpath = os.path.abspath(args.musicdiffusion) - musicembedpath = os.path.abspath(args.musicembeddings) - musicllmpath = os.path.abspath(args.musicllm) - musicvaepath = os.path.abspath(args.musicvae) - loadok = music_load_model(musicllmpath,musicembedpath,musicdiffusionmodelpath,musicvaepath) - print("Load Music Model OK: " + str(loadok)) - if not loadok: - exitcounter = 999 - exit_with_error(3,"Could not load Music model!") + exit_with_error(2,"Invalid config: Music Diffusion requires Music embedding and Music VAE models!") + + paths_to_check = [args.musicdiffusion,args.musicembeddings,args.musicvae] + if mu_has_llm: + paths_to_check.append(args.musicllm) + + if not all(os.path.exists(p) for p in paths_to_check): + if args.ignoremissing: + print("Ignoring missing Music model files!") + args.musicllm = None + args.musicembeddings = None + args.musicdiffusion = None + args.musicvae = None + else: + exitcounter = 999 + exit_with_error(2,"Cannot find required music diffusion/embedding/VAE model files!") + else: + musicdiffusionmodelpath = os.path.abspath(args.musicdiffusion) + musicembedpath = os.path.abspath(args.musicembeddings) + musicvaepath = os.path.abspath(args.musicvae) + musicllmpath = os.path.abspath(args.musicllm) if mu_has_llm else "" + loadok = music_load_model(musicllmpath,musicembedpath,musicdiffusionmodelpath,musicvaepath) + print("Load Music Models OK: " + str(loadok)) + if not loadok: + exitcounter = 999 + exit_with_error(3, "Could not load Music models!") #load embedded lite embddir = os.path.join(os.path.abspath(os.path.dirname(os.path.realpath(__file__))),"embd_res") diff --git a/otherarch/acestep/music_adapter.cpp b/otherarch/acestep/music_adapter.cpp index fab0bcb3b..4e4a27fcc 100644 --- a/otherarch/acestep/music_adapter.cpp +++ b/otherarch/acestep/music_adapter.cpp @@ -22,7 +22,8 @@ static int musicdebugmode = 0; static bool music_is_quiet = false; -static bool musicgen_loaded = false; +static bool musicgen_llm_loaded = false; +static bool musicgen_diffusion_loaded = false; static std::string musicvulkandeviceenv; static std::string music_output_json_str = ""; @@ -60,39 +61,48 @@ bool musictype_load_model(const music_load_model_inputs inputs) printf("\nLoading Music Gen LLM Model: %s\nLoading Music Gen Embed Model: %s\nLoading Music Gen Diffusion Model: %s\nLoading Music Gen VAE Model: %s\n", musicllm_filename.c_str(),musicembedding_filename.c_str(),musicdiffusion_filename.c_str(),musicvae_filename.c_str()); musicdebugmode = inputs.debugmode; - - bool ok = load_acestep_lm(musicllm_filename,lowvram,musicdebugmode); - if (!ok) { - printf("\nFailed to load Music Gen LM Model!\n"); - return false; - } - if(lowvram) + bool ok = false; + if(musicllm_filename!="") { - unload_acestep_lm(); + ok = load_acestep_lm(musicllm_filename,lowvram,musicdebugmode); + if (!ok) { + printf("\nFailed to load Music Gen LM Model!\n"); + return false; + } + if(lowvram) + { + unload_acestep_lm(); + } + musicgen_llm_loaded = ok; } - ok = load_acestep_dit(musicembedding_filename,musicdiffusion_filename,musicvae_filename,lowvram); - if (!ok) { - printf("\nFailed to load Music Gen Diffusion, Embed or VAE Model!\n"); - return false; - } - if(lowvram) + if(musicdiffusion_filename!="") { - unload_acestep_dit_core(); - unload_acestep_dit_others(); + ok = load_acestep_dit(musicembedding_filename,musicdiffusion_filename,musicvae_filename,lowvram); + if (!ok) { + printf("\nFailed to load Music Gen Diffusion, Embed or VAE Model!\n"); + return false; + } + if(lowvram) + { + unload_acestep_dit_core(); + unload_acestep_dit_others(); + } + musicgen_diffusion_loaded = ok; } - musicgen_loaded = true; - - printf("\nMusic Gen Load Complete.\n"); - return true; + if(ok) + { + printf("\nMusic Gen Load Complete.\n"); + } + return ok; } music_generation_outputs musictype_generate(const music_generation_inputs inputs) { music_generation_outputs output; - if(!musicgen_loaded) + if(!musicgen_llm_loaded && !musicgen_diffusion_loaded) { printf("\nWarning: KCPP music gen not initialized!\n"); output.status = 0; @@ -101,7 +111,7 @@ music_generation_outputs musictype_generate(const music_generation_inputs inputs return output; } - if (inputs.is_planner_mode) { + if (inputs.is_planner_mode && musicgen_llm_loaded) { if (!music_is_quiet) { printf("\nMusic Gen Generating Codes..."); } @@ -120,7 +130,9 @@ music_generation_outputs musictype_generate(const music_generation_inputs inputs if (!music_is_quiet) { printf("\nMusic Gen Codes Done:\n%s\n",music_output_json_str.c_str()); } - } else { + } + else if (!inputs.is_planner_mode && musicgen_diffusion_loaded) + { if (!music_is_quiet) { printf("\nMusic Gen Generating Audio..."); } @@ -140,6 +152,14 @@ music_generation_outputs musictype_generate(const music_generation_inputs inputs printf("\nMusic Gen Audio Done\n"); } } + else + { + printf("\nWarning: KCPP music gen missing requested model (Make sure it was loaded)!\n"); + output.status = 0; + output.music_output_json = ""; + output.data = ""; + return output; + } return output; }