mirror of
https://github.com/LostRuins/koboldcpp.git
synced 2026-05-07 17:22:04 +00:00
try fix fattn again, porting some older code. the cc detection is not working well, so its hacky
This commit is contained in:
parent
9423de5ea2
commit
7b04191eac
1 changed files with 14 additions and 2 deletions
|
|
@ -298,6 +298,12 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
|
|||
const int warp_size = ggml_cuda_info().devices[device].warp_size;
|
||||
const enum ggml_prec prec = ggml_flash_attn_ext_get_prec(KQV);
|
||||
|
||||
#if defined(GGML_HIP_ROCWMMA_FATTN)
|
||||
if (GGML_CUDA_CC_IS_AMD(cc) && fp16_mma_available(cc)) { //kcpp: fix for rocwmma
|
||||
return BEST_FATTN_KERNEL_WMMA_F16;
|
||||
}
|
||||
#endif // defined(GGML_HIP_ROCWMMA_FATTN)
|
||||
|
||||
switch (K->ne[0]) {
|
||||
case 64:
|
||||
case 128:
|
||||
|
|
@ -415,15 +421,21 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
|
|||
return BEST_FATTN_KERNEL_WMMA_F16;
|
||||
}
|
||||
|
||||
//kcpp: always force WMMA for older gpus, fix issues like "FlashAttention without tensor cores only supports head sizes 64 and 128."
|
||||
if (ggml_cuda_highest_compiled_arch(cc) <= GGML_CUDA_CC_TURING || cc == GGML_CUDA_CC_TURING) {
|
||||
//kcpp: always force WMMA for Turing and Volta if above check fails, fix "FlashAttention without tensor cores only supports head sizes 64 and 128."
|
||||
if (cc == GGML_CUDA_CC_TURING || cc == GGML_CUDA_CC_VOLTA) {
|
||||
return BEST_FATTN_KERNEL_WMMA_F16;
|
||||
}
|
||||
|
||||
// If there is no suitable kernel for tensor cores or small batch sizes, use the generic kernel for large batch sizes:
|
||||
if (prec == GGML_PREC_DEFAULT && fast_fp16_available(cc)) {
|
||||
if (Q->ne[1] <= 8 || Q->ne[0] == 256) {
|
||||
return BEST_FATTN_KERNEL_VEC_F16; //kcpp: patch from previous version for my sanity. it worked before, idk it should work now.
|
||||
}
|
||||
return BEST_FATTN_KERNEL_TILE_F16;
|
||||
}
|
||||
if (Q->ne[1] <= 8 || Q->ne[0] == 256) {
|
||||
return BEST_FATTN_KERNEL_VEC_F32; //kcpp: patch from previous version for my sanity. it worked before, idk it should work now.
|
||||
}
|
||||
return BEST_FATTN_KERNEL_TILE_F32;
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue