From 4a5c903718f7038ac8aed82ea7806e9a41f09212 Mon Sep 17 00:00:00 2001 From: Concedo <39025047+LostRuins@users.noreply.github.com> Date: Thu, 26 Mar 2026 21:57:42 +0800 Subject: [PATCH] sd model model replacement logic: adjusted approach for easy merge --- otherarch/sdcpp/stable-diffusion.cpp | 249 ++++++++++++++++----------- 1 file changed, 146 insertions(+), 103 deletions(-) diff --git a/otherarch/sdcpp/stable-diffusion.cpp b/otherarch/sdcpp/stable-diffusion.cpp index 66cbadfd5..5841414d4 100644 --- a/otherarch/sdcpp/stable-diffusion.cpp +++ b/otherarch/sdcpp/stable-diffusion.cpp @@ -285,7 +285,10 @@ public: } } - bool init(const sd_ctx_params_t* sd_ctx_params) { + bool init(const sd_ctx_params_t* sd_ctx_params_kcpp) { + // kcpp make sd_ctx_params mutable + sd_ctx_params_t sd_ctx_params_local = *sd_ctx_params_kcpp; + sd_ctx_params_t *sd_ctx_params = &sd_ctx_params_local; n_threads = sd_ctx_params->n_threads; vae_decode_only = sd_ctx_params->vae_decode_only; free_params_immediately = sd_ctx_params->free_params_immediately; @@ -304,10 +307,13 @@ public: init_backend(); - std::string taesd_path_fixed = taesd_path; - std::string t5_path_fixed = SAFE_STR(sd_ctx_params->t5xxl_path); - std::string clipl_path_fixed = SAFE_STR(sd_ctx_params->clip_l_path); - std::string clipg_path_fixed = SAFE_STR(sd_ctx_params->clip_g_path); + std::string clip_vision_fixed = SAFE_STR(sd_ctx_params->clip_vision_path); + std::string clipg_path_fixed = SAFE_STR(sd_ctx_params->clip_g_path); + std::string clipl_path_fixed = SAFE_STR(sd_ctx_params->clip_l_path); + std::string llm_path_fixed = SAFE_STR(sd_ctx_params->llm_path); + std::string llm_vision_path_fixed = SAFE_STR(sd_ctx_params->llm_vision_path); + std::string t5_path_fixed = SAFE_STR(sd_ctx_params->t5xxl_path); + std::string taesd_path_fixed = taesd_path; ModelLoader model_loader; @@ -333,7 +339,9 @@ public: } bool is_unet = sd_version_is_unet(model_loader.get_sd_version()); - int tempver = model_loader.get_sd_version(); + + // begin kcpp replacements + SDVersion tempver = model_loader.get_sd_version(); // kcpp fallback to separate diffusion model passed as model if (tempver == VERSION_COUNT && @@ -365,24 +373,19 @@ public: tempver = model_loader.get_sd_version(); } - bool iswan = (tempver==VERSION_WAN2 || tempver==VERSION_WAN2_2_I2V || tempver==VERSION_WAN2_2_TI2V); - bool isqwenimg = (tempver==VERSION_QWEN_IMAGE); - bool iszimg = (tempver==VERSION_Z_IMAGE); - bool isflux2 = (tempver==VERSION_FLUX2); - bool isflux2k = (tempver==VERSION_FLUX2_KLEIN); + bool iswan = sd_version_is_wan(tempver); + bool is_wan21 = sd_version_is_wan(tempver) && tempver != VERSION_WAN2_2_TI2V; + bool is_qwenimg = sd_version_is_qwen_image(tempver); + bool iszimg = sd_version_is_z_image(tempver); + bool isflux2 = sd_version_is_flux2(tempver); bool is_ovis = (tempver==VERSION_OVIS_IMAGE); - bool is_anima = (tempver==VERSION_ANIMA); - bool conditioner_is_llm = (isqwenimg||iszimg||isflux2||isflux2k||is_ovis||is_anima); + bool is_anima = sd_version_is_anima(tempver); + bool conditioner_is_llm = (is_qwenimg || iszimg || isflux2 || is_ovis || is_anima); - //kcpp qol fallback: if qwen image, and they loaded the qwen2vl llm as t5 by mistake + //kcpp qol fallback: if a llm was loaded as t5 by mistake if(conditioner_is_llm && t5_path_fixed!="") { - if(clipl_path_fixed=="" && clipg_path_fixed=="") - { - clipl_path_fixed = t5_path_fixed; - t5_path_fixed = ""; - } - else if(clipl_path_fixed=="" && clipg_path_fixed!="") + if(clipl_path_fixed=="") { clipl_path_fixed = t5_path_fixed; t5_path_fixed = ""; @@ -399,39 +402,131 @@ public: } } - if (clipl_path_fixed!="") { - LOG_INFO("loading clip_l from '%s'", clipl_path_fixed.c_str()); - std::string prefix = is_unet ? "cond_stage_model.transformer." : "text_encoders.clip_l.transformer."; - if(iswan) + //settle clip-l replacements + if (clipl_path_fixed!="") + { + if(conditioner_is_llm && llm_path_fixed=="") { - prefix = "cond_stage_model.transformer."; - LOG_INFO("swap clip_vision from '%s'", clipl_path_fixed.c_str()); + llm_path_fixed = clipl_path_fixed; + clipl_path_fixed = ""; } - if(conditioner_is_llm) + else if(iswan) { - prefix = "text_encoders.llm."; - LOG_INFO("swap llm from '%s'", clipl_path_fixed.c_str()); - } - if (!model_loader.init_from_file(clipl_path_fixed.c_str(), prefix)) { - LOG_WARN("loading clip_l from '%s' failed", clipl_path_fixed.c_str()); + if(t5_path_fixed=="") + { + t5_path_fixed = clipl_path_fixed; + clipl_path_fixed = ""; + } else if (t5_path_fixed != "" && clip_vision_fixed == "") { + clip_vision_fixed = clipl_path_fixed; + clipl_path_fixed = ""; + } } } - if (clipg_path_fixed!="") { - LOG_INFO("loading clip_g from '%s'", clipg_path_fixed.c_str()); + //settle clip-g replacements + if (clipg_path_fixed!="") + { + if(iswan && clip_vision_fixed=="") + { + clip_vision_fixed = clipg_path_fixed; + clipg_path_fixed = ""; + } + else if(is_qwenimg && llm_vision_path_fixed=="") + { + llm_vision_path_fixed = clipg_path_fixed; + clipg_path_fixed = ""; + } + } + + //settle possible inversions for mmproj + if(llm_vision_path_fixed!="" && llm_path_fixed!="") + { + if(toLowerCase(llm_vision_path_fixed).find("mmproj") == std::string::npos && + toLowerCase(llm_path_fixed).find("mmproj") != std::string::npos) + { + std::string tmp = llm_path_fixed; + llm_path_fixed = llm_vision_path_fixed; + llm_vision_path_fixed = tmp; + } + } + + //settle tae replacements + if(taesd_path_fixed != "") + { + std::string to_search = "taesd.embd"; + std::string to_replace = ""; + if(sd_version_is_sd1(tempver) || sd_version_is_sd2(tempver)) + { + to_replace = "taesd.embd"; + } + else if(sd_version_is_sdxl(tempver)) + { + to_replace = "taesd_xl.embd"; + } + else if(sd_version_is_flux(tempver)||sd_version_is_z_image(tempver)||tempver == VERSION_OVIS_IMAGE) + { + to_replace = "taesd_f.embd"; + } + else if(sd_version_is_sd3(tempver)) + { + to_replace = "taesd_3.embd"; + } + else if(sd_version_is_flux2(tempver)) + { + to_replace = "taesd_f2.embd"; + } + else if(is_wan21||is_qwenimg||sd_version_is_anima(tempver)) + { + to_replace = "taesd_w21.embd"; + } + + if(to_replace!="") + { + size_t pos = taesd_path_fixed.find(to_search); + if (pos != std::string::npos) { + taesd_path_fixed.replace(pos, to_search.length(), to_replace); + } + } + else + { + printf("\nCannot use TAESD: Unknown tempver %d. TAESD Disabled!\n",tempver); + taesd_path_fixed = ""; + } + if (taesd_path_fixed != "" && !file_exists(taesd_path_fixed)) + { + printf("\nCannot use TAESD: \"%s\" not found. TAESD Disabled!\n", taesd_path_fixed.c_str()); + taesd_path_fixed = ""; + } + } + + sd_ctx_params->clip_g_path = clipg_path_fixed.c_str(); + sd_ctx_params->clip_l_path = clipl_path_fixed.c_str(); + sd_ctx_params->clip_vision_path = clip_vision_fixed.c_str(); + sd_ctx_params->llm_path = llm_path_fixed.c_str(); + sd_ctx_params->llm_vision_path = llm_vision_path_fixed.c_str(); + sd_ctx_params->t5xxl_path = t5_path_fixed.c_str(); + taesd_path = taesd_path_fixed; + use_tiny_autoencoder = (taesd_path != ""); + //debug print + // printf("\n\nclip_g: %s\nclip_l: %s\nclip_vision: %s\nllm: %s\nllm_vision: %s\nt5xxl: %s\ntaesd: %s\n", + // sd_ctx_params->clip_g_path, sd_ctx_params->clip_l_path, sd_ctx_params->clip_vision_path, + // sd_ctx_params->llm_path, sd_ctx_params->llm_vision_path, sd_ctx_params->t5xxl_path, + // taesd_path.c_str()); + // end kcpp replacements + + if (strlen(SAFE_STR(sd_ctx_params->clip_l_path)) > 0) { + LOG_INFO("loading clip_l from '%s'", sd_ctx_params->clip_l_path); + std::string prefix = is_unet ? "cond_stage_model.transformer." : "text_encoders.clip_l.transformer."; + if (!model_loader.init_from_file(sd_ctx_params->clip_l_path, prefix)) { + LOG_WARN("loading clip_l from '%s' failed", sd_ctx_params->clip_l_path); + } + } + + if (strlen(SAFE_STR(sd_ctx_params->clip_g_path)) > 0) { + LOG_INFO("loading clip_g from '%s'", sd_ctx_params->clip_g_path); std::string prefix = is_unet ? "cond_stage_model.1.transformer." : "text_encoders.clip_g.transformer."; - if(iswan) - { - prefix = "cond_stage_model.transformer."; - LOG_INFO("swap clip_vision from '%s'", clipg_path_fixed.c_str()); - } - if(isqwenimg) - { - prefix = "text_encoders.llm.visual."; - LOG_INFO("swap llm mmproj from '%s'", clipg_path_fixed.c_str()); - } - if (!model_loader.init_from_file(clipg_path_fixed.c_str(), prefix)) { - LOG_WARN("loading clip_g from '%s' failed", clipg_path_fixed.c_str()); + if (!model_loader.init_from_file(sd_ctx_params->clip_g_path, prefix)) { + LOG_WARN("loading clip_g from '%s' failed", sd_ctx_params->clip_g_path); } } @@ -443,10 +538,10 @@ public: } } - if (t5_path_fixed!="") { - LOG_INFO("loading t5xxl from '%s'", t5_path_fixed.c_str()); - if (!model_loader.init_from_file(t5_path_fixed.c_str(), "text_encoders.t5xxl.transformer.")) { - LOG_WARN("loading t5xxl from '%s' failed", t5_path_fixed.c_str()); + if (strlen(SAFE_STR(sd_ctx_params->t5xxl_path)) > 0) { + LOG_INFO("loading t5xxl from '%s'", sd_ctx_params->t5xxl_path); + if (!model_loader.init_from_file(sd_ctx_params->t5xxl_path, "text_encoders.t5xxl.transformer.")) { + LOG_WARN("loading t5xxl from '%s' failed", sd_ctx_params->t5xxl_path); } } @@ -475,7 +570,6 @@ public: model_loader.convert_tensors_name(); version = model_loader.get_sd_version(); - if (version == VERSION_COUNT) { LOG_ERROR("get sd version from file failed: '%s'", SAFE_STR(sd_ctx_params->model_path)); return false; @@ -484,57 +578,6 @@ public: auto& tensor_storage_map = model_loader.get_tensor_storage_map(); LOG_INFO("Version: %s ", model_version_to_str[version]); - - if(use_tiny_autoencoder) // kcpp - { - std::string to_search = "taesd.embd"; - std::string to_replace = ""; - if(sd_version_is_sd1(version) || sd_version_is_sd2(version)) - { - to_replace = "taesd.embd"; - } - else if(sd_version_is_sdxl(version)) - { - to_replace = "taesd_xl.embd"; - } - else if(sd_version_is_flux(version)||sd_version_is_z_image(version)||version == VERSION_OVIS_IMAGE) - { - to_replace = "taesd_f.embd"; - } - else if(sd_version_is_sd3(version)) - { - to_replace = "taesd_3.embd"; - } - else if(sd_version_is_flux2(version)) - { - to_replace = "taesd_f2.embd"; - } - else if((sd_version_is_wan(version) && version != VERSION_WAN2_2_TI2V)||sd_version_is_qwen_image(version)||sd_version_is_anima(version)) - { - to_replace = "taesd_w21.embd"; - } - - if(to_replace!="") - { - size_t pos = taesd_path_fixed.find(to_search); - if (pos != std::string::npos) { - taesd_path_fixed.replace(pos, to_search.length(), to_replace); - } - } - else - { - printf("\nCannot use TAESD: Unknown version %d. TAESD Disabled!\n",version); - taesd_path_fixed = ""; - use_tiny_autoencoder = false; - } - if (use_tiny_autoencoder && !file_exists(taesd_path_fixed)) - { - printf("\nCannot use TAESD: \"%s\" not found. TAESD Disabled!\n", taesd_path_fixed.c_str()); - taesd_path_fixed = ""; - use_tiny_autoencoder = false; - } - } - ggml_type wtype = (int)sd_ctx_params->wtype < std::min(SD_TYPE_COUNT, GGML_TYPE_COUNT) ? (ggml_type)sd_ctx_params->wtype : GGML_TYPE_COUNT; @@ -1025,7 +1068,7 @@ public: vae_params_mem_size = first_stage_model->get_params_buffer_size(); } if (use_tiny_autoencoder || version == VERSION_SDXS) { - if (use_tiny_autoencoder && !tae_first_stage->load_from_file(taesd_path_fixed, n_threads)) { + if (use_tiny_autoencoder && !tae_first_stage->load_from_file(taesd_path, n_threads)) { return false; } use_tiny_autoencoder = true; // now the processing is identical for VERSION_SDXS