ggml-webgpu: Fix how to dispatch WG to some ops (#23750)

This commit is contained in:
Masashi Yoshimura 2026-05-28 01:48:12 +09:00 committed by GitHub
parent c6e4088376
commit c40006a62e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 58 additions and 50 deletions

View file

@ -749,8 +749,11 @@ static webgpu_encoded_op ggml_webgpu_cpy(webgpu_context & ctx, ggml_tensor * src
ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, dst),
};
uint32_t wg_x = CEIL_DIV(ne, decisions->wg_size);
return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x);
uint32_t wg_x;
uint32_t wg_y;
uint32_t total_wg = CEIL_DIV(ne, decisions->wg_size);
compute_2d_workgroups(total_wg, ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension, wg_x, wg_y);
return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x, wg_y);
}
static webgpu_encoded_op ggml_webgpu_set(webgpu_context & ctx,
@ -974,9 +977,10 @@ static webgpu_encoded_op ggml_webgpu_conv_2d(webgpu_context & ctx,
auto * decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context.get());
uint32_t wg_x;
uint32_t wg_y;
uint32_t total_wg = CEIL_DIV((uint32_t) ggml_nelements(dst), decisions->wg_size);
uint32_t wg_x = std::min(ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension, total_wg);
uint32_t wg_y = CEIL_DIV(total_wg, wg_x);
compute_2d_workgroups(total_wg, ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension, wg_x, wg_y);
return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x, wg_y);
}
@ -1064,9 +1068,10 @@ static webgpu_encoded_op ggml_webgpu_im2col(webgpu_context & ctx,
auto * decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context.get());
uint32_t wg_x;
uint32_t wg_y;
uint32_t total_wg = CEIL_DIV((uint32_t) ggml_nelements(dst), decisions->wg_size);
uint32_t wg_x = std::min(ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension, total_wg);
uint32_t wg_y = CEIL_DIV(total_wg, wg_x);
compute_2d_workgroups(total_wg, ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension, wg_x, wg_y);
return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x, wg_y);
}
@ -1689,14 +1694,11 @@ static webgpu_encoded_op ggml_webgpu_mul_mat_id(webgpu_context & ctx,
gathered_count_ids_binding_size),
};
const uint32_t max_wg_per_dim = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension;
const uint32_t gather_total_wg = param_n_expert;
const uint32_t gather_wg_x = std::min(gather_total_wg, max_wg_per_dim);
const uint32_t gather_wg_y = CEIL_DIV(gather_total_wg, gather_wg_x);
// n_expert is much less than maxComputeWorkgroupsPerDimension (e.g., n_exeprt=256 at Qwen3.5-35B-A3B)
const uint32_t gather_wg_x = param_n_expert;
dispatches.push_back({
gather_pipeline, std::move(gather_params), std::move(gather_entries), { gather_wg_x, gather_wg_y }
gather_pipeline, std::move(gather_params), std::move(gather_entries), { gather_wg_x, 1 }
});
// params for mul_mat_id.wgsl
@ -1748,7 +1750,7 @@ static webgpu_encoded_op ggml_webgpu_mul_mat_id(webgpu_context & ctx,
uint32_t max_wg_n = CEIL_DIV(total_gathered, tile_n_s) + max_active_experts;
uint32_t total_wg = wg_m * max_wg_n;
compute_2d_workgroups(total_wg, max_wg_per_dim, wg_x, wg_y);
compute_2d_workgroups(total_wg, ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension, wg_x, wg_y);
dispatches.push_back({
main_pipeline, std::move(main_params), std::move(main_entries), { wg_x, wg_y }
@ -2771,10 +2773,12 @@ static webgpu_encoded_op ggml_webgpu_argsort(webgpu_context & ctx, ggml_tensor *
block_size, npr, nrows
};
const uint32_t total_wg_init = npr * nrows;
const uint32_t max_wg = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension;
const uint32_t wg_x_init = std::min(total_wg_init, max_wg);
const uint32_t wg_y_init = CEIL_DIV(total_wg_init, wg_x_init);
uint32_t wg_x_init;
uint32_t wg_y_init;
const uint32_t total_wg_init = npr * nrows;
const uint32_t max_wg_per_dim = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension;
compute_2d_workgroups(total_wg_init, max_wg_per_dim, wg_x_init, wg_y_init);
std::vector<wgpu::BindGroupEntry> init_entries = {
ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, src),
ggml_webgpu_make_bind_group_entry(1, ggml_webgpu_tensor_buf(dst), init_align_offset, init_binding_size)
@ -2831,9 +2835,11 @@ static webgpu_encoded_op ggml_webgpu_argsort(webgpu_context & ctx, ggml_tensor *
ggml_webgpu_make_bind_group_entry(2, ggml_webgpu_tensor_buf(dst), align_out, size_out)
};
uint32_t wg_x_merge;
uint32_t wg_y_merge;
const uint32_t total_wg_merge = nm * nrows;
const uint32_t wg_x_merge = std::min(total_wg_merge, max_wg);
const uint32_t wg_y_merge = CEIL_DIV(total_wg_merge, wg_x_merge);
compute_2d_workgroups(total_wg_merge, max_wg_per_dim, wg_x_merge, wg_y_merge);
dispatches.push_back({
argsort_merge_pipeline, std::move(merge_params), std::move(merge_entries), { wg_x_merge, wg_y_merge }
});
@ -2953,9 +2959,12 @@ static webgpu_encoded_op ggml_webgpu_upscale(webgpu_context ctx, ggml_tensor * s
webgpu_pipeline pipeline = ctx->shader_lib->get_upscale_pipeline(shader_lib_ctx);
auto * decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context.get());
uint32_t total_wg = CEIL_DIV((uint32_t) ggml_nelements(dst), decisions->wg_size);
uint32_t wg_x = std::min(ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension, total_wg);
uint32_t wg_y = CEIL_DIV(total_wg, wg_x);
uint32_t wg_x;
uint32_t wg_y;
uint32_t total_wg = CEIL_DIV((uint32_t) ggml_nelements(dst), decisions->wg_size);
compute_2d_workgroups(total_wg, ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension, wg_x, wg_y);
return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x, wg_y);
}

View file

@ -49,12 +49,14 @@ struct Params{
var<uniform> params: Params;
@compute @workgroup_size(WG_SIZE)
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
if (gid.x >= params.ne) {
fn main(
@builtin(global_invocation_index) gindex: u32,
) {
if (gindex >= params.ne) {
return;
}
var i = gid.x;
var i = gindex;
let i3 = i / (params.src_ne2 * params.src_ne1 * params.src_ne0);
i = i % (params.src_ne2 * params.src_ne1 * params.src_ne0);
let i2 = i / (params.src_ne1 * params.src_ne0);
@ -62,7 +64,7 @@ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
let i1 = i / params.src_ne0;
let i0 = i % params.src_ne0;
var j = gid.x;
var j = gindex;
let j3 = j / (params.dst_ne2 * params.dst_ne1 * params.dst_ne0);
j = j % (params.dst_ne2 * params.dst_ne1 * params.dst_ne0);
let j2 = j / (params.dst_ne1 * params.dst_ne0);

View file

@ -21,35 +21,32 @@ var<workgroup> count:atomic<u32>;
@compute @workgroup_size(WG_SIZE)
fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
@builtin(local_invocation_id) local_id: vec3<u32>,
@builtin(num_workgroups) num_wg: vec3<u32>) {
@builtin(local_invocation_id) local_id: vec3<u32>) {
let thread_id = local_id.x;
let own_expert = wg_id.y * num_wg.x + wg_id.x; // the expert assigned to this workgroup
let own_expert = wg_id.x; // the expert assigned to this workgroup
if (own_expert < params.n_expert) {
if (thread_id == 0u) {
atomicStore(&count, 0);
}
if (thread_id == 0u) {
atomicStore(&count, 0);
}
workgroupBarrier();
workgroupBarrier();
for (var i = thread_id;i < params.n_expert_used * params.n_tokens;i += WG_SIZE) {
let row = i / params.n_expert_used;
let col = i % params.n_expert_used;
let expert = u32(ids[params.offset_ids + row * params.stride_ids_1 + col]);
if (own_expert == expert) {
let pos = atomicAdd(&count, 1u);
let gathered_id = own_expert * params.n_tokens + pos;
global_gathered_expert_used[gathered_id] = col;
global_gathered_tokens[gathered_id] = row;
}
}
workgroupBarrier();
if (thread_id == 0u) {
gathered_count_ids[own_expert] = atomicLoad(&count);
for (var i = thread_id;i < params.n_expert_used * params.n_tokens;i += WG_SIZE) {
let row = i / params.n_expert_used;
let col = i % params.n_expert_used;
let expert = u32(ids[params.offset_ids + row * params.stride_ids_1 + col]);
if (own_expert == expert) {
let pos = atomicAdd(&count, 1u);
let gathered_id = own_expert * params.n_tokens + pos;
global_gathered_expert_used[gathered_id] = col;
global_gathered_tokens[gathered_id] = row;
}
}
workgroupBarrier();
if (thread_id == 0u) {
gathered_count_ids[own_expert] = atomicLoad(&count);
}
}