diff --git a/ggml/src/ggml-zendnn/ggml-zendnn.cpp b/ggml/src/ggml-zendnn/ggml-zendnn.cpp index 6051d0820..3c33dcb11 100644 --- a/ggml/src/ggml-zendnn/ggml-zendnn.cpp +++ b/ggml/src/ggml-zendnn/ggml-zendnn.cpp @@ -88,7 +88,7 @@ static bool ggml_zendnn_matmul(ggml_backend_zendnn_context * ctx, int64_t m, int return true; } -static bool ggml_zendnn_sgemm(ggml_backend_zendnn_context * ctx, int64_t m, int64_t n, int64_t k, +static bool ggml_zendnn_gemm(ggml_backend_zendnn_context * ctx, int64_t m, int64_t n, int64_t k, const void * A, int64_t lda, const void * B, int64_t ldb, void * C, int64_t ldc, int Atype, int Btype, int Ctype) { @@ -200,7 +200,7 @@ static void ggml_zendnn_compute_forward_mul_mat( for (int64_t i12 = 0; i12 < ne12; i12++) { const void* wdata = (src1->type == vec_dot_type || src0->type == GGML_TYPE_Q8_0) ? src1->data : work_data; const size_t row_size = ggml_row_size(vec_dot_type, ne10); - if (!ggml_zendnn_sgemm(ctx, + if (!ggml_zendnn_gemm(ctx, ne01, // m ne11, // n ne10, // k @@ -213,7 +213,7 @@ static void ggml_zendnn_compute_forward_mul_mat( src0->type, src0->type == GGML_TYPE_Q8_0 ? GGML_TYPE_F32 : vec_dot_type, dst->type)) - GGML_ABORT("%s: ZenDNN sgemm failed\n", __func__); + GGML_ABORT("%s: ZenDNN gemm failed\n", __func__); } } } @@ -355,7 +355,7 @@ static void ggml_zendnn_compute_forward_mul_mat_id( } // batched gemm for all tokens in this expert - if (!ggml_zendnn_sgemm(ctx, + if (!ggml_zendnn_gemm(ctx, ne01, // m cne1, // n ne10, // k @@ -368,7 +368,7 @@ static void ggml_zendnn_compute_forward_mul_mat_id( src0->type, src0->type == GGML_TYPE_Q8_0 ? GGML_TYPE_F32 : vec_dot_type, dst->type)) { - GGML_ABORT("%s: ZenDNN sgemm failed\n", __func__); + GGML_ABORT("%s: ZenDNN gemm failed\n", __func__); } // scatter output rows to destination