mirror of
https://github.com/LostRuins/koboldcpp.git
synced 2026-05-08 01:41:37 +00:00
fix Chroma workaround for flash attention (#1746)
chroma_use_dit_mask is a context parameter, so changing it after creating the context has no effect.
This commit is contained in:
parent
9e7661352c
commit
dd0bf706a3
1 changed files with 18 additions and 15 deletions
|
|
@ -338,6 +338,11 @@ bool sdtype_load_model(const sd_load_model_inputs inputs) {
|
|||
params.chroma_use_t5_mask = sd_params->chroma_use_t5_mask;
|
||||
params.chroma_t5_mask_pad = sd_params->chroma_t5_mask_pad;
|
||||
|
||||
if (params.chroma_use_dit_mask && params.diffusion_flash_attn) {
|
||||
// note we don't know yet if it's a Chroma model
|
||||
params.chroma_use_dit_mask = false;
|
||||
}
|
||||
|
||||
sd_ctx = new_sd_ctx(¶ms);
|
||||
|
||||
if (sd_ctx == NULL) {
|
||||
|
|
@ -346,6 +351,14 @@ bool sdtype_load_model(const sd_load_model_inputs inputs) {
|
|||
return false;
|
||||
}
|
||||
|
||||
if (!sd_is_quiet) {
|
||||
if (loaded_model_is_chroma(sd_ctx) && sd_params->diffusion_flash_attn && sd_params->chroma_use_dit_mask)
|
||||
{
|
||||
printf("Chroma: flash attention is on, disabling DiT mask (this will lower image quality)\n");
|
||||
// disabled before loading
|
||||
}
|
||||
}
|
||||
|
||||
std::filesystem::path mpath(inputs.model_filename);
|
||||
sdmodelfilename = mpath.filename().string();
|
||||
|
||||
|
|
@ -528,22 +541,12 @@ sd_generation_outputs sdtype_generate(const sd_generation_inputs inputs)
|
|||
auto loadedsdver = get_loaded_sd_version(sd_ctx);
|
||||
if (loadedsdver == SDVersion::VERSION_FLUX)
|
||||
{
|
||||
if (loaded_model_is_chroma(sd_ctx)) {
|
||||
if (sd_params->diffusion_flash_attn && sd_params->chroma_use_dit_mask) {
|
||||
if (!sd_is_quiet && sddebugmode) {
|
||||
printf("Chroma: flash attention is on, disabling DiT mask\n");
|
||||
}
|
||||
sd_params->chroma_use_dit_mask = false;
|
||||
}
|
||||
}
|
||||
else {
|
||||
if (sd_params->cfg_scale != 1.0f) {
|
||||
//non chroma clamp cfg scale
|
||||
if (!sd_is_quiet && sddebugmode) {
|
||||
printf("Flux: clamping CFG Scale to 1\n");
|
||||
}
|
||||
sd_params->cfg_scale = 1.0f;
|
||||
if (!loaded_model_is_chroma(sd_ctx) && sd_params->cfg_scale != 1.0f) {
|
||||
//non chroma clamp cfg scale
|
||||
if (!sd_is_quiet && sddebugmode) {
|
||||
printf("Flux: clamping CFG Scale to 1\n");
|
||||
}
|
||||
sd_params->cfg_scale = 1.0f;
|
||||
}
|
||||
if (sampler == "euler a" || sampler == "k_euler_a" || sampler == "euler_a") {
|
||||
//euler a broken on flux
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue