SYCL: improve MoE prefill throughput (#23142)

- change `k_copy_src1_to_contiguous` so that uses a precomputed contiguous mapping where all rows "owned" by an expert are in one slice with a know starts and ends
- switch the `O(n_as * n_routed_rows)` contraption to a counting sort-based procedure with `O(n_as + n_routed_rows)` complexity
This commit is contained in:
Alexey Kopytko 2026-05-22 21:50:17 +09:00 committed by GitHub
parent bcfd1989e9
commit cc9e331213
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -3919,35 +3919,17 @@ struct mmid_row_mapping {
__dpct_inline__ static void k_copy_src1_to_contiguous(
const char *__restrict__ src1_original, char *__restrict__ src1_contiguous,
int *__restrict__ cur_src1_row, mmid_row_mapping *__restrict__ row_mapping,
const char *__restrict ids, int64_t i02, size_t ids_nb1, size_t ids_nb0,
const mmid_row_mapping *__restrict__ row_mapping,
int64_t ne11, int64_t ne10, size_t nb11, size_t nb12,
const sycl::nd_item<3> &item_ct1, int &src1_row) {
int32_t iid1 = item_ct1.get_group(2);
int32_t id = item_ct1.get_group(1);
const sycl::nd_item<3> &item_ct1) {
const int32_t src1_row = item_ct1.get_group(2);
const int32_t row_id_i = *(const int32_t *) (ids + iid1*ids_nb1 + id*ids_nb0);
if (row_id_i != i02) {
return;
}
const int32_t iid1 = row_mapping[src1_row].i2;
const int32_t id = row_mapping[src1_row].i1;
const int64_t i11 = id % ne11;
const int64_t i12 = iid1;
if (item_ct1.get_local_id(2) == 0) {
src1_row =
dpct::atomic_fetch_add<sycl::access::address_space::generic_space>(
cur_src1_row, 1);
row_mapping[src1_row] = {id, iid1};
}
/*
DPCT1065:194: Consider replacing sycl::nd_item::barrier() with
sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better
performance if there is no access to global memory.
*/
item_ct1.barrier();
const float * src1_row_original = (const float *)(src1_original + i11*nb11 + i12*nb12);
float * src1_row_contiguous = (float *)(src1_contiguous + src1_row*nb11);
@ -4022,6 +4004,47 @@ static bool ggml_sycl_mul_mat_id_mmvq_fused(
src1_row_stride, stream);
}
// counting sort of the routed rows by expert id (row_id_i, as chosen by the router):
// builds a projection of a memory layout where each expert's slice is contiguous
static void mmid_counting_sort_rows(
const ggml_tensor * ids, const char * ids_host,
int64_t n_ids, int64_t n_as, int64_t n_routed_rows,
std::vector<int64_t> & expert_counts,
std::vector<int64_t> & expert_row_offsets,
std::vector<mmid_row_mapping> & routed_row_src) {
// frequencies: how many routed rows each expert "owns"
expert_counts.assign(n_as, 0);
for (int64_t iid1 = 0; iid1 < ids->ne[1]; iid1++) {
for (int64_t id = 0; id < n_ids; id++) {
const int32_t row_id_i = *(const int32_t *) (ids_host + iid1*ids->nb[1] + id*ids->nb[0]);
GGML_ASSERT(row_id_i >= 0 && row_id_i < n_as);
expert_counts[row_id_i]++;
}
}
// where each expert's slice starts (row indices) and the previous ends
expert_row_offsets.assign(n_as + 1, 0);
for (int64_t i02 = 0; i02 < n_as; i02++) {
expert_row_offsets[i02 + 1] = expert_row_offsets[i02] + expert_counts[i02];
}
std::vector<int64_t> expert_row_next = expert_row_offsets;
routed_row_src.resize(n_routed_rows);
for (int64_t iid1 = 0; iid1 < ids->ne[1]; iid1++) {
for (int64_t id = 0; id < n_ids; id++) {
const int32_t row_id_i = *(const int32_t *) (ids_host + iid1*ids->nb[1] + id*ids->nb[0]);
GGML_ASSERT(row_id_i >= 0 && row_id_i < n_as);
// find and validate the next free row for a given expert (row_id_i)
const int64_t routed_row = expert_row_next[row_id_i]++;
GGML_ASSERT(routed_row >= expert_row_offsets[row_id_i]);
GGML_ASSERT(routed_row < expert_row_offsets[row_id_i + 1]);
routed_row_src[routed_row] = {(int32_t) id, (int32_t) iid1};
}
}
}
static void ggml_sycl_mul_mat_id(ggml_backend_sycl_context & ctx,
ggml_tensor *dst) try {
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/3);
@ -4100,99 +4123,91 @@ static void ggml_sycl_mul_mat_id(ggml_backend_sycl_context & ctx,
src1_row.data = src1_contiguous.get();
dst_row.data = dst_contiguous.get();
// how many "owned" routed rows to pass to each expert
std::vector<int64_t> expert_row_counts;
// where each expert's slice starts and the previous ends (row indices, right-exclusive)
std::vector<int64_t> expert_row_offsets;
// the sources (slot/token pairs) of contiguous rows to guide k_copy_src1_to_contiguous
std::vector<mmid_row_mapping> routed_row_src;
mmid_counting_sort_rows(ids, ids_host.data(), n_ids, n_as, n_routed_rows,
expert_row_counts, expert_row_offsets, routed_row_src);
ggml_sycl_pool_alloc<mmid_row_mapping> dev_row_mapping(ctx.pool(), n_routed_rows);
SYCL_CHECK(CHECK_TRY_ERROR(
stream->memcpy(dev_row_mapping.get(), routed_row_src.data(), n_routed_rows*sizeof(mmid_row_mapping))));
const unsigned int max_work_group_size = ggml_sycl_info().max_work_group_sizes[ctx.device];
assert(max_work_group_size % (WARP_SIZE * WARP_SIZE) == 0);
{
sycl::range<3> block_dims(1, 1, std::min((unsigned int)ne10, max_work_group_size));
sycl::range<3> grid_dims(1, 1, n_routed_rows);
stream->submit([&](sycl::handler &cgh) {
char *__restrict src1_contiguous_get =
src1_contiguous.get();
mmid_row_mapping *__restrict dev_row_mapping_get =
dev_row_mapping.get();
cgh.parallel_for(
sycl::nd_range<3>(grid_dims * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) {
k_copy_src1_to_contiguous(
src1_original, src1_contiguous_get,
dev_row_mapping_get,
ne11, ne10, nb11, nb12,
item_ct1);
});
});
}
for (int64_t i02 = 0; i02 < n_as; i02++) {
int64_t num_src1_rows = 0;
for (int64_t iid1 = 0; iid1 < ids->ne[1]; iid1++) {
for (int64_t id = 0; id < n_ids; id++) {
const int32_t row_id_i = *(const int32_t *) (ids_host.data() + iid1*ids->nb[1] + id*ids->nb[0]);
GGML_ASSERT(row_id_i >= 0 && row_id_i < n_as);
if (row_id_i != i02) {
continue;
}
num_src1_rows++;
}
}
const int64_t num_src1_rows = expert_row_counts[i02];
if (num_src1_rows == 0) {
continue;
}
ggml_sycl_pool_alloc<int> dev_cur_src1_row(ctx.pool(), 1);
ggml_sycl_pool_alloc<mmid_row_mapping> dev_row_mapping(ctx.pool(), num_src1_rows);
SYCL_CHECK(CHECK_TRY_ERROR(
stream->memset(dev_cur_src1_row.get(), 0, sizeof(int))));
const unsigned int max_work_group_size = ggml_sycl_info().max_work_group_sizes[ctx.device];
assert(max_work_group_size % (WARP_SIZE * WARP_SIZE) == 0);
{
sycl::range<3> block_dims(1, 1, std::min((unsigned int)ne10, max_work_group_size));
sycl::range<3> grid_dims(1, n_ids, ids->ne[1]);
stream->submit([&](sycl::handler &cgh) {
sycl::local_accessor<int, 0> src1_row_acc(cgh);
char *__restrict src1_contiguous_get =
src1_contiguous.get();
int *__restrict dev_cur_src1_row_get =
dev_cur_src1_row.get();
mmid_row_mapping *__restrict dev_row_mapping_get =
dev_row_mapping.get();
size_t ids_nb_ct6 = ids->nb[1];
size_t ids_nb_ct7 = ids->nb[0];
cgh.parallel_for(
sycl::nd_range<3>(grid_dims * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) {
k_copy_src1_to_contiguous(
src1_original, src1_contiguous_get,
dev_cur_src1_row_get,
dev_row_mapping_get, ids_dev, i02,
ids_nb_ct6, ids_nb_ct7, ne11, ne10, nb11, nb12,
item_ct1, src1_row_acc);
});
});
}
const int64_t expert_row_offset = expert_row_offsets[i02];
src0_row.data = src0_original + i02*nb02;
GGML_ASSERT(nb11 == sizeof(float)*ne10);
GGML_ASSERT(nb1 == sizeof(float)*ne0);
src1_row.data = src1_contiguous.get() + expert_row_offset*nb11;
src1_row.ne[1] = num_src1_rows;
src1_row.nb[1] = nb11;
src1_row.nb[2] = num_src1_rows*nb11;
src1_row.nb[3] = num_src1_rows*nb11;
dst_row.data = dst_contiguous.get() + expert_row_offset*nb1;
dst_row.ne[1] = num_src1_rows;
dst_row.nb[1] = nb1;
dst_row.nb[2] = num_src1_rows*nb1;
dst_row.nb[3] = num_src1_rows*nb1;
ggml_sycl_mul_mat(ctx, &src0_row, &src1_row, &dst_row);
}
{
sycl::range<3> block_dims(1, 1, std::min((unsigned int)ne0, max_work_group_size));
sycl::range<3> grid_dims(1, 1, num_src1_rows);
stream->submit([&](sycl::handler &cgh) {
const char *__restrict dst_contiguous_get =
dst_contiguous.get();
const mmid_row_mapping *__restrict dev_row_mapping_get =
dev_row_mapping.get();
{
sycl::range<3> block_dims(1, 1, std::min((unsigned int)ne0, max_work_group_size));
sycl::range<3> grid_dims(1, 1, n_routed_rows);
stream->submit([&](sycl::handler &cgh) {
const char *__restrict dst_contiguous_get =
dst_contiguous.get();
const mmid_row_mapping *__restrict dev_row_mapping_get =
dev_row_mapping.get();
cgh.parallel_for(
sycl::nd_range<3>(grid_dims * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) {
k_copy_dst_from_contiguous(dst_original,
dst_contiguous_get,
dev_row_mapping_get,
ne0, nb1, nb2, item_ct1);
});
});
}
cgh.parallel_for(
sycl::nd_range<3>(grid_dims * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) {
k_copy_dst_from_contiguous(dst_original,
dst_contiguous_get,
dev_row_mapping_get,
ne0, nb1, nb2, item_ct1);
});
});
}
}
}