opencl: generalize Adreno MoE kernels on M (#23449)

This commit is contained in:
Shawn Gu 2026-05-22 17:08:41 -07:00 committed by GitHub
parent 1acee6bf89
commit 0f3cb3fc8b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
18 changed files with 145 additions and 17 deletions

View file

@ -4693,7 +4693,7 @@ inline bool use_adreno_kernels(const ggml_backend_opencl_context *backend_ctx, c
inline bool use_adreno_moe_kernels(const ggml_backend_opencl_context *backend_ctx, const ggml_tensor *tensor) {
GGML_UNUSED(backend_ctx);
int ne01 = tensor->ne[1];
return (((strstr(tensor->name, "ffn") != NULL) && (strstr(tensor->name, "exps") != NULL)) || (strstr(tensor->name, "as") != NULL)) && (ne01 % 64 == 0);
return (((strstr(tensor->name, "ffn") != NULL) && (strstr(tensor->name, "exps") != NULL)) || (strstr(tensor->name, "as") != NULL)) && (ne01 % 32 == 0);
}
inline bool enable_adreno_trans_weight(const ggml_backend_opencl_context *backend_ctx, const ggml_tensor *tensor) {
@ -14297,7 +14297,7 @@ static void ggml_cl_mul_mat_id(ggml_backend_t backend, const ggml_tensor * src0,
CL_CHECK(status);
// set thread grid
global_size[0] = static_cast<size_t>(ne01);
global_size[0] = static_cast<size_t>(((ne01 + 63) / 64) * 64);
global_size[1] = 4;
global_size[2] = static_cast<size_t>(ne20);
local_size[1] = 4;
@ -14513,7 +14513,7 @@ static void ggml_cl_mul_mat_id(ggml_backend_t backend, const ggml_tensor * src0,
CL_CHECK(status);
// set thread grid
global_size[0] = static_cast<size_t>(ne01);
global_size[0] = static_cast<size_t>(((ne01 + 63) / 64) * 64);
global_size[1] = 4;
global_size[2] = static_cast<size_t>(ne20);
local_size[1] = 4;
@ -14689,7 +14689,7 @@ static void ggml_cl_mul_mat_id(ggml_backend_t backend, const ggml_tensor * src0,
CL_CHECK(status);
// set thread grid
global_size[0] = static_cast<size_t>(ne01);
global_size[0] = static_cast<size_t>(((ne01 + 63) / 64) * 64);
global_size[1] = 4;
global_size[2] = static_cast<size_t>(ne20);
local_size[1] = 4;
@ -14865,7 +14865,7 @@ static void ggml_cl_mul_mat_id(ggml_backend_t backend, const ggml_tensor * src0,
CL_CHECK(status);
// set thread grid
global_size[0] = static_cast<size_t>(ne01);
global_size[0] = static_cast<size_t>(((ne01 + 63) / 64) * 64);
global_size[1] = 4;
global_size[2] = static_cast<size_t>(ne20);
local_size[1] = 4;
@ -15118,7 +15118,7 @@ static void ggml_cl_mul_mat_id(ggml_backend_t backend, const ggml_tensor * src0,
CL_CHECK(status);
// set thread grid
global_size[0] = static_cast<size_t>(ne01);
global_size[0] = static_cast<size_t>(((ne01 + 63) / 64) * 64);
global_size[1] = 4;
global_size[2] = static_cast<size_t>(ne20);
local_size[1] = 4;
@ -15291,7 +15291,7 @@ static void ggml_cl_mul_mat_id(ggml_backend_t backend, const ggml_tensor * src0,
CL_CHECK(status);
// set thread grid
global_size[0] = static_cast<size_t>(ne01);
global_size[0] = static_cast<size_t>(((ne01 + 63) / 64) * 64);
global_size[1] = 4;
global_size[2] = static_cast<size_t>(ne20);
local_size[1] = 4;
@ -15469,7 +15469,7 @@ static void ggml_cl_mul_mat_id(ggml_backend_t backend, const ggml_tensor * src0,
CL_CHECK(status);
// set thread grid
global_size[0] = static_cast<size_t>(ne01);
global_size[0] = static_cast<size_t>(((ne01 + 63) / 64) * 64);
global_size[1] = 4;
global_size[2] = static_cast<size_t>(ne20);
local_size[1] = 4;
@ -15644,7 +15644,7 @@ static void ggml_cl_mul_mat_id(ggml_backend_t backend, const ggml_tensor * src0,
CL_CHECK(status);
// set thread grid
global_size[0] = static_cast<size_t>(ne01);
global_size[0] = static_cast<size_t>(((ne01 + 63) / 64) * 64);
global_size[1] = 4;
global_size[2] = static_cast<size_t>(ne20);
local_size[1] = 4;

View file

@ -220,6 +220,10 @@ kernel void kernel_convert_block_q4_0_trans4_ns(
uint i01 = get_global_id(0);
uint i02 = get_global_id(2);
if (i01 >= ne01) {
return;
}
uint ne00_blk = ne00 / QK4_0;
uint src_blk_offset = i00 + i01 * ne00_blk + i02 * ne00_blk * ne01;
uint dst_blk_offset = i01 + i00 * ne01 + i02 * ne00_blk * ne01;
@ -263,6 +267,10 @@ kernel void kernel_restore_block_q4_0_trans4_ns(
uint i01 = get_global_id(0);
uint i02 = get_global_id(2);
if (i01 >= ne01) {
return;
}
uint ne00_blk = ne00 / QK4_0;
uint dst_blk_offset = i00 + i01 * ne00_blk + i02 * ne00_blk * ne01;
uint src_d_offset = i01 + i00 * ne01 + i02 * ne00_blk * ne01;
@ -401,6 +409,10 @@ kernel void kernel_convert_block_q4_1_trans4_ns(
uint i01 = get_global_id(0);
uint i02 = get_global_id(2);
if (i01 >= ne01) {
return;
}
uint ne00_blk = ne00 / QK4_1;
uint src_blk_offset = i00 + i01 * ne00_blk + i02 * ne00_blk * ne01;
uint dst_blk_offset = i01 + i00 * ne01 + i02 * ne00_blk * ne01;
@ -446,6 +458,10 @@ kernel void kernel_restore_block_q4_1_trans4_ns(
uint i01 = get_global_id(0);
uint i02 = get_global_id(2);
if (i01 >= ne01) {
return;
}
uint ne00_blk = ne00 / QK4_1;
uint dst_blk_offset = i00 + i01 * ne00_blk + i02 * ne00_blk * ne01;
uint src_dm_offset = i01 + i00 * ne01 + i02 * ne00_blk * ne01;
@ -491,6 +507,10 @@ kernel void kernel_convert_block_q5_0_trans4_ns(
uint i01 = get_global_id(0);
uint i02 = get_global_id(2);
if (i01 >= ne01) {
return;
}
uint ne00_blk = ne00 / QK5_0;
uint src_blk_offset = i00 + i01 * ne00_blk + i02 * ne00_blk * ne01;
uint dst_blk_offset = i01 + i00 * ne01 + i02 * ne00_blk * ne01;
@ -536,6 +556,10 @@ kernel void kernel_restore_block_q5_0_trans4_ns(
uint i01 = get_global_id(0);
uint i02 = get_global_id(2);
if (i01 >= ne01) {
return;
}
uint ne00_blk = ne00 / QK5_0;
uint dst_blk_offset = i00 + i01 * ne00_blk + i02 * ne00_blk * ne01;
uint src_blk_offset = i01 + i00 * ne01 + i02 * ne00_blk * ne01;
@ -583,6 +607,10 @@ kernel void kernel_convert_block_q5_1_trans4_ns(
uint i01 = get_global_id(0);
uint i02 = get_global_id(2);
if (i01 >= ne01) {
return;
}
uint ne00_blk = ne00 / QK5_1;
uint src_blk_offset = i00 + i01 * ne00_blk + i02 * ne00_blk * ne01;
uint dst_blk_offset = i01 + i00 * ne01 + i02 * ne00_blk * ne01;
@ -630,6 +658,10 @@ kernel void kernel_restore_block_q5_1_trans4_ns(
uint i01 = get_global_id(0);
uint i02 = get_global_id(2);
if (i01 >= ne01) {
return;
}
uint ne00_blk = ne00 / QK5_1;
uint dst_blk_offset = i00 + i01 * ne00_blk + i02 * ne00_blk * ne01;
uint src_blk_offset = i01 + i00 * ne01 + i02 * ne00_blk * ne01;
@ -679,6 +711,10 @@ kernel void kernel_convert_block_q4_k_trans4_ns(
uint i01 = get_global_id(0);
uint i02 = get_global_id(2);
if (i01 >= ne01) {
return;
}
uint ne00_blk = ne00 / QK_K;
uint src_blk_offset = i00 + i01 * ne00_blk + i02 * ne00_blk * ne01;
uint dst_blk_offset = i01 + i00 * ne01 + i02 * ne00_blk * ne01;
@ -732,6 +768,10 @@ kernel void kernel_restore_block_q4_k_trans4_ns(
uint i01 = get_global_id(0); // row index
uint i02 = get_global_id(2); // batch index
if (i01 >= ne01) {
return;
}
uint ne00_blk = ne00 / QK_K;
uint src_blk_offset = i01 + i00 * ne01 + i02 * ne00_blk * ne01;
@ -784,6 +824,10 @@ kernel void kernel_convert_block_q5_k_trans4_ns(
uint i01 = get_global_id(0);
uint i02 = get_global_id(2);
if (i01 >= ne01) {
return;
}
uint ne00_blk = ne00 / QK_K;
uint src_blk_offset = i00 + i01 * ne00_blk + i02 * ne00_blk * ne01;
uint dst_blk_offset = i01 + i00 * ne01 + i02 * ne00_blk * ne01;
@ -850,6 +894,10 @@ kernel void kernel_restore_block_q5_k_trans4_ns(
uint i01 = get_global_id(0); // row index
uint i02 = get_global_id(2); // batch index
if (i01 >= ne01) {
return;
}
uint ne00_blk = ne00 / QK_K;
uint src_blk_offset = i01 + i00 * ne01 + i02 * ne00_blk * ne01;
@ -916,6 +964,10 @@ kernel void kernel_convert_block_q6_k_trans4_ns(
uint i01 = get_global_id(0);
uint i02 = get_global_id(2);
if (i01 >= ne01) {
return;
}
uint ne00_blk = ne00 / QK_K;
uint src_blk_offset = i00 + i01 * ne00_blk + i02 * ne00_blk * ne01;
@ -993,6 +1045,10 @@ kernel void kernel_restore_block_q6_k_trans4_ns(
uint i01 = get_global_id(0); // row index
uint i02 = get_global_id(2); // batch index
if (i01 >= ne01) {
return;
}
uint ne00_blk = ne00 / QK_K;
uint src_blk_offset = i01 + i00 * ne01 + i02 * ne00_blk * ne01;
@ -1147,6 +1203,10 @@ kernel void kernel_convert_block_mxfp4_trans4_ns(
uint i01 = get_global_id(0);
uint i02 = get_global_id(2);
if (i01 >= ne01) {
return;
}
uint ne00_blk = ne00 / QK_MXFP4;
uint src_blk_offset = i00 + i01 * ne00_blk + i02 * ne00_blk * ne01;
uint dst_blk_offset = i01 + i00 * ne01 + i02 * ne00_blk * ne01;
@ -1190,6 +1250,10 @@ kernel void kernel_restore_block_mxfp4_trans4_ns(
uint i01 = get_global_id(0);
uint i02 = get_global_id(2);
if (i01 >= ne01) {
return;
}
uint ne00_blk = ne00 / QK_MXFP4;
uint dst_blk_offset = i00 + i01 * ne00_blk + i02 * ne00_blk * ne01;
uint src_d_offset = i01 + i00 * ne01 + i02 * ne00_blk * ne01;

View file

@ -163,7 +163,7 @@ kernel void kernel_gemm_moe_mxfp4_f32_ns(
uint block_id_n = get_global_id(2); // n_tile
// Boundary check
if (((get_global_id(0) + block_id_m * TILESIZE_M) >= ne01) || (block_id_n >= total_tiles[0])) {
if (block_id_n >= total_tiles[0]) {
return;
}
@ -248,6 +248,10 @@ kernel void kernel_gemm_moe_mxfp4_f32_ns(
dotx16_reduce8(reg_a, shared_b, reg_c.hi, 16);
}
if ((get_global_id(0) + block_id_m * TILESIZE_M) >= ne01) {
return;
}
// Load poster router and share in LM
__local uint out_idx[TILESIZE_N];

View file

@ -115,7 +115,7 @@ kernel void kernel_gemm_moe_q4_0_f32_ns(
uint block_id_n = get_global_id(2); // n_tile
// Boundary check
if (((get_global_id(0) + block_id_m * TILESIZE_M) >= ne01) || (block_id_n >= total_tiles[0])) {
if (block_id_n >= total_tiles[0]) {
return;
}
@ -198,6 +198,10 @@ kernel void kernel_gemm_moe_q4_0_f32_ns(
dotx16_reduce8(reg_a, shared_b, reg_c.hi, 16);
}
if ((get_global_id(0) + block_id_m * TILESIZE_M) >= ne01) {
return;
}
// Load poster router and share in LM
__local uint out_idx[TILESIZE_N];

View file

@ -116,7 +116,7 @@ kernel void kernel_gemm_moe_q4_1_f32_ns(
uint block_id_n = get_global_id(2); // n_tile
// Boundary check
if (((get_global_id(0) + block_id_m * TILESIZE_M) >= ne01) || (block_id_n >= total_tiles[0])) {
if (block_id_n >= total_tiles[0]) {
return;
}
@ -200,6 +200,10 @@ kernel void kernel_gemm_moe_q4_1_f32_ns(
dotx16_reduce8(reg_a, shared_b, reg_c.hi, 16);
}
if ((get_global_id(0) + block_id_m * TILESIZE_M) >= ne01) {
return;
}
// Load poster router and share in LM
__local uint out_idx[TILESIZE_N];

View file

@ -133,7 +133,7 @@ kernel void kernel_gemm_moe_q4_k_f32_ns(
uint block_id_n = get_global_id(2); // n_tile
// Boundary check
if (((get_global_id(0) + block_id_m * TILESIZE_M) >= ne01) || (block_id_n >= total_tiles[0])) {
if (block_id_n >= total_tiles[0]) {
return;
}
@ -225,6 +225,10 @@ kernel void kernel_gemm_moe_q4_k_f32_ns(
dotx16_reduce8(reg_a, shared_b, reg_c.hi, 16);
}
if ((get_global_id(0) + block_id_m * TILESIZE_M) >= ne01) {
return;
}
// Load post router and share in LM
__local uint out_idx[TILESIZE_N];

View file

@ -116,7 +116,7 @@ kernel void kernel_gemm_moe_q5_0_f32_ns(
uint block_id_n = get_global_id(2); // n_tile
// Boundary check
if (((get_global_id(0) + block_id_m * TILESIZE_M) >= ne01) || (block_id_n >= total_tiles[0])) {
if (block_id_n >= total_tiles[0]) {
return;
}
@ -202,6 +202,10 @@ kernel void kernel_gemm_moe_q5_0_f32_ns(
dotx16_reduce8(reg_a, shared_b, reg_c.hi, 16);
}
if ((get_global_id(0) + block_id_m * TILESIZE_M) >= ne01) {
return;
}
// Load poster router and share in LM
__local uint out_idx[TILESIZE_N];

View file

@ -117,7 +117,7 @@ kernel void kernel_gemm_moe_q5_1_f32_ns(
uint block_id_n = get_global_id(2); // n_tile
// Boundary check
if (((get_global_id(0) + block_id_m * TILESIZE_M) >= ne01) || (block_id_n >= total_tiles[0])) {
if (block_id_n >= total_tiles[0]) {
return;
}
@ -204,6 +204,10 @@ kernel void kernel_gemm_moe_q5_1_f32_ns(
dotx16_reduce8(reg_a, shared_b, reg_c.hi, 16);
}
if ((get_global_id(0) + block_id_m * TILESIZE_M) >= ne01) {
return;
}
// Load poster router and share in LM
__local uint out_idx[TILESIZE_N];

View file

@ -134,7 +134,7 @@ kernel void kernel_gemm_moe_q5_k_f32_ns(
uint block_id_n = get_global_id(2); // n_tile
// Boundary check
if (((get_global_id(0) + block_id_m * TILESIZE_M) >= ne01) || (block_id_n >= total_tiles[0])) {
if (block_id_n >= total_tiles[0]) {
return;
}
@ -230,6 +230,10 @@ kernel void kernel_gemm_moe_q5_k_f32_ns(
dotx16_reduce8(reg_a, shared_b, reg_c.hi, 16);
}
if ((get_global_id(0) + block_id_m * TILESIZE_M) >= ne01) {
return;
}
// Load post router and share in LM
__local uint out_idx[TILESIZE_N];

View file

@ -117,7 +117,7 @@ kernel void kernel_gemm_moe_q6_k_f32_ns(
uint block_id_n = get_global_id(2); // n_tile
// Boundary check
if (((get_global_id(0) + block_id_m * TILESIZE_M) >= ne01) || (block_id_n >= total_tiles[0])) {
if (block_id_n >= total_tiles[0]) {
return;
}
@ -209,6 +209,10 @@ kernel void kernel_gemm_moe_q6_k_f32_ns(
dotx16_reduce8(reg_a, shared_b, reg_c.hi, 16);
}
if ((get_global_id(0) + block_id_m * TILESIZE_M) >= ne01) {
return;
}
// Load post router and share in LM
__local uint out_idx[TILESIZE_N];

View file

@ -82,6 +82,10 @@ __kernel void kernel_gemv_moe_mxfp4_f32_ns(
uint sgid = get_local_id(1);
uint slid = get_sub_group_local_id();
if (i01 >= ne01) {
return;
}
uint i11 = i20 % ne11;
uint expert_id = src2[i20];

View file

@ -37,6 +37,10 @@ __kernel void kernel_gemv_moe_q4_0_f32_ns(
uint sgid = get_local_id(1);
uint slid = get_sub_group_local_id();
if (i01 >= ne01) {
return;
}
uint i11 = i20 % ne11;
uint expert_id = src2[i20];

View file

@ -38,6 +38,10 @@ __kernel void kernel_gemv_moe_q4_1_f32_ns(
uint sgid = get_local_id(1);
uint slid = get_sub_group_local_id();
if (i01 >= ne01) {
return;
}
uint i11 = i20 % ne11;
uint expert_id = src2[i20];

View file

@ -54,6 +54,10 @@ __kernel void kernel_gemv_moe_q4_k_f32_ns(
uint sgid = get_local_id(1);
uint slid = get_sub_group_local_id();
if (i01 >= ne01) {
return;
}
uint i11 = i20 % ne11;
uint expert_id = src2[i20];

View file

@ -38,6 +38,10 @@ __kernel void kernel_gemv_moe_q5_0_f32_ns(
uint sgid = get_local_id(1);
uint slid = get_sub_group_local_id();
if (i01 >= ne01) {
return;
}
uint i11 = i20 % ne11;
uint expert_id = src2[i20];

View file

@ -39,6 +39,10 @@ __kernel void kernel_gemv_moe_q5_1_f32_ns(
uint sgid = get_local_id(1);
uint slid = get_sub_group_local_id();
if (i01 >= ne01) {
return;
}
uint i11 = i20 % ne11;
uint expert_id = src2[i20];

View file

@ -55,6 +55,10 @@ __kernel void kernel_gemv_moe_q5_k_f32_ns(
uint sgid = get_local_id(1);
uint slid = get_sub_group_local_id();
if (i01 >= ne01) {
return;
}
uint i11 = i20 % ne11;
uint expert_id = src2[i20];

View file

@ -38,6 +38,10 @@ __kernel void kernel_gemv_moe_q6_k_f32_ns(
uint sgid = get_local_id(1);
uint slid = get_sub_group_local_id();
if (i01 >= ne01) {
return;
}
uint i11 = i20 % ne11;
uint expert_id = src2[i20];