mirror of
https://github.com/LostRuins/koboldcpp.git
synced 2025-09-13 10:29:43 +00:00
try hack in missing hmax2 functions (+1 squashed commits)
Squashed commits: [c98d0ab6] try hack in missing hmax2 functions (+1 squashed commits) Squashed commits: [9ba8599f] try hack in missing hmax2 functions (+2 squashed commit) Squashed commit: [be497493] try hack in missing hmax2 functions [159ee4c3] bypass missing hmax functions on old cuda
This commit is contained in:
parent
b48ea96ead
commit
cea46750b0
1 changed files with 26 additions and 3 deletions
|
@ -11,6 +11,29 @@
|
|||
#define HALF_MAX_HALF __float2half(65504.0f/2) // Use neg. of this instead of -INFINITY to initialize KQ max vals to avoid NaN upon subtraction.
|
||||
#define SOFTMAX_FTZ_THRESHOLD -20.0f // Softmax exp. of values smaller than this are flushed to zero to avoid NaNs.
|
||||
|
||||
//hack: polyfill hmax and hmax2 for older cuda version
|
||||
#if CUDART_VERSION < CUDART_HMAX
|
||||
__device__ __inline__ __half hmax(const __half a, const __half b) {
|
||||
const float fa = __half2float(a);
|
||||
const float fb = __half2float(b);
|
||||
return __float2half(fa > fb ? fa : fb);
|
||||
}
|
||||
__device__ __inline__ __half2 hmax2(const __half2 a, const __half2 b) {
|
||||
__half2 result;
|
||||
result.x = hmax(a.x, b.x);
|
||||
result.y = hmax(a.y, b.y);
|
||||
return result;
|
||||
}
|
||||
#else
|
||||
__device__ __inline__ __half hmax(const __half a, const __half b) {
|
||||
return __hmax(a,b);
|
||||
}
|
||||
__device__ __inline__ __half2 hmax2(const __half2 a, const __half2 b) {
|
||||
return __hmax2(a,b);
|
||||
}
|
||||
#endif
|
||||
|
||||
|
||||
template<int D, int parallel_blocks> // D == head size
|
||||
__launch_bounds__(((D + WARP_SIZE - 1) / WARP_SIZE)*WARP_SIZE, 1)
|
||||
static __global__ void flash_attn_vec_ext_f16(
|
||||
|
@ -116,7 +139,7 @@ static __global__ void flash_attn_vec_ext_f16(
|
|||
sum2 = warp_reduce_sum(sum2);
|
||||
half sum = __low2half(sum2) + __high2half(sum2);
|
||||
sum += mask ? maskh[k_VKQ_0 + i_KQ] : __float2half(0.0f);
|
||||
kqmax_new = __hmax(kqmax_new, sum);
|
||||
kqmax_new = hmax(kqmax_new, sum);
|
||||
if (threadIdx.x == 0) {
|
||||
KQ[i_KQ] = sum;
|
||||
}
|
||||
|
@ -416,9 +439,9 @@ static __global__ void flash_attn_ext_f16(
|
|||
const int k = k0 + threadIdx.x;
|
||||
|
||||
KQ2_tmp[k0/WARP_SIZE] += mask ? mask2[(j*ne11 + k_VKQ_0)/2 + k] : make_half2(0.0f, 0.0f);
|
||||
KQ_max_new = __hmax2(KQ_max_new, KQ2_tmp[k0/WARP_SIZE]);
|
||||
KQ_max_new = hmax2(KQ_max_new, KQ2_tmp[k0/WARP_SIZE]);
|
||||
}
|
||||
KQ_max_new = __half2half2(warp_reduce_max(__hmax(__low2half(KQ_max_new), __high2half(KQ_max_new))));
|
||||
KQ_max_new = __half2half2(warp_reduce_max(hmax(__low2half(KQ_max_new), __high2half(KQ_max_new))));
|
||||
const half2 diff = KQ_max_h2[j0/nwarps] - KQ_max_new;
|
||||
KQ_max_scale_h2[j0/nwarps] = h2exp(diff);
|
||||
const uint32_t ftz_mask = __hgt2_mask(diff, make_half2(SOFTMAX_FTZ_THRESHOLD, SOFTMAX_FTZ_THRESHOLD));
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue