diff --git a/ggml/src/ggml-cuda/fattn-mma-f16.cuh b/ggml/src/ggml-cuda/fattn-mma-f16.cuh index c59cd5371..d777f5413 100644 --- a/ggml/src/ggml-cuda/fattn-mma-f16.cuh +++ b/ggml/src/ggml-cuda/fattn-mma-f16.cuh @@ -89,7 +89,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( float2 & KQ_max, float2 & KQ_rowsum, const int kb0) { - +#ifdef NEW_MMA_AVAILABLE constexpr int np = nwarps*tile_B::I / ncols; // Number of parallel CUDA warps per Q column. constexpr int D2_padded = D/2 + 4; // Size of D in half2, padded to avoid shared memory bank conflicts. @@ -241,6 +241,10 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( #ifndef CP_ASYNC_AVAILABLE __syncthreads(); // Only needed if tile_K == tile_V. #endif // CP_ASYNC_AVAILABLE + +#else + NO_DEVICE_CODE; +#endif // NEW_MMA_AVAILABLE } template @@ -262,6 +266,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( const int jt, const int kb0_start, const int kb0_stop) { +#ifdef NEW_MMA_AVAILABLE //In this kernel Q, K, V are matrices while i, j, k are matrix indices. static_assert(nwarps*tile_B::I % ncols == 0, "bad nwarps"); @@ -518,6 +523,9 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( if (np > 1) { __syncthreads(); } +#else + NO_DEVICE_CODE; +#endif // NEW_MMA_AVAILABLE } template