From cd465d823c378853f1f6570eebfb77f69c7e1d39 Mon Sep 17 00:00:00 2001 From: Romain Biessy Date: Mon, 21 Jul 2025 18:39:29 +0200 Subject: [PATCH 1/8] sycl: Fix im2col (#14797) --- ggml/src/ggml-sycl/im2col.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ggml/src/ggml-sycl/im2col.cpp b/ggml/src/ggml-sycl/im2col.cpp index 52737cc74..7adcb3d9d 100644 --- a/ggml/src/ggml-sycl/im2col.cpp +++ b/ggml/src/ggml-sycl/im2col.cpp @@ -26,7 +26,7 @@ static void im2col_kernel(const float * x, T * dst, int64_t batch_offset, int64_ // make each work-item deal with more elements since sycl global range can not exceed max int for (int64_t i = global_id; i < pelements; i += (work_group_size * item_ct1.get_group_range(2))) { - const int64_t ksize = OW * (KH > 1 ? KW : 1); + const int64_t ksize = OW * KH; const int64_t kx = i / ksize; const int64_t kd = kx * ksize; const int64_t ky = (i - kd) / OW; From 6c9ee3b17e19dcc82ab93d52ae46fdd0226d4777 Mon Sep 17 00:00:00 2001 From: rmatif Date: Mon, 21 Jul 2025 19:03:19 +0200 Subject: [PATCH 2/8] opencl: add conv2d kernel (#14403) * add conv2d kernel * fix trailing whitespace * whitespace fixe * handle f16 input and f16 kernel, more opt * resolve conflicts * use enqueue_ndrange_kernel --- ggml/src/ggml-opencl/CMakeLists.txt | 2 + ggml/src/ggml-opencl/ggml-opencl.cpp | 134 +++++++++++++ ggml/src/ggml-opencl/kernels/conv2d.cl | 185 ++++++++++++++++++ .../src/ggml-opencl/kernels/conv2d_f16_f32.cl | 176 +++++++++++++++++ 4 files changed, 497 insertions(+) create mode 100644 ggml/src/ggml-opencl/kernels/conv2d.cl create mode 100644 ggml/src/ggml-opencl/kernels/conv2d_f16_f32.cl diff --git a/ggml/src/ggml-opencl/CMakeLists.txt b/ggml/src/ggml-opencl/CMakeLists.txt index ec5d8cf59..015fa8f06 100644 --- a/ggml/src/ggml-opencl/CMakeLists.txt +++ b/ggml/src/ggml-opencl/CMakeLists.txt @@ -105,6 +105,8 @@ set(GGML_OPENCL_KERNELS pad repeat mul_mat_f16_f32 + conv2d + conv2d_f16_f32 ) foreach (K ${GGML_OPENCL_KERNELS}) diff --git a/ggml/src/ggml-opencl/ggml-opencl.cpp b/ggml/src/ggml-opencl/ggml-opencl.cpp index 338825915..a31483b61 100644 --- a/ggml/src/ggml-opencl/ggml-opencl.cpp +++ b/ggml/src/ggml-opencl/ggml-opencl.cpp @@ -390,6 +390,9 @@ struct ggml_backend_opencl_context { cl_program program_tanh; cl_program program_upscale; cl_program program_concat; + cl_program program_conv_2d_f16; + cl_program program_conv_2d_f32; + cl_program program_conv_2d_f16_f32; cl_program program_tsembd; cl_program program_mul_mv_id_q4_0_f32_8x_flat; @@ -441,6 +444,9 @@ struct ggml_backend_opencl_context { cl_kernel kernel_upscale_bilinear; cl_kernel kernel_concat_f32_contiguous; cl_kernel kernel_concat_f32_non_contiguous; + cl_kernel kernel_conv_2d_f16; + cl_kernel kernel_conv_2d_f32; + cl_kernel kernel_conv_2d_f16_f32; cl_kernel kernel_timestep_embedding; cl_kernel kernel_mul_mv_id_q4_0_f32_8x_flat; @@ -1478,6 +1484,47 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve GGML_LOG_CONT("."); } + // conv2d + { + #ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "conv2d.cl.h" + }; + const std::string kernel_src_f16_f32 { + #include "conv2d_f16_f32.cl.h" + }; + #else + const std::string kernel_src = read_file("conv2d.cl"); + const std::string kernel_src_f16_f32 = read_file("conv2d_f16_f32.cl"); + #endif + if (!kernel_src.empty()) { + backend_ctx->program_conv_2d_f16 = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), (std::string(compile_opts) + " -DUSE_FP16=1").c_str()); + CL_CHECK((backend_ctx->kernel_conv_2d_f16 = clCreateKernel(backend_ctx->program_conv_2d_f16, "kernel_conv_2d", &err), err)); + GGML_LOG_CONT("."); + backend_ctx->program_conv_2d_f32 = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + CL_CHECK((backend_ctx->kernel_conv_2d_f32 = clCreateKernel(backend_ctx->program_conv_2d_f32, "kernel_conv_2d", &err), err)); + GGML_LOG_CONT("."); + } else { + GGML_LOG_WARN("ggml_opencl: conv2d kernel source not found or empty. This op will not be available.\n"); + backend_ctx->program_conv_2d_f16 = nullptr; + backend_ctx->kernel_conv_2d_f16 = nullptr; + backend_ctx->program_conv_2d_f32 = nullptr; + backend_ctx->kernel_conv_2d_f32 = nullptr; + } + if (!kernel_src_f16_f32.empty()) { + backend_ctx->program_conv_2d_f16_f32 = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src_f16_f32.c_str(), compile_opts); + CL_CHECK((backend_ctx->kernel_conv_2d_f16_f32 = clCreateKernel(backend_ctx->program_conv_2d_f16_f32, "kernel_conv_2d", &err), err)); + GGML_LOG_CONT("."); + } else { + GGML_LOG_WARN("ggml_opencl: conv2d_f16_f32 kernel source not found or empty. This op will not be available.\n"); + backend_ctx->program_conv_2d_f16_f32 = nullptr; + backend_ctx->kernel_conv_2d_f16_f32 = nullptr; + } + } + // mul_mv_id_q4_0_f32_8x_flat { #ifdef GGML_OPENCL_EMBED_KERNELS @@ -2361,6 +2408,10 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te op->src[0]->ne[3] == 1 && op->ne[3] == 1; case GGML_OP_UPSCALE: return op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32; + case GGML_OP_CONV_2D: + return (op->src[0]->type == GGML_TYPE_F16 && op->src[1]->type == GGML_TYPE_F16 && op->type == GGML_TYPE_F16) || + (op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32) || + (op->src[0]->type == GGML_TYPE_F16 && op->src[1]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32); case GGML_OP_CONCAT: return op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32; case GGML_OP_TIMESTEP_EMBEDDING: @@ -4998,6 +5049,83 @@ static void ggml_cl_mul_mat_f16_f32_tiled(ggml_backend_t backend, const ggml_ten backend_ctx->enqueue_ndrange_kernel(kernel, 2, global_work_size, local_work_size, dst); } +static void ggml_cl_conv_2d(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + GGML_TENSOR_BINARY_OP_LOCALS; + ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; + + ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra; + ggml_tensor_extra_cl * extra1 = (ggml_tensor_extra_cl *)src1->extra; + ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra; + + cl_ulong offset0 = extra0->offset + src0->view_offs; + cl_ulong offset1 = extra1->offset + src1->view_offs; + cl_ulong offsetd = extrad->offset + dst->view_offs; + + const cl_uint Cout = ne03; const cl_uint Cin = ne02; const cl_uint N = ne13; + const cl_uint KW = ne00; const cl_uint KH = ne01; const cl_uint W = ne10; const cl_uint H = ne11; const cl_uint OW = ne0; const cl_uint OH = ne1; + + const cl_uint s0 = dst->op_params[0]; const cl_uint s1 = dst->op_params[1]; + const cl_uint p0 = dst->op_params[2]; const cl_uint p1 = dst->op_params[3]; + const cl_uint d0 = dst->op_params[4]; const cl_uint d1 = dst->op_params[5]; + + const cl_uint cl_nb01 = nb01/ggml_type_size(src0->type); const cl_uint cl_nb02 = nb02/ggml_type_size(src0->type); const cl_uint cl_nb03 = nb03/ggml_type_size(src0->type); + const cl_uint cl_nb11 = nb11/ggml_type_size(src1->type); const cl_uint cl_nb12 = nb12/ggml_type_size(src1->type); const cl_uint cl_nb13 = nb13/ggml_type_size(src1->type); + const cl_uint cl_nb1 = nb1/ggml_type_size(dst->type); const cl_uint cl_nb2 = nb2/ggml_type_size(dst->type); const cl_uint cl_nb3 = nb3/ggml_type_size(dst->type); + + const int64_t NPQ = (int64_t)N * OW * OH; + + const uint32_t BS_K = 64; + const uint32_t BS_NPQ = 64; + const uint32_t BS_CRS = 16; + const uint32_t VEC_SIZE = 4; + + const uint32_t TS_K = 4; + const uint32_t TS_NPQ = 8; + + const uint32_t WG_K = BS_K / TS_K; + const uint32_t WG_NPQ = BS_NPQ / TS_NPQ; + + auto splitWork = [](uint32_t work_size, uint32_t block_size) { return (block_size + work_size - 1) / block_size; }; + const uint32_t NB_K = splitWork(Cout, BS_K); + const uint32_t NB_NPQ = splitWork(NPQ, BS_NPQ); + + cl_kernel kernel; + size_t shmem_size; + + if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) { + kernel = backend_ctx->kernel_conv_2d_f16; + shmem_size = (size_t)(BS_K * BS_CRS * sizeof(cl_half) + BS_CRS * (BS_NPQ / VEC_SIZE) * sizeof(cl_half4)); + } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) { + kernel = backend_ctx->kernel_conv_2d_f32; + shmem_size = (size_t)(BS_K * BS_CRS * sizeof(cl_float) + BS_CRS * (BS_NPQ / VEC_SIZE) * sizeof(cl_float4)); + } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) { + kernel = backend_ctx->kernel_conv_2d_f16_f32; + shmem_size = (size_t)(BS_K * BS_CRS * sizeof(cl_half) + BS_CRS * (BS_NPQ / VEC_SIZE) * sizeof(cl_float4)); + } else { + GGML_ASSERT(false && "Unsupported data type combination for conv2d"); + return; + } + + cl_uint idx = 0; + CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_mem), &extra0->data_device)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_ulong), &offset0)); + CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_mem), &extra1->data_device)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_ulong), &offset1)); + CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_mem), &extrad->data_device)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, idx++, shmem_size, NULL)); + CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &Cout)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &Cin)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &N)); + CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &KW)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &KH)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &W)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &H)); + CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &OW)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &OH)); + CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &s0)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &s1)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &p0)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &p1)); + CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &d0)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &d1)); + CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &cl_nb01)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &cl_nb02)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &cl_nb03)); + CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &cl_nb11)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &cl_nb12)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &cl_nb13)); + CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &cl_nb1)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &cl_nb2)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &cl_nb3)); + + size_t global_work_size[] = { (size_t)NB_K * WG_K, (size_t)NB_NPQ * WG_NPQ, 1 }; + size_t local_work_size[] = { (size_t)WG_K, (size_t)WG_NPQ, 1 }; + + backend_ctx->enqueue_ndrange_kernel(kernel, 2, global_work_size, local_work_size, dst); +} + static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { GGML_ASSERT(src0); GGML_ASSERT(src0->extra); @@ -6752,6 +6880,12 @@ bool ggml_cl_compute_forward(ggml_backend_t backend, struct ggml_tensor * tensor } ggml_cl_upscale(backend, tensor->src[0], tensor); return true; + case GGML_OP_CONV_2D: + if (!any_on_device) { + return false; + } + func = ggml_cl_conv_2d; + break; case GGML_OP_CONCAT: if (!any_on_device) { return false; diff --git a/ggml/src/ggml-opencl/kernels/conv2d.cl b/ggml/src/ggml-opencl/kernels/conv2d.cl new file mode 100644 index 000000000..e339c90cf --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/conv2d.cl @@ -0,0 +1,185 @@ +#ifdef USE_FP16 +#pragma OPENCL EXTENSION cl_khr_fp16 : enable +#define T_FLOAT half +#define T_FLOAT4 half4 +#define VSTORE_T_FLOAT4(data, offset, p) vstore_half4_rte(data, offset, p) +#else +#define T_FLOAT float +#define T_FLOAT4 float4 +#define VSTORE_T_FLOAT4(data, offset, p) vstore4(data, offset, p) +#endif + +#if defined(cl_qcom_reqd_sub_group_size) +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable +#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full"))) +#else +#define REQD_SUBGROUP_SIZE_128 +#endif + +#define T_ACCUM float4 +#define VEC_SIZE 4 + +#define BS_K 64 +#define BS_NPQ 64 +#define BS_CRS 16 + +#define TS_K 4 +#define TS_NPQ 8 + +#define WG_K (BS_K / TS_K) +#define WG_NPQ (BS_NPQ / TS_NPQ) + +#define BS_NPQ_VEC (BS_NPQ / VEC_SIZE) +#define TS_NPQ_VEC (TS_NPQ / VEC_SIZE) + +static inline uint splitWork(uint work_size, uint block_size){ + return (work_size + block_size - 1) / block_size; +} + +REQD_SUBGROUP_SIZE_128 +kernel void kernel_conv_2d( + global void* p_knl, + ulong off_knl, + global void* p_src, + ulong off_src, + global void* p_dst, + ulong off_dst, + local void* shared, + uint Cout, uint Cin, uint N, + uint KW, uint KH, uint W, uint H, uint OW, uint OH, + uint s0, uint s1, uint p0, uint p1, uint d0, uint d1, + uint nb01, uint nb02, uint nb03, + uint nb11, uint nb12, uint nb13, + uint nb1, uint nb2, uint nb3 +) { + global T_FLOAT* knl_data = (global T_FLOAT*) ((global char*)p_knl + off_knl); + global T_FLOAT* src_data = (global T_FLOAT*) ((global char*)p_src + off_src); + global T_FLOAT* dst_data = (global T_FLOAT*) ((global char*)p_dst + off_dst); + + const uint K = Cout; + const uint CRS = Cin*KH*KW; + const uint NPQ = N*OH*OW; + + const uint lid_k = get_local_id(0); + const uint lid_npq = get_local_id(1); + const uint tid = lid_npq * WG_K + lid_k; + + const uint B_idx_K = get_group_id(0); + const uint B_idx_NPQ = get_group_id(1); + + const uint offset_k = B_idx_K * BS_K; + const uint offset_npq = B_idx_NPQ * BS_NPQ; + + local T_FLOAT* Ash = (local T_FLOAT*)shared; + local T_FLOAT4* Bsh = (local T_FLOAT4*) &Ash[BS_K * BS_CRS]; + + T_ACCUM regC[TS_K][TS_NPQ_VEC]; + for (int i = 0; i < TS_K; ++i) { + for (int j = 0; j < TS_NPQ_VEC; ++j) { + regC[i][j] = (T_ACCUM)(0.0f); + } + } + + const uint NB_CRS = splitWork(CRS, BS_CRS); + + for (uint B_idx_CRS = 0; B_idx_CRS < NB_CRS; ++B_idx_CRS) { + const uint offset_crs = B_idx_CRS * BS_CRS; + + for (int i = tid; i < BS_K * BS_CRS; i += (WG_K * WG_NPQ)) { + const uint k_l = i / BS_CRS; + const uint crs_l = i % BS_CRS; + const uint k_g = offset_k + k_l; + const uint crs_g = offset_crs + crs_l; + + if (k_g < K && crs_g < CRS) { + const uint Cin_idx = crs_g / (KW*KH); + const uint KH_idx = (crs_g - Cin_idx*KW*KH) / KW; + const uint KW_idx = crs_g - Cin_idx*KW*KH - KH_idx*KW; + const uint knl_idx = KW_idx + KH_idx*nb01 + Cin_idx*nb02 + k_g*nb03; + Ash[k_l * BS_CRS + crs_l] = knl_data[knl_idx]; + } else { + Ash[k_l * BS_CRS + crs_l] = (T_FLOAT)0.0f; + } + } + + for (int i = tid; i < BS_CRS * BS_NPQ_VEC; i += (WG_K * WG_NPQ)) { + const uint crs_l = i / BS_NPQ_VEC; + const uint npq_l_vec = i % BS_NPQ_VEC; + const uint crs_g = offset_crs + crs_l; + + T_FLOAT4 val = (T_FLOAT4)(0.0f); + if (crs_g < CRS) { + const uint Cin_idx = crs_g / (KW * KH); + const uint KH_idx = (crs_g - Cin_idx * KW * KH) / KW; + const uint KW_idx = crs_g - Cin_idx * KW * KH - KH_idx * KW; + for (int v = 0; v < VEC_SIZE; ++v) { + const uint npq_g = offset_npq + npq_l_vec * VEC_SIZE + v; + if (npq_g < NPQ) { + const uint N_idx = npq_g / (OH * OW); + const uint pq_idx = npq_g % (OH * OW); + const uint OH_idx = pq_idx / OW; + const uint OW_idx = pq_idx % OW; + const int H_idx = (int)(OH_idx * s1 + KH_idx * d1 - p1); + const int W_idx = (int)(OW_idx * s0 + KW_idx * d0 - p0); + + if (H_idx >= 0 && H_idx < H && W_idx >= 0 && W_idx < W) { + const uint src_idx = W_idx + H_idx * nb11 + Cin_idx * nb12 + N_idx * nb13; + ((T_FLOAT*)&val)[v] = src_data[src_idx]; + } + } + } + } + Bsh[crs_l * BS_NPQ_VEC + npq_l_vec] = val; + } + + barrier(CLK_LOCAL_MEM_FENCE); + + #pragma unroll + for (uint crs_l = 0; crs_l < BS_CRS; ++crs_l) { + T_FLOAT regA[TS_K]; + for (uint k_l_reg = 0; k_l_reg < TS_K; ++k_l_reg) { + regA[k_l_reg] = Ash[(lid_k * TS_K + k_l_reg) * BS_CRS + crs_l]; + } + + for (uint npq_l_vec_reg = 0; npq_l_vec_reg < TS_NPQ_VEC; ++npq_l_vec_reg) { + T_FLOAT4 regB = Bsh[crs_l * BS_NPQ_VEC + lid_npq * TS_NPQ_VEC + npq_l_vec_reg]; + for (uint k_l_reg = 0; k_l_reg < TS_K; ++k_l_reg) { + regC[k_l_reg][npq_l_vec_reg] = mad(convert_float(regA[k_l_reg]), convert_float4(regB), regC[k_l_reg][npq_l_vec_reg]); + } + } + } + barrier(CLK_LOCAL_MEM_FENCE); + } + + for (uint k_l_reg = 0; k_l_reg < TS_K; ++k_l_reg) { + const uint k_g = offset_k + lid_k * TS_K + k_l_reg; + if (k_g >= K) continue; + + for (uint npq_l_vec_reg = 0; npq_l_vec_reg < TS_NPQ_VEC; ++npq_l_vec_reg) { + const uint npq_g_base = offset_npq + (lid_npq * TS_NPQ_VEC + npq_l_vec_reg) * VEC_SIZE; + + const uint N_idx = npq_g_base / (OH * OW); + const uint pq_idx = npq_g_base % (OH * OW); + const uint OH_idx = pq_idx / OW; + const uint OW_idx = pq_idx % OW; + + if (nb1 == OW && OW_idx + VEC_SIZE <= OW && npq_g_base + VEC_SIZE <= NPQ) { + const uint dst_idx = OW_idx + OH_idx*nb1 + k_g*nb2 + N_idx*nb3; + VSTORE_T_FLOAT4(regC[k_l_reg][npq_l_vec_reg], 0, &dst_data[dst_idx]); + } else { + T_ACCUM res = regC[k_l_reg][npq_l_vec_reg]; + for (int v = 0; v < VEC_SIZE; ++v) { + const uint npq_g = npq_g_base + v; + if (npq_g < NPQ) { + const uint N_idx_s = npq_g / (OH*OW); + const uint pq_idx_s = npq_g % (OH*OW); + const uint OH_idx_s = pq_idx_s / OW; + const uint OW_idx_s = pq_idx_s % OW; + const uint dst_idx_s = OW_idx_s + OH_idx_s*nb1 + k_g*nb2 + N_idx_s*nb3; + dst_data[dst_idx_s] = (T_FLOAT)(((float*)&res)[v]); + } + } + } + } + } +} diff --git a/ggml/src/ggml-opencl/kernels/conv2d_f16_f32.cl b/ggml/src/ggml-opencl/kernels/conv2d_f16_f32.cl new file mode 100644 index 000000000..cb05637f3 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/conv2d_f16_f32.cl @@ -0,0 +1,176 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +#if defined(cl_qcom_reqd_sub_group_size) +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable +#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full"))) +#else +#define REQD_SUBGROUP_SIZE_128 +#endif + +#define T_ACCUM float4 +#define VEC_SIZE 4 + +#define BS_K 64 +#define BS_NPQ 64 +#define BS_CRS 16 + +#define TS_K 4 +#define TS_NPQ 8 + +#define WG_K (BS_K / TS_K) +#define WG_NPQ (BS_NPQ / TS_NPQ) + +#define BS_NPQ_VEC (BS_NPQ / VEC_SIZE) +#define TS_NPQ_VEC (TS_NPQ / VEC_SIZE) + +static inline uint splitWork(uint work_size, uint block_size){ + return (work_size + block_size - 1) / block_size; +} + +REQD_SUBGROUP_SIZE_128 +kernel void kernel_conv_2d( + global void* p_knl, + ulong off_knl, + global void* p_src, + ulong off_src, + global void* p_dst, + ulong off_dst, + local void* shared, + uint Cout, uint Cin, uint N, + uint KW, uint KH, uint W, uint H, uint OW, uint OH, + uint s0, uint s1, uint p0, uint p1, uint d0, uint d1, + uint nb01, uint nb02, uint nb03, + uint nb11, uint nb12, uint nb13, + uint nb1, uint nb2, uint nb3 +) { + global half* knl_data = (global half*) ((global char*)p_knl + off_knl); + global float* src_data = (global float*) ((global char*)p_src + off_src); + global float* dst_data = (global float*) ((global char*)p_dst + off_dst); + + const uint K = Cout; + const uint CRS = Cin*KH*KW; + const uint NPQ = N*OH*OW; + + const uint lid_k = get_local_id(0); + const uint lid_npq = get_local_id(1); + const uint tid = lid_npq * WG_K + lid_k; + + const uint B_idx_K = get_group_id(0); + const uint B_idx_NPQ = get_group_id(1); + + const uint offset_k = B_idx_K * BS_K; + const uint offset_npq = B_idx_NPQ * BS_NPQ; + + local half* Ash = (local half*)shared; + local float4* Bsh = (local float4*) &Ash[BS_K * BS_CRS]; + + T_ACCUM regC[TS_K][TS_NPQ_VEC]; + for (int i = 0; i < TS_K; ++i) { + for (int j = 0; j < TS_NPQ_VEC; ++j) { + regC[i][j] = (T_ACCUM)(0.0f); + } + } + + const uint NB_CRS = splitWork(CRS, BS_CRS); + + for (uint B_idx_CRS = 0; B_idx_CRS < NB_CRS; ++B_idx_CRS) { + const uint offset_crs = B_idx_CRS * BS_CRS; + + for (int i = tid; i < BS_K * BS_CRS; i += (WG_K * WG_NPQ)) { + const uint k_l = i / BS_CRS; + const uint crs_l = i % BS_CRS; + const uint k_g = offset_k + k_l; + const uint crs_g = offset_crs + crs_l; + + if (k_g < K && crs_g < CRS) { + const uint Cin_idx = crs_g / (KW*KH); + const uint KH_idx = (crs_g - Cin_idx*KW*KH) / KW; + const uint KW_idx = crs_g - Cin_idx*KW*KH - KH_idx*KW; + const uint knl_idx = KW_idx + KH_idx*nb01 + Cin_idx*nb02 + k_g*nb03; + Ash[k_l * BS_CRS + crs_l] = knl_data[knl_idx]; + } else { + Ash[k_l * BS_CRS + crs_l] = (half)0.0f; + } + } + + for (int i = tid; i < BS_CRS * BS_NPQ_VEC; i += (WG_K * WG_NPQ)) { + const uint crs_l = i / BS_NPQ_VEC; + const uint npq_l_vec = i % BS_NPQ_VEC; + const uint crs_g = offset_crs + crs_l; + + float4 val = (float4)(0.0f); + if (crs_g < CRS) { + const uint Cin_idx = crs_g / (KW * KH); + const uint KH_idx = (crs_g - Cin_idx * KW * KH) / KW; + const uint KW_idx = crs_g - Cin_idx * KW * KH - KH_idx * KW; + for (int v = 0; v < VEC_SIZE; ++v) { + const uint npq_g = offset_npq + npq_l_vec * VEC_SIZE + v; + if (npq_g < NPQ) { + const uint N_idx = npq_g / (OH * OW); + const uint pq_idx = npq_g % (OH * OW); + const uint OH_idx = pq_idx / OW; + const uint OW_idx = pq_idx % OW; + const int H_idx = (int)(OH_idx * s1 + KH_idx * d1 - p1); + const int W_idx = (int)(OW_idx * s0 + KW_idx * d0 - p0); + + if (H_idx >= 0 && H_idx < H && W_idx >= 0 && W_idx < W) { + const uint src_idx = W_idx + H_idx * nb11 + Cin_idx * nb12 + N_idx * nb13; + ((float*)&val)[v] = src_data[src_idx]; + } + } + } + } + Bsh[crs_l * BS_NPQ_VEC + npq_l_vec] = val; + } + + barrier(CLK_LOCAL_MEM_FENCE); + + #pragma unroll + for (uint crs_l = 0; crs_l < BS_CRS; ++crs_l) { + half regA[TS_K]; + for (uint k_l_reg = 0; k_l_reg < TS_K; ++k_l_reg) { + regA[k_l_reg] = Ash[(lid_k * TS_K + k_l_reg) * BS_CRS + crs_l]; + } + + for (uint npq_l_vec_reg = 0; npq_l_vec_reg < TS_NPQ_VEC; ++npq_l_vec_reg) { + float4 regB = Bsh[crs_l * BS_NPQ_VEC + lid_npq * TS_NPQ_VEC + npq_l_vec_reg]; + for (uint k_l_reg = 0; k_l_reg < TS_K; ++k_l_reg) { + regC[k_l_reg][npq_l_vec_reg] = mad(convert_float(regA[k_l_reg]), regB, regC[k_l_reg][npq_l_vec_reg]); + } + } + } + barrier(CLK_LOCAL_MEM_FENCE); + } + + for (uint k_l_reg = 0; k_l_reg < TS_K; ++k_l_reg) { + const uint k_g = offset_k + lid_k * TS_K + k_l_reg; + if (k_g >= K) continue; + + for (uint npq_l_vec_reg = 0; npq_l_vec_reg < TS_NPQ_VEC; ++npq_l_vec_reg) { + const uint npq_g_base = offset_npq + (lid_npq * TS_NPQ_VEC + npq_l_vec_reg) * VEC_SIZE; + + const uint N_idx = npq_g_base / (OH * OW); + const uint pq_idx = npq_g_base % (OH * OW); + const uint OH_idx = pq_idx / OW; + const uint OW_idx = pq_idx % OW; + + if (nb1 == OW && OW_idx + VEC_SIZE <= OW && npq_g_base + VEC_SIZE <= NPQ) { + const uint dst_idx = OW_idx + OH_idx*nb1 + k_g*nb2 + N_idx*nb3; + vstore4(regC[k_l_reg][npq_l_vec_reg], 0, &dst_data[dst_idx]); + } else { + T_ACCUM res = regC[k_l_reg][npq_l_vec_reg]; + for (int v = 0; v < VEC_SIZE; ++v) { + const uint npq_g = npq_g_base + v; + if (npq_g < NPQ) { + const uint N_idx_s = npq_g / (OH*OW); + const uint pq_idx_s = npq_g % (OH*OW); + const uint OH_idx_s = pq_idx_s / OW; + const uint OW_idx_s = pq_idx_s % OW; + const uint dst_idx_s = OW_idx_s + OH_idx_s*nb1 + k_g*nb2 + N_idx_s*nb3; + dst_data[dst_idx_s] = ((float*)&res)[v]; + } + } + } + } + } +} From 38d3af1b73302377111e88ddd725257638339286 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sigbj=C3=B8rn=20Skj=C3=A6ret?= Date: Mon, 21 Jul 2025 22:55:10 +0200 Subject: [PATCH 3/8] opencl: fix `im2col` when `KW!=KH` (#14803) --- ggml/src/ggml-opencl/kernels/im2col_f16.cl | 2 +- ggml/src/ggml-opencl/kernels/im2col_f32.cl | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/ggml/src/ggml-opencl/kernels/im2col_f16.cl b/ggml/src/ggml-opencl/kernels/im2col_f16.cl index b84c89846..cf6cdaa4c 100644 --- a/ggml/src/ggml-opencl/kernels/im2col_f16.cl +++ b/ggml/src/ggml-opencl/kernels/im2col_f16.cl @@ -31,7 +31,7 @@ kernel void kernel_im2col_f16( src1 = (global float*)((global char*)src1 + offset1); dst = (global half*)((global char*)dst + offsetd); - long ksize = OW * (KH > 1 ? KW : 1); + long ksize = OW * KH; long kx = i / ksize; long kd = kx * ksize; long ky = (i - kd) / OW; diff --git a/ggml/src/ggml-opencl/kernels/im2col_f32.cl b/ggml/src/ggml-opencl/kernels/im2col_f32.cl index 4bf65e4ea..1ecdb2344 100644 --- a/ggml/src/ggml-opencl/kernels/im2col_f32.cl +++ b/ggml/src/ggml-opencl/kernels/im2col_f32.cl @@ -31,7 +31,7 @@ kernel void kernel_im2col_f32( src1 = (global float*)((global char*)src1 + offset1); dst = (global float*)((global char*)dst + offsetd); - long ksize = OW * (KH > 1 ? KW : 1); + long ksize = OW * KH; long kx = i / ksize; long kd = kx * ksize; long ky = (i - kd) / OW; From 48b86c4fdb1b1246d9fe17d2d53e3a1c2a0c3245 Mon Sep 17 00:00:00 2001 From: R0CKSTAR Date: Tue, 22 Jul 2025 07:45:26 +0800 Subject: [PATCH 4/8] cuda: remove linking to cublasLt (#14790) Signed-off-by: Xiaodong Ye --- ggml/src/ggml-cuda/CMakeLists.txt | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/ggml/src/ggml-cuda/CMakeLists.txt b/ggml/src/ggml-cuda/CMakeLists.txt index c9ff4aa32..98ed29bc9 100644 --- a/ggml/src/ggml-cuda/CMakeLists.txt +++ b/ggml/src/ggml-cuda/CMakeLists.txt @@ -102,12 +102,12 @@ if (CUDAToolkit_FOUND) if (GGML_STATIC) if (WIN32) # As of 12.3.1 CUDA Toolkit for Windows does not offer a static cublas library - target_link_libraries(ggml-cuda PRIVATE CUDA::cudart_static CUDA::cublas CUDA::cublasLt) + target_link_libraries(ggml-cuda PRIVATE CUDA::cudart_static CUDA::cublas) else () - target_link_libraries(ggml-cuda PRIVATE CUDA::cudart_static CUDA::cublas_static CUDA::cublasLt_static) + target_link_libraries(ggml-cuda PRIVATE CUDA::cudart_static CUDA::cublas_static) endif() else() - target_link_libraries(ggml-cuda PRIVATE CUDA::cudart CUDA::cublas CUDA::cublasLt) + target_link_libraries(ggml-cuda PRIVATE CUDA::cudart CUDA::cublas) endif() if (GGML_CUDA_NO_VMM) From adef81781a15083f218eae6c488b95cdad781971 Mon Sep 17 00:00:00 2001 From: Molly Sophia Date: Tue, 22 Jul 2025 09:24:22 +0800 Subject: [PATCH 5/8] server : allow setting `--reverse-prompt` arg (#14799) Signed-off-by: Molly Sophia --- common/arg.cpp | 2 +- tools/server/server.cpp | 5 +++++ 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/common/arg.cpp b/common/arg.cpp index c1151f51d..80f965cc7 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -1612,7 +1612,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex [](common_params & params, const std::string & value) { params.antiprompt.emplace_back(value); } - ).set_examples({LLAMA_EXAMPLE_MAIN})); + ).set_examples({LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_SERVER})); add_opt(common_arg( {"-sp", "--special"}, string_format("special tokens output enabled (default: %s)", params.special ? "true" : "false"), diff --git a/tools/server/server.cpp b/tools/server/server.cpp index 256a2928b..022b5d0b3 100644 --- a/tools/server/server.cpp +++ b/tools/server/server.cpp @@ -253,6 +253,7 @@ struct server_task { defaults.sampling = params_base.sampling; defaults.speculative = params_base.speculative; defaults.n_keep = params_base.n_keep; + defaults.antiprompt = params_base.antiprompt; // enabling this will output extra debug information in the HTTP responses from the server params.verbose = params_base.verbosity > 9; @@ -490,6 +491,10 @@ struct server_task { } } } + // set reverse prompt from cli args if not set in the request + if (params.antiprompt.empty()) { + params.antiprompt = defaults.antiprompt; + } } { From 8e6f8bc875358968b63e08f7bbbe0a288f29d856 Mon Sep 17 00:00:00 2001 From: lhez Date: Mon, 21 Jul 2025 23:53:30 -0700 Subject: [PATCH 6/8] opencl: remove unreachable `return` (#14806) --- ggml/src/ggml-opencl/ggml-opencl.cpp | 1 - 1 file changed, 1 deletion(-) diff --git a/ggml/src/ggml-opencl/ggml-opencl.cpp b/ggml/src/ggml-opencl/ggml-opencl.cpp index a31483b61..63ac4a989 100644 --- a/ggml/src/ggml-opencl/ggml-opencl.cpp +++ b/ggml/src/ggml-opencl/ggml-opencl.cpp @@ -5103,7 +5103,6 @@ static void ggml_cl_conv_2d(ggml_backend_t backend, const ggml_tensor * src0, co shmem_size = (size_t)(BS_K * BS_CRS * sizeof(cl_half) + BS_CRS * (BS_NPQ / VEC_SIZE) * sizeof(cl_float4)); } else { GGML_ASSERT(false && "Unsupported data type combination for conv2d"); - return; } cl_uint idx = 0; From e28c0b80c249b5618c3a0d5e960f63929f380ac9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sigbj=C3=B8rn=20Skj=C3=A6ret?= Date: Tue, 22 Jul 2025 12:33:10 +0200 Subject: [PATCH 7/8] cuda : implement bf16 cpy ops and enable bf16 cont (#14763) * implement bf16 cpy ops and enable bf16 cont * deduplicate copy functions * deduplicate checks --- ggml/src/ggml-cuda/cpy-utils.cuh | 46 ++++------------- ggml/src/ggml-cuda/cpy.cu | 89 ++++++++++++-------------------- ggml/src/ggml-cuda/ggml-cuda.cu | 18 ++----- ggml/src/ggml-cuda/set-rows.cu | 20 +------ 4 files changed, 49 insertions(+), 124 deletions(-) diff --git a/ggml/src/ggml-cuda/cpy-utils.cuh b/ggml/src/ggml-cuda/cpy-utils.cuh index e7a0bd2f1..410c12b7b 100644 --- a/ggml/src/ggml-cuda/cpy-utils.cuh +++ b/ggml/src/ggml-cuda/cpy-utils.cuh @@ -2,24 +2,13 @@ #include "ggml-common.h" -static __device__ __forceinline__ void convert_f32_f32(const float * src, float * dst) { - *dst = *src; -} - -static __device__ __forceinline__ void convert_f32_f16(const float * src, half * dst) { - *dst = __float2half(*src); -} - -static __device__ __forceinline__ void convert_f32_bf16(const float * src, nv_bfloat16 * dst) { - *dst = *src; -} - -static __device__ __forceinline__ void convert_f16_f16(const half * src, half * dst) { - *dst = *src; -} - -static __device__ __forceinline__ void convert_f16_f32(const half * src, float * dst) { - *dst = *src; +template +static __device__ __forceinline__ void convert_flt(const src_t * src, dst_t * dst) { + if constexpr (std::is_same_v) { + *dst = *src; + } else { + *dst = float(*src); + } } static __device__ __forceinline__ int best_index_int8(int n, const int8_t * val, float x) { @@ -230,22 +219,7 @@ static __device__ void cpy_blck_f32_iq4_nl(const char * cxi, char * cdsti) { quantize_f32_iq4_nl_block((const float *)cxi, (block_iq4_nl *)cdsti); } -static __device__ void cpy_1_f32_f32(const char * cxi, char * cdsti) { - convert_f32_f32((const float *)cxi, (float *)cdsti); -} - -static __device__ void cpy_1_f32_f16(const char * cxi, char * cdsti) { - convert_f32_f16((const float *)cxi, (half *)cdsti); -} - -static __device__ void cpy_1_f32_bf16(const char * cxi, char * cdsti) { - convert_f32_bf16((const float *)cxi, (nv_bfloat16 *)cdsti); -} - -static __device__ void cpy_1_f16_f16(const char * cxi, char * cdsti) { - convert_f16_f16((const half *)cxi, (half *)cdsti); -} - -static __device__ void cpy_1_f16_f32(const char * cxi, char * cdsti) { - convert_f16_f32((const half *)cxi, (float *)cdsti); +template +static __device__ void cpy_1_flt(const char * cxi, char * cdsti) { + convert_flt((const src_t *)cxi, (dst_t *)cdsti); } diff --git a/ggml/src/ggml-cuda/cpy.cu b/ggml/src/ggml-cuda/cpy.cu index e7d0da087..0e5964907 100644 --- a/ggml/src/ggml-cuda/cpy.cu +++ b/ggml/src/ggml-cuda/cpy.cu @@ -8,10 +8,10 @@ typedef void (*cpy_kernel_t)(const char * cx, char * cdst); template -static __global__ void cpy_f32_f16(const char * cx, char * cdst_direct, const int ne, - const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, - const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, - const int nb12, const int nb13, char ** cdst_indirect, int graph_cpynode_index) { +static __global__ void cpy_flt(const char * cx, char * cdst_direct, const int ne, + const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, + const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, + const int nb12, const int nb13, char ** cdst_indirect, int graph_cpynode_index) { const int64_t i = blockDim.x*blockIdx.x + threadIdx.x; if (i >= ne) { @@ -139,43 +139,14 @@ void ggml_cuda_cpy_dest_ptrs_copy(ggml_cuda_graph * cuda_graph, char ** host_des #endif } -static void ggml_cpy_f16_f32_cuda( +template +static void ggml_cpy_flt_cuda( const char * cx, char * cdst, const int ne, const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) { const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE; - cpy_f32_f16<<>> - (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++); -} - -static void ggml_cpy_f32_f32_cuda( - const char * cx, char * cdst, const int ne, - const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, - const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) { - - const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE; - cpy_f32_f16<<>> - (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++); -} - -static void ggml_cpy_f32_bf16_cuda( - const char * cx, char * cdst, const int ne, - const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, - const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) { - - const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE; - cpy_f32_f16<<>> - (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++); -} - -static void ggml_cpy_f32_f16_cuda( - const char * cx, char * cdst, const int ne, - const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, - const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) { - - const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE; - cpy_f32_f16<<>> + cpy_flt><<>> (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++); } @@ -307,16 +278,6 @@ static void ggml_cpy_f32_iq4_nl_cuda( (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++); } -static void ggml_cpy_f16_f16_cuda( - const char * cx, char * cdst, const int ne, - const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, - const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) { - - const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE; - cpy_f32_f16<<>> - (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++); -} - void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, ggml_tensor * src1, bool disable_indirection_for_this_node) { const int64_t ne = ggml_nelements(src0); GGML_ASSERT(ne == ggml_nelements(src1)); @@ -372,11 +333,11 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg CUDA_CHECK(cudaMemcpyAsync(src1_ddc, src0_ddc, ggml_nbytes(src0), cudaMemcpyDeviceToDevice, main_stream)); } } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) { - ggml_cpy_f32_f32_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); + ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_BF16) { - ggml_cpy_f32_bf16_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); + ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16) { - ggml_cpy_f32_f16_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); + ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q8_0) { ggml_cpy_f32_q8_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); } else if (src0->type == GGML_TYPE_Q8_0 && src1->type == GGML_TYPE_F32) { @@ -403,9 +364,17 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg } else if (src0->type == GGML_TYPE_Q5_1 && src1->type == GGML_TYPE_F32) { ggml_cpy_q5_1_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) { - ggml_cpy_f16_f16_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); + ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); + } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_BF16) { + ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) { - ggml_cpy_f16_f32_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); + ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); + } else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_BF16) { + ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); + } else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_F16) { + ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); + } else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_F32) { + ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); } else { GGML_ABORT("%s: unsupported type combination (%s to %s)\n", __func__, ggml_type_name(src0->type), ggml_type_name(src1->type)); @@ -430,11 +399,11 @@ void* ggml_cuda_cpy_fn(const ggml_tensor * src0, ggml_tensor * src1) { if (src0->type == src1->type && ggml_is_contiguous(src0) && ggml_is_contiguous(src1)) { return nullptr; } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) { - return (void*) cpy_f32_f16; + return (void*) cpy_flt>; } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_BF16) { - return (void*) cpy_f32_f16; + return (void*) cpy_flt>; } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16) { - return (void*) cpy_f32_f16; + return (void*) cpy_flt>; } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q8_0) { return (void*) cpy_f32_q; } else if (src0->type == GGML_TYPE_Q8_0 && src1->type == GGML_TYPE_F32) { @@ -458,9 +427,17 @@ void* ggml_cuda_cpy_fn(const ggml_tensor * src0, ggml_tensor * src1) { } else if (src0->type == GGML_TYPE_Q5_1 && src1->type == GGML_TYPE_F32) { return (void*) cpy_q_f32, QK5_1>; } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) { - return (void*) cpy_f32_f16; + return (void*) cpy_flt>; + } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_BF16) { + return (void*) cpy_flt>; } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) { - return (void*) cpy_f32_f16; + return (void*) cpy_flt>; + } else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_F16) { + return (void*) cpy_flt>; + } else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_BF16) { + return (void*) cpy_flt>; + } else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_F32) { + return (void*) cpy_flt>; } else { GGML_ABORT("%s: unsupported type combination (%s to %s)\n", __func__, ggml_type_name(src0->type), ggml_type_name(src1->type)); diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index dfc50ef0d..548bc31ce 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -3242,13 +3242,9 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g { ggml_type src0_type = op->src[0]->type; ggml_type src1_type = op->src[1]->type; - if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F32) { - return true; - } - if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_BF16) { - return true; - } - if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F16) { + if ((src0_type == GGML_TYPE_F32 || src0_type == GGML_TYPE_BF16 || src0_type == GGML_TYPE_F16) && + (src1_type == GGML_TYPE_F32 || src1_type == GGML_TYPE_BF16 || src1_type == GGML_TYPE_F16) + ) { return true; } if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_Q8_0) { @@ -3284,12 +3280,6 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_IQ4_NL) { return true; } - if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F16) { - return true; - } - if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F32) { - return true; - } if (src0_type == src1_type && ggml_is_contiguous(op->src[0]) && ggml_is_contiguous(op->src[1])) { return true; } @@ -3370,7 +3360,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g return op->src[0]->ne[1] % 128 == 0; } case GGML_OP_CONT: - return op->src[0]->type != GGML_TYPE_BF16; + return true; case GGML_OP_DIAG_MASK_INF: return true; case GGML_OP_SOFT_MAX: diff --git a/ggml/src/ggml-cuda/set-rows.cu b/ggml/src/ggml-cuda/set-rows.cu index 560604d09..b2acdf855 100644 --- a/ggml/src/ggml-cuda/set-rows.cu +++ b/ggml/src/ggml-cuda/set-rows.cu @@ -4,24 +4,8 @@ typedef void (*set_rows_kernel_t)(const char * src, char * dst); template -__device__ void set_rows_1(const src_t * src_f, dst_t * dst_f) { - GGML_UNUSED(src_f); - GGML_UNUSED(dst_f); -} - -template<> -__device__ __forceinline__ void set_rows_1(const float * src_f, half * dst_h) { - convert_f32_f16(src_f, dst_h); -} - -template<> -__device__ __forceinline__ void set_rows_1(const float * src_f, nv_bfloat16 * dst_b) { - convert_f32_bf16(src_f, dst_b); -} - -template<> -__device__ __forceinline__ void set_rows_1(const float * src_f, float * dst_f) { - convert_f32_f32(src_f, dst_f); +__device__ __forceinline__ void set_rows_1(const src_t * src_f, dst_t * dst_f) { + convert_flt(src_f, dst_f); } // Generic quantized set_rows kernel template From c8ade30036139e32108fee53d8b7164dbfda4bee Mon Sep 17 00:00:00 2001 From: stduhpf Date: Tue, 22 Jul 2025 12:51:03 +0200 Subject: [PATCH 8/8] Mtmd: add a way to select device for vision encoder (#14236) * Mtmd: add a way to select device for vision encoder * simplify * format * Warn user if manual device selection failed * initialize backend to nullptr --- tools/mtmd/clip.cpp | 19 ++++++++++++++----- 1 file changed, 14 insertions(+), 5 deletions(-) diff --git a/tools/mtmd/clip.cpp b/tools/mtmd/clip.cpp index 9146c9e9c..be191404c 100644 --- a/tools/mtmd/clip.cpp +++ b/tools/mtmd/clip.cpp @@ -367,8 +367,8 @@ struct clip_ctx { std::vector backend_ptrs; std::vector backend_buft; - ggml_backend_t backend; - ggml_backend_t backend_cpu; + ggml_backend_t backend = nullptr; + ggml_backend_t backend_cpu = nullptr; ggml_backend_buffer_ptr buf; int max_nodes = 8192; @@ -384,9 +384,18 @@ struct clip_ctx { if (!backend_cpu) { throw std::runtime_error("failed to initialize CPU backend"); } - backend = ctx_params.use_gpu - ? ggml_backend_init_by_type(GGML_BACKEND_DEVICE_TYPE_GPU, nullptr) - : nullptr; + if (ctx_params.use_gpu) { + auto backend_name = std::getenv("MTMD_BACKEND_DEVICE"); + if (backend_name != nullptr) { + backend = ggml_backend_init_by_name(backend_name, nullptr); + if (!backend) { + LOG_WRN("%s: Warning: Failed to initialize \"%s\" backend, falling back to default GPU backend\n", __func__, backend_name); + } + } + if (!backend) { + backend = ggml_backend_init_by_type(GGML_BACKEND_DEVICE_TYPE_GPU, nullptr); + } + } if (backend) { LOG_INF("%s: CLIP using %s backend\n", __func__, ggml_backend_name(backend));