mirror of
https://github.com/LostRuins/koboldcpp.git
synced 2025-09-11 09:34:37 +00:00
taesd for sdxl, add lora loading done
This commit is contained in:
parent
2ee808a747
commit
5ce2fdad24
12 changed files with 59 additions and 13 deletions
|
@ -121,6 +121,7 @@ public:
|
|||
schedule_t schedule,
|
||||
bool control_net_cpu) {
|
||||
use_tiny_autoencoder = taesd_path.size() > 0;
|
||||
std::string taesd_path_fixed = taesd_path;
|
||||
#ifdef SD_USE_CUBLAS
|
||||
LOG_DEBUG("Using CUDA backend");
|
||||
backend = ggml_backend_cuda_init(0);
|
||||
|
@ -165,6 +166,17 @@ public:
|
|||
return false;
|
||||
}
|
||||
LOG_INFO("Stable Diffusion %s ", model_version_to_str[version]);
|
||||
|
||||
if(use_tiny_autoencoder && version==VERSION_XL)
|
||||
{
|
||||
std::string to_search = "taesd.embd";
|
||||
std::string to_replace = "taesd_xl.embd";
|
||||
size_t pos = taesd_path_fixed.find(to_search);
|
||||
if (pos != std::string::npos) {
|
||||
taesd_path_fixed.replace(pos, to_search.length(), to_replace);
|
||||
}
|
||||
}
|
||||
|
||||
if (wtype == GGML_TYPE_COUNT) {
|
||||
model_data_type = model_loader.get_sd_wtype();
|
||||
} else {
|
||||
|
@ -175,7 +187,7 @@ public:
|
|||
|
||||
if (version == VERSION_XL) {
|
||||
scale_factor = 0.13025f;
|
||||
if (vae_path.size() == 0 && taesd_path.size() == 0) {
|
||||
if (vae_path.size() == 0 && taesd_path_fixed.size() == 0) {
|
||||
LOG_WARN(
|
||||
"!!!It looks like you are using SDXL model. "
|
||||
"If you find that the generated images are completely black, "
|
||||
|
@ -287,7 +299,7 @@ public:
|
|||
if (!use_tiny_autoencoder) {
|
||||
vae_params_mem_size = first_stage_model->get_params_mem_size();
|
||||
} else {
|
||||
if (!tae_first_stage->load_from_file(taesd_path)) {
|
||||
if (!tae_first_stage->load_from_file(taesd_path_fixed)) {
|
||||
return false;
|
||||
}
|
||||
vae_params_mem_size = tae_first_stage->get_params_mem_size();
|
||||
|
@ -390,6 +402,33 @@ public:
|
|||
return result < -1;
|
||||
}
|
||||
|
||||
void apply_lora_from_file(const std::string& lora_path, float multiplier) {
|
||||
int64_t t0 = ggml_time_ms();
|
||||
std::string st_file_path = lora_path;
|
||||
std::string file_path;
|
||||
if (file_exists(st_file_path)) {
|
||||
file_path = st_file_path;
|
||||
} else {
|
||||
LOG_WARN("can not find %s for lora %s", st_file_path.c_str(), lora_path.c_str());
|
||||
return;
|
||||
}
|
||||
LoraModel lora(backend, model_data_type, file_path);
|
||||
if (!lora.load_from_file()) {
|
||||
LOG_WARN("load lora tensors from %s failed", file_path.c_str());
|
||||
return;
|
||||
}
|
||||
|
||||
lora.multiplier = multiplier;
|
||||
lora.apply(tensors, n_threads);
|
||||
lora.free_params_buffer();
|
||||
|
||||
int64_t t1 = ggml_time_ms();
|
||||
|
||||
LOG_INFO("lora '%s' applied, taking %.2fs",
|
||||
lora_path.c_str(),
|
||||
(t1 - t0) * 1.0f / 1000);
|
||||
}
|
||||
|
||||
void apply_lora(const std::string& lora_name, float multiplier) {
|
||||
int64_t t0 = ggml_time_ms();
|
||||
std::string st_file_path = path_join(lora_model_dir, lora_name + ".safetensors");
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue