diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index c5354293a..76a0cfa4f 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -8392,7 +8392,7 @@ static void ggml_vk_sum_rows(ggml_backend_vk_context * ctx, vk_context& subctx, } static void ggml_vk_argmax(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { - ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_ARGMAX, { (uint32_t)src0->ne[0], 0, 0.0f, 0.0f }, dryrun); + ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_ARGMAX, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], 0.0f, 0.0f }, dryrun); } static void ggml_vk_count_equal(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/argmax.comp b/ggml/src/ggml-vulkan/vulkan-shaders/argmax.comp index eaf4da341..a1d4c240d 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/argmax.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/argmax.comp @@ -5,6 +5,8 @@ #extension GL_EXT_control_flow_attributes : enable +#define FLT_MAX 3.402823466e+38F + layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; @@ -19,19 +21,26 @@ void main() { const uint row = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x; const uint col = gl_LocalInvocationID.x; - if (col >= p.KX) { + if (row >= p.KY) { return; } - A_TYPE amax = data_a[row*p.KX + col]; - tmp[col] = col; + + A_TYPE amax = -FLT_MAX; + uint acol = col; + + if (col < p.KX) { + amax = data_a[row*p.KX + col]; + } for (uint i = col + BLOCK_SIZE; i < p.KX; i += BLOCK_SIZE) { A_TYPE val = data_a[row*p.KX + i]; if (val > amax) { amax = val; - tmp[col] = i; + acol = i; } } + + tmp[col] = acol; tmpmax[col] = amax; barrier(); diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index cc9c3a0d5..f4565f9b7 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -5528,6 +5528,7 @@ static std::vector> make_test_cases_eval() { test_cases.emplace_back(new test_count_equal(GGML_TYPE_F32, {4, 5000, 1, 1})); test_cases.emplace_back(new test_argmax(GGML_TYPE_F32, {32, 1, 1, 1})); + test_cases.emplace_back(new test_argmax(GGML_TYPE_F32, {32, 513, 1, 1})); test_cases.emplace_back(new test_argmax(GGML_TYPE_F32, {100, 10, 1, 1})); test_cases.emplace_back(new test_argmax(GGML_TYPE_F32, {1024, 10, 1, 1})); test_cases.emplace_back(new test_argmax(GGML_TYPE_F32, {1024, 12, 1, 1}));