sycl: Q5_K reorder MMVQ/dequant + Q8_0 reorder MMVQ path (#22152)

* sycl: Q5_K reorder MMVQ/dequant + Q8_0 reorder MMVQ path

Signed-off-by: Chun Tao <chun.tao@intel.com>

* Remove duplicate definitions

---------

Signed-off-by: Chun Tao <chun.tao@intel.com>
Co-authored-by: Chun Tao <chun.tao@intel.com>
Co-authored-by: Todd Malsbary <todd.malsbary@intel.com>
This commit is contained in:
Intel AI Get-to Market Customer Success and Solutions 2026-05-08 22:48:07 -07:00 committed by GitHub
parent 4a4f819cb6
commit 60489932ec
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 265 additions and 26 deletions

View file

@ -252,6 +252,23 @@ static void dequantize_row_q5_K_sycl(const void *vx, dst_t *y, const int64_t k,
#endif
}
template <typename dst_t>
static void dequantize_row_q5_K_sycl_reorder(const void * vx, dst_t * y, const int64_t k, dpct::queue_ptr stream) {
const int64_t nb = k / QK_K;
dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 });
stream->submit([&](sycl::handler & cgh) {
sycl::local_accessor<uint8_t, 1> scale_local_acc(sycl::range<1>(K_SCALE_SIZE), cgh);
cgh.parallel_for(
sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * sycl::range<3>(1, 1, 64), sycl::range<3>(1, 1, 64)),
[=](sycl::nd_item<3> item_ct1) {
dequantize_block_q5_K_reorder(vx, y, get_pointer(scale_local_acc), item_ct1, nb);
});
});
}
template <typename dst_t>
static void dequantize_row_q6_K_sycl(const void *vx, dst_t *y, const int64_t k,
dpct::queue_ptr stream) {
@ -643,7 +660,11 @@ to_fp16_sycl_t ggml_get_to_fp16_sycl(ggml_type type, ggml_tensor * dst) {
return dequantize_row_q4_K_sycl;
}
case GGML_TYPE_Q5_K:
return dequantize_row_q5_K_sycl;
if (dst->src[0]->extra && ((ggml_tensor_extra_gpu *) dst->src[0]->extra)->optimized_feature.reorder) {
return dequantize_row_q5_K_sycl_reorder;
} else {
return dequantize_row_q5_K_sycl;
}
case GGML_TYPE_Q6_K:
if (dst->src[0]->extra && ((ggml_tensor_extra_gpu *) dst->src[0]->extra)->optimized_feature.reorder) {
return dequantize_row_q6_K_sycl_reorder;
@ -718,7 +739,11 @@ to_fp32_sycl_t ggml_get_to_fp32_sycl(ggml_type type, ggml_tensor *dst) {
return dequantize_row_q4_K_sycl;
}
case GGML_TYPE_Q5_K:
return dequantize_row_q5_K_sycl;
if (dst->src[0]->extra && ((ggml_tensor_extra_gpu *) dst->src[0]->extra)->optimized_feature.reorder) {
return dequantize_row_q5_K_sycl_reorder;
} else {
return dequantize_row_q5_K_sycl;
}
case GGML_TYPE_Q6_K:
if (dst->src[0]->extra && ((ggml_tensor_extra_gpu *) dst->src[0]->extra)->optimized_feature.reorder) {
return dequantize_row_q6_K_sycl_reorder;

View file

@ -537,6 +537,63 @@ static void dequantize_block_q5_K(const void * __restrict__ vx, dst_t * __restri
#endif
}
template <typename dst_t>
static void dequantize_block_q5_K_reorder(const void * __restrict__ vx, dst_t * __restrict__ yy,
uint8_t * scales_local, const sycl::nd_item<3> & item_ct1, int64_t n_blocks) {
const int64_t ib = item_ct1.get_group(2);
#if QK_K == 256
// assume 64 threads
const int64_t tid = item_ct1.get_local_id(2);
const int64_t il = tid / 16; // 0...3
const int64_t ir = tid % 16; // 0...15
const int64_t is = 2 * il;
dst_t * y = yy + ib * QK_K + 64 * il + 2 * ir;
const uint8_t * base = static_cast<const uint8_t *>(vx);
// Reordered layout: [qs (QK_K/2 per block)] [qh (QK_K/8 per block)] [scales (K_SCALE_SIZE per block)] [dm (half2 per block)]
const size_t qs_offset = ib * (QK_K / 2);
const size_t qh_offset = n_blocks * (QK_K / 2) + ib * (QK_K / 8);
const size_t scales_offset = n_blocks * (QK_K / 2) + n_blocks * (QK_K / 8) + ib * K_SCALE_SIZE;
const size_t dm_offset = n_blocks * (QK_K / 2) + n_blocks * (QK_K / 8) + n_blocks * K_SCALE_SIZE + ib * sizeof(ggml_half2);
const uint8_t * qs_ptr = base + qs_offset;
const uint8_t * qh_ptr = base + qh_offset;
const uint8_t * scales_ptr = base + scales_offset;
const ggml_half2 dm_values = *reinterpret_cast<const ggml_half2 *>(base + dm_offset);
const float dall = dm_values.x();
const float dmin = dm_values.y();
const uint8_t * ql = qs_ptr + 32 * il + 2 * ir;
const uint8_t * qh = qh_ptr + 2 * ir;
if (tid < K_SCALE_SIZE) {
scales_local[tid] = scales_ptr[tid];
}
item_ct1.barrier(sycl::access::fence_space::local_space);
uint8_t sc, m;
get_scale_min_k4(is + 0, scales_local, sc, m);
const float d1 = dall * sc; const float m1 = dmin * m;
get_scale_min_k4(is + 1, scales_local, sc, m);
const float d2 = dall * sc; const float m2 = dmin * m;
uint8_t hm = 1 << (2 * il);
y[ 0] = d1 * ((ql[ 0] & 0xF) + (qh[ 0] & hm ? 16 : 0)) - m1;
y[ 1] = d1 * ((ql[ 1] & 0xF) + (qh[ 1] & hm ? 16 : 0)) - m1;
hm <<= 1;
y[32] = d2 * ((ql[ 0] >> 4) + (qh[ 0] & hm ? 16 : 0)) - m2;
y[33] = d2 * ((ql[ 1] >> 4) + (qh[ 1] & hm ? 16 : 0)) - m2;
#else
GGML_UNUSED(ib); GGML_UNUSED(tid); GGML_UNUSED(yy); GGML_UNUSED(scales_local); GGML_UNUSED(n_blocks);
GGML_ABORT("Q5_K reorder dequantize not supported for QK_K != 256");
#endif
}
template<typename dst_t>
static void dequantize_block_q6_K(const void * __restrict__ vx, dst_t * __restrict__ yy,
const sycl::nd_item<3> &item_ct1) {

View file

@ -3303,6 +3303,7 @@ inline bool ggml_sycl_supports_reorder_mul_mat_sycl(enum ggml_type type) {
case GGML_TYPE_Q8_0:
return true;
case GGML_TYPE_Q4_K:
case GGML_TYPE_Q5_K:
case GGML_TYPE_Q6_K:
return !g_ggml_sycl_prioritize_dmmv;
default:
@ -3325,6 +3326,7 @@ inline bool ggml_sycl_supports_reorder_mmvq(enum ggml_type type) {
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q8_0:
case GGML_TYPE_Q4_K:
case GGML_TYPE_Q5_K:
case GGML_TYPE_Q6_K:
return true;
default:
@ -3541,6 +3543,54 @@ static bool reorder_qw_q4_k(uint8_t * data_device, size_t size, size_t offset, d
return true;
}
static bool reorder_qw_q5_k(uint8_t * data_device, size_t size, size_t offset, dpct::queue_ptr stream) {
GGML_ASSERT(size % sizeof(block_q5_K) == 0);
GGML_ASSERT(offset % sizeof(block_q5_K) == 0);
const int nblocks = size / sizeof(block_q5_K);
sycl_reorder_temp_buffer tmp(stream, size);
if (!tmp) {
GGML_LOG_WARN("%s: failed to allocate %zu bytes for reorder temp buffer, skipping reorder\n", __func__, size);
return false;
}
uint8_t * tmp_buf = static_cast<uint8_t *>(tmp.ptr);
sycl::event copy_event;
SYCL_CHECK(CHECK_TRY_ERROR(copy_event = stream->memcpy(tmp_buf, data_device, size)));
if (!g_ggml_sycl_use_async_mem_op) {
copy_event.wait();
}
auto * qs_ptr = data_device;
auto * qh_ptr = qs_ptr + (QK_K / 2) * nblocks;
auto * scales_ptr = qh_ptr + (QK_K / 8) * nblocks;
auto * dm_ptr = (sycl::half2 *) (scales_ptr + K_SCALE_SIZE * nblocks);
auto reorder_event = stream->parallel_for(nblocks, [=](auto i) {
const block_q5_K * x = (const block_q5_K *) tmp_buf;
const int ib = i;
for (int j = 0; j < QK_K / 2; ++j) {
qs_ptr[ib * (QK_K / 2) + j] = x[ib].qs[j];
}
for (int j = 0; j < QK_K / 8; ++j) {
qh_ptr[ib * (QK_K / 8) + j] = x[ib].qh[j];
}
for (int j = 0; j < K_SCALE_SIZE; ++j) {
scales_ptr[ib * K_SCALE_SIZE + j] = x[ib].scales[j];
}
dm_ptr[ib] = x[ib].dm;
});
if (!g_ggml_sycl_use_async_mem_op) {
reorder_event.wait_and_throw();
}
return true;
}
static bool reorder_qw_q6_k(uint8_t * data_device, size_t size, size_t offset, dpct::queue_ptr stream) {
GGML_ASSERT(size % sizeof(block_q6_K) == 0);
GGML_ASSERT(offset % sizeof(block_q6_K) == 0);
@ -3607,6 +3657,8 @@ static bool reorder_qw(const ggml_tensor * src0, dpct::queue_ptr stream) {
return reorder_qw_q8_0(data_device, ncols, nrows, size, 0, stream);
case GGML_TYPE_Q4_K:
return reorder_qw_q4_k(data_device, size, 0, stream);
case GGML_TYPE_Q5_K:
return reorder_qw_q5_k(data_device, size, 0, stream);
case GGML_TYPE_Q6_K:
return reorder_qw_q6_k(data_device, size, 0, stream);
default:

View file

@ -839,6 +839,26 @@ static void mul_mat_vec_q5_K_q8_1_sycl(const void *vx, const void *vy,
}
}
static void reorder_mul_mat_vec_q5_k_q8_1_sycl(const void * vx, const void * vy, float * dst, const int ncols,
const int nrows, dpct::queue_ptr stream) {
GGML_ASSERT(ncols % QK_K == 0);
const int block_num_y = ceil_div(nrows, GGML_SYCL_MMV_Y);
constexpr size_t num_subgroups = 16;
GGML_ASSERT(block_num_y % num_subgroups == 0);
const sycl::range<3> global_size(1, GGML_SYCL_MMV_Y, block_num_y * WARP_SIZE);
const sycl::range<3> workgroup_size(1, GGML_SYCL_MMV_Y, num_subgroups * WARP_SIZE);
stream->submit([&](sycl::handler & cgh) {
cgh.parallel_for(sycl::nd_range<3>(global_size, workgroup_size),
[=](sycl::nd_item<3> nd_item) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
mul_mat_vec_q_reorder<reorder_vec_dot_q_sycl<GGML_TYPE_Q5_K>>(vx, vy, dst, ncols,
nrows, nd_item);
});
});
}
static void reorder_mul_mat_vec_q6_k_q8_1_sycl(const void * vx, const void * vy, float * dst, const int ncols,
const int nrows, dpct::queue_ptr stream) {
GGML_ASSERT(ncols % QK_K == 0);
@ -1125,6 +1145,7 @@ void ggml_sycl_op_mul_mat_vec_q(ggml_backend_sycl_context & ctx, const ggml_tens
GGML_SYCL_DEBUG("Calling reorder_mul_mat_vec_q8_0_q8_1_sycl\n");
reorder_mul_mat_vec_q8_0_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
} else {
GGML_SYCL_DEBUG("Calling mul_mat_vec_q8_0_q8_1_sycl\n");
mul_mat_vec_q8_0_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
}
break;
@ -1145,7 +1166,14 @@ void ggml_sycl_op_mul_mat_vec_q(ggml_backend_sycl_context & ctx, const ggml_tens
}
break;
case GGML_TYPE_Q5_K:
mul_mat_vec_q5_K_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
if ((ggml_tensor_extra_gpu *) dst->src[0]->extra &&
((ggml_tensor_extra_gpu *) dst->src[0]->extra)->optimized_feature.reorder) {
GGML_SYCL_DEBUG("Calling reorder_mul_mat_vec_q5_k_q8_1_sycl\n");
reorder_mul_mat_vec_q5_k_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
} else {
GGML_SYCL_DEBUG("Calling mul_mat_vec_q5_K_q8_1_sycl\n");
mul_mat_vec_q5_K_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
}
break;
case GGML_TYPE_Q6_K:
if ((ggml_tensor_extra_gpu *) dst->src[0]->extra &&

View file

@ -79,6 +79,31 @@ template <> struct block_q_t<GGML_TYPE_Q4_K> {
static constexpr int block_to_q8_1_ratio() { return traits::qk / QK8_1; }
};
template <> struct block_q_t<GGML_TYPE_Q5_K> {
struct traits {
static constexpr uint32_t qk = QK_K;
static constexpr uint32_t qi = QI5_K;
static constexpr uint32_t qr = QR5_K;
static constexpr uint32_t vdr_mmvq = 2;
};
// Reordered layout: [qs (QK_K/2 per block)] [qh (QK_K/8 per block)] [scales] [dm]
static constexpr std::pair<int, int> get_block_offset(const int block_index, const int n_blocks) {
auto qs_offset = block_index * (QK_K / 2);
auto qh_offset = n_blocks * (QK_K / 2) + block_index * (QK_K / 8);
return { qs_offset, qh_offset };
}
static constexpr std::pair<int, int> get_d_offset(int nrows, int ncols, const int block_index) {
auto nblocks = (nrows * (ncols / QK_K));
auto total_qs_bytes = nblocks * (QK_K / 2) + nblocks * (QK_K / 8);
return { total_qs_bytes + block_index * K_SCALE_SIZE,
total_qs_bytes + nblocks * K_SCALE_SIZE + block_index * sizeof(ggml_half2) };
}
static constexpr int block_to_q8_1_ratio() { return traits::qk / QK8_1; }
};
template <> struct block_q_t<GGML_TYPE_Q6_K> {
struct traits {
static constexpr uint32_t qk = QK_K;

View file

@ -357,38 +357,31 @@ template <> struct reorder_vec_dot_q_sycl<GGML_TYPE_Q8_0> {
using q8_0_block = ggml_sycl_reordered::block_q_t<GGML_TYPE_Q8_0>;
using q8_0_traits = typename q8_0_block::traits;
__dpct_inline__ float vec_dot_q8_0_q8_1_impl(const int * v, const int * u, const float & d8_0, const sycl::half2 & ds8) {
int sumi = 0;
#pragma unroll
for (size_t i = 0; i < q8_0_traits::vdr_mmvq; ++i) {
// Q8_0 values are signed int8, no nibble extraction needed
// Direct dp4a: each int packs 4 int8 values
sumi = dpct::dp4a(v[i], u[i], sumi);
}
const sycl::float2 ds8f = ds8.convert<float, sycl::rounding_mode::automatic>();
// Q8_0 has no bias term (values are signed), so just scale
return d8_0 * sumi * ds8f.x();
}
__dpct_inline__ float operator()(const void * __restrict__ vbq, const std::pair<int, int> ibx_offset,
const std::pair<int, int> d_offset, const int8_t * q8_1_quant_ptr,
const sycl::half2 * q8_1_ds, const int & iqs) {
const int8_t * bq8_0 = static_cast<const int8_t *>(vbq) + ibx_offset.first;
const ggml_half d = *(reinterpret_cast<const ggml_half *>(static_cast<const uint8_t *>(vbq) + d_offset.first));
int v[q8_0_traits::vdr_mmvq];
int u[q8_0_traits::vdr_mmvq];
const uint8_t * base = static_cast<const uint8_t *>(vbq);
const int8_t * qs = reinterpret_cast<const int8_t *>(base + ibx_offset.first);
const ggml_half d = *reinterpret_cast<const ggml_half *>(base + d_offset.first);
int v[q8_0_traits::vdr_mmvq];
int u[q8_0_traits::vdr_mmvq];
#pragma unroll
for (size_t i = 0; i < q8_0_traits::vdr_mmvq; ++i) {
v[i] = get_int_from_int8(bq8_0, iqs + i);
v[i] = get_int_from_int8(qs, iqs + i);
u[i] = get_int_from_int8_aligned(q8_1_quant_ptr, iqs + i);
}
return vec_dot_q8_0_q8_1_impl(v, u, d, *q8_1_ds);
};
int sumi = 0;
#pragma unroll
for (size_t i = 0; i < q8_0_traits::vdr_mmvq; ++i) {
sumi = dpct::dp4a(v[i], u[i], sumi);
}
const sycl::half2 ds_values = *q8_1_ds;
return static_cast<float>(d) * static_cast<float>(ds_values[0]) * sumi;
}
};
static inline float vec_dot_q4_K_q8_1_common(const int * __restrict__ q4, const uint16_t * __restrict__ scales,
@ -481,6 +474,65 @@ template <> struct reorder_vec_dot_q_sycl<GGML_TYPE_Q4_K> {
}
};
template <> struct reorder_vec_dot_q_sycl<GGML_TYPE_Q5_K> {
static constexpr ggml_type gtype = GGML_TYPE_Q5_K;
using q5_k_block = ggml_sycl_reordered::block_q_t<GGML_TYPE_Q5_K>;
using q5_k_traits = typename q5_k_block::traits;
__dpct_inline__ float operator()(const void * __restrict__ vbq, const std::pair<int, int> ibx_offset,
const std::pair<int, int> d_offset, const int8_t * q8_1_quant_ptr,
const sycl::half2 * q8_1_ds, const int & iqs) {
const uint8_t * base = static_cast<const uint8_t *>(vbq);
const uint8_t * qs = base + ibx_offset.first; // low 4 bits
const uint8_t * qh_base = base + ibx_offset.second; // high bit
const uint8_t * scs = base + d_offset.first;
const ggml_half2 * dms = reinterpret_cast<const ggml_half2 *>(base + d_offset.second);
const int bq8_offset = QR5_K * ((iqs / 2) / (QI8_1 / 2));
const int * ql_ptr = (const int *) (qs + 16 * bq8_offset + 4 * ((iqs / 2) % 4));
const int * qh_ptr = (const int *) (qh_base + 4 * ((iqs / 2) % 4));
const uint16_t * scales = (const uint16_t *) scs;
int vl[2];
int vh[2];
int u[2 * QR5_K];
float d8[QR5_K];
vl[0] = ql_ptr[0];
vl[1] = ql_ptr[4];
vh[0] = qh_ptr[0] >> bq8_offset;
vh[1] = qh_ptr[4] >> bq8_offset;
uint16_t aux[2];
const int j = (QR5_K * ((iqs / 2) / (QI8_1 / 2))) / 2;
if (j < 2) {
aux[0] = scales[j + 0] & 0x3f3f;
aux[1] = scales[j + 2] & 0x3f3f;
} else {
aux[0] = ((scales[j + 2] >> 0) & 0x0f0f) | ((scales[j - 2] & 0xc0c0) >> 2);
aux[1] = ((scales[j + 2] >> 4) & 0x0f0f) | ((scales[j - 0] & 0xc0c0) >> 2);
}
const uint8_t * sc = (const uint8_t *) aux;
const uint8_t * m = sc + 2;
for (int i = 0; i < QR5_K; ++i) {
const int8_t* quant_base_ptr = q8_1_quant_ptr + (bq8_offset + i) * QK8_1;
sycl::half2 ds_values = *(q8_1_ds + bq8_offset + i);
d8[i] = ds_values[0];
const int * q8 = (const int *) quant_base_ptr + ((iqs / 2) % 4);
u[2 * i + 0] = q8[0];
u[2 * i + 1] = q8[4];
}
return vec_dot_q5_K_q8_1_impl_vmmq(vl, vh, u, sc, m, *dms, d8);
}
};
template <> struct reorder_vec_dot_q_sycl<GGML_TYPE_Q6_K> {
static constexpr ggml_type gtype = GGML_TYPE_Q6_K;