vulkan : support ggml_mean (#15393)

* vulkan : support ggml_mean

* vulkan : support sum, sum_rows and mean with non-contiguous tensors

* vulkan : fix subbuffer size not accounting for misalign offset

* tests : add backend-op tests for non-contiguous sum_rows

* cuda : require contiguous src for SUM_ROWS, MEAN support
* sycl : require contiguous src for SUM, SUM_ROWS, ARGSORT support

* require ggml_contiguous_rows in supports_op and expect nb00=1 in the shader
This commit is contained in:
Acly 2025-08-23 08:35:21 +02:00 committed by GitHub
parent 330c3d2d21
commit 0a9b43e507
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 135 additions and 18 deletions

View file

@ -4300,20 +4300,32 @@ struct test_sum : public test_case {
struct test_sum_rows : public test_case {
const ggml_type type;
const std::array<int64_t, 4> ne;
const bool permute;
const bool slice;
std::string vars() override {
return VARS_TO_STR2(type, ne);
return VARS_TO_STR4(type, ne, permute, slice);
}
test_sum_rows(ggml_type type = GGML_TYPE_F32,
std::array<int64_t, 4> ne = {10, 5, 4, 3})
: type(type), ne(ne) {}
std::array<int64_t, 4> ne = {10, 5, 4, 3},
bool permute = false, bool slice = false)
: type(type), ne(ne), permute(permute), slice(slice) {}
ggml_tensor * build_graph(ggml_context * ctx) override {
ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
ggml_set_param(a);
ggml_set_name(a, "a");
if (slice) {
a = ggml_view_4d(ctx, a,
ne[0], ne[1], ne[2] / 2, ne[3] - 1,
a->nb[1], a->nb[2] * 2, a->nb[3], /*offset=*/a->nb[3]);
}
if (permute) {
a = ggml_permute(ctx, a, 0, 2, 3, 1);
}
ggml_tensor * out = ggml_sum_rows(ctx, a);
ggml_set_name(out, "out");
@ -6195,6 +6207,9 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
test_cases.emplace_back(new test_sum());
test_cases.emplace_back(new test_sum_rows());
test_cases.emplace_back(new test_sum_rows(GGML_TYPE_F32, { 11, 5, 6, 3 }, true, false));
test_cases.emplace_back(new test_sum_rows(GGML_TYPE_F32, { 11, 5, 6, 3 }, false, true));
test_cases.emplace_back(new test_sum_rows(GGML_TYPE_F32, { 11, 5, 6, 3 }, true, true));
test_cases.emplace_back(new test_mean());
test_cases.emplace_back(new test_sum(GGML_TYPE_F32, { 33, 1, 1, 1 }));
test_cases.emplace_back(new test_sum_rows(GGML_TYPE_F32, { 33, 1, 1, 1 }));