mirror of
https://github.com/LostRuins/koboldcpp.git
synced 2025-09-14 10:59:41 +00:00
vulkan: scalar flash attention implementation
This commit is contained in:
parent
9070365020
commit
005756a2a9
3 changed files with 536 additions and 77 deletions
|
@ -421,7 +421,6 @@ void process_shaders() {
|
|||
#endif
|
||||
}
|
||||
|
||||
#if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
|
||||
// flash attention
|
||||
for (const auto& f16acc : {false, true}) {
|
||||
std::string acctype = f16acc ? "float16_t" : "float";
|
||||
|
@ -432,6 +431,7 @@ void process_shaders() {
|
|||
}
|
||||
if (tname == "bf16") continue;
|
||||
|
||||
#if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
|
||||
if (tname == "f16") {
|
||||
string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm2.comp",
|
||||
merge_maps(base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"ACC_TYPE", acctype}}), true, false, true, f16acc);
|
||||
|
@ -440,9 +440,13 @@ void process_shaders() {
|
|||
string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm2.comp",
|
||||
merge_maps(base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"ACC_TYPE", acctype}, {"DEQUANTFUNC", "dequantFunc"+to_uppercase(tname) }, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname) }}), true, false, true, f16acc);
|
||||
}
|
||||
#endif
|
||||
if (tname == "f16") {
|
||||
string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn.comp",
|
||||
merge_maps(base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"ACC_TYPE", acctype}}), true, false, false, f16acc);
|
||||
} // quants not supported yet
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
for (const auto& tname : type_names) {
|
||||
// mul mat vec
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue