diff --git a/koboldcpp.py b/koboldcpp.py
index 31beebdb7..078045ab8 100755
--- a/koboldcpp.py
+++ b/koboldcpp.py
@@ -3951,7 +3951,8 @@ Change Mode
if friendlysdmodelname=="inactive" or fullsdmodelpath=="":
response_body = (json.dumps([]).encode())
else:
- response_body = (json.dumps([{"name":"Euler","aliases":["k_euler"],"options":{}},{"name":"Euler a","aliases":["k_euler_a","k_euler_ancestral"],"options":{}},{"name":"Heun","aliases":["k_heun"],"options":{}},{"name":"DPM2","aliases":["k_dpm_2"],"options":{}},{"name":"DPM++ 2M","aliases":["k_dpmpp_2m"],"options":{}},{"name":"DDIM","aliases":["ddim"],"options":{}},{"name":"LCM","aliases":["k_lcm"],"options":{}},{"name":"Default","aliases":["default"],"options":{}}]).encode())
+ response_body = (json.dumps([{"name":"Euler","aliases":["k_euler"],"options":{}},{"name":"Euler a","aliases":["k_euler_a","k_euler_ancestral"],"options":{}},{"name":"Heun","aliases":["k_heun"],"options":{}},{"name":"DPM2","aliases":["k_dpm_2"],"options":{}},{"name":"DPM++ 2M","aliases":["k_dpmpp_2m"],"options":{}},{"name":"DDIM","aliases":["ddim"],"options":{}},{"name":"LCM","aliases":["k_lcm"],"options":{}},{"name":"Res 2s","aliases":["k_res_2s"],"options":{}},{"name":"Res Multistep","aliases":["k_res_multistep"],"options":{}},
+ {"name":"Default","aliases":["default"],"options":{}}]).encode())
elif clean_path.endswith('/sdapi/v1/schedulers'):
if friendlysdmodelname=="inactive" or fullsdmodelpath=="":
response_body = (json.dumps([]).encode())
diff --git a/otherarch/sdcpp/common/common.hpp b/otherarch/sdcpp/common/common.hpp
index cadc92aaa..749ab2b86 100644
--- a/otherarch/sdcpp/common/common.hpp
+++ b/otherarch/sdcpp/common/common.hpp
@@ -462,6 +462,7 @@ struct SDContextParams {
bool control_net_cpu = false;
bool clip_on_cpu = false;
bool vae_on_cpu = false;
+ bool flash_attn = false;
bool diffusion_flash_attn = false;
bool diffusion_conv_direct = false;
bool vae_conv_direct = false;
@@ -620,9 +621,13 @@ struct SDContextParams {
"--vae-on-cpu",
"keep vae in cpu (for low vram)",
true, &vae_on_cpu},
+ {"",
+ "--fa",
+ "use flash attention",
+ true, &flash_attn},
{"",
"--diffusion-fa",
- "use flash attention in the diffusion model",
+ "use flash attention in the diffusion model only",
true, &diffusion_flash_attn},
{"",
"--diffusion-conv-direct",
@@ -909,6 +914,7 @@ struct SDContextParams {
<< " control_net_cpu: " << (control_net_cpu ? "true" : "false") << ",\n"
<< " clip_on_cpu: " << (clip_on_cpu ? "true" : "false") << ",\n"
<< " vae_on_cpu: " << (vae_on_cpu ? "true" : "false") << ",\n"
+ << " flash_attn: " << (flash_attn ? "true" : "false") << ",\n"
<< " diffusion_flash_attn: " << (diffusion_flash_attn ? "true" : "false") << ",\n"
<< " diffusion_conv_direct: " << (diffusion_conv_direct ? "true" : "false") << ",\n"
<< " vae_conv_direct: " << (vae_conv_direct ? "true" : "false") << ",\n"
@@ -973,6 +979,7 @@ struct SDContextParams {
clip_on_cpu,
control_net_cpu,
vae_on_cpu,
+ flash_attn,
diffusion_flash_attn,
taesd_preview,
diffusion_conv_direct,
@@ -1483,17 +1490,17 @@ struct SDGenerationParams {
on_seed_arg},
{"",
"--sampling-method",
- "sampling method, one of [euler, euler_a, heun, dpm2, dpm++2s_a, dpm++2m, dpm++2mv2, ipndm, ipndm_v, lcm, ddim_trailing, tcd] "
+ "sampling method, one of [euler, euler_a, heun, dpm2, dpm++2s_a, dpm++2m, dpm++2mv2, ipndm, ipndm_v, lcm, ddim_trailing, tcd, res_multistep, res_2s] "
"(default: euler for Flux/SD3/Wan, euler_a otherwise)",
on_sample_method_arg},
{"",
"--high-noise-sampling-method",
- "(high noise) sampling method, one of [euler, euler_a, heun, dpm2, dpm++2s_a, dpm++2m, dpm++2mv2, ipndm, ipndm_v, lcm, ddim_trailing, tcd]"
+ "(high noise) sampling method, one of [euler, euler_a, heun, dpm2, dpm++2s_a, dpm++2m, dpm++2mv2, ipndm, ipndm_v, lcm, ddim_trailing, tcd, res_multistep, res_2s]"
" default: euler for Flux/SD3/Wan, euler_a otherwise",
on_high_noise_sample_method_arg},
{"",
"--scheduler",
- "denoiser sigma scheduler, one of [discrete, karras, exponential, ays, gits, smoothstep, sgm_uniform, simple, kl_optimal, lcm], default: discrete",
+ "denoiser sigma scheduler, one of [discrete, karras, exponential, ays, gits, smoothstep, sgm_uniform, simple, kl_optimal, lcm, bong_tangent], default: discrete",
on_scheduler_arg},
{"",
"--sigmas",
diff --git a/otherarch/sdcpp/conditioner.hpp b/otherarch/sdcpp/conditioner.hpp
index a4e84aa3b..b1876954f 100644
--- a/otherarch/sdcpp/conditioner.hpp
+++ b/otherarch/sdcpp/conditioner.hpp
@@ -34,6 +34,7 @@ struct Conditioner {
virtual void free_params_buffer() = 0;
virtual void get_param_tensors(std::map& tensors) = 0;
virtual size_t get_params_buffer_size() = 0;
+ virtual void set_flash_attention_enabled(bool enabled) = 0;
virtual void set_weight_adapter(const std::shared_ptr& adapter) {}
virtual std::tuple> get_learned_condition_with_trigger(ggml_context* work_ctx,
int n_threads,
@@ -115,6 +116,13 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
return buffer_size;
}
+ void set_flash_attention_enabled(bool enabled) override {
+ text_model->set_flash_attention_enabled(enabled);
+ if (sd_version_is_sdxl(version)) {
+ text_model2->set_flash_attention_enabled(enabled);
+ }
+ }
+
void set_weight_adapter(const std::shared_ptr& adapter) override {
text_model->set_weight_adapter(adapter);
if (sd_version_is_sdxl(version)) {
@@ -783,6 +791,18 @@ struct SD3CLIPEmbedder : public Conditioner {
return buffer_size;
}
+ void set_flash_attention_enabled(bool enabled) override {
+ if (clip_l) {
+ clip_l->set_flash_attention_enabled(enabled);
+ }
+ if (clip_g) {
+ clip_g->set_flash_attention_enabled(enabled);
+ }
+ if (t5) {
+ t5->set_flash_attention_enabled(enabled);
+ }
+ }
+
void set_weight_adapter(const std::shared_ptr& adapter) override {
if (clip_l) {
clip_l->set_weight_adapter(adapter);
@@ -1191,6 +1211,15 @@ struct FluxCLIPEmbedder : public Conditioner {
return buffer_size;
}
+ void set_flash_attention_enabled(bool enabled) override {
+ if (clip_l) {
+ clip_l->set_flash_attention_enabled(enabled);
+ }
+ if (t5) {
+ t5->set_flash_attention_enabled(enabled);
+ }
+ }
+
void set_weight_adapter(const std::shared_ptr& adapter) {
if (clip_l) {
clip_l->set_weight_adapter(adapter);
@@ -1440,6 +1469,12 @@ struct T5CLIPEmbedder : public Conditioner {
return buffer_size;
}
+ void set_flash_attention_enabled(bool enabled) override {
+ if (t5) {
+ t5->set_flash_attention_enabled(enabled);
+ }
+ }
+
void set_weight_adapter(const std::shared_ptr& adapter) override {
if (t5) {
t5->set_weight_adapter(adapter);
@@ -1650,6 +1685,10 @@ struct LLMEmbedder : public Conditioner {
return buffer_size;
}
+ void set_flash_attention_enabled(bool enabled) override {
+ llm->set_flash_attention_enabled(enabled);
+ }
+
void set_weight_adapter(const std::shared_ptr& adapter) override {
if (llm) {
llm->set_weight_adapter(adapter);
diff --git a/otherarch/sdcpp/denoiser.hpp b/otherarch/sdcpp/denoiser.hpp
index 98aef702d..7e99b84a8 100644
--- a/otherarch/sdcpp/denoiser.hpp
+++ b/otherarch/sdcpp/denoiser.hpp
@@ -1,6 +1,8 @@
#ifndef __DENOISER_HPP__
#define __DENOISER_HPP__
+#include
+
#include "ggml_extend.hpp"
#include "gits_noise.inl"
@@ -351,6 +353,95 @@ struct SmoothStepScheduler : SigmaScheduler {
}
};
+struct BongTangentScheduler : SigmaScheduler {
+ static constexpr float kPi = 3.14159265358979323846f;
+
+ static std::vector get_bong_tangent_sigmas(int steps, float slope, float pivot, float start, float end) {
+ std::vector sigmas;
+ if (steps <= 0) {
+ return sigmas;
+ }
+
+ float smax = ((2.0f / kPi) * atanf(-slope * (0.0f - pivot)) + 1.0f) * 0.5f;
+ float smin = ((2.0f / kPi) * atanf(-slope * ((float)(steps - 1) - pivot)) + 1.0f) * 0.5f;
+ float srange = smax - smin;
+ float sscale = start - end;
+
+ sigmas.reserve(steps);
+
+ if (fabsf(srange) < 1e-8f) {
+ if (steps == 1) {
+ sigmas.push_back(start);
+ return sigmas;
+ }
+ for (int i = 0; i < steps; ++i) {
+ float t = (float)i / (float)(steps - 1);
+ sigmas.push_back(start + (end - start) * t);
+ }
+ return sigmas;
+ }
+
+ float inv_srange = 1.0f / srange;
+ for (int x = 0; x < steps; ++x) {
+ float v = ((2.0f / kPi) * atanf(-slope * ((float)x - pivot)) + 1.0f) * 0.5f;
+ float sigma = ((v - smin) * inv_srange) * sscale + end;
+ sigmas.push_back(sigma);
+ }
+
+ return sigmas;
+ }
+
+ std::vector get_sigmas(uint32_t n, float sigma_min, float sigma_max, t_to_sigma_t /*t_to_sigma*/) override {
+ std::vector result;
+ if (n == 0) {
+ return result;
+ }
+
+ float start = sigma_max;
+ float end = sigma_min;
+ float middle = sigma_min + (sigma_max - sigma_min) * 0.5f;
+
+ float pivot_1 = 0.6f;
+ float pivot_2 = 0.6f;
+ float slope_1 = 0.2f;
+ float slope_2 = 0.2f;
+
+ int steps = static_cast(n) + 2;
+ int midpoint = static_cast(((float)steps * pivot_1 + (float)steps * pivot_2) * 0.5f);
+ int pivot_1_i = static_cast((float)steps * pivot_1);
+ int pivot_2_i = static_cast((float)steps * pivot_2);
+
+ float slope_scale = (float)steps / 40.0f;
+ slope_1 = slope_1 / slope_scale;
+ slope_2 = slope_2 / slope_scale;
+
+ int stage_2_len = steps - midpoint;
+ int stage_1_len = steps - stage_2_len;
+
+ std::vector sigmas_1 = get_bong_tangent_sigmas(stage_1_len, slope_1, (float)pivot_1_i, start, middle);
+ std::vector sigmas_2 = get_bong_tangent_sigmas(stage_2_len, slope_2, (float)(pivot_2_i - stage_1_len), middle, end);
+
+ if (!sigmas_1.empty()) {
+ sigmas_1.pop_back();
+ }
+
+ result.reserve(n + 1);
+ result.insert(result.end(), sigmas_1.begin(), sigmas_1.end());
+ result.insert(result.end(), sigmas_2.begin(), sigmas_2.end());
+
+ if (result.size() < n + 1) {
+ while (result.size() < n + 1) {
+ result.push_back(end);
+ }
+ } else if (result.size() > n + 1) {
+ result.resize(n + 1);
+ }
+
+ result[n] = 0.0f;
+ return result;
+ }
+};
+
struct KLOptimalScheduler : SigmaScheduler {
std::vector get_sigmas(uint32_t n, float sigma_min, float sigma_max, t_to_sigma_t t_to_sigma) override {
std::vector sigmas;
@@ -431,6 +522,10 @@ struct Denoiser {
LOG_INFO("get_sigmas with SmoothStep scheduler");
scheduler = std::make_shared();
break;
+ case BONG_TANGENT_SCHEDULER:
+ LOG_INFO("get_sigmas with bong_tangent scheduler");
+ scheduler = std::make_shared();
+ break;
case KL_OPTIMAL_SCHEDULER:
LOG_INFO("get_sigmas with KL Optimal scheduler");
scheduler = std::make_shared();
@@ -1634,6 +1729,216 @@ static bool sample_k_diffusion(sample_method_t method,
}
}
} break;
+ case RES_MULTISTEP_SAMPLE_METHOD: // Res Multistep sampler
+ {
+ struct ggml_tensor* noise = ggml_dup_tensor(work_ctx, x);
+ struct ggml_tensor* old_denoised = ggml_dup_tensor(work_ctx, x);
+
+ bool have_old_sigma = false;
+ float old_sigma_down = 0.0f;
+
+ auto t_fn = [](float sigma) -> float { return -logf(sigma); };
+ auto sigma_fn = [](float t) -> float { return expf(-t); };
+ auto phi1_fn = [](float t) -> float {
+ if (fabsf(t) < 1e-6f) {
+ return 1.0f + t * 0.5f + (t * t) / 6.0f;
+ }
+ return (expf(t) - 1.0f) / t;
+ };
+ auto phi2_fn = [&](float t) -> float {
+ if (fabsf(t) < 1e-6f) {
+ return 0.5f + t / 6.0f + (t * t) / 24.0f;
+ }
+ float phi1_val = phi1_fn(t);
+ return (phi1_val - 1.0f) / t;
+ };
+
+ for (int i = 0; i < steps; i++) {
+ ggml_tensor* denoised = model(x, sigmas[i], i + 1);
+ if (denoised == nullptr) {
+ return false;
+ }
+
+ float sigma_from = sigmas[i];
+ float sigma_to = sigmas[i + 1];
+ float sigma_up = 0.0f;
+ float sigma_down = sigma_to;
+
+ if (eta > 0.0f) {
+ float sigma_from_sq = sigma_from * sigma_from;
+ float sigma_to_sq = sigma_to * sigma_to;
+ if (sigma_from_sq > 0.0f) {
+ float term = sigma_to_sq * (sigma_from_sq - sigma_to_sq) / sigma_from_sq;
+ if (term > 0.0f) {
+ sigma_up = eta * std::sqrt(term);
+ }
+ }
+ sigma_up = std::min(sigma_up, sigma_to);
+ float sigma_down_sq = sigma_to_sq - sigma_up * sigma_up;
+ sigma_down = sigma_down_sq > 0.0f ? std::sqrt(sigma_down_sq) : 0.0f;
+ }
+
+ if (sigma_down == 0.0f || !have_old_sigma) {
+ float dt = sigma_down - sigma_from;
+ float* vec_x = (float*)x->data;
+ float* vec_denoised = (float*)denoised->data;
+
+ for (int j = 0; j < ggml_nelements(x); j++) {
+ float d = (vec_x[j] - vec_denoised[j]) / sigma_from;
+ vec_x[j] = vec_x[j] + d * dt;
+ }
+ } else {
+ float t = t_fn(sigma_from);
+ float t_old = t_fn(old_sigma_down);
+ float t_next = t_fn(sigma_down);
+ float t_prev = t_fn(sigmas[i - 1]);
+ float h = t_next - t;
+ float c2 = (t_prev - t_old) / h;
+
+ float phi1_val = phi1_fn(-h);
+ float phi2_val = phi2_fn(-h);
+ float b1 = phi1_val - phi2_val / c2;
+ float b2 = phi2_val / c2;
+
+ if (!std::isfinite(b1)) {
+ b1 = 0.0f;
+ }
+ if (!std::isfinite(b2)) {
+ b2 = 0.0f;
+ }
+
+ float sigma_h = sigma_fn(h);
+ float* vec_x = (float*)x->data;
+ float* vec_denoised = (float*)denoised->data;
+ float* vec_old_denoised = (float*)old_denoised->data;
+
+ for (int j = 0; j < ggml_nelements(x); j++) {
+ vec_x[j] = sigma_h * vec_x[j] + h * (b1 * vec_denoised[j] + b2 * vec_old_denoised[j]);
+ }
+ }
+
+ if (sigmas[i + 1] > 0 && sigma_up > 0.0f) {
+ ggml_ext_im_set_randn_f32(noise, rng);
+ float* vec_x = (float*)x->data;
+ float* vec_noise = (float*)noise->data;
+
+ for (int j = 0; j < ggml_nelements(x); j++) {
+ vec_x[j] = vec_x[j] + vec_noise[j] * sigma_up;
+ }
+ }
+
+ float* vec_old_denoised = (float*)old_denoised->data;
+ float* vec_denoised = (float*)denoised->data;
+ for (int j = 0; j < ggml_nelements(x); j++) {
+ vec_old_denoised[j] = vec_denoised[j];
+ }
+
+ old_sigma_down = sigma_down;
+ have_old_sigma = true;
+ }
+ } break;
+ case RES_2S_SAMPLE_METHOD: // Res 2s sampler
+ {
+ struct ggml_tensor* noise = ggml_dup_tensor(work_ctx, x);
+ struct ggml_tensor* x0 = ggml_dup_tensor(work_ctx, x);
+ struct ggml_tensor* x2 = ggml_dup_tensor(work_ctx, x);
+
+ const float c2 = 0.5f;
+ auto t_fn = [](float sigma) -> float { return -logf(sigma); };
+ auto phi1_fn = [](float t) -> float {
+ if (fabsf(t) < 1e-6f) {
+ return 1.0f + t * 0.5f + (t * t) / 6.0f;
+ }
+ return (expf(t) - 1.0f) / t;
+ };
+ auto phi2_fn = [&](float t) -> float {
+ if (fabsf(t) < 1e-6f) {
+ return 0.5f + t / 6.0f + (t * t) / 24.0f;
+ }
+ float phi1_val = phi1_fn(t);
+ return (phi1_val - 1.0f) / t;
+ };
+
+ for (int i = 0; i < steps; i++) {
+ float sigma_from = sigmas[i];
+ float sigma_to = sigmas[i + 1];
+
+ ggml_tensor* denoised = model(x, sigma_from, -(i + 1));
+ if (denoised == nullptr) {
+ return false;
+ }
+
+ float sigma_up = 0.0f;
+ float sigma_down = sigma_to;
+ if (eta > 0.0f) {
+ float sigma_from_sq = sigma_from * sigma_from;
+ float sigma_to_sq = sigma_to * sigma_to;
+ if (sigma_from_sq > 0.0f) {
+ float term = sigma_to_sq * (sigma_from_sq - sigma_to_sq) / sigma_from_sq;
+ if (term > 0.0f) {
+ sigma_up = eta * std::sqrt(term);
+ }
+ }
+ sigma_up = std::min(sigma_up, sigma_to);
+ float sigma_down_sq = sigma_to_sq - sigma_up * sigma_up;
+ sigma_down = sigma_down_sq > 0.0f ? std::sqrt(sigma_down_sq) : 0.0f;
+ }
+
+ float* vec_x = (float*)x->data;
+ float* vec_x0 = (float*)x0->data;
+ for (int j = 0; j < ggml_nelements(x); j++) {
+ vec_x0[j] = vec_x[j];
+ }
+
+ if (sigma_down == 0.0f || sigma_from == 0.0f) {
+ float* vec_denoised = (float*)denoised->data;
+ for (int j = 0; j < ggml_nelements(x); j++) {
+ vec_x[j] = vec_denoised[j];
+ }
+ } else {
+ float t = t_fn(sigma_from);
+ float t_next = t_fn(sigma_down);
+ float h = t_next - t;
+
+ float a21 = c2 * phi1_fn(-h * c2);
+ float phi1_val = phi1_fn(-h);
+ float phi2_val = phi2_fn(-h);
+ float b2 = phi2_val / c2;
+ float b1 = phi1_val - b2;
+
+ float sigma_c2 = expf(-(t + h * c2));
+
+ float* vec_denoised = (float*)denoised->data;
+ float* vec_x2 = (float*)x2->data;
+ for (int j = 0; j < ggml_nelements(x); j++) {
+ float eps1 = vec_denoised[j] - vec_x0[j];
+ vec_x2[j] = vec_x0[j] + h * a21 * eps1;
+ }
+
+ ggml_tensor* denoised2 = model(x2, sigma_c2, i + 1);
+ if (denoised2 == nullptr) {
+ return false;
+ }
+ float* vec_denoised2 = (float*)denoised2->data;
+
+ for (int j = 0; j < ggml_nelements(x); j++) {
+ float eps1 = vec_denoised[j] - vec_x0[j];
+ float eps2 = vec_denoised2[j] - vec_x0[j];
+ vec_x[j] = vec_x0[j] + h * (b1 * eps1 + b2 * eps2);
+ }
+ }
+
+ if (sigmas[i + 1] > 0 && sigma_up > 0.0f) {
+ ggml_ext_im_set_randn_f32(noise, rng);
+ float* vec_x = (float*)x->data;
+ float* vec_noise = (float*)noise->data;
+
+ for (int j = 0; j < ggml_nelements(x); j++) {
+ vec_x[j] = vec_x[j] + vec_noise[j] * sigma_up;
+ }
+ }
+ }
+ } break;
default:
LOG_ERROR("Attempting to sample with nonexisting sample method %i", method);
diff --git a/otherarch/sdcpp/diffusion_model.hpp b/otherarch/sdcpp/diffusion_model.hpp
index 06cbecc28..3293ba9b7 100644
--- a/otherarch/sdcpp/diffusion_model.hpp
+++ b/otherarch/sdcpp/diffusion_model.hpp
@@ -38,7 +38,7 @@ struct DiffusionModel {
virtual size_t get_params_buffer_size() = 0;
virtual void set_weight_adapter(const std::shared_ptr& adapter){};
virtual int64_t get_adm_in_channels() = 0;
- virtual void set_flash_attn_enabled(bool enabled) = 0;
+ virtual void set_flash_attention_enabled(bool enabled) = 0;
virtual void set_circular_axes(bool circular_x, bool circular_y) = 0;
};
@@ -84,7 +84,7 @@ struct UNetModel : public DiffusionModel {
return unet.unet.adm_in_channels;
}
- void set_flash_attn_enabled(bool enabled) {
+ void set_flash_attention_enabled(bool enabled) {
unet.set_flash_attention_enabled(enabled);
}
@@ -149,7 +149,7 @@ struct MMDiTModel : public DiffusionModel {
return 768 + 1280;
}
- void set_flash_attn_enabled(bool enabled) {
+ void set_flash_attention_enabled(bool enabled) {
mmdit.set_flash_attention_enabled(enabled);
}
@@ -215,7 +215,7 @@ struct FluxModel : public DiffusionModel {
return 768;
}
- void set_flash_attn_enabled(bool enabled) {
+ void set_flash_attention_enabled(bool enabled) {
flux.set_flash_attention_enabled(enabled);
}
@@ -286,7 +286,7 @@ struct WanModel : public DiffusionModel {
return 768;
}
- void set_flash_attn_enabled(bool enabled) {
+ void set_flash_attention_enabled(bool enabled) {
wan.set_flash_attention_enabled(enabled);
}
@@ -357,7 +357,7 @@ struct QwenImageModel : public DiffusionModel {
return 768;
}
- void set_flash_attn_enabled(bool enabled) {
+ void set_flash_attention_enabled(bool enabled) {
qwen_image.set_flash_attention_enabled(enabled);
}
@@ -424,7 +424,7 @@ struct ZImageModel : public DiffusionModel {
return 768;
}
- void set_flash_attn_enabled(bool enabled) {
+ void set_flash_attention_enabled(bool enabled) {
z_image.set_flash_attention_enabled(enabled);
}
diff --git a/otherarch/sdcpp/ggml_extend.hpp b/otherarch/sdcpp/ggml_extend.hpp
index 8195a17a7..3419fa918 100644
--- a/otherarch/sdcpp/ggml_extend.hpp
+++ b/otherarch/sdcpp/ggml_extend.hpp
@@ -2623,7 +2623,7 @@ public:
v = v_proj->forward(ctx, x);
}
- x = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k, v, n_head, mask); // [N, n_token, embed_dim]
+ x = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k, v, n_head, mask, false); // [N, n_token, embed_dim]
x = out_proj->forward(ctx, x); // [N, n_token, embed_dim]
return x;
diff --git a/otherarch/sdcpp/name_conversion.cpp b/otherarch/sdcpp/name_conversion.cpp
index 3ae229b63..d3e863b8a 100644
--- a/otherarch/sdcpp/name_conversion.cpp
+++ b/otherarch/sdcpp/name_conversion.cpp
@@ -842,6 +842,7 @@ std::string convert_sep_to_dot(std::string name) {
"conv_in",
"conv_out",
"lora_down",
+ "lora_mid",
"lora_up",
"diff_b",
"hada_w1_a",
@@ -997,10 +998,13 @@ std::string convert_tensor_name(std::string name, SDVersion version) {
if (is_lora) {
std::map lora_suffix_map = {
{".lora_down.weight", ".weight.lora_down"},
+ {".lora_mid.weight", ".weight.lora_mid"},
{".lora_up.weight", ".weight.lora_up"},
{".lora.down.weight", ".weight.lora_down"},
+ {".lora.mid.weight", ".weight.lora_mid"},
{".lora.up.weight", ".weight.lora_up"},
{"_lora.down.weight", ".weight.lora_down"},
+ {"_lora.mid.weight", ".weight.lora_mid"},
{"_lora.up.weight", ".weight.lora_up"},
{".lora_A.weight", ".weight.lora_down"},
{".lora_B.weight", ".weight.lora_up"},
diff --git a/otherarch/sdcpp/sdtype_adapter.cpp b/otherarch/sdcpp/sdtype_adapter.cpp
index 114ff288b..61e6b3a4e 100644
--- a/otherarch/sdcpp/sdtype_adapter.cpp
+++ b/otherarch/sdcpp/sdtype_adapter.cpp
@@ -370,6 +370,9 @@ bool sdtype_load_model(const sd_load_model_inputs inputs) {
params.lora_apply_mode = (lora_apply_mode_t)lora_apply_mode;
// params.flow_shift = 5.0f;
+ // also switches flash attn for the vae and conditioner
+ params.flash_attn = params.diffusion_flash_attn;
+
if (params.chroma_use_dit_mask && params.diffusion_flash_attn) {
// note we don't know yet if it's a Chroma model
params.chroma_use_dit_mask = false;
@@ -620,6 +623,14 @@ static enum sample_method_t sampler_from_name(const std::string& sampler)
{
return sample_method_t::DPMPP2M_SAMPLE_METHOD;
}
+ else if(sampler=="res multistep" || sampler=="k_res_multistep")
+ {
+ return sample_method_t::RES_MULTISTEP_SAMPLE_METHOD;
+ }
+ else if(sampler=="res 2s" || sampler=="k_res_2s")
+ {
+ return sample_method_t::RES_2S_SAMPLE_METHOD;
+ }
else
{
return sample_method_t::SAMPLE_METHOD_COUNT;
diff --git a/otherarch/sdcpp/stable-diffusion.cpp b/otherarch/sdcpp/stable-diffusion.cpp
index b143883c0..b1c1549c1 100644
--- a/otherarch/sdcpp/stable-diffusion.cpp
+++ b/otherarch/sdcpp/stable-diffusion.cpp
@@ -69,6 +69,8 @@ const char* sampling_methods_str[] = {
"LCM",
"DDIM \"trailing\"",
"TCD",
+ "Res Multistep",
+ "Res 2s",
};
/*================================================== Helper Functions ================================================*/
@@ -583,7 +585,7 @@ public:
}
}
if (is_chroma) {
- if (sd_ctx_params->diffusion_flash_attn && sd_ctx_params->chroma_use_dit_mask) {
+ if ((sd_ctx_params->flash_attn || sd_ctx_params->diffusion_flash_attn) && sd_ctx_params->chroma_use_dit_mask) {
LOG_WARN(
"!!!It looks like you are using Chroma with flash attention. "
"This is currently unsupported. "
@@ -709,14 +711,6 @@ public:
}
}
- if (sd_ctx_params->diffusion_flash_attn) {
- LOG_INFO("Using flash attention in the diffusion model");
- diffusion_model->set_flash_attn_enabled(true);
- if (high_noise_diffusion_model) {
- high_noise_diffusion_model->set_flash_attn_enabled(true);
- }
- }
-
cond_stage_model->alloc_params_buffer();
cond_stage_model->get_param_tensors(tensors);
@@ -867,6 +861,28 @@ public:
pmid_model->get_param_tensors(tensors, "pmid");
}
+ if (sd_ctx_params->flash_attn) {
+ LOG_INFO("Using flash attention");
+ cond_stage_model->set_flash_attention_enabled(true);
+ if (clip_vision) {
+ clip_vision->set_flash_attention_enabled(true);
+ }
+ if (first_stage_model) {
+ first_stage_model->set_flash_attention_enabled(true);
+ }
+ if (tae_first_stage) {
+ tae_first_stage->set_flash_attention_enabled(true);
+ }
+ }
+
+ if (sd_ctx_params->flash_attn || sd_ctx_params->diffusion_flash_attn) {
+ LOG_INFO("Using flash attention in the diffusion model");
+ diffusion_model->set_flash_attention_enabled(true);
+ if (high_noise_diffusion_model) {
+ high_noise_diffusion_model->set_flash_attention_enabled(true);
+ }
+ }
+
diffusion_model->set_circular_axes(sd_ctx_params->circular_x, sd_ctx_params->circular_y);
if (high_noise_diffusion_model) {
high_noise_diffusion_model->set_circular_axes(sd_ctx_params->circular_x, sd_ctx_params->circular_y);
@@ -2907,6 +2923,8 @@ const char* sample_method_to_str[] = {
"lcm",
"ddim_trailing",
"tcd",
+ "res_multistep",
+ "res_2s",
};
const char* sd_sample_method_name(enum sample_method_t sample_method) {
@@ -2936,6 +2954,7 @@ const char* scheduler_to_str[] = {
"smoothstep",
"kl_optimal",
"lcm",
+ "bong_tangent",
};
const char* sd_scheduler_name(enum scheduler_t scheduler) {
@@ -3101,6 +3120,7 @@ char* sd_ctx_params_to_str(const sd_ctx_params_t* sd_ctx_params) {
"keep_clip_on_cpu: %s\n"
"keep_control_net_on_cpu: %s\n"
"keep_vae_on_cpu: %s\n"
+ "flash_attn: %s\n"
"diffusion_flash_attn: %s\n"
"circular_x: %s\n"
"circular_y: %s\n"
@@ -3132,6 +3152,7 @@ char* sd_ctx_params_to_str(const sd_ctx_params_t* sd_ctx_params) {
BOOL_STR(sd_ctx_params->keep_clip_on_cpu),
BOOL_STR(sd_ctx_params->keep_control_net_on_cpu),
BOOL_STR(sd_ctx_params->keep_vae_on_cpu),
+ BOOL_STR(sd_ctx_params->flash_attn),
BOOL_STR(sd_ctx_params->diffusion_flash_attn),
BOOL_STR(sd_ctx_params->circular_x),
BOOL_STR(sd_ctx_params->circular_y),
diff --git a/otherarch/sdcpp/stable-diffusion.h b/otherarch/sdcpp/stable-diffusion.h
index 9c6c6169e..ac5e7c9a9 100644
--- a/otherarch/sdcpp/stable-diffusion.h
+++ b/otherarch/sdcpp/stable-diffusion.h
@@ -48,6 +48,8 @@ enum sample_method_t {
LCM_SAMPLE_METHOD,
DDIM_TRAILING_SAMPLE_METHOD,
TCD_SAMPLE_METHOD,
+ RES_MULTISTEP_SAMPLE_METHOD,
+ RES_2S_SAMPLE_METHOD,
SAMPLE_METHOD_COUNT
};
@@ -62,6 +64,7 @@ enum scheduler_t {
SMOOTHSTEP_SCHEDULER,
KL_OPTIMAL_SCHEDULER,
LCM_SCHEDULER,
+ BONG_TANGENT_SCHEDULER,
SCHEDULER_COUNT
};
@@ -186,6 +189,7 @@ typedef struct {
bool keep_clip_on_cpu;
bool keep_control_net_on_cpu;
bool keep_vae_on_cpu;
+ bool flash_attn;
bool diffusion_flash_attn;
bool tae_preview_only;
bool diffusion_conv_direct;
diff --git a/otherarch/sdcpp/vae.hpp b/otherarch/sdcpp/vae.hpp
index 01b99e89b..01081343b 100644
--- a/otherarch/sdcpp/vae.hpp
+++ b/otherarch/sdcpp/vae.hpp
@@ -141,7 +141,7 @@ public:
v = ggml_reshape_3d(ctx->ggml_ctx, v, c, h * w, n); // [N, h * w, in_channels]
}
- h_ = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k, v, 1, nullptr, true, false);
+ h_ = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k, v, 1, nullptr, true, ctx->flash_attn_enabled);
if (use_linear) {
h_ = proj_out->forward(ctx, h_); // [N, h * w, in_channels]
diff --git a/otherarch/sdcpp/wan.hpp b/otherarch/sdcpp/wan.hpp
index 81959efcf..7b1059785 100644
--- a/otherarch/sdcpp/wan.hpp
+++ b/otherarch/sdcpp/wan.hpp
@@ -572,8 +572,8 @@ namespace WAN {
auto v = qkv_vec[2];
v = ggml_reshape_3d(ctx->ggml_ctx, v, h * w, c, n); // [t, c, h * w]
- v = ggml_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, v, 1, 0, 2, 3)); // [t, h * w, c]
- x = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k, v, 1, nullptr, true, false); // [t, h * w, c]
+ v = ggml_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, v, 1, 0, 2, 3)); // [t, h * w, c]
+ x = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k, v, 1, nullptr, true, ctx->flash_attn_enabled); // [t, h * w, c]
x = ggml_ext_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, x, 1, 0, 2, 3)); // [t, c, h * w]
x = ggml_reshape_4d(ctx->ggml_ctx, x, w, h, c, n); // [t, c, h, w]