diff --git a/ggml/src/ggml-sycl/dmmv.cpp b/ggml/src/ggml-sycl/dmmv.cpp index 5577bf73b..4ae431a96 100644 --- a/ggml/src/ggml-sycl/dmmv.cpp +++ b/ggml/src/ggml-sycl/dmmv.cpp @@ -3,6 +3,13 @@ #include "dequantize.hpp" #include "presets.hpp" +#if defined(__INTEL_LLVM_COMPILER) + #if __has_include() + #include + #define GGML_SYCL_DMMV_HAS_BF16 + #endif +#endif + static void convert_f16(const void * vx, const int64_t ib, const int iqs, dfloat2 & v){ const sycl::half *x = (const sycl::half *)vx; @@ -11,6 +18,16 @@ static void convert_f16(const void * vx, const int64_t ib, const int iqs, dfloat v.y() = x[ib + iqs + 1]; } +#ifdef GGML_SYCL_DMMV_HAS_BF16 +static void convert_bf16(const void * vx, const int64_t ib, const int iqs, dfloat2 & v){ + const sycl::ext::oneapi::bfloat16 *x = (const sycl::ext::oneapi::bfloat16 *)vx; + + // automatic bfloat16 -> float type cast if dfloat == float + v.x() = x[ib + iqs + 0]; + v.y() = x[ib + iqs + 1]; +} +#endif + static void convert_f32(const void * vx, const int64_t ib, const int iqs, dfloat2 & v){ const float * x = (const float *) vx; @@ -217,6 +234,28 @@ static void convert_mul_mat_vec_f16_sycl(const void *vx, const dfloat *y, } } +#ifdef GGML_SYCL_DMMV_HAS_BF16 +static void convert_mul_mat_vec_bf16_sycl(const void *vx, const dfloat *y, + float *dst, const int ncols, + const int nrows, + dpct::queue_ptr stream) { + // The qk=1 kernel iterates with stride 2*GGML_SYCL_DMMV_X, so ncols must be a + // multiple of that — not just GGML_SYCL_DMMV_X — to avoid out-of-bounds reads. + GGML_ASSERT(ncols % (2*GGML_SYCL_DMMV_X) == 0); + const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y; + const sycl::range<3> block_nums(1, 1, block_num_y); + const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE); + { + stream->parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { + dequantize_mul_mat_vec<1, 1, convert_bf16>(vx, y, dst, ncols, + nrows, item_ct1); + }); + } +} +#endif + /* DPCT1110:4: The total declared local variable size in device function dequantize_mul_mat_vec_q2_k exceeds 128 bytes and may cause high register @@ -1497,7 +1536,8 @@ void ggml_sycl_op_dequantize_mul_mat_vec( bool src1_convert_f16 = src0->type == GGML_TYPE_Q4_0 || src0->type == GGML_TYPE_Q4_1 || src0->type == GGML_TYPE_Q5_0 || src0->type == GGML_TYPE_Q5_1 || - src0->type == GGML_TYPE_Q8_0 || src0->type == GGML_TYPE_F16; + src0->type == GGML_TYPE_Q8_0 || src0->type == GGML_TYPE_F16 || + src0->type == GGML_TYPE_BF16; if (src1_convert_f16) { scope_op_debug_print scope_dbg_print(__func__, "/to_fp16_sycl", dst, /*num_src=*/2, @@ -1565,6 +1605,11 @@ void ggml_sycl_op_dequantize_mul_mat_vec( case GGML_TYPE_F16: convert_mul_mat_vec_f16_sycl(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream); break; +#ifdef GGML_SYCL_DMMV_HAS_BF16 + case GGML_TYPE_BF16: + convert_mul_mat_vec_bf16_sycl(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream); + break; +#endif default: printf("ggml_sycl_op_dequantize_mul_mat_vec unsupported GGML_TYPE %d\n", src0->type); GGML_ABORT("fatal error"); diff --git a/ggml/src/ggml-sycl/ggml-sycl.cpp b/ggml/src/ggml-sycl/ggml-sycl.cpp index 2ea47f715..bba37a6f8 100644 --- a/ggml/src/ggml-sycl/ggml-sycl.cpp +++ b/ggml/src/ggml-sycl/ggml-sycl.cpp @@ -3455,6 +3455,7 @@ static bool ggml_sycl_supports_dmmv(enum ggml_type type) { case GGML_TYPE_Q5_K: case GGML_TYPE_Q6_K: case GGML_TYPE_F16: + case GGML_TYPE_BF16: return true; default: return false; @@ -3818,8 +3819,13 @@ static void opt_for_reorder(ggml_backend_sycl_context * ctx, const ggml_tensor * static bool can_use_dequantize_mul_mat_vec(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + // The F16/BF16 qk=1 kernel iterates with stride 2*DMMV_X, requiring ne[0] to be + // a multiple of 2*DMMV_X. Quantized types use block-structured kernels that only + // need ne[0] % DMMV_X == 0. + const int64_t dmmv_x_required = (src0->type == GGML_TYPE_BF16 || src0->type == GGML_TYPE_F16) ? + 2*GGML_SYCL_DMMV_X : GGML_SYCL_DMMV_X; return ggml_sycl_supports_dmmv(src0->type) && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32 && - src0->ne[0] % GGML_SYCL_DMMV_X == 0 && src1->ne[1] == 1; + src0->ne[0] % dmmv_x_required == 0 && src1->ne[1] == 1; } static bool can_use_mul_mat_vec_q(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {