From aaa3d07ae749b781d6135eaff23c7fa8a4ab404a Mon Sep 17 00:00:00 2001 From: lhez Date: Fri, 8 Aug 2025 13:47:03 +0900 Subject: [PATCH 1/4] opencl: support sink in `soft_max` (attn sinks) (#15152) --- ggml/src/ggml-opencl/ggml-opencl.cpp | 49 +++++++++++-------- ggml/src/ggml-opencl/kernels/softmax_4_f16.cl | 12 ++++- ggml/src/ggml-opencl/kernels/softmax_4_f32.cl | 12 ++++- ggml/src/ggml-opencl/kernels/softmax_f16.cl | 12 ++++- ggml/src/ggml-opencl/kernels/softmax_f32.cl | 12 ++++- 5 files changed, 68 insertions(+), 29 deletions(-) diff --git a/ggml/src/ggml-opencl/ggml-opencl.cpp b/ggml/src/ggml-opencl/ggml-opencl.cpp index 4f765ab53..b32d5da30 100644 --- a/ggml/src/ggml-opencl/ggml-opencl.cpp +++ b/ggml/src/ggml-opencl/ggml-opencl.cpp @@ -2520,8 +2520,6 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te case GGML_OP_CLAMP: return op->src[0]->type == GGML_TYPE_F32; case GGML_OP_SOFT_MAX: - // TODO: support attention sinks [TAG_ATTN_SINKS] - return op->src[2] == nullptr; case GGML_OP_NORM: case GGML_OP_RMS_NORM: return true; @@ -6594,17 +6592,24 @@ static void ggml_cl_soft_max(ggml_backend_t backend, const ggml_tensor * src0, c GGML_ASSERT(src1->extra); } + const ggml_tensor * src2 = dst->src[2]; + if (src2) { + GGML_ASSERT(src2->extra); + } + ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra; ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra; ggml_tensor_extra_cl * extra1 = src1 ? (ggml_tensor_extra_cl *)src1->extra : nullptr; + ggml_tensor_extra_cl * extra2 = src2 ? (ggml_tensor_extra_cl *)src2->extra : nullptr; cl_ulong offset0 = extra0->offset + src0->view_offs; cl_ulong offsetd = extrad->offset + dst->view_offs; cl_ulong offset1 = extra1 ? extra1->offset + src1->view_offs : offset0; + cl_ulong offset2 = extra2 ? extra2->offset + src2->view_offs : offset0; const int ne00 = src0->ne[0]; const int ne01 = src0->ne[1]; @@ -6672,25 +6677,27 @@ static void ggml_cl_soft_max(ggml_backend_t backend, const ggml_tensor * src0, c CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), extra1 ? &extra1->data_device : &extra0->data_device)); CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1)); - CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device)); - CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd)); - CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne00)); - CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_ulong), &nb01)); - CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &nb02)); - CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &nb03)); - CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &ne12)); - CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int), &ne13)); - CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb11)); - CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_ulong), &nb12)); - CL_CHECK(clSetKernelArg(kernel, 14, sizeof(cl_ulong), &nb13)); - CL_CHECK(clSetKernelArg(kernel, 15, sizeof(cl_ulong), &nb1)); - CL_CHECK(clSetKernelArg(kernel, 16, sizeof(cl_ulong), &nb2)); - CL_CHECK(clSetKernelArg(kernel, 17, sizeof(cl_ulong), &nb3)); - CL_CHECK(clSetKernelArg(kernel, 18, sizeof(float), &scale)); - CL_CHECK(clSetKernelArg(kernel, 19, sizeof(float), &max_bias)); - CL_CHECK(clSetKernelArg(kernel, 20, sizeof(float), &m0)); - CL_CHECK(clSetKernelArg(kernel, 21, sizeof(float), &m1)); - CL_CHECK(clSetKernelArg(kernel, 22, sizeof(int), &n_head_log2)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), extra2 ? &extra2->data_device : &extra0->data_device)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offset2)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &nb01)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb02)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb03)); + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &ne12)); + CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &ne13)); + CL_CHECK(clSetKernelArg(kernel, 14, sizeof(cl_ulong), &nb11)); + CL_CHECK(clSetKernelArg(kernel, 15, sizeof(cl_ulong), &nb12)); + CL_CHECK(clSetKernelArg(kernel, 16, sizeof(cl_ulong), &nb13)); + CL_CHECK(clSetKernelArg(kernel, 17, sizeof(cl_ulong), &nb1)); + CL_CHECK(clSetKernelArg(kernel, 18, sizeof(cl_ulong), &nb2)); + CL_CHECK(clSetKernelArg(kernel, 19, sizeof(cl_ulong), &nb3)); + CL_CHECK(clSetKernelArg(kernel, 20, sizeof(float), &scale)); + CL_CHECK(clSetKernelArg(kernel, 21, sizeof(float), &max_bias)); + CL_CHECK(clSetKernelArg(kernel, 22, sizeof(float), &m0)); + CL_CHECK(clSetKernelArg(kernel, 23, sizeof(float), &m1)); + CL_CHECK(clSetKernelArg(kernel, 24, sizeof(int), &n_head_log2)); size_t global_work_size[] = {(size_t)ne01*nth, (size_t)ne02, (size_t)ne03}; size_t local_work_size[] = {(size_t)nth, 1, 1}; diff --git a/ggml/src/ggml-opencl/kernels/softmax_4_f16.cl b/ggml/src/ggml-opencl/kernels/softmax_4_f16.cl index a6d8ede67..571d16507 100644 --- a/ggml/src/ggml-opencl/kernels/softmax_4_f16.cl +++ b/ggml/src/ggml-opencl/kernels/softmax_4_f16.cl @@ -26,6 +26,8 @@ kernel void kernel_soft_max_4_f16( ulong offset0, global char * src1, ulong offset1, + global char * src2, + ulong offset2, global char * dst, ulong offsetd, int ne00, @@ -48,6 +50,7 @@ kernel void kernel_soft_max_4_f16( ) { src0 = src0 + offset0; src1 = src1 + offset1; + src2 = src2 + offset2; dst = dst + offsetd; int i03 = get_group_id(2); @@ -60,6 +63,7 @@ kernel void kernel_soft_max_4_f16( global float4 * psrc4 = (global float4 *)(src0 + i01*nb01 + i02*nb02 + i03*nb03); global half4 * pmask = src1 != src0 ? (global half4 *)(src1 + i11*nb11 + i12*nb12 + i13*nb13) : 0; + global float * psrc2 = src2 != src0 ? (global float *)(src2) : 0; global float4 * pdst4 = (global float4 *)(dst + i01*nb1 + i02*nb2 + i03*nb3); float slope = 1.0f; @@ -75,7 +79,7 @@ kernel void kernel_soft_max_4_f16( } // parallel max - float4 lmax4 = -INFINITY; + float4 lmax4 = psrc2 ? psrc2[i02] : -INFINITY; for (int i00 = get_local_id(0); i00 < ne00/4; i00 += get_local_size(0)) { lmax4 = fmax(lmax4, psrc4[i00]*scale + slope*(pmask ? convert_float4(pmask[i00]) : 0.0f)); } @@ -92,7 +96,11 @@ kernel void kernel_soft_max_4_f16( } float lsum = lsum4.s0 + lsum4.s1 + lsum4.s2 + lsum4.s3; - const float sum = sub_group_reduce_add(lsum); + float sum = sub_group_reduce_add(lsum); + + if (psrc2) { + sum += exp(psrc2[i02] - max); + } for (int i00 = get_local_id(0); i00 < ne00/4; i00 += get_local_size(0)) { pdst4[i00] /= sum; diff --git a/ggml/src/ggml-opencl/kernels/softmax_4_f32.cl b/ggml/src/ggml-opencl/kernels/softmax_4_f32.cl index 35b5573b4..1f944b220 100644 --- a/ggml/src/ggml-opencl/kernels/softmax_4_f32.cl +++ b/ggml/src/ggml-opencl/kernels/softmax_4_f32.cl @@ -26,6 +26,8 @@ kernel void kernel_soft_max_4( ulong offset0, global char * src1, ulong offset1, + global char * src2, + ulong offset2, global char * dst, ulong offsetd, int ne00, @@ -48,6 +50,7 @@ kernel void kernel_soft_max_4( ) { src0 = src0 + offset0; src1 = src1 + offset1; + src2 = src2 + offset2; dst = dst + offsetd; int i03 = get_group_id(2); @@ -60,6 +63,7 @@ kernel void kernel_soft_max_4( global float4 * psrc4 = (global float4 *)(src0 + i01*nb01 + i02*nb02 + i03*nb03); global float4 * pmask = src1 != src0 ? (global float4 *)(src1 + i11*nb11 + i12*nb12 + i13*nb13) : 0; + global float * psrc2 = src2 != src0 ? (global float *)(src2) : 0; global float4 * pdst4 = (global float4 *)(dst + i01*nb1 + i02*nb2 + i03*nb3); float slope = 1.0f; @@ -75,7 +79,7 @@ kernel void kernel_soft_max_4( } // parallel max - float4 lmax4 = -INFINITY; + float4 lmax4 = psrc2 ? psrc2[i02] : -INFINITY; for (int i00 = get_local_id(0); i00 < ne00/4; i00 += get_local_size(0)) { lmax4 = fmax(lmax4, psrc4[i00]*scale + (pmask ? slope*pmask[i00] : 0.0f)); } @@ -92,7 +96,11 @@ kernel void kernel_soft_max_4( } float lsum = lsum4.s0 + lsum4.s1 + lsum4.s2 + lsum4.s3; - const float sum = sub_group_reduce_add(lsum); + float sum = sub_group_reduce_add(lsum); + + if (psrc2) { + sum += exp(psrc2[i02] - max); + } for (int i00 = get_local_id(0); i00 < ne00/4; i00 += get_local_size(0)) { pdst4[i00] /= sum; diff --git a/ggml/src/ggml-opencl/kernels/softmax_f16.cl b/ggml/src/ggml-opencl/kernels/softmax_f16.cl index 9d292b574..4baa6c28e 100644 --- a/ggml/src/ggml-opencl/kernels/softmax_f16.cl +++ b/ggml/src/ggml-opencl/kernels/softmax_f16.cl @@ -26,6 +26,8 @@ kernel void kernel_soft_max_f16( ulong offset0, global char * src1, ulong offset1, + global char * src2, + ulong offset2, global char * dst, ulong offsetd, int ne00, @@ -48,6 +50,7 @@ kernel void kernel_soft_max_f16( ) { src0 = src0 + offset0; src1 = src1 + offset1; + src2 = src2 + offset2; dst = dst + offsetd; int i03 = get_group_id(2); @@ -60,6 +63,7 @@ kernel void kernel_soft_max_f16( global float * psrc0 = (global float *)(src0 + i01*nb01 + i02*nb02 + i03*nb03); global half * pmask = src1 != src0 ? (global half *)(src1 + i11*nb11 + i12*nb12 + i13*nb13) : 0; + global float * psrc2 = src2 != src0 ? (global float *)(src2) : 0; global float * pdst = (global float *)(dst + i01*nb1 + i02*nb2 + i03*nb3); float slope = 1.0f; @@ -75,7 +79,7 @@ kernel void kernel_soft_max_f16( } // parallel max - float lmax = -INFINITY; + float lmax = psrc2 ? psrc2[i02] : -INFINITY; for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) { lmax = fmax(lmax, psrc0[i00]*scale + (pmask ? slope*pmask[i00] : 0.0f)); } @@ -91,7 +95,11 @@ kernel void kernel_soft_max_f16( pdst[i00] = exp_psrc0; } - const float sum = sub_group_reduce_add(lsum); + float sum = sub_group_reduce_add(lsum); + + if (psrc2) { + sum += exp(psrc2[i02] - max); + } for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) { pdst[i00] /= sum; diff --git a/ggml/src/ggml-opencl/kernels/softmax_f32.cl b/ggml/src/ggml-opencl/kernels/softmax_f32.cl index 7c53dfbe5..d503190b4 100644 --- a/ggml/src/ggml-opencl/kernels/softmax_f32.cl +++ b/ggml/src/ggml-opencl/kernels/softmax_f32.cl @@ -26,6 +26,8 @@ kernel void kernel_soft_max( ulong offset0, global char * src1, ulong offset1, + global char * src2, + ulong offset2, global char * dst, ulong offsetd, int ne00, @@ -48,6 +50,7 @@ kernel void kernel_soft_max( ) { src0 = src0 + offset0; src1 = src1 + offset1; + src2 = src2 + offset2; dst = dst + offsetd; int i03 = get_group_id(2); @@ -60,6 +63,7 @@ kernel void kernel_soft_max( global float * psrc0 = (global float *)(src0 + i01*nb01 + i02*nb02 + i03*nb03); global float * pmask = src1 != src0 ? (global float *)(src1 + i11*nb11 + i12*nb12 + i13*nb13) : 0; + global float * psrc2 = src2 != src0 ? (global float *)(src2) : 0; global float * pdst = (global float *)(dst + i01*nb1 + i02*nb2 + i03*nb3); float slope = 1.0f; @@ -75,7 +79,7 @@ kernel void kernel_soft_max( } // parallel max - float lmax = -INFINITY; + float lmax = psrc2 ? psrc2[i02] : -INFINITY; for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) { lmax = fmax(lmax, psrc0[i00]*scale + (pmask ? slope*pmask[i00] : 0.0f)); } @@ -91,7 +95,11 @@ kernel void kernel_soft_max( pdst[i00] = exp_psrc0; } - const float sum = sub_group_reduce_add(lsum); + float sum = sub_group_reduce_add(lsum); + + if (psrc2) { + sum += exp(psrc2[i02] - max); + } for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) { pdst[i00] /= sum; From 1425f587a82bc303469b5c32759a2746ba4e1e20 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Fri, 8 Aug 2025 08:19:58 +0200 Subject: [PATCH 2/4] CUDA: attention sinks for mma FlashAttention (#15157) --- ggml/src/ggml-cuda/fattn-mma-f16.cuh | 87 ++++++++++++++++++++++------ ggml/src/ggml-cuda/fattn.cu | 2 +- ggml/src/ggml-cuda/ggml-cuda.cu | 3 +- 3 files changed, 73 insertions(+), 19 deletions(-) diff --git a/ggml/src/ggml-cuda/fattn-mma-f16.cuh b/ggml/src/ggml-cuda/fattn-mma-f16.cuh index 371253844..39731baae 100644 --- a/ggml/src/ggml-cuda/fattn-mma-f16.cuh +++ b/ggml/src/ggml-cuda/fattn-mma-f16.cuh @@ -785,6 +785,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( const half2 * const __restrict__ K_h2, const half2 * const __restrict__ V_h2, const half2 * const __restrict__ mask_h2, + const float * const __restrict__ sinks_f, float2 * const __restrict__ dstk, float2 * const __restrict__ dstk_fixup, const float scale, @@ -957,6 +958,52 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( } } + // If attention sinks are used, potentially re-scale if KQ_max is small. + // Also add the sink as a value to KQ_rowsum, this is done after synchonization of KQ_rowsum + // so it's being done unconditionally for every thread. + if (!is_fixup && (np == 1 || threadIdx.y % np == 0) && sinks_f) { + float KQ_max_scale[cols_per_thread]; +#pragma unroll + for (int col = 0; col < cols_per_thread; ++col) { + static_assert(ntiles == 1 || ntiles == 2, "ntiles > 2 not implemented"); + const int jc = ntiles == 1 ? 2*tile_C_VKQ::get_j(col/2) + col % 2 : tile_C_VKQ_16::get_i(col); + const float sink = sinks_f[jc % ncols2]; + + const float KQ_max_new = fmaxf(KQ_max[col], sink); + const float KQ_max_diff = KQ_max[col] - KQ_max_new; + KQ_max_scale[col] = expf(KQ_max_diff); + KQ_max[col] = KQ_max_new; + + *((uint32_t *) &KQ_max_scale[col]) *= KQ_max_diff >= SOFTMAX_FTZ_THRESHOLD; + + const float KQ_max_add = expf(sink - KQ_max_new); + KQ_rowsum[col] = KQ_max_scale[col]*KQ_rowsum[col] + KQ_max_add; + } + + if (ntiles == 1) { + const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[0], KQ_max_scale[1]); +#pragma unroll + for (int i = 0; i < DV/tile_C_VKQ::I; ++i) { +#pragma unroll + for (int l = 0; l < tile_C_VKQ::ne; ++l) { + VKQ_C[i].x[l] *= KQ_max_scale_h2; + } + } + } else { +#pragma unroll + for (int col = 0; col < cols_per_thread; ++col) { + const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[col], KQ_max_scale[col]); +#pragma unroll + for (int i = 0; i < DV/tile_C_VKQ_16::J; ++i) { +#pragma unroll + for (int l0 = 0; l0 < tile_C_VKQ_16::ne; l0 += 2) { + VKQ_C_16[i*ntiles/2 + col/2].x[l0 + col % 2] *= KQ_max_scale_h2; + } + } + } + } + } + // Combine VKQ accumulator values if np > 1. // It's also faster to do small writes to shared memory, then large write to VRAM than to do small writes to VRAM. // So also write VKQ accumulators to shared memory in column-major format if np == 1. @@ -1271,18 +1318,21 @@ static __global__ void flash_attn_ext_f16( while (kbc < kbc_stop && kb0_stop == iter_k) { const int sequence = kbc / (iter_k*iter_j*(ne02/ncols2)); - const int head = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence) / (iter_k*iter_j); - const int jt = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence - iter_k*iter_j*head) / iter_k; // j index of current tile. + const int zt = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence) / (iter_k*iter_j); // head in units of ncols2 + const int jt = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence - iter_k*iter_j*zt) / iter_k; // j index of current tile. - const float2 * Q_f2 = (const float2 *) (Q + nb03*sequence + nb02*(head*ncols2)); - const half2 * K_h2 = (const half2 *) (K + nb13*sequence + nb12*(head*ncols2 / gqa_ratio)); + const int head0 = zt * ncols2; + + const float2 * Q_f2 = (const float2 *) (Q + nb03*sequence + nb02* head0); + const half2 * K_h2 = (const half2 *) (K + nb13*sequence + nb12*(head0 / gqa_ratio)); const half2 * mask_h2 = ncols2 == 1 && !mask ? nullptr : (const half2 *) (mask + nb33*(sequence % ne33) + nb31*jt*ncols1); - float2 * dstk = ((float2 *) dst) + (sequence*ne01*ne02 + head*ncols2) * (DV/2); + float2 * dstk = ((float2 *) dst) + (sequence*ne01*ne02 + head0) * (DV/2); - const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb23*sequence + nb22*(head*ncols2 / gqa_ratio)); + const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb23*sequence + nb22*(head0 / gqa_ratio)); + const float * sinks_f = sinks ? (const float *) sinks + head0 : nullptr; - const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, head, n_head_log2, m0, m1) : 1.0f; + const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, head0, n_head_log2, m0, m1) : 1.0f; const int kb0_start_kernel = kb0_start * kb_niter; int kb0_stop_kernel = kb0_stop * kb_niter; @@ -1295,12 +1345,12 @@ static __global__ void flash_attn_ext_f16( if (kb0_start == 0) { constexpr bool needs_fixup = false; // CUDA block is working on an entire tile. flash_attn_ext_f16_process_tile - (Q_f2, K_h2, V_h2, mask_h2, dstk, dst_meta, scale, slope, logit_softcap, + (Q_f2, K_h2, V_h2, mask_h2, sinks_f, dstk, dst_meta, scale, slope, logit_softcap, ne01, ne02, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel); } else { constexpr bool needs_fixup = true; // CUDA block is working on the beginning of a tile. flash_attn_ext_f16_process_tile - (Q_f2, K_h2, V_h2, mask_h2, dstk, dst_meta, scale, slope, logit_softcap, + (Q_f2, K_h2, V_h2, mask_h2, sinks_f, dstk, dst_meta, scale, slope, logit_softcap, ne01, ne02, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel); } @@ -1316,18 +1366,21 @@ static __global__ void flash_attn_ext_f16( } const int sequence = kbc / (iter_k*iter_j*(ne02/ncols2)); - const int head = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence) / (iter_k*iter_j); - const int jt = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence - iter_k*iter_j*head) / iter_k; // j index of current tile. + const int zt = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence) / (iter_k*iter_j); // head in units of ncols2 + const int jt = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence - iter_k*iter_j*zt) / iter_k; // j index of current tile. - const float2 * Q_f2 = (const float2 *) (Q + nb03*sequence + nb02*(head*ncols2)); - const half2 * K_h2 = (const half2 *) (K + nb13*sequence + nb12*(head*ncols2 / gqa_ratio)); + const int head0 = zt * ncols2; + + const float2 * Q_f2 = (const float2 *) (Q + nb03*sequence + nb02* head0); + const half2 * K_h2 = (const half2 *) (K + nb13*sequence + nb12*(head0 / gqa_ratio)); const half2 * mask_h2 = ncols2 == 1 && !mask ? nullptr : (const half2 *) (mask + nb33*(sequence % ne33) + nb31*jt*ncols1); - float2 * dstk = ((float2 *) dst) + (sequence*ne01*ne02 + head*ncols2) * (DV/2); + float2 * dstk = ((float2 *) dst) + (sequence*ne01*ne02 + head0) * (DV/2); - const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb23*sequence + nb22*(head*ncols2 / gqa_ratio)); + const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb23*sequence + nb22*(head0 / gqa_ratio)); + const float * sinks_f = sinks ? (const float *) sinks + head0 : nullptr; - const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, head, n_head_log2, m0, m1) : 1.0f; + const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, head0, n_head_log2, m0, m1) : 1.0f; const int kb0_start_kernel = kb0_start * kb_niter; int kb0_stop_kernel = kb0_stop * kb_niter; @@ -1339,7 +1392,7 @@ static __global__ void flash_attn_ext_f16( constexpr bool is_fixup = true; // Last index writes its data to fixup buffer to avoid data races with other blocks. constexpr bool needs_fixup = false; flash_attn_ext_f16_process_tile - (Q_f2, K_h2, V_h2, mask_h2, dstk, dst_meta, scale, slope, logit_softcap, + (Q_f2, K_h2, V_h2, mask_h2, sinks_f, dstk, dst_meta, scale, slope, logit_softcap, ne01, ne02, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel); #else GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask); GGML_UNUSED(sinks); diff --git a/ggml/src/ggml-cuda/fattn.cu b/ggml/src/ggml-cuda/fattn.cu index 8ddd0415b..6c1185dea 100644 --- a/ggml/src/ggml-cuda/fattn.cu +++ b/ggml/src/ggml-cuda/fattn.cu @@ -282,7 +282,7 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst const enum ggml_prec prec = ggml_flash_attn_ext_get_prec(KQV); // TODO: currently only vec implementation for sinks is supported [TAG_ATTN_SINKS] - if (sinks) { + if (sinks && !fp16_mma_available(cc)) { if (prec == GGML_PREC_DEFAULT && fast_fp16_available(cc)) { ggml_cuda_flash_attn_ext_vec_f16(ctx, dst); } else { diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index ec7ab2551..19e9c405e 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -3532,7 +3532,8 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g return op->src[1]->ne[0] == 576 && op->src[2]->ne[0] == 512 && op->src[3] && gqa_ratio % 16 == 0; } // TODO: more general-purpose attention sink support [TAG_ATTN_SINKS] - if (op->src[4] && op->src[0]->ne[0] != 64 && op->src[0]->ne[0] != 128) { // currently only sinks for head_size 64 and 128 are supported + if (op->src[4] && !fp16_mma_available(ggml_cuda_info().devices[dev_ctx->device].cc) + && op->src[0]->ne[0] != 64 && op->src[0]->ne[0] != 128) { return false; } if (op->src[0]->ne[0] == 192) { From 6c7e9a54406dbba5e53754a8f70a285414717b06 Mon Sep 17 00:00:00 2001 From: Olivier Chafik Date: Fri, 8 Aug 2025 10:45:18 +0100 Subject: [PATCH 3/4] vendor: sync minja (#15161) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * vendor: sync minja * Update minja.hpp * Apply suggestions from code review Co-authored-by: Sigbjørn Skjæret --------- Co-authored-by: Sigbjørn Skjæret --- vendor/minja/chat-template.hpp | 15 ++++++++--- vendor/minja/minja.hpp | 49 +++++++++++++++++++++++----------- 2 files changed, 44 insertions(+), 20 deletions(-) diff --git a/vendor/minja/chat-template.hpp b/vendor/minja/chat-template.hpp index cf113bf22..d5295b335 100644 --- a/vendor/minja/chat-template.hpp +++ b/vendor/minja/chat-template.hpp @@ -162,8 +162,15 @@ class chat_template { }), false); caps_.supports_tools = contains(out, "some_tool"); - auto out_empty = try_raw_render(json::array({dummy_user_msg, {{"role", "assistant"}, {"content", ""}}}), {}, false); - auto out_null = try_raw_render(json::array({dummy_user_msg, {{"role", "assistant"}, {"content", nullptr}}}), {}, false); + const auto render_with_content = [&](const json & content) { + const json assistant_msg {{"role", "assistant"}, {"content", content}}; + // Render two assistant messages as some templates like QwQ-32B are handling + // the content differently depending on whether it's the last message or not + // (to remove the tag in all but the last message). + return try_raw_render(json::array({dummy_user_msg, assistant_msg, dummy_user_msg, assistant_msg}), {}, false); + }; + auto out_empty = render_with_content(""); + auto out_null = render_with_content(json()); caps_.requires_non_null_content = contains(out_empty, user_needle) && !contains(out_null, user_needle); json j_null; @@ -191,12 +198,12 @@ class chat_template { dummy_user_msg, make_tool_calls_msg(json::array({make_tool_call("ipython", dummy_args_obj.dump())})), }), {}, false); - auto tool_call_renders_str_arguments = contains(out, "\"argument_needle\":") || contains(out, "'argument_needle':"); + auto tool_call_renders_str_arguments = contains(out, "") || contains(out, "\"argument_needle\":") || contains(out, "'argument_needle':"); out = try_raw_render(json::array({ dummy_user_msg, make_tool_calls_msg(json::array({make_tool_call("ipython", dummy_args_obj)})), }), {}, false); - auto tool_call_renders_obj_arguments = contains(out, "\"argument_needle\":") || contains(out, "'argument_needle':"); + auto tool_call_renders_obj_arguments = contains(out, "") || contains(out, "\"argument_needle\":") || contains(out, "'argument_needle':"); caps_.supports_tool_calls = tool_call_renders_str_arguments || tool_call_renders_obj_arguments; caps_.requires_object_arguments = !tool_call_renders_str_arguments && tool_call_renders_obj_arguments; diff --git a/vendor/minja/minja.hpp b/vendor/minja/minja.hpp index dd107dccd..dad75efbb 100644 --- a/vendor/minja/minja.hpp +++ b/vendor/minja/minja.hpp @@ -1291,6 +1291,12 @@ public: } }; +static bool in(const Value & value, const Value & container) { + return (((container.is_array() || container.is_object()) && container.contains(value)) || + (value.is_string() && container.is_string() && + container.to_str().find(value.to_str()) != std::string::npos)); +} + class BinaryOpExpr : public Expression { public: enum class Op { StrConcat, Add, Sub, Mul, MulMul, Div, DivDiv, Mod, Eq, Ne, Lt, Gt, Le, Ge, And, Or, In, NotIn, Is, IsNot }; @@ -1355,13 +1361,8 @@ public: case Op::Gt: return l > r; case Op::Le: return l <= r; case Op::Ge: return l >= r; - case Op::In: return (((r.is_array() || r.is_object()) && r.contains(l)) || - (l.is_string() && r.is_string() && - r.to_str().find(l.to_str()) != std::string::npos)); - case Op::NotIn: - return !(((r.is_array() || r.is_object()) && r.contains(l)) || - (l.is_string() && r.is_string() && - r.to_str().find(l.to_str()) != std::string::npos)); + case Op::In: return in(l, r); + case Op::NotIn: return !in(l, r); default: break; } throw std::runtime_error("Unknown binary operator"); @@ -1500,6 +1501,13 @@ public: } else if (method->get_name() == "pop") { vargs.expectArgs("pop method", {1, 1}, {0, 0}); return obj.pop(vargs.args[0]); + } else if (method->get_name() == "keys") { + vargs.expectArgs("keys method", {0, 0}, {0, 0}); + auto result = Value::array(); + for (const auto& key : obj.keys()) { + result.push_back(Value(key)); + } + return result; } else if (method->get_name() == "get") { vargs.expectArgs("get method", {1, 2}, {0, 0}); auto key = vargs.args[0]; @@ -1541,6 +1549,16 @@ public: } else if (method->get_name() == "capitalize") { vargs.expectArgs("capitalize method", {0, 0}, {0, 0}); return Value(capitalize(str)); + } else if (method->get_name() == "upper") { + vargs.expectArgs("upper method", {0, 0}, {0, 0}); + auto result = str; + std::transform(result.begin(), result.end(), result.begin(), ::toupper); + return Value(result); + } else if (method->get_name() == "lower") { + vargs.expectArgs("lower method", {0, 0}, {0, 0}); + auto result = str; + std::transform(result.begin(), result.end(), result.begin(), ::tolower); + return Value(result); } else if (method->get_name() == "endswith") { vargs.expectArgs("endswith method", {1, 1}, {0, 0}); auto suffix = vargs.args[0].get(); @@ -2646,15 +2664,11 @@ inline std::shared_ptr Context::builtins() { auto items = Value::array(); if (args.contains("object")) { auto & obj = args.at("object"); - if (obj.is_string()) { - auto json_obj = json::parse(obj.get()); - for (const auto & kv : json_obj.items()) { - items.push_back(Value::array({kv.key(), kv.value()})); - } - } else if (!obj.is_null()) { - for (auto & key : obj.keys()) { - items.push_back(Value::array({key, obj.at(key)})); - } + if (!obj.is_object()) { + throw std::runtime_error("Can only get item pairs from a mapping"); + } + for (auto & key : obj.keys()) { + items.push_back(Value::array({key, obj.at(key)})); } } return items; @@ -2782,6 +2796,9 @@ inline std::shared_ptr Context::builtins() { if (!items.is_array()) throw std::runtime_error("object is not iterable"); return items; })); + globals.set("in", simple_function("in", { "item", "items" }, [](const std::shared_ptr &, Value & args) -> Value { + return in(args.at("item"), args.at("items")); + })); globals.set("unique", simple_function("unique", { "items" }, [](const std::shared_ptr &, Value & args) -> Value { auto & items = args.at("items"); if (!items.is_array()) throw std::runtime_error("object is not iterable"); From cd6983d56d2cce94ecb86bb114ae8379a609073c Mon Sep 17 00:00:00 2001 From: AN Long Date: Fri, 8 Aug 2025 21:37:22 +0900 Subject: [PATCH 4/4] ggml : fix field name when new ggml_backend (#14944) --- ggml/src/ggml-blas/ggml-blas.cpp | 8 ++++---- ggml/src/ggml-cpu/ggml-cpu.cpp | 8 ++++---- ggml/src/ggml-cuda/ggml-cuda.cu | 8 ++++---- ggml/src/ggml-opencl/ggml-opencl.cpp | 8 ++++---- ggml/src/ggml-rpc/ggml-rpc.cpp | 8 ++++---- ggml/src/ggml-sycl/ggml-sycl.cpp | 8 ++++---- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 8 ++++---- 7 files changed, 28 insertions(+), 28 deletions(-) diff --git a/ggml/src/ggml-blas/ggml-blas.cpp b/ggml/src/ggml-blas/ggml-blas.cpp index ec158dfac..aeac2e574 100644 --- a/ggml/src/ggml-blas/ggml-blas.cpp +++ b/ggml/src/ggml-blas/ggml-blas.cpp @@ -281,10 +281,10 @@ ggml_backend_t ggml_backend_blas_init(void) { ggml_backend_blas_context * ctx = new ggml_backend_blas_context; ggml_backend_t backend = new ggml_backend { - /* .guid = */ ggml_backend_blas_guid(), - /* .interface = */ blas_backend_i, - /* .device = */ ggml_backend_reg_dev_get(ggml_backend_blas_reg(), 0), - /* .context = */ ctx, + /* .guid = */ ggml_backend_blas_guid(), + /* .iface = */ blas_backend_i, + /* .device = */ ggml_backend_reg_dev_get(ggml_backend_blas_reg(), 0), + /* .context = */ ctx, }; #if defined(OPENBLAS_VERSION) && defined(GGML_USE_OPENMP) diff --git a/ggml/src/ggml-cpu/ggml-cpu.cpp b/ggml/src/ggml-cpu/ggml-cpu.cpp index e16cdc9d4..8dacd3671 100644 --- a/ggml/src/ggml-cpu/ggml-cpu.cpp +++ b/ggml/src/ggml-cpu/ggml-cpu.cpp @@ -214,10 +214,10 @@ ggml_backend_t ggml_backend_cpu_init(void) { ctx->abort_callback_data = NULL; ggml_backend_t cpu_backend = new ggml_backend { - /* .guid = */ ggml_backend_cpu_guid(), - /* .interface = */ ggml_backend_cpu_i, - /* .device = */ ggml_backend_reg_dev_get(ggml_backend_cpu_reg(), 0), - /* .context = */ ctx, + /* .guid = */ ggml_backend_cpu_guid(), + /* .iface = */ ggml_backend_cpu_i, + /* .device = */ ggml_backend_reg_dev_get(ggml_backend_cpu_reg(), 0), + /* .context = */ ctx, }; if (cpu_backend == NULL) { diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 19e9c405e..d9110491e 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -3799,10 +3799,10 @@ ggml_backend_t ggml_backend_cuda_init(int device) { } ggml_backend_t cuda_backend = new ggml_backend { - /* .guid = */ ggml_backend_cuda_guid(), - /* .interface = */ ggml_backend_cuda_interface, - /* .device = */ ggml_backend_reg_dev_get(ggml_backend_cuda_reg(), device), - /* .context = */ ctx, + /* .guid = */ ggml_backend_cuda_guid(), + /* .iface = */ ggml_backend_cuda_interface, + /* .device = */ ggml_backend_reg_dev_get(ggml_backend_cuda_reg(), device), + /* .context = */ ctx, }; return cuda_backend; diff --git a/ggml/src/ggml-opencl/ggml-opencl.cpp b/ggml/src/ggml-opencl/ggml-opencl.cpp index b32d5da30..8ba1e00df 100644 --- a/ggml/src/ggml-opencl/ggml-opencl.cpp +++ b/ggml/src/ggml-opencl/ggml-opencl.cpp @@ -2624,10 +2624,10 @@ ggml_backend_t ggml_backend_opencl_init(void) { ggml_backend_opencl_context *backend_ctx = ggml_cl2_init(dev); ggml_backend_t backend = new ggml_backend { - /* .guid = */ ggml_backend_opencl_guid(), - /* .interface = */ ggml_backend_opencl_i, - /* .device = */ dev, - /* .context = */ backend_ctx + /* .guid = */ ggml_backend_opencl_guid(), + /* .iface = */ ggml_backend_opencl_i, + /* .device = */ dev, + /* .context = */ backend_ctx }; return backend; diff --git a/ggml/src/ggml-rpc/ggml-rpc.cpp b/ggml/src/ggml-rpc/ggml-rpc.cpp index 29bc421d5..df6ba5407 100644 --- a/ggml/src/ggml-rpc/ggml-rpc.cpp +++ b/ggml/src/ggml-rpc/ggml-rpc.cpp @@ -823,10 +823,10 @@ ggml_backend_t ggml_backend_rpc_init(const char * endpoint) { }; ggml_backend_t backend = new ggml_backend { - /* .guid = */ ggml_backend_rpc_guid(), - /* .interface = */ ggml_backend_rpc_interface, - /* .device = */ ggml_backend_rpc_add_device(endpoint), - /* .context = */ ctx + /* .guid = */ ggml_backend_rpc_guid(), + /* .iface = */ ggml_backend_rpc_interface, + /* .device = */ ggml_backend_rpc_add_device(endpoint), + /* .context = */ ctx }; return backend; } diff --git a/ggml/src/ggml-sycl/ggml-sycl.cpp b/ggml/src/ggml-sycl/ggml-sycl.cpp index 6fa27418c..3992dad01 100644 --- a/ggml/src/ggml-sycl/ggml-sycl.cpp +++ b/ggml/src/ggml-sycl/ggml-sycl.cpp @@ -4586,10 +4586,10 @@ ggml_backend_t ggml_backend_sycl_init(int device) { }; ggml_backend_t sycl_backend = new ggml_backend { - /* .guid = */ ggml_backend_sycl_guid(), - /* .interface = */ ggml_backend_sycl_interface, - /* .device = */ ggml_backend_reg_dev_get(ggml_backend_sycl_reg(), device), - /* .context = */ ctx + /* .guid = */ ggml_backend_sycl_guid(), + /* .iface = */ ggml_backend_sycl_interface, + /* .device = */ ggml_backend_reg_dev_get(ggml_backend_sycl_reg(), device), + /* .context = */ ctx }; return sycl_backend; diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index b1cbbc986..4070e248b 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -10767,10 +10767,10 @@ ggml_backend_t ggml_backend_vk_init(size_t dev_num) { ggml_vk_init(ctx, dev_num); ggml_backend_t vk_backend = new ggml_backend { - /* .guid = */ ggml_backend_vk_guid(), - /* .interface = */ ggml_backend_vk_interface, - /* .device = */ ggml_backend_reg_dev_get(ggml_backend_vk_reg(), dev_num), - /* .context = */ ctx, + /* .guid = */ ggml_backend_vk_guid(), + /* .iface = */ ggml_backend_vk_interface, + /* .device = */ ggml_backend_reg_dev_get(ggml_backend_vk_reg(), dev_num), + /* .context = */ ctx, }; return vk_backend;