From 7a5983399bad686b73eaf3464a47bc0ae88eceb2 Mon Sep 17 00:00:00 2001 From: Concedo <39025047+LostRuins@users.noreply.github.com> Date: Wed, 15 May 2024 23:49:01 +0800 Subject: [PATCH] try to fix lora naming issues --- otherarch/sdcpp/ggml_extend.hpp | 4 +- otherarch/sdcpp/lora.hpp | 11 +++- otherarch/sdcpp/model.cpp | 108 ++++++++++++++++++++++++++------ otherarch/sdcpp/util.cpp | 7 +++ otherarch/sdcpp/util.h | 1 + 5 files changed, 109 insertions(+), 22 deletions(-) diff --git a/otherarch/sdcpp/ggml_extend.hpp b/otherarch/sdcpp/ggml_extend.hpp index 71f91fe58..3851b0e57 100644 --- a/otherarch/sdcpp/ggml_extend.hpp +++ b/otherarch/sdcpp/ggml_extend.hpp @@ -701,8 +701,8 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_timestep_embedding( // virtual struct ggml_cgraph* get_ggml_cgraph() = 0; // }; -#define MAX_PARAMS_TENSOR_NUM 10240 -#define MAX_GRAPH_SIZE 10240 +#define MAX_PARAMS_TENSOR_NUM 15360 +#define MAX_GRAPH_SIZE 15360 struct GGMLModule { protected: diff --git a/otherarch/sdcpp/lora.hpp b/otherarch/sdcpp/lora.hpp index 734635b66..36b5ee638 100644 --- a/otherarch/sdcpp/lora.hpp +++ b/otherarch/sdcpp/lora.hpp @@ -82,7 +82,16 @@ struct LoraModel : public GGMLModule { } k_tensor = k_tensor.substr(0, k_pos); replace_all_chars(k_tensor, '.', '_'); - std::string lora_up_name = "lora." + k_tensor + ".lora_up.weight"; + // LOG_DEBUG("k_tensor %s", k_tensor.c_str()); + std::string lora_up_name = "lora." + k_tensor + ".lora_up.weight"; + if (lora_tensors.find(lora_up_name) == lora_tensors.end()) { + if (k_tensor == "model_diffusion_model_output_blocks_2_2_conv") { + // fix for some sdxl lora, like lcm-lora-xl + k_tensor = "model_diffusion_model_output_blocks_2_1_conv"; + lora_up_name = "lora." + k_tensor + ".lora_up.weight"; + } + } + std::string lora_down_name = "lora." + k_tensor + ".lora_down.weight"; std::string alpha_name = "lora." + k_tensor + ".alpha"; std::string scale_name = "lora." + k_tensor + ".scale"; diff --git a/otherarch/sdcpp/model.cpp b/otherarch/sdcpp/model.cpp index 07ed53d19..b1d6b26ee 100644 --- a/otherarch/sdcpp/model.cpp +++ b/otherarch/sdcpp/model.cpp @@ -108,14 +108,14 @@ std::unordered_map open_clip_to_hf_clip_model = { {"model.positional_embedding", "transformer.text_model.embeddings.position_embedding.weight"}, {"model.token_embedding.weight", "transformer.text_model.embeddings.token_embedding.weight"}, {"model.text_projection", "transformer.text_model.text_projection"}, - {"model.visual.class_embedding", "transformer.visual_model.embeddings.class_embedding"}, - {"model.visual.conv1.weight", "transformer.visual_model.embeddings.patch_embedding.weight"}, - {"model.visual.ln_post.bias", "transformer.visual_model.post_layernorm.bias"}, - {"model.visual.ln_post.weight", "transformer.visual_model.post_layernorm.weight"}, - {"model.visual.ln_pre.bias", "transformer.visual_model.pre_layernorm.bias"}, - {"model.visual.ln_pre.weight", "transformer.visual_model.pre_layernorm.weight"}, - {"model.visual.positional_embedding", "transformer.visual_model.embeddings.position_embedding.weight"}, - {"model.visual.proj", "transformer.visual_model.visual_projection"}, + {"model.visual.class_embedding", "transformer.vision_model.embeddings.class_embedding"}, + {"model.visual.conv1.weight", "transformer.vision_model.embeddings.patch_embedding.weight"}, + {"model.visual.ln_post.bias", "transformer.vision_model.post_layernorm.bias"}, + {"model.visual.ln_post.weight", "transformer.vision_model.post_layernorm.weight"}, + {"model.visual.ln_pre.bias", "transformer.vision_model.pre_layernorm.bias"}, + {"model.visual.ln_pre.weight", "transformer.vision_model.pre_layernorm.weight"}, + {"model.visual.positional_embedding", "transformer.vision_model.embeddings.position_embedding.weight"}, + {"model.visual.proj", "transformer.visual_projection.weight"}, }; std::unordered_map open_clip_to_hk_clip_resblock = { @@ -157,6 +157,10 @@ std::string convert_open_clip_to_hf_clip(const std::string& name) { } else if (starts_with(new_name, "cond_stage_model.")) { prefix = "cond_stage_model."; new_name = new_name.substr(strlen("cond_stage_model.")); + } else if (ends_with(new_name, "vision_model.visual_projection.weight")) { + prefix = new_name.substr(0, new_name.size() - strlen("vision_model.visual_projection.weight")); + new_name = prefix + "visual_projection.weight"; + return new_name; } else { return new_name; } @@ -186,7 +190,7 @@ std::string convert_open_clip_to_hf_clip(const std::string& name) { replace_suffix(); open_clip_resblock_prefix = "model.visual.transformer.resblocks."; - hf_clip_resblock_prefix = "transformer.visual_model.encoder.layers."; + hf_clip_resblock_prefix = "transformer.vision_model.encoder.layers."; replace_suffix(); @@ -200,6 +204,25 @@ std::string convert_vae_decoder_name(const std::string& name) { return name; } +/* If not a SDXL LoRA the unet" prefix will have already been replaced by this + * point and "te2" and "te1" don't seem to appear in non-SDXL only "te_" */ +std::string convert_sdxl_lora_name(std::string tensor_name) { + const std::pair sdxl_lora_name_lookup[] = { + {"unet", "model_diffusion_model"}, + {"te2", "cond_stage_model_1_transformer"}, + {"te1", "cond_stage_model_transformer"}, + {"text_encoder_2", "cond_stage_model_1_transformer"}, + {"text_encoder", "cond_stage_model_transformer"}, + }; + for (auto& pair_i : sdxl_lora_name_lookup) { + if (tensor_name.compare(0, pair_i.first.length(), pair_i.first) == 0) { + tensor_name = std::regex_replace(tensor_name, std::regex(pair_i.first), pair_i.second); + break; + } + } + return tensor_name; +} + std::unordered_map> suffix_conversion_underline = { { "attentions", @@ -248,7 +271,7 @@ std::unordered_map> su }, }; -std::string convert_diffusers_name_to_compvis(const std::string& key, char seq) { +std::string convert_diffusers_name_to_compvis(std::string key, char seq) { std::vector m; auto match = [](std::vector& match_list, const std::regex& regex, const std::string& key) { @@ -282,6 +305,11 @@ std::string convert_diffusers_name_to_compvis(const std::string& key, char seq) return inner_key; }; + // convert attn to out + if (ends_with(key, "to_out")) { + key += format("%c0", seq); + } + // unet if (match(m, std::regex(format("unet%cconv_in(.*)", seq)), key)) { return format("model%cdiffusion_model%cinput_blocks%c0%c0", seq, seq, seq, seq) + m[0]; @@ -391,8 +419,8 @@ std::string convert_diffusers_name_to_compvis(const std::string& key, char seq) } std::string convert_tensor_name(const std::string& name) { - std::string new_name; - if (starts_with(name, "cond_stage_model.") || starts_with(name, "conditioner.embedders.")) { + std::string new_name = name; + if (starts_with(name, "cond_stage_model.") || starts_with(name, "conditioner.embedders.") || ends_with(name, ".vision_model.visual_projection.weight")) { new_name = convert_open_clip_to_hf_clip(name); } else if (starts_with(name, "first_stage_model.decoder")) { new_name = convert_vae_decoder_name(name); @@ -406,8 +434,12 @@ std::string convert_tensor_name(const std::string& name) { if (pos != std::string::npos) { std::string name_without_network_parts = name.substr(5, pos - 5); std::string network_part = name.substr(pos + 1); + // LOG_DEBUG("%s %s", name_without_network_parts.c_str(), network_part.c_str()); std::string new_key = convert_diffusers_name_to_compvis(name_without_network_parts, '_'); + /* For dealing with the new SDXL LoRA tensor naming convention */ + new_key = convert_sdxl_lora_name(new_key); + if (new_key.empty()) { new_name = name; } else { @@ -416,6 +448,33 @@ std::string convert_tensor_name(const std::string& name) { } else { new_name = name; } + } else if (contains(name, "lora_up") || contains(name, "lora_down") || + contains(name, "lora.up") || contains(name, "lora.down") || + contains(name, "lora_linear")) { + size_t pos = new_name.find(".processor"); + if (pos != std::string::npos) { + new_name.replace(pos, strlen(".processor"), ""); + } + pos = new_name.rfind("lora"); + if (pos != std::string::npos) { + std::string name_without_network_parts = new_name.substr(0, pos - 1); + std::string network_part = new_name.substr(pos); + // LOG_DEBUG("%s %s", name_without_network_parts.c_str(), network_part.c_str()); + std::string new_key = convert_diffusers_name_to_compvis(name_without_network_parts, '.'); + new_key = convert_sdxl_lora_name(new_key); + replace_all_chars(new_key, '.', '_'); + size_t npos = network_part.rfind("_linear_layer"); + if (npos != std::string::npos) { + network_part.replace(npos, strlen("_linear_layer"), ""); + } + if (starts_with(network_part, "lora.")) { + network_part = "lora_" + network_part.substr(5); + } + if (new_key.size() > 0) { + new_name = "lora." + new_key + "." + network_part; + } + // LOG_DEBUG("new name: %s", new_name.c_str()); + } } else if (starts_with(name, "unet") || starts_with(name, "vae") || starts_with(name, "te")) { // for diffuser size_t pos = name.find_last_of('.'); if (pos != std::string::npos) { @@ -832,8 +891,12 @@ bool ModelLoader::init_from_safetensors_file(const std::string& file_path, const } } - TensorStorage tensor_storage(prefix + name, type, ne, n_dims, file_index, ST_HEADER_SIZE_LEN + header_size_ + begin); + // ggml_n_dims returns 1 for scalars + if (n_dims == 0) { + n_dims = 1; + } + TensorStorage tensor_storage(prefix + name, type, ne, n_dims, file_index, ST_HEADER_SIZE_LEN + header_size_ + begin); tensor_storage.reverse_ne(); size_t tensor_data_size = end - begin; @@ -1172,7 +1235,9 @@ bool ModelLoader::parse_data_pkl(uint8_t* buffer, if (reader.phase == PickleTensorReader::READ_DIMENS) { reader.tensor_storage.reverse_ne(); reader.tensor_storage.file_index = file_index; - reader.tensor_storage.name = prefix + reader.tensor_storage.name; + // if(strcmp(prefix.c_str(), "scarlett") == 0) + // printf(" got tensor %s \n ", reader.tensor_storage.name.c_str()); + reader.tensor_storage.name = prefix + reader.tensor_storage.name; tensor_storages.push_back(reader.tensor_storage); // LOG_DEBUG("%s", reader.tensor_storage.name.c_str()); // reset @@ -1275,7 +1340,8 @@ std::string ModelLoader::load_merges() { return merges_utf8_str; } -void remove_duplicates(std::vector& vec) { +std::vector remove_duplicates(const std::vector& vec) { + std::vector res; std::unordered_map name_to_index_map; for (size_t i = 0; i < vec.size(); ++i) { @@ -1283,13 +1349,16 @@ void remove_duplicates(std::vector& vec) { auto it = name_to_index_map.find(current_name); if (it != name_to_index_map.end()) { - vec[it->second] = vec[i]; + res[it->second] = vec[i]; } else { name_to_index_map[current_name] = i; + res.push_back(vec[i]); } } - vec.resize(name_to_index_map.size()); + // vec.resize(name_to_index_map.size()); + + return res; } bool ModelLoader::load_tensors(on_new_tensor_cb_t on_new_tensor_cb, ggml_backend_t backend) { @@ -1303,7 +1372,9 @@ bool ModelLoader::load_tensors(on_new_tensor_cb_t on_new_tensor_cb, ggml_backend preprocess_tensor(tensor_storage, processed_tensor_storages); } - remove_duplicates(processed_tensor_storages); + std::vector dedup = remove_duplicates(processed_tensor_storages); + processed_tensor_storages = dedup; + bool success = true; for (size_t file_index = 0; file_index < file_paths_.size(); file_index++) { std::string file_path = file_paths_[file_index]; @@ -1365,7 +1436,6 @@ bool ModelLoader::load_tensors(on_new_tensor_cb_t on_new_tensor_cb, ggml_backend if (tensor_storage.file_index != file_index) { continue; } - ggml_tensor* dst_tensor = NULL; success = on_new_tensor_cb(tensor_storage, &dst_tensor); diff --git a/otherarch/sdcpp/util.cpp b/otherarch/sdcpp/util.cpp index 0811ec0b5..47ae7774c 100644 --- a/otherarch/sdcpp/util.cpp +++ b/otherarch/sdcpp/util.cpp @@ -43,6 +43,13 @@ bool starts_with(const std::string& str, const std::string& start) { return false; } +bool contains(const std::string& str, const std::string& substr) { + if (str.find(substr) != std::string::npos) { + return true; + } + return false; +} + void replace_all_chars(std::string& str, char target, char replacement) { for (size_t i = 0; i < str.length(); ++i) { if (str[i] == target) { diff --git a/otherarch/sdcpp/util.h b/otherarch/sdcpp/util.h index 837f20994..6a42568f6 100644 --- a/otherarch/sdcpp/util.h +++ b/otherarch/sdcpp/util.h @@ -8,6 +8,7 @@ bool ends_with(const std::string& str, const std::string& ending); bool starts_with(const std::string& str, const std::string& start); +bool contains(const std::string& str, const std::string& substr); std::string format(const char* fmt, ...);