Merge branch 'upstream' into concedo_experimental

# Conflicts:
#	CODEOWNERS
#	docs/build.md
#	scripts/sync-ggml.last
#	tests/test-backend-ops.cpp
#	tools/imatrix/README.md
#	tools/imatrix/imatrix.cpp
This commit is contained in:
Concedo 2025-07-20 22:47:31 +08:00
commit 30675b0798
21 changed files with 1397 additions and 1181 deletions

View file

@ -22,8 +22,8 @@ AllowShortIfStatementsOnASingleLine: Never
AllowShortLambdasOnASingleLine: Inline
AllowShortLoopsOnASingleLine: false
AlwaysBreakBeforeMultilineStrings: true
BinPackArguments: true
BinPackParameters: true # OnePerLine
BinPackArguments: false
BinPackParameters: false # OnePerLine
BitFieldColonSpacing: Both
BreakBeforeBraces: Custom # Attach
BraceWrapping:
@ -70,15 +70,18 @@ ExperimentalAutoDetectBinPacking: false
FixNamespaceComments: true
IncludeBlocks: Regroup
IncludeCategories:
- Regex: '^<.*\.h>'
- Regex: '".*"'
Priority: 1
SortPriority: 0
- Regex: '^<.*'
- Regex: '^<.*\.h>'
Priority: 2
SortPriority: 0
- Regex: '.*'
- Regex: '^<.*'
Priority: 3
SortPriority: 0
- Regex: '.*'
Priority: 4
SortPriority: 0
IncludeIsMainRegex: '([-_](test|unittest))?$'
IncludeIsMainSourceRegex: ''
IndentAccessModifiers: false

View file

@ -456,6 +456,15 @@ void string_replace_all(std::string & s, const std::string & search, const std::
bool string_ends_with(const std::string_view & str, const std::string_view & suffix) {
return str.size() >= suffix.size() && str.compare(str.size()-suffix.size(), suffix.size(), suffix) == 0;
}
bool string_remove_suffix(std::string & str, const std::string_view & suffix) {
bool has_suffix = string_ends_with(str, suffix);
if (has_suffix) {
str = str.substr(0, str.size() - suffix.size());
}
return has_suffix;
}
size_t string_find_partial_stop(const std::string_view & str, const std::string_view & stop) {
if (!str.empty() && !stop.empty()) {
const char text_last_char = str.back();

View file

@ -530,6 +530,7 @@ static bool string_starts_with(const std::string & str,
// While we wait for C++20's std::string::ends_with...
bool string_ends_with(const std::string_view & str, const std::string_view & suffix);
bool string_remove_suffix(std::string & str, const std::string_view & suffix);
size_t string_find_partial_stop(const std::string_view & str, const std::string_view & stop);
bool string_parse_kv_override(const char * data, std::vector<llama_model_kv_override> & overrides);

View file

@ -22,21 +22,6 @@ static bool ggml_is_view(const struct ggml_tensor * t) {
return t->view_src != NULL;
}
static bool ggml_are_same_layout(const struct ggml_tensor * a, const struct ggml_tensor * b) {
if (a->type != b->type) {
return false;
}
for (int i = 0; i < GGML_MAX_DIMS; i++) {
if (a->ne[i] != b->ne[i]) {
return false;
}
if (a->nb[i] != b->nb[i]) {
return false;
}
}
return true;
}
// ops that return true for this function must not use restrict pointers for their backend implementations
static bool ggml_op_can_inplace(enum ggml_op op) {
switch (op) {

View file

@ -352,21 +352,6 @@ ggml_backend_dev_t ggml_backend_get_device(ggml_backend_t backend) {
// backend copy
static bool ggml_are_same_layout(const struct ggml_tensor * a, const struct ggml_tensor * b) {
if (a->type != b->type) {
return false;
}
for (int i = 0; i < GGML_MAX_DIMS; i++) {
if (a->ne[i] != b->ne[i]) {
return false;
}
if (a->nb[i] != b->nb[i]) {
return false;
}
}
return true;
}
void ggml_backend_tensor_copy(struct ggml_tensor * src, struct ggml_tensor * dst) {
GGML_ASSERT(ggml_are_same_layout(src, dst) && "cannot copy tensors with different layouts");

View file

@ -1,337 +0,0 @@
// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates <open-source-office@arm.com>
// SPDX-License-Identifier: MIT
//
// KleidiAI micro-kernels
#include "kai_matmul_clamp_f32_qsi8d32p_qsi4c32p_interface.h"
#include "kai_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod.h"
#include "kai_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod.h"
#include "kai_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod.h"
#include "kai_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm.h"
#include "kai_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa.h"
#include "kai_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot.h"
#include "kai_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa.h"
#include "kai_lhs_pack_bf16p2vlx2_f32_sme.h"
#include "kai_lhs_quant_pack_qsi8d32p_f32.h"
#include "kai_lhs_quant_pack_qsi8d32p_f32_neon.h"
#include "kai_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme.h"
#include "kai_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0.h"
#include "kai_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon.h"
#include "kai_common.h"
#include "kernels.h"
#define NELEMS(x) sizeof(x) / sizeof(*x)
static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
#if defined(__ARM_FEATURE_SME)
{
/* SME GEMM */
/* .kern_info = */ {
/* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa,
/* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa,
/* .get_mr = */ kai_get_mr_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa,
/* .get_nr = */ kai_get_nr_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa,
/* .get_kr = */ kai_get_kr_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa,
/* .get_sr = */ kai_get_sr_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa,
/* .get_lhs_offset = */ kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa,
/* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa,
/* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa,
/* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa,
/* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa,
},
/* SME GEMV */
/* .kern_info = */ {
/* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot,
/* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot,
/* .get_mr = */ kai_get_mr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot,
/* .get_nr = */ kai_get_nr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot,
/* .get_kr = */ kai_get_kr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot,
/* .get_sr = */ kai_get_sr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot,
/* .get_lhs_offset = */ kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot,
/* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot,
/* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot,
/* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot,
/* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot,
},
/* .lhs_info = */ {
/* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f32_neon,
/* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32_neon,
/* .packed_size = */ kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32_neon,
/* .pack_func = */ kai_run_lhs_quant_pack_qsi8d32p_f32_neon,
},
/* .rhs_info = */ {
/* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon,
/* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon,
},
/* .required_cpu = */ CPU_FEATURE_SME,
/* .lhs_type = */ GGML_TYPE_F32,
/* .rhs_type = */ GGML_TYPE_Q4_0,
/* .op_type = */ GGML_TYPE_F32,
},
{
/* SME GEMM */
/* .kern_info = */ {
/* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
/* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
/* .get_mr = */ kai_get_mr_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
/* .get_nr = */ kai_get_nr_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
/* .get_kr = */ kai_get_kr_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
/* .get_sr = */ kai_get_sr_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
/* .get_lhs_offset = */ kai_get_lhs_packed_offset_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
/* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
/* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
/* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
/* .run_kernel = */ kai_run_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
},
/* SME GEMV */
/* .kern_info = */ {
/* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
/* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
/* .get_mr = */ kai_get_mr_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
/* .get_nr = */ kai_get_nr_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
/* .get_kr = */ kai_get_kr_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
/* .get_sr = */ kai_get_sr_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
/* .get_lhs_offset = */ kai_get_lhs_packed_offset_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
/* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
/* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
/* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
/* .run_kernel = */ kai_run_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
},
/* .lhs_info = */ {
/* .get_offset = */ kai_get_lhs_offset_lhs_pack_bf16p2vlx2_f32_sme,
/* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_pack_bf16p2vlx2_f32_sme,
/* .packed_size = */ kai_get_lhs_packed_size_lhs_pack_bf16p2vlx2_f32_sme,
/* .pack_func = */ kai_run_lhs_pack_bf16p2vlx2_f32_sme,
},
/* .rhs_info = */ {
/* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme,
/* .pack_func = */ kai_run_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme,
},
/* .required_cpu = */ CPU_FEATURE_SME,
/* .lhs_type = */ GGML_TYPE_F32,
/* .rhs_type = */ GGML_TYPE_F16,
/* .op_type = */ GGML_TYPE_F32,
},
#endif
#if defined(__APPLE__)
#if defined(__ARM_FEATURE_DOTPROD)
{
/* DOTPROD GEMM */
/* .kern_info = */ {
/* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
/* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
/* .get_mr = */ kai_get_mr_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
/* .get_nr = */ kai_get_nr_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
/* .get_kr = */ kai_get_kr_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
/* .get_sr = */ kai_get_sr_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
/* .get_lhs_offset = */ kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
/* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
/* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
/* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
/* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
},
/* DOTPROD GEMV */
/* .kern_info = */ {
/* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
/* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
/* .get_mr = */ kai_get_mr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
/* .get_nr = */ kai_get_nr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
/* .get_kr = */ kai_get_kr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
/* .get_sr = */ kai_get_sr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
/* .get_lhs_offset = */ kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
/* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
/* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
/* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
/* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
},
/* .lhs_info = */ {
/* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f32,
/* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32,
/* .packed_size = */ kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32,
/* .pack_func = */ kai_run_lhs_quant_pack_qsi8d32p_f32,
},
/* .rhs_info = */ {
/* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
/* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
},
/* .required_cpu = */ CPU_FEATURE_DOTPROD,
/* .lhs_type = */ GGML_TYPE_F32,
/* .rhs_type = */ GGML_TYPE_Q4_0,
/* .op_type = */ GGML_TYPE_F32,
},
#endif
#if defined(__ARM_FEATURE_MATMUL_INT8)
{
/* i8mm GEMM */
/* .kern_info = */ {
/* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
/* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
/* .get_mr = */ kai_get_mr_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
/* .get_nr = */ kai_get_nr_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
/* .get_kr = */ kai_get_kr_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
/* .get_sr = */ kai_get_sr_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
/* .get_lhs_offset = */ kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
/* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
/* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
/* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
/* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
},
/* i8mm GEMV */
/* .kern_info = */ {
/* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
/* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
/* .get_mr = */ kai_get_mr_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
/* .get_nr = */ kai_get_nr_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
/* .get_kr = */ kai_get_kr_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
/* .get_sr = */ kai_get_sr_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
/* .get_lhs_offset = */ kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
/* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
/* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
/* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
/* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
},
/* .lhs_info = */ {
/* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f32,
/* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32,
/* .packed_size = */ kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32,
/* .pack_func = */ kai_run_lhs_quant_pack_qsi8d32p_f32,
},
/* .rhs_info = */ {
/* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
/* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
},
/* .required_cpu = */ CPU_FEATURE_DOTPROD | CPU_FEATURE_I8MM,
/* .lhs_type = */ GGML_TYPE_F32,
/* .rhs_type = */ GGML_TYPE_Q4_0,
/* .op_type = */ GGML_TYPE_F32,
},
#endif
#else
#if defined(__ARM_FEATURE_MATMUL_INT8)
{
/* i8mm GEMM */
/* .kern_info = */ {
/* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
/* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
/* .get_mr = */ kai_get_mr_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
/* .get_nr = */ kai_get_nr_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
/* .get_kr = */ kai_get_kr_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
/* .get_sr = */ kai_get_sr_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
/* .get_lhs_offset = */ kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
/* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
/* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
/* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
/* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
},
/* i8mm GEMV */
/* .kern_info = */ {
/* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
/* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
/* .get_mr = */ kai_get_mr_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
/* .get_nr = */ kai_get_nr_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
/* .get_kr = */ kai_get_kr_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
/* .get_sr = */ kai_get_sr_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
/* .get_lhs_offset = */ kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
/* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
/* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
/* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
/* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
},
/* .lhs_info = */ {
/* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f32,
/* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32,
/* .packed_size = */ kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32,
/* .pack_func = */ kai_run_lhs_quant_pack_qsi8d32p_f32,
},
/* .rhs_info = */ {
/* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
/* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
},
/* .required_cpu = */ CPU_FEATURE_DOTPROD | CPU_FEATURE_I8MM,
/* .lhs_type = */ GGML_TYPE_F32,
/* .rhs_type = */ GGML_TYPE_Q4_0,
/* .op_type = */ GGML_TYPE_F32,
},
#endif
#if defined(__ARM_FEATURE_DOTPROD)
{
/* DOTPROD GEMM */
/* .kern_info = */ {
/* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
/* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
/* .get_mr = */ kai_get_mr_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
/* .get_nr = */ kai_get_nr_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
/* .get_kr = */ kai_get_kr_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
/* .get_sr = */ kai_get_sr_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
/* .get_lhs_offset = */ kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
/* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
/* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
/* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
/* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
},
/* DOTPROD GEMV */
/* .kern_info = */ {
/* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
/* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
/* .get_mr = */ kai_get_mr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
/* .get_nr = */ kai_get_nr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
/* .get_kr = */ kai_get_kr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
/* .get_sr = */ kai_get_sr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
/* .get_lhs_offset = */ kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
/* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
/* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
/* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
/* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
},
/* .lhs_info = */ {
/* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f32,
/* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32,
/* .packed_size = */ kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32,
/* .pack_func = */ kai_run_lhs_quant_pack_qsi8d32p_f32,
},
/* .rhs_info = */ {
/* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
/* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
},
/* .required_cpu = */ CPU_FEATURE_DOTPROD,
/* .lhs_type = */ GGML_TYPE_F32,
/* .rhs_type = */ GGML_TYPE_Q4_0,
/* .op_type = */ GGML_TYPE_F32,
},
#endif
#endif
};
ggml_kleidiai_kernels * ggml_kleidiai_select_kernels(cpu_feature cpu_features, const ggml_tensor * tensor) {
ggml_kleidiai_kernels * kernel = nullptr;
if (tensor->op == GGML_OP_MUL_MAT && tensor->src[0] != nullptr && tensor->src[1] != nullptr) {
for (size_t i = 0; i < NELEMS(gemm_gemv_kernels); ++i) {
if ((cpu_features & gemm_gemv_kernels[i].required_cpu) == gemm_gemv_kernels[i].required_cpu &&
gemm_gemv_kernels[i].lhs_type == tensor->src[1]->type &&
gemm_gemv_kernels[i].rhs_type == tensor->src[0]->type &&
gemm_gemv_kernels[i].op_type == tensor->type) {
kernel = &gemm_gemv_kernels[i];
break;
}
}
}
return kernel;
}
ggml_kleidiai_kernels * ggml_kleidiai_select_kernels_q4_0(cpu_feature features) {
ggml_kleidiai_kernels * kernels = nullptr;
for (size_t i = 0; i < NELEMS(gemm_gemv_kernels); ++i) {
if ((features & gemm_gemv_kernels[i].required_cpu) == gemm_gemv_kernels[i].required_cpu) {
kernels = &gemm_gemv_kernels[i];
break;
}
}
return kernels;
}

View file

@ -1,95 +0,0 @@
// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates <open-source-office@arm.com>
// SPDX-License-Identifier: MIT
//
#pragma once
#include <functional>
#include <variant>
#include "ggml.h"
enum cpu_feature {
CPU_FEATURE_NONE = 0,
CPU_FEATURE_DOTPROD = 1,
CPU_FEATURE_I8MM = 2,
CPU_FEATURE_SVE = 4,
CPU_FEATURE_SME = 8
};
inline cpu_feature& operator|=(cpu_feature& lhs, cpu_feature rhs) {
lhs = static_cast<cpu_feature>(lhs | rhs);
return lhs;
}
inline cpu_feature operator|(cpu_feature lhs, cpu_feature rhs) {
return static_cast<cpu_feature>(static_cast<int>(lhs) | static_cast<int>(rhs));
}
struct kernel_info {
size_t (*get_m_step)(void);
size_t (*get_n_step)(void);
size_t (*get_mr)(void);
size_t (*get_nr)(void);
size_t (*get_kr)(void);
size_t (*get_sr)(void);
std::variant<
std::function<size_t(size_t n_idx, size_t k, size_t bl)>,
std::function<size_t(size_t m_idx, size_t k)>
> get_lhs_offset;
std::variant<
std::function<size_t(size_t n_idx, size_t k, size_t bl)>,
std::function<size_t(size_t n_idx, size_t k)>
> get_rhs_packed_offset;
size_t (*get_dst_offset)(size_t m_idx, size_t n_idx, size_t stride);
size_t (*get_dst_size)(size_t m, size_t n);
std::variant<
std::function<void(size_t m, size_t n, size_t k, size_t bl, const void* lhs_packed, const void* rhs_packed,
float* dst, size_t dst_stride_row, size_t dst_stride_col, float scalar_min, float scalar_max)>,
std::function<void(size_t m, size_t n, size_t k, const void* lhs_packed, const void* rhs_packed, void* dst, size_t dst_stride_row,
size_t dst_stride_col, float clamp_min, float clamp_max)>
> run_kernel;
};
struct lhs_packing_info {
size_t (*get_offset)(size_t m_idx, size_t lhs_stride);
std::variant<
std::function<size_t(size_t m_idx, size_t k, size_t bl, size_t mr, size_t kr, size_t sr)>,
std::function<size_t(size_t m_idx, size_t k, size_t mr, size_t kr, size_t sr)>
> get_packed_offset;
std::variant<
std::function<size_t(size_t m_idx, size_t k, size_t bl, size_t mr, size_t kr, size_t sr)>,
std::function<size_t(size_t m, size_t k, size_t mr, size_t kr, size_t sr)>
> packed_size;
std::variant<
std::function<void(size_t m, size_t k, size_t bl, size_t mr, size_t kr, size_t sr, size_t m_idx_start, const float* lhs,
size_t lhs_stride, void* lhs_packed)>,
std::function<void(size_t m, size_t k, size_t mr, size_t kr, size_t sr, size_t m_idx_start, const void* lhs, size_t lhs_stride,
void* lhs_packed)>
> pack_func;
};
struct rhs_packing_info {
std::variant<
std::function<size_t(size_t n, size_t k, size_t nr, size_t kr, size_t bl)>,
std::function<size_t(size_t n, size_t k)>
> packed_size;
std::variant<
std::function<void(size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, size_t bl, const uint8_t* rhs,
const float* bias, void* rhs_packed, size_t extra_bytes, const struct kai_rhs_pack_qs4cxs1s0_param* params)>,
std::function<void(size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, size_t rhs_stride, const void* rhs,
const void* bias, const void* scale, void* rhs_packed, size_t extra_bytes, const void* params)>
> pack_func;
};
struct ggml_kleidiai_kernels {
kernel_info gemm;
kernel_info gemv;
lhs_packing_info lhs_info;
rhs_packing_info rhs_info;
cpu_feature required_cpu;
ggml_type lhs_type;
ggml_type rhs_type;
ggml_type op_type;
};
ggml_kleidiai_kernels * ggml_kleidiai_select_kernels(cpu_feature cpu_features, const ggml_tensor * tensor);
ggml_kleidiai_kernels * ggml_kleidiai_select_kernels_q4_0(cpu_feature features);

View file

@ -1,482 +0,0 @@
// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates <open-source-office@arm.com>
// SPDX-License-Identifier: MIT
//
#include <arm_neon.h>
#include <assert.h>
#include <atomic>
#include <cfloat>
#include <stdexcept>
#include <stdint.h>
#include <string.h>
#if defined(__linux__)
#include <asm/hwcap.h>
#include <sys/auxv.h>
#elif defined(__APPLE__)
#include <string_view>
#include <sys/sysctl.h>
#include <sys/types.h>
#elif defined(_WIN32)
#include <windows.h>
#include <excpt.h>
#endif
#include "kleidiai.h"
#include "ggml-cpu.h"
#include "ggml-impl.h"
#include "ggml-backend-impl.h"
#include "ggml-threading.h"
#include "traits.h"
#include "kernels.h"
#include "kai_common.h"
#define GGML_COMMON_DECL_CPP
#include "ggml-common.h"
struct ggml_kleidiai_context {
cpu_feature features;
ggml_kleidiai_kernels * kernels;
} static ctx = { CPU_FEATURE_NONE, NULL };
static void init_kleidiai_context(void) {
ggml_critical_section_start();
static bool initialized = false;
if (!initialized) {
initialized = true;
const char *env_var = getenv("GGML_KLEIDIAI_SME");
int sme_enabled = 0;
ctx.features = (ggml_cpu_has_dotprod() ? CPU_FEATURE_DOTPROD : CPU_FEATURE_NONE) |
(ggml_cpu_has_matmul_int8() ? CPU_FEATURE_I8MM : CPU_FEATURE_NONE) |
(ggml_cpu_has_sve() ? CPU_FEATURE_SVE : CPU_FEATURE_NONE);
if (env_var) {
sme_enabled = atoi(env_var);
}
if (sme_enabled != 0) {
ctx.features |= ggml_cpu_has_sme() ? CPU_FEATURE_SME : CPU_FEATURE_NONE;
}
ctx.kernels = ggml_kleidiai_select_kernels_q4_0(ctx.features);
}
ggml_critical_section_end();
}
static inline int64_t ggml_ne(const ggml_tensor * tensor, int dim) {
GGML_ASSERT(dim >= 0 && dim < GGML_MAX_DIMS);
return tensor->ne[dim];
}
template<typename Ret, typename Variant, typename... Args>
static Ret variant_call(const Variant & var, Args&&... args) {
return std::visit([&](auto&& func) -> Ret {
if constexpr (std::is_invocable_r_v<Ret, decltype(func), Args...>) {
return func(std::forward<Args>(args)...);
} else {
throw std::runtime_error("Invalid function type in variant_call");
}
}, var);
}
namespace ggml::cpu::kleidiai {
static size_t round_down(size_t x, size_t y) {
return y == 0 ? x : x - (x % y);
}
static void transpose_f32kxn_f16nxk(size_t n, size_t k, float * dst, const uint16_t * src, size_t rhs_stride) {
size_t src_stride = rhs_stride / sizeof(uint16_t);
size_t dst_stride = n;
for (size_t k_idx = 0; k_idx < k; ++k_idx) {
for (size_t n_idx = 0; n_idx < n; ++n_idx) {
uint16_t v = *(src + k_idx + n_idx * src_stride);
*(dst + n_idx + k_idx * dst_stride) = kai_cast_f32_f16(v);
}
}
}
class tensor_traits : public ggml::cpu::tensor_traits {
bool work_size(int /* n_threads */, const struct ggml_tensor * op, size_t & size) override {
ggml_kleidiai_kernels *kernels = ggml_kleidiai_select_kernels(ctx.features, op);
GGML_ASSERT(kernels);
kernel_info * kernel = op->src[1]->ne[1] == 1 ? &kernels->gemv : &kernels->gemm;
size_t k = op->src[0]->ne[0];
size_t n = op->src[0]->ne[1];
size_t m = op->src[1]->ne[1];
size_t mr = kernel->get_mr();
size_t kr = kernel->get_kr();
size_t sr = kernel->get_sr();
if (kernels->rhs_type == GGML_TYPE_Q4_0) {
size = variant_call<size_t>(kernels->lhs_info.packed_size, m, k, QK4_0, mr, kr, sr);
} else if (kernels->rhs_type == GGML_TYPE_F16) {
size = variant_call<size_t>(kernels->lhs_info.packed_size, m, k, mr, kr, sr) +
variant_call<size_t>(kernels->rhs_info.packed_size, n, k) +
k * n * sizeof(float) + n * sizeof(float);
} else {
GGML_ASSERT(false);
}
return true;
}
bool compute_forward(struct ggml_compute_params * params, struct ggml_tensor * dst) override {
if (dst->op == GGML_OP_MUL_MAT) {
if (dst->src[0]->type == GGML_TYPE_Q4_0) {
return compute_forward_q4_0(params, dst);
} else if (dst->src[0]->type == GGML_TYPE_F16) {
return compute_forward_kv_cache(params, dst);
}
}
return false;
}
bool compute_forward_kv_cache(ggml_compute_params * params, struct ggml_tensor * dst) {
static std::atomic_flag first_to_arrive = ATOMIC_FLAG_INIT;
const ggml_tensor * src0 = dst->src[0];
const ggml_tensor * src1 = dst->src[1];
GGML_TENSOR_BINARY_OP_LOCALS
ggml_kleidiai_kernels *kernels = ggml_kleidiai_select_kernels(ctx.features, dst);
GGML_ASSERT(kernels);
kernel_info * kernel = src1->ne[1] == 1 ? &kernels->gemv : &kernels->gemm;
GGML_ASSERT(kernel);
const int nth = params->nth;
const int ith = params->ith;
const int64_t lhs_batch_size0 = ne12;
const int64_t rhs_batch_size0 = ne02;
const int64_t batch_size = rhs_batch_size0;
const int64_t r = lhs_batch_size0 / rhs_batch_size0;
const int64_t m = ne11 * r;
const int64_t n = ne01;
const int64_t k = ne00;
const size_t lhs_stride = src1->nb[1];
const size_t rhs_stride = src0->nb[1];
const size_t dst_stride = dst->nb[1];
const int64_t mr = static_cast<int64_t>(kernel->get_mr());
const int64_t nr = static_cast<int64_t>(kernel->get_nr());
const int64_t kr = static_cast<int64_t>(kernel->get_kr());
const int64_t sr = static_cast<int64_t>(kernel->get_sr());
const size_t lhs_packed_size = variant_call<size_t>(kernels->lhs_info.packed_size, m, k, mr, kr, sr);
const size_t rhs_packed_size = variant_call<size_t>(kernels->rhs_info.packed_size, n, k);
const size_t kxn_size = k * n * sizeof(float);
const size_t bias_size = n * sizeof(float);
const size_t wsize_required = lhs_packed_size + rhs_packed_size + kxn_size + bias_size;
GGML_ASSERT(wsize_required <= params->wsize);
uint8_t * lhs_packed = static_cast<uint8_t *>(params->wdata);
uint8_t * rhs_packed = lhs_packed + lhs_packed_size;
uint8_t * rhs_kxn = rhs_packed + rhs_packed_size;
uint8_t * bias = rhs_kxn + kxn_size;
for (int64_t batch_idx = 0; batch_idx < batch_size; ++batch_idx) {
const uint8_t * lhs_batch = static_cast<const uint8_t *>(src1->data) + batch_idx * m * lhs_stride;
const uint8_t * rhs_batch = static_cast<const uint8_t *>(src0->data) + batch_idx * n * rhs_stride;
uint8_t * dst_batch = static_cast<uint8_t *>(dst->data) + batch_idx * m * dst_stride;
// LHS packing
{
const int64_t m_roundup_mr = kai_roundup(m, mr);
const int64_t num_threads = KAI_MIN(m_roundup_mr / mr, nth);
if (ith < num_threads) {
const int64_t num_m_per_thread0 = round_down(m_roundup_mr / num_threads, mr);
const int64_t num_m_per_threadN_1 = m - (num_threads - 1) * num_m_per_thread0;
const int64_t m_start = ith * num_m_per_thread0;
const int64_t num_m_per_thread = (ith == num_threads - 1) ? num_m_per_threadN_1 : num_m_per_thread0;
const size_t lhs_offset = variant_call<size_t>(kernels->gemm.get_lhs_offset, m_start, lhs_stride);
const size_t lhs_packed_offset = variant_call<size_t>(kernels->lhs_info.get_packed_offset, m_start, k, mr, kr, sr);
const void * src_ptr = static_cast<const uint8_t *>(lhs_batch) + lhs_offset;
void * dst_ptr = static_cast<uint8_t *>(lhs_packed) + lhs_packed_offset;
variant_call<void>(kernels->lhs_info.pack_func, num_m_per_thread, k, mr, kr, sr, 0, src_ptr, lhs_stride, dst_ptr);
}
}
// RHS packing
if (first_to_arrive.test_and_set(std::memory_order_acquire) == false) {
// First thread to reach this point handles RHS packing
memset(bias, 0, n * sizeof(float));
transpose_f32kxn_f16nxk(n, k, reinterpret_cast<float *>(rhs_kxn),
reinterpret_cast<const uint16_t *>(rhs_batch), rhs_stride);
variant_call<void>(kernels->rhs_info.pack_func, 1, n, k, nr, kr, sr, n * sizeof(float),
rhs_kxn, bias, nullptr, rhs_packed, 0, nullptr);
}
ggml_barrier(params->threadpool);
first_to_arrive.clear(std::memory_order_release);
// Perform the matmul
{
const int64_t m_to_process = m;
const int64_t m_start = 0;
const int64_t n_step = static_cast<int64_t>(kernel->get_n_step());
const int64_t num_threads = KAI_MIN(n / n_step, nth);
if (ith < num_threads) {
const int64_t num_n_per_thread0 = round_down(n / num_threads, n_step);
const int64_t num_n_per_threadN_1 = n - (num_threads - 1) * num_n_per_thread0;
const int64_t n_start = ith * num_n_per_thread0;
const int64_t n_to_process = (ith == num_threads - 1) ? num_n_per_threadN_1 : num_n_per_thread0;
const size_t lhs_packed_offset = variant_call<size_t>(kernel->get_lhs_offset, m_start, k);
const size_t rhs_packed_offset = variant_call<size_t>(kernel->get_rhs_packed_offset, n_start, k);
const size_t dst_offset = kernel->get_dst_offset(m_start, n_start, dst_stride);
const void * lhs_ptr = lhs_packed + lhs_packed_offset;
const void * rhs_ptr = rhs_packed + rhs_packed_offset;
float * dst_ptr = reinterpret_cast<float *>(dst_batch + dst_offset);
variant_call<void>(kernel->run_kernel, m_to_process, n_to_process, k, lhs_ptr, rhs_ptr, dst_ptr, dst_stride, sizeof(float), -FLT_MAX, FLT_MAX);
}
}
if (batch_idx != batch_size - 1) {
// This barrier is necessary when the batch size is larger than 1. While processing a batch,
// the work data buffer (params->wdata) is used as temporary storage which means that only
// a single batch can be processed at any given time. No barrier is needed for the last
// batch since GGML inserts a barrier between the execution of every operator.
ggml_barrier(params->threadpool);
}
}
return true;
}
bool compute_forward_q4_0(struct ggml_compute_params * params, struct ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const ggml_tensor * src1 = dst->src[1];
GGML_TENSOR_BINARY_OP_LOCALS
ggml_kleidiai_kernels *kernels = ggml_kleidiai_select_kernels(ctx.features, dst);
GGML_ASSERT(kernels);
kernel_info * kernel = src1->ne[1] == 1 ? &kernels->gemv : &kernels->gemm;
lhs_packing_info * lhs_info = &kernels->lhs_info;
GGML_ASSERT(kernel);
const int ith = params->ith;
const int nth = params->nth;
const size_t k = ne00;
const size_t m = ne11;
const size_t n = ne01;
size_t mr = kernel->get_mr();
size_t kr = kernel->get_kr();
size_t sr = kernel->get_sr();
const uint8_t * lhs = static_cast<const uint8_t *>(src1->data);
uint8_t * lhs_packed = (uint8_t*)params->wdata;
const uint8_t * rhs_packed = static_cast<const uint8_t *>(src0->data);
const size_t n_step = kernel->get_n_step();
const size_t num_n_per_thread = kai_roundup(kai_roundup(n, nth) / nth, n_step);
const size_t n_start = ith * num_n_per_thread;
size_t n_to_process = num_n_per_thread;
if ((n_start + n_to_process) > n) {
n_to_process = n - n_start;
}
// Calculate number of columns to be processed per thread
const size_t num_m_per_thread = kai_roundup(m, mr * nth) / nth;
const size_t m_start = ith * num_m_per_thread;
size_t m_to_process = num_m_per_thread;
if ((m_start + m_to_process) > m) {
m_to_process = m - m_start;
}
if (m_start < m) {
// Transform LHS
const size_t src_stride = src1->nb[1];
const float * src_ptr = reinterpret_cast<const float *>(lhs + lhs_info->get_offset(m_start, dst->src[1]->nb[1]));
const size_t lhs_packed_offset = variant_call<size_t>(lhs_info->get_packed_offset, m_start, k, QK4_0, mr, kr, sr);
void * lhs_packed_ptr = static_cast<void *>(lhs_packed + lhs_packed_offset);
variant_call<void>(lhs_info->pack_func, m_to_process, k, QK4_0, mr, kr, sr, 0, src_ptr, src_stride, lhs_packed_ptr);
}
ggml_barrier(params->threadpool);
// Perform the operation
const size_t dst_stride = dst->nb[1];
const size_t lhs_packed_offset = variant_call<size_t>(lhs_info->get_packed_offset, 0, k, QK4_0, mr, kr, sr);
const size_t rhs_packed_offset = variant_call<size_t>(kernel->get_rhs_packed_offset, n_start, k, QK4_0);
const size_t dst_offset = kernel->get_dst_offset(0, n_start, dst_stride);
const void * rhs_ptr = static_cast<const void *>(rhs_packed + rhs_packed_offset);
const void* lhs_ptr = (const void*)((const char *)lhs_packed + lhs_packed_offset);
float *dst_ptr = reinterpret_cast<float *>(static_cast<uint8_t *>(dst->data) + dst_offset);
variant_call<void>(kernel->run_kernel, m, n_to_process, k, QK4_0, lhs_ptr, rhs_ptr, dst_ptr, dst_stride,
sizeof(float), -FLT_MAX, FLT_MAX);
return true;
}
public:
int repack(struct ggml_tensor * tensor, const void * data, size_t data_size) {
GGML_ASSERT(ctx.kernels);
const size_t n = tensor->ne[1];
const size_t k = tensor->ne[0];
size_t nr = ctx.kernels->gemm.get_nr();
size_t kr = ctx.kernels->gemm.get_kr();
size_t sr = ctx.kernels->gemm.get_sr();
#ifndef NDEBUG
const size_t repacked_size = variant_call<size_t>(ctx.kernels->rhs_info.packed_size, n, k, nr, kr, QK4_0);
GGML_ASSERT(repacked_size <= data_size && "repacked size larger than the packed size!");
#endif
struct kai_rhs_pack_qs4cxs1s0_param params;
params.lhs_zero_point = 1;
params.rhs_zero_point = 8;
variant_call<void>(ctx.kernels->rhs_info.pack_func, 1, n, k, nr, kr, sr, QK4_0, (const uint8_t*)data, nullptr, tensor->data, 0, &params);
return 0;
GGML_UNUSED(data_size);
}
};
static ggml::cpu::tensor_traits * get_tensor_traits(ggml_backend_buffer_t, struct ggml_tensor *) {
static tensor_traits traits;
return &traits;
}
} // namespace ggml::cpu::kleidiai
static enum ggml_status ggml_backend_cpu_kleidiai_buffer_init_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor) {
tensor->extra = (void *) ggml::cpu::kleidiai::get_tensor_traits(buffer, tensor);
GGML_UNUSED(buffer);
return GGML_STATUS_SUCCESS;
}
static void ggml_backend_cpu_kleidiai_buffer_set_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor,
const void * data, size_t offset, size_t size) {
GGML_ASSERT(offset == 0);
GGML_ASSERT(size == ggml_nbytes(tensor));
auto tensor_traits = (ggml::cpu::kleidiai::tensor_traits *) tensor->extra;
auto OK = tensor_traits->repack(tensor, data, size);
GGML_ASSERT(OK == 0);
GGML_UNUSED(buffer);
}
static const char * ggml_backend_cpu_kleidiai_buffer_type_get_name(ggml_backend_buffer_type_t buft) {
return "CPU_KLEIDIAI";
GGML_UNUSED(buft);
}
static ggml_backend_buffer_t ggml_backend_cpu_kleidiai_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
ggml_backend_buffer_t buffer = ggml_backend_buft_alloc_buffer(ggml_backend_cpu_buffer_type(), size);
if (buffer == nullptr) {
return nullptr;
}
buffer->buft = buft;
buffer->iface.init_tensor = ggml_backend_cpu_kleidiai_buffer_init_tensor;
buffer->iface.set_tensor = ggml_backend_cpu_kleidiai_buffer_set_tensor;
buffer->iface.get_tensor = nullptr;
buffer->iface.cpy_tensor = nullptr;
return buffer;
}
static size_t ggml_backend_cpu_kleidiai_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {
return TENSOR_ALIGNMENT;
GGML_UNUSED(buft);
}
namespace ggml::cpu::kleidiai {
class extra_buffer_type : ggml::cpu::extra_buffer_type {
bool supports_op(ggml_backend_dev_t, const struct ggml_tensor * op) override {
if (op->op == GGML_OP_MUL_MAT &&
op->src[0]->type == GGML_TYPE_Q4_0 &&
op->src[0]->buffer &&
(ggml_n_dims(op->src[0]) == 2) &&
op->src[0]->buffer->buft == ggml_backend_cpu_kleidiai_buffer_type() && ctx.kernels) {
if (op->src[1]->buffer && !ggml_backend_buft_is_host(op->src[1]->buffer->buft)) {
return false;
}
if (op->src[1]->type == GGML_TYPE_F32 &&
ggml_ne(op->src[1], 2) == 1 && ggml_ne(op->src[1], 3) == 1) {
return true;
}
}
return false;
}
ggml::cpu::tensor_traits * get_tensor_traits(const struct ggml_tensor * op) override {
if (op->op == GGML_OP_MUL_MAT) {
if (op->src[0]->buffer && op->src[0]->buffer->buft == ggml_backend_cpu_kleidiai_buffer_type()) {
return (ggml::cpu::tensor_traits *) op->src[0]->extra;
}
else if (ggml_kleidiai_select_kernels(ctx.features, op) &&
op->src[0]->op == GGML_OP_VIEW &&
(op->src[1]->op == GGML_OP_PERMUTE || op->src[1]->op == GGML_OP_SOFT_MAX) &&
op->src[1]->ne[1] > 1) {
if ((op->src[0]->nb[0] != 2) ||
(op->src[1]->nb[0] != 4) ||
(op->src[0]->nb[1] * op->src[0]->ne[1] != op->src[0]->nb[2]) ||
(op->src[1]->nb[1] * op->src[1]->ne[1] != op->src[1]->nb[2])) {
return nullptr;
}
return ggml::cpu::kleidiai::get_tensor_traits(NULL, NULL);
}
}
return nullptr;
}
};
} // namespace ggml::cpu::kleidiai
ggml_backend_buffer_type_t ggml_backend_cpu_kleidiai_buffer_type(void) {
static ggml::cpu::kleidiai::extra_buffer_type ctx;
static struct ggml_backend_buffer_type ggml_backend_cpu_buffer_type_kleidiai = {
/* .iface = */ {
/* .get_name = */ ggml_backend_cpu_kleidiai_buffer_type_get_name,
/* .alloc_buffer = */ ggml_backend_cpu_kleidiai_buffer_type_alloc_buffer,
/* .get_alignment = */ ggml_backend_cpu_kleidiai_buffer_type_get_alignment,
/* .get_max_size = */ nullptr, // defaults to SIZE_MAX
/* .get_alloc_size = */ nullptr, // defaults to ggml_nbytes
/* .is_host = */ nullptr,
},
/* .device = */ ggml_backend_reg_dev_get(ggml_backend_cpu_reg(), 0),
/* .context = */ &ctx,
};
init_kleidiai_context();
return &ggml_backend_cpu_buffer_type_kleidiai;
}

View file

@ -1,17 +0,0 @@
// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates <open-source-office@arm.com>
// SPDX-License-Identifier: MIT
//
#pragma once
#include "ggml-alloc.h"
#ifdef __cplusplus
extern "C" {
#endif
ggml_backend_buffer_type_t ggml_backend_cpu_kleidiai_buffer_type(void);
#ifdef __cplusplus
}
#endif

View file

@ -73,6 +73,22 @@ static inline int ggml_up(int n, int m) {
return (n + m - 1) & ~(m - 1);
}
// TODO: move to ggml.h?
static bool ggml_are_same_layout(const struct ggml_tensor * a, const struct ggml_tensor * b) {
if (a->type != b->type) {
return false;
}
for (int i = 0; i < GGML_MAX_DIMS; i++) {
if (a->ne[i] != b->ne[i]) {
return false;
}
if (a->nb[i] != b->nb[i]) {
return false;
}
}
return true;
}
//
// logging
//

View file

@ -126,6 +126,7 @@ typedef struct {
uint64_t nb2;
uint64_t nb3;
uint64_t offs;
uint64_t o1[8];
} ggml_metal_kargs_bin;
typedef struct {
@ -240,7 +241,7 @@ typedef struct {
float max_bias;
float m0;
float m1;
uint16_t n_head_log2;
int32_t n_head_log2;
float logit_softcap;
} ggml_metal_kargs_flash_attn_ext;
@ -377,8 +378,16 @@ typedef struct {
typedef struct {
int32_t ne00;
int32_t ne00_4;
uint64_t nb01;
uint64_t nb1;
uint64_t nb2;
uint64_t nb3;
float eps;
int32_t nef1[3];
int32_t nef2[3];
int32_t nef3[3];
uint64_t nbf1[3];
uint64_t nbf2[3];
uint64_t nbf3[3];
} ggml_metal_kargs_rms_norm;
typedef struct {
@ -484,7 +493,7 @@ typedef struct {
float max_bias;
float m0;
float m1;
uint32_t n_head_log2;
int32_t n_head_log2;
} ggml_metal_kargs_soft_max;
typedef struct {

View file

@ -55,6 +55,12 @@ static struct ggml_backend_metal_device_context {
bool has_residency_sets;
bool has_bfloat;
bool use_bfloat;
bool use_fusion;
int debug_fusion;
// how many times a given op was fused
uint64_t fuse_cnt[GGML_OP_COUNT];
size_t max_size;
@ -69,6 +75,9 @@ static struct ggml_backend_metal_device_context {
/*.has_residency_sets =*/ false,
/*.has_bfloat =*/ false,
/*.use_bfloat =*/ false,
/*.use_fusion =*/ true,
/*.debug_fusion =*/ 0,
/*.fuse_cnt =*/ { 0 },
/*.max_size =*/ 0,
/*.name =*/ "",
};
@ -83,16 +92,14 @@ static id<MTLDevice> ggml_backend_metal_device_acq(struct ggml_backend_metal_dev
if (ctx->mtl_device == nil) {
ctx->mtl_device = MTLCreateSystemDefaultDevice();
}
if (ctx->mtl_device) {
ctx->has_simdgroup_reduction = [ctx->mtl_device supportsFamily:MTLGPUFamilyApple7];
ctx->has_simdgroup_reduction |= [ctx->mtl_device supportsFamily:MTLGPUFamilyMetal3_GGML];
ctx->has_simdgroup_mm = [ctx->mtl_device supportsFamily:MTLGPUFamilyApple7];
#if defined(GGML_METAL_HAS_RESIDENCY_SETS)
ctx->has_residency_sets = getenv("GGML_METAL_NO_RESIDENCY") == NULL;
ctx->has_residency_sets = getenv("GGML_METAL_NO_RESIDENCY") == nil;
#endif
ctx->has_bfloat = [ctx->mtl_device supportsFamily:MTLGPUFamilyMetal3_GGML];
@ -103,6 +110,14 @@ static id<MTLDevice> ggml_backend_metal_device_acq(struct ggml_backend_metal_dev
#else
ctx->use_bfloat = false;
#endif
ctx->use_fusion = getenv("GGML_METAL_FUSION_DISABLE") == nil;
{
const char * val = getenv("GGML_METAL_FUSION_DEBUG");
ctx->debug_fusion = val ? atoi(val) : 0;
}
memset(ctx->fuse_cnt, 0, sizeof(ctx->fuse_cnt));
ctx->max_size = ctx->mtl_device.maxBufferLength;
@ -122,6 +137,18 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte
ctx->mtl_device_ref_count--;
if (ctx->mtl_device_ref_count == 0) {
if (ctx->debug_fusion > 0) {
fprintf(stderr, "%s: fusion stats:\n", __func__);
for (int i = 0; i < GGML_OP_COUNT; i++) {
if (ctx->fuse_cnt[i] == 0) {
continue;
}
// note: cannot use ggml_log here
fprintf(stderr, "%s: - %s: %" PRIu64 "\n", __func__, ggml_op_name((enum ggml_op) i), ctx->fuse_cnt[i]);
}
}
if (ctx->mtl_lock) {
[ctx->mtl_lock release];
ctx->mtl_lock = nil;
@ -147,13 +174,27 @@ struct ggml_metal_kernel {
enum ggml_metal_kernel_type {
GGML_METAL_KERNEL_TYPE_ADD,
GGML_METAL_KERNEL_TYPE_ADD_ROW,
GGML_METAL_KERNEL_TYPE_ADD_FUSE_2,
GGML_METAL_KERNEL_TYPE_ADD_FUSE_3,
GGML_METAL_KERNEL_TYPE_ADD_FUSE_4,
GGML_METAL_KERNEL_TYPE_ADD_FUSE_5,
GGML_METAL_KERNEL_TYPE_ADD_FUSE_6,
GGML_METAL_KERNEL_TYPE_ADD_FUSE_7,
GGML_METAL_KERNEL_TYPE_ADD_FUSE_8,
GGML_METAL_KERNEL_TYPE_ADD_ROW_C4,
GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_2,
GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_3,
GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_4,
GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_5,
GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_6,
GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_7,
GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_8,
GGML_METAL_KERNEL_TYPE_SUB,
GGML_METAL_KERNEL_TYPE_SUB_ROW,
GGML_METAL_KERNEL_TYPE_SUB_ROW_C4,
GGML_METAL_KERNEL_TYPE_MUL,
GGML_METAL_KERNEL_TYPE_MUL_ROW,
GGML_METAL_KERNEL_TYPE_MUL_ROW_C4,
GGML_METAL_KERNEL_TYPE_DIV,
GGML_METAL_KERNEL_TYPE_DIV_ROW,
GGML_METAL_KERNEL_TYPE_DIV_ROW_C4,
GGML_METAL_KERNEL_TYPE_REPEAT_F32,
GGML_METAL_KERNEL_TYPE_REPEAT_F16,
GGML_METAL_KERNEL_TYPE_REPEAT_I32,
@ -218,6 +259,8 @@ enum ggml_metal_kernel_type {
GGML_METAL_KERNEL_TYPE_SET_ROWS_Q5_1,
GGML_METAL_KERNEL_TYPE_SET_ROWS_IQ4_NL,
GGML_METAL_KERNEL_TYPE_RMS_NORM,
GGML_METAL_KERNEL_TYPE_RMS_NORM_MUL,
GGML_METAL_KERNEL_TYPE_RMS_NORM_MUL_ADD,
GGML_METAL_KERNEL_TYPE_L2_NORM,
GGML_METAL_KERNEL_TYPE_GROUP_NORM,
GGML_METAL_KERNEL_TYPE_NORM,
@ -1135,13 +1178,27 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
// simd_sum and simd_max requires MTLGPUFamilyApple7
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD, add, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW, add_row, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_FUSE_2, add_fuse_2, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_FUSE_3, add_fuse_3, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_FUSE_4, add_fuse_4, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_FUSE_5, add_fuse_5, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_FUSE_6, add_fuse_6, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_FUSE_7, add_fuse_7, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_FUSE_8, add_fuse_8, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW_C4, add_row_c4, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_2, add_row_c4_fuse_2, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_3, add_row_c4_fuse_3, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_4, add_row_c4_fuse_4, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_5, add_row_c4_fuse_5, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_6, add_row_c4_fuse_6, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_7, add_row_c4_fuse_7, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_8, add_row_c4_fuse_8, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUB, sub, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUB_ROW, sub_row, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUB_ROW_C4, sub_row_c4, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL, mul, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_ROW, mul_row, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_ROW_C4, mul_row_c4, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV, div, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV_ROW, div_row, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV_ROW_C4, div_row_c4, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_F32, repeat_f32, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_F16, repeat_f16, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_I32, repeat_i32, true);
@ -1206,6 +1263,8 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_ROWS_Q5_1, set_rows_q5_1, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_ROWS_IQ4_NL, set_rows_iq4_nl, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RMS_NORM, rms_norm, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RMS_NORM_MUL, rms_norm_mul, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RMS_NORM_MUL_ADD, rms_norm_mul_add, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_L2_NORM, l2_norm, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GROUP_NORM, group_norm, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NORM, norm, true);
@ -1893,7 +1952,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
}
}
static bool ggml_metal_encode_node(
static int ggml_metal_encode_node(
ggml_backend_t backend,
int idx,
id<MTLComputeCommandEncoder> encoder,
@ -1903,7 +1962,10 @@ static bool ggml_metal_encode_node(
struct ggml_cgraph * gf = ctx->gf;
struct ggml_tensor * node = ggml_graph_node(gf, idx);
enum ggml_op ops[8];
struct ggml_tensor ** nodes = ggml_graph_nodes(gf) + idx;
struct ggml_tensor * node = nodes[0];
//GGML_LOG_INFO("%s: encoding node %3d, op = %8s\n", __func__, idx, ggml_op_name(node->op));
@ -1913,7 +1975,7 @@ static bool ggml_metal_encode_node(
struct ggml_tensor * dst = node;
if (ggml_is_empty(dst)) {
return true;
return 1;
}
switch (dst->op) {
@ -1924,7 +1986,7 @@ static bool ggml_metal_encode_node(
case GGML_OP_PERMUTE:
{
// noop -> next node
} return true;
} return 1;
default:
{
} break;
@ -1991,6 +2053,8 @@ static bool ggml_metal_encode_node(
id<MTLBuffer> id_src2 = src2 ? ggml_metal_get_buffer(src2, &offs_src2) : nil;
id<MTLBuffer> id_dst = dst ? ggml_metal_get_buffer(dst, &offs_dst) : nil;
int n_fuse = 1;
#if 0
GGML_LOG_INFO("%s: op - %s\n", __func__, ggml_op_name(dst->op));
if (src0) {
@ -2062,37 +2126,15 @@ static bool ggml_metal_encode_node(
GGML_ASSERT(src0t == GGML_TYPE_F32);
GGML_ASSERT(src1t == GGML_TYPE_F32);
GGML_ASSERT(ggml_is_contiguous_rows(src0));
GGML_ASSERT(ggml_is_contiguous_rows(src1));
const size_t offs = 0;
bool bcast_row = false;
id<MTLComputePipelineState> pipeline = nil;
if (ggml_nelements(src1) == ne10 && ggml_is_contiguous(src1) && ne00 % 4 == 0 && ne10 % 4 == 0) {
GGML_ASSERT(ggml_is_contiguous(src0));
// src1 is a row
GGML_ASSERT(ne11 == 1);
switch (dst->op) {
case GGML_OP_ADD: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW].pipeline; break;
case GGML_OP_SUB: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SUB_ROW].pipeline; break;
case GGML_OP_MUL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_ROW].pipeline; break;
case GGML_OP_DIV: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIV_ROW].pipeline; break;
default: GGML_ABORT("fatal error");
}
bcast_row = true;
} else {
switch (dst->op) {
case GGML_OP_ADD: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD].pipeline; break;
case GGML_OP_SUB: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SUB].pipeline; break;
case GGML_OP_MUL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL].pipeline; break;
case GGML_OP_DIV: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIV].pipeline; break;
default: GGML_ABORT("fatal error");
}
}
ggml_metal_kargs_bin args = {
/*.ne00 =*/ ne00,
/*.ne01 =*/ ne01,
@ -2119,12 +2161,117 @@ static bool ggml_metal_encode_node(
/*.nb2 =*/ nb2,
/*.nb3 =*/ nb3,
/*.offs =*/ offs,
/*.o1 =*/ { offs_src1 },
};
// c[0] = add(a, b[0])
// c[1] = add(c[0], b[1])
// c[2] = add(c[1], b[2])
// ...
if (ctx_dev->use_fusion) {
ops[0] = GGML_OP_ADD;
ops[1] = GGML_OP_ADD;
ops[2] = GGML_OP_ADD;
ops[3] = GGML_OP_ADD;
ops[4] = GGML_OP_ADD;
ops[5] = GGML_OP_ADD;
ops[6] = GGML_OP_ADD;
ops[7] = GGML_OP_ADD;
size_t offs_fuse;
id<MTLBuffer> id_fuse;
for (n_fuse = 0; n_fuse <= 6; ++n_fuse) {
if (!ggml_can_fuse(gf, idx + n_fuse, ops + n_fuse, 2)) {
break;
}
if (nodes[n_fuse] != nodes[n_fuse + 1]->src[0]) {
break;
}
// b[0] === b[1] === ...
if (!ggml_are_same_layout(nodes[n_fuse]->src[1], nodes[n_fuse + 1]->src[1])) {
break;
}
// only fuse nodes if src1 is in the same Metal buffer
id_fuse = ggml_metal_get_buffer(nodes[n_fuse + 1]->src[1], &offs_fuse);
if (id_fuse != id_src1) {
break;
}
ctx_dev->fuse_cnt[nodes[n_fuse + 1]->op]++;
args.o1[n_fuse + 1] = offs_fuse;
}
++n_fuse;
if (ctx_dev->debug_fusion > 1 && n_fuse > 1) {
GGML_LOG_DEBUG("%s: fuse: ADD x %d\n", __func__, n_fuse);
}
}
if (ggml_nelements(src1) == ne10 && ggml_is_contiguous(src1) && ne00 % 4 == 0 && ne10 % 4 == 0) {
GGML_ASSERT(ggml_is_contiguous(src0));
// src1 is a row
GGML_ASSERT(ne11 == 1);
switch (dst->op) {
case GGML_OP_ADD:
{
switch (n_fuse) {
case 1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW_C4 ].pipeline; break;
case 2: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_2].pipeline; break;
case 3: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_3].pipeline; break;
case 4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_4].pipeline; break;
case 5: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_5].pipeline; break;
case 6: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_6].pipeline; break;
case 7: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_7].pipeline; break;
case 8: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_8].pipeline; break;
default: GGML_ABORT("fatal error");
}
} break;
case GGML_OP_SUB: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SUB_ROW_C4].pipeline; break;
case GGML_OP_MUL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_ROW_C4].pipeline; break;
case GGML_OP_DIV: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIV_ROW_C4].pipeline; break;
default: GGML_ABORT("fatal error");
}
bcast_row = true;
} else {
switch (dst->op) {
case GGML_OP_ADD:
{
switch (n_fuse) {
case 1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD ].pipeline; break;
case 2: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_FUSE_2].pipeline; break;
case 3: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_FUSE_3].pipeline; break;
case 4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_FUSE_4].pipeline; break;
case 5: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_FUSE_5].pipeline; break;
case 6: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_FUSE_6].pipeline; break;
case 7: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_FUSE_7].pipeline; break;
case 8: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_FUSE_8].pipeline; break;
default: GGML_ABORT("fatal error");
}
} break;
case GGML_OP_SUB: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SUB].pipeline; break;
case GGML_OP_MUL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL].pipeline; break;
case GGML_OP_DIV: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIV].pipeline; break;
default: GGML_ABORT("fatal error");
}
}
if (n_fuse > 1) {
id_dst = ggml_metal_get_buffer(nodes[n_fuse - 1], &offs_dst);
}
[encoder setComputePipelineState:pipeline];
[encoder setBytes:&args length:sizeof(args) atIndex:0];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:2];
[encoder setBuffer:id_src1 offset:0 atIndex:2];
[encoder setBuffer:id_dst offset:offs_dst atIndex:3];
if (bcast_row) {
@ -2132,7 +2279,11 @@ static bool ggml_metal_encode_node(
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
} else {
const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne0);
int nth = 32;
while (16*nth < ne0 && nth < (int) pipeline.maxTotalThreadsPerThreadgroup) {
nth *= 2;
}
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
}
@ -2257,12 +2408,13 @@ static bool ggml_metal_encode_node(
/*.nb2 =*/ pnb2,
/*.nb3 =*/ pnb3,
/*.offs =*/ offs,
/*.o1 =*/ { offs_src1},
};
[encoder setComputePipelineState:pipeline];
[encoder setBytes:&args length:sizeof(args) atIndex:0];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:2];
[encoder setBuffer:id_src1 offset:0 atIndex:2];
[encoder setBuffer:id_dst offset:offs_dst atIndex:3];
const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne00);
@ -2764,7 +2916,7 @@ static bool ggml_metal_encode_node(
id<MTLBuffer> h_src0 = h_src0 = ggml_metal_mem_pool_alloc(mem_pool, ggml_nbytes(src0));
if (!h_src0) {
GGML_LOG_ERROR("%s: failed to allocate buffer from memory pool, size = %zu\n", __func__, ggml_nbytes(src0));
return false;
return 0;
}
offs_src0 = 0;
@ -3640,7 +3792,7 @@ static bool ggml_metal_encode_node(
id<MTLBuffer> h_src1 = ggml_metal_mem_pool_alloc(mem_pool, s_src1);
if (!h_src1) {
GGML_LOG_ERROR("%s: failed to allocate buffer from memory pool, size = %zu\n", __func__, s_src1);
return false;
return 0;
}
const int64_t neh0 = ne0;
@ -3656,7 +3808,7 @@ static bool ggml_metal_encode_node(
id<MTLBuffer> h_dst = ggml_metal_mem_pool_alloc(mem_pool, s_dst);
if (!h_dst) {
GGML_LOG_ERROR("%s: failed to allocate buffer from memory pool, size = %zu\n", __func__, s_dst);
return false;
return 0;
}
// tokens per expert
@ -3664,7 +3816,7 @@ static bool ggml_metal_encode_node(
id<MTLBuffer> h_tpe = ggml_metal_mem_pool_alloc(mem_pool, s_tpe);
if (!h_tpe) {
GGML_LOG_ERROR("%s: failed to allocate buffer from memory pool, size = %zu\n", __func__, s_tpe);
return false;
return 0;
}
// id map
@ -3673,7 +3825,7 @@ static bool ggml_metal_encode_node(
id<MTLBuffer> h_ids = ggml_metal_mem_pool_alloc(mem_pool, s_ids);
if (!h_ids) {
GGML_LOG_ERROR("%s: failed to allocate buffer from memory pool, size = %zu\n", __func__, s_ids);
return false;
return 0;
}
{
@ -4105,12 +4257,95 @@ static bool ggml_metal_encode_node(
case GGML_OP_RMS_NORM:
{
GGML_ASSERT(ne00 % 4 == 0);
GGML_ASSERT(ggml_is_contiguous_1(src0));
GGML_ASSERT(ggml_is_contiguous_rows(src0));
float eps;
memcpy(&eps, dst->op_params, sizeof(float));
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_RMS_NORM].pipeline;
ggml_metal_kargs_rms_norm args = {
/*.ne00 =*/ ne00,
/*.ne00_4 =*/ ne00/4,
/*.nb1 =*/ nb1,
/*.nb2 =*/ nb2,
/*.nb3 =*/ nb3,
/*.eps =*/ eps,
/*.nef1 =*/ { ne01 },
/*.nef2 =*/ { ne02 },
/*.nef3 =*/ { ne03 },
/*.nbf1 =*/ { nb01 },
/*.nbf2 =*/ { nb02 },
/*.nbf3 =*/ { nb03 },
};
size_t offs_fuse[2] = { 0, 0 };
id<MTLBuffer> id_fuse[2] = { id_src0, id_src0 };
// d[0] = rms_norm(a)
// d[1] = mul(d[0], b)
// d[2] = add(d[1], c)
if (ctx_dev->use_fusion) {
ops[0] = GGML_OP_RMS_NORM;
ops[1] = GGML_OP_MUL;
ops[2] = GGML_OP_ADD;
for (n_fuse = 0; n_fuse <= 1; ++n_fuse) {
if (!ggml_can_fuse(gf, idx + n_fuse, ops + n_fuse, 2)) {
break;
}
if (nodes[n_fuse] != nodes[n_fuse + 1]->src[0]) {
break;
}
if (nodes[n_fuse + 1]->src[1]->ne[0] != node->ne[0]) {
break;
}
if (!ggml_is_contiguous_rows(nodes[n_fuse + 1]->src[1])) {
break;
}
if (nodes[n_fuse + 1]->type != GGML_TYPE_F32) {
break;
}
ctx_dev->fuse_cnt[nodes[n_fuse + 1]->op]++;
id_fuse[n_fuse] = ggml_metal_get_buffer(nodes[n_fuse + 1]->src[1], &offs_fuse[n_fuse]);
args.nef1[n_fuse + 1] = nodes[n_fuse + 1]->src[1]->ne[1];
args.nef2[n_fuse + 1] = nodes[n_fuse + 1]->src[1]->ne[2];
args.nef3[n_fuse + 1] = nodes[n_fuse + 1]->src[1]->ne[3];
args.nbf1[n_fuse + 1] = nodes[n_fuse + 1]->src[1]->nb[1];
args.nbf2[n_fuse + 1] = nodes[n_fuse + 1]->src[1]->nb[2];
args.nbf3[n_fuse + 1] = nodes[n_fuse + 1]->src[1]->nb[3];
}
++n_fuse;
if (ctx_dev->debug_fusion > 1 && n_fuse > 1) {
if (n_fuse == 2) {
GGML_LOG_DEBUG("%s: fuse: RMS_NORM + MUL\n", __func__);
}
if (n_fuse == 3) {
GGML_LOG_DEBUG("%s: fuse: RMS_NORM + MUL + ADD\n", __func__);
}
}
}
if (n_fuse > 1) {
id_dst = ggml_metal_get_buffer(nodes[n_fuse - 1], &offs_dst);
}
id<MTLComputePipelineState> pipeline;
switch (n_fuse) {
case 1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_RMS_NORM ].pipeline; break;
case 2: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_RMS_NORM_MUL ].pipeline; break;
case 3: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_RMS_NORM_MUL_ADD].pipeline; break;
default: GGML_ABORT("unsupported n_fuse = %d\n", n_fuse);
}
int nth = 32; // SIMD width
@ -4121,23 +4356,16 @@ static bool ggml_metal_encode_node(
nth = MIN(nth, (int) pipeline.maxTotalThreadsPerThreadgroup);
nth = MIN(nth, ne00/4);
ggml_metal_kargs_rms_norm args = {
/*.ne00 =*/ ne00,
/*.ne00_4 =*/ ne00/4,
/*.nb01 =*/ nb01,
/*.eps =*/ eps,
};
[encoder setComputePipelineState:pipeline];
[encoder setBytes:&args length:sizeof(args) atIndex:0];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
[encoder setBytes:&args length:sizeof(args) atIndex:0];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
[encoder setBuffer:id_fuse[0] offset:offs_fuse[0] atIndex:2];
[encoder setBuffer:id_fuse[1] offset:offs_fuse[1] atIndex:3];
[encoder setBuffer:id_dst offset:offs_dst atIndex:4];
[encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
const int64_t nrows = ggml_nrows(src0);
[encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
} break;
case GGML_OP_L2_NORM:
{
@ -5532,7 +5760,7 @@ static bool ggml_metal_encode_node(
}
}
return true;
return n_fuse;
}
static enum ggml_status ggml_metal_graph_compute(
@ -6038,20 +6266,22 @@ static void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) {
struct ggml_metal_mem_pool * mem_pool = ctx->cmd_bufs[cb_idx].mem_pool;
ggml_metal_mem_pool_reset(mem_pool);
for (int idx = node_start; idx < node_end; ++idx) {
for (int idx = node_start; idx < node_end;) {
if (should_capture) {
[encoder pushDebugGroup:[NSString stringWithCString:ggml_op_desc(ggml_graph_node(ctx->gf, idx)) encoding:NSUTF8StringEncoding]];
}
const bool res = ggml_metal_encode_node(backend, idx, encoder, mem_pool);
const int res = ggml_metal_encode_node(backend, idx, encoder, mem_pool);
if (should_capture) {
[encoder popDebugGroup];
}
if (!res) {
if (res == 0) {
break;
}
idx += res;
}
[encoder endEncoding];

View file

@ -832,7 +832,8 @@ enum ggml_sort_order {
// general-purpose kernel for addition, subtraction, multiplication and division of two tensors
// pros: works for non-contiguous tensors, supports broadcast across all dims
// cons: not very efficient
kernel void kernel_add(
template <int F>
kernel void kernel_add_fuse_impl(
constant ggml_metal_kargs_bin & args,
device const char * src0,
device const char * src1,
@ -848,16 +849,39 @@ kernel void kernel_add(
const int i12 = i02%args.ne12;
const int i11 = i01%args.ne11;
device const char * src0_ptr = src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + args.offs;
device const char * src1_ptr = src1 + i13*args.nb13 + i12*args.nb12 + i11*args.nb11;
device char * dst_ptr = dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1 + args.offs;
device const float * src0_ptr = (device const float *) (src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + args.offs);
device float * dst_ptr = (device float *) (dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1 + args.offs);
device const float * src1_ptr[F];
for (short j = 0; j < F; ++j) {
src1_ptr[j] = (device const float *) (src1 + args.o1[j] + i13*args.nb13 + i12*args.nb12 + i11*args.nb11);
}
for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
const int i10 = i0%args.ne10;
*((device float *)(dst_ptr + i0*args.nb0)) = *((device float *)(src0_ptr + i0*args.nb00)) + *((device float *)(src1_ptr + i10*args.nb10));
float res = src0_ptr[i0];
#pragma unroll
for (short j = 0; j < F; ++j) {
res += src1_ptr[j][i10];
}
dst_ptr[i0] = res;
}
}
typedef decltype(kernel_add_fuse_impl<2>) kernel_add_fuse_t;
template [[host_name("kernel_add")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<1>;
template [[host_name("kernel_add_fuse_2")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<2>;
template [[host_name("kernel_add_fuse_3")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<3>;
template [[host_name("kernel_add_fuse_4")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<4>;
template [[host_name("kernel_add_fuse_5")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<5>;
template [[host_name("kernel_add_fuse_6")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<6>;
template [[host_name("kernel_add_fuse_7")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<7>;
template [[host_name("kernel_add_fuse_8")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<8>;
kernel void kernel_sub(
constant ggml_metal_kargs_bin & args,
device const char * src0,
@ -875,7 +899,7 @@ kernel void kernel_sub(
const int i11 = i01%args.ne11;
device const char * src0_ptr = src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + args.offs;
device const char * src1_ptr = src1 + i13*args.nb13 + i12*args.nb12 + i11*args.nb11;
device const char * src1_ptr = src1 + i13*args.nb13 + i12*args.nb12 + i11*args.nb11 + args.o1[0];
device char * dst_ptr = dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1 + args.offs;
for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
@ -900,9 +924,9 @@ kernel void kernel_mul(
const int i12 = i02%args.ne12;
const int i11 = i01%args.ne11;
device const char * src0_ptr = src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01;
device const char * src1_ptr = src1 + i13*args.nb13 + i12*args.nb12 + i11*args.nb11;
device char * dst_ptr = dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1;
device const char * src0_ptr = src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + args.offs;
device const char * src1_ptr = src1 + i13*args.nb13 + i12*args.nb12 + i11*args.nb11 + args.o1[0];
device char * dst_ptr = dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1 + args.offs;
for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
const int i10 = i0%args.ne10;
@ -926,9 +950,9 @@ kernel void kernel_div(
const int i12 = i02%args.ne12;
const int i11 = i01%args.ne11;
device const char * src0_ptr = src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01;
device const char * src1_ptr = src1 + i13*args.nb13 + i12*args.nb12 + i11*args.nb11;
device char * dst_ptr = dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1;
device const char * src0_ptr = src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + args.offs;
device const char * src1_ptr = src1 + i13*args.nb13 + i12*args.nb12 + i11*args.nb11 + args.o1[0];
device char * dst_ptr = dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1 + args.offs;
for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
const int i10 = i0%args.ne10;
@ -970,46 +994,145 @@ template [[host_name("kernel_repeat_i16")]] kernel kernel_repeat_t kernel_repeat
// assumption: src1 is a row
// broadcast src1 into src0
kernel void kernel_add_row(
template <short F>
kernel void kernel_add_row_c4_fuse_impl(
constant ggml_metal_kargs_bin & args,
device const float4 * src0,
device const float4 * src1,
device float4 * dst,
device const char * src0,
device const char * src1,
device char * dst,
uint tpig[[thread_position_in_grid]]) {
const uint nb = args.ne00/4;
dst[tpig] = src0[tpig] + src1[tpig % nb];
const uint i = tpig % nb;
device const float4 * src0_row = (device const float4 *) (src0);
device float4 * dst_row = (device float4 *) (dst);
device const float4 * src1_row[F];
for (short j = 0; j < F; ++j) {
src1_row[j] = (device const float4 *) (src1 + args.o1[j]);
}
float4 res = src0_row[tpig];
#pragma unroll(F)
for (short j = 0; j < F; ++j) {
res += src1_row[j][i];
}
dst_row[tpig] = res;
}
kernel void kernel_sub_row(
typedef decltype(kernel_add_row_c4_fuse_impl<1>) kernel_add_row_c4_fuse_t;
template [[host_name("kernel_add_row_c4")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<1>;
template [[host_name("kernel_add_row_c4_fuse_2")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<2>;
template [[host_name("kernel_add_row_c4_fuse_3")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<3>;
template [[host_name("kernel_add_row_c4_fuse_4")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<4>;
template [[host_name("kernel_add_row_c4_fuse_5")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<5>;
template [[host_name("kernel_add_row_c4_fuse_6")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<6>;
template [[host_name("kernel_add_row_c4_fuse_7")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<7>;
template [[host_name("kernel_add_row_c4_fuse_8")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<8>;
template <short F>
kernel void kernel_sub_row_c4_fuse_impl(
constant ggml_metal_kargs_bin & args,
device const float4 * src0,
device const float4 * src1,
device float4 * dst,
device const char * src0,
device const char * src1,
device char * dst,
uint tpig[[thread_position_in_grid]]) {
const uint nb = args.ne00/4;
dst[tpig] = src0[tpig] - src1[tpig % nb];
const uint i = tpig % nb;
device const float4 * src0_row = (device const float4 *) (src0);
device float4 * dst_row = (device float4 *) (dst);
device const float4 * src1_row[F];
for (short j = 0; j < F; ++j) {
src1_row[j] = (device const float4 *) (src1 + args.o1[j]);
}
float4 res = src0_row[tpig];
#pragma unroll(F)
for (short j = 0; j < F; ++j) {
res -= src1_row[j][i];
}
dst_row[tpig] = res;
}
kernel void kernel_mul_row(
typedef decltype(kernel_sub_row_c4_fuse_impl<1>) kernel_sub_row_c4_fuse_t;
template [[host_name("kernel_sub_row_c4")]] kernel kernel_sub_row_c4_fuse_t kernel_sub_row_c4_fuse_impl<1>;
template <short F>
kernel void kernel_mul_row_c4_fuse_impl(
constant ggml_metal_kargs_bin & args,
device const float4 * src0,
device const float4 * src1,
device float4 * dst,
device const char * src0,
device const char * src1,
device char * dst,
uint tpig[[thread_position_in_grid]]) {
const uint nb = args.ne00/4;
dst[tpig] = src0[tpig] * src1[tpig % nb];
const uint i = tpig % nb;
device const float4 * src0_row = (device const float4 *) (src0);
device float4 * dst_row = (device float4 *) (dst);
device const float4 * src1_row[F];
for (short j = 0; j < F; ++j) {
src1_row[j] = (device const float4 *) (src1 + args.o1[j]);
}
float4 res = src0_row[tpig];
#pragma unroll(F)
for (short j = 0; j < F; ++j) {
res *= src1_row[j][i];
}
dst_row[tpig] = res;
}
kernel void kernel_div_row(
typedef decltype(kernel_mul_row_c4_fuse_impl<1>) kernel_mul_row_c4_fuse_t;
template [[host_name("kernel_mul_row_c4")]] kernel kernel_mul_row_c4_fuse_t kernel_mul_row_c4_fuse_impl<1>;
template <short F>
kernel void kernel_div_row_c4_fuse_impl(
constant ggml_metal_kargs_bin & args,
device const float4 * src0,
device const float4 * src1,
device float4 * dst,
device const char * src0,
device const char * src1,
device char * dst,
uint tpig[[thread_position_in_grid]]) {
const uint nb = args.ne00/4;
dst[tpig] = src0[tpig] / src1[tpig % nb];
const uint i = tpig % nb;
device const float4 * src0_row = (device const float4 *) (src0);
device float4 * dst_row = (device float4 *) (dst);
device const float4 * src1_row[F];
for (short j = 0; j < F; ++j) {
src1_row[j] = (device const float4 *) (src1 + args.o1[j]);
}
float4 res = src0_row[tpig];
#pragma unroll(F)
for (short j = 0; j < F; ++j) {
res /= src1_row[j][i];
}
dst_row[tpig] = res;
}
typedef decltype(kernel_div_row_c4_fuse_impl<1>) kernel_div_row_c4_fuse_t;
template [[host_name("kernel_div_row_c4")]] kernel kernel_div_row_c4_fuse_t kernel_div_row_c4_fuse_impl<1>;
kernel void kernel_scale(
device const float * src0,
device float * dst,
@ -2116,26 +2239,39 @@ kernel void kernel_norm(
}
}
kernel void kernel_rms_norm(
// F == 1 : rms_norm (no fuse)
// F == 2 : rms_norm + mul
// F == 3 : rms_norm + mul + add
template <short F>
kernel void kernel_rms_norm_fuse_impl(
constant ggml_metal_kargs_rms_norm & args,
device const char * src0,
device const char * src1_0,
device const char * src1_1,
device char * dst,
threadgroup float * shmem_f32 [[threadgroup(0)]],
uint tgpig[[threadgroup_position_in_grid]],
ushort tpitg[[thread_position_in_threadgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]],
ushort tiisg[[thread_index_in_simdgroup]],
ushort ntg[[threads_per_threadgroup]]) {
uint3 tgpig[[threadgroup_position_in_grid]],
ushort3 tpitg[[thread_position_in_threadgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]],
ushort tiisg[[thread_index_in_simdgroup]],
ushort3 ntg[[threads_per_threadgroup]]) {
if (sgitg == 0) {
shmem_f32[tiisg] = 0.0f;
}
device const float4 * x = (device const float4 *) (src0 + tgpig*args.nb01);
const int i01 = tgpig.x;
const int i02 = tgpig.y;
const int i03 = tgpig.z;
device const float4 * x = (device const float4 *) (src0 + i03*args.nbf3[0] + i02*args.nbf2[0] + i01*args.nbf1[0]);
device const float4 * f0 = (device const float4 *) (src1_0 + (i03%args.nef3[1])*args.nbf3[1] + (i02%args.nef2[1])*args.nbf2[1] + (i01%args.nef1[1])*args.nbf1[1]);
device const float4 * f1 = (device const float4 *) (src1_1 + (i03%args.nef3[2])*args.nbf3[2] + (i02%args.nef2[2])*args.nbf2[2] + (i01%args.nef1[2])*args.nbf1[2]);
float sumf = 0.0f;
// parallel sum
for (int i00 = tpitg; i00 < args.ne00_4; i00 += ntg) {
for (int i00 = tpitg.x; i00 < args.ne00_4; i00 += ntg.x) {
sumf += dot(x[i00], x[i00]);
}
sumf = simd_sum(sumf);
@ -2154,12 +2290,26 @@ kernel void kernel_rms_norm(
const float mean = sumf/args.ne00;
const float scale = 1.0f/sqrt(mean + args.eps);
device float4 * y = (device float4 *) dst + tgpig*args.ne00_4;
for (int i00 = tpitg; i00 < args.ne00_4; i00 += ntg) {
y[i00] = x[i00] * scale;
device float4 * y = (device float4 *) (dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1);
for (int i00 = tpitg.x; i00 < args.ne00_4; i00 += ntg.x) {
if (F == 1) {
y[i00] = (x[i00]*scale);
}
if (F == 2) {
y[i00] = (x[i00]*scale)*f0[i00];
}
if (F == 3) {
y[i00] = (x[i00]*scale)*f0[i00] + f1[i00];
}
}
}
typedef decltype(kernel_rms_norm_fuse_impl<1>) kernel_rms_norm_fuse_t;
template [[host_name("kernel_rms_norm")]] kernel kernel_rms_norm_fuse_t kernel_rms_norm_fuse_impl<1>;
template [[host_name("kernel_rms_norm_mul")]] kernel kernel_rms_norm_fuse_t kernel_rms_norm_fuse_impl<2>;
template [[host_name("kernel_rms_norm_mul_add")]] kernel kernel_rms_norm_fuse_t kernel_rms_norm_fuse_impl<3>;
kernel void kernel_l2_norm(
constant ggml_metal_kargs_l2_norm & args,
device const char * src0,

View file

@ -344,6 +344,7 @@ struct vk_device_struct {
uint64_t max_memory_allocation_size;
uint64_t suballocation_block_size;
bool fp16;
bool bf16;
bool pipeline_robustness;
vk::Device device;
uint32_t vendor_id;
@ -498,6 +499,7 @@ struct vk_device_struct {
vk_pipeline pipeline_rwkv_wkv6_f32;
vk_pipeline pipeline_rwkv_wkv7_f32;
vk_pipeline pipeline_opt_step_adamw_f32;
vk_pipeline pipeline_conv2d_f32;
vk_pipeline pipeline_conv2d_dw_whcn_f32;
vk_pipeline pipeline_conv2d_dw_cwhn_f32;
@ -891,6 +893,38 @@ struct vk_op_rwkv_wkv7_push_constants {
uint32_t H;
};
struct vk_op_conv2d_push_constants {
uint32_t Cout;
uint32_t Cin;
uint32_t N;
uint32_t KW;
uint32_t KH;
uint32_t W;
uint32_t H;
uint32_t OW;
uint32_t OH;
uint32_t s0;
uint32_t s1;
uint32_t p0;
uint32_t p1;
uint32_t d0;
uint32_t d1;
uint32_t nb01;
uint32_t nb02;
uint32_t nb03;
uint32_t nb11;
uint32_t nb12;
uint32_t nb13;
uint32_t nb1;
uint32_t nb2;
uint32_t nb3;
};
struct vk_op_conv2d_dw_push_constants {
uint32_t ne;
uint32_t batches;
@ -990,18 +1024,45 @@ private:
#endif // GGML_VULKAN_MEMORY_DEBUG
class vk_perf_logger {
public:
public:
void print_timings() {
if (timings.empty()) {
return;
}
uint64_t total_all_op_times = 0;
std::cerr << "----------------\nVulkan Timings:" << std::endl;
for (const auto& t : timings) {
uint64_t total = 0;
for (const auto& time : t.second) {
total += time;
for (const auto & t : timings) {
uint64_t total_op_times = 0;
for (const auto & time : t.second) {
total_op_times += time;
}
std::cerr << t.first << ": " << t.second.size() << " x " << (total / t.second.size() / 1000.0) << " us" << std::endl;
std::cerr << t.first << ": " << t.second.size() << " x " << (total_op_times / t.second.size() / 1000.0)
<< " us";
// If we have as many flops entries as timing entries for the op, then compute and log the flops/S.
auto it = flops.find(t.first);
if (it != flops.end() && (it->second).size() == t.second.size()) {
uint64_t total_op_flops = 0;
for (const auto & elem : it->second) {
total_op_flops += elem;
}
std::cerr << " ("
<< (double(total_op_flops) / (1000.0 * 1000.0 * 1000.0)) /
(double(total_op_times) / (1000.0 * 1000.0 * 1000.0))
<< " GFLOPS/s)";
}
total_all_op_times += total_op_times;
std::cerr << std::endl;
}
if (timings.size() > 0) {
std::cerr << "Total time: " << total_all_op_times / 1000.0 << " us." << std::endl;
}
timings.clear();
flops.clear();
}
void log_timing(const ggml_tensor * node, uint64_t time) {
@ -1010,22 +1071,45 @@ public:
return;
}
if (node->op == GGML_OP_MUL_MAT || node->op == GGML_OP_MUL_MAT_ID) {
const uint64_t m = node->src[0]->ne[1];
const uint64_t n = node->src[1]->ne[1];
const uint64_t k = node->src[1]->ne[0];
std::string name = ggml_op_name(node->op);
const uint64_t m = node->src[0]->ne[1];
const uint64_t n = node->src[1]->ne[1];
const uint64_t k = node->src[1]->ne[0];
std::string name = ggml_op_name(node->op);
if (n == 1) {
name += "_VEC m=" + std::to_string(m) + " k=" + std::to_string(k);
} else {
name += " m=" + std::to_string(m) + " n=" + std::to_string(n) + " k=" + std::to_string(k);
}
timings[name].push_back(time);
flops[name].push_back(m * n * (k + (k - 1)));
return;
}
if (node->op == GGML_OP_CONV_2D) {
std::string name = ggml_op_name(node->op);
ggml_tensor * knl = node->src[0];
uint64_t OW = node->ne[0];
uint64_t OH = node->ne[1];
uint64_t N = node->ne[3];
uint64_t Cout = node->ne[2];
uint64_t KW = knl->ne[0];
uint64_t KH = knl->ne[1];
uint64_t Cin = knl->ne[2];
// KxCRS @ CRSxNPQ = KxNPQ -> M=K, K=CRS, N=NPQ
uint64_t size_M = Cout;
uint64_t size_K = Cin * KW * KH;
uint64_t size_N = N * OW * OH;
uint64_t n_flops = size_M * size_N * (size_K + (size_K - 1));
name += " M=Cout=" + std::to_string(size_M) + ", K=Cin*KW*KH=" + std::to_string(size_K) +
", N=N*OW*OH=" + std::to_string(size_N);
flops[name].push_back(n_flops);
timings[name].push_back(time);
return;
}
timings[ggml_op_name(node->op)].push_back(time);
}
private:
private:
std::map<std::string, std::vector<uint64_t>> timings;
std::map<std::string, std::vector<uint64_t>> flops;
};
struct ggml_backend_vk_context {
@ -2128,6 +2212,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
}
compile_count++;
}
compiles.push_back(std::async(ggml_vk_create_pipeline_func, std::ref(device), std::ref(pipeline), spv_size, spv_data, entrypoint,
parameter_count, wg_denoms, specialization_constants, disable_robustness, require_full_subgroups, required_subgroup_size));
};
@ -2977,6 +3062,42 @@ static void ggml_vk_load_shaders(vk_device& device) {
ggml_vk_create_pipeline(device, device->pipeline_opt_step_adamw_f32, "opt_step_adamw_f32", opt_step_adamw_f32_len, opt_step_adamw_f32_data, "main", 5, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
// conv2d
uint32_t conv2d_WG_SIZE = 256;
uint32_t conv2d_BS_K = 128;
uint32_t conv2d_BS_CRS = 16;
uint32_t use_collectives = 0; // Enables subgroup ops for preventing the re-calculation of indices.
if (device->subgroup_shuffle &&
device->vendor_id != VK_VENDOR_ID_INTEL) { // Do not enable collectives on Intel, see PR 14316
use_collectives = 1;
conv2d_BS_CRS = std::min(
device->subgroup_size,
conv2d_BS_CRS); // CRS block size should be capped at sugroup size for correctness when shuffle is used.
}
uint32_t conv2d_BS_NPQ = 128;
uint32_t conv2d_TS_K = 8;
uint32_t conv2d_shmem_req =
(conv2d_BS_K * (conv2d_BS_CRS + 1) + conv2d_BS_CRS * (conv2d_BS_NPQ + 1)) * sizeof(float);
if (device->properties.limits.maxComputeSharedMemorySize < conv2d_shmem_req) {
conv2d_BS_CRS = 8;
if (use_collectives) {
conv2d_BS_CRS = std::min(device->subgroup_size, conv2d_BS_CRS);
}
}
if (use_collectives) {
ggml_vk_create_pipeline(
device, device->pipeline_conv2d_f32, "conv2d_f32", conv2d_f32_len, conv2d_f32_data, "main", 3,
sizeof(vk_op_conv2d_push_constants), { conv2d_BS_K, conv2d_BS_NPQ, 1 },
{ conv2d_WG_SIZE, conv2d_BS_K, conv2d_BS_CRS, conv2d_BS_NPQ, conv2d_TS_K, use_collectives }, 1, true, true);
} else {
ggml_vk_create_pipeline(
device, device->pipeline_conv2d_f32, "conv2d_f32", conv2d_f32_len, conv2d_f32_data, "main", 3,
sizeof(vk_op_conv2d_push_constants), { conv2d_BS_K, conv2d_BS_NPQ, 1 },
{ conv2d_WG_SIZE, conv2d_BS_K, conv2d_BS_CRS, conv2d_BS_NPQ, conv2d_TS_K, use_collectives }, 1, true,
false);
}
ggml_vk_create_pipeline(device, device->pipeline_conv2d_dw_whcn_f32, "conv2d_dw_whcn_f32", conv2d_dw_whcn_f32_len, conv2d_dw_whcn_f32_data, "main", 3, sizeof(vk_op_conv2d_dw_push_constants), {512, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_conv2d_dw_cwhn_f32, "conv2d_dw_cwhn_f32", conv2d_dw_cwhn_f32_len, conv2d_dw_cwhn_f32_data, "main", 3, sizeof(vk_op_conv2d_dw_push_constants), {512, 1, 1}, {}, 1);
@ -3297,6 +3418,12 @@ static vk_device ggml_vk_get_device(size_t idx) {
device->fp16 = device->fp16 && vk12_features.shaderFloat16;
#if defined(VK_KHR_shader_bfloat16)
device->bf16 = bfloat16_support && bfloat16_features.shaderBFloat16Type;
#else
device->bf16 = false;
#endif
device->pipeline_robustness = pl_robustness_features.pipelineRobustness;
if (device->subgroup_size_control) {
@ -3639,6 +3766,7 @@ static void ggml_vk_print_gpu_info(size_t idx) {
bool coopmat_support = false;
bool coopmat2_support = false;
bool integer_dot_product = false;
bool bfloat16_support = false;
for (auto properties : ext_props) {
if (strcmp("VK_KHR_16bit_storage", properties.extensionName) == 0) {
@ -3659,6 +3787,11 @@ static void ggml_vk_print_gpu_info(size_t idx) {
} else if (strcmp("VK_KHR_shader_integer_dot_product", properties.extensionName) == 0 &&
!getenv("GGML_VK_DISABLE_INTEGER_DOT_PRODUCT")) {
integer_dot_product = true;
#endif
#if defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT)
} else if (strcmp("VK_KHR_shader_bfloat16", properties.extensionName) == 0 &&
!getenv("GGML_VK_DISABLE_BFLOAT16")) {
bfloat16_support = true;
#endif
}
}
@ -3725,10 +3858,25 @@ static void ggml_vk_print_gpu_info(size_t idx) {
last_struct = (VkBaseOutStructure *)&shader_integer_dot_product_features;
}
#if defined(VK_KHR_shader_bfloat16)
VkPhysicalDeviceShaderBfloat16FeaturesKHR bfloat16_features {};
bfloat16_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SHADER_BFLOAT16_FEATURES_KHR;
if (bfloat16_support) {
last_struct->pNext = (VkBaseOutStructure *)&bfloat16_features;
last_struct = (VkBaseOutStructure *)&bfloat16_features;
}
#endif
vkGetPhysicalDeviceFeatures2(physical_device, &device_features2);
fp16 = fp16 && vk12_features.shaderFloat16;
#if defined(VK_KHR_shader_bfloat16)
bool bf16 = bfloat16_support && bfloat16_features.shaderBFloat16Type;
#else
bool bf16 = false;
#endif
uint32_t default_subgroup_size = get_subgroup_size("", device_architecture);
const size_t subgroup_size = (default_subgroup_size != 0) ? default_subgroup_size : subgroup_props.subgroupSize;
const bool uma = props2.properties.deviceType == vk::PhysicalDeviceType::eIntegratedGpu;
@ -3746,8 +3894,8 @@ static void ggml_vk_print_gpu_info(size_t idx) {
std::string matrix_cores = coopmat2_support ? "NV_coopmat2" : coopmat_support ? "KHR_coopmat" : "none";
std::string device_name = props2.properties.deviceName.data();
GGML_LOG_DEBUG("ggml_vulkan: %zu = %s (%s) | uma: %d | fp16: %d | warp size: %zu | shared memory: %d | int dot: %d | matrix cores: %s\n",
idx, device_name.c_str(), driver_props.driverName.data(), uma, fp16, subgroup_size,
GGML_LOG_DEBUG("ggml_vulkan: %zu = %s (%s) | uma: %d | fp16: %d | bf16: %d | warp size: %zu | shared memory: %d | int dot: %d | matrix cores: %s\n",
idx, device_name.c_str(), driver_props.driverName.data(), uma, fp16, bf16, subgroup_size,
props2.properties.limits.maxComputeSharedMemorySize, integer_dot_product, matrix_cores.c_str());
if (props2.properties.deviceType == vk::PhysicalDeviceType::eCpu) {
@ -6833,6 +6981,12 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
return ctx->device->pipeline_leaky_relu_f32;
}
return nullptr;
case GGML_OP_CONV_2D:
if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32 &&
ggml_is_contiguous(src0) && ggml_is_contiguous(src1) && ggml_is_contiguous(dst)) {
return ctx->device->pipeline_conv2d_f32;
}
return nullptr;
case GGML_OP_CONV_2D_DW:
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
if (ggml_is_contiguous(src1)) {
@ -7155,6 +7309,31 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
const uint32_t OW = dst->ne[0];
elements = { N * OC * OH * OW, 1, 1};
} break;
case GGML_OP_CONV_2D:
{
// src0 - kernel: [KW, KH, Cin, Cout]
// src1 - input: [W, H, Cin, N]
// dst - result: [OW, OH, Cout, N]
// Copied from ggml.c: int64_t ggml_calc_conv_output_size(int64_t ins, int64_t ks, int s, int p, int d)
auto calc_conv_output_size = [](int64_t ins, int64_t ks, int s, int p, int d) -> int64_t {
return (ins + 2 * p - d * (ks - 1) - 1) / s + 1;
};
// parallelize in {OW/BS_K, OH/BS_NPQ, 1}
int64_t W = src1->ne[0];
int64_t H = src1->ne[1];
int64_t KW = src0->ne[0];
int64_t KH = src0->ne[1];
int64_t Cout = src0->ne[3];
int64_t N = src1->ne[3];
int64_t OH = calc_conv_output_size(H, KH, dst->op_params[1], dst->op_params[3], dst->op_params[5]);
int64_t OW = calc_conv_output_size(W, KW, dst->op_params[0], dst->op_params[2], dst->op_params[4]);
int64_t NPQ = N * OW * OH;
// Tile output matrix to (K/NB_K, NPQ/NB_NPQ, 1) workgroups
elements = { static_cast<uint32_t>(Cout), static_cast<uint32_t>(NPQ), 1 };
}
break;
case GGML_OP_ADD:
case GGML_OP_SUB:
case GGML_OP_DIV:
@ -8021,6 +8200,55 @@ static void ggml_vk_pool_2d(ggml_backend_vk_context * ctx, vk_context& subctx, c
}, dryrun);
}
static void ggml_vk_conv_2d(ggml_backend_vk_context * ctx, vk_context & subctx, const ggml_tensor * src0,
const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT(src1->type == GGML_TYPE_F32);
GGML_ASSERT(dst->type == GGML_TYPE_F32);
GGML_TENSOR_BINARY_OP_LOCALS
GGML_ASSERT(nb00 == sizeof(float));
GGML_ASSERT(nb10 == sizeof(float));
GGML_ASSERT(nb0 == sizeof(float));
vk_op_conv2d_push_constants p{};
p.Cout = static_cast<uint32_t>(ne03);
p.Cin = static_cast<uint32_t>(ne02);
p.N = static_cast<uint32_t>(ne13);
p.KW = static_cast<uint32_t>(ne00);
p.KH = static_cast<uint32_t>(ne01);
p.W = static_cast<uint32_t>(ne10);
p.H = static_cast<uint32_t>(ne11);
p.OW = static_cast<uint32_t>(ne0);
p.OH = static_cast<uint32_t>(ne1);
p.s0 = static_cast<uint32_t>(dst->op_params[0]);
p.s1 = static_cast<uint32_t>(dst->op_params[1]);
p.p0 = static_cast<uint32_t>(dst->op_params[2]);
p.p1 = static_cast<uint32_t>(dst->op_params[3]);
p.d0 = static_cast<uint32_t>(dst->op_params[4]);
p.d1 = static_cast<uint32_t>(dst->op_params[5]);
p.nb01 = static_cast<uint32_t>(nb01 / nb00);
p.nb02 = static_cast<uint32_t>(nb02 / nb00);
p.nb03 = static_cast<uint32_t>(nb03 / nb00);
p.nb11 = static_cast<uint32_t>(nb11 / nb10);
p.nb12 = static_cast<uint32_t>(nb12 / nb10);
p.nb13 = static_cast<uint32_t>(nb13 / nb10);
p.nb1 = static_cast<uint32_t>(nb1 / nb0);
p.nb2 = static_cast<uint32_t>(nb2 / nb0);
p.nb3 = static_cast<uint32_t>(nb3 / nb0);
GGML_ASSERT(ne03 == ne2);
GGML_ASSERT(ne02 == ne12);
ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_CONV_2D, std::move(p), dryrun);
}
static void ggml_vk_conv_2d_dw(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
vk_op_conv2d_dw_push_constants p{};
p.ne = ggml_nelements(dst);
@ -9083,6 +9311,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
case GGML_OP_TIMESTEP_EMBEDDING:
case GGML_OP_CONV_TRANSPOSE_1D:
case GGML_OP_POOL_2D:
case GGML_OP_CONV_2D:
case GGML_OP_CONV_2D_DW:
case GGML_OP_RWKV_WKV6:
case GGML_OP_RWKV_WKV7:
@ -9150,6 +9379,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
case GGML_OP_TIMESTEP_EMBEDDING:
case GGML_OP_CONV_TRANSPOSE_1D:
case GGML_OP_POOL_2D:
case GGML_OP_CONV_2D:
case GGML_OP_CONV_2D_DW:
case GGML_OP_LEAKY_RELU:
{
@ -9356,6 +9586,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
case GGML_OP_POOL_2D:
ggml_vk_pool_2d(ctx, compute_ctx, src0, node, dryrun);
break;
case GGML_OP_CONV_2D:
ggml_vk_conv_2d(ctx, compute_ctx, src0, src1, node, dryrun);
break;
case GGML_OP_CONV_2D_DW:
ggml_vk_conv_2d_dw(ctx, compute_ctx, src0, src1, node, dryrun);
@ -9486,6 +9720,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_cgraph *
case GGML_OP_TIMESTEP_EMBEDDING:
case GGML_OP_CONV_TRANSPOSE_1D:
case GGML_OP_POOL_2D:
case GGML_OP_CONV_2D:
case GGML_OP_CONV_2D_DW:
case GGML_OP_RWKV_WKV6:
case GGML_OP_RWKV_WKV7:
@ -10067,6 +10302,12 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
ggml_vk_build_graph(ctx, cgraph, i, nullptr, 0, true, false, false, false);
if (cgraph->nodes[i]->op == GGML_OP_MUL_MAT || cgraph->nodes[i]->op == GGML_OP_MUL_MAT_ID) {
total_mat_mul_bytes += ggml_nbytes(cgraph->nodes[i]->src[0]);
} else if (cgraph->nodes[i]->op == GGML_OP_CONV_2D) {
// Return CRSxNPQxsizeof(*) to account as many bytes as mul_mat has in im2col->mul_mat mode.
auto CRS_size =
cgraph->nodes[i]->src[0]->ne[0] * cgraph->nodes[i]->src[0]->ne[1] * cgraph->nodes[i]->src[0]->ne[2];
auto NPQ_size = cgraph->nodes[i]->ne[0] * cgraph->nodes[i]->ne[1] * cgraph->nodes[i]->ne[3];
total_mat_mul_bytes += NPQ_size * CRS_size * ggml_type_size(cgraph->nodes[i]->type);
}
i += ctx->num_additional_fused_ops;
ctx->num_additional_fused_ops = 0;
@ -10643,6 +10884,20 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
return true;
case GGML_OP_CONV_TRANSPOSE_1D:
return op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32;
case GGML_OP_CONV_2D:
{
// Op is disabled for Apple because it segfaults at pipeline create time on MoltenVK
ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
const vk_device& device = ggml_vk_get_device(ctx->device);
bool is_Apple = ggml_vk_get_device(ctx->device)->vendor_id == VK_VENDOR_ID_APPLE;
// Channel-contiguous format is not supported yet.
return (op->src[0]->type == GGML_TYPE_F32 &&
op->src[1]->type == GGML_TYPE_F32 &&
op->type == GGML_TYPE_F32 &&
ggml_is_contiguous(op->src[0]) &&
ggml_is_contiguous(op->src[1]) &&
ggml_is_contiguous(op)) && !is_Apple;
}
default:
return false;
}
@ -11201,6 +11456,14 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph *
const int32_t p1 = tensor->op_params[6];
tensor_clone = ggml_pool_2d(ggml_ctx, src_clone[0], op, k0, k1, s0, s1, p0, p1);
} else if (tensor->op == GGML_OP_CONV_2D) {
const int32_t s0 = tensor->op_params[0];
const int32_t s1 = tensor->op_params[1];
const int32_t p0 = tensor->op_params[2];
const int32_t p1 = tensor->op_params[3];
const int32_t d0 = tensor->op_params[4];
const int32_t d1 = tensor->op_params[5];
tensor_clone = ggml_conv_2d(ggml_ctx, src_clone[0], src_clone[1], s0, s1, p0, p1, d0, d1);
} else if (tensor->op == GGML_OP_LEAKY_RELU) {
const float * op_params = (const float *)tensor->op_params;
tensor_clone = ggml_leaky_relu(ggml_ctx, src_clone[0], op_params[0], false);

View file

@ -0,0 +1,265 @@
#version 450
#ifdef USE_COLLECTIVES
# extension GL_KHR_shader_subgroup_shuffle : enable
#endif
#include "types.comp"
// Make spec constant
#define SHMEM_PAD 0
// shape notation: [dim(N), ..., dim(0)] -- stride(dim(j)) >= stride(dim(i)) if i > j
layout(binding = 0) readonly buffer A {
A_TYPE knl_data[];
}; // src0 - kernel: [KW, KH, Cin, Cout]
layout(binding = 1) readonly buffer B {
B_TYPE src_data[];
}; // src1 - input: [W, H, Cin, N] -- channel_first format
layout(binding = 2) writeonly buffer D {
D_TYPE dst_data[];
}; // dst - result: [OW, OH, Cout, N]
layout(push_constant) uniform parameter {
// I/O channels, batch size
uint32_t Cout;
uint32_t Cin;
uint32_t N;
// Tensor spatial sizes: kernel, input, output
uint32_t KW;
uint32_t KH;
uint32_t W;
uint32_t H;
uint32_t OW;
uint32_t OH;
// Parameters: stride, padding, dilation - 0=y, 1=x
uint32_t s0;
uint32_t s1;
uint32_t p0;
uint32_t p1;
uint32_t d0;
uint32_t d1;
// Strides in elements
uint32_t nb01;
uint32_t nb02;
uint32_t nb03;
uint32_t nb11;
uint32_t nb12;
uint32_t nb13;
uint32_t nb1;
uint32_t nb2;
uint32_t nb3;
}
p;
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
// Blocktile sizes
layout(constant_id = 1) const uint BS_K = 128;
layout(constant_id = 2) const uint BS_CRS = 16;
layout(constant_id = 3) const uint BS_NPQ = 128;
// Thread-tile sizes
layout(constant_id = 4) const uint TS_K = 8;
layout(constant_id = 5) const uint use_collectives = 1;
uint32_t tid = gl_LocalInvocationID.x;
const uint32_t WG_SIZE = gl_WorkGroupSize.x;
uint splitWork(uint work_size, uint block_size) {
return (block_size + work_size - 1) / block_size;
}
uint32_t K = p.Cout;
uint32_t CRS = p.Cin * p.KH * p.KW;
uint32_t NPQ = p.N * p.OH * p.OW;
uint32_t n_elems_out = K * NPQ;
// Number of blocktiles per input
uint32_t NB_CRS = splitWork(CRS, BS_CRS);
const uint32_t Ash_stride = BS_CRS + SHMEM_PAD;
const uint32_t Bsh_stride = BS_NPQ + SHMEM_PAD;
const uint32_t Ash_numel = BS_K * BS_CRS;
const uint32_t Bsh_numel = BS_CRS * BS_NPQ;
const uint32_t Ash_len = BS_K * Ash_stride;
const uint32_t Bsh_len = BS_CRS * Bsh_stride;
shared float Ash[Ash_len]; // K x CRS
shared float Bsh[Bsh_len]; // CRS x NPQ
// Threadtile sizes
const uint32_t TS_NPQ = BS_K * BS_NPQ / WG_SIZE / TS_K;
// Number of threadtiles per blocktile
const uint32_t NT_K = BS_K / TS_K;
const uint32_t NT_NPQ = BS_NPQ / TS_NPQ;
float regA[TS_K];
float regB[TS_NPQ];
float regC[TS_K][TS_NPQ];
/*
Compute
KxCRS @ CRSxNPQ = K x NPQ
K=Cout
C=Cin
R,S=KH,KW
P,Q=OH,OW
*/
uint32_t B_idx_K = gl_WorkGroupID.x;
uint32_t B_idx_NPQ = gl_WorkGroupID.y;
uint32_t T_y = tid / NT_NPQ;
uint32_t T_x = tid % NT_NPQ;
uint32_t Ar = tid / BS_CRS;
uint32_t Ac = tid % BS_CRS;
const uint32_t ArpWg = WG_SIZE / BS_CRS;
uint32_t Br = tid / BS_NPQ;
uint32_t Bc = tid % BS_NPQ;
const uint32_t BrpWg = WG_SIZE / BS_NPQ;
void main() {
for (uint32_t T_ly = 0; T_ly < TS_K; T_ly++) {
for (uint32_t T_lx = 0; T_lx < TS_NPQ; T_lx++) {
regC[T_ly][T_lx] = 0.0;
}
}
/* Advance block in CRS dim */
for (uint32_t B_idx_CRS = 0; B_idx_CRS < NB_CRS; B_idx_CRS++) {
uint32_t CRS_idx_a;
uint32_t Cin_idx_a;
uint32_t KH_idx_a;
uint32_t KW_idx_a;
#ifdef USE_COLLECTIVES
uint32_t cached_CRS_idx;
uint32_t cached_Cin_idx;
uint32_t cached_KH_idx;
uint32_t cached_KW_idx;
if (use_collectives == 1) {
cached_CRS_idx = B_idx_CRS * BS_CRS + gl_SubgroupInvocationID;
cached_Cin_idx = cached_CRS_idx / (p.KW * p.KH);
uint32_t cached_CRS_remainder = (cached_CRS_idx - cached_Cin_idx * p.KW * p.KH);
cached_KH_idx = cached_CRS_remainder / p.KW;
cached_KW_idx = cached_CRS_remainder - cached_KH_idx * p.KW;
CRS_idx_a = subgroupShuffle(cached_CRS_idx, Ac);
Cin_idx_a = subgroupShuffle(cached_Cin_idx, Ac);
KH_idx_a = subgroupShuffle(cached_KH_idx, Ac);
KW_idx_a = subgroupShuffle(cached_KW_idx, Ac);
} else {
CRS_idx_a = B_idx_CRS * BS_CRS + Ac; // Global CRS_idx_a (column index of A)
Cin_idx_a = CRS_idx_a / (p.KW * p.KH);
uint32_t CRS_remainder = CRS_idx_a - Cin_idx_a * p.KW * p.KH;
KH_idx_a = CRS_remainder / p.KW;
KW_idx_a = CRS_remainder - KH_idx_a * p.KW;
}
#else
CRS_idx_a = B_idx_CRS * BS_CRS + Ac; // Global CRS_idx_a (column index of A)
Cin_idx_a = CRS_idx_a / (p.KW * p.KH);
CRS_remainder = CRS_idx_a - Cin_idx_a * p.KW * p.KH;
KH_idx_a = CRS_remainder / p.KW;
KW_idx_a = CRS_remainder - KH_idx_a * p.KW;
#endif
/* Load kernel to A_block: (BS_K x BS_CRS)*/
for (uint32_t r_offset = 0; r_offset < BS_K; r_offset += ArpWg) {
uint32_t B_ly = r_offset + Ar;
uint32_t B_lx = Ac;
uint32_t K_idx = B_idx_K * BS_K + B_ly; /* Global K_idx (row index of A)*/
uint32_t knl_idx = min(KW_idx_a + KH_idx_a * p.nb01 + Cin_idx_a * p.nb02 + K_idx * p.nb03, K * CRS - 1);
float val = knl_data[knl_idx];
if (K_idx >= K || CRS_idx_a >= CRS) {
val = 0.0;
}
Ash[B_ly * Ash_stride + B_lx] = val;
}
/* Load input to B_block: (BS_CRS x BS_NPQ) */
for (uint32_t r_offset = 0; r_offset < BS_CRS; r_offset += BrpWg) {
uint32_t B_ly = r_offset + Br; /* Row index of B block */
uint32_t B_lx = Bc;
uint32_t NPQ_idx = B_idx_NPQ * BS_NPQ + B_lx; /* Global NPQ index (column index of B) */
uint32_t N_idx = NPQ_idx / (p.OH * p.OW);
uint32_t NPQ_remainder = NPQ_idx - N_idx * p.OH * p.OW;
uint32_t OH_idx = NPQ_remainder / p.OW;
uint32_t OW_idx = NPQ_remainder - OH_idx * p.OW;
uint32_t CRS_idx_b;
uint32_t Cin_idx_b;
uint32_t KH_idx_b;
uint32_t KW_idx_b;
#ifdef USE_COLLECTIVES
if (use_collectives == 1) {
CRS_idx_b = subgroupShuffle(cached_CRS_idx, r_offset + Br);
Cin_idx_b = subgroupShuffle(cached_Cin_idx, r_offset + Br);
KH_idx_b = subgroupShuffle(cached_KH_idx, r_offset + Br);
KW_idx_b = subgroupShuffle(cached_KW_idx, r_offset + Br);
} else {
CRS_idx_b = B_idx_CRS * BS_CRS + B_ly; /* Global CRS index (row index of B) */
Cin_idx_b = CRS_idx_b / (p.KW * p.KH);
uint32_t CRS_remainder = CRS_idx_b - Cin_idx_b * p.KW * p.KH;
KH_idx_b = CRS_remainder / p.KW;
KW_idx_b = CRS_remainder - KH_idx_b * p.KW;
}
#else
CRS_idx_b = B_idx_CRS * BS_CRS + B_ly; /* Global CRS index (row index of B) */
Cin_idx_b = CRS_idx_b / (p.KW * p.KH);
uint32_t CRS_remainder = CRS_idx_b - Cin_idx_b * p.KW * p.KH;
KH_idx_b = CRS_remainder / p.KW;
KW_idx_b = CRS_remainder - KH_idx_b * p.KW;
#endif
uint32_t H_idx = OH_idx * p.s1 + KH_idx_b * p.d1 - p.p1;
uint32_t W_idx = OW_idx * p.s0 + KW_idx_b * p.d0 - p.p0;
uint32_t src_idx =
min(max(W_idx + H_idx * p.nb11 + Cin_idx_b * p.nb12 + N_idx * p.nb13, 0), p.Cin * p.N * p.W * p.H - 1);
float val = src_data[src_idx];
if (CRS_idx_b >= CRS || NPQ_idx >= NPQ || H_idx < 0 || H_idx >= p.H || W_idx < 0 || W_idx >= p.W) {
val = 0.0;
}
Bsh[B_ly * Bsh_stride + B_lx] = val;
}
barrier();
for (uint32_t CRS_lidx = 0; CRS_lidx < BS_CRS; CRS_lidx++) {
for (uint32_t T_ly = 0; T_ly < TS_K; T_ly++) {
regA[T_ly] = Ash[(T_y * TS_K + T_ly) * Ash_stride + CRS_lidx];
}
for (uint32_t T_lx = 0; T_lx < TS_NPQ; T_lx++) {
regB[T_lx] = Bsh[CRS_lidx * Bsh_stride + T_x * TS_NPQ + T_lx];
}
for (uint32_t T_ly = 0; T_ly < TS_K; T_ly++) {
for (uint32_t T_lx = 0; T_lx < TS_NPQ; T_lx++) {
regC[T_ly][T_lx] = fma(regA[T_ly], regB[T_lx], regC[T_ly][T_lx]);
}
}
}
barrier();
}
/* Save C* */
for (uint32_t T_ly = 0; T_ly < TS_K; T_ly++) {
for (uint32_t T_lx = 0; T_lx < TS_NPQ; T_lx++) {
uint32_t K_idx = B_idx_K * BS_K + T_y * TS_K + T_ly;
uint32_t NPQ_idx = B_idx_NPQ * BS_NPQ + T_x * TS_NPQ + T_lx;
uint32_t N_idx = NPQ_idx / (p.OH * p.OW);
uint32_t OH_idx = (NPQ_idx - N_idx * p.OH * p.OW) / p.OW;
uint32_t OW_idx = NPQ_idx - N_idx * p.OH * p.OW - OH_idx * p.OW;
uint32_t dst_idx = OW_idx + OH_idx * p.nb1 + K_idx * p.nb2 + N_idx * p.nb3;
if (K_idx < K && NPQ_idx < NPQ) {
dst_data[dst_idx] = regC[T_ly][T_lx];
}
}
}
}

View file

@ -669,6 +669,8 @@ void process_shaders() {
string_to_spv("opt_step_adamw_f32", "opt_step_adamw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}}));
string_to_spv("conv2d_f32", "conv2d_mm.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"USE_COLLECTIVES", "1"}});
string_to_spv("conv2d_dw_whcn_f32", "conv2d_dw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"WHCN", "1"}}));
string_to_spv("conv2d_dw_cwhn_f32", "conv2d_dw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"CWHN", "1"}}));
@ -779,8 +781,8 @@ void write_output_files() {
len += "};\n";
}
}
fprintf(src, data.c_str());
fprintf(src, len.c_str());
fputs(data.c_str(), src);
fputs(len.c_str(), src);
}
fclose(hdr);
fclose(src);

View file

@ -233,6 +233,11 @@ class Keys:
TYPE = "adapter.type"
LORA_ALPHA = "adapter.lora.alpha"
class IMatrix:
CHUNK_COUNT = "imatrix.chunk_count"
CHUNK_SIZE = "imatrix.chunk_size"
DATASETS = "imatrix.datasets"
class Clip:
PROJECTOR_TYPE = "clip.projector_type"
HAS_VISION_ENCODER = "clip.has_vision_encoder"
@ -282,6 +287,7 @@ class Keys:
class GGUFType:
MODEL = "model"
ADAPTER = "adapter"
IMATRIX = "imatrix"
MMPROJ = "mmproj" # dummy, unused for now

View file

@ -3425,6 +3425,9 @@ Current version indicated by LITEVER below.
websearch_retain: false,
websearch_template: "",
wordsearch_enabled: false,
second_ep_qty:0,
second_ep_model:"gpt2",
second_ep_url:"",
max_context_length: (localflag?4096:2048),
max_length: (localflag?512:256),
@ -12674,6 +12677,9 @@ Current version indicated by LITEVER below.
document.getElementById("dry_base").value = localsettings.dry_base;
document.getElementById("dry_allowed_length").value = localsettings.dry_allowed_length;
document.getElementById("token_count_multiplier").value = localsettings.token_count_multiplier;
document.getElementById("second_ep_qty").value = localsettings.second_ep_qty;
document.getElementById("second_ep_model").value = localsettings.second_ep_model;
document.getElementById("second_ep_url").value = localsettings.second_ep_url;
if(is_using_kcpp_with_mirostat())
{
@ -13247,6 +13253,9 @@ Current version indicated by LITEVER below.
localsettings.token_count_multiplier = parseInt(document.getElementById("token_count_multiplier").value);
localsettings.guidance_scale = parseFloat(document.getElementById("guidance_scale").value);
localsettings.guidance_prompt = document.getElementById("guidance_prompt").value;
localsettings.second_ep_qty = parseInt(document.getElementById("second_ep_qty").value);
localsettings.second_ep_model = document.getElementById("second_ep_model").value;
localsettings.second_ep_url = document.getElementById("second_ep_url").value;
localsettings.extrastopseq = document.getElementById("extrastopseq").value;
localsettings.tokenbans = document.getElementById("tokenbans").value;
@ -13446,6 +13455,7 @@ Current version indicated by LITEVER below.
localsettings.xtc_threshold = cleannum(localsettings.xtc_threshold, 0.0, 1.0);
localsettings.sampler_seed = cleannum(localsettings.sampler_seed, -1, 999999);
localsettings.token_count_multiplier = cleannum(localsettings.token_count_multiplier, 70, 130);
localsettings.second_ep_qty = cleannum(Math.floor(localsettings.second_ep_qty), 0, 1024);
toggle_invert_colors();
toggle_sidepanel_mode();
@ -13901,6 +13911,7 @@ Current version indicated by LITEVER below.
on_searchsummary_done = onDoneFn;
submit_payload = finalize_submit_payload(submit_payload, false);
dispatch_submit_generation(submit_payload, false);
render_gametext();
}
@ -13975,7 +13986,8 @@ Current version indicated by LITEVER below.
//v2 api specific fields
submit_payload.workers = selected_workers.map((m) => { return m.id });
dispatch_submit_generation(submit_payload,false);
submit_payload = finalize_submit_payload(submit_payload, false);
dispatch_submit_generation(submit_payload, false);
render_gametext();
document.getElementById("memorytext").value = "[<|Generating summary, do not close window...|>]"
};
@ -16545,6 +16557,14 @@ Current version indicated by LITEVER below.
if (!doNotGenerate)
{
submit_payload = finalize_submit_payload(submit_payload, user_input_empty);
if(localsettings.second_ep_qty>0 && localsettings.second_ep_url!="")
{
let decensored_prefix = yield FetchDecensoredPrefix(submit_payload, localsettings.second_ep_url, localsettings.second_ep_model, localsettings.second_ep_qty);
console.log(`Obtained Prefix: ${decensored_prefix}`);
pending_context_preinjection += decensored_prefix;
submit_payload.prompt += decensored_prefix;
}
dispatch_submit_generation(submit_payload, user_input_empty);
}
else
@ -16714,10 +16734,8 @@ Current version indicated by LITEVER below.
return resp;
}
function dispatch_submit_generation(submit_payload, input_was_empty) //if input is not empty, always unban eos
function finalize_submit_payload(submit_payload, input_was_empty)
{
console.log(submit_payload);
//preprocess to add extra fields
if(custom_kobold_endpoint != "" && is_using_kcpp_with_mirostat())
{
@ -16772,6 +16790,28 @@ Current version indicated by LITEVER below.
submit_payload.params.negative_prompt = localsettings.guidance_prompt;
}
//for vesion 1.2.2 and later, send stopper tokens for chat and instruct
if (custom_kobold_endpoint != "" && kobold_endpoint_version && kobold_endpoint_version != "" && compare_version_str(kobold_endpoint_version, "1.2.2") >= 0) {
submit_payload.params.stop_sequence = get_stop_sequences();
}
//version 1.2.4 and later supports unban tokens
if (custom_kobold_endpoint != "" && kobold_endpoint_version && kobold_endpoint_version != "" && compare_version_str(kobold_endpoint_version, "1.2.4") >= 0)
{
submit_payload.params.use_default_badwordsids = determine_if_ban_eos(input_was_empty);
if(is_using_kcpp_with_added_memory())
{
submit_payload.params.bypass_eos = (localsettings.eos_ban_mode == 3?true:false);
}
}
return submit_payload;
}
function dispatch_submit_generation(submit_payload, input_was_empty) //if input is not empty, always unban eos
{
console.log(submit_payload);
start_time_taken(); //timestamp start request
if (is_using_custom_ep()) {
@ -16792,23 +16832,6 @@ Current version indicated by LITEVER below.
let prompt = submit_payload.prompt;
submit_payload = submit_payload.params;
submit_payload.prompt = prompt;
let showlog = false;
submit_payload.quiet = !showlog;
//for vesion 1.2.2 and later, send stopper tokens for chat and instruct
if (kobold_endpoint_version && kobold_endpoint_version != "" && compare_version_str(kobold_endpoint_version, "1.2.2") >= 0) {
submit_payload.stop_sequence = get_stop_sequences();
}
//version 1.2.4 and later supports unban tokens
if (kobold_endpoint_version && kobold_endpoint_version != "" && compare_version_str(kobold_endpoint_version, "1.2.4") >= 0)
{
submit_payload.use_default_badwordsids = determine_if_ban_eos(input_was_empty);
if(is_using_kcpp_with_added_memory())
{
submit_payload.bypass_eos = (localsettings.eos_ban_mode == 3?true:false);
}
}
last_request_str = JSON.stringify(submit_payload);
last_response_obj = null;
@ -17006,18 +17029,6 @@ Current version indicated by LITEVER below.
}
else
{
//apply custom logit bias for official OAI only
let needbaneos = (custom_oai_endpoint.toLowerCase().includes("api.openai.com") && determine_if_ban_eos(input_was_empty));
if(needbaneos)
{
if(oai_payload.logit_bias)
{
oai_payload.logit_bias["50256"] = -100;
}else{
oai_payload.logit_bias = { "50256": -100 };
}
}
oai_payload.prompt = submit_payload.prompt;
}
@ -19039,7 +19050,7 @@ Current version indicated by LITEVER below.
{
if(!localsettings.allow_continue_chat)
{
pending_context_preinjection = get_instructendplaceholder();
pending_context_preinjection = get_instructendplaceholder();
if(localsettings.inject_timestamps)
{
@ -20351,6 +20362,7 @@ Current version indicated by LITEVER below.
fulltxt = replace_placeholders(fulltxt,true);
}
let alreadyShowedStreaming = false; //sometimes, streaming is injected inline to preserve indentation
if(localsettings.opmode==4 && !inEditMode)
{
@ -20380,6 +20392,36 @@ Current version indicated by LITEVER below.
{
let curr = instruct_turns[i];
let currmsg = curr.msg;
//in some special cases, add streaming text inline
let allow_cont_prev_turn = (localsettings.opmode==4 || (localsettings.opmode==3 && localsettings.allow_continue_chat));
if(i>0 && (i==instruct_turns.length-1) && curr.myturn==false && currmsg && allow_cont_prev_turn)
{
//inject into previous turn, only for instruct OR continuechat
currmsg += `<span class="color_yellow pending_text">${escape_html(pending_context_preinjection) + format_streaming_text(escape_html(synchro_pending_stream))}</span>`;
alreadyShowedStreaming = true;
}
//apply stylization to time tags
if(localsettings.inject_timestamps && localsettings.instruct_has_markdown)
{
currmsg = currmsg.replace(/(\[\d{1,2}\/\d{1,2}\/\d{4}, \d{1,2}:\d{2} [AP]M\])/g, "$1\n");
}
if(localsettings.inject_chatnames_instruct && localsettings.instruct_has_markdown)
{
let m_name = localsettings.chatname + ": ";
currmsg = replaceAll(currmsg, m_name, `<b>` + escape_html(m_name) + `</b>`);
let m_opps = localsettings.chatopponent.split("||$||");
for(let i=0;i<m_opps.length;++i)
{
if(m_opps[i] && m_opps[i].trim()!="")
{
let m_opp = m_opps[i] + ": ";
currmsg = replaceAll(currmsg, m_opp, `<b>` + escape_html(m_opp) + `</b>`);
}
}
}
if(localsettings.instruct_has_markdown && (localsettings.render_streaming_markdown||synchro_pending_stream==""||(i+1)<instruct_turns.length))
{
//if a list has a starttag on the same line, add a newline before it
@ -20406,26 +20448,7 @@ Current version indicated by LITEVER below.
}
}
//apply stylization to time tags
if(localsettings.inject_timestamps && localsettings.instruct_has_markdown)
{
fulltxt = fulltxt.replace(/(\[\d{1,2}\/\d{1,2}\/\d{4}, \d{1,2}:\d{2} [AP]M\])/g, "$1\n");
}
if(localsettings.inject_chatnames_instruct && localsettings.instruct_has_markdown)
{
let m_name = localsettings.chatname + ": ";
fulltxt = replaceAll(fulltxt, m_name, `<b>` + escape_html(m_name) + `</b>`);
let m_opps = localsettings.chatopponent.split("||$||");
for(let i=0;i<m_opps.length;++i)
{
if(m_opps[i] && m_opps[i].trim()!="")
{
let m_opp = m_opps[i] + ": ";
fulltxt = replaceAll(fulltxt, m_opp, `<b>` + escape_html(m_opp) + `</b>`);
}
}
}
}else{
} else {
fulltxt = replaceAll(fulltxt, get_instruct_starttag(true), `%SclStg%`+escape_html(get_instruct_starttag(true))+`%SpnEtg%`);
fulltxt = replaceAll(fulltxt, get_instruct_endtag(true), `%SclStg%`+escape_html(get_instruct_endtag(true))+`%SpnEtg%`);
if (localsettings.separate_end_tags) {
@ -20463,7 +20486,7 @@ Current version indicated by LITEVER below.
{
if(!document.getElementById("allowediting").checked && !fulltxt.startsWith("\n"))
{
fulltxt = "\n"+fulltxt;
fulltxt = "\n" + fulltxt;
}
//for chat mode, highlight our name in blue and opponent in red
@ -20483,7 +20506,6 @@ Current version indicated by LITEVER below.
return `<span class="`+colormap[onametrim]+`">` + oname + `</span>`;
});
fulltxt = replaceAll(fulltxt,m_name, `<span class="color_blue">` + escape_html(m_name) + `</span>`);
}
//for adventure mode, highlight our actions in green
@ -20493,8 +20515,8 @@ Current version indicated by LITEVER below.
});
}
//streaming display
if(synchro_pending_stream!="" && waiting_for_tool_call==0)
//streaming display for all other cases
if(!alreadyShowedStreaming && synchro_pending_stream!="" && waiting_for_tool_call==0)
{
fulltxt += `<span class="color_yellow pending_text">${escape_html(pending_context_preinjection) + format_streaming_text(escape_html(synchro_pending_stream))}</span>`;
}
@ -23906,6 +23928,59 @@ Current version indicated by LITEVER below.
resultsContainer.appendChild(btn);
});
}
//queries the decensoring prefix from a second OAI compatible endpoint and returns a string to add to our main request
const FetchDecensoredPrefix = asyncRunner(function* (submit_payload, endpoint_url, modelused, num_tokens)
{
let oaiheaders = {
'Content-Type': 'application/json'
};
if (custom_oai_key!="" && custom_oai_key!=dummy_api_key) {
oaiheaders['Authorization'] = 'Bearer ' + custom_oai_key;
}
let scaled_rep_pen = 0;
if(submit_payload.params.presence_penalty > 0)
{
scaled_rep_pen = submit_payload.params.presence_penalty;
}else{
//original range between 1 and 3, scale to 0 and 2
scaled_rep_pen = (submit_payload.params.rep_pen - 1.0);
}
let oai_payload =
{
"prompt": submit_payload.prompt,
"max_tokens": num_tokens,
"model": modelused,
"temperature": submit_payload.params.temperature,
"top_p": submit_payload.params.top_p,
"stream": false,
"presence_penalty": scaled_rep_pen,
"stop": get_stop_sequences().slice(0, 4)
};
return fetch(endpoint_url, {
method: 'POST',
headers: oaiheaders,
body: JSON.stringify(oai_payload),
referrerPolicy: 'no-referrer',
})
.then((response) => response.json())
.then((data) => {
let reply = "";
if (data.choices != null && data.choices.length > 0) {
let dch = data.choices[0];
if (dch.text) {
reply = dch.text;
}
}
return reply;
})
.catch((error) => {
console.error('Error:', error);
return "";
});
});
</script>
</head>
@ -24873,6 +24948,19 @@ Current version indicated by LITEVER below.
<div class="justifyleft settingsmall" style="width:100%">
<input title="Top N Sigma. 0 to deactivate. Range 0 to 5." class="settinglabel miniinput" type="text" inputmode="decimal" placeholder="0" value="0" id="nsigma"></div>
</div>
<div class="settinglabel settingcell">
<div class="justifyleft settingsmall" style="width:100%">2ndEpSampleQty <span class="helpicon">?<span class="helptext">
If enabled, sample N tokens from a second OpenAI Compatible endpoint first before a request. This can be useful for jailbreak purposes.</span></span></div>
<div class="justifyleft settingsmall" style="width:100%;">
<div style="display:flex;">
<div class="settingsmall" style="margin-left:2px;margin-right:2px;margin-top:4px;" title="Number of tokens to sample">Qty: </div> <input title="Token Quantity" class="settinglabel miniinput" type="text" inputmode="decimal" placeholder="0" value="0" id="second_ep_qty">
<div class="settingsmall" style="margin-left:2px;margin-right:2px;margin-top:4px;" title="Model name to use">Model: </div> <input title="Model Name" class="settinglabel miniinput" type="text" inputmode="decimal" placeholder="(Blank Model)" value="" id="second_ep_model">
</div>
<div style="display:flex; margin-top:2px">
<div class="settingsmall" style="margin-left:2px;margin-right:2px;margin-top:4px;" title="URL of second OpenAI Compatible Endpoint">Url: </div> <input title="Second OpenAI Compatible URL" class="settinglabel miniinput" type="text" placeholder="(Full /v1/completions URL)" value="" id="second_ep_url">
</div>
</div>
</div>
</div>
</div>
</div>

View file

@ -63,7 +63,7 @@ dry_seq_break_max = 128
extra_images_max = 4
# global vars
KcppVersion = "1.96.1"
KcppVersion = "1.96.2"
showdebug = True
kcpp_instance = None #global running instance
global_memory = {"tunnel_url": "", "restart_target":"", "input_to_exit":False, "load_complete":False, "restart_override_config_target":""}

View file

@ -428,6 +428,8 @@ void llm_graph_result::reset() {
t_embd = nullptr;
t_embd_pooled = nullptr;
params = {};
inputs.clear();
buf_compute_meta.resize(ggml_tensor_overhead()*max_nodes + ggml_graph_overhead_custom(max_nodes, false));
@ -905,20 +907,25 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
cb(cur, "ffn_moe_weighted", il);
}
ggml_tensor * cur_experts[LLAMA_MAX_EXPERTS] = { nullptr };
assert(n_expert_used > 0);
// order the views before the adds
for (uint32_t i = 0; i < hparams.n_expert_used; ++i) {
cur_experts[i] = ggml_view_2d(ctx0, experts, n_embd, n_tokens, experts->nb[2], i*experts->nb[1]);
ggml_build_forward_expand(gf, cur_experts[i]);
}
// aggregate experts
// note: here we explicitly use hparams.n_expert_used instead of n_expert_used
// to avoid potentially a large number of add nodes during warmup
// ref: https://github.com/ggml-org/llama.cpp/pull/14753
ggml_tensor * moe_out = nullptr;
for (uint32_t i = 0; i < hparams.n_expert_used; ++i) {
ggml_tensor * cur_expert = ggml_view_2d(ctx0, experts, n_embd, n_tokens,
experts->nb[2], i*experts->nb[1]);
ggml_tensor * moe_out = cur_experts[0];
if (i == 0) {
moe_out = cur_expert;
} else {
moe_out = ggml_add(ctx0, moe_out, cur_expert);
}
for (uint32_t i = 1; i < hparams.n_expert_used; ++i) {
moe_out = ggml_add(ctx0, moe_out, cur_experts[i]);
}
if (hparams.n_expert_used == 1) {

View file

@ -1,12 +1,14 @@
#include "build-info.h"
#include "common/common.h"
#include "llama.h"
#include "gguf.h"
#include <cstdio>
#include <cstring>
#include <vector>
#include <string>
#include <unordered_map>
#include <map>
#include <fstream>
#include <cmath>
#include <cctype>
@ -69,6 +71,11 @@ static const char * const LLM_KV_QUANTIZE_IMATRIX_DATASET = "quantize.imatrix
static const char * const LLM_KV_QUANTIZE_IMATRIX_N_ENTRIES = "quantize.imatrix.entries_count";
static const char * const LLM_KV_QUANTIZE_IMATRIX_N_CHUNKS = "quantize.imatrix.chunks_count";
// TODO: share with imatrix.cpp
static const char * const LLM_KV_IMATRIX_DATASETS = "imatrix.datasets";
static const char * const LLM_KV_IMATRIX_CHUNK_COUNT = "imatrix.chunk_count";
static const char * const LLM_KV_IMATRIX_CHUNK_SIZE = "imatrix.chunk_size";
static bool striequals(const char * a, const char * b) {
while (*a && *b) {
if (std::tolower(*a) != std::tolower(*b)) {
@ -85,7 +92,7 @@ static bool try_parse_ftype(const std::string & ftype_str_in, llama_ftype & ftyp
for (auto ch : ftype_str_in) {
ftype_str.push_back(std::toupper(ch));
}
for (auto & it : QUANT_OPTIONS) {
for (const auto & it : QUANT_OPTIONS) {
if (striequals(it.name.c_str(), ftype_str.c_str())) {
ftype = it.ftype;
ftype_str_out = it.name;
@ -94,7 +101,7 @@ static bool try_parse_ftype(const std::string & ftype_str_in, llama_ftype & ftyp
}
try {
int ftype_int = std::stoi(ftype_str);
for (auto & it : QUANT_OPTIONS) {
for (const auto & it : QUANT_OPTIONS) {
if (it.ftype == ftype_int) {
ftype = it.ftype;
ftype_str_out = it.name;
@ -130,7 +137,7 @@ static void usage(const char * executable) {
printf(" Advanced option to override model metadata by key in the quantized model. May be specified multiple times.\n");
printf("Note: --include-weights and --exclude-weights cannot be used together\n");
printf("\nAllowed quantization types:\n");
for (auto & it : QUANT_OPTIONS) {
for (const auto & it : QUANT_OPTIONS) {
if (it.name != "COPY") {
printf(" %2d or ", it.ftype);
} else {
@ -141,7 +148,7 @@ static void usage(const char * executable) {
exit(1);
}
static int load_imatrix(const std::string & imatrix_file, std::string & imatrix_dataset, std::unordered_map<std::string, std::vector<float>> & imatrix_data) {
static int load_legacy_imatrix(const std::string & imatrix_file, std::vector<std::string> & imatrix_datasets, std::unordered_map<std::string, std::vector<float>> & imatrix_data) {
std::ifstream in(imatrix_file.c_str(), std::ios::binary);
if (!in) {
printf("%s: failed to open %s\n",__func__, imatrix_file.c_str());
@ -181,7 +188,9 @@ static int load_imatrix(const std::string & imatrix_file, std::string & imatrix_
exit(1);
}
if (ncall > 0) {
for (auto& v : e) v /= ncall;
for (auto & v : e) {
v /= ncall;
}
}
if (getenv("LLAMA_TRACE")) {
@ -189,7 +198,7 @@ static int load_imatrix(const std::string & imatrix_file, std::string & imatrix_
}
}
// latest imatrix version contains the dataset filename at the end of the file
// latest legacy imatrix version contains the dataset filename at the end of the file
int m_last_call = 0;
if (in.peek() != EOF) {
in.read((char *)&m_last_call, sizeof(m_last_call));
@ -197,15 +206,130 @@ static int load_imatrix(const std::string & imatrix_file, std::string & imatrix_
in.read((char *)&dataset_len, sizeof(dataset_len));
std::vector<char> dataset_as_vec(dataset_len);
in.read(dataset_as_vec.data(), dataset_len);
imatrix_dataset.assign(dataset_as_vec.begin(), dataset_as_vec.end());
printf("%s: imatrix dataset='%s'\n", __func__, imatrix_dataset.c_str());
imatrix_datasets.resize(1);
imatrix_datasets[0].assign(dataset_as_vec.begin(), dataset_as_vec.end());
printf("%s: imatrix dataset='%s'\n", __func__, imatrix_datasets[0].c_str());
}
printf("%s: loaded %d importance matrix entries from %s computed on %d chunks\n", __func__, int(imatrix_data.size()), imatrix_file.c_str(), m_last_call);
return m_last_call;
}
static int load_imatrix(const std::string & imatrix_file, std::vector<std::string> & imatrix_datasets, std::unordered_map<std::string, std::vector<float>> & imatrix_data) {
struct ggml_context * ctx = nullptr;
struct gguf_init_params meta_gguf_params = {
/* .no_alloc = */ false, // the data is needed
/* .ctx = */ &ctx,
};
struct gguf_context * ctx_gguf = gguf_init_from_file(imatrix_file.c_str(), meta_gguf_params);
if (!ctx_gguf) {
fprintf(stderr, "%s: imatrix file '%s' is using old format\n", __func__, imatrix_file.c_str());
return load_legacy_imatrix(imatrix_file, imatrix_datasets, imatrix_data);
}
const int32_t n_entries = gguf_get_n_tensors(ctx_gguf);
if (n_entries < 1) {
fprintf(stderr, "%s: no data in file %s\n", __func__, imatrix_file.c_str());
gguf_free(ctx_gguf);
ggml_free(ctx);
exit(1);
}
const int dataset_idx = gguf_find_key(ctx_gguf, LLM_KV_IMATRIX_DATASETS);
const int chunk_count_idx = gguf_find_key(ctx_gguf, LLM_KV_IMATRIX_CHUNK_COUNT);
const int chunk_size_idx = gguf_find_key(ctx_gguf, LLM_KV_IMATRIX_CHUNK_SIZE);
if (dataset_idx < 0 || chunk_count_idx < 0 || chunk_size_idx < 0) {
fprintf(stderr, "%s: missing imatrix metadata in file %s\n", __func__, imatrix_file.c_str());
gguf_free(ctx_gguf);
ggml_free(ctx);
exit(1);
}
const uint32_t chunk_size = gguf_get_val_u32(ctx_gguf, chunk_size_idx);
const std::string sums_suffix{ ".in_sum2" };
const std::string counts_suffix{ ".counts" };
// Using an ordered map to get a deterministic iteration order.
std::map<std::string, std::pair<struct ggml_tensor *, struct ggml_tensor *>> sums_counts_for;
for (struct ggml_tensor * cur = ggml_get_first_tensor(ctx); cur; cur = ggml_get_next_tensor(ctx, cur)) {
std::string name = cur->name;
if (name.empty()) { continue; }
if (string_remove_suffix(name, sums_suffix)) {
// in_sum2
sums_counts_for[std::move(name)].first = cur;
} else if (string_remove_suffix(name, counts_suffix)) {
// counts
sums_counts_for[std::move(name)].second = cur;
} else {
// ignore other tensors
}
}
for (const auto & sc : sums_counts_for) {
const std::string & name = sc.first;
const struct ggml_tensor * sums = sc.second.first;
const struct ggml_tensor * counts = sc.second.second;
if (!sums || !counts) {
fprintf(stderr, "%s: mismatched sums and counts for %s\n", __func__, name.c_str());
gguf_free(ctx_gguf);
ggml_free(ctx);
exit(1);
}
const int64_t ne0 = sums->ne[0];
const int64_t ne1 = sums->ne[1];
auto & e = imatrix_data[name];
e.resize(ggml_nelements(sums));
float max_count = 0.0f;
for (int64_t j = 0; j < ne1; ++j) {
const float count = ((const float *) counts->data)[j];
if (count > 0.0f) {
for (int64_t i = 0; i < ne0; ++i) {
e[j*ne0 + i] = ((const float *) sums->data)[j*ne0 + i] / count;
}
} else {
// Partial imatrix data, this tensor never got any input during calibration
for (int64_t i = 0; i < ne0; ++i) {
e[j*ne0 + i] = 1;
}
}
if (count > max_count) {
max_count = count;
}
}
if (getenv("LLAMA_TRACE")) {
printf("%s: loaded data (size = %6d, n_tokens = %6d, n_chunks = %6d) for '%s'\n", __func__, int(e.size()), int(max_count), int(max_count / chunk_size), name.c_str());
}
}
int m_last_chunk = gguf_get_val_u32(ctx_gguf, chunk_count_idx);
int64_t n_datasets = gguf_get_arr_n(ctx_gguf, dataset_idx);
imatrix_datasets.reserve(n_datasets);
for (int64_t i = 0; i < n_datasets; ++i) {
imatrix_datasets.push_back(gguf_get_val_str(ctx_gguf, dataset_idx));
}
printf("%s: imatrix datasets=['%s'", __func__, imatrix_datasets[0].c_str());
for (size_t i = 1; i < imatrix_datasets.size(); ++i) {
printf(", '%s'", imatrix_datasets[i].c_str());
}
printf("]\n");
printf("%s: loaded %d importance matrix entries from %s computed on %d chunks\n", __func__, int(imatrix_data.size()), imatrix_file.c_str(), m_last_chunk);
gguf_free(ctx_gguf);
ggml_free(ctx);
return m_last_chunk;
}
static int prepare_imatrix(const std::string & imatrix_file,
std::string & imatrix_dataset,
std::vector<std::string> & imatrix_dataset,
const std::vector<std::string> & included_weights,
const std::vector<std::string> & excluded_weights,
std::unordered_map<std::string, std::vector<float>> & imatrix_data) {
@ -217,18 +341,21 @@ static int prepare_imatrix(const std::string & imatrix_file,
return m_last_call;
}
if (!excluded_weights.empty()) {
for (auto& name : excluded_weights) {
for (auto it = imatrix_data.begin(); it != imatrix_data.end(); ) {
for (const auto & name : excluded_weights) {
for (auto it = imatrix_data.begin(); it != imatrix_data.end();) {
auto pos = it->first.find(name);
if (pos != std::string::npos) it = imatrix_data.erase(it);
else ++it;
if (pos != std::string::npos) {
it = imatrix_data.erase(it);
} else {
++it;
}
}
}
}
if (!included_weights.empty()) {
std::unordered_map<std::string, std::vector<float>> tmp;
for (auto& name : included_weights) {
for (auto& e : imatrix_data) {
for (const auto & name : included_weights) {
for (auto & e : imatrix_data) {
auto pos = e.first.find(name);
if (pos != std::string::npos) {
tmp.emplace(std::move(e));
@ -397,9 +524,9 @@ int main(int argc, char ** argv) {
usage(argv[0]);
}
std::string imatrix_dataset;
std::vector<std::string> imatrix_datasets;
std::unordered_map<std::string, std::vector<float>> imatrix_data;
int m_last_call = prepare_imatrix(imatrix_file, imatrix_dataset, included_weights, excluded_weights, imatrix_data);
int m_last_call = prepare_imatrix(imatrix_file, imatrix_datasets, included_weights, excluded_weights, imatrix_data);
if (!imatrix_data.empty()) {
params.imatrix = &imatrix_data;
{
@ -410,11 +537,12 @@ int main(int argc, char ** argv) {
kvo.val_str[127] = '\0';
kv_overrides.emplace_back(std::move(kvo));
}
if (!imatrix_dataset.empty()) {
if (!imatrix_datasets.empty()) {
llama_model_kv_override kvo;
// TODO: list multiple datasets when there are more than one
std::strcpy(kvo.key, LLM_KV_QUANTIZE_IMATRIX_DATASET);
kvo.tag = LLAMA_KV_OVERRIDE_TYPE_STR;
strncpy(kvo.val_str, imatrix_dataset.c_str(), 127);
strncpy(kvo.val_str, imatrix_datasets[0].c_str(), 127);
kvo.val_str[127] = '\0';
kv_overrides.emplace_back(std::move(kvo));
}