try CI fix

This commit is contained in:
Johannes Gäßler 2025-02-15 22:44:27 +01:00
parent eb4f7954b6
commit 727db805a2

View file

@ -89,7 +89,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
float2 & KQ_max,
float2 & KQ_rowsum,
const int kb0) {
#ifdef NEW_MMA_AVAILABLE
constexpr int np = nwarps*tile_B::I / ncols; // Number of parallel CUDA warps per Q column.
constexpr int D2_padded = D/2 + 4; // Size of D in half2, padded to avoid shared memory bank conflicts.
@ -241,6 +241,10 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
#ifndef CP_ASYNC_AVAILABLE
__syncthreads(); // Only needed if tile_K == tile_V.
#endif // CP_ASYNC_AVAILABLE
#else
NO_DEVICE_CODE;
#endif // NEW_MMA_AVAILABLE
}
template<int D, int ncols, int nwarps, int KQ_stride, bool use_logit_softcap, bool needs_fixup, bool is_fixup>
@ -262,6 +266,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
const int jt,
const int kb0_start,
const int kb0_stop) {
#ifdef NEW_MMA_AVAILABLE
//In this kernel Q, K, V are matrices while i, j, k are matrix indices.
static_assert(nwarps*tile_B::I % ncols == 0, "bad nwarps");
@ -518,6 +523,9 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
if (np > 1) {
__syncthreads();
}
#else
NO_DEVICE_CODE;
#endif // NEW_MMA_AVAILABLE
}
template<int D, int ncols, int nwarps, int KQ_stride, bool use_logit_softcap>