mirror of
https://github.com/LostRuins/koboldcpp.git
synced 2026-05-27 00:14:49 +00:00
SYCL: improve MoE prefill throughput (#23142)
- change `k_copy_src1_to_contiguous` so that uses a precomputed contiguous mapping where all rows "owned" by an expert are in one slice with a know starts and ends - switch the `O(n_as * n_routed_rows)` contraption to a counting sort-based procedure with `O(n_as + n_routed_rows)` complexity
This commit is contained in:
parent
bcfd1989e9
commit
cc9e331213
1 changed files with 106 additions and 91 deletions
|
|
@ -3919,35 +3919,17 @@ struct mmid_row_mapping {
|
|||
|
||||
__dpct_inline__ static void k_copy_src1_to_contiguous(
|
||||
const char *__restrict__ src1_original, char *__restrict__ src1_contiguous,
|
||||
int *__restrict__ cur_src1_row, mmid_row_mapping *__restrict__ row_mapping,
|
||||
const char *__restrict ids, int64_t i02, size_t ids_nb1, size_t ids_nb0,
|
||||
const mmid_row_mapping *__restrict__ row_mapping,
|
||||
int64_t ne11, int64_t ne10, size_t nb11, size_t nb12,
|
||||
const sycl::nd_item<3> &item_ct1, int &src1_row) {
|
||||
int32_t iid1 = item_ct1.get_group(2);
|
||||
int32_t id = item_ct1.get_group(1);
|
||||
const sycl::nd_item<3> &item_ct1) {
|
||||
const int32_t src1_row = item_ct1.get_group(2);
|
||||
|
||||
const int32_t row_id_i = *(const int32_t *) (ids + iid1*ids_nb1 + id*ids_nb0);
|
||||
|
||||
if (row_id_i != i02) {
|
||||
return;
|
||||
}
|
||||
const int32_t iid1 = row_mapping[src1_row].i2;
|
||||
const int32_t id = row_mapping[src1_row].i1;
|
||||
|
||||
const int64_t i11 = id % ne11;
|
||||
const int64_t i12 = iid1;
|
||||
|
||||
if (item_ct1.get_local_id(2) == 0) {
|
||||
src1_row =
|
||||
dpct::atomic_fetch_add<sycl::access::address_space::generic_space>(
|
||||
cur_src1_row, 1);
|
||||
row_mapping[src1_row] = {id, iid1};
|
||||
}
|
||||
/*
|
||||
DPCT1065:194: Consider replacing sycl::nd_item::barrier() with
|
||||
sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better
|
||||
performance if there is no access to global memory.
|
||||
*/
|
||||
item_ct1.barrier();
|
||||
|
||||
const float * src1_row_original = (const float *)(src1_original + i11*nb11 + i12*nb12);
|
||||
float * src1_row_contiguous = (float *)(src1_contiguous + src1_row*nb11);
|
||||
|
||||
|
|
@ -4022,6 +4004,47 @@ static bool ggml_sycl_mul_mat_id_mmvq_fused(
|
|||
src1_row_stride, stream);
|
||||
}
|
||||
|
||||
// counting sort of the routed rows by expert id (row_id_i, as chosen by the router):
|
||||
// builds a projection of a memory layout where each expert's slice is contiguous
|
||||
static void mmid_counting_sort_rows(
|
||||
const ggml_tensor * ids, const char * ids_host,
|
||||
int64_t n_ids, int64_t n_as, int64_t n_routed_rows,
|
||||
std::vector<int64_t> & expert_counts,
|
||||
std::vector<int64_t> & expert_row_offsets,
|
||||
std::vector<mmid_row_mapping> & routed_row_src) {
|
||||
|
||||
// frequencies: how many routed rows each expert "owns"
|
||||
expert_counts.assign(n_as, 0);
|
||||
for (int64_t iid1 = 0; iid1 < ids->ne[1]; iid1++) {
|
||||
for (int64_t id = 0; id < n_ids; id++) {
|
||||
const int32_t row_id_i = *(const int32_t *) (ids_host + iid1*ids->nb[1] + id*ids->nb[0]);
|
||||
GGML_ASSERT(row_id_i >= 0 && row_id_i < n_as);
|
||||
expert_counts[row_id_i]++;
|
||||
}
|
||||
}
|
||||
|
||||
// where each expert's slice starts (row indices) and the previous ends
|
||||
expert_row_offsets.assign(n_as + 1, 0);
|
||||
for (int64_t i02 = 0; i02 < n_as; i02++) {
|
||||
expert_row_offsets[i02 + 1] = expert_row_offsets[i02] + expert_counts[i02];
|
||||
}
|
||||
|
||||
std::vector<int64_t> expert_row_next = expert_row_offsets;
|
||||
routed_row_src.resize(n_routed_rows);
|
||||
for (int64_t iid1 = 0; iid1 < ids->ne[1]; iid1++) {
|
||||
for (int64_t id = 0; id < n_ids; id++) {
|
||||
const int32_t row_id_i = *(const int32_t *) (ids_host + iid1*ids->nb[1] + id*ids->nb[0]);
|
||||
GGML_ASSERT(row_id_i >= 0 && row_id_i < n_as);
|
||||
|
||||
// find and validate the next free row for a given expert (row_id_i)
|
||||
const int64_t routed_row = expert_row_next[row_id_i]++;
|
||||
GGML_ASSERT(routed_row >= expert_row_offsets[row_id_i]);
|
||||
GGML_ASSERT(routed_row < expert_row_offsets[row_id_i + 1]);
|
||||
routed_row_src[routed_row] = {(int32_t) id, (int32_t) iid1};
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static void ggml_sycl_mul_mat_id(ggml_backend_sycl_context & ctx,
|
||||
ggml_tensor *dst) try {
|
||||
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/3);
|
||||
|
|
@ -4100,99 +4123,91 @@ static void ggml_sycl_mul_mat_id(ggml_backend_sycl_context & ctx,
|
|||
src1_row.data = src1_contiguous.get();
|
||||
dst_row.data = dst_contiguous.get();
|
||||
|
||||
// how many "owned" routed rows to pass to each expert
|
||||
std::vector<int64_t> expert_row_counts;
|
||||
// where each expert's slice starts and the previous ends (row indices, right-exclusive)
|
||||
std::vector<int64_t> expert_row_offsets;
|
||||
// the sources (slot/token pairs) of contiguous rows to guide k_copy_src1_to_contiguous
|
||||
std::vector<mmid_row_mapping> routed_row_src;
|
||||
|
||||
mmid_counting_sort_rows(ids, ids_host.data(), n_ids, n_as, n_routed_rows,
|
||||
expert_row_counts, expert_row_offsets, routed_row_src);
|
||||
|
||||
ggml_sycl_pool_alloc<mmid_row_mapping> dev_row_mapping(ctx.pool(), n_routed_rows);
|
||||
SYCL_CHECK(CHECK_TRY_ERROR(
|
||||
stream->memcpy(dev_row_mapping.get(), routed_row_src.data(), n_routed_rows*sizeof(mmid_row_mapping))));
|
||||
|
||||
const unsigned int max_work_group_size = ggml_sycl_info().max_work_group_sizes[ctx.device];
|
||||
assert(max_work_group_size % (WARP_SIZE * WARP_SIZE) == 0);
|
||||
|
||||
{
|
||||
sycl::range<3> block_dims(1, 1, std::min((unsigned int)ne10, max_work_group_size));
|
||||
sycl::range<3> grid_dims(1, 1, n_routed_rows);
|
||||
stream->submit([&](sycl::handler &cgh) {
|
||||
char *__restrict src1_contiguous_get =
|
||||
src1_contiguous.get();
|
||||
mmid_row_mapping *__restrict dev_row_mapping_get =
|
||||
dev_row_mapping.get();
|
||||
|
||||
cgh.parallel_for(
|
||||
sycl::nd_range<3>(grid_dims * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1) {
|
||||
k_copy_src1_to_contiguous(
|
||||
src1_original, src1_contiguous_get,
|
||||
dev_row_mapping_get,
|
||||
ne11, ne10, nb11, nb12,
|
||||
item_ct1);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
for (int64_t i02 = 0; i02 < n_as; i02++) {
|
||||
int64_t num_src1_rows = 0;
|
||||
for (int64_t iid1 = 0; iid1 < ids->ne[1]; iid1++) {
|
||||
for (int64_t id = 0; id < n_ids; id++) {
|
||||
const int32_t row_id_i = *(const int32_t *) (ids_host.data() + iid1*ids->nb[1] + id*ids->nb[0]);
|
||||
|
||||
GGML_ASSERT(row_id_i >= 0 && row_id_i < n_as);
|
||||
|
||||
if (row_id_i != i02) {
|
||||
continue;
|
||||
}
|
||||
|
||||
num_src1_rows++;
|
||||
}
|
||||
}
|
||||
const int64_t num_src1_rows = expert_row_counts[i02];
|
||||
|
||||
if (num_src1_rows == 0) {
|
||||
continue;
|
||||
}
|
||||
|
||||
|
||||
ggml_sycl_pool_alloc<int> dev_cur_src1_row(ctx.pool(), 1);
|
||||
ggml_sycl_pool_alloc<mmid_row_mapping> dev_row_mapping(ctx.pool(), num_src1_rows);
|
||||
SYCL_CHECK(CHECK_TRY_ERROR(
|
||||
stream->memset(dev_cur_src1_row.get(), 0, sizeof(int))));
|
||||
|
||||
const unsigned int max_work_group_size = ggml_sycl_info().max_work_group_sizes[ctx.device];
|
||||
assert(max_work_group_size % (WARP_SIZE * WARP_SIZE) == 0);
|
||||
|
||||
{
|
||||
sycl::range<3> block_dims(1, 1, std::min((unsigned int)ne10, max_work_group_size));
|
||||
sycl::range<3> grid_dims(1, n_ids, ids->ne[1]);
|
||||
stream->submit([&](sycl::handler &cgh) {
|
||||
sycl::local_accessor<int, 0> src1_row_acc(cgh);
|
||||
|
||||
char *__restrict src1_contiguous_get =
|
||||
src1_contiguous.get();
|
||||
int *__restrict dev_cur_src1_row_get =
|
||||
dev_cur_src1_row.get();
|
||||
mmid_row_mapping *__restrict dev_row_mapping_get =
|
||||
dev_row_mapping.get();
|
||||
size_t ids_nb_ct6 = ids->nb[1];
|
||||
size_t ids_nb_ct7 = ids->nb[0];
|
||||
|
||||
cgh.parallel_for(
|
||||
sycl::nd_range<3>(grid_dims * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1) {
|
||||
k_copy_src1_to_contiguous(
|
||||
src1_original, src1_contiguous_get,
|
||||
dev_cur_src1_row_get,
|
||||
dev_row_mapping_get, ids_dev, i02,
|
||||
ids_nb_ct6, ids_nb_ct7, ne11, ne10, nb11, nb12,
|
||||
item_ct1, src1_row_acc);
|
||||
});
|
||||
});
|
||||
}
|
||||
const int64_t expert_row_offset = expert_row_offsets[i02];
|
||||
|
||||
src0_row.data = src0_original + i02*nb02;
|
||||
|
||||
GGML_ASSERT(nb11 == sizeof(float)*ne10);
|
||||
GGML_ASSERT(nb1 == sizeof(float)*ne0);
|
||||
src1_row.data = src1_contiguous.get() + expert_row_offset*nb11;
|
||||
src1_row.ne[1] = num_src1_rows;
|
||||
|
||||
src1_row.nb[1] = nb11;
|
||||
src1_row.nb[2] = num_src1_rows*nb11;
|
||||
src1_row.nb[3] = num_src1_rows*nb11;
|
||||
|
||||
dst_row.data = dst_contiguous.get() + expert_row_offset*nb1;
|
||||
dst_row.ne[1] = num_src1_rows;
|
||||
dst_row.nb[1] = nb1;
|
||||
dst_row.nb[2] = num_src1_rows*nb1;
|
||||
dst_row.nb[3] = num_src1_rows*nb1;
|
||||
|
||||
ggml_sycl_mul_mat(ctx, &src0_row, &src1_row, &dst_row);
|
||||
}
|
||||
|
||||
{
|
||||
sycl::range<3> block_dims(1, 1, std::min((unsigned int)ne0, max_work_group_size));
|
||||
sycl::range<3> grid_dims(1, 1, num_src1_rows);
|
||||
stream->submit([&](sycl::handler &cgh) {
|
||||
const char *__restrict dst_contiguous_get =
|
||||
dst_contiguous.get();
|
||||
const mmid_row_mapping *__restrict dev_row_mapping_get =
|
||||
dev_row_mapping.get();
|
||||
{
|
||||
sycl::range<3> block_dims(1, 1, std::min((unsigned int)ne0, max_work_group_size));
|
||||
sycl::range<3> grid_dims(1, 1, n_routed_rows);
|
||||
stream->submit([&](sycl::handler &cgh) {
|
||||
const char *__restrict dst_contiguous_get =
|
||||
dst_contiguous.get();
|
||||
const mmid_row_mapping *__restrict dev_row_mapping_get =
|
||||
dev_row_mapping.get();
|
||||
|
||||
cgh.parallel_for(
|
||||
sycl::nd_range<3>(grid_dims * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1) {
|
||||
k_copy_dst_from_contiguous(dst_original,
|
||||
dst_contiguous_get,
|
||||
dev_row_mapping_get,
|
||||
ne0, nb1, nb2, item_ct1);
|
||||
});
|
||||
});
|
||||
}
|
||||
cgh.parallel_for(
|
||||
sycl::nd_range<3>(grid_dims * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1) {
|
||||
k_copy_dst_from_contiguous(dst_original,
|
||||
dst_contiguous_get,
|
||||
dev_row_mapping_get,
|
||||
ne0, nb1, nb2, item_ct1);
|
||||
});
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue