Concedo 2026-01-06 20:41:47 +08:00
parent bd51d775be
commit 246ce4babd

View file

@ -6012,6 +6012,7 @@ template <typename T, uint32_t N> const T *push_constant_data(const std::array<T
return t.data();
}
static bool kcpp_wg_warning = false;
template <typename T>
static void ggml_vk_dispatch_pipeline(ggml_backend_vk_context* ctx, vk_context& subctx, vk_pipeline& pipeline, std::initializer_list<vk::DescriptorBufferInfo> const& descriptor_buffer_infos, const T &push_constants, std::array<uint32_t, 3> elements) {
const uint32_t wg0 = CEIL_DIV(elements[0], pipeline->wg_denoms[0]);
@ -6022,9 +6023,16 @@ static void ggml_vk_dispatch_pipeline(ggml_backend_vk_context* ctx, vk_context&
std::cerr << "(" << buffer.buffer << ", " << buffer.offset << ", " << buffer.range << "), ";
}
std::cerr << "}, (" << wg0 << "," << wg1 << "," << wg2 << "))");
GGML_ASSERT_CONTINUE(wg0 <= ctx->device->properties.limits.maxComputeWorkGroupCount[0] &&
wg1 <= ctx->device->properties.limits.maxComputeWorkGroupCount[1] &&
wg2 <= ctx->device->properties.limits.maxComputeWorkGroupCount[2]);
// GGML_ASSERT_CONTINUE(wg0 <= ctx->device->properties.limits.maxComputeWorkGroupCount[0] &&
// wg1 <= ctx->device->properties.limits.maxComputeWorkGroupCount[1] &&
// wg2 <= ctx->device->properties.limits.maxComputeWorkGroupCount[2]);
if(!kcpp_wg_warning && !(wg0 <= ctx->device->properties.limits.maxComputeWorkGroupCount[0] &&
wg1 <= ctx->device->properties.limits.maxComputeWorkGroupCount[1] &&
wg2 <= ctx->device->properties.limits.maxComputeWorkGroupCount[2]))
{
kcpp_wg_warning = true;
fprintf(stderr, "\nWarning: Workgroup exceeds max count: wg0=%d wg1=%d wg2=%d vs (%d, %d, %d)\n",wg0,wg1,wg2,ctx->device->properties.limits.maxComputeWorkGroupCount[0],ctx->device->properties.limits.maxComputeWorkGroupCount[1],ctx->device->properties.limits.maxComputeWorkGroupCount[2]);
}
GGML_ASSERT(ctx->descriptor_set_idx < ctx->descriptor_sets.size());
GGML_ASSERT(descriptor_buffer_infos.size() <= MAX_PARAMETER_COUNT);
GGML_ASSERT(pipeline->parameter_count == descriptor_buffer_infos.size());