CUDA: fix half2 -> half conversion for HIP (#15529)

This commit is contained in:
Johannes Gäßler 2025-08-23 21:37:06 +02:00 committed by GitHub
parent 611f419cff
commit 710dfc465a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -258,7 +258,7 @@ static __global__ void flash_attn_tile_ext_f16(
const half val = hexp(sink - kqmax[j0/nwarps]); const half val = hexp(sink - kqmax[j0/nwarps]);
kqsum[j0/nwarps] = kqsum[j0/nwarps] * KQ_max_scale; kqsum[j0/nwarps] = kqsum[j0/nwarps] * KQ_max_scale;
if (threadIdx.x == 0) { if (threadIdx.x == 0) {
kqsum[j0/nwarps].x = __hadd(kqsum[j0/nwarps].x, val); kqsum[j0/nwarps].x = __hadd(__low2half(kqsum[j0/nwarps]), val);
} }
#pragma unroll #pragma unroll