vulkan: scalar flash attention implementation

This commit is contained in:
Jeff Bolz 2025-05-05 19:34:23 -05:00
parent 9070365020
commit 005756a2a9
3 changed files with 536 additions and 77 deletions

View file

@ -421,7 +421,6 @@ void process_shaders() {
#endif
}
#if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
// flash attention
for (const auto& f16acc : {false, true}) {
std::string acctype = f16acc ? "float16_t" : "float";
@ -432,6 +431,7 @@ void process_shaders() {
}
if (tname == "bf16") continue;
#if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
if (tname == "f16") {
string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm2.comp",
merge_maps(base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"ACC_TYPE", acctype}}), true, false, true, f16acc);
@ -440,9 +440,13 @@ void process_shaders() {
string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm2.comp",
merge_maps(base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"ACC_TYPE", acctype}, {"DEQUANTFUNC", "dequantFunc"+to_uppercase(tname) }, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname) }}), true, false, true, f16acc);
}
#endif
if (tname == "f16") {
string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn.comp",
merge_maps(base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"ACC_TYPE", acctype}}), true, false, false, f16acc);
} // quants not supported yet
}
}
#endif
for (const auto& tname : type_names) {
// mul mat vec