sd model model replacement logic: adjusted approach for easy merge

This commit is contained in:
Concedo 2026-03-26 21:57:42 +08:00
parent 25216a0793
commit 4a5c903718

View file

@ -285,7 +285,10 @@ public:
}
}
bool init(const sd_ctx_params_t* sd_ctx_params) {
bool init(const sd_ctx_params_t* sd_ctx_params_kcpp) {
// kcpp make sd_ctx_params mutable
sd_ctx_params_t sd_ctx_params_local = *sd_ctx_params_kcpp;
sd_ctx_params_t *sd_ctx_params = &sd_ctx_params_local;
n_threads = sd_ctx_params->n_threads;
vae_decode_only = sd_ctx_params->vae_decode_only;
free_params_immediately = sd_ctx_params->free_params_immediately;
@ -304,10 +307,13 @@ public:
init_backend();
std::string taesd_path_fixed = taesd_path;
std::string t5_path_fixed = SAFE_STR(sd_ctx_params->t5xxl_path);
std::string clipl_path_fixed = SAFE_STR(sd_ctx_params->clip_l_path);
std::string clipg_path_fixed = SAFE_STR(sd_ctx_params->clip_g_path);
std::string clip_vision_fixed = SAFE_STR(sd_ctx_params->clip_vision_path);
std::string clipg_path_fixed = SAFE_STR(sd_ctx_params->clip_g_path);
std::string clipl_path_fixed = SAFE_STR(sd_ctx_params->clip_l_path);
std::string llm_path_fixed = SAFE_STR(sd_ctx_params->llm_path);
std::string llm_vision_path_fixed = SAFE_STR(sd_ctx_params->llm_vision_path);
std::string t5_path_fixed = SAFE_STR(sd_ctx_params->t5xxl_path);
std::string taesd_path_fixed = taesd_path;
ModelLoader model_loader;
@ -333,7 +339,9 @@ public:
}
bool is_unet = sd_version_is_unet(model_loader.get_sd_version());
int tempver = model_loader.get_sd_version();
// begin kcpp replacements
SDVersion tempver = model_loader.get_sd_version();
// kcpp fallback to separate diffusion model passed as model
if (tempver == VERSION_COUNT &&
@ -365,24 +373,19 @@ public:
tempver = model_loader.get_sd_version();
}
bool iswan = (tempver==VERSION_WAN2 || tempver==VERSION_WAN2_2_I2V || tempver==VERSION_WAN2_2_TI2V);
bool isqwenimg = (tempver==VERSION_QWEN_IMAGE);
bool iszimg = (tempver==VERSION_Z_IMAGE);
bool isflux2 = (tempver==VERSION_FLUX2);
bool isflux2k = (tempver==VERSION_FLUX2_KLEIN);
bool iswan = sd_version_is_wan(tempver);
bool is_wan21 = sd_version_is_wan(tempver) && tempver != VERSION_WAN2_2_TI2V;
bool is_qwenimg = sd_version_is_qwen_image(tempver);
bool iszimg = sd_version_is_z_image(tempver);
bool isflux2 = sd_version_is_flux2(tempver);
bool is_ovis = (tempver==VERSION_OVIS_IMAGE);
bool is_anima = (tempver==VERSION_ANIMA);
bool conditioner_is_llm = (isqwenimg||iszimg||isflux2||isflux2k||is_ovis||is_anima);
bool is_anima = sd_version_is_anima(tempver);
bool conditioner_is_llm = (is_qwenimg || iszimg || isflux2 || is_ovis || is_anima);
//kcpp qol fallback: if qwen image, and they loaded the qwen2vl llm as t5 by mistake
//kcpp qol fallback: if a llm was loaded as t5 by mistake
if(conditioner_is_llm && t5_path_fixed!="")
{
if(clipl_path_fixed=="" && clipg_path_fixed=="")
{
clipl_path_fixed = t5_path_fixed;
t5_path_fixed = "";
}
else if(clipl_path_fixed=="" && clipg_path_fixed!="")
if(clipl_path_fixed=="")
{
clipl_path_fixed = t5_path_fixed;
t5_path_fixed = "";
@ -399,39 +402,131 @@ public:
}
}
if (clipl_path_fixed!="") {
LOG_INFO("loading clip_l from '%s'", clipl_path_fixed.c_str());
std::string prefix = is_unet ? "cond_stage_model.transformer." : "text_encoders.clip_l.transformer.";
if(iswan)
//settle clip-l replacements
if (clipl_path_fixed!="")
{
if(conditioner_is_llm && llm_path_fixed=="")
{
prefix = "cond_stage_model.transformer.";
LOG_INFO("swap clip_vision from '%s'", clipl_path_fixed.c_str());
llm_path_fixed = clipl_path_fixed;
clipl_path_fixed = "";
}
if(conditioner_is_llm)
else if(iswan)
{
prefix = "text_encoders.llm.";
LOG_INFO("swap llm from '%s'", clipl_path_fixed.c_str());
}
if (!model_loader.init_from_file(clipl_path_fixed.c_str(), prefix)) {
LOG_WARN("loading clip_l from '%s' failed", clipl_path_fixed.c_str());
if(t5_path_fixed=="")
{
t5_path_fixed = clipl_path_fixed;
clipl_path_fixed = "";
} else if (t5_path_fixed != "" && clip_vision_fixed == "") {
clip_vision_fixed = clipl_path_fixed;
clipl_path_fixed = "";
}
}
}
if (clipg_path_fixed!="") {
LOG_INFO("loading clip_g from '%s'", clipg_path_fixed.c_str());
//settle clip-g replacements
if (clipg_path_fixed!="")
{
if(iswan && clip_vision_fixed=="")
{
clip_vision_fixed = clipg_path_fixed;
clipg_path_fixed = "";
}
else if(is_qwenimg && llm_vision_path_fixed=="")
{
llm_vision_path_fixed = clipg_path_fixed;
clipg_path_fixed = "";
}
}
//settle possible inversions for mmproj
if(llm_vision_path_fixed!="" && llm_path_fixed!="")
{
if(toLowerCase(llm_vision_path_fixed).find("mmproj") == std::string::npos &&
toLowerCase(llm_path_fixed).find("mmproj") != std::string::npos)
{
std::string tmp = llm_path_fixed;
llm_path_fixed = llm_vision_path_fixed;
llm_vision_path_fixed = tmp;
}
}
//settle tae replacements
if(taesd_path_fixed != "")
{
std::string to_search = "taesd.embd";
std::string to_replace = "";
if(sd_version_is_sd1(tempver) || sd_version_is_sd2(tempver))
{
to_replace = "taesd.embd";
}
else if(sd_version_is_sdxl(tempver))
{
to_replace = "taesd_xl.embd";
}
else if(sd_version_is_flux(tempver)||sd_version_is_z_image(tempver)||tempver == VERSION_OVIS_IMAGE)
{
to_replace = "taesd_f.embd";
}
else if(sd_version_is_sd3(tempver))
{
to_replace = "taesd_3.embd";
}
else if(sd_version_is_flux2(tempver))
{
to_replace = "taesd_f2.embd";
}
else if(is_wan21||is_qwenimg||sd_version_is_anima(tempver))
{
to_replace = "taesd_w21.embd";
}
if(to_replace!="")
{
size_t pos = taesd_path_fixed.find(to_search);
if (pos != std::string::npos) {
taesd_path_fixed.replace(pos, to_search.length(), to_replace);
}
}
else
{
printf("\nCannot use TAESD: Unknown tempver %d. TAESD Disabled!\n",tempver);
taesd_path_fixed = "";
}
if (taesd_path_fixed != "" && !file_exists(taesd_path_fixed))
{
printf("\nCannot use TAESD: \"%s\" not found. TAESD Disabled!\n", taesd_path_fixed.c_str());
taesd_path_fixed = "";
}
}
sd_ctx_params->clip_g_path = clipg_path_fixed.c_str();
sd_ctx_params->clip_l_path = clipl_path_fixed.c_str();
sd_ctx_params->clip_vision_path = clip_vision_fixed.c_str();
sd_ctx_params->llm_path = llm_path_fixed.c_str();
sd_ctx_params->llm_vision_path = llm_vision_path_fixed.c_str();
sd_ctx_params->t5xxl_path = t5_path_fixed.c_str();
taesd_path = taesd_path_fixed;
use_tiny_autoencoder = (taesd_path != "");
//debug print
// printf("\n\nclip_g: %s\nclip_l: %s\nclip_vision: %s\nllm: %s\nllm_vision: %s\nt5xxl: %s\ntaesd: %s\n",
// sd_ctx_params->clip_g_path, sd_ctx_params->clip_l_path, sd_ctx_params->clip_vision_path,
// sd_ctx_params->llm_path, sd_ctx_params->llm_vision_path, sd_ctx_params->t5xxl_path,
// taesd_path.c_str());
// end kcpp replacements
if (strlen(SAFE_STR(sd_ctx_params->clip_l_path)) > 0) {
LOG_INFO("loading clip_l from '%s'", sd_ctx_params->clip_l_path);
std::string prefix = is_unet ? "cond_stage_model.transformer." : "text_encoders.clip_l.transformer.";
if (!model_loader.init_from_file(sd_ctx_params->clip_l_path, prefix)) {
LOG_WARN("loading clip_l from '%s' failed", sd_ctx_params->clip_l_path);
}
}
if (strlen(SAFE_STR(sd_ctx_params->clip_g_path)) > 0) {
LOG_INFO("loading clip_g from '%s'", sd_ctx_params->clip_g_path);
std::string prefix = is_unet ? "cond_stage_model.1.transformer." : "text_encoders.clip_g.transformer.";
if(iswan)
{
prefix = "cond_stage_model.transformer.";
LOG_INFO("swap clip_vision from '%s'", clipg_path_fixed.c_str());
}
if(isqwenimg)
{
prefix = "text_encoders.llm.visual.";
LOG_INFO("swap llm mmproj from '%s'", clipg_path_fixed.c_str());
}
if (!model_loader.init_from_file(clipg_path_fixed.c_str(), prefix)) {
LOG_WARN("loading clip_g from '%s' failed", clipg_path_fixed.c_str());
if (!model_loader.init_from_file(sd_ctx_params->clip_g_path, prefix)) {
LOG_WARN("loading clip_g from '%s' failed", sd_ctx_params->clip_g_path);
}
}
@ -443,10 +538,10 @@ public:
}
}
if (t5_path_fixed!="") {
LOG_INFO("loading t5xxl from '%s'", t5_path_fixed.c_str());
if (!model_loader.init_from_file(t5_path_fixed.c_str(), "text_encoders.t5xxl.transformer.")) {
LOG_WARN("loading t5xxl from '%s' failed", t5_path_fixed.c_str());
if (strlen(SAFE_STR(sd_ctx_params->t5xxl_path)) > 0) {
LOG_INFO("loading t5xxl from '%s'", sd_ctx_params->t5xxl_path);
if (!model_loader.init_from_file(sd_ctx_params->t5xxl_path, "text_encoders.t5xxl.transformer.")) {
LOG_WARN("loading t5xxl from '%s' failed", sd_ctx_params->t5xxl_path);
}
}
@ -475,7 +570,6 @@ public:
model_loader.convert_tensors_name();
version = model_loader.get_sd_version();
if (version == VERSION_COUNT) {
LOG_ERROR("get sd version from file failed: '%s'", SAFE_STR(sd_ctx_params->model_path));
return false;
@ -484,57 +578,6 @@ public:
auto& tensor_storage_map = model_loader.get_tensor_storage_map();
LOG_INFO("Version: %s ", model_version_to_str[version]);
if(use_tiny_autoencoder) // kcpp
{
std::string to_search = "taesd.embd";
std::string to_replace = "";
if(sd_version_is_sd1(version) || sd_version_is_sd2(version))
{
to_replace = "taesd.embd";
}
else if(sd_version_is_sdxl(version))
{
to_replace = "taesd_xl.embd";
}
else if(sd_version_is_flux(version)||sd_version_is_z_image(version)||version == VERSION_OVIS_IMAGE)
{
to_replace = "taesd_f.embd";
}
else if(sd_version_is_sd3(version))
{
to_replace = "taesd_3.embd";
}
else if(sd_version_is_flux2(version))
{
to_replace = "taesd_f2.embd";
}
else if((sd_version_is_wan(version) && version != VERSION_WAN2_2_TI2V)||sd_version_is_qwen_image(version)||sd_version_is_anima(version))
{
to_replace = "taesd_w21.embd";
}
if(to_replace!="")
{
size_t pos = taesd_path_fixed.find(to_search);
if (pos != std::string::npos) {
taesd_path_fixed.replace(pos, to_search.length(), to_replace);
}
}
else
{
printf("\nCannot use TAESD: Unknown version %d. TAESD Disabled!\n",version);
taesd_path_fixed = "";
use_tiny_autoencoder = false;
}
if (use_tiny_autoencoder && !file_exists(taesd_path_fixed))
{
printf("\nCannot use TAESD: \"%s\" not found. TAESD Disabled!\n", taesd_path_fixed.c_str());
taesd_path_fixed = "";
use_tiny_autoencoder = false;
}
}
ggml_type wtype = (int)sd_ctx_params->wtype < std::min<int>(SD_TYPE_COUNT, GGML_TYPE_COUNT)
? (ggml_type)sd_ctx_params->wtype
: GGML_TYPE_COUNT;
@ -1025,7 +1068,7 @@ public:
vae_params_mem_size = first_stage_model->get_params_buffer_size();
}
if (use_tiny_autoencoder || version == VERSION_SDXS) {
if (use_tiny_autoencoder && !tae_first_stage->load_from_file(taesd_path_fixed, n_threads)) {
if (use_tiny_autoencoder && !tae_first_stage->load_from_file(taesd_path, n_threads)) {
return false;
}
use_tiny_autoencoder = true; // now the processing is identical for VERSION_SDXS