diff --git a/expose.cpp b/expose.cpp index 1a9b2437e..d456200e3 100644 --- a/expose.cpp +++ b/expose.cpp @@ -219,7 +219,7 @@ extern "C" return gpttype_generate(inputs); } - bool sd_load_model(const sd_load_model_inputs inputs) + sd_load_model_outputs sd_load_model(const sd_load_model_inputs inputs) { return sdtype_load_model(inputs); } diff --git a/expose.h b/expose.h index 9d03b30c4..6b7b2fed5 100644 --- a/expose.h +++ b/expose.h @@ -172,6 +172,11 @@ struct sd_load_model_inputs const bool quiet = false; const int debugmode = 0; }; +struct sd_load_model_outputs +{ + int status = -1; + const char * model_version = ""; // SDVersion +}; struct sd_generation_inputs { const char * prompt = nullptr; diff --git a/koboldcpp.py b/koboldcpp.py index 2643dc8d8..8095ffae2 100644 --- a/koboldcpp.py +++ b/koboldcpp.py @@ -66,6 +66,7 @@ using_gui_launcher = False handle = None friendlymodelname = "inactive" friendlysdmodelname = "inactive" +sdmodelversion = "" friendlyembeddingsmodelname = "inactive" lastgeneratedcomfyimg = b'' fullsdmodelpath = "" #if empty, it's not initialized @@ -278,6 +279,10 @@ class sd_load_model_inputs(ctypes.Structure): ("quiet", ctypes.c_bool), ("debugmode", ctypes.c_int)] +class sd_load_model_outputs(ctypes.Structure): + _fields_ = [("status", ctypes.c_int), + ("model_version", ctypes.c_char_p)] + class sd_generation_inputs(ctypes.Structure): _fields_ = [("prompt", ctypes.c_char_p), ("negative_prompt", ctypes.c_char_p), @@ -540,7 +545,7 @@ def init_library(): handle.load_state_kv.restype = ctypes.c_bool handle.clear_state_kv.restype = ctypes.c_bool handle.sd_load_model.argtypes = [sd_load_model_inputs] - handle.sd_load_model.restype = ctypes.c_bool + handle.sd_load_model.restype = sd_load_model_outputs handle.sd_generate.argtypes = [sd_generation_inputs] handle.sd_generate.restype = sd_generation_outputs handle.whisper_load_model.argtypes = [whisper_load_model_inputs] @@ -6811,11 +6816,14 @@ def kcpp_main_process(launch_args, g_memory=None, gui_launcher=False): friendlysdmodelname = os.path.basename(imgmodel) friendlysdmodelname = os.path.splitext(friendlysdmodelname)[0] friendlysdmodelname = sanitize_string(friendlysdmodelname) - loadok = sd_load_model(imgmodel,imgvae,imglora,imgt5xxl,imgclipl,imgclipg) + ret = sd_load_model(imgmodel,imgvae,imglora,imgt5xxl,imgclipl,imgclipg) + loadok = (ret.status == 0) + sdmodelversion = ret.model_version.decode("UTF-8","ignore") print("Load Image Model OK: " + str(loadok)) if not loadok: exitcounter = 999 exit_with_error(3,"Could not load image model: " + imgmodel) + print("Image Model Type: " + sdmodelversion) #handle whisper model if args.whispermodel and args.whispermodel!="": diff --git a/model_adapter.h b/model_adapter.h index d5602cd40..55cca0f25 100644 --- a/model_adapter.h +++ b/model_adapter.h @@ -103,7 +103,7 @@ std::vector gpttype_get_token_arr(const std::string & input, bool addbos); std::string gpttype_detokenize(const std::vector & input, bool render_special); const std::vector gpttype_get_top_picks_data(); -bool sdtype_load_model(const sd_load_model_inputs inputs); +sd_load_model_outputs sdtype_load_model(const sd_load_model_inputs inputs); sd_generation_outputs sdtype_generate(const sd_generation_inputs inputs); bool whispertype_load_model(const whisper_load_model_inputs inputs); @@ -135,4 +135,4 @@ size_t gpttype_calc_old_state_kv(int slot); size_t gpttype_calc_old_state_tokencount(int slot); size_t gpttype_save_state_kv(int slot); bool gpttype_load_state_kv(int slot); -bool gpttype_clear_state_kv(bool shrink); \ No newline at end of file +bool gpttype_clear_state_kv(bool shrink); diff --git a/otherarch/sdcpp/sdtype_adapter.cpp b/otherarch/sdcpp/sdtype_adapter.cpp index 36ebe317c..d99bc14bc 100644 --- a/otherarch/sdcpp/sdtype_adapter.cpp +++ b/otherarch/sdcpp/sdtype_adapter.cpp @@ -106,6 +106,27 @@ struct SDParams { float skip_layer_end = 0.2f; }; +static const char * sdversion_name (enum SDVersion version) +{ + // TODO: stable-diffusion.h should expose model_version_to_str or equivalent + static const char * model_version_to_str[] = { + "SD 1.x", + "SD 1.x Inpaint", + "SD 2.x", + "SD 2.x Inpaint", + "SDXL", + "SDXL Inpaint", + "SVD", + "SD3.x", + "Flux", + "Flux Fill" + }; + unsigned int idx = (unsigned int) version; + if (idx < (sizeof model_version_to_str / sizeof model_version_to_str[0])) + return model_version_to_str[idx]; + return "UNKNOWN"; +} + //shared int total_img_gens = 0; @@ -122,7 +143,9 @@ static bool notiling = false; static bool sd_is_quiet = false; static std::string sdmodelfilename = ""; -bool sdtype_load_model(const sd_load_model_inputs inputs) { +sd_load_model_outputs sdtype_load_model(const sd_load_model_inputs inputs) { + sd_load_model_outputs output; + sd_is_quiet = inputs.quiet; set_sd_quiet(sd_is_quiet); executable_path = inputs.executable_path; @@ -260,7 +283,7 @@ bool sdtype_load_model(const sd_load_model_inputs inputs) { if (sd_ctx == NULL) { printf("\nError: KCPP SD Failed to create context!\nIf using Flux/SD3.5, make sure you have ALL files required (e.g. VAE, T5, Clip...) or baked in!\n"); - return false; + return output; } std::filesystem::path mpath(inputs.model_filename); @@ -273,7 +296,9 @@ bool sdtype_load_model(const sd_load_model_inputs inputs) { sd_ctx->sd->apply_lora_from_file(lorafilename,inputs.lora_multiplier); } - return true; + output.status = 0; + output.model_version = sdversion_name(sd_ctx->sd->version); + return output; }