From 7b04191eac8354a3fc00dcbb788f468b1bd709cd Mon Sep 17 00:00:00 2001 From: Concedo <39025047+LostRuins@users.noreply.github.com> Date: Mon, 25 Aug 2025 21:22:36 +0800 Subject: [PATCH] try fix fattn again, porting some older code. the cc detection is not working well, so its hacky --- ggml/src/ggml-cuda/fattn.cu | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/ggml/src/ggml-cuda/fattn.cu b/ggml/src/ggml-cuda/fattn.cu index f8b492f8c..aebcf2c9f 100644 --- a/ggml/src/ggml-cuda/fattn.cu +++ b/ggml/src/ggml-cuda/fattn.cu @@ -298,6 +298,12 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const const int warp_size = ggml_cuda_info().devices[device].warp_size; const enum ggml_prec prec = ggml_flash_attn_ext_get_prec(KQV); + #if defined(GGML_HIP_ROCWMMA_FATTN) + if (GGML_CUDA_CC_IS_AMD(cc) && fp16_mma_available(cc)) { //kcpp: fix for rocwmma + return BEST_FATTN_KERNEL_WMMA_F16; + } + #endif // defined(GGML_HIP_ROCWMMA_FATTN) + switch (K->ne[0]) { case 64: case 128: @@ -415,15 +421,21 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const return BEST_FATTN_KERNEL_WMMA_F16; } - //kcpp: always force WMMA for older gpus, fix issues like "FlashAttention without tensor cores only supports head sizes 64 and 128." - if (ggml_cuda_highest_compiled_arch(cc) <= GGML_CUDA_CC_TURING || cc == GGML_CUDA_CC_TURING) { + //kcpp: always force WMMA for Turing and Volta if above check fails, fix "FlashAttention without tensor cores only supports head sizes 64 and 128." + if (cc == GGML_CUDA_CC_TURING || cc == GGML_CUDA_CC_VOLTA) { return BEST_FATTN_KERNEL_WMMA_F16; } // If there is no suitable kernel for tensor cores or small batch sizes, use the generic kernel for large batch sizes: if (prec == GGML_PREC_DEFAULT && fast_fp16_available(cc)) { + if (Q->ne[1] <= 8 || Q->ne[0] == 256) { + return BEST_FATTN_KERNEL_VEC_F16; //kcpp: patch from previous version for my sanity. it worked before, idk it should work now. + } return BEST_FATTN_KERNEL_TILE_F16; } + if (Q->ne[1] <= 8 || Q->ne[0] == 256) { + return BEST_FATTN_KERNEL_VEC_F32; //kcpp: patch from previous version for my sanity. it worked before, idk it should work now. + } return BEST_FATTN_KERNEL_TILE_F32; }