mirror of
https://github.com/LostRuins/koboldcpp.git
synced 2026-05-07 17:22:04 +00:00
sd model model replacement logic: adjusted approach for easy merge
This commit is contained in:
parent
25216a0793
commit
4a5c903718
1 changed files with 146 additions and 103 deletions
|
|
@ -285,7 +285,10 @@ public:
|
|||
}
|
||||
}
|
||||
|
||||
bool init(const sd_ctx_params_t* sd_ctx_params) {
|
||||
bool init(const sd_ctx_params_t* sd_ctx_params_kcpp) {
|
||||
// kcpp make sd_ctx_params mutable
|
||||
sd_ctx_params_t sd_ctx_params_local = *sd_ctx_params_kcpp;
|
||||
sd_ctx_params_t *sd_ctx_params = &sd_ctx_params_local;
|
||||
n_threads = sd_ctx_params->n_threads;
|
||||
vae_decode_only = sd_ctx_params->vae_decode_only;
|
||||
free_params_immediately = sd_ctx_params->free_params_immediately;
|
||||
|
|
@ -304,10 +307,13 @@ public:
|
|||
|
||||
init_backend();
|
||||
|
||||
std::string taesd_path_fixed = taesd_path;
|
||||
std::string t5_path_fixed = SAFE_STR(sd_ctx_params->t5xxl_path);
|
||||
std::string clipl_path_fixed = SAFE_STR(sd_ctx_params->clip_l_path);
|
||||
std::string clipg_path_fixed = SAFE_STR(sd_ctx_params->clip_g_path);
|
||||
std::string clip_vision_fixed = SAFE_STR(sd_ctx_params->clip_vision_path);
|
||||
std::string clipg_path_fixed = SAFE_STR(sd_ctx_params->clip_g_path);
|
||||
std::string clipl_path_fixed = SAFE_STR(sd_ctx_params->clip_l_path);
|
||||
std::string llm_path_fixed = SAFE_STR(sd_ctx_params->llm_path);
|
||||
std::string llm_vision_path_fixed = SAFE_STR(sd_ctx_params->llm_vision_path);
|
||||
std::string t5_path_fixed = SAFE_STR(sd_ctx_params->t5xxl_path);
|
||||
std::string taesd_path_fixed = taesd_path;
|
||||
|
||||
ModelLoader model_loader;
|
||||
|
||||
|
|
@ -333,7 +339,9 @@ public:
|
|||
}
|
||||
|
||||
bool is_unet = sd_version_is_unet(model_loader.get_sd_version());
|
||||
int tempver = model_loader.get_sd_version();
|
||||
|
||||
// begin kcpp replacements
|
||||
SDVersion tempver = model_loader.get_sd_version();
|
||||
|
||||
// kcpp fallback to separate diffusion model passed as model
|
||||
if (tempver == VERSION_COUNT &&
|
||||
|
|
@ -365,24 +373,19 @@ public:
|
|||
tempver = model_loader.get_sd_version();
|
||||
}
|
||||
|
||||
bool iswan = (tempver==VERSION_WAN2 || tempver==VERSION_WAN2_2_I2V || tempver==VERSION_WAN2_2_TI2V);
|
||||
bool isqwenimg = (tempver==VERSION_QWEN_IMAGE);
|
||||
bool iszimg = (tempver==VERSION_Z_IMAGE);
|
||||
bool isflux2 = (tempver==VERSION_FLUX2);
|
||||
bool isflux2k = (tempver==VERSION_FLUX2_KLEIN);
|
||||
bool iswan = sd_version_is_wan(tempver);
|
||||
bool is_wan21 = sd_version_is_wan(tempver) && tempver != VERSION_WAN2_2_TI2V;
|
||||
bool is_qwenimg = sd_version_is_qwen_image(tempver);
|
||||
bool iszimg = sd_version_is_z_image(tempver);
|
||||
bool isflux2 = sd_version_is_flux2(tempver);
|
||||
bool is_ovis = (tempver==VERSION_OVIS_IMAGE);
|
||||
bool is_anima = (tempver==VERSION_ANIMA);
|
||||
bool conditioner_is_llm = (isqwenimg||iszimg||isflux2||isflux2k||is_ovis||is_anima);
|
||||
bool is_anima = sd_version_is_anima(tempver);
|
||||
bool conditioner_is_llm = (is_qwenimg || iszimg || isflux2 || is_ovis || is_anima);
|
||||
|
||||
//kcpp qol fallback: if qwen image, and they loaded the qwen2vl llm as t5 by mistake
|
||||
//kcpp qol fallback: if a llm was loaded as t5 by mistake
|
||||
if(conditioner_is_llm && t5_path_fixed!="")
|
||||
{
|
||||
if(clipl_path_fixed=="" && clipg_path_fixed=="")
|
||||
{
|
||||
clipl_path_fixed = t5_path_fixed;
|
||||
t5_path_fixed = "";
|
||||
}
|
||||
else if(clipl_path_fixed=="" && clipg_path_fixed!="")
|
||||
if(clipl_path_fixed=="")
|
||||
{
|
||||
clipl_path_fixed = t5_path_fixed;
|
||||
t5_path_fixed = "";
|
||||
|
|
@ -399,39 +402,131 @@ public:
|
|||
}
|
||||
}
|
||||
|
||||
if (clipl_path_fixed!="") {
|
||||
LOG_INFO("loading clip_l from '%s'", clipl_path_fixed.c_str());
|
||||
std::string prefix = is_unet ? "cond_stage_model.transformer." : "text_encoders.clip_l.transformer.";
|
||||
if(iswan)
|
||||
//settle clip-l replacements
|
||||
if (clipl_path_fixed!="")
|
||||
{
|
||||
if(conditioner_is_llm && llm_path_fixed=="")
|
||||
{
|
||||
prefix = "cond_stage_model.transformer.";
|
||||
LOG_INFO("swap clip_vision from '%s'", clipl_path_fixed.c_str());
|
||||
llm_path_fixed = clipl_path_fixed;
|
||||
clipl_path_fixed = "";
|
||||
}
|
||||
if(conditioner_is_llm)
|
||||
else if(iswan)
|
||||
{
|
||||
prefix = "text_encoders.llm.";
|
||||
LOG_INFO("swap llm from '%s'", clipl_path_fixed.c_str());
|
||||
}
|
||||
if (!model_loader.init_from_file(clipl_path_fixed.c_str(), prefix)) {
|
||||
LOG_WARN("loading clip_l from '%s' failed", clipl_path_fixed.c_str());
|
||||
if(t5_path_fixed=="")
|
||||
{
|
||||
t5_path_fixed = clipl_path_fixed;
|
||||
clipl_path_fixed = "";
|
||||
} else if (t5_path_fixed != "" && clip_vision_fixed == "") {
|
||||
clip_vision_fixed = clipl_path_fixed;
|
||||
clipl_path_fixed = "";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (clipg_path_fixed!="") {
|
||||
LOG_INFO("loading clip_g from '%s'", clipg_path_fixed.c_str());
|
||||
//settle clip-g replacements
|
||||
if (clipg_path_fixed!="")
|
||||
{
|
||||
if(iswan && clip_vision_fixed=="")
|
||||
{
|
||||
clip_vision_fixed = clipg_path_fixed;
|
||||
clipg_path_fixed = "";
|
||||
}
|
||||
else if(is_qwenimg && llm_vision_path_fixed=="")
|
||||
{
|
||||
llm_vision_path_fixed = clipg_path_fixed;
|
||||
clipg_path_fixed = "";
|
||||
}
|
||||
}
|
||||
|
||||
//settle possible inversions for mmproj
|
||||
if(llm_vision_path_fixed!="" && llm_path_fixed!="")
|
||||
{
|
||||
if(toLowerCase(llm_vision_path_fixed).find("mmproj") == std::string::npos &&
|
||||
toLowerCase(llm_path_fixed).find("mmproj") != std::string::npos)
|
||||
{
|
||||
std::string tmp = llm_path_fixed;
|
||||
llm_path_fixed = llm_vision_path_fixed;
|
||||
llm_vision_path_fixed = tmp;
|
||||
}
|
||||
}
|
||||
|
||||
//settle tae replacements
|
||||
if(taesd_path_fixed != "")
|
||||
{
|
||||
std::string to_search = "taesd.embd";
|
||||
std::string to_replace = "";
|
||||
if(sd_version_is_sd1(tempver) || sd_version_is_sd2(tempver))
|
||||
{
|
||||
to_replace = "taesd.embd";
|
||||
}
|
||||
else if(sd_version_is_sdxl(tempver))
|
||||
{
|
||||
to_replace = "taesd_xl.embd";
|
||||
}
|
||||
else if(sd_version_is_flux(tempver)||sd_version_is_z_image(tempver)||tempver == VERSION_OVIS_IMAGE)
|
||||
{
|
||||
to_replace = "taesd_f.embd";
|
||||
}
|
||||
else if(sd_version_is_sd3(tempver))
|
||||
{
|
||||
to_replace = "taesd_3.embd";
|
||||
}
|
||||
else if(sd_version_is_flux2(tempver))
|
||||
{
|
||||
to_replace = "taesd_f2.embd";
|
||||
}
|
||||
else if(is_wan21||is_qwenimg||sd_version_is_anima(tempver))
|
||||
{
|
||||
to_replace = "taesd_w21.embd";
|
||||
}
|
||||
|
||||
if(to_replace!="")
|
||||
{
|
||||
size_t pos = taesd_path_fixed.find(to_search);
|
||||
if (pos != std::string::npos) {
|
||||
taesd_path_fixed.replace(pos, to_search.length(), to_replace);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
printf("\nCannot use TAESD: Unknown tempver %d. TAESD Disabled!\n",tempver);
|
||||
taesd_path_fixed = "";
|
||||
}
|
||||
if (taesd_path_fixed != "" && !file_exists(taesd_path_fixed))
|
||||
{
|
||||
printf("\nCannot use TAESD: \"%s\" not found. TAESD Disabled!\n", taesd_path_fixed.c_str());
|
||||
taesd_path_fixed = "";
|
||||
}
|
||||
}
|
||||
|
||||
sd_ctx_params->clip_g_path = clipg_path_fixed.c_str();
|
||||
sd_ctx_params->clip_l_path = clipl_path_fixed.c_str();
|
||||
sd_ctx_params->clip_vision_path = clip_vision_fixed.c_str();
|
||||
sd_ctx_params->llm_path = llm_path_fixed.c_str();
|
||||
sd_ctx_params->llm_vision_path = llm_vision_path_fixed.c_str();
|
||||
sd_ctx_params->t5xxl_path = t5_path_fixed.c_str();
|
||||
taesd_path = taesd_path_fixed;
|
||||
use_tiny_autoencoder = (taesd_path != "");
|
||||
//debug print
|
||||
// printf("\n\nclip_g: %s\nclip_l: %s\nclip_vision: %s\nllm: %s\nllm_vision: %s\nt5xxl: %s\ntaesd: %s\n",
|
||||
// sd_ctx_params->clip_g_path, sd_ctx_params->clip_l_path, sd_ctx_params->clip_vision_path,
|
||||
// sd_ctx_params->llm_path, sd_ctx_params->llm_vision_path, sd_ctx_params->t5xxl_path,
|
||||
// taesd_path.c_str());
|
||||
// end kcpp replacements
|
||||
|
||||
if (strlen(SAFE_STR(sd_ctx_params->clip_l_path)) > 0) {
|
||||
LOG_INFO("loading clip_l from '%s'", sd_ctx_params->clip_l_path);
|
||||
std::string prefix = is_unet ? "cond_stage_model.transformer." : "text_encoders.clip_l.transformer.";
|
||||
if (!model_loader.init_from_file(sd_ctx_params->clip_l_path, prefix)) {
|
||||
LOG_WARN("loading clip_l from '%s' failed", sd_ctx_params->clip_l_path);
|
||||
}
|
||||
}
|
||||
|
||||
if (strlen(SAFE_STR(sd_ctx_params->clip_g_path)) > 0) {
|
||||
LOG_INFO("loading clip_g from '%s'", sd_ctx_params->clip_g_path);
|
||||
std::string prefix = is_unet ? "cond_stage_model.1.transformer." : "text_encoders.clip_g.transformer.";
|
||||
if(iswan)
|
||||
{
|
||||
prefix = "cond_stage_model.transformer.";
|
||||
LOG_INFO("swap clip_vision from '%s'", clipg_path_fixed.c_str());
|
||||
}
|
||||
if(isqwenimg)
|
||||
{
|
||||
prefix = "text_encoders.llm.visual.";
|
||||
LOG_INFO("swap llm mmproj from '%s'", clipg_path_fixed.c_str());
|
||||
}
|
||||
if (!model_loader.init_from_file(clipg_path_fixed.c_str(), prefix)) {
|
||||
LOG_WARN("loading clip_g from '%s' failed", clipg_path_fixed.c_str());
|
||||
if (!model_loader.init_from_file(sd_ctx_params->clip_g_path, prefix)) {
|
||||
LOG_WARN("loading clip_g from '%s' failed", sd_ctx_params->clip_g_path);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -443,10 +538,10 @@ public:
|
|||
}
|
||||
}
|
||||
|
||||
if (t5_path_fixed!="") {
|
||||
LOG_INFO("loading t5xxl from '%s'", t5_path_fixed.c_str());
|
||||
if (!model_loader.init_from_file(t5_path_fixed.c_str(), "text_encoders.t5xxl.transformer.")) {
|
||||
LOG_WARN("loading t5xxl from '%s' failed", t5_path_fixed.c_str());
|
||||
if (strlen(SAFE_STR(sd_ctx_params->t5xxl_path)) > 0) {
|
||||
LOG_INFO("loading t5xxl from '%s'", sd_ctx_params->t5xxl_path);
|
||||
if (!model_loader.init_from_file(sd_ctx_params->t5xxl_path, "text_encoders.t5xxl.transformer.")) {
|
||||
LOG_WARN("loading t5xxl from '%s' failed", sd_ctx_params->t5xxl_path);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -475,7 +570,6 @@ public:
|
|||
model_loader.convert_tensors_name();
|
||||
|
||||
version = model_loader.get_sd_version();
|
||||
|
||||
if (version == VERSION_COUNT) {
|
||||
LOG_ERROR("get sd version from file failed: '%s'", SAFE_STR(sd_ctx_params->model_path));
|
||||
return false;
|
||||
|
|
@ -484,57 +578,6 @@ public:
|
|||
auto& tensor_storage_map = model_loader.get_tensor_storage_map();
|
||||
|
||||
LOG_INFO("Version: %s ", model_version_to_str[version]);
|
||||
|
||||
if(use_tiny_autoencoder) // kcpp
|
||||
{
|
||||
std::string to_search = "taesd.embd";
|
||||
std::string to_replace = "";
|
||||
if(sd_version_is_sd1(version) || sd_version_is_sd2(version))
|
||||
{
|
||||
to_replace = "taesd.embd";
|
||||
}
|
||||
else if(sd_version_is_sdxl(version))
|
||||
{
|
||||
to_replace = "taesd_xl.embd";
|
||||
}
|
||||
else if(sd_version_is_flux(version)||sd_version_is_z_image(version)||version == VERSION_OVIS_IMAGE)
|
||||
{
|
||||
to_replace = "taesd_f.embd";
|
||||
}
|
||||
else if(sd_version_is_sd3(version))
|
||||
{
|
||||
to_replace = "taesd_3.embd";
|
||||
}
|
||||
else if(sd_version_is_flux2(version))
|
||||
{
|
||||
to_replace = "taesd_f2.embd";
|
||||
}
|
||||
else if((sd_version_is_wan(version) && version != VERSION_WAN2_2_TI2V)||sd_version_is_qwen_image(version)||sd_version_is_anima(version))
|
||||
{
|
||||
to_replace = "taesd_w21.embd";
|
||||
}
|
||||
|
||||
if(to_replace!="")
|
||||
{
|
||||
size_t pos = taesd_path_fixed.find(to_search);
|
||||
if (pos != std::string::npos) {
|
||||
taesd_path_fixed.replace(pos, to_search.length(), to_replace);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
printf("\nCannot use TAESD: Unknown version %d. TAESD Disabled!\n",version);
|
||||
taesd_path_fixed = "";
|
||||
use_tiny_autoencoder = false;
|
||||
}
|
||||
if (use_tiny_autoencoder && !file_exists(taesd_path_fixed))
|
||||
{
|
||||
printf("\nCannot use TAESD: \"%s\" not found. TAESD Disabled!\n", taesd_path_fixed.c_str());
|
||||
taesd_path_fixed = "";
|
||||
use_tiny_autoencoder = false;
|
||||
}
|
||||
}
|
||||
|
||||
ggml_type wtype = (int)sd_ctx_params->wtype < std::min<int>(SD_TYPE_COUNT, GGML_TYPE_COUNT)
|
||||
? (ggml_type)sd_ctx_params->wtype
|
||||
: GGML_TYPE_COUNT;
|
||||
|
|
@ -1025,7 +1068,7 @@ public:
|
|||
vae_params_mem_size = first_stage_model->get_params_buffer_size();
|
||||
}
|
||||
if (use_tiny_autoencoder || version == VERSION_SDXS) {
|
||||
if (use_tiny_autoencoder && !tae_first_stage->load_from_file(taesd_path_fixed, n_threads)) {
|
||||
if (use_tiny_autoencoder && !tae_first_stage->load_from_file(taesd_path, n_threads)) {
|
||||
return false;
|
||||
}
|
||||
use_tiny_autoencoder = true; // now the processing is identical for VERSION_SDXS
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue