taesd for sdxl, add lora loading done

This commit is contained in:
Concedo 2024-05-14 23:02:56 +08:00
parent 2ee808a747
commit 5ce2fdad24
12 changed files with 59 additions and 13 deletions

View file

@ -121,6 +121,7 @@ public:
schedule_t schedule,
bool control_net_cpu) {
use_tiny_autoencoder = taesd_path.size() > 0;
std::string taesd_path_fixed = taesd_path;
#ifdef SD_USE_CUBLAS
LOG_DEBUG("Using CUDA backend");
backend = ggml_backend_cuda_init(0);
@ -165,6 +166,17 @@ public:
return false;
}
LOG_INFO("Stable Diffusion %s ", model_version_to_str[version]);
if(use_tiny_autoencoder && version==VERSION_XL)
{
std::string to_search = "taesd.embd";
std::string to_replace = "taesd_xl.embd";
size_t pos = taesd_path_fixed.find(to_search);
if (pos != std::string::npos) {
taesd_path_fixed.replace(pos, to_search.length(), to_replace);
}
}
if (wtype == GGML_TYPE_COUNT) {
model_data_type = model_loader.get_sd_wtype();
} else {
@ -175,7 +187,7 @@ public:
if (version == VERSION_XL) {
scale_factor = 0.13025f;
if (vae_path.size() == 0 && taesd_path.size() == 0) {
if (vae_path.size() == 0 && taesd_path_fixed.size() == 0) {
LOG_WARN(
"!!!It looks like you are using SDXL model. "
"If you find that the generated images are completely black, "
@ -287,7 +299,7 @@ public:
if (!use_tiny_autoencoder) {
vae_params_mem_size = first_stage_model->get_params_mem_size();
} else {
if (!tae_first_stage->load_from_file(taesd_path)) {
if (!tae_first_stage->load_from_file(taesd_path_fixed)) {
return false;
}
vae_params_mem_size = tae_first_stage->get_params_mem_size();
@ -390,6 +402,33 @@ public:
return result < -1;
}
void apply_lora_from_file(const std::string& lora_path, float multiplier) {
int64_t t0 = ggml_time_ms();
std::string st_file_path = lora_path;
std::string file_path;
if (file_exists(st_file_path)) {
file_path = st_file_path;
} else {
LOG_WARN("can not find %s for lora %s", st_file_path.c_str(), lora_path.c_str());
return;
}
LoraModel lora(backend, model_data_type, file_path);
if (!lora.load_from_file()) {
LOG_WARN("load lora tensors from %s failed", file_path.c_str());
return;
}
lora.multiplier = multiplier;
lora.apply(tensors, n_threads);
lora.free_params_buffer();
int64_t t1 = ggml_time_ms();
LOG_INFO("lora '%s' applied, taking %.2fs",
lora_path.c_str(),
(t1 - t0) * 1.0f / 1000);
}
void apply_lora(const std::string& lora_name, float multiplier) {
int64_t t0 = ggml_time_ms();
std::string st_file_path = path_join(lora_model_dir, lora_name + ".safetensors");