fix Chroma workaround for flash attention (#1746)

chroma_use_dit_mask is a context parameter, so changing it
after creating the context has no effect.
This commit is contained in:
Wagner Bruna 2025-09-21 09:53:40 -03:00 committed by GitHub
parent 9e7661352c
commit dd0bf706a3
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -338,6 +338,11 @@ bool sdtype_load_model(const sd_load_model_inputs inputs) {
params.chroma_use_t5_mask = sd_params->chroma_use_t5_mask;
params.chroma_t5_mask_pad = sd_params->chroma_t5_mask_pad;
if (params.chroma_use_dit_mask && params.diffusion_flash_attn) {
// note we don't know yet if it's a Chroma model
params.chroma_use_dit_mask = false;
}
sd_ctx = new_sd_ctx(&params);
if (sd_ctx == NULL) {
@ -346,6 +351,14 @@ bool sdtype_load_model(const sd_load_model_inputs inputs) {
return false;
}
if (!sd_is_quiet) {
if (loaded_model_is_chroma(sd_ctx) && sd_params->diffusion_flash_attn && sd_params->chroma_use_dit_mask)
{
printf("Chroma: flash attention is on, disabling DiT mask (this will lower image quality)\n");
// disabled before loading
}
}
std::filesystem::path mpath(inputs.model_filename);
sdmodelfilename = mpath.filename().string();
@ -528,22 +541,12 @@ sd_generation_outputs sdtype_generate(const sd_generation_inputs inputs)
auto loadedsdver = get_loaded_sd_version(sd_ctx);
if (loadedsdver == SDVersion::VERSION_FLUX)
{
if (loaded_model_is_chroma(sd_ctx)) {
if (sd_params->diffusion_flash_attn && sd_params->chroma_use_dit_mask) {
if (!sd_is_quiet && sddebugmode) {
printf("Chroma: flash attention is on, disabling DiT mask\n");
}
sd_params->chroma_use_dit_mask = false;
}
}
else {
if (sd_params->cfg_scale != 1.0f) {
//non chroma clamp cfg scale
if (!sd_is_quiet && sddebugmode) {
printf("Flux: clamping CFG Scale to 1\n");
}
sd_params->cfg_scale = 1.0f;
if (!loaded_model_is_chroma(sd_ctx) && sd_params->cfg_scale != 1.0f) {
//non chroma clamp cfg scale
if (!sd_is_quiet && sddebugmode) {
printf("Flux: clamping CFG Scale to 1\n");
}
sd_params->cfg_scale = 1.0f;
}
if (sampler == "euler a" || sampler == "k_euler_a" || sampler == "euler_a") {
//euler a broken on flux