fallback flux loader

This commit is contained in:
Concedo 2024-11-07 15:55:43 +08:00
parent c9977a5cb5
commit 262437f393
3 changed files with 25 additions and 1 deletions

View file

@ -1366,6 +1366,16 @@ bool ModelLoader::init_from_ckpt_file(const std::string& file_path, const std::s
return true; return true;
} }
bool ModelLoader::has_diffusion_model_tensors()
{
for (auto& tensor_storage : tensor_storages) {
if (tensor_storage.name.find("model.diffusion_model.") != std::string::npos) {
return true;
}
}
return false;
}
SDVersion ModelLoader::get_sd_version() { SDVersion ModelLoader::get_sd_version() {
TensorStorage token_embedding_weight; TensorStorage token_embedding_weight;
bool is_flux = false; bool is_flux = false;

View file

@ -149,6 +149,7 @@ protected:
public: public:
bool init_from_file(const std::string& file_path, const std::string& prefix = ""); bool init_from_file(const std::string& file_path, const std::string& prefix = "");
bool has_diffusion_model_tensors();
SDVersion get_sd_version(); SDVersion get_sd_version();
ggml_type get_sd_wtype(); ggml_type get_sd_wtype();
ggml_type get_conditioner_wtype(); ggml_type get_conditioner_wtype();

View file

@ -232,8 +232,21 @@ public:
} }
version = model_loader.get_sd_version(); version = model_loader.get_sd_version();
if (version == VERSION_COUNT && model_path.size() > 0 && clip_l_path.size() > 0 && diffusion_model_path.size() == 0 && t5xxl_path.size() > 0) {
bool endswithsafetensors = (model_path.rfind(".safetensors") == model_path.size() - 12);
if(endswithsafetensors && !model_loader.has_diffusion_model_tensors())
{
LOG_INFO("SD Diffusion Model tensors missing! Fallback trying alternative tensor names...\n");
if (!model_loader.init_from_file(model_path, "model.diffusion_model.")) {
LOG_WARN("loading diffusion model from '%s' failed", model_path.c_str());
}
version = model_loader.get_sd_version();
}
}
if (version == VERSION_COUNT) { if (version == VERSION_COUNT) {
LOG_ERROR("get sd version from file failed: '%s'", model_path.c_str()); LOG_ERROR("Error: get SD version from file failed: '%s'", model_path.c_str());
return false; return false;
} }