sd: sync to master-431-23fce0b (#1893)

* sd: sync to master-427-78e15bd

* add kl_optimal to the available schedulers list

* more robust workaround to avoid stb linkage issues

* sd: sync to master-431-23fce0b

* add TAEHV support and disable TAE if the model isn't found
This commit is contained in:
Wagner Bruna 2025-12-22 04:07:09 -03:00 committed by GitHub
parent 27c53099f4
commit 44ce1a80b3
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
11 changed files with 787 additions and 196 deletions

View file

@ -136,7 +136,6 @@ public:
std::map<std::string, struct ggml_tensor*> tensors;
std::string lora_model_dir;
// lora_name => multiplier
std::unordered_map<std::string, float> curr_lora_state;
@ -219,7 +218,6 @@ public:
n_threads = sd_ctx_params->n_threads;
vae_decode_only = sd_ctx_params->vae_decode_only;
free_params_immediately = sd_ctx_params->free_params_immediately;
lora_model_dir = SAFE_STR(sd_ctx_params->lora_model_dir);
taesd_path = SAFE_STR(sd_ctx_params->taesd_path);
use_tiny_autoencoder = taesd_path.size() > 0;
offload_params_to_cpu = sd_ctx_params->offload_params_to_cpu;
@ -418,6 +416,14 @@ public:
{
to_replace = "taesd_3.embd";
}
else if(version == VERSION_WAN2_2_TI2V)
{
to_replace = "taesd_w22.embd";
}
else if(sd_version_is_wan(version)||sd_version_is_qwen_image(version))
{
to_replace = "taesd_w21.embd";
}
if(to_replace!="")
{
@ -432,6 +438,12 @@ public:
taesd_path_fixed = "";
use_tiny_autoencoder = false;
}
if (use_tiny_autoencoder && !file_exists(taesd_path_fixed))
{
printf("\nCannot use TAESD: \"%s\" not found. TAESD Disabled!\n", taesd_path_fixed);
taesd_path_fixed = "";
use_tiny_autoencoder = false;
}
}
ggml_type wtype = (int)sd_ctx_params->wtype < std::min<int>(SD_TYPE_COUNT, GGML_TYPE_COUNT)
@ -663,6 +675,9 @@ public:
if (sd_ctx_params->diffusion_flash_attn) {
LOG_INFO("Using flash attention in the diffusion model");
diffusion_model->set_flash_attn_enabled(true);
if (high_noise_diffusion_model) {
high_noise_diffusion_model->set_flash_attn_enabled(true);
}
}
cond_stage_model->alloc_params_buffer();
@ -688,14 +703,27 @@ public:
}
if (sd_version_is_wan(version) || sd_version_is_qwen_image(version)) {
first_stage_model = std::make_shared<WAN::WanVAERunner>(vae_backend,
offload_params_to_cpu,
tensor_storage_map,
"first_stage_model",
vae_decode_only,
version);
first_stage_model->alloc_params_buffer();
first_stage_model->get_param_tensors(tensors, "first_stage_model");
if (!use_tiny_autoencoder) {
first_stage_model = std::make_shared<WAN::WanVAERunner>(vae_backend,
offload_params_to_cpu,
tensor_storage_map,
"first_stage_model",
vae_decode_only,
version);
first_stage_model->alloc_params_buffer();
first_stage_model->get_param_tensors(tensors, "first_stage_model");
} else {
tae_first_stage = std::make_shared<TinyVideoAutoEncoder>(vae_backend,
offload_params_to_cpu,
tensor_storage_map,
"decoder",
vae_decode_only,
version);
if (sd_ctx_params->vae_conv_direct) {
LOG_INFO("Using Conv2d direct in the tae model");
tae_first_stage->set_conv2d_direct_enabled(true);
}
}
} else if (version == VERSION_CHROMA_RADIANCE) {
first_stage_model = std::make_shared<FakeVAE>(vae_backend,
offload_params_to_cpu);
@ -722,14 +750,13 @@ public:
}
first_stage_model->alloc_params_buffer();
first_stage_model->get_param_tensors(tensors, "first_stage_model");
}
if (use_tiny_autoencoder) {
tae_first_stage = std::make_shared<TinyAutoEncoder>(vae_backend,
offload_params_to_cpu,
tensor_storage_map,
"decoder.layers",
vae_decode_only,
version);
} else if (use_tiny_autoencoder) {
tae_first_stage = std::make_shared<TinyImageAutoEncoder>(vae_backend,
offload_params_to_cpu,
tensor_storage_map,
"decoder.layers",
vae_decode_only,
version);
if (sd_ctx_params->vae_conv_direct) {
LOG_INFO("Using Conv2d direct in the tae model");
tae_first_stage->set_conv2d_direct_enabled(true);
@ -823,6 +850,8 @@ public:
if (stacked_id) {
ignore_tensors.insert("pmid.unet.");
}
ignore_tensors.insert("model.diffusion_model.__x0__");
ignore_tensors.insert("model.diffusion_model.__32x32__");
if (vae_decode_only) {
ignore_tensors.insert("first_stage_model.encoder");
@ -957,6 +986,7 @@ public:
}
} else if (sd_version_is_flux(version)) {
pred_type = FLUX_FLOW_PRED;
if (flow_shift == INFINITY) {
flow_shift = 1.0f; // TODO: validate
for (const auto& [name, tensor_storage] : tensor_storage_map) {
@ -1612,6 +1642,17 @@ public:
std::vector<int> skip_layers(guidance.slg.layers, guidance.slg.layers + guidance.slg.layer_count);
float cfg_scale = guidance.txt_cfg;
if (cfg_scale < 1.f) {
if (cfg_scale == 0.f) {
// Diffusers follow the convention from the original paper
// (https://arxiv.org/abs/2207.12598v1), so many distilled model docs
// recommend 0 as guidance; warn the user that it'll disable prompt folowing
LOG_WARN("unconditioned mode, images won't follow the prompt (use cfg-scale=1 for distilled models)");
} else {
LOG_WARN("cfg value out of expected range may produce unexpected results");
}
}
float img_cfg_scale = std::isfinite(guidance.img_cfg) ? guidance.img_cfg : guidance.txt_cfg;
float slg_scale = guidance.slg.scale;
@ -2527,6 +2568,7 @@ const char* scheduler_to_str[] = {
"sgm_uniform",
"simple",
"smoothstep",
"kl_optimal",
"lcm",
};
@ -2664,7 +2706,6 @@ char* sd_ctx_params_to_str(const sd_ctx_params_t* sd_ctx_params) {
"vae_path: %s\n"
"taesd_path: %s\n"
"control_net_path: %s\n"
"lora_model_dir: %s\n"
"photo_maker_path: %s\n"
"tensor_type_rules: %s\n"
"vae_decode_only: %s\n"
@ -2694,7 +2735,6 @@ char* sd_ctx_params_to_str(const sd_ctx_params_t* sd_ctx_params) {
SAFE_STR(sd_ctx_params->vae_path),
SAFE_STR(sd_ctx_params->taesd_path),
SAFE_STR(sd_ctx_params->control_net_path),
SAFE_STR(sd_ctx_params->lora_model_dir),
SAFE_STR(sd_ctx_params->photo_maker_path),
SAFE_STR(sd_ctx_params->tensor_type_rules),
BOOL_STR(sd_ctx_params->vae_decode_only),
@ -2893,13 +2933,16 @@ enum sample_method_t sd_get_default_sample_method(const sd_ctx_t* sd_ctx) {
return EULER_A_SAMPLE_METHOD;
}
enum scheduler_t sd_get_default_scheduler(const sd_ctx_t* sd_ctx) {
enum scheduler_t sd_get_default_scheduler(const sd_ctx_t* sd_ctx, enum sample_method_t sample_method) {
if (sd_ctx != nullptr && sd_ctx->sd != nullptr) {
auto edm_v_denoiser = std::dynamic_pointer_cast<EDMVDenoiser>(sd_ctx->sd->denoiser);
if (edm_v_denoiser) {
return EXPONENTIAL_SCHEDULER;
}
}
if (sample_method == LCM_SAMPLE_METHOD) {
return LCM_SCHEDULER;
}
return DISCRETE_SCHEDULER;
}
@ -3334,9 +3377,13 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_g
LOG_WARN("sample_steps != custom_sigmas_count - 1, set sample_steps to %d", sample_steps);
}
} else {
scheduler_t scheduler = sd_img_gen_params->sample_params.scheduler;
if (scheduler == SCHEDULER_COUNT) {
scheduler = sd_get_default_scheduler(sd_ctx, sample_method);
}
sigmas = sd_ctx->sd->denoiser->get_sigmas(sample_steps,
sd_ctx->sd->get_image_seq_len(height, width),
sd_img_gen_params->sample_params.scheduler,
scheduler,
sd_ctx->sd->version);
}
@ -3619,9 +3666,13 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s
}
}
} else {
scheduler_t scheduler = sd_vid_gen_params->sample_params.scheduler;
if (scheduler == SCHEDULER_COUNT) {
scheduler = sd_get_default_scheduler(sd_ctx, sample_method);
}
sigmas = sd_ctx->sd->denoiser->get_sigmas(total_steps,
0,
sd_vid_gen_params->sample_params.scheduler,
scheduler,
sd_ctx->sd->version);
}
@ -3746,7 +3797,8 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s
denoise_mask = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, init_latent->ne[0], init_latent->ne[1], init_latent->ne[2], 1);
ggml_set_f32(denoise_mask, 1.f);
sd_ctx->sd->process_latent_out(init_latent);
if (!sd_ctx->sd->use_tiny_autoencoder)
sd_ctx->sd->process_latent_out(init_latent);
ggml_ext_tensor_iter(init_image_latent, [&](ggml_tensor* t, int64_t i0, int64_t i1, int64_t i2, int64_t i3) {
float value = ggml_ext_tensor_get_f32(t, i0, i1, i2, i3);
@ -3756,7 +3808,8 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s
}
});
sd_ctx->sd->process_latent_in(init_latent);
if (!sd_ctx->sd->use_tiny_autoencoder)
sd_ctx->sd->process_latent_in(init_latent);
int64_t t2 = ggml_time_ms();
LOG_INFO("encode_first_stage completed, taking %" PRId64 " ms", t2 - t1);
@ -3979,7 +4032,7 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s
struct ggml_tensor* vid = sd_ctx->sd->decode_first_stage(work_ctx, final_latent, true);
int64_t t5 = ggml_time_ms();
LOG_INFO("decode_first_stage completed, taking %.2fs", (t5 - t4) * 1.0f / 1000);
if (sd_ctx->sd->free_params_immediately) {
if (sd_ctx->sd->free_params_immediately && !sd_ctx->sd->use_tiny_autoencoder) {
sd_ctx->sd->first_stage_model->free_params_buffer();
}