diff --git a/otherarch/sdcpp/model.cpp b/otherarch/sdcpp/model.cpp index 73458320d..80fdd78cf 100644 --- a/otherarch/sdcpp/model.cpp +++ b/otherarch/sdcpp/model.cpp @@ -1366,6 +1366,16 @@ bool ModelLoader::init_from_ckpt_file(const std::string& file_path, const std::s return true; } +bool ModelLoader::has_diffusion_model_tensors() +{ + for (auto& tensor_storage : tensor_storages) { + if (tensor_storage.name.find("model.diffusion_model.") != std::string::npos) { + return true; + } + } + return false; +} + SDVersion ModelLoader::get_sd_version() { TensorStorage token_embedding_weight; bool is_flux = false; diff --git a/otherarch/sdcpp/model.h b/otherarch/sdcpp/model.h index 041245e37..f890db67f 100644 --- a/otherarch/sdcpp/model.h +++ b/otherarch/sdcpp/model.h @@ -149,6 +149,7 @@ protected: public: bool init_from_file(const std::string& file_path, const std::string& prefix = ""); + bool has_diffusion_model_tensors(); SDVersion get_sd_version(); ggml_type get_sd_wtype(); ggml_type get_conditioner_wtype(); diff --git a/otherarch/sdcpp/stable-diffusion.cpp b/otherarch/sdcpp/stable-diffusion.cpp index 877b8f878..2c87926d6 100644 --- a/otherarch/sdcpp/stable-diffusion.cpp +++ b/otherarch/sdcpp/stable-diffusion.cpp @@ -232,8 +232,21 @@ public: } version = model_loader.get_sd_version(); + + if (version == VERSION_COUNT && model_path.size() > 0 && clip_l_path.size() > 0 && diffusion_model_path.size() == 0 && t5xxl_path.size() > 0) { + bool endswithsafetensors = (model_path.rfind(".safetensors") == model_path.size() - 12); + if(endswithsafetensors && !model_loader.has_diffusion_model_tensors()) + { + LOG_INFO("SD Diffusion Model tensors missing! Fallback trying alternative tensor names...\n"); + if (!model_loader.init_from_file(model_path, "model.diffusion_model.")) { + LOG_WARN("loading diffusion model from '%s' failed", model_path.c_str()); + } + version = model_loader.get_sd_version(); + } + } + if (version == VERSION_COUNT) { - LOG_ERROR("get sd version from file failed: '%s'", model_path.c_str()); + LOG_ERROR("Error: get SD version from file failed: '%s'", model_path.c_str()); return false; }