diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 38f6f13c1..a5df508ff 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -6012,6 +6012,7 @@ template const T *push_constant_data(const std::array static void ggml_vk_dispatch_pipeline(ggml_backend_vk_context* ctx, vk_context& subctx, vk_pipeline& pipeline, std::initializer_list const& descriptor_buffer_infos, const T &push_constants, std::array elements) { const uint32_t wg0 = CEIL_DIV(elements[0], pipeline->wg_denoms[0]); @@ -6022,9 +6023,16 @@ static void ggml_vk_dispatch_pipeline(ggml_backend_vk_context* ctx, vk_context& std::cerr << "(" << buffer.buffer << ", " << buffer.offset << ", " << buffer.range << "), "; } std::cerr << "}, (" << wg0 << "," << wg1 << "," << wg2 << "))"); - GGML_ASSERT_CONTINUE(wg0 <= ctx->device->properties.limits.maxComputeWorkGroupCount[0] && - wg1 <= ctx->device->properties.limits.maxComputeWorkGroupCount[1] && - wg2 <= ctx->device->properties.limits.maxComputeWorkGroupCount[2]); + // GGML_ASSERT_CONTINUE(wg0 <= ctx->device->properties.limits.maxComputeWorkGroupCount[0] && + // wg1 <= ctx->device->properties.limits.maxComputeWorkGroupCount[1] && + // wg2 <= ctx->device->properties.limits.maxComputeWorkGroupCount[2]); + if(!kcpp_wg_warning && !(wg0 <= ctx->device->properties.limits.maxComputeWorkGroupCount[0] && + wg1 <= ctx->device->properties.limits.maxComputeWorkGroupCount[1] && + wg2 <= ctx->device->properties.limits.maxComputeWorkGroupCount[2])) + { + kcpp_wg_warning = true; + fprintf(stderr, "\nWarning: Workgroup exceeds max count: wg0=%d wg1=%d wg2=%d vs (%d, %d, %d)\n",wg0,wg1,wg2,ctx->device->properties.limits.maxComputeWorkGroupCount[0],ctx->device->properties.limits.maxComputeWorkGroupCount[1],ctx->device->properties.limits.maxComputeWorkGroupCount[2]); + } GGML_ASSERT(ctx->descriptor_set_idx < ctx->descriptor_sets.size()); GGML_ASSERT(descriptor_buffer_infos.size() <= MAX_PARAMETER_COUNT); GGML_ASSERT(pipeline->parameter_count == descriptor_buffer_infos.size());