From d74c16e6e0b37c3d1d7bf01fc3406ebf4460aeaf Mon Sep 17 00:00:00 2001 From: Wagner Bruna Date: Sat, 5 Jul 2025 00:20:51 -0300 Subject: [PATCH] enable flash attention for image generation (#1633) --- expose.h | 1 + koboldcpp.py | 2 ++ otherarch/sdcpp/sdtype_adapter.cpp | 5 +++++ 3 files changed, 8 insertions(+) diff --git a/expose.h b/expose.h index 8f299e07e..cafa4bf90 100644 --- a/expose.h +++ b/expose.h @@ -161,6 +161,7 @@ struct sd_load_model_inputs const char * vulkan_info = nullptr; const int threads = 0; const int quant = 0; + const bool flash_attention = false; const bool taesd = false; const int tiled_vae_threshold = 0; const char * t5xxl_filename = nullptr; diff --git a/koboldcpp.py b/koboldcpp.py index 75e979f50..329e34794 100644 --- a/koboldcpp.py +++ b/koboldcpp.py @@ -273,6 +273,7 @@ class sd_load_model_inputs(ctypes.Structure): ("vulkan_info", ctypes.c_char_p), ("threads", ctypes.c_int), ("quant", ctypes.c_int), + ("flash_attention", ctypes.c_bool), ("taesd", ctypes.c_bool), ("tiled_vae_threshold", ctypes.c_int), ("t5xxl_filename", ctypes.c_char_p), @@ -1624,6 +1625,7 @@ def sd_load_model(model_filename,vae_filename,lora_filename,t5xxl_filename,clipl inputs.threads = thds inputs.quant = quant + inputs.flash_attention = args.flashattention inputs.taesd = True if args.sdvaeauto else False inputs.tiled_vae_threshold = args.sdtiledvae inputs.vae_filename = vae_filename.encode("UTF-8") diff --git a/otherarch/sdcpp/sdtype_adapter.cpp b/otherarch/sdcpp/sdtype_adapter.cpp index 6e111e74c..b50da61c6 100644 --- a/otherarch/sdcpp/sdtype_adapter.cpp +++ b/otherarch/sdcpp/sdtype_adapter.cpp @@ -179,6 +179,10 @@ bool sdtype_load_model(const sd_load_model_inputs inputs) { printf("With PhotoMaker Model: %s\n",photomaker_filename.c_str()); photomaker_enabled = true; } + if(inputs.flash_attention) + { + printf("Flash Attention is enabled\n"); + } if(inputs.quant) { printf("Note: Loading a pre-quantized model is always faster than using compress weights!\n"); @@ -213,6 +217,7 @@ bool sdtype_load_model(const sd_load_model_inputs inputs) { sd_params->model_path = inputs.model_filename; sd_params->wtype = (inputs.quant==0?SD_TYPE_COUNT:SD_TYPE_Q4_0); sd_params->n_threads = inputs.threads; //if -1 use physical cores + sd_params->diffusion_flash_attn = inputs.flash_attention; sd_params->input_path = ""; //unused sd_params->batch_count = 1; sd_params->vae_path = vaefilename;