diff --git a/.clang-format b/.clang-format index 45232b80e..47d96b6b4 100644 --- a/.clang-format +++ b/.clang-format @@ -22,8 +22,8 @@ AllowShortIfStatementsOnASingleLine: Never AllowShortLambdasOnASingleLine: Inline AllowShortLoopsOnASingleLine: false AlwaysBreakBeforeMultilineStrings: true -BinPackArguments: true -BinPackParameters: true # OnePerLine +BinPackArguments: false +BinPackParameters: false # OnePerLine BitFieldColonSpacing: Both BreakBeforeBraces: Custom # Attach BraceWrapping: @@ -70,15 +70,18 @@ ExperimentalAutoDetectBinPacking: false FixNamespaceComments: true IncludeBlocks: Regroup IncludeCategories: - - Regex: '^<.*\.h>' + - Regex: '".*"' Priority: 1 SortPriority: 0 - - Regex: '^<.*' + - Regex: '^<.*\.h>' Priority: 2 SortPriority: 0 - - Regex: '.*' + - Regex: '^<.*' Priority: 3 SortPriority: 0 + - Regex: '.*' + Priority: 4 + SortPriority: 0 IncludeIsMainRegex: '([-_](test|unittest))?$' IncludeIsMainSourceRegex: '' IndentAccessModifiers: false diff --git a/common/common.cpp b/common/common.cpp index 9d0ebb8c4..871158ae3 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -456,6 +456,15 @@ void string_replace_all(std::string & s, const std::string & search, const std:: bool string_ends_with(const std::string_view & str, const std::string_view & suffix) { return str.size() >= suffix.size() && str.compare(str.size()-suffix.size(), suffix.size(), suffix) == 0; } + +bool string_remove_suffix(std::string & str, const std::string_view & suffix) { + bool has_suffix = string_ends_with(str, suffix); + if (has_suffix) { + str = str.substr(0, str.size() - suffix.size()); + } + return has_suffix; +} + size_t string_find_partial_stop(const std::string_view & str, const std::string_view & stop) { if (!str.empty() && !stop.empty()) { const char text_last_char = str.back(); diff --git a/common/common.h b/common/common.h index 0fc5e07ed..0e8fa156f 100644 --- a/common/common.h +++ b/common/common.h @@ -530,6 +530,7 @@ static bool string_starts_with(const std::string & str, // While we wait for C++20's std::string::ends_with... bool string_ends_with(const std::string_view & str, const std::string_view & suffix); +bool string_remove_suffix(std::string & str, const std::string_view & suffix); size_t string_find_partial_stop(const std::string_view & str, const std::string_view & stop); bool string_parse_kv_override(const char * data, std::vector & overrides); diff --git a/ggml/src/ggml-alloc.c b/ggml/src/ggml-alloc.c index 5fd379f6a..fcc552da5 100644 --- a/ggml/src/ggml-alloc.c +++ b/ggml/src/ggml-alloc.c @@ -22,21 +22,6 @@ static bool ggml_is_view(const struct ggml_tensor * t) { return t->view_src != NULL; } -static bool ggml_are_same_layout(const struct ggml_tensor * a, const struct ggml_tensor * b) { - if (a->type != b->type) { - return false; - } - for (int i = 0; i < GGML_MAX_DIMS; i++) { - if (a->ne[i] != b->ne[i]) { - return false; - } - if (a->nb[i] != b->nb[i]) { - return false; - } - } - return true; -} - // ops that return true for this function must not use restrict pointers for their backend implementations static bool ggml_op_can_inplace(enum ggml_op op) { switch (op) { diff --git a/ggml/src/ggml-backend.cpp b/ggml/src/ggml-backend.cpp index eafd5fd34..04a35806b 100644 --- a/ggml/src/ggml-backend.cpp +++ b/ggml/src/ggml-backend.cpp @@ -352,21 +352,6 @@ ggml_backend_dev_t ggml_backend_get_device(ggml_backend_t backend) { // backend copy -static bool ggml_are_same_layout(const struct ggml_tensor * a, const struct ggml_tensor * b) { - if (a->type != b->type) { - return false; - } - for (int i = 0; i < GGML_MAX_DIMS; i++) { - if (a->ne[i] != b->ne[i]) { - return false; - } - if (a->nb[i] != b->nb[i]) { - return false; - } - } - return true; -} - void ggml_backend_tensor_copy(struct ggml_tensor * src, struct ggml_tensor * dst) { GGML_ASSERT(ggml_are_same_layout(src, dst) && "cannot copy tensors with different layouts"); diff --git a/ggml/src/ggml-cpu/kleidiai/kernels.cpp b/ggml/src/ggml-cpu/kleidiai/kernels.cpp deleted file mode 100644 index 910fd0ee4..000000000 --- a/ggml/src/ggml-cpu/kleidiai/kernels.cpp +++ /dev/null @@ -1,337 +0,0 @@ -// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates -// SPDX-License-Identifier: MIT -// - -// KleidiAI micro-kernels -#include "kai_matmul_clamp_f32_qsi8d32p_qsi4c32p_interface.h" -#include "kai_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod.h" -#include "kai_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod.h" -#include "kai_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod.h" -#include "kai_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm.h" -#include "kai_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa.h" -#include "kai_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot.h" -#include "kai_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa.h" - -#include "kai_lhs_pack_bf16p2vlx2_f32_sme.h" -#include "kai_lhs_quant_pack_qsi8d32p_f32.h" -#include "kai_lhs_quant_pack_qsi8d32p_f32_neon.h" - -#include "kai_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme.h" -#include "kai_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0.h" -#include "kai_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon.h" - -#include "kai_common.h" - -#include "kernels.h" - -#define NELEMS(x) sizeof(x) / sizeof(*x) -static ggml_kleidiai_kernels gemm_gemv_kernels[] = { -#if defined(__ARM_FEATURE_SME) - { - /* SME GEMM */ - /* .kern_info = */ { - /* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa, - /* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa, - /* .get_mr = */ kai_get_mr_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa, - /* .get_nr = */ kai_get_nr_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa, - /* .get_kr = */ kai_get_kr_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa, - /* .get_sr = */ kai_get_sr_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa, - /* .get_lhs_offset = */ kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa, - /* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa, - /* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa, - /* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa, - /* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa, - }, - /* SME GEMV */ - /* .kern_info = */ { - /* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot, - /* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot, - /* .get_mr = */ kai_get_mr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot, - /* .get_nr = */ kai_get_nr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot, - /* .get_kr = */ kai_get_kr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot, - /* .get_sr = */ kai_get_sr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot, - /* .get_lhs_offset = */ kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot, - /* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot, - /* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot, - /* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot, - /* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot, - }, - /* .lhs_info = */ { - /* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f32_neon, - /* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32_neon, - /* .packed_size = */ kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32_neon, - /* .pack_func = */ kai_run_lhs_quant_pack_qsi8d32p_f32_neon, - }, - /* .rhs_info = */ { - /* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon, - /* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon, - }, - /* .required_cpu = */ CPU_FEATURE_SME, - /* .lhs_type = */ GGML_TYPE_F32, - /* .rhs_type = */ GGML_TYPE_Q4_0, - /* .op_type = */ GGML_TYPE_F32, - }, - { - /* SME GEMM */ - /* .kern_info = */ { - /* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa, - /* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa, - /* .get_mr = */ kai_get_mr_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa, - /* .get_nr = */ kai_get_nr_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa, - /* .get_kr = */ kai_get_kr_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa, - /* .get_sr = */ kai_get_sr_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa, - /* .get_lhs_offset = */ kai_get_lhs_packed_offset_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa, - /* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa, - /* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa, - /* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa, - /* .run_kernel = */ kai_run_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa, - }, - /* SME GEMV */ - /* .kern_info = */ { - /* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa, - /* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa, - /* .get_mr = */ kai_get_mr_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa, - /* .get_nr = */ kai_get_nr_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa, - /* .get_kr = */ kai_get_kr_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa, - /* .get_sr = */ kai_get_sr_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa, - /* .get_lhs_offset = */ kai_get_lhs_packed_offset_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa, - /* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa, - /* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa, - /* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa, - /* .run_kernel = */ kai_run_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa, - }, - /* .lhs_info = */ { - /* .get_offset = */ kai_get_lhs_offset_lhs_pack_bf16p2vlx2_f32_sme, - /* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_pack_bf16p2vlx2_f32_sme, - /* .packed_size = */ kai_get_lhs_packed_size_lhs_pack_bf16p2vlx2_f32_sme, - /* .pack_func = */ kai_run_lhs_pack_bf16p2vlx2_f32_sme, - }, - /* .rhs_info = */ { - /* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme, - /* .pack_func = */ kai_run_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme, - }, - /* .required_cpu = */ CPU_FEATURE_SME, - /* .lhs_type = */ GGML_TYPE_F32, - /* .rhs_type = */ GGML_TYPE_F16, - /* .op_type = */ GGML_TYPE_F32, - }, -#endif -#if defined(__APPLE__) -#if defined(__ARM_FEATURE_DOTPROD) - { - /* DOTPROD GEMM */ - /* .kern_info = */ { - /* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod, - /* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod, - /* .get_mr = */ kai_get_mr_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod, - /* .get_nr = */ kai_get_nr_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod, - /* .get_kr = */ kai_get_kr_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod, - /* .get_sr = */ kai_get_sr_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod, - /* .get_lhs_offset = */ kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod, - /* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod, - /* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod, - /* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod, - /* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod, - }, - /* DOTPROD GEMV */ - /* .kern_info = */ { - /* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod, - /* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod, - /* .get_mr = */ kai_get_mr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod, - /* .get_nr = */ kai_get_nr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod, - /* .get_kr = */ kai_get_kr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod, - /* .get_sr = */ kai_get_sr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod, - /* .get_lhs_offset = */ kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod, - /* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod, - /* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod, - /* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod, - /* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod, - }, - /* .lhs_info = */ { - /* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f32, - /* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32, - /* .packed_size = */ kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32, - /* .pack_func = */ kai_run_lhs_quant_pack_qsi8d32p_f32, - }, - /* .rhs_info = */ { - /* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0, - /* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0, - }, - /* .required_cpu = */ CPU_FEATURE_DOTPROD, - /* .lhs_type = */ GGML_TYPE_F32, - /* .rhs_type = */ GGML_TYPE_Q4_0, - /* .op_type = */ GGML_TYPE_F32, - }, -#endif -#if defined(__ARM_FEATURE_MATMUL_INT8) - { - /* i8mm GEMM */ - /* .kern_info = */ { - /* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm, - /* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm, - /* .get_mr = */ kai_get_mr_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm, - /* .get_nr = */ kai_get_nr_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm, - /* .get_kr = */ kai_get_kr_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm, - /* .get_sr = */ kai_get_sr_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm, - /* .get_lhs_offset = */ kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm, - /* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm, - /* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm, - /* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm, - /* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm, - }, - /* i8mm GEMV */ - /* .kern_info = */ { - /* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod, - /* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod, - /* .get_mr = */ kai_get_mr_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod, - /* .get_nr = */ kai_get_nr_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod, - /* .get_kr = */ kai_get_kr_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod, - /* .get_sr = */ kai_get_sr_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod, - /* .get_lhs_offset = */ kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod, - /* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod, - /* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod, - /* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod, - /* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod, - }, - /* .lhs_info = */ { - /* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f32, - /* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32, - /* .packed_size = */ kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32, - /* .pack_func = */ kai_run_lhs_quant_pack_qsi8d32p_f32, - }, - /* .rhs_info = */ { - /* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0, - /* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0, - }, - /* .required_cpu = */ CPU_FEATURE_DOTPROD | CPU_FEATURE_I8MM, - /* .lhs_type = */ GGML_TYPE_F32, - /* .rhs_type = */ GGML_TYPE_Q4_0, - /* .op_type = */ GGML_TYPE_F32, - }, -#endif -#else -#if defined(__ARM_FEATURE_MATMUL_INT8) - { - /* i8mm GEMM */ - /* .kern_info = */ { - /* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm, - /* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm, - /* .get_mr = */ kai_get_mr_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm, - /* .get_nr = */ kai_get_nr_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm, - /* .get_kr = */ kai_get_kr_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm, - /* .get_sr = */ kai_get_sr_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm, - /* .get_lhs_offset = */ kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm, - /* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm, - /* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm, - /* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm, - /* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm, - }, - /* i8mm GEMV */ - /* .kern_info = */ { - /* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod, - /* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod, - /* .get_mr = */ kai_get_mr_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod, - /* .get_nr = */ kai_get_nr_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod, - /* .get_kr = */ kai_get_kr_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod, - /* .get_sr = */ kai_get_sr_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod, - /* .get_lhs_offset = */ kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod, - /* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod, - /* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod, - /* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod, - /* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod, - }, - /* .lhs_info = */ { - /* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f32, - /* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32, - /* .packed_size = */ kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32, - /* .pack_func = */ kai_run_lhs_quant_pack_qsi8d32p_f32, - }, - /* .rhs_info = */ { - /* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0, - /* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0, - }, - /* .required_cpu = */ CPU_FEATURE_DOTPROD | CPU_FEATURE_I8MM, - /* .lhs_type = */ GGML_TYPE_F32, - /* .rhs_type = */ GGML_TYPE_Q4_0, - /* .op_type = */ GGML_TYPE_F32, - }, -#endif -#if defined(__ARM_FEATURE_DOTPROD) - { - /* DOTPROD GEMM */ - /* .kern_info = */ { - /* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod, - /* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod, - /* .get_mr = */ kai_get_mr_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod, - /* .get_nr = */ kai_get_nr_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod, - /* .get_kr = */ kai_get_kr_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod, - /* .get_sr = */ kai_get_sr_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod, - /* .get_lhs_offset = */ kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod, - /* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod, - /* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod, - /* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod, - /* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod, - }, - /* DOTPROD GEMV */ - /* .kern_info = */ { - /* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod, - /* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod, - /* .get_mr = */ kai_get_mr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod, - /* .get_nr = */ kai_get_nr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod, - /* .get_kr = */ kai_get_kr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod, - /* .get_sr = */ kai_get_sr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod, - /* .get_lhs_offset = */ kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod, - /* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod, - /* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod, - /* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod, - /* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod, - }, - /* .lhs_info = */ { - /* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f32, - /* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32, - /* .packed_size = */ kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32, - /* .pack_func = */ kai_run_lhs_quant_pack_qsi8d32p_f32, - }, - /* .rhs_info = */ { - /* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0, - /* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0, - }, - /* .required_cpu = */ CPU_FEATURE_DOTPROD, - /* .lhs_type = */ GGML_TYPE_F32, - /* .rhs_type = */ GGML_TYPE_Q4_0, - /* .op_type = */ GGML_TYPE_F32, - }, -#endif -#endif -}; - -ggml_kleidiai_kernels * ggml_kleidiai_select_kernels(cpu_feature cpu_features, const ggml_tensor * tensor) { - ggml_kleidiai_kernels * kernel = nullptr; - - if (tensor->op == GGML_OP_MUL_MAT && tensor->src[0] != nullptr && tensor->src[1] != nullptr) { - for (size_t i = 0; i < NELEMS(gemm_gemv_kernels); ++i) { - if ((cpu_features & gemm_gemv_kernels[i].required_cpu) == gemm_gemv_kernels[i].required_cpu && - gemm_gemv_kernels[i].lhs_type == tensor->src[1]->type && - gemm_gemv_kernels[i].rhs_type == tensor->src[0]->type && - gemm_gemv_kernels[i].op_type == tensor->type) { - kernel = &gemm_gemv_kernels[i]; - break; - } - } - } - - return kernel; -} - -ggml_kleidiai_kernels * ggml_kleidiai_select_kernels_q4_0(cpu_feature features) { - ggml_kleidiai_kernels * kernels = nullptr; - - for (size_t i = 0; i < NELEMS(gemm_gemv_kernels); ++i) { - if ((features & gemm_gemv_kernels[i].required_cpu) == gemm_gemv_kernels[i].required_cpu) { - kernels = &gemm_gemv_kernels[i]; - break; - } - } - - return kernels; -} diff --git a/ggml/src/ggml-cpu/kleidiai/kernels.h b/ggml/src/ggml-cpu/kleidiai/kernels.h deleted file mode 100644 index 3b268d4a2..000000000 --- a/ggml/src/ggml-cpu/kleidiai/kernels.h +++ /dev/null @@ -1,95 +0,0 @@ -// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates -// SPDX-License-Identifier: MIT -// - -#pragma once - -#include -#include -#include "ggml.h" - -enum cpu_feature { - CPU_FEATURE_NONE = 0, - CPU_FEATURE_DOTPROD = 1, - CPU_FEATURE_I8MM = 2, - CPU_FEATURE_SVE = 4, - CPU_FEATURE_SME = 8 -}; -inline cpu_feature& operator|=(cpu_feature& lhs, cpu_feature rhs) { - lhs = static_cast(lhs | rhs); - return lhs; -} -inline cpu_feature operator|(cpu_feature lhs, cpu_feature rhs) { - return static_cast(static_cast(lhs) | static_cast(rhs)); -} - -struct kernel_info { - size_t (*get_m_step)(void); - size_t (*get_n_step)(void); - size_t (*get_mr)(void); - size_t (*get_nr)(void); - size_t (*get_kr)(void); - size_t (*get_sr)(void); - std::variant< - std::function, - std::function - > get_lhs_offset; - std::variant< - std::function, - std::function - > get_rhs_packed_offset; - size_t (*get_dst_offset)(size_t m_idx, size_t n_idx, size_t stride); - size_t (*get_dst_size)(size_t m, size_t n); - std::variant< - std::function, - std::function - > run_kernel; -}; - -struct lhs_packing_info { - size_t (*get_offset)(size_t m_idx, size_t lhs_stride); - std::variant< - std::function, - std::function - > get_packed_offset; - std::variant< - std::function, - std::function - > packed_size; - std::variant< - std::function, - std::function - > pack_func; -}; - -struct rhs_packing_info { - std::variant< - std::function, - std::function - > packed_size; - std::variant< - std::function, - std::function - > pack_func; -}; - -struct ggml_kleidiai_kernels { - kernel_info gemm; - kernel_info gemv; - lhs_packing_info lhs_info; - rhs_packing_info rhs_info; - - cpu_feature required_cpu; - ggml_type lhs_type; - ggml_type rhs_type; - ggml_type op_type; -}; - -ggml_kleidiai_kernels * ggml_kleidiai_select_kernels(cpu_feature cpu_features, const ggml_tensor * tensor); -ggml_kleidiai_kernels * ggml_kleidiai_select_kernels_q4_0(cpu_feature features); diff --git a/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp b/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp deleted file mode 100644 index fafe45e6c..000000000 --- a/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +++ /dev/null @@ -1,482 +0,0 @@ -// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates -// SPDX-License-Identifier: MIT -// -#include -#include -#include -#include -#include -#include -#include -#if defined(__linux__) -#include -#include -#elif defined(__APPLE__) -#include -#include -#include -#elif defined(_WIN32) -#include -#include -#endif - -#include "kleidiai.h" - -#include "ggml-cpu.h" -#include "ggml-impl.h" -#include "ggml-backend-impl.h" -#include "ggml-threading.h" -#include "traits.h" - -#include "kernels.h" - -#include "kai_common.h" - -#define GGML_COMMON_DECL_CPP -#include "ggml-common.h" - -struct ggml_kleidiai_context { - cpu_feature features; - ggml_kleidiai_kernels * kernels; -} static ctx = { CPU_FEATURE_NONE, NULL }; - -static void init_kleidiai_context(void) { - - ggml_critical_section_start(); - static bool initialized = false; - - if (!initialized) { - initialized = true; - const char *env_var = getenv("GGML_KLEIDIAI_SME"); - int sme_enabled = 0; - - ctx.features = (ggml_cpu_has_dotprod() ? CPU_FEATURE_DOTPROD : CPU_FEATURE_NONE) | - (ggml_cpu_has_matmul_int8() ? CPU_FEATURE_I8MM : CPU_FEATURE_NONE) | - (ggml_cpu_has_sve() ? CPU_FEATURE_SVE : CPU_FEATURE_NONE); - - if (env_var) { - sme_enabled = atoi(env_var); - } - - if (sme_enabled != 0) { - ctx.features |= ggml_cpu_has_sme() ? CPU_FEATURE_SME : CPU_FEATURE_NONE; - } - ctx.kernels = ggml_kleidiai_select_kernels_q4_0(ctx.features); - } - ggml_critical_section_end(); -} - -static inline int64_t ggml_ne(const ggml_tensor * tensor, int dim) { - GGML_ASSERT(dim >= 0 && dim < GGML_MAX_DIMS); - return tensor->ne[dim]; -} - -template -static Ret variant_call(const Variant & var, Args&&... args) { - return std::visit([&](auto&& func) -> Ret { - if constexpr (std::is_invocable_r_v) { - return func(std::forward(args)...); - } else { - throw std::runtime_error("Invalid function type in variant_call"); - } - }, var); -} - -namespace ggml::cpu::kleidiai { - -static size_t round_down(size_t x, size_t y) { - return y == 0 ? x : x - (x % y); -} - -static void transpose_f32kxn_f16nxk(size_t n, size_t k, float * dst, const uint16_t * src, size_t rhs_stride) { - size_t src_stride = rhs_stride / sizeof(uint16_t); - size_t dst_stride = n; - - for (size_t k_idx = 0; k_idx < k; ++k_idx) { - for (size_t n_idx = 0; n_idx < n; ++n_idx) { - uint16_t v = *(src + k_idx + n_idx * src_stride); - *(dst + n_idx + k_idx * dst_stride) = kai_cast_f32_f16(v); - } - } -} - -class tensor_traits : public ggml::cpu::tensor_traits { - bool work_size(int /* n_threads */, const struct ggml_tensor * op, size_t & size) override { - ggml_kleidiai_kernels *kernels = ggml_kleidiai_select_kernels(ctx.features, op); - GGML_ASSERT(kernels); - kernel_info * kernel = op->src[1]->ne[1] == 1 ? &kernels->gemv : &kernels->gemm; - - size_t k = op->src[0]->ne[0]; - size_t n = op->src[0]->ne[1]; - size_t m = op->src[1]->ne[1]; - - size_t mr = kernel->get_mr(); - size_t kr = kernel->get_kr(); - size_t sr = kernel->get_sr(); - - if (kernels->rhs_type == GGML_TYPE_Q4_0) { - size = variant_call(kernels->lhs_info.packed_size, m, k, QK4_0, mr, kr, sr); - } else if (kernels->rhs_type == GGML_TYPE_F16) { - size = variant_call(kernels->lhs_info.packed_size, m, k, mr, kr, sr) + - variant_call(kernels->rhs_info.packed_size, n, k) + - k * n * sizeof(float) + n * sizeof(float); - } else { - GGML_ASSERT(false); - } - - return true; - } - - - bool compute_forward(struct ggml_compute_params * params, struct ggml_tensor * dst) override { - if (dst->op == GGML_OP_MUL_MAT) { - if (dst->src[0]->type == GGML_TYPE_Q4_0) { - return compute_forward_q4_0(params, dst); - } else if (dst->src[0]->type == GGML_TYPE_F16) { - return compute_forward_kv_cache(params, dst); - } - } - return false; - } - - bool compute_forward_kv_cache(ggml_compute_params * params, struct ggml_tensor * dst) { - static std::atomic_flag first_to_arrive = ATOMIC_FLAG_INIT; - - const ggml_tensor * src0 = dst->src[0]; - const ggml_tensor * src1 = dst->src[1]; - - GGML_TENSOR_BINARY_OP_LOCALS - - ggml_kleidiai_kernels *kernels = ggml_kleidiai_select_kernels(ctx.features, dst); - GGML_ASSERT(kernels); - - kernel_info * kernel = src1->ne[1] == 1 ? &kernels->gemv : &kernels->gemm; - GGML_ASSERT(kernel); - - const int nth = params->nth; - const int ith = params->ith; - - const int64_t lhs_batch_size0 = ne12; - const int64_t rhs_batch_size0 = ne02; - const int64_t batch_size = rhs_batch_size0; - - const int64_t r = lhs_batch_size0 / rhs_batch_size0; - - const int64_t m = ne11 * r; - const int64_t n = ne01; - const int64_t k = ne00; - - const size_t lhs_stride = src1->nb[1]; - const size_t rhs_stride = src0->nb[1]; - const size_t dst_stride = dst->nb[1]; - - const int64_t mr = static_cast(kernel->get_mr()); - const int64_t nr = static_cast(kernel->get_nr()); - const int64_t kr = static_cast(kernel->get_kr()); - const int64_t sr = static_cast(kernel->get_sr()); - - const size_t lhs_packed_size = variant_call(kernels->lhs_info.packed_size, m, k, mr, kr, sr); - const size_t rhs_packed_size = variant_call(kernels->rhs_info.packed_size, n, k); - const size_t kxn_size = k * n * sizeof(float); - const size_t bias_size = n * sizeof(float); - - const size_t wsize_required = lhs_packed_size + rhs_packed_size + kxn_size + bias_size; - GGML_ASSERT(wsize_required <= params->wsize); - - uint8_t * lhs_packed = static_cast(params->wdata); - uint8_t * rhs_packed = lhs_packed + lhs_packed_size; - uint8_t * rhs_kxn = rhs_packed + rhs_packed_size; - uint8_t * bias = rhs_kxn + kxn_size; - - for (int64_t batch_idx = 0; batch_idx < batch_size; ++batch_idx) { - const uint8_t * lhs_batch = static_cast(src1->data) + batch_idx * m * lhs_stride; - const uint8_t * rhs_batch = static_cast(src0->data) + batch_idx * n * rhs_stride; - uint8_t * dst_batch = static_cast(dst->data) + batch_idx * m * dst_stride; - - // LHS packing - { - const int64_t m_roundup_mr = kai_roundup(m, mr); - const int64_t num_threads = KAI_MIN(m_roundup_mr / mr, nth); - - if (ith < num_threads) { - const int64_t num_m_per_thread0 = round_down(m_roundup_mr / num_threads, mr); - const int64_t num_m_per_threadN_1 = m - (num_threads - 1) * num_m_per_thread0; - - const int64_t m_start = ith * num_m_per_thread0; - const int64_t num_m_per_thread = (ith == num_threads - 1) ? num_m_per_threadN_1 : num_m_per_thread0; - - const size_t lhs_offset = variant_call(kernels->gemm.get_lhs_offset, m_start, lhs_stride); - const size_t lhs_packed_offset = variant_call(kernels->lhs_info.get_packed_offset, m_start, k, mr, kr, sr); - - const void * src_ptr = static_cast(lhs_batch) + lhs_offset; - void * dst_ptr = static_cast(lhs_packed) + lhs_packed_offset; - - variant_call(kernels->lhs_info.pack_func, num_m_per_thread, k, mr, kr, sr, 0, src_ptr, lhs_stride, dst_ptr); - } - } - - // RHS packing - if (first_to_arrive.test_and_set(std::memory_order_acquire) == false) { - // First thread to reach this point handles RHS packing - memset(bias, 0, n * sizeof(float)); - transpose_f32kxn_f16nxk(n, k, reinterpret_cast(rhs_kxn), - reinterpret_cast(rhs_batch), rhs_stride); - - variant_call(kernels->rhs_info.pack_func, 1, n, k, nr, kr, sr, n * sizeof(float), - rhs_kxn, bias, nullptr, rhs_packed, 0, nullptr); - } - - ggml_barrier(params->threadpool); - - first_to_arrive.clear(std::memory_order_release); - - // Perform the matmul - { - const int64_t m_to_process = m; - const int64_t m_start = 0; - - const int64_t n_step = static_cast(kernel->get_n_step()); - const int64_t num_threads = KAI_MIN(n / n_step, nth); - - if (ith < num_threads) { - const int64_t num_n_per_thread0 = round_down(n / num_threads, n_step); - const int64_t num_n_per_threadN_1 = n - (num_threads - 1) * num_n_per_thread0; - - const int64_t n_start = ith * num_n_per_thread0; - const int64_t n_to_process = (ith == num_threads - 1) ? num_n_per_threadN_1 : num_n_per_thread0; - - const size_t lhs_packed_offset = variant_call(kernel->get_lhs_offset, m_start, k); - const size_t rhs_packed_offset = variant_call(kernel->get_rhs_packed_offset, n_start, k); - const size_t dst_offset = kernel->get_dst_offset(m_start, n_start, dst_stride); - - const void * lhs_ptr = lhs_packed + lhs_packed_offset; - const void * rhs_ptr = rhs_packed + rhs_packed_offset; - float * dst_ptr = reinterpret_cast(dst_batch + dst_offset); - - variant_call(kernel->run_kernel, m_to_process, n_to_process, k, lhs_ptr, rhs_ptr, dst_ptr, dst_stride, sizeof(float), -FLT_MAX, FLT_MAX); - } - } - - if (batch_idx != batch_size - 1) { - // This barrier is necessary when the batch size is larger than 1. While processing a batch, - // the work data buffer (params->wdata) is used as temporary storage which means that only - // a single batch can be processed at any given time. No barrier is needed for the last - // batch since GGML inserts a barrier between the execution of every operator. - ggml_barrier(params->threadpool); - } - } - - return true; - } - - bool compute_forward_q4_0(struct ggml_compute_params * params, struct ggml_tensor * dst) { - const ggml_tensor * src0 = dst->src[0]; - const ggml_tensor * src1 = dst->src[1]; - - GGML_TENSOR_BINARY_OP_LOCALS - - ggml_kleidiai_kernels *kernels = ggml_kleidiai_select_kernels(ctx.features, dst); - GGML_ASSERT(kernels); - - kernel_info * kernel = src1->ne[1] == 1 ? &kernels->gemv : &kernels->gemm; - lhs_packing_info * lhs_info = &kernels->lhs_info; - - GGML_ASSERT(kernel); - - const int ith = params->ith; - const int nth = params->nth; - - const size_t k = ne00; - const size_t m = ne11; - const size_t n = ne01; - - size_t mr = kernel->get_mr(); - size_t kr = kernel->get_kr(); - size_t sr = kernel->get_sr(); - - const uint8_t * lhs = static_cast(src1->data); - uint8_t * lhs_packed = (uint8_t*)params->wdata; - const uint8_t * rhs_packed = static_cast(src0->data); - - const size_t n_step = kernel->get_n_step(); - const size_t num_n_per_thread = kai_roundup(kai_roundup(n, nth) / nth, n_step); - const size_t n_start = ith * num_n_per_thread; - - size_t n_to_process = num_n_per_thread; - if ((n_start + n_to_process) > n) { - n_to_process = n - n_start; - } - - // Calculate number of columns to be processed per thread - const size_t num_m_per_thread = kai_roundup(m, mr * nth) / nth; - const size_t m_start = ith * num_m_per_thread; - size_t m_to_process = num_m_per_thread; - if ((m_start + m_to_process) > m) { - m_to_process = m - m_start; - } - - if (m_start < m) { - // Transform LHS - const size_t src_stride = src1->nb[1]; - const float * src_ptr = reinterpret_cast(lhs + lhs_info->get_offset(m_start, dst->src[1]->nb[1])); - const size_t lhs_packed_offset = variant_call(lhs_info->get_packed_offset, m_start, k, QK4_0, mr, kr, sr); - void * lhs_packed_ptr = static_cast(lhs_packed + lhs_packed_offset); - - variant_call(lhs_info->pack_func, m_to_process, k, QK4_0, mr, kr, sr, 0, src_ptr, src_stride, lhs_packed_ptr); - } - - ggml_barrier(params->threadpool); - - // Perform the operation - const size_t dst_stride = dst->nb[1]; - const size_t lhs_packed_offset = variant_call(lhs_info->get_packed_offset, 0, k, QK4_0, mr, kr, sr); - const size_t rhs_packed_offset = variant_call(kernel->get_rhs_packed_offset, n_start, k, QK4_0); - const size_t dst_offset = kernel->get_dst_offset(0, n_start, dst_stride); - const void * rhs_ptr = static_cast(rhs_packed + rhs_packed_offset); - const void* lhs_ptr = (const void*)((const char *)lhs_packed + lhs_packed_offset); - float *dst_ptr = reinterpret_cast(static_cast(dst->data) + dst_offset); - - variant_call(kernel->run_kernel, m, n_to_process, k, QK4_0, lhs_ptr, rhs_ptr, dst_ptr, dst_stride, - sizeof(float), -FLT_MAX, FLT_MAX); - - return true; - } - -public: - int repack(struct ggml_tensor * tensor, const void * data, size_t data_size) { - GGML_ASSERT(ctx.kernels); - const size_t n = tensor->ne[1]; - const size_t k = tensor->ne[0]; - size_t nr = ctx.kernels->gemm.get_nr(); - size_t kr = ctx.kernels->gemm.get_kr(); - size_t sr = ctx.kernels->gemm.get_sr(); - -#ifndef NDEBUG - const size_t repacked_size = variant_call(ctx.kernels->rhs_info.packed_size, n, k, nr, kr, QK4_0); - GGML_ASSERT(repacked_size <= data_size && "repacked size larger than the packed size!"); -#endif - struct kai_rhs_pack_qs4cxs1s0_param params; - params.lhs_zero_point = 1; - params.rhs_zero_point = 8; - variant_call(ctx.kernels->rhs_info.pack_func, 1, n, k, nr, kr, sr, QK4_0, (const uint8_t*)data, nullptr, tensor->data, 0, ¶ms); - - return 0; - - GGML_UNUSED(data_size); - } -}; - -static ggml::cpu::tensor_traits * get_tensor_traits(ggml_backend_buffer_t, struct ggml_tensor *) { - static tensor_traits traits; - return &traits; -} -} // namespace ggml::cpu::kleidiai - -static enum ggml_status ggml_backend_cpu_kleidiai_buffer_init_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor) { - tensor->extra = (void *) ggml::cpu::kleidiai::get_tensor_traits(buffer, tensor); - - GGML_UNUSED(buffer); - return GGML_STATUS_SUCCESS; -} - -static void ggml_backend_cpu_kleidiai_buffer_set_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, - const void * data, size_t offset, size_t size) { - GGML_ASSERT(offset == 0); - GGML_ASSERT(size == ggml_nbytes(tensor)); - - auto tensor_traits = (ggml::cpu::kleidiai::tensor_traits *) tensor->extra; - auto OK = tensor_traits->repack(tensor, data, size); - - GGML_ASSERT(OK == 0); - GGML_UNUSED(buffer); -} - -static const char * ggml_backend_cpu_kleidiai_buffer_type_get_name(ggml_backend_buffer_type_t buft) { - return "CPU_KLEIDIAI"; - - GGML_UNUSED(buft); -} - -static ggml_backend_buffer_t ggml_backend_cpu_kleidiai_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) { - ggml_backend_buffer_t buffer = ggml_backend_buft_alloc_buffer(ggml_backend_cpu_buffer_type(), size); - - if (buffer == nullptr) { - return nullptr; - } - - buffer->buft = buft; - buffer->iface.init_tensor = ggml_backend_cpu_kleidiai_buffer_init_tensor; - buffer->iface.set_tensor = ggml_backend_cpu_kleidiai_buffer_set_tensor; - buffer->iface.get_tensor = nullptr; - buffer->iface.cpy_tensor = nullptr; - return buffer; -} - -static size_t ggml_backend_cpu_kleidiai_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) { - return TENSOR_ALIGNMENT; - - GGML_UNUSED(buft); -} - -namespace ggml::cpu::kleidiai { -class extra_buffer_type : ggml::cpu::extra_buffer_type { - bool supports_op(ggml_backend_dev_t, const struct ggml_tensor * op) override { - if (op->op == GGML_OP_MUL_MAT && - op->src[0]->type == GGML_TYPE_Q4_0 && - op->src[0]->buffer && - (ggml_n_dims(op->src[0]) == 2) && - op->src[0]->buffer->buft == ggml_backend_cpu_kleidiai_buffer_type() && ctx.kernels) { - if (op->src[1]->buffer && !ggml_backend_buft_is_host(op->src[1]->buffer->buft)) { - return false; - } - if (op->src[1]->type == GGML_TYPE_F32 && - ggml_ne(op->src[1], 2) == 1 && ggml_ne(op->src[1], 3) == 1) { - return true; - } - } - return false; - } - - ggml::cpu::tensor_traits * get_tensor_traits(const struct ggml_tensor * op) override { - if (op->op == GGML_OP_MUL_MAT) { - if (op->src[0]->buffer && op->src[0]->buffer->buft == ggml_backend_cpu_kleidiai_buffer_type()) { - return (ggml::cpu::tensor_traits *) op->src[0]->extra; - } - else if (ggml_kleidiai_select_kernels(ctx.features, op) && - op->src[0]->op == GGML_OP_VIEW && - (op->src[1]->op == GGML_OP_PERMUTE || op->src[1]->op == GGML_OP_SOFT_MAX) && - op->src[1]->ne[1] > 1) { - if ((op->src[0]->nb[0] != 2) || - (op->src[1]->nb[0] != 4) || - (op->src[0]->nb[1] * op->src[0]->ne[1] != op->src[0]->nb[2]) || - (op->src[1]->nb[1] * op->src[1]->ne[1] != op->src[1]->nb[2])) { - return nullptr; - } - - return ggml::cpu::kleidiai::get_tensor_traits(NULL, NULL); - } - } - return nullptr; - } -}; -} // namespace ggml::cpu::kleidiai - -ggml_backend_buffer_type_t ggml_backend_cpu_kleidiai_buffer_type(void) { - static ggml::cpu::kleidiai::extra_buffer_type ctx; - static struct ggml_backend_buffer_type ggml_backend_cpu_buffer_type_kleidiai = { - /* .iface = */ { - /* .get_name = */ ggml_backend_cpu_kleidiai_buffer_type_get_name, - /* .alloc_buffer = */ ggml_backend_cpu_kleidiai_buffer_type_alloc_buffer, - /* .get_alignment = */ ggml_backend_cpu_kleidiai_buffer_type_get_alignment, - /* .get_max_size = */ nullptr, // defaults to SIZE_MAX - /* .get_alloc_size = */ nullptr, // defaults to ggml_nbytes - /* .is_host = */ nullptr, - }, - /* .device = */ ggml_backend_reg_dev_get(ggml_backend_cpu_reg(), 0), - /* .context = */ &ctx, - }; - - init_kleidiai_context(); - - return &ggml_backend_cpu_buffer_type_kleidiai; -} diff --git a/ggml/src/ggml-cpu/kleidiai/kleidiai.h b/ggml/src/ggml-cpu/kleidiai/kleidiai.h deleted file mode 100644 index 38eac58f7..000000000 --- a/ggml/src/ggml-cpu/kleidiai/kleidiai.h +++ /dev/null @@ -1,17 +0,0 @@ -// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates -// SPDX-License-Identifier: MIT -// - -#pragma once - -#include "ggml-alloc.h" - -#ifdef __cplusplus -extern "C" { -#endif - -ggml_backend_buffer_type_t ggml_backend_cpu_kleidiai_buffer_type(void); - -#ifdef __cplusplus -} -#endif diff --git a/ggml/src/ggml-impl.h b/ggml/src/ggml-impl.h index 4972558c9..a2e30994c 100644 --- a/ggml/src/ggml-impl.h +++ b/ggml/src/ggml-impl.h @@ -73,6 +73,22 @@ static inline int ggml_up(int n, int m) { return (n + m - 1) & ~(m - 1); } +// TODO: move to ggml.h? +static bool ggml_are_same_layout(const struct ggml_tensor * a, const struct ggml_tensor * b) { + if (a->type != b->type) { + return false; + } + for (int i = 0; i < GGML_MAX_DIMS; i++) { + if (a->ne[i] != b->ne[i]) { + return false; + } + if (a->nb[i] != b->nb[i]) { + return false; + } + } + return true; +} + // // logging // diff --git a/ggml/src/ggml-metal/ggml-metal-impl.h b/ggml/src/ggml-metal/ggml-metal-impl.h index 752d55c21..b7b3fc49a 100644 --- a/ggml/src/ggml-metal/ggml-metal-impl.h +++ b/ggml/src/ggml-metal/ggml-metal-impl.h @@ -126,6 +126,7 @@ typedef struct { uint64_t nb2; uint64_t nb3; uint64_t offs; + uint64_t o1[8]; } ggml_metal_kargs_bin; typedef struct { @@ -240,7 +241,7 @@ typedef struct { float max_bias; float m0; float m1; - uint16_t n_head_log2; + int32_t n_head_log2; float logit_softcap; } ggml_metal_kargs_flash_attn_ext; @@ -377,8 +378,16 @@ typedef struct { typedef struct { int32_t ne00; int32_t ne00_4; - uint64_t nb01; + uint64_t nb1; + uint64_t nb2; + uint64_t nb3; float eps; + int32_t nef1[3]; + int32_t nef2[3]; + int32_t nef3[3]; + uint64_t nbf1[3]; + uint64_t nbf2[3]; + uint64_t nbf3[3]; } ggml_metal_kargs_rms_norm; typedef struct { @@ -484,7 +493,7 @@ typedef struct { float max_bias; float m0; float m1; - uint32_t n_head_log2; + int32_t n_head_log2; } ggml_metal_kargs_soft_max; typedef struct { diff --git a/ggml/src/ggml-metal/ggml-metal.m b/ggml/src/ggml-metal/ggml-metal.m index bdeef2cde..b267f394c 100644 --- a/ggml/src/ggml-metal/ggml-metal.m +++ b/ggml/src/ggml-metal/ggml-metal.m @@ -55,6 +55,12 @@ static struct ggml_backend_metal_device_context { bool has_residency_sets; bool has_bfloat; bool use_bfloat; + bool use_fusion; + + int debug_fusion; + + // how many times a given op was fused + uint64_t fuse_cnt[GGML_OP_COUNT]; size_t max_size; @@ -69,6 +75,9 @@ static struct ggml_backend_metal_device_context { /*.has_residency_sets =*/ false, /*.has_bfloat =*/ false, /*.use_bfloat =*/ false, + /*.use_fusion =*/ true, + /*.debug_fusion =*/ 0, + /*.fuse_cnt =*/ { 0 }, /*.max_size =*/ 0, /*.name =*/ "", }; @@ -83,16 +92,14 @@ static id ggml_backend_metal_device_acq(struct ggml_backend_metal_dev if (ctx->mtl_device == nil) { ctx->mtl_device = MTLCreateSystemDefaultDevice(); - } - if (ctx->mtl_device) { ctx->has_simdgroup_reduction = [ctx->mtl_device supportsFamily:MTLGPUFamilyApple7]; ctx->has_simdgroup_reduction |= [ctx->mtl_device supportsFamily:MTLGPUFamilyMetal3_GGML]; ctx->has_simdgroup_mm = [ctx->mtl_device supportsFamily:MTLGPUFamilyApple7]; #if defined(GGML_METAL_HAS_RESIDENCY_SETS) - ctx->has_residency_sets = getenv("GGML_METAL_NO_RESIDENCY") == NULL; + ctx->has_residency_sets = getenv("GGML_METAL_NO_RESIDENCY") == nil; #endif ctx->has_bfloat = [ctx->mtl_device supportsFamily:MTLGPUFamilyMetal3_GGML]; @@ -103,6 +110,14 @@ static id ggml_backend_metal_device_acq(struct ggml_backend_metal_dev #else ctx->use_bfloat = false; #endif + ctx->use_fusion = getenv("GGML_METAL_FUSION_DISABLE") == nil; + + { + const char * val = getenv("GGML_METAL_FUSION_DEBUG"); + ctx->debug_fusion = val ? atoi(val) : 0; + } + + memset(ctx->fuse_cnt, 0, sizeof(ctx->fuse_cnt)); ctx->max_size = ctx->mtl_device.maxBufferLength; @@ -122,6 +137,18 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte ctx->mtl_device_ref_count--; if (ctx->mtl_device_ref_count == 0) { + if (ctx->debug_fusion > 0) { + fprintf(stderr, "%s: fusion stats:\n", __func__); + for (int i = 0; i < GGML_OP_COUNT; i++) { + if (ctx->fuse_cnt[i] == 0) { + continue; + } + + // note: cannot use ggml_log here + fprintf(stderr, "%s: - %s: %" PRIu64 "\n", __func__, ggml_op_name((enum ggml_op) i), ctx->fuse_cnt[i]); + } + } + if (ctx->mtl_lock) { [ctx->mtl_lock release]; ctx->mtl_lock = nil; @@ -147,13 +174,27 @@ struct ggml_metal_kernel { enum ggml_metal_kernel_type { GGML_METAL_KERNEL_TYPE_ADD, - GGML_METAL_KERNEL_TYPE_ADD_ROW, + GGML_METAL_KERNEL_TYPE_ADD_FUSE_2, + GGML_METAL_KERNEL_TYPE_ADD_FUSE_3, + GGML_METAL_KERNEL_TYPE_ADD_FUSE_4, + GGML_METAL_KERNEL_TYPE_ADD_FUSE_5, + GGML_METAL_KERNEL_TYPE_ADD_FUSE_6, + GGML_METAL_KERNEL_TYPE_ADD_FUSE_7, + GGML_METAL_KERNEL_TYPE_ADD_FUSE_8, + GGML_METAL_KERNEL_TYPE_ADD_ROW_C4, + GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_2, + GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_3, + GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_4, + GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_5, + GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_6, + GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_7, + GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_8, GGML_METAL_KERNEL_TYPE_SUB, - GGML_METAL_KERNEL_TYPE_SUB_ROW, + GGML_METAL_KERNEL_TYPE_SUB_ROW_C4, GGML_METAL_KERNEL_TYPE_MUL, - GGML_METAL_KERNEL_TYPE_MUL_ROW, + GGML_METAL_KERNEL_TYPE_MUL_ROW_C4, GGML_METAL_KERNEL_TYPE_DIV, - GGML_METAL_KERNEL_TYPE_DIV_ROW, + GGML_METAL_KERNEL_TYPE_DIV_ROW_C4, GGML_METAL_KERNEL_TYPE_REPEAT_F32, GGML_METAL_KERNEL_TYPE_REPEAT_F16, GGML_METAL_KERNEL_TYPE_REPEAT_I32, @@ -218,6 +259,8 @@ enum ggml_metal_kernel_type { GGML_METAL_KERNEL_TYPE_SET_ROWS_Q5_1, GGML_METAL_KERNEL_TYPE_SET_ROWS_IQ4_NL, GGML_METAL_KERNEL_TYPE_RMS_NORM, + GGML_METAL_KERNEL_TYPE_RMS_NORM_MUL, + GGML_METAL_KERNEL_TYPE_RMS_NORM_MUL_ADD, GGML_METAL_KERNEL_TYPE_L2_NORM, GGML_METAL_KERNEL_TYPE_GROUP_NORM, GGML_METAL_KERNEL_TYPE_NORM, @@ -1135,13 +1178,27 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de // simd_sum and simd_max requires MTLGPUFamilyApple7 GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD, add, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW, add_row, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_FUSE_2, add_fuse_2, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_FUSE_3, add_fuse_3, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_FUSE_4, add_fuse_4, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_FUSE_5, add_fuse_5, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_FUSE_6, add_fuse_6, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_FUSE_7, add_fuse_7, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_FUSE_8, add_fuse_8, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW_C4, add_row_c4, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_2, add_row_c4_fuse_2, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_3, add_row_c4_fuse_3, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_4, add_row_c4_fuse_4, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_5, add_row_c4_fuse_5, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_6, add_row_c4_fuse_6, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_7, add_row_c4_fuse_7, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_8, add_row_c4_fuse_8, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUB, sub, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUB_ROW, sub_row, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUB_ROW_C4, sub_row_c4, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL, mul, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_ROW, mul_row, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_ROW_C4, mul_row_c4, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV, div, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV_ROW, div_row, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV_ROW_C4, div_row_c4, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_F32, repeat_f32, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_F16, repeat_f16, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_I32, repeat_i32, true); @@ -1206,6 +1263,8 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_ROWS_Q5_1, set_rows_q5_1, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_ROWS_IQ4_NL, set_rows_iq4_nl, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RMS_NORM, rms_norm, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RMS_NORM_MUL, rms_norm_mul, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RMS_NORM_MUL_ADD, rms_norm_mul_add, has_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_L2_NORM, l2_norm, has_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GROUP_NORM, group_norm, has_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NORM, norm, true); @@ -1893,7 +1952,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex } } -static bool ggml_metal_encode_node( +static int ggml_metal_encode_node( ggml_backend_t backend, int idx, id encoder, @@ -1903,7 +1962,10 @@ static bool ggml_metal_encode_node( struct ggml_cgraph * gf = ctx->gf; - struct ggml_tensor * node = ggml_graph_node(gf, idx); + enum ggml_op ops[8]; + + struct ggml_tensor ** nodes = ggml_graph_nodes(gf) + idx; + struct ggml_tensor * node = nodes[0]; //GGML_LOG_INFO("%s: encoding node %3d, op = %8s\n", __func__, idx, ggml_op_name(node->op)); @@ -1913,7 +1975,7 @@ static bool ggml_metal_encode_node( struct ggml_tensor * dst = node; if (ggml_is_empty(dst)) { - return true; + return 1; } switch (dst->op) { @@ -1924,7 +1986,7 @@ static bool ggml_metal_encode_node( case GGML_OP_PERMUTE: { // noop -> next node - } return true; + } return 1; default: { } break; @@ -1991,6 +2053,8 @@ static bool ggml_metal_encode_node( id id_src2 = src2 ? ggml_metal_get_buffer(src2, &offs_src2) : nil; id id_dst = dst ? ggml_metal_get_buffer(dst, &offs_dst) : nil; + int n_fuse = 1; + #if 0 GGML_LOG_INFO("%s: op - %s\n", __func__, ggml_op_name(dst->op)); if (src0) { @@ -2062,37 +2126,15 @@ static bool ggml_metal_encode_node( GGML_ASSERT(src0t == GGML_TYPE_F32); GGML_ASSERT(src1t == GGML_TYPE_F32); + GGML_ASSERT(ggml_is_contiguous_rows(src0)); + GGML_ASSERT(ggml_is_contiguous_rows(src1)); + const size_t offs = 0; bool bcast_row = false; id pipeline = nil; - if (ggml_nelements(src1) == ne10 && ggml_is_contiguous(src1) && ne00 % 4 == 0 && ne10 % 4 == 0) { - GGML_ASSERT(ggml_is_contiguous(src0)); - - // src1 is a row - GGML_ASSERT(ne11 == 1); - - switch (dst->op) { - case GGML_OP_ADD: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW].pipeline; break; - case GGML_OP_SUB: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SUB_ROW].pipeline; break; - case GGML_OP_MUL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_ROW].pipeline; break; - case GGML_OP_DIV: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIV_ROW].pipeline; break; - default: GGML_ABORT("fatal error"); - } - - bcast_row = true; - } else { - switch (dst->op) { - case GGML_OP_ADD: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD].pipeline; break; - case GGML_OP_SUB: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SUB].pipeline; break; - case GGML_OP_MUL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL].pipeline; break; - case GGML_OP_DIV: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIV].pipeline; break; - default: GGML_ABORT("fatal error"); - } - } - ggml_metal_kargs_bin args = { /*.ne00 =*/ ne00, /*.ne01 =*/ ne01, @@ -2119,12 +2161,117 @@ static bool ggml_metal_encode_node( /*.nb2 =*/ nb2, /*.nb3 =*/ nb3, /*.offs =*/ offs, + /*.o1 =*/ { offs_src1 }, }; + // c[0] = add(a, b[0]) + // c[1] = add(c[0], b[1]) + // c[2] = add(c[1], b[2]) + // ... + if (ctx_dev->use_fusion) { + ops[0] = GGML_OP_ADD; + ops[1] = GGML_OP_ADD; + ops[2] = GGML_OP_ADD; + ops[3] = GGML_OP_ADD; + ops[4] = GGML_OP_ADD; + ops[5] = GGML_OP_ADD; + ops[6] = GGML_OP_ADD; + ops[7] = GGML_OP_ADD; + + size_t offs_fuse; + id id_fuse; + + for (n_fuse = 0; n_fuse <= 6; ++n_fuse) { + if (!ggml_can_fuse(gf, idx + n_fuse, ops + n_fuse, 2)) { + break; + } + + if (nodes[n_fuse] != nodes[n_fuse + 1]->src[0]) { + break; + } + + // b[0] === b[1] === ... + if (!ggml_are_same_layout(nodes[n_fuse]->src[1], nodes[n_fuse + 1]->src[1])) { + break; + } + + // only fuse nodes if src1 is in the same Metal buffer + id_fuse = ggml_metal_get_buffer(nodes[n_fuse + 1]->src[1], &offs_fuse); + if (id_fuse != id_src1) { + break; + } + + ctx_dev->fuse_cnt[nodes[n_fuse + 1]->op]++; + + args.o1[n_fuse + 1] = offs_fuse; + } + + ++n_fuse; + + if (ctx_dev->debug_fusion > 1 && n_fuse > 1) { + GGML_LOG_DEBUG("%s: fuse: ADD x %d\n", __func__, n_fuse); + } + } + + if (ggml_nelements(src1) == ne10 && ggml_is_contiguous(src1) && ne00 % 4 == 0 && ne10 % 4 == 0) { + GGML_ASSERT(ggml_is_contiguous(src0)); + + // src1 is a row + GGML_ASSERT(ne11 == 1); + + switch (dst->op) { + case GGML_OP_ADD: + { + switch (n_fuse) { + case 1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW_C4 ].pipeline; break; + case 2: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_2].pipeline; break; + case 3: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_3].pipeline; break; + case 4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_4].pipeline; break; + case 5: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_5].pipeline; break; + case 6: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_6].pipeline; break; + case 7: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_7].pipeline; break; + case 8: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_8].pipeline; break; + default: GGML_ABORT("fatal error"); + } + } break; + case GGML_OP_SUB: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SUB_ROW_C4].pipeline; break; + case GGML_OP_MUL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_ROW_C4].pipeline; break; + case GGML_OP_DIV: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIV_ROW_C4].pipeline; break; + default: GGML_ABORT("fatal error"); + } + + bcast_row = true; + } else { + switch (dst->op) { + case GGML_OP_ADD: + { + switch (n_fuse) { + case 1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD ].pipeline; break; + case 2: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_FUSE_2].pipeline; break; + case 3: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_FUSE_3].pipeline; break; + case 4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_FUSE_4].pipeline; break; + case 5: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_FUSE_5].pipeline; break; + case 6: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_FUSE_6].pipeline; break; + case 7: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_FUSE_7].pipeline; break; + case 8: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_FUSE_8].pipeline; break; + default: GGML_ABORT("fatal error"); + } + } break; + case GGML_OP_SUB: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SUB].pipeline; break; + case GGML_OP_MUL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL].pipeline; break; + case GGML_OP_DIV: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIV].pipeline; break; + default: GGML_ABORT("fatal error"); + } + } + + if (n_fuse > 1) { + id_dst = ggml_metal_get_buffer(nodes[n_fuse - 1], &offs_dst); + } + [encoder setComputePipelineState:pipeline]; [encoder setBytes:&args length:sizeof(args) atIndex:0]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1]; - [encoder setBuffer:id_src1 offset:offs_src1 atIndex:2]; + [encoder setBuffer:id_src1 offset:0 atIndex:2]; [encoder setBuffer:id_dst offset:offs_dst atIndex:3]; if (bcast_row) { @@ -2132,7 +2279,11 @@ static bool ggml_metal_encode_node( [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; } else { - const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne0); + int nth = 32; + + while (16*nth < ne0 && nth < (int) pipeline.maxTotalThreadsPerThreadgroup) { + nth *= 2; + } [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; } @@ -2257,12 +2408,13 @@ static bool ggml_metal_encode_node( /*.nb2 =*/ pnb2, /*.nb3 =*/ pnb3, /*.offs =*/ offs, + /*.o1 =*/ { offs_src1}, }; [encoder setComputePipelineState:pipeline]; [encoder setBytes:&args length:sizeof(args) atIndex:0]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1]; - [encoder setBuffer:id_src1 offset:offs_src1 atIndex:2]; + [encoder setBuffer:id_src1 offset:0 atIndex:2]; [encoder setBuffer:id_dst offset:offs_dst atIndex:3]; const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne00); @@ -2764,7 +2916,7 @@ static bool ggml_metal_encode_node( id h_src0 = h_src0 = ggml_metal_mem_pool_alloc(mem_pool, ggml_nbytes(src0)); if (!h_src0) { GGML_LOG_ERROR("%s: failed to allocate buffer from memory pool, size = %zu\n", __func__, ggml_nbytes(src0)); - return false; + return 0; } offs_src0 = 0; @@ -3640,7 +3792,7 @@ static bool ggml_metal_encode_node( id h_src1 = ggml_metal_mem_pool_alloc(mem_pool, s_src1); if (!h_src1) { GGML_LOG_ERROR("%s: failed to allocate buffer from memory pool, size = %zu\n", __func__, s_src1); - return false; + return 0; } const int64_t neh0 = ne0; @@ -3656,7 +3808,7 @@ static bool ggml_metal_encode_node( id h_dst = ggml_metal_mem_pool_alloc(mem_pool, s_dst); if (!h_dst) { GGML_LOG_ERROR("%s: failed to allocate buffer from memory pool, size = %zu\n", __func__, s_dst); - return false; + return 0; } // tokens per expert @@ -3664,7 +3816,7 @@ static bool ggml_metal_encode_node( id h_tpe = ggml_metal_mem_pool_alloc(mem_pool, s_tpe); if (!h_tpe) { GGML_LOG_ERROR("%s: failed to allocate buffer from memory pool, size = %zu\n", __func__, s_tpe); - return false; + return 0; } // id map @@ -3673,7 +3825,7 @@ static bool ggml_metal_encode_node( id h_ids = ggml_metal_mem_pool_alloc(mem_pool, s_ids); if (!h_ids) { GGML_LOG_ERROR("%s: failed to allocate buffer from memory pool, size = %zu\n", __func__, s_ids); - return false; + return 0; } { @@ -4105,12 +4257,95 @@ static bool ggml_metal_encode_node( case GGML_OP_RMS_NORM: { GGML_ASSERT(ne00 % 4 == 0); - GGML_ASSERT(ggml_is_contiguous_1(src0)); + GGML_ASSERT(ggml_is_contiguous_rows(src0)); float eps; memcpy(&eps, dst->op_params, sizeof(float)); - id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_RMS_NORM].pipeline; + ggml_metal_kargs_rms_norm args = { + /*.ne00 =*/ ne00, + /*.ne00_4 =*/ ne00/4, + /*.nb1 =*/ nb1, + /*.nb2 =*/ nb2, + /*.nb3 =*/ nb3, + /*.eps =*/ eps, + /*.nef1 =*/ { ne01 }, + /*.nef2 =*/ { ne02 }, + /*.nef3 =*/ { ne03 }, + /*.nbf1 =*/ { nb01 }, + /*.nbf2 =*/ { nb02 }, + /*.nbf3 =*/ { nb03 }, + }; + + size_t offs_fuse[2] = { 0, 0 }; + id id_fuse[2] = { id_src0, id_src0 }; + + // d[0] = rms_norm(a) + // d[1] = mul(d[0], b) + // d[2] = add(d[1], c) + if (ctx_dev->use_fusion) { + ops[0] = GGML_OP_RMS_NORM; + ops[1] = GGML_OP_MUL; + ops[2] = GGML_OP_ADD; + + for (n_fuse = 0; n_fuse <= 1; ++n_fuse) { + if (!ggml_can_fuse(gf, idx + n_fuse, ops + n_fuse, 2)) { + break; + } + + if (nodes[n_fuse] != nodes[n_fuse + 1]->src[0]) { + break; + } + + if (nodes[n_fuse + 1]->src[1]->ne[0] != node->ne[0]) { + break; + } + + if (!ggml_is_contiguous_rows(nodes[n_fuse + 1]->src[1])) { + break; + } + + if (nodes[n_fuse + 1]->type != GGML_TYPE_F32) { + break; + } + + ctx_dev->fuse_cnt[nodes[n_fuse + 1]->op]++; + + id_fuse[n_fuse] = ggml_metal_get_buffer(nodes[n_fuse + 1]->src[1], &offs_fuse[n_fuse]); + + args.nef1[n_fuse + 1] = nodes[n_fuse + 1]->src[1]->ne[1]; + args.nef2[n_fuse + 1] = nodes[n_fuse + 1]->src[1]->ne[2]; + args.nef3[n_fuse + 1] = nodes[n_fuse + 1]->src[1]->ne[3]; + + args.nbf1[n_fuse + 1] = nodes[n_fuse + 1]->src[1]->nb[1]; + args.nbf2[n_fuse + 1] = nodes[n_fuse + 1]->src[1]->nb[2]; + args.nbf3[n_fuse + 1] = nodes[n_fuse + 1]->src[1]->nb[3]; + } + + ++n_fuse; + + if (ctx_dev->debug_fusion > 1 && n_fuse > 1) { + if (n_fuse == 2) { + GGML_LOG_DEBUG("%s: fuse: RMS_NORM + MUL\n", __func__); + } + if (n_fuse == 3) { + GGML_LOG_DEBUG("%s: fuse: RMS_NORM + MUL + ADD\n", __func__); + } + } + } + + if (n_fuse > 1) { + id_dst = ggml_metal_get_buffer(nodes[n_fuse - 1], &offs_dst); + } + + id pipeline; + + switch (n_fuse) { + case 1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_RMS_NORM ].pipeline; break; + case 2: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_RMS_NORM_MUL ].pipeline; break; + case 3: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_RMS_NORM_MUL_ADD].pipeline; break; + default: GGML_ABORT("unsupported n_fuse = %d\n", n_fuse); + } int nth = 32; // SIMD width @@ -4121,23 +4356,16 @@ static bool ggml_metal_encode_node( nth = MIN(nth, (int) pipeline.maxTotalThreadsPerThreadgroup); nth = MIN(nth, ne00/4); - ggml_metal_kargs_rms_norm args = { - /*.ne00 =*/ ne00, - /*.ne00_4 =*/ ne00/4, - /*.nb01 =*/ nb01, - /*.eps =*/ eps, - }; - [encoder setComputePipelineState:pipeline]; - [encoder setBytes:&args length:sizeof(args) atIndex:0]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; + [encoder setBytes:&args length:sizeof(args) atIndex:0]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1]; + [encoder setBuffer:id_fuse[0] offset:offs_fuse[0] atIndex:2]; + [encoder setBuffer:id_fuse[1] offset:offs_fuse[1] atIndex:3]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:4]; [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0]; - const int64_t nrows = ggml_nrows(src0); - - [encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; + [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; } break; case GGML_OP_L2_NORM: { @@ -5532,7 +5760,7 @@ static bool ggml_metal_encode_node( } } - return true; + return n_fuse; } static enum ggml_status ggml_metal_graph_compute( @@ -6038,20 +6266,22 @@ static void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) { struct ggml_metal_mem_pool * mem_pool = ctx->cmd_bufs[cb_idx].mem_pool; ggml_metal_mem_pool_reset(mem_pool); - for (int idx = node_start; idx < node_end; ++idx) { + for (int idx = node_start; idx < node_end;) { if (should_capture) { [encoder pushDebugGroup:[NSString stringWithCString:ggml_op_desc(ggml_graph_node(ctx->gf, idx)) encoding:NSUTF8StringEncoding]]; } - const bool res = ggml_metal_encode_node(backend, idx, encoder, mem_pool); + const int res = ggml_metal_encode_node(backend, idx, encoder, mem_pool); if (should_capture) { [encoder popDebugGroup]; } - if (!res) { + if (res == 0) { break; } + + idx += res; } [encoder endEncoding]; diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index 13235e288..f62b9ad54 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -832,7 +832,8 @@ enum ggml_sort_order { // general-purpose kernel for addition, subtraction, multiplication and division of two tensors // pros: works for non-contiguous tensors, supports broadcast across all dims // cons: not very efficient -kernel void kernel_add( +template +kernel void kernel_add_fuse_impl( constant ggml_metal_kargs_bin & args, device const char * src0, device const char * src1, @@ -848,16 +849,39 @@ kernel void kernel_add( const int i12 = i02%args.ne12; const int i11 = i01%args.ne11; - device const char * src0_ptr = src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + args.offs; - device const char * src1_ptr = src1 + i13*args.nb13 + i12*args.nb12 + i11*args.nb11; - device char * dst_ptr = dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1 + args.offs; + device const float * src0_ptr = (device const float *) (src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + args.offs); + device float * dst_ptr = (device float *) (dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1 + args.offs); + + device const float * src1_ptr[F]; + for (short j = 0; j < F; ++j) { + src1_ptr[j] = (device const float *) (src1 + args.o1[j] + i13*args.nb13 + i12*args.nb12 + i11*args.nb11); + } for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) { const int i10 = i0%args.ne10; - *((device float *)(dst_ptr + i0*args.nb0)) = *((device float *)(src0_ptr + i0*args.nb00)) + *((device float *)(src1_ptr + i10*args.nb10)); + + float res = src0_ptr[i0]; + +#pragma unroll + for (short j = 0; j < F; ++j) { + res += src1_ptr[j][i10]; + } + + dst_ptr[i0] = res; } } +typedef decltype(kernel_add_fuse_impl<2>) kernel_add_fuse_t; + +template [[host_name("kernel_add")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<1>; +template [[host_name("kernel_add_fuse_2")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<2>; +template [[host_name("kernel_add_fuse_3")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<3>; +template [[host_name("kernel_add_fuse_4")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<4>; +template [[host_name("kernel_add_fuse_5")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<5>; +template [[host_name("kernel_add_fuse_6")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<6>; +template [[host_name("kernel_add_fuse_7")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<7>; +template [[host_name("kernel_add_fuse_8")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<8>; + kernel void kernel_sub( constant ggml_metal_kargs_bin & args, device const char * src0, @@ -875,7 +899,7 @@ kernel void kernel_sub( const int i11 = i01%args.ne11; device const char * src0_ptr = src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + args.offs; - device const char * src1_ptr = src1 + i13*args.nb13 + i12*args.nb12 + i11*args.nb11; + device const char * src1_ptr = src1 + i13*args.nb13 + i12*args.nb12 + i11*args.nb11 + args.o1[0]; device char * dst_ptr = dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1 + args.offs; for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) { @@ -900,9 +924,9 @@ kernel void kernel_mul( const int i12 = i02%args.ne12; const int i11 = i01%args.ne11; - device const char * src0_ptr = src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01; - device const char * src1_ptr = src1 + i13*args.nb13 + i12*args.nb12 + i11*args.nb11; - device char * dst_ptr = dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1; + device const char * src0_ptr = src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + args.offs; + device const char * src1_ptr = src1 + i13*args.nb13 + i12*args.nb12 + i11*args.nb11 + args.o1[0]; + device char * dst_ptr = dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1 + args.offs; for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) { const int i10 = i0%args.ne10; @@ -926,9 +950,9 @@ kernel void kernel_div( const int i12 = i02%args.ne12; const int i11 = i01%args.ne11; - device const char * src0_ptr = src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01; - device const char * src1_ptr = src1 + i13*args.nb13 + i12*args.nb12 + i11*args.nb11; - device char * dst_ptr = dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1; + device const char * src0_ptr = src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + args.offs; + device const char * src1_ptr = src1 + i13*args.nb13 + i12*args.nb12 + i11*args.nb11 + args.o1[0]; + device char * dst_ptr = dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1 + args.offs; for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) { const int i10 = i0%args.ne10; @@ -970,46 +994,145 @@ template [[host_name("kernel_repeat_i16")]] kernel kernel_repeat_t kernel_repeat // assumption: src1 is a row // broadcast src1 into src0 -kernel void kernel_add_row( +template +kernel void kernel_add_row_c4_fuse_impl( constant ggml_metal_kargs_bin & args, - device const float4 * src0, - device const float4 * src1, - device float4 * dst, + device const char * src0, + device const char * src1, + device char * dst, uint tpig[[thread_position_in_grid]]) { + const uint nb = args.ne00/4; - dst[tpig] = src0[tpig] + src1[tpig % nb]; + const uint i = tpig % nb; + + device const float4 * src0_row = (device const float4 *) (src0); + device float4 * dst_row = (device float4 *) (dst); + + device const float4 * src1_row[F]; + for (short j = 0; j < F; ++j) { + src1_row[j] = (device const float4 *) (src1 + args.o1[j]); + } + + float4 res = src0_row[tpig]; + +#pragma unroll(F) + for (short j = 0; j < F; ++j) { + res += src1_row[j][i]; + } + + dst_row[tpig] = res; } -kernel void kernel_sub_row( +typedef decltype(kernel_add_row_c4_fuse_impl<1>) kernel_add_row_c4_fuse_t; + +template [[host_name("kernel_add_row_c4")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<1>; +template [[host_name("kernel_add_row_c4_fuse_2")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<2>; +template [[host_name("kernel_add_row_c4_fuse_3")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<3>; +template [[host_name("kernel_add_row_c4_fuse_4")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<4>; +template [[host_name("kernel_add_row_c4_fuse_5")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<5>; +template [[host_name("kernel_add_row_c4_fuse_6")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<6>; +template [[host_name("kernel_add_row_c4_fuse_7")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<7>; +template [[host_name("kernel_add_row_c4_fuse_8")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<8>; + +template +kernel void kernel_sub_row_c4_fuse_impl( constant ggml_metal_kargs_bin & args, - device const float4 * src0, - device const float4 * src1, - device float4 * dst, + device const char * src0, + device const char * src1, + device char * dst, uint tpig[[thread_position_in_grid]]) { + const uint nb = args.ne00/4; - dst[tpig] = src0[tpig] - src1[tpig % nb]; + const uint i = tpig % nb; + + device const float4 * src0_row = (device const float4 *) (src0); + device float4 * dst_row = (device float4 *) (dst); + + device const float4 * src1_row[F]; + for (short j = 0; j < F; ++j) { + src1_row[j] = (device const float4 *) (src1 + args.o1[j]); + } + + float4 res = src0_row[tpig]; + +#pragma unroll(F) + for (short j = 0; j < F; ++j) { + res -= src1_row[j][i]; + } + + dst_row[tpig] = res; } -kernel void kernel_mul_row( +typedef decltype(kernel_sub_row_c4_fuse_impl<1>) kernel_sub_row_c4_fuse_t; + +template [[host_name("kernel_sub_row_c4")]] kernel kernel_sub_row_c4_fuse_t kernel_sub_row_c4_fuse_impl<1>; + +template +kernel void kernel_mul_row_c4_fuse_impl( constant ggml_metal_kargs_bin & args, - device const float4 * src0, - device const float4 * src1, - device float4 * dst, + device const char * src0, + device const char * src1, + device char * dst, uint tpig[[thread_position_in_grid]]) { + const uint nb = args.ne00/4; - dst[tpig] = src0[tpig] * src1[tpig % nb]; + const uint i = tpig % nb; + + device const float4 * src0_row = (device const float4 *) (src0); + device float4 * dst_row = (device float4 *) (dst); + + device const float4 * src1_row[F]; + for (short j = 0; j < F; ++j) { + src1_row[j] = (device const float4 *) (src1 + args.o1[j]); + } + + float4 res = src0_row[tpig]; + +#pragma unroll(F) + for (short j = 0; j < F; ++j) { + res *= src1_row[j][i]; + } + + dst_row[tpig] = res; } -kernel void kernel_div_row( +typedef decltype(kernel_mul_row_c4_fuse_impl<1>) kernel_mul_row_c4_fuse_t; + +template [[host_name("kernel_mul_row_c4")]] kernel kernel_mul_row_c4_fuse_t kernel_mul_row_c4_fuse_impl<1>; + +template +kernel void kernel_div_row_c4_fuse_impl( constant ggml_metal_kargs_bin & args, - device const float4 * src0, - device const float4 * src1, - device float4 * dst, + device const char * src0, + device const char * src1, + device char * dst, uint tpig[[thread_position_in_grid]]) { + const uint nb = args.ne00/4; - dst[tpig] = src0[tpig] / src1[tpig % nb]; + const uint i = tpig % nb; + + device const float4 * src0_row = (device const float4 *) (src0); + device float4 * dst_row = (device float4 *) (dst); + + device const float4 * src1_row[F]; + for (short j = 0; j < F; ++j) { + src1_row[j] = (device const float4 *) (src1 + args.o1[j]); + } + + float4 res = src0_row[tpig]; + +#pragma unroll(F) + for (short j = 0; j < F; ++j) { + res /= src1_row[j][i]; + } + + dst_row[tpig] = res; } +typedef decltype(kernel_div_row_c4_fuse_impl<1>) kernel_div_row_c4_fuse_t; + +template [[host_name("kernel_div_row_c4")]] kernel kernel_div_row_c4_fuse_t kernel_div_row_c4_fuse_impl<1>; + kernel void kernel_scale( device const float * src0, device float * dst, @@ -2116,26 +2239,39 @@ kernel void kernel_norm( } } -kernel void kernel_rms_norm( +// F == 1 : rms_norm (no fuse) +// F == 2 : rms_norm + mul +// F == 3 : rms_norm + mul + add +template +kernel void kernel_rms_norm_fuse_impl( constant ggml_metal_kargs_rms_norm & args, device const char * src0, + device const char * src1_0, + device const char * src1_1, device char * dst, threadgroup float * shmem_f32 [[threadgroup(0)]], - uint tgpig[[threadgroup_position_in_grid]], - ushort tpitg[[thread_position_in_threadgroup]], - ushort sgitg[[simdgroup_index_in_threadgroup]], - ushort tiisg[[thread_index_in_simdgroup]], - ushort ntg[[threads_per_threadgroup]]) { + uint3 tgpig[[threadgroup_position_in_grid]], + ushort3 tpitg[[thread_position_in_threadgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort3 ntg[[threads_per_threadgroup]]) { if (sgitg == 0) { shmem_f32[tiisg] = 0.0f; } - device const float4 * x = (device const float4 *) (src0 + tgpig*args.nb01); + const int i01 = tgpig.x; + const int i02 = tgpig.y; + const int i03 = tgpig.z; + + device const float4 * x = (device const float4 *) (src0 + i03*args.nbf3[0] + i02*args.nbf2[0] + i01*args.nbf1[0]); + + device const float4 * f0 = (device const float4 *) (src1_0 + (i03%args.nef3[1])*args.nbf3[1] + (i02%args.nef2[1])*args.nbf2[1] + (i01%args.nef1[1])*args.nbf1[1]); + device const float4 * f1 = (device const float4 *) (src1_1 + (i03%args.nef3[2])*args.nbf3[2] + (i02%args.nef2[2])*args.nbf2[2] + (i01%args.nef1[2])*args.nbf1[2]); float sumf = 0.0f; // parallel sum - for (int i00 = tpitg; i00 < args.ne00_4; i00 += ntg) { + for (int i00 = tpitg.x; i00 < args.ne00_4; i00 += ntg.x) { sumf += dot(x[i00], x[i00]); } sumf = simd_sum(sumf); @@ -2154,12 +2290,26 @@ kernel void kernel_rms_norm( const float mean = sumf/args.ne00; const float scale = 1.0f/sqrt(mean + args.eps); - device float4 * y = (device float4 *) dst + tgpig*args.ne00_4; - for (int i00 = tpitg; i00 < args.ne00_4; i00 += ntg) { - y[i00] = x[i00] * scale; + device float4 * y = (device float4 *) (dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1); + for (int i00 = tpitg.x; i00 < args.ne00_4; i00 += ntg.x) { + if (F == 1) { + y[i00] = (x[i00]*scale); + } + if (F == 2) { + y[i00] = (x[i00]*scale)*f0[i00]; + } + if (F == 3) { + y[i00] = (x[i00]*scale)*f0[i00] + f1[i00]; + } } } +typedef decltype(kernel_rms_norm_fuse_impl<1>) kernel_rms_norm_fuse_t; + +template [[host_name("kernel_rms_norm")]] kernel kernel_rms_norm_fuse_t kernel_rms_norm_fuse_impl<1>; +template [[host_name("kernel_rms_norm_mul")]] kernel kernel_rms_norm_fuse_t kernel_rms_norm_fuse_impl<2>; +template [[host_name("kernel_rms_norm_mul_add")]] kernel kernel_rms_norm_fuse_t kernel_rms_norm_fuse_impl<3>; + kernel void kernel_l2_norm( constant ggml_metal_kargs_l2_norm & args, device const char * src0, diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index a5b26d06c..b14a04760 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -344,6 +344,7 @@ struct vk_device_struct { uint64_t max_memory_allocation_size; uint64_t suballocation_block_size; bool fp16; + bool bf16; bool pipeline_robustness; vk::Device device; uint32_t vendor_id; @@ -498,6 +499,7 @@ struct vk_device_struct { vk_pipeline pipeline_rwkv_wkv6_f32; vk_pipeline pipeline_rwkv_wkv7_f32; vk_pipeline pipeline_opt_step_adamw_f32; + vk_pipeline pipeline_conv2d_f32; vk_pipeline pipeline_conv2d_dw_whcn_f32; vk_pipeline pipeline_conv2d_dw_cwhn_f32; @@ -891,6 +893,38 @@ struct vk_op_rwkv_wkv7_push_constants { uint32_t H; }; +struct vk_op_conv2d_push_constants { + uint32_t Cout; + uint32_t Cin; + uint32_t N; + + uint32_t KW; + uint32_t KH; + uint32_t W; + uint32_t H; + uint32_t OW; + uint32_t OH; + + uint32_t s0; + uint32_t s1; + uint32_t p0; + uint32_t p1; + uint32_t d0; + uint32_t d1; + + uint32_t nb01; + uint32_t nb02; + uint32_t nb03; + + uint32_t nb11; + uint32_t nb12; + uint32_t nb13; + + uint32_t nb1; + uint32_t nb2; + uint32_t nb3; +}; + struct vk_op_conv2d_dw_push_constants { uint32_t ne; uint32_t batches; @@ -990,18 +1024,45 @@ private: #endif // GGML_VULKAN_MEMORY_DEBUG class vk_perf_logger { -public: + public: void print_timings() { + if (timings.empty()) { + return; + } + uint64_t total_all_op_times = 0; std::cerr << "----------------\nVulkan Timings:" << std::endl; - for (const auto& t : timings) { - uint64_t total = 0; - for (const auto& time : t.second) { - total += time; + for (const auto & t : timings) { + uint64_t total_op_times = 0; + for (const auto & time : t.second) { + total_op_times += time; } - std::cerr << t.first << ": " << t.second.size() << " x " << (total / t.second.size() / 1000.0) << " us" << std::endl; + std::cerr << t.first << ": " << t.second.size() << " x " << (total_op_times / t.second.size() / 1000.0) + << " us"; + + // If we have as many flops entries as timing entries for the op, then compute and log the flops/S. + auto it = flops.find(t.first); + if (it != flops.end() && (it->second).size() == t.second.size()) { + uint64_t total_op_flops = 0; + for (const auto & elem : it->second) { + total_op_flops += elem; + } + std::cerr << " (" + << (double(total_op_flops) / (1000.0 * 1000.0 * 1000.0)) / + (double(total_op_times) / (1000.0 * 1000.0 * 1000.0)) + << " GFLOPS/s)"; + } + + total_all_op_times += total_op_times; + + std::cerr << std::endl; + } + + if (timings.size() > 0) { + std::cerr << "Total time: " << total_all_op_times / 1000.0 << " us." << std::endl; } timings.clear(); + flops.clear(); } void log_timing(const ggml_tensor * node, uint64_t time) { @@ -1010,22 +1071,45 @@ public: return; } if (node->op == GGML_OP_MUL_MAT || node->op == GGML_OP_MUL_MAT_ID) { - const uint64_t m = node->src[0]->ne[1]; - const uint64_t n = node->src[1]->ne[1]; - const uint64_t k = node->src[1]->ne[0]; - std::string name = ggml_op_name(node->op); + const uint64_t m = node->src[0]->ne[1]; + const uint64_t n = node->src[1]->ne[1]; + const uint64_t k = node->src[1]->ne[0]; + std::string name = ggml_op_name(node->op); if (n == 1) { name += "_VEC m=" + std::to_string(m) + " k=" + std::to_string(k); } else { name += " m=" + std::to_string(m) + " n=" + std::to_string(n) + " k=" + std::to_string(k); } timings[name].push_back(time); + flops[name].push_back(m * n * (k + (k - 1))); + return; + } + if (node->op == GGML_OP_CONV_2D) { + std::string name = ggml_op_name(node->op); + ggml_tensor * knl = node->src[0]; + uint64_t OW = node->ne[0]; + uint64_t OH = node->ne[1]; + uint64_t N = node->ne[3]; + uint64_t Cout = node->ne[2]; + uint64_t KW = knl->ne[0]; + uint64_t KH = knl->ne[1]; + uint64_t Cin = knl->ne[2]; + // KxCRS @ CRSxNPQ = KxNPQ -> M=K, K=CRS, N=NPQ + uint64_t size_M = Cout; + uint64_t size_K = Cin * KW * KH; + uint64_t size_N = N * OW * OH; + uint64_t n_flops = size_M * size_N * (size_K + (size_K - 1)); + name += " M=Cout=" + std::to_string(size_M) + ", K=Cin*KW*KH=" + std::to_string(size_K) + + ", N=N*OW*OH=" + std::to_string(size_N); + flops[name].push_back(n_flops); + timings[name].push_back(time); return; } timings[ggml_op_name(node->op)].push_back(time); } -private: + private: std::map> timings; + std::map> flops; }; struct ggml_backend_vk_context { @@ -2128,6 +2212,7 @@ static void ggml_vk_load_shaders(vk_device& device) { } compile_count++; } + compiles.push_back(std::async(ggml_vk_create_pipeline_func, std::ref(device), std::ref(pipeline), spv_size, spv_data, entrypoint, parameter_count, wg_denoms, specialization_constants, disable_robustness, require_full_subgroups, required_subgroup_size)); }; @@ -2977,6 +3062,42 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_opt_step_adamw_f32, "opt_step_adamw_f32", opt_step_adamw_f32_len, opt_step_adamw_f32_data, "main", 5, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); + // conv2d + uint32_t conv2d_WG_SIZE = 256; + uint32_t conv2d_BS_K = 128; + uint32_t conv2d_BS_CRS = 16; + uint32_t use_collectives = 0; // Enables subgroup ops for preventing the re-calculation of indices. + if (device->subgroup_shuffle && + device->vendor_id != VK_VENDOR_ID_INTEL) { // Do not enable collectives on Intel, see PR 14316 + use_collectives = 1; + conv2d_BS_CRS = std::min( + device->subgroup_size, + conv2d_BS_CRS); // CRS block size should be capped at sugroup size for correctness when shuffle is used. + } + uint32_t conv2d_BS_NPQ = 128; + uint32_t conv2d_TS_K = 8; + uint32_t conv2d_shmem_req = + (conv2d_BS_K * (conv2d_BS_CRS + 1) + conv2d_BS_CRS * (conv2d_BS_NPQ + 1)) * sizeof(float); + if (device->properties.limits.maxComputeSharedMemorySize < conv2d_shmem_req) { + conv2d_BS_CRS = 8; + if (use_collectives) { + conv2d_BS_CRS = std::min(device->subgroup_size, conv2d_BS_CRS); + } + } + + if (use_collectives) { + ggml_vk_create_pipeline( + device, device->pipeline_conv2d_f32, "conv2d_f32", conv2d_f32_len, conv2d_f32_data, "main", 3, + sizeof(vk_op_conv2d_push_constants), { conv2d_BS_K, conv2d_BS_NPQ, 1 }, + { conv2d_WG_SIZE, conv2d_BS_K, conv2d_BS_CRS, conv2d_BS_NPQ, conv2d_TS_K, use_collectives }, 1, true, true); + } else { + ggml_vk_create_pipeline( + device, device->pipeline_conv2d_f32, "conv2d_f32", conv2d_f32_len, conv2d_f32_data, "main", 3, + sizeof(vk_op_conv2d_push_constants), { conv2d_BS_K, conv2d_BS_NPQ, 1 }, + { conv2d_WG_SIZE, conv2d_BS_K, conv2d_BS_CRS, conv2d_BS_NPQ, conv2d_TS_K, use_collectives }, 1, true, + false); + } + ggml_vk_create_pipeline(device, device->pipeline_conv2d_dw_whcn_f32, "conv2d_dw_whcn_f32", conv2d_dw_whcn_f32_len, conv2d_dw_whcn_f32_data, "main", 3, sizeof(vk_op_conv2d_dw_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_conv2d_dw_cwhn_f32, "conv2d_dw_cwhn_f32", conv2d_dw_cwhn_f32_len, conv2d_dw_cwhn_f32_data, "main", 3, sizeof(vk_op_conv2d_dw_push_constants), {512, 1, 1}, {}, 1); @@ -3297,6 +3418,12 @@ static vk_device ggml_vk_get_device(size_t idx) { device->fp16 = device->fp16 && vk12_features.shaderFloat16; +#if defined(VK_KHR_shader_bfloat16) + device->bf16 = bfloat16_support && bfloat16_features.shaderBFloat16Type; +#else + device->bf16 = false; +#endif + device->pipeline_robustness = pl_robustness_features.pipelineRobustness; if (device->subgroup_size_control) { @@ -3639,6 +3766,7 @@ static void ggml_vk_print_gpu_info(size_t idx) { bool coopmat_support = false; bool coopmat2_support = false; bool integer_dot_product = false; + bool bfloat16_support = false; for (auto properties : ext_props) { if (strcmp("VK_KHR_16bit_storage", properties.extensionName) == 0) { @@ -3659,6 +3787,11 @@ static void ggml_vk_print_gpu_info(size_t idx) { } else if (strcmp("VK_KHR_shader_integer_dot_product", properties.extensionName) == 0 && !getenv("GGML_VK_DISABLE_INTEGER_DOT_PRODUCT")) { integer_dot_product = true; +#endif +#if defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT) + } else if (strcmp("VK_KHR_shader_bfloat16", properties.extensionName) == 0 && + !getenv("GGML_VK_DISABLE_BFLOAT16")) { + bfloat16_support = true; #endif } } @@ -3725,10 +3858,25 @@ static void ggml_vk_print_gpu_info(size_t idx) { last_struct = (VkBaseOutStructure *)&shader_integer_dot_product_features; } +#if defined(VK_KHR_shader_bfloat16) + VkPhysicalDeviceShaderBfloat16FeaturesKHR bfloat16_features {}; + bfloat16_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SHADER_BFLOAT16_FEATURES_KHR; + if (bfloat16_support) { + last_struct->pNext = (VkBaseOutStructure *)&bfloat16_features; + last_struct = (VkBaseOutStructure *)&bfloat16_features; + } +#endif + vkGetPhysicalDeviceFeatures2(physical_device, &device_features2); fp16 = fp16 && vk12_features.shaderFloat16; +#if defined(VK_KHR_shader_bfloat16) + bool bf16 = bfloat16_support && bfloat16_features.shaderBFloat16Type; +#else + bool bf16 = false; +#endif + uint32_t default_subgroup_size = get_subgroup_size("", device_architecture); const size_t subgroup_size = (default_subgroup_size != 0) ? default_subgroup_size : subgroup_props.subgroupSize; const bool uma = props2.properties.deviceType == vk::PhysicalDeviceType::eIntegratedGpu; @@ -3746,8 +3894,8 @@ static void ggml_vk_print_gpu_info(size_t idx) { std::string matrix_cores = coopmat2_support ? "NV_coopmat2" : coopmat_support ? "KHR_coopmat" : "none"; std::string device_name = props2.properties.deviceName.data(); - GGML_LOG_DEBUG("ggml_vulkan: %zu = %s (%s) | uma: %d | fp16: %d | warp size: %zu | shared memory: %d | int dot: %d | matrix cores: %s\n", - idx, device_name.c_str(), driver_props.driverName.data(), uma, fp16, subgroup_size, + GGML_LOG_DEBUG("ggml_vulkan: %zu = %s (%s) | uma: %d | fp16: %d | bf16: %d | warp size: %zu | shared memory: %d | int dot: %d | matrix cores: %s\n", + idx, device_name.c_str(), driver_props.driverName.data(), uma, fp16, bf16, subgroup_size, props2.properties.limits.maxComputeSharedMemorySize, integer_dot_product, matrix_cores.c_str()); if (props2.properties.deviceType == vk::PhysicalDeviceType::eCpu) { @@ -6833,6 +6981,12 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const return ctx->device->pipeline_leaky_relu_f32; } return nullptr; + case GGML_OP_CONV_2D: + if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32 && + ggml_is_contiguous(src0) && ggml_is_contiguous(src1) && ggml_is_contiguous(dst)) { + return ctx->device->pipeline_conv2d_f32; + } + return nullptr; case GGML_OP_CONV_2D_DW: if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { if (ggml_is_contiguous(src1)) { @@ -7155,6 +7309,31 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co const uint32_t OW = dst->ne[0]; elements = { N * OC * OH * OW, 1, 1}; } break; + case GGML_OP_CONV_2D: + { + // src0 - kernel: [KW, KH, Cin, Cout] + // src1 - input: [W, H, Cin, N] + // dst - result: [OW, OH, Cout, N] + + // Copied from ggml.c: int64_t ggml_calc_conv_output_size(int64_t ins, int64_t ks, int s, int p, int d) + auto calc_conv_output_size = [](int64_t ins, int64_t ks, int s, int p, int d) -> int64_t { + return (ins + 2 * p - d * (ks - 1) - 1) / s + 1; + }; + // parallelize in {OW/BS_K, OH/BS_NPQ, 1} + int64_t W = src1->ne[0]; + int64_t H = src1->ne[1]; + int64_t KW = src0->ne[0]; + int64_t KH = src0->ne[1]; + int64_t Cout = src0->ne[3]; + int64_t N = src1->ne[3]; + int64_t OH = calc_conv_output_size(H, KH, dst->op_params[1], dst->op_params[3], dst->op_params[5]); + int64_t OW = calc_conv_output_size(W, KW, dst->op_params[0], dst->op_params[2], dst->op_params[4]); + int64_t NPQ = N * OW * OH; + + // Tile output matrix to (K/NB_K, NPQ/NB_NPQ, 1) workgroups + elements = { static_cast(Cout), static_cast(NPQ), 1 }; + } + break; case GGML_OP_ADD: case GGML_OP_SUB: case GGML_OP_DIV: @@ -8021,6 +8200,55 @@ static void ggml_vk_pool_2d(ggml_backend_vk_context * ctx, vk_context& subctx, c }, dryrun); } +static void ggml_vk_conv_2d(ggml_backend_vk_context * ctx, vk_context & subctx, const ggml_tensor * src0, + const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT(src1->type == GGML_TYPE_F32); + GGML_ASSERT(dst->type == GGML_TYPE_F32); + + GGML_TENSOR_BINARY_OP_LOCALS + + GGML_ASSERT(nb00 == sizeof(float)); + GGML_ASSERT(nb10 == sizeof(float)); + GGML_ASSERT(nb0 == sizeof(float)); + + vk_op_conv2d_push_constants p{}; + p.Cout = static_cast(ne03); + p.Cin = static_cast(ne02); + p.N = static_cast(ne13); + + p.KW = static_cast(ne00); + p.KH = static_cast(ne01); + p.W = static_cast(ne10); + p.H = static_cast(ne11); + p.OW = static_cast(ne0); + p.OH = static_cast(ne1); + + p.s0 = static_cast(dst->op_params[0]); + p.s1 = static_cast(dst->op_params[1]); + p.p0 = static_cast(dst->op_params[2]); + p.p1 = static_cast(dst->op_params[3]); + p.d0 = static_cast(dst->op_params[4]); + p.d1 = static_cast(dst->op_params[5]); + + p.nb01 = static_cast(nb01 / nb00); + p.nb02 = static_cast(nb02 / nb00); + p.nb03 = static_cast(nb03 / nb00); + + p.nb11 = static_cast(nb11 / nb10); + p.nb12 = static_cast(nb12 / nb10); + p.nb13 = static_cast(nb13 / nb10); + + p.nb1 = static_cast(nb1 / nb0); + p.nb2 = static_cast(nb2 / nb0); + p.nb3 = static_cast(nb3 / nb0); + + GGML_ASSERT(ne03 == ne2); + GGML_ASSERT(ne02 == ne12); + + ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_CONV_2D, std::move(p), dryrun); +} + static void ggml_vk_conv_2d_dw(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { vk_op_conv2d_dw_push_constants p{}; p.ne = ggml_nelements(dst); @@ -9083,6 +9311,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr case GGML_OP_TIMESTEP_EMBEDDING: case GGML_OP_CONV_TRANSPOSE_1D: case GGML_OP_POOL_2D: + case GGML_OP_CONV_2D: case GGML_OP_CONV_2D_DW: case GGML_OP_RWKV_WKV6: case GGML_OP_RWKV_WKV7: @@ -9150,6 +9379,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr case GGML_OP_TIMESTEP_EMBEDDING: case GGML_OP_CONV_TRANSPOSE_1D: case GGML_OP_POOL_2D: + case GGML_OP_CONV_2D: case GGML_OP_CONV_2D_DW: case GGML_OP_LEAKY_RELU: { @@ -9356,6 +9586,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr case GGML_OP_POOL_2D: ggml_vk_pool_2d(ctx, compute_ctx, src0, node, dryrun); + break; + case GGML_OP_CONV_2D: + ggml_vk_conv_2d(ctx, compute_ctx, src0, src1, node, dryrun); + break; case GGML_OP_CONV_2D_DW: ggml_vk_conv_2d_dw(ctx, compute_ctx, src0, src1, node, dryrun); @@ -9486,6 +9720,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_cgraph * case GGML_OP_TIMESTEP_EMBEDDING: case GGML_OP_CONV_TRANSPOSE_1D: case GGML_OP_POOL_2D: + case GGML_OP_CONV_2D: case GGML_OP_CONV_2D_DW: case GGML_OP_RWKV_WKV6: case GGML_OP_RWKV_WKV7: @@ -10067,6 +10302,12 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg ggml_vk_build_graph(ctx, cgraph, i, nullptr, 0, true, false, false, false); if (cgraph->nodes[i]->op == GGML_OP_MUL_MAT || cgraph->nodes[i]->op == GGML_OP_MUL_MAT_ID) { total_mat_mul_bytes += ggml_nbytes(cgraph->nodes[i]->src[0]); + } else if (cgraph->nodes[i]->op == GGML_OP_CONV_2D) { + // Return CRSxNPQxsizeof(*) to account as many bytes as mul_mat has in im2col->mul_mat mode. + auto CRS_size = + cgraph->nodes[i]->src[0]->ne[0] * cgraph->nodes[i]->src[0]->ne[1] * cgraph->nodes[i]->src[0]->ne[2]; + auto NPQ_size = cgraph->nodes[i]->ne[0] * cgraph->nodes[i]->ne[1] * cgraph->nodes[i]->ne[3]; + total_mat_mul_bytes += NPQ_size * CRS_size * ggml_type_size(cgraph->nodes[i]->type); } i += ctx->num_additional_fused_ops; ctx->num_additional_fused_ops = 0; @@ -10643,6 +10884,20 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm return true; case GGML_OP_CONV_TRANSPOSE_1D: return op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32; + case GGML_OP_CONV_2D: + { + // Op is disabled for Apple because it segfaults at pipeline create time on MoltenVK + ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context; + const vk_device& device = ggml_vk_get_device(ctx->device); + bool is_Apple = ggml_vk_get_device(ctx->device)->vendor_id == VK_VENDOR_ID_APPLE; + // Channel-contiguous format is not supported yet. + return (op->src[0]->type == GGML_TYPE_F32 && + op->src[1]->type == GGML_TYPE_F32 && + op->type == GGML_TYPE_F32 && + ggml_is_contiguous(op->src[0]) && + ggml_is_contiguous(op->src[1]) && + ggml_is_contiguous(op)) && !is_Apple; + } default: return false; } @@ -11201,6 +11456,14 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph * const int32_t p1 = tensor->op_params[6]; tensor_clone = ggml_pool_2d(ggml_ctx, src_clone[0], op, k0, k1, s0, s1, p0, p1); + } else if (tensor->op == GGML_OP_CONV_2D) { + const int32_t s0 = tensor->op_params[0]; + const int32_t s1 = tensor->op_params[1]; + const int32_t p0 = tensor->op_params[2]; + const int32_t p1 = tensor->op_params[3]; + const int32_t d0 = tensor->op_params[4]; + const int32_t d1 = tensor->op_params[5]; + tensor_clone = ggml_conv_2d(ggml_ctx, src_clone[0], src_clone[1], s0, s1, p0, p1, d0, d1); } else if (tensor->op == GGML_OP_LEAKY_RELU) { const float * op_params = (const float *)tensor->op_params; tensor_clone = ggml_leaky_relu(ggml_ctx, src_clone[0], op_params[0], false); diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp b/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp new file mode 100644 index 000000000..481940a52 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp @@ -0,0 +1,265 @@ +#version 450 + +#ifdef USE_COLLECTIVES +# extension GL_KHR_shader_subgroup_shuffle : enable +#endif + +#include "types.comp" + +// Make spec constant +#define SHMEM_PAD 0 + +// shape notation: [dim(N), ..., dim(0)] -- stride(dim(j)) >= stride(dim(i)) if i > j +layout(binding = 0) readonly buffer A { + A_TYPE knl_data[]; +}; // src0 - kernel: [KW, KH, Cin, Cout] + +layout(binding = 1) readonly buffer B { + B_TYPE src_data[]; +}; // src1 - input: [W, H, Cin, N] -- channel_first format + +layout(binding = 2) writeonly buffer D { + D_TYPE dst_data[]; +}; // dst - result: [OW, OH, Cout, N] + +layout(push_constant) uniform parameter { + // I/O channels, batch size + uint32_t Cout; + uint32_t Cin; + uint32_t N; + + // Tensor spatial sizes: kernel, input, output + uint32_t KW; + uint32_t KH; + uint32_t W; + uint32_t H; + uint32_t OW; + uint32_t OH; + + // Parameters: stride, padding, dilation - 0=y, 1=x + uint32_t s0; + uint32_t s1; + uint32_t p0; + uint32_t p1; + uint32_t d0; + uint32_t d1; + + // Strides in elements + uint32_t nb01; + uint32_t nb02; + uint32_t nb03; + + uint32_t nb11; + uint32_t nb12; + uint32_t nb13; + + uint32_t nb1; + uint32_t nb2; + uint32_t nb3; +} + +p; + +layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; +// Blocktile sizes +layout(constant_id = 1) const uint BS_K = 128; +layout(constant_id = 2) const uint BS_CRS = 16; +layout(constant_id = 3) const uint BS_NPQ = 128; +// Thread-tile sizes +layout(constant_id = 4) const uint TS_K = 8; +layout(constant_id = 5) const uint use_collectives = 1; + +uint32_t tid = gl_LocalInvocationID.x; +const uint32_t WG_SIZE = gl_WorkGroupSize.x; + +uint splitWork(uint work_size, uint block_size) { + return (block_size + work_size - 1) / block_size; +} + +uint32_t K = p.Cout; +uint32_t CRS = p.Cin * p.KH * p.KW; +uint32_t NPQ = p.N * p.OH * p.OW; + +uint32_t n_elems_out = K * NPQ; + +// Number of blocktiles per input +uint32_t NB_CRS = splitWork(CRS, BS_CRS); + +const uint32_t Ash_stride = BS_CRS + SHMEM_PAD; +const uint32_t Bsh_stride = BS_NPQ + SHMEM_PAD; + +const uint32_t Ash_numel = BS_K * BS_CRS; +const uint32_t Bsh_numel = BS_CRS * BS_NPQ; + +const uint32_t Ash_len = BS_K * Ash_stride; +const uint32_t Bsh_len = BS_CRS * Bsh_stride; + +shared float Ash[Ash_len]; // K x CRS +shared float Bsh[Bsh_len]; // CRS x NPQ + +// Threadtile sizes +const uint32_t TS_NPQ = BS_K * BS_NPQ / WG_SIZE / TS_K; + +// Number of threadtiles per blocktile +const uint32_t NT_K = BS_K / TS_K; +const uint32_t NT_NPQ = BS_NPQ / TS_NPQ; + +float regA[TS_K]; +float regB[TS_NPQ]; +float regC[TS_K][TS_NPQ]; + +/* +Compute +KxCRS @ CRSxNPQ = K x NPQ +K=Cout +C=Cin +R,S=KH,KW +P,Q=OH,OW +*/ + +uint32_t B_idx_K = gl_WorkGroupID.x; +uint32_t B_idx_NPQ = gl_WorkGroupID.y; + +uint32_t T_y = tid / NT_NPQ; +uint32_t T_x = tid % NT_NPQ; + +uint32_t Ar = tid / BS_CRS; +uint32_t Ac = tid % BS_CRS; +const uint32_t ArpWg = WG_SIZE / BS_CRS; + +uint32_t Br = tid / BS_NPQ; +uint32_t Bc = tid % BS_NPQ; +const uint32_t BrpWg = WG_SIZE / BS_NPQ; + +void main() { + for (uint32_t T_ly = 0; T_ly < TS_K; T_ly++) { + for (uint32_t T_lx = 0; T_lx < TS_NPQ; T_lx++) { + regC[T_ly][T_lx] = 0.0; + } + } + /* Advance block in CRS dim */ + for (uint32_t B_idx_CRS = 0; B_idx_CRS < NB_CRS; B_idx_CRS++) { + uint32_t CRS_idx_a; + uint32_t Cin_idx_a; + uint32_t KH_idx_a; + uint32_t KW_idx_a; + +#ifdef USE_COLLECTIVES + uint32_t cached_CRS_idx; + uint32_t cached_Cin_idx; + uint32_t cached_KH_idx; + uint32_t cached_KW_idx; + if (use_collectives == 1) { + cached_CRS_idx = B_idx_CRS * BS_CRS + gl_SubgroupInvocationID; + cached_Cin_idx = cached_CRS_idx / (p.KW * p.KH); + uint32_t cached_CRS_remainder = (cached_CRS_idx - cached_Cin_idx * p.KW * p.KH); + cached_KH_idx = cached_CRS_remainder / p.KW; + cached_KW_idx = cached_CRS_remainder - cached_KH_idx * p.KW; + + CRS_idx_a = subgroupShuffle(cached_CRS_idx, Ac); + Cin_idx_a = subgroupShuffle(cached_Cin_idx, Ac); + KH_idx_a = subgroupShuffle(cached_KH_idx, Ac); + KW_idx_a = subgroupShuffle(cached_KW_idx, Ac); + } else { + CRS_idx_a = B_idx_CRS * BS_CRS + Ac; // Global CRS_idx_a (column index of A) + Cin_idx_a = CRS_idx_a / (p.KW * p.KH); + uint32_t CRS_remainder = CRS_idx_a - Cin_idx_a * p.KW * p.KH; + KH_idx_a = CRS_remainder / p.KW; + KW_idx_a = CRS_remainder - KH_idx_a * p.KW; + } +#else + CRS_idx_a = B_idx_CRS * BS_CRS + Ac; // Global CRS_idx_a (column index of A) + Cin_idx_a = CRS_idx_a / (p.KW * p.KH); + CRS_remainder = CRS_idx_a - Cin_idx_a * p.KW * p.KH; + KH_idx_a = CRS_remainder / p.KW; + KW_idx_a = CRS_remainder - KH_idx_a * p.KW; +#endif + + /* Load kernel to A_block: (BS_K x BS_CRS)*/ + for (uint32_t r_offset = 0; r_offset < BS_K; r_offset += ArpWg) { + uint32_t B_ly = r_offset + Ar; + uint32_t B_lx = Ac; + uint32_t K_idx = B_idx_K * BS_K + B_ly; /* Global K_idx (row index of A)*/ + uint32_t knl_idx = min(KW_idx_a + KH_idx_a * p.nb01 + Cin_idx_a * p.nb02 + K_idx * p.nb03, K * CRS - 1); + float val = knl_data[knl_idx]; + if (K_idx >= K || CRS_idx_a >= CRS) { + val = 0.0; + } + Ash[B_ly * Ash_stride + B_lx] = val; + } + /* Load input to B_block: (BS_CRS x BS_NPQ) */ + for (uint32_t r_offset = 0; r_offset < BS_CRS; r_offset += BrpWg) { + uint32_t B_ly = r_offset + Br; /* Row index of B block */ + uint32_t B_lx = Bc; + uint32_t NPQ_idx = B_idx_NPQ * BS_NPQ + B_lx; /* Global NPQ index (column index of B) */ + uint32_t N_idx = NPQ_idx / (p.OH * p.OW); + uint32_t NPQ_remainder = NPQ_idx - N_idx * p.OH * p.OW; + uint32_t OH_idx = NPQ_remainder / p.OW; + uint32_t OW_idx = NPQ_remainder - OH_idx * p.OW; + + uint32_t CRS_idx_b; + uint32_t Cin_idx_b; + uint32_t KH_idx_b; + uint32_t KW_idx_b; +#ifdef USE_COLLECTIVES + if (use_collectives == 1) { + CRS_idx_b = subgroupShuffle(cached_CRS_idx, r_offset + Br); + Cin_idx_b = subgroupShuffle(cached_Cin_idx, r_offset + Br); + KH_idx_b = subgroupShuffle(cached_KH_idx, r_offset + Br); + KW_idx_b = subgroupShuffle(cached_KW_idx, r_offset + Br); + } else { + CRS_idx_b = B_idx_CRS * BS_CRS + B_ly; /* Global CRS index (row index of B) */ + Cin_idx_b = CRS_idx_b / (p.KW * p.KH); + uint32_t CRS_remainder = CRS_idx_b - Cin_idx_b * p.KW * p.KH; + KH_idx_b = CRS_remainder / p.KW; + KW_idx_b = CRS_remainder - KH_idx_b * p.KW; + } +#else + CRS_idx_b = B_idx_CRS * BS_CRS + B_ly; /* Global CRS index (row index of B) */ + Cin_idx_b = CRS_idx_b / (p.KW * p.KH); + uint32_t CRS_remainder = CRS_idx_b - Cin_idx_b * p.KW * p.KH; + KH_idx_b = CRS_remainder / p.KW; + KW_idx_b = CRS_remainder - KH_idx_b * p.KW; +#endif + + uint32_t H_idx = OH_idx * p.s1 + KH_idx_b * p.d1 - p.p1; + uint32_t W_idx = OW_idx * p.s0 + KW_idx_b * p.d0 - p.p0; + uint32_t src_idx = + min(max(W_idx + H_idx * p.nb11 + Cin_idx_b * p.nb12 + N_idx * p.nb13, 0), p.Cin * p.N * p.W * p.H - 1); + float val = src_data[src_idx]; + if (CRS_idx_b >= CRS || NPQ_idx >= NPQ || H_idx < 0 || H_idx >= p.H || W_idx < 0 || W_idx >= p.W) { + val = 0.0; + } + Bsh[B_ly * Bsh_stride + B_lx] = val; + } + barrier(); + for (uint32_t CRS_lidx = 0; CRS_lidx < BS_CRS; CRS_lidx++) { + for (uint32_t T_ly = 0; T_ly < TS_K; T_ly++) { + regA[T_ly] = Ash[(T_y * TS_K + T_ly) * Ash_stride + CRS_lidx]; + } + for (uint32_t T_lx = 0; T_lx < TS_NPQ; T_lx++) { + regB[T_lx] = Bsh[CRS_lidx * Bsh_stride + T_x * TS_NPQ + T_lx]; + } + for (uint32_t T_ly = 0; T_ly < TS_K; T_ly++) { + for (uint32_t T_lx = 0; T_lx < TS_NPQ; T_lx++) { + regC[T_ly][T_lx] = fma(regA[T_ly], regB[T_lx], regC[T_ly][T_lx]); + } + } + } + barrier(); + } + /* Save C* */ + for (uint32_t T_ly = 0; T_ly < TS_K; T_ly++) { + for (uint32_t T_lx = 0; T_lx < TS_NPQ; T_lx++) { + uint32_t K_idx = B_idx_K * BS_K + T_y * TS_K + T_ly; + uint32_t NPQ_idx = B_idx_NPQ * BS_NPQ + T_x * TS_NPQ + T_lx; + uint32_t N_idx = NPQ_idx / (p.OH * p.OW); + uint32_t OH_idx = (NPQ_idx - N_idx * p.OH * p.OW) / p.OW; + uint32_t OW_idx = NPQ_idx - N_idx * p.OH * p.OW - OH_idx * p.OW; + uint32_t dst_idx = OW_idx + OH_idx * p.nb1 + K_idx * p.nb2 + N_idx * p.nb3; + if (K_idx < K && NPQ_idx < NPQ) { + dst_data[dst_idx] = regC[T_ly][T_lx]; + } + } + } +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp index 020144f76..3cf184fdd 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp @@ -669,6 +669,8 @@ void process_shaders() { string_to_spv("opt_step_adamw_f32", "opt_step_adamw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}})); + string_to_spv("conv2d_f32", "conv2d_mm.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"USE_COLLECTIVES", "1"}}); + string_to_spv("conv2d_dw_whcn_f32", "conv2d_dw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"WHCN", "1"}})); string_to_spv("conv2d_dw_cwhn_f32", "conv2d_dw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"CWHN", "1"}})); @@ -779,8 +781,8 @@ void write_output_files() { len += "};\n"; } } - fprintf(src, data.c_str()); - fprintf(src, len.c_str()); + fputs(data.c_str(), src); + fputs(len.c_str(), src); } fclose(hdr); fclose(src); diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 40e809f1a..680210db7 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -233,6 +233,11 @@ class Keys: TYPE = "adapter.type" LORA_ALPHA = "adapter.lora.alpha" + class IMatrix: + CHUNK_COUNT = "imatrix.chunk_count" + CHUNK_SIZE = "imatrix.chunk_size" + DATASETS = "imatrix.datasets" + class Clip: PROJECTOR_TYPE = "clip.projector_type" HAS_VISION_ENCODER = "clip.has_vision_encoder" @@ -282,6 +287,7 @@ class Keys: class GGUFType: MODEL = "model" ADAPTER = "adapter" + IMATRIX = "imatrix" MMPROJ = "mmproj" # dummy, unused for now diff --git a/klite.embd b/klite.embd index 50b1642ae..8b3c2c0f9 100644 --- a/klite.embd +++ b/klite.embd @@ -3425,6 +3425,9 @@ Current version indicated by LITEVER below. websearch_retain: false, websearch_template: "", wordsearch_enabled: false, + second_ep_qty:0, + second_ep_model:"gpt2", + second_ep_url:"", max_context_length: (localflag?4096:2048), max_length: (localflag?512:256), @@ -12674,6 +12677,9 @@ Current version indicated by LITEVER below. document.getElementById("dry_base").value = localsettings.dry_base; document.getElementById("dry_allowed_length").value = localsettings.dry_allowed_length; document.getElementById("token_count_multiplier").value = localsettings.token_count_multiplier; + document.getElementById("second_ep_qty").value = localsettings.second_ep_qty; + document.getElementById("second_ep_model").value = localsettings.second_ep_model; + document.getElementById("second_ep_url").value = localsettings.second_ep_url; if(is_using_kcpp_with_mirostat()) { @@ -13247,6 +13253,9 @@ Current version indicated by LITEVER below. localsettings.token_count_multiplier = parseInt(document.getElementById("token_count_multiplier").value); localsettings.guidance_scale = parseFloat(document.getElementById("guidance_scale").value); localsettings.guidance_prompt = document.getElementById("guidance_prompt").value; + localsettings.second_ep_qty = parseInt(document.getElementById("second_ep_qty").value); + localsettings.second_ep_model = document.getElementById("second_ep_model").value; + localsettings.second_ep_url = document.getElementById("second_ep_url").value; localsettings.extrastopseq = document.getElementById("extrastopseq").value; localsettings.tokenbans = document.getElementById("tokenbans").value; @@ -13446,6 +13455,7 @@ Current version indicated by LITEVER below. localsettings.xtc_threshold = cleannum(localsettings.xtc_threshold, 0.0, 1.0); localsettings.sampler_seed = cleannum(localsettings.sampler_seed, -1, 999999); localsettings.token_count_multiplier = cleannum(localsettings.token_count_multiplier, 70, 130); + localsettings.second_ep_qty = cleannum(Math.floor(localsettings.second_ep_qty), 0, 1024); toggle_invert_colors(); toggle_sidepanel_mode(); @@ -13901,6 +13911,7 @@ Current version indicated by LITEVER below. on_searchsummary_done = onDoneFn; + submit_payload = finalize_submit_payload(submit_payload, false); dispatch_submit_generation(submit_payload, false); render_gametext(); } @@ -13975,7 +13986,8 @@ Current version indicated by LITEVER below. //v2 api specific fields submit_payload.workers = selected_workers.map((m) => { return m.id }); - dispatch_submit_generation(submit_payload,false); + submit_payload = finalize_submit_payload(submit_payload, false); + dispatch_submit_generation(submit_payload, false); render_gametext(); document.getElementById("memorytext").value = "[<|Generating summary, do not close window...|>]" }; @@ -16545,6 +16557,14 @@ Current version indicated by LITEVER below. if (!doNotGenerate) { + submit_payload = finalize_submit_payload(submit_payload, user_input_empty); + if(localsettings.second_ep_qty>0 && localsettings.second_ep_url!="") + { + let decensored_prefix = yield FetchDecensoredPrefix(submit_payload, localsettings.second_ep_url, localsettings.second_ep_model, localsettings.second_ep_qty); + console.log(`Obtained Prefix: ${decensored_prefix}`); + pending_context_preinjection += decensored_prefix; + submit_payload.prompt += decensored_prefix; + } dispatch_submit_generation(submit_payload, user_input_empty); } else @@ -16714,10 +16734,8 @@ Current version indicated by LITEVER below. return resp; } - function dispatch_submit_generation(submit_payload, input_was_empty) //if input is not empty, always unban eos + function finalize_submit_payload(submit_payload, input_was_empty) { - console.log(submit_payload); - //preprocess to add extra fields if(custom_kobold_endpoint != "" && is_using_kcpp_with_mirostat()) { @@ -16772,6 +16790,28 @@ Current version indicated by LITEVER below. submit_payload.params.negative_prompt = localsettings.guidance_prompt; } + //for vesion 1.2.2 and later, send stopper tokens for chat and instruct + if (custom_kobold_endpoint != "" && kobold_endpoint_version && kobold_endpoint_version != "" && compare_version_str(kobold_endpoint_version, "1.2.2") >= 0) { + submit_payload.params.stop_sequence = get_stop_sequences(); + } + + //version 1.2.4 and later supports unban tokens + if (custom_kobold_endpoint != "" && kobold_endpoint_version && kobold_endpoint_version != "" && compare_version_str(kobold_endpoint_version, "1.2.4") >= 0) + { + submit_payload.params.use_default_badwordsids = determine_if_ban_eos(input_was_empty); + if(is_using_kcpp_with_added_memory()) + { + submit_payload.params.bypass_eos = (localsettings.eos_ban_mode == 3?true:false); + } + } + + return submit_payload; + } + + function dispatch_submit_generation(submit_payload, input_was_empty) //if input is not empty, always unban eos + { + console.log(submit_payload); + start_time_taken(); //timestamp start request if (is_using_custom_ep()) { @@ -16792,23 +16832,6 @@ Current version indicated by LITEVER below. let prompt = submit_payload.prompt; submit_payload = submit_payload.params; submit_payload.prompt = prompt; - let showlog = false; - submit_payload.quiet = !showlog; - - //for vesion 1.2.2 and later, send stopper tokens for chat and instruct - if (kobold_endpoint_version && kobold_endpoint_version != "" && compare_version_str(kobold_endpoint_version, "1.2.2") >= 0) { - submit_payload.stop_sequence = get_stop_sequences(); - } - - //version 1.2.4 and later supports unban tokens - if (kobold_endpoint_version && kobold_endpoint_version != "" && compare_version_str(kobold_endpoint_version, "1.2.4") >= 0) - { - submit_payload.use_default_badwordsids = determine_if_ban_eos(input_was_empty); - if(is_using_kcpp_with_added_memory()) - { - submit_payload.bypass_eos = (localsettings.eos_ban_mode == 3?true:false); - } - } last_request_str = JSON.stringify(submit_payload); last_response_obj = null; @@ -17006,18 +17029,6 @@ Current version indicated by LITEVER below. } else { - //apply custom logit bias for official OAI only - let needbaneos = (custom_oai_endpoint.toLowerCase().includes("api.openai.com") && determine_if_ban_eos(input_was_empty)); - - if(needbaneos) - { - if(oai_payload.logit_bias) - { - oai_payload.logit_bias["50256"] = -100; - }else{ - oai_payload.logit_bias = { "50256": -100 }; - } - } oai_payload.prompt = submit_payload.prompt; } @@ -19039,7 +19050,7 @@ Current version indicated by LITEVER below. { if(!localsettings.allow_continue_chat) { - pending_context_preinjection = get_instructendplaceholder(); + pending_context_preinjection = get_instructendplaceholder(); if(localsettings.inject_timestamps) { @@ -20351,6 +20362,7 @@ Current version indicated by LITEVER below. fulltxt = replace_placeholders(fulltxt,true); } + let alreadyShowedStreaming = false; //sometimes, streaming is injected inline to preserve indentation if(localsettings.opmode==4 && !inEditMode) { @@ -20380,6 +20392,36 @@ Current version indicated by LITEVER below. { let curr = instruct_turns[i]; let currmsg = curr.msg; + + //in some special cases, add streaming text inline + let allow_cont_prev_turn = (localsettings.opmode==4 || (localsettings.opmode==3 && localsettings.allow_continue_chat)); + if(i>0 && (i==instruct_turns.length-1) && curr.myturn==false && currmsg && allow_cont_prev_turn) + { + //inject into previous turn, only for instruct OR continuechat + currmsg += `${escape_html(pending_context_preinjection) + format_streaming_text(escape_html(synchro_pending_stream))}`; + alreadyShowedStreaming = true; + } + + //apply stylization to time tags + if(localsettings.inject_timestamps && localsettings.instruct_has_markdown) + { + currmsg = currmsg.replace(/(\[\d{1,2}\/\d{1,2}\/\d{4}, \d{1,2}:\d{2} [AP]M\])/g, "$1\n"); + } + if(localsettings.inject_chatnames_instruct && localsettings.instruct_has_markdown) + { + let m_name = localsettings.chatname + ": "; + currmsg = replaceAll(currmsg, m_name, `` + escape_html(m_name) + ``); + let m_opps = localsettings.chatopponent.split("||$||"); + for(let i=0;i` + escape_html(m_opp) + ``); + } + } + } + if(localsettings.instruct_has_markdown && (localsettings.render_streaming_markdown||synchro_pending_stream==""||(i+1)` + escape_html(m_name) + ``); - let m_opps = localsettings.chatopponent.split("||$||"); - for(let i=0;i` + escape_html(m_opp) + ``); - } - } - } - }else{ + } else { fulltxt = replaceAll(fulltxt, get_instruct_starttag(true), `%SclStg%`+escape_html(get_instruct_starttag(true))+`%SpnEtg%`); fulltxt = replaceAll(fulltxt, get_instruct_endtag(true), `%SclStg%`+escape_html(get_instruct_endtag(true))+`%SpnEtg%`); if (localsettings.separate_end_tags) { @@ -20463,7 +20486,7 @@ Current version indicated by LITEVER below. { if(!document.getElementById("allowediting").checked && !fulltxt.startsWith("\n")) { - fulltxt = "\n"+fulltxt; + fulltxt = "\n" + fulltxt; } //for chat mode, highlight our name in blue and opponent in red @@ -20483,7 +20506,6 @@ Current version indicated by LITEVER below. return `` + oname + ``; }); fulltxt = replaceAll(fulltxt,m_name, `` + escape_html(m_name) + ``); - } //for adventure mode, highlight our actions in green @@ -20493,8 +20515,8 @@ Current version indicated by LITEVER below. }); } - //streaming display - if(synchro_pending_stream!="" && waiting_for_tool_call==0) + //streaming display for all other cases + if(!alreadyShowedStreaming && synchro_pending_stream!="" && waiting_for_tool_call==0) { fulltxt += `${escape_html(pending_context_preinjection) + format_streaming_text(escape_html(synchro_pending_stream))}`; } @@ -23906,6 +23928,59 @@ Current version indicated by LITEVER below. resultsContainer.appendChild(btn); }); } + + //queries the decensoring prefix from a second OAI compatible endpoint and returns a string to add to our main request + const FetchDecensoredPrefix = asyncRunner(function* (submit_payload, endpoint_url, modelused, num_tokens) + { + let oaiheaders = { + 'Content-Type': 'application/json' + }; + if (custom_oai_key!="" && custom_oai_key!=dummy_api_key) { + oaiheaders['Authorization'] = 'Bearer ' + custom_oai_key; + } + + let scaled_rep_pen = 0; + if(submit_payload.params.presence_penalty > 0) + { + scaled_rep_pen = submit_payload.params.presence_penalty; + }else{ + //original range between 1 and 3, scale to 0 and 2 + scaled_rep_pen = (submit_payload.params.rep_pen - 1.0); + } + let oai_payload = + { + "prompt": submit_payload.prompt, + "max_tokens": num_tokens, + "model": modelused, + "temperature": submit_payload.params.temperature, + "top_p": submit_payload.params.top_p, + "stream": false, + "presence_penalty": scaled_rep_pen, + "stop": get_stop_sequences().slice(0, 4) + }; + + return fetch(endpoint_url, { + method: 'POST', + headers: oaiheaders, + body: JSON.stringify(oai_payload), + referrerPolicy: 'no-referrer', + }) + .then((response) => response.json()) + .then((data) => { + let reply = ""; + if (data.choices != null && data.choices.length > 0) { + let dch = data.choices[0]; + if (dch.text) { + reply = dch.text; + } + } + return reply; + }) + .catch((error) => { + console.error('Error:', error); + return ""; + }); + }); @@ -24873,6 +24948,19 @@ Current version indicated by LITEVER below.
+
+
2ndEpSampleQty ? + If enabled, sample N tokens from a second OpenAI Compatible endpoint first before a request. This can be useful for jailbreak purposes.
+
+
+
Qty:
+
Model:
+
+
+
Url:
+
+
+
diff --git a/koboldcpp.py b/koboldcpp.py index 0df5daa23..357d81164 100644 --- a/koboldcpp.py +++ b/koboldcpp.py @@ -63,7 +63,7 @@ dry_seq_break_max = 128 extra_images_max = 4 # global vars -KcppVersion = "1.96.1" +KcppVersion = "1.96.2" showdebug = True kcpp_instance = None #global running instance global_memory = {"tunnel_url": "", "restart_target":"", "input_to_exit":False, "load_complete":False, "restart_override_config_target":""} diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index 7cac3b98f..b63a41053 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -428,6 +428,8 @@ void llm_graph_result::reset() { t_embd = nullptr; t_embd_pooled = nullptr; + params = {}; + inputs.clear(); buf_compute_meta.resize(ggml_tensor_overhead()*max_nodes + ggml_graph_overhead_custom(max_nodes, false)); @@ -905,20 +907,25 @@ ggml_tensor * llm_graph_context::build_moe_ffn( cb(cur, "ffn_moe_weighted", il); } + ggml_tensor * cur_experts[LLAMA_MAX_EXPERTS] = { nullptr }; + + assert(n_expert_used > 0); + + // order the views before the adds + for (uint32_t i = 0; i < hparams.n_expert_used; ++i) { + cur_experts[i] = ggml_view_2d(ctx0, experts, n_embd, n_tokens, experts->nb[2], i*experts->nb[1]); + + ggml_build_forward_expand(gf, cur_experts[i]); + } + // aggregate experts // note: here we explicitly use hparams.n_expert_used instead of n_expert_used // to avoid potentially a large number of add nodes during warmup // ref: https://github.com/ggml-org/llama.cpp/pull/14753 - ggml_tensor * moe_out = nullptr; - for (uint32_t i = 0; i < hparams.n_expert_used; ++i) { - ggml_tensor * cur_expert = ggml_view_2d(ctx0, experts, n_embd, n_tokens, - experts->nb[2], i*experts->nb[1]); + ggml_tensor * moe_out = cur_experts[0]; - if (i == 0) { - moe_out = cur_expert; - } else { - moe_out = ggml_add(ctx0, moe_out, cur_expert); - } + for (uint32_t i = 1; i < hparams.n_expert_used; ++i) { + moe_out = ggml_add(ctx0, moe_out, cur_experts[i]); } if (hparams.n_expert_used == 1) { diff --git a/tools/quantize/quantize.cpp b/tools/quantize/quantize.cpp index efbf27ae6..7c6e009be 100644 --- a/tools/quantize/quantize.cpp +++ b/tools/quantize/quantize.cpp @@ -1,12 +1,14 @@ #include "build-info.h" #include "common/common.h" #include "llama.h" +#include "gguf.h" #include #include #include #include #include +#include #include #include #include @@ -69,6 +71,11 @@ static const char * const LLM_KV_QUANTIZE_IMATRIX_DATASET = "quantize.imatrix static const char * const LLM_KV_QUANTIZE_IMATRIX_N_ENTRIES = "quantize.imatrix.entries_count"; static const char * const LLM_KV_QUANTIZE_IMATRIX_N_CHUNKS = "quantize.imatrix.chunks_count"; +// TODO: share with imatrix.cpp +static const char * const LLM_KV_IMATRIX_DATASETS = "imatrix.datasets"; +static const char * const LLM_KV_IMATRIX_CHUNK_COUNT = "imatrix.chunk_count"; +static const char * const LLM_KV_IMATRIX_CHUNK_SIZE = "imatrix.chunk_size"; + static bool striequals(const char * a, const char * b) { while (*a && *b) { if (std::tolower(*a) != std::tolower(*b)) { @@ -85,7 +92,7 @@ static bool try_parse_ftype(const std::string & ftype_str_in, llama_ftype & ftyp for (auto ch : ftype_str_in) { ftype_str.push_back(std::toupper(ch)); } - for (auto & it : QUANT_OPTIONS) { + for (const auto & it : QUANT_OPTIONS) { if (striequals(it.name.c_str(), ftype_str.c_str())) { ftype = it.ftype; ftype_str_out = it.name; @@ -94,7 +101,7 @@ static bool try_parse_ftype(const std::string & ftype_str_in, llama_ftype & ftyp } try { int ftype_int = std::stoi(ftype_str); - for (auto & it : QUANT_OPTIONS) { + for (const auto & it : QUANT_OPTIONS) { if (it.ftype == ftype_int) { ftype = it.ftype; ftype_str_out = it.name; @@ -130,7 +137,7 @@ static void usage(const char * executable) { printf(" Advanced option to override model metadata by key in the quantized model. May be specified multiple times.\n"); printf("Note: --include-weights and --exclude-weights cannot be used together\n"); printf("\nAllowed quantization types:\n"); - for (auto & it : QUANT_OPTIONS) { + for (const auto & it : QUANT_OPTIONS) { if (it.name != "COPY") { printf(" %2d or ", it.ftype); } else { @@ -141,7 +148,7 @@ static void usage(const char * executable) { exit(1); } -static int load_imatrix(const std::string & imatrix_file, std::string & imatrix_dataset, std::unordered_map> & imatrix_data) { +static int load_legacy_imatrix(const std::string & imatrix_file, std::vector & imatrix_datasets, std::unordered_map> & imatrix_data) { std::ifstream in(imatrix_file.c_str(), std::ios::binary); if (!in) { printf("%s: failed to open %s\n",__func__, imatrix_file.c_str()); @@ -181,7 +188,9 @@ static int load_imatrix(const std::string & imatrix_file, std::string & imatrix_ exit(1); } if (ncall > 0) { - for (auto& v : e) v /= ncall; + for (auto & v : e) { + v /= ncall; + } } if (getenv("LLAMA_TRACE")) { @@ -189,7 +198,7 @@ static int load_imatrix(const std::string & imatrix_file, std::string & imatrix_ } } - // latest imatrix version contains the dataset filename at the end of the file + // latest legacy imatrix version contains the dataset filename at the end of the file int m_last_call = 0; if (in.peek() != EOF) { in.read((char *)&m_last_call, sizeof(m_last_call)); @@ -197,15 +206,130 @@ static int load_imatrix(const std::string & imatrix_file, std::string & imatrix_ in.read((char *)&dataset_len, sizeof(dataset_len)); std::vector dataset_as_vec(dataset_len); in.read(dataset_as_vec.data(), dataset_len); - imatrix_dataset.assign(dataset_as_vec.begin(), dataset_as_vec.end()); - printf("%s: imatrix dataset='%s'\n", __func__, imatrix_dataset.c_str()); + imatrix_datasets.resize(1); + imatrix_datasets[0].assign(dataset_as_vec.begin(), dataset_as_vec.end()); + printf("%s: imatrix dataset='%s'\n", __func__, imatrix_datasets[0].c_str()); } printf("%s: loaded %d importance matrix entries from %s computed on %d chunks\n", __func__, int(imatrix_data.size()), imatrix_file.c_str(), m_last_call); return m_last_call; } +static int load_imatrix(const std::string & imatrix_file, std::vector & imatrix_datasets, std::unordered_map> & imatrix_data) { + + struct ggml_context * ctx = nullptr; + struct gguf_init_params meta_gguf_params = { + /* .no_alloc = */ false, // the data is needed + /* .ctx = */ &ctx, + }; + struct gguf_context * ctx_gguf = gguf_init_from_file(imatrix_file.c_str(), meta_gguf_params); + if (!ctx_gguf) { + fprintf(stderr, "%s: imatrix file '%s' is using old format\n", __func__, imatrix_file.c_str()); + return load_legacy_imatrix(imatrix_file, imatrix_datasets, imatrix_data); + } + const int32_t n_entries = gguf_get_n_tensors(ctx_gguf); + if (n_entries < 1) { + fprintf(stderr, "%s: no data in file %s\n", __func__, imatrix_file.c_str()); + gguf_free(ctx_gguf); + ggml_free(ctx); + exit(1); + } + + const int dataset_idx = gguf_find_key(ctx_gguf, LLM_KV_IMATRIX_DATASETS); + const int chunk_count_idx = gguf_find_key(ctx_gguf, LLM_KV_IMATRIX_CHUNK_COUNT); + const int chunk_size_idx = gguf_find_key(ctx_gguf, LLM_KV_IMATRIX_CHUNK_SIZE); + if (dataset_idx < 0 || chunk_count_idx < 0 || chunk_size_idx < 0) { + fprintf(stderr, "%s: missing imatrix metadata in file %s\n", __func__, imatrix_file.c_str()); + gguf_free(ctx_gguf); + ggml_free(ctx); + exit(1); + } + + const uint32_t chunk_size = gguf_get_val_u32(ctx_gguf, chunk_size_idx); + + const std::string sums_suffix{ ".in_sum2" }; + const std::string counts_suffix{ ".counts" }; + + // Using an ordered map to get a deterministic iteration order. + std::map> sums_counts_for; + + for (struct ggml_tensor * cur = ggml_get_first_tensor(ctx); cur; cur = ggml_get_next_tensor(ctx, cur)) { + std::string name = cur->name; + + if (name.empty()) { continue; } + + if (string_remove_suffix(name, sums_suffix)) { + // in_sum2 + sums_counts_for[std::move(name)].first = cur; + } else if (string_remove_suffix(name, counts_suffix)) { + // counts + sums_counts_for[std::move(name)].second = cur; + } else { + // ignore other tensors + } + } + + for (const auto & sc : sums_counts_for) { + const std::string & name = sc.first; + const struct ggml_tensor * sums = sc.second.first; + const struct ggml_tensor * counts = sc.second.second; + + if (!sums || !counts) { + fprintf(stderr, "%s: mismatched sums and counts for %s\n", __func__, name.c_str()); + gguf_free(ctx_gguf); + ggml_free(ctx); + exit(1); + } + + const int64_t ne0 = sums->ne[0]; + const int64_t ne1 = sums->ne[1]; + + auto & e = imatrix_data[name]; + e.resize(ggml_nelements(sums)); + float max_count = 0.0f; + for (int64_t j = 0; j < ne1; ++j) { + const float count = ((const float *) counts->data)[j]; + if (count > 0.0f) { + for (int64_t i = 0; i < ne0; ++i) { + e[j*ne0 + i] = ((const float *) sums->data)[j*ne0 + i] / count; + } + } else { + // Partial imatrix data, this tensor never got any input during calibration + for (int64_t i = 0; i < ne0; ++i) { + e[j*ne0 + i] = 1; + } + } + if (count > max_count) { + max_count = count; + } + } + if (getenv("LLAMA_TRACE")) { + printf("%s: loaded data (size = %6d, n_tokens = %6d, n_chunks = %6d) for '%s'\n", __func__, int(e.size()), int(max_count), int(max_count / chunk_size), name.c_str()); + } + } + + int m_last_chunk = gguf_get_val_u32(ctx_gguf, chunk_count_idx); + + int64_t n_datasets = gguf_get_arr_n(ctx_gguf, dataset_idx); + imatrix_datasets.reserve(n_datasets); + for (int64_t i = 0; i < n_datasets; ++i) { + imatrix_datasets.push_back(gguf_get_val_str(ctx_gguf, dataset_idx)); + } + printf("%s: imatrix datasets=['%s'", __func__, imatrix_datasets[0].c_str()); + for (size_t i = 1; i < imatrix_datasets.size(); ++i) { + printf(", '%s'", imatrix_datasets[i].c_str()); + } + printf("]\n"); + + printf("%s: loaded %d importance matrix entries from %s computed on %d chunks\n", __func__, int(imatrix_data.size()), imatrix_file.c_str(), m_last_chunk); + + gguf_free(ctx_gguf); + ggml_free(ctx); + + return m_last_chunk; +} + static int prepare_imatrix(const std::string & imatrix_file, - std::string & imatrix_dataset, + std::vector & imatrix_dataset, const std::vector & included_weights, const std::vector & excluded_weights, std::unordered_map> & imatrix_data) { @@ -217,18 +341,21 @@ static int prepare_imatrix(const std::string & imatrix_file, return m_last_call; } if (!excluded_weights.empty()) { - for (auto& name : excluded_weights) { - for (auto it = imatrix_data.begin(); it != imatrix_data.end(); ) { + for (const auto & name : excluded_weights) { + for (auto it = imatrix_data.begin(); it != imatrix_data.end();) { auto pos = it->first.find(name); - if (pos != std::string::npos) it = imatrix_data.erase(it); - else ++it; + if (pos != std::string::npos) { + it = imatrix_data.erase(it); + } else { + ++it; + } } } } if (!included_weights.empty()) { std::unordered_map> tmp; - for (auto& name : included_weights) { - for (auto& e : imatrix_data) { + for (const auto & name : included_weights) { + for (auto & e : imatrix_data) { auto pos = e.first.find(name); if (pos != std::string::npos) { tmp.emplace(std::move(e)); @@ -397,9 +524,9 @@ int main(int argc, char ** argv) { usage(argv[0]); } - std::string imatrix_dataset; + std::vector imatrix_datasets; std::unordered_map> imatrix_data; - int m_last_call = prepare_imatrix(imatrix_file, imatrix_dataset, included_weights, excluded_weights, imatrix_data); + int m_last_call = prepare_imatrix(imatrix_file, imatrix_datasets, included_weights, excluded_weights, imatrix_data); if (!imatrix_data.empty()) { params.imatrix = &imatrix_data; { @@ -410,11 +537,12 @@ int main(int argc, char ** argv) { kvo.val_str[127] = '\0'; kv_overrides.emplace_back(std::move(kvo)); } - if (!imatrix_dataset.empty()) { + if (!imatrix_datasets.empty()) { llama_model_kv_override kvo; + // TODO: list multiple datasets when there are more than one std::strcpy(kvo.key, LLM_KV_QUANTIZE_IMATRIX_DATASET); kvo.tag = LLAMA_KV_OVERRIDE_TYPE_STR; - strncpy(kvo.val_str, imatrix_dataset.c_str(), 127); + strncpy(kvo.val_str, imatrix_datasets[0].c_str(), 127); kvo.val_str[127] = '\0'; kv_overrides.emplace_back(std::move(kvo)); }