sycl : fused MoE mul_mat_vec_q for TG (#21920)

* sycl : fused MoE mul_mat_vec_q for TG

Create an MMVQ kernel so ggml_sycl_mul_mat_id can consolidate
n_experts_used matmuls in a single kernel launch. The kernel
also reads expert IDs directly, removing a per-call host sync.

This is similar to the CUDA backend's ggml_cuda_mul_mat_vec_q*
paths.

All types supported in the current MMVQ are supported here as well:
Q2_K, Q3_K, Q4_K, Q5_K, Q6_K, Q4_0, Q4_1, Q5_0, Q5_1, Q8_0

It will fall back to the existing per-expert path when src0 has been rewritten
by opt_for_reorder(), and for any shape the fused path doesn't handle.

test-backend-ops passes for supported type/shape combos.

Benchmark: Qwen3-Next-35B-A3B Q4_K_M on Intel Arc B70 (SYCL0),
baseline 707c0b7a6, 16k context, -fa 0.

  build/bin/llama-bench -hf unsloth/Qwen3.5-35B-A3B-GGUF:Q4_K_M \
    -p 1024 -n 128 -d 16384 -ngl 99 -fa 0 -ub 2048 -r 2 -dev SYCL0

Before (3 runs on 707c0b7a6):

  | test            |            run 1 |            run 2 |            run 3 |
  | --------------- | ----------------:| ----------------:| ----------------:|
  | pp1024 @ d16384 |   533.26 ±  4.87 |   535.20 ±  2.78 |   524.27 ±  3.10 |
  | tg128  @ d16384 |    33.47 ±  0.02 |    33.31 ±  0.02 |    33.17 ±  0.05 |

After (3 runs on 707c0b7a6 + this patch):

  | test            |            run 1 |            run 2 |            run 3 |
  | --------------- | ----------------:| ----------------:| ----------------:|
  | pp1024 @ d16384 |   534.06 ±  0.97 |   531.95 ±  0.02 |   520.94 ± 20.10 |
  | tg128  @ d16384 |    45.85 ±  0.21 |    45.95 ±  0.45 |    46.22 ±  0.12 |

disclosure: Claude wrote it, but I reviewed and understand the implementation
(albeit my C is a little rusty).

* sycl: also support nvfp4 and mxfp4 expert types

* sycl: terser comments/nested dispatch in response to review

* sycl: more comment cleanup in mmvq.cpp/hpp

---------

Co-authored-by: Debian <aaron@openllmi.net.bots.is>
This commit is contained in:
abotsis 2026-04-22 23:18:56 -06:00 committed by GitHub
parent b76429a69c
commit 60b68a6279
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 218 additions and 0 deletions

View file

@ -3808,6 +3808,51 @@ __dpct_inline__ static void k_copy_dst_from_contiguous(
}
}
// Fused MoE TG fast path. Returns false to fall back to the per-expert loop below.
static bool ggml_sycl_mul_mat_id_mmvq_fused(
ggml_backend_sycl_context & ctx, const ggml_tensor * src0,
const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst)
{
const int64_t ne10 = src1->ne[0];
const int64_t ne11 = src1->ne[1];
const int64_t ne12 = src1->ne[2];
if (ne12 != 1) return false;
if (src1->type != GGML_TYPE_F32 || dst->type != GGML_TYPE_F32) return false;
if (ne10 != src0->ne[0] || ne10 % QK8_1 != 0) return false;
if (!ggml_is_contiguous(src1)) return false;
// Reorder layout not supported; fall back.
const ggml_tensor_extra_gpu * src0_extra =
static_cast<const ggml_tensor_extra_gpu *>(src0->extra);
if (src0_extra && src0_extra->optimized_feature.reorder) return false;
const int64_t n_ids_per_group = ids->ne[0];
if (ids->ne[1] != 1) return false;
if (ne11 != 1 && ne11 != n_ids_per_group) return false;
const queue_ptr stream = ctx.stream();
const int src1_padded_cols = GGML_PAD((int) ne10, MATRIX_ROW_PADDING);
const int n_experts_used = (int) n_ids_per_group;
const int nrows = (int) src0->ne[1];
ggml_sycl_pool_alloc<char> src1_q8_alloc(ctx.pool(),
(size_t) ne11 * src1_padded_cols * sizeof(block_q8_1) / QK8_1);
char * src1_ddq = src1_q8_alloc.get();
quantize_row_q8_1_sycl<quantize_q8_1>(
(const float *) src1->data, src1_ddq, (int) ne10, (int) ne11,
src1_padded_cols, stream);
const size_t bytes_per_qrow = (size_t) src1_padded_cols * sizeof(block_q8_1) / QK8_1;
const size_t src1_row_stride = (ne11 == 1) ? 0 : bytes_per_qrow;
return ggml_sycl_mul_mat_vec_q_id(
src0->type, src0->data, src1_ddq, (const int32_t *) ids->data,
(float *) dst->data, (int) ne10, nrows, n_experts_used,
/*expert_weight_stride=*/ src0->nb[2],
/*dst_row_stride=*/ dst->nb[1],
src1_row_stride, stream);
}
static void ggml_sycl_mul_mat_id(ggml_backend_sycl_context & ctx,
ggml_tensor *dst) try {
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/3);
@ -3823,6 +3868,12 @@ static void ggml_sycl_mul_mat_id(ggml_backend_sycl_context & ctx,
const int64_t n_as = ne02;
const int64_t n_ids = ids->ne[0];
if (ne12 == 1) {
if (ggml_sycl_mul_mat_id_mmvq_fused(ctx, src0, src1, ids, dst)) {
return;
}
}
std::vector<char> ids_host(ggml_nbytes(ids));
const char * ids_dev = (const char *) ids->data;

View file

@ -1199,3 +1199,154 @@ void ggml_sycl_op_mul_mat_vec_q(ggml_backend_sycl_context & ctx, const ggml_tens
GGML_UNUSED(src1_ddf_i);
GGML_UNUSED(ctx);
}
// src1_row_stride: 0 for shared src1 (gate/up proj), else per-expert stride (down proj).
template <int qk, int qi, typename block_q_t, int vdr, vec_dot_q_sycl_t vec_dot_q_sycl>
static void mul_mat_vec_q_moe(
const void * __restrict__ vx_base, const void * __restrict__ vy_base,
float * __restrict__ dst_base, const int32_t * __restrict__ ids_dev,
const int ncols, const int nrows,
const size_t expert_weight_stride, const size_t dst_row_stride,
const size_t src1_row_stride,
const sycl::nd_item<3> & item_ct1) {
const int expert_idx = item_ct1.get_group(1);
const int i02 = ids_dev[expert_idx];
const char * vx = (const char *) vx_base + (size_t) i02 * expert_weight_stride;
const char * vy = (const char *) vy_base + (size_t) expert_idx * src1_row_stride;
float * dst = (float *) ((char *) dst_base + (size_t) expert_idx * dst_row_stride);
const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) + item_ct1.get_local_id(1);
if (row >= nrows) {
return;
}
const int blocks_per_row = ncols / qk;
constexpr int blocks_per_warp = (vdr * WARP_SIZE + qi - 1) / qi;
float tmp = 0.0f;
const block_q_t * x = (const block_q_t *) vx;
const block_q8_1 * y = (const block_q8_1 *) vy;
for (int i = item_ct1.get_local_id(2) / (qi / vdr); i < blocks_per_row; i += blocks_per_warp) {
const int ibx = row * blocks_per_row + i;
const int iby = i * (qk / QK8_1);
for (size_t elem = 0; elem < qi / vdr; elem += WARP_SIZE) {
const int iqs = elem + vdr * (item_ct1.get_local_id(2) % (qi / vdr));
tmp += vec_dot_q_sycl(&x[ibx], &y[iby], iqs);
}
}
#pragma unroll
for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) {
tmp += dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
}
if (item_ct1.get_local_id(2) == 0) {
dst[row] = tmp;
}
}
template <int qk, int qi, typename block_q_t, int vdr, vec_dot_q_sycl_t vec_dot_q_sycl>
static void launch_mul_mat_vec_q_moe(
const void * vx_base, const void * vy, const int32_t * ids_dev,
float * dst_base, const int ncols, const int nrows, const int n_experts_used,
const size_t expert_weight_stride, const size_t dst_row_stride,
const size_t src1_row_stride,
dpct::queue_ptr stream) {
const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
const sycl::range<3> block_nums(1, (unsigned) n_experts_used, (unsigned) block_num_y);
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
stream->submit([&](sycl::handler & cgh) {
cgh.parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
mul_mat_vec_q_moe<qk, qi, block_q_t, vdr, vec_dot_q_sycl>(
vx_base, vy, dst_base, ids_dev, ncols, nrows,
expert_weight_stride, dst_row_stride, src1_row_stride, item);
});
});
}
bool ggml_sycl_mul_mat_vec_q_id(
enum ggml_type src0_type,
const void * vx_base,
const void * vy,
const int32_t * ids_dev,
float * dst_base,
int ncols,
int nrows,
int n_experts_used,
size_t expert_weight_stride,
size_t dst_row_stride,
size_t src1_row_stride,
dpct::queue_ptr stream) {
switch (src0_type) {
case GGML_TYPE_Q4_0:
launch_mul_mat_vec_q_moe<QK4_0, QI4_0, block_q4_0, VDR_Q4_0_Q8_1_MMVQ, vec_dot_q4_0_q8_1>(
vx_base, vy, ids_dev, dst_base, ncols, nrows, n_experts_used,
expert_weight_stride, dst_row_stride, src1_row_stride, stream);
return true;
case GGML_TYPE_Q4_1:
launch_mul_mat_vec_q_moe<QK4_1, QI4_1, block_q4_1, VDR_Q4_1_Q8_1_MMVQ, vec_dot_q4_1_q8_1>(
vx_base, vy, ids_dev, dst_base, ncols, nrows, n_experts_used,
expert_weight_stride, dst_row_stride, src1_row_stride, stream);
return true;
case GGML_TYPE_Q5_0:
launch_mul_mat_vec_q_moe<QK5_0, QI5_0, block_q5_0, VDR_Q5_0_Q8_1_MMVQ, vec_dot_q5_0_q8_1>(
vx_base, vy, ids_dev, dst_base, ncols, nrows, n_experts_used,
expert_weight_stride, dst_row_stride, src1_row_stride, stream);
return true;
case GGML_TYPE_Q5_1:
launch_mul_mat_vec_q_moe<QK5_1, QI5_1, block_q5_1, VDR_Q5_1_Q8_1_MMVQ, vec_dot_q5_1_q8_1>(
vx_base, vy, ids_dev, dst_base, ncols, nrows, n_experts_used,
expert_weight_stride, dst_row_stride, src1_row_stride, stream);
return true;
case GGML_TYPE_Q8_0:
launch_mul_mat_vec_q_moe<QK8_0, QI8_0, block_q8_0, VDR_Q8_0_Q8_1_MMVQ, vec_dot_q8_0_q8_1>(
vx_base, vy, ids_dev, dst_base, ncols, nrows, n_experts_used,
expert_weight_stride, dst_row_stride, src1_row_stride, stream);
return true;
case GGML_TYPE_Q2_K:
launch_mul_mat_vec_q_moe<QK_K, QI2_K, block_q2_K, VDR_Q2_K_Q8_1_MMVQ, vec_dot_q2_K_q8_1>(
vx_base, vy, ids_dev, dst_base, ncols, nrows, n_experts_used,
expert_weight_stride, dst_row_stride, src1_row_stride, stream);
return true;
case GGML_TYPE_Q3_K:
launch_mul_mat_vec_q_moe<QK_K, QI3_K, block_q3_K, VDR_Q3_K_Q8_1_MMVQ, vec_dot_q3_K_q8_1>(
vx_base, vy, ids_dev, dst_base, ncols, nrows, n_experts_used,
expert_weight_stride, dst_row_stride, src1_row_stride, stream);
return true;
case GGML_TYPE_Q4_K:
launch_mul_mat_vec_q_moe<QK_K, QI4_K, block_q4_K, VDR_Q4_K_Q8_1_MMVQ, vec_dot_q4_K_q8_1>(
vx_base, vy, ids_dev, dst_base, ncols, nrows, n_experts_used,
expert_weight_stride, dst_row_stride, src1_row_stride, stream);
return true;
case GGML_TYPE_Q5_K:
launch_mul_mat_vec_q_moe<QK_K, QI5_K, block_q5_K, VDR_Q5_K_Q8_1_MMVQ, vec_dot_q5_K_q8_1>(
vx_base, vy, ids_dev, dst_base, ncols, nrows, n_experts_used,
expert_weight_stride, dst_row_stride, src1_row_stride, stream);
return true;
case GGML_TYPE_Q6_K:
launch_mul_mat_vec_q_moe<QK_K, QI6_K, block_q6_K, VDR_Q6_K_Q8_1_MMVQ, vec_dot_q6_K_q8_1>(
vx_base, vy, ids_dev, dst_base, ncols, nrows, n_experts_used,
expert_weight_stride, dst_row_stride, src1_row_stride, stream);
return true;
case GGML_TYPE_MXFP4:
launch_mul_mat_vec_q_moe<QK_MXFP4, QI_MXFP4, block_mxfp4, VDR_MXFP4_Q8_1_MMVQ, vec_dot_mxfp4_q8_1>(
vx_base, vy, ids_dev, dst_base, ncols, nrows, n_experts_used,
expert_weight_stride, dst_row_stride, src1_row_stride, stream);
return true;
case GGML_TYPE_NVFP4:
launch_mul_mat_vec_q_moe<QK_NVFP4, QI_NVFP4, block_nvfp4, VDR_NVFP4_Q8_1_MMVQ, vec_dot_nvfp4_q8_1>(
vx_base, vy, ids_dev, dst_base, ncols, nrows, n_experts_used,
expert_weight_stride, dst_row_stride, src1_row_stride, stream);
return true;
default:
return false;
}
}

View file

@ -24,4 +24,20 @@ void ggml_sycl_op_mul_mat_vec_q(
const int64_t src1_ncols, const int64_t src1_padded_row_size,
const dpct::queue_ptr &stream);
// Requires standard (non-reorder) block layout for src0.
// Returns false if src0_type isn't handled; caller should fall back.
bool ggml_sycl_mul_mat_vec_q_id(
enum ggml_type src0_type,
const void * vx_base, // start of stacked expert weights
const void * vy, // pre-quantized src1 (Q8_1)
const int32_t * ids_dev, // device-side int32, length n_experts_used
float * dst_base,
int ncols,
int nrows,
int n_experts_used,
size_t expert_weight_stride, // bytes between experts in vx_base
size_t dst_row_stride, // bytes between dst rows
size_t src1_row_stride, // 0 = shared src1, else per-expert stride in bytes
dpct::queue_ptr stream);
#endif // GGML_SYCL_MMVQ_HPP