From dd0bf706a32805f3c5d12e9d32fc2406ad5ae5f6 Mon Sep 17 00:00:00 2001 From: Wagner Bruna Date: Sun, 21 Sep 2025 09:53:40 -0300 Subject: [PATCH] fix Chroma workaround for flash attention (#1746) chroma_use_dit_mask is a context parameter, so changing it after creating the context has no effect. --- otherarch/sdcpp/sdtype_adapter.cpp | 33 ++++++++++++++++-------------- 1 file changed, 18 insertions(+), 15 deletions(-) diff --git a/otherarch/sdcpp/sdtype_adapter.cpp b/otherarch/sdcpp/sdtype_adapter.cpp index b962a60a5..e5e5a2ac7 100644 --- a/otherarch/sdcpp/sdtype_adapter.cpp +++ b/otherarch/sdcpp/sdtype_adapter.cpp @@ -338,6 +338,11 @@ bool sdtype_load_model(const sd_load_model_inputs inputs) { params.chroma_use_t5_mask = sd_params->chroma_use_t5_mask; params.chroma_t5_mask_pad = sd_params->chroma_t5_mask_pad; + if (params.chroma_use_dit_mask && params.diffusion_flash_attn) { + // note we don't know yet if it's a Chroma model + params.chroma_use_dit_mask = false; + } + sd_ctx = new_sd_ctx(¶ms); if (sd_ctx == NULL) { @@ -346,6 +351,14 @@ bool sdtype_load_model(const sd_load_model_inputs inputs) { return false; } + if (!sd_is_quiet) { + if (loaded_model_is_chroma(sd_ctx) && sd_params->diffusion_flash_attn && sd_params->chroma_use_dit_mask) + { + printf("Chroma: flash attention is on, disabling DiT mask (this will lower image quality)\n"); + // disabled before loading + } + } + std::filesystem::path mpath(inputs.model_filename); sdmodelfilename = mpath.filename().string(); @@ -528,22 +541,12 @@ sd_generation_outputs sdtype_generate(const sd_generation_inputs inputs) auto loadedsdver = get_loaded_sd_version(sd_ctx); if (loadedsdver == SDVersion::VERSION_FLUX) { - if (loaded_model_is_chroma(sd_ctx)) { - if (sd_params->diffusion_flash_attn && sd_params->chroma_use_dit_mask) { - if (!sd_is_quiet && sddebugmode) { - printf("Chroma: flash attention is on, disabling DiT mask\n"); - } - sd_params->chroma_use_dit_mask = false; - } - } - else { - if (sd_params->cfg_scale != 1.0f) { - //non chroma clamp cfg scale - if (!sd_is_quiet && sddebugmode) { - printf("Flux: clamping CFG Scale to 1\n"); - } - sd_params->cfg_scale = 1.0f; + if (!loaded_model_is_chroma(sd_ctx) && sd_params->cfg_scale != 1.0f) { + //non chroma clamp cfg scale + if (!sd_is_quiet && sddebugmode) { + printf("Flux: clamping CFG Scale to 1\n"); } + sd_params->cfg_scale = 1.0f; } if (sampler == "euler a" || sampler == "k_euler_a" || sampler == "euler_a") { //euler a broken on flux