From bc091a4dc585af25c438c8473285a8cfec5c7695 Mon Sep 17 00:00:00 2001 From: Prajwal B Mehendarkar Date: Sat, 12 Apr 2025 21:03:39 +0530 Subject: [PATCH 01/20] common : Define cache directory on AIX (#12915) --- common/common.cpp | 2 +- examples/rpc/rpc-server.cpp | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/common/common.cpp b/common/common.cpp index f18420794..94f545f81 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -830,7 +830,7 @@ std::string fs_get_cache_directory() { if (getenv("LLAMA_CACHE")) { cache_directory = std::getenv("LLAMA_CACHE"); } else { -#if defined(__linux__) || defined(__FreeBSD__) +#if defined(__linux__) || defined(__FreeBSD__) || defined(_AIX) if (std::getenv("XDG_CACHE_HOME")) { cache_directory = std::getenv("XDG_CACHE_HOME"); } else { diff --git a/examples/rpc/rpc-server.cpp b/examples/rpc/rpc-server.cpp index b61749c3f..f8f2cb90e 100644 --- a/examples/rpc/rpc-server.cpp +++ b/examples/rpc/rpc-server.cpp @@ -126,7 +126,7 @@ static std::string fs_get_cache_directory() { if (getenv("LLAMA_CACHE")) { cache_directory = std::getenv("LLAMA_CACHE"); } else { -#if defined(__linux__) || defined(__FreeBSD__) +#if defined(__linux__) || defined(__FreeBSD__) || defined(_AIX) if (std::getenv("XDG_CACHE_HOME")) { cache_directory = std::getenv("XDG_CACHE_HOME"); } else { From 71e90e8813f90097701e62f7fce137d96ddf41e2 Mon Sep 17 00:00:00 2001 From: Ed Addario <29247825+EAddario@users.noreply.github.com> Date: Sun, 13 Apr 2025 19:29:28 +0100 Subject: [PATCH 02/20] quantize: Handle user-defined quantization levels for additional tensors (#12511) * Add llama_model_quantize_params parameters * Add new quantize parameters parsing and validation * Update usage * Add new parameters defaults * Add new quantization parameters logic * Add llama_model_quantize_params parameters * Add new quantize parameters parsing and validation * Update usage * Add new parameters defaults * Add new quantization parameters logic * Minor refactoring as per the contributors' coding guidelines * Update descriptions to match existing style * Add llama_model_quantize_params parameters * Add new quantize parameters parsing and validation * Update usage * Add new parameters defaults * Add new quantization parameters logic * Minor refactoring as per the contributors' guidelines * Implement general --tensor-type instead of tensor-specific command option * Fix implied type bug * Restore missing #includes * Add regex capability for tensor selection * Refactor function name and update ALLOWED_TENSOR_TYPE * Add missing #include * Handle edge case when tensor name is cls.output * Minor logging improvement --- examples/quantize/quantize.cpp | 117 ++++++++++++++++++++++++++++++++- include/llama.h | 23 +++---- src/llama-quant.cpp | 35 ++++++++-- 3 files changed, 155 insertions(+), 20 deletions(-) diff --git a/examples/quantize/quantize.cpp b/examples/quantize/quantize.cpp index a4468b169..0355311dc 100644 --- a/examples/quantize/quantize.cpp +++ b/examples/quantize/quantize.cpp @@ -9,6 +9,7 @@ #include #include #include +#include struct quant_option { std::string name; @@ -16,7 +17,7 @@ struct quant_option { std::string desc; }; -static const std::vector QUANT_OPTIONS = { +static const std::vector QUANT_OPTIONS = { { "Q4_0", LLAMA_FTYPE_MOSTLY_Q4_0, " 4.34G, +0.4685 ppl @ Llama-3-8B", }, { "Q4_1", LLAMA_FTYPE_MOSTLY_Q4_1, " 4.78G, +0.4511 ppl @ Llama-3-8B", }, { "Q5_0", LLAMA_FTYPE_MOSTLY_Q5_0, " 5.21G, +0.1316 ppl @ Llama-3-8B", }, @@ -105,7 +106,8 @@ static bool try_parse_ftype(const std::string & ftype_str_in, llama_ftype & ftyp // [[noreturn]] static void usage(const char * executable) { - printf("usage: %s [--help] [--allow-requantize] [--leave-output-tensor] [--pure] [--imatrix] [--include-weights] [--exclude-weights] [--output-tensor-type] [--token-embedding-type] [--override-kv] model-f32.gguf [model-quant.gguf] type [nthreads]\n\n", executable); + printf("usage: %s [--help] [--allow-requantize] [--leave-output-tensor] [--pure] [--imatrix] [--include-weights] [--exclude-weights] [--output-tensor-type]\n", executable); + printf(" [--token-embedding-type] [--tensor-type] [--keep-split] [--override-kv] model-f32.gguf [model-quant.gguf] type [nthreads]\n\n"); printf(" --allow-requantize: Allows requantizing tensors that have already been quantized. Warning: This can severely reduce quality compared to quantizing from 16bit or 32bit\n"); printf(" --leave-output-tensor: Will leave output.weight un(re)quantized. Increases model size but may also increase quality, especially when requantizing\n"); printf(" --pure: Disable k-quant mixtures and quantize all tensors to the same type\n"); @@ -114,6 +116,8 @@ static void usage(const char * executable) { printf(" --exclude-weights tensor_name: use importance matrix for this/these tensor(s)\n"); printf(" --output-tensor-type ggml_type: use this ggml_type for the output.weight tensor\n"); printf(" --token-embedding-type ggml_type: use this ggml_type for the token embeddings tensor\n"); + printf(" --tensor-type TENSOR=TYPE: quantize this tensor to this ggml_type. example: --tensor-type attn_q=q8_0\n"); + printf(" Advanced option to selectively quantize tensors. May be specified multiple times.\n"); printf(" --keep-split: will generate quantized model in the same shards as input\n"); printf(" --override-kv KEY=TYPE:VALUE\n"); printf(" Advanced option to override model metadata by key in the quantized model. May be specified multiple times.\n"); @@ -244,6 +248,107 @@ static ggml_type parse_ggml_type(const char * arg) { return GGML_TYPE_COUNT; } +// Allowed tensors for arbitrary quantization with --tensor-type option +static const std::vector ALLOWED_TENSOR_TYPE = { + "attn_k", + "attn_kv_a_mqa", + "attn_kv_b", + "attn_o", + "attn_output", + "attn_q", + "attn_q_a", + "attn_q_b", + "attn_qkv", + "attn_v", + "channel_mix_key", + "channel_mix_receptance", + "channel_mix_value", + "cls", + "cls.output", + "cross_attn_k", + "cross_attn_o", + "cross_attn_q", + "cross_attn_v", + "ffn_act", + "ffn_down", + "ffn_down_exps", + "ffn_down_shexp", + "ffn_gate", + "ffn_gate_exps", + "ffn_gate_shexp", + "ffn_up", + "ffn_up_exps", + "ffn_up_shexp", + "ssm_in", + "ssm_out", + "time_mix_gate", + "time_mix_key", + "time_mix_output", + "time_mix_receptance", + "time_mix_value", +}; + +// changes to this struct must be replicated in llama-quant.cpp +struct tensor_quantization { + std::string name; + ggml_type quant = GGML_TYPE_COUNT; +}; + +static bool parse_tensor_type(const char * data, std::vector & tensor_type) { + const char * sep = strchr(data, '='); + if (sep == nullptr) { + printf("\n%s: malformed tensor type '%s'\n\n", __func__, data); + return false; + } + + const size_t tn_len = sep - data; + if (tn_len == 0) { + printf("\n%s: missing tensor name\n\n", __func__); + return false; + } + + if (const size_t qt_len = strlen(sep); qt_len == 1) { + printf("\n%s: missing quantization type\n\n", __func__); + return false; + } + + std::string tn(data, tn_len); + std::transform(tn.begin(), tn.end(), tn.begin(), tolower); + sep++; + const std::string qt(sep); + + bool found = false; + for (const auto & allowed : ALLOWED_TENSOR_TYPE) { + std::string tensor; + tensor = tn.rfind('.') != std::string::npos ? tn.substr(tn.rfind('.') + 1) : tn; + // handle special case of cls.output + std::string cls_output = "cls.output"; + if (tn.find(cls_output) != std::string::npos) { + tensor = "cls.output"; + } + // check if an allowed tensor exists and it's at the end of the kv string + if (tensor == allowed) { + found = true; + break; + } + } + if (!found) { + printf("\n%s: invalid tensor name '%s'\n\n", __func__, tn.c_str()); + return false; + } + + if (parse_ggml_type(qt.c_str()) == GGML_TYPE_COUNT) { + printf("\n%s: invalid quantization type '%s'\n\n", __func__, qt.c_str()); + return false; + } + + tensor_quantization tqz; + tqz.name = tn; + tqz.quant = parse_ggml_type(qt.c_str()); + tensor_type.emplace_back(std::move(tqz)); + return true; +} + int main(int argc, char ** argv) { if (argc < 3) { usage(argv[0]); @@ -255,6 +360,7 @@ int main(int argc, char ** argv) { std::string imatrix_file; std::vector included_weights, excluded_weights; std::vector kv_overrides; + std::vector tensor_types; for (; arg_idx < argc && strncmp(argv[arg_idx], "--", 2) == 0; arg_idx++) { if (strcmp(argv[arg_idx], "--leave-output-tensor") == 0) { @@ -277,6 +383,10 @@ int main(int argc, char ** argv) { } else { usage(argv[0]); } + } else if (strcmp(argv[arg_idx], "--tensor-type") == 0) { + if (arg_idx == argc-1 || !parse_tensor_type(argv[++arg_idx], tensor_types)) { + usage(argv[0]); + } } else if (strcmp(argv[arg_idx], "--override-kv") == 0) { if (arg_idx == argc-1 || !string_parse_kv_override(argv[++arg_idx], kv_overrides)) { usage(argv[0]); @@ -361,6 +471,9 @@ int main(int argc, char ** argv) { kv_overrides.back().key[0] = 0; params.kv_overrides = &kv_overrides; } + if (!tensor_types.empty()) { + params.tensor_types = &tensor_types; + } llama_backend_init(); diff --git a/include/llama.h b/include/llama.h index 5fd99f4c8..5657fbf0a 100644 --- a/include/llama.h +++ b/include/llama.h @@ -367,17 +367,18 @@ extern "C" { // model quantization parameters typedef struct llama_model_quantize_params { - int32_t nthread; // number of threads to use for quantizing, if <=0 will use std::thread::hardware_concurrency() - enum llama_ftype ftype; // quantize to this llama_ftype - enum ggml_type output_tensor_type; // output tensor type - enum ggml_type token_embedding_type; // token embeddings tensor type - bool allow_requantize; // allow quantizing non-f32/f16 tensors - bool quantize_output_tensor; // quantize output.weight - bool only_copy; // only copy tensors - ftype, allow_requantize and quantize_output_tensor are ignored - bool pure; // quantize all tensors to the default type - bool keep_split; // quantize to the same number of shards - void * imatrix; // pointer to importance matrix data - void * kv_overrides; // pointer to vector containing overrides + int32_t nthread; // number of threads to use for quantizing, if <=0 will use std::thread::hardware_concurrency() + enum llama_ftype ftype; // quantize to this llama_ftype + enum ggml_type output_tensor_type; // output tensor type + enum ggml_type token_embedding_type; // token embeddings tensor type + bool allow_requantize; // allow quantizing non-f32/f16 tensors + bool quantize_output_tensor; // quantize output.weight + bool only_copy; // only copy tensors - ftype, allow_requantize and quantize_output_tensor are ignored + bool pure; // quantize all tensors to the default type + bool keep_split; // quantize to the same number of shards + void * imatrix; // pointer to importance matrix data + void * kv_overrides; // pointer to vector containing overrides + void * tensor_types; // pointer to vector containing tensor types } llama_model_quantize_params; typedef struct llama_logit_bias { diff --git a/src/llama-quant.cpp b/src/llama-quant.cpp index e3e10fa6c..7dc542276 100644 --- a/src/llama-quant.cpp +++ b/src/llama-quant.cpp @@ -10,6 +10,7 @@ #include #include #include +#include #include #include @@ -47,8 +48,14 @@ struct quantize_state_impl { {} }; +// changes to this struct must be replicated in quantize.cpp +struct tensor_quantization { + std::string name; + ggml_type quant = GGML_TYPE_COUNT; +}; + static void llama_tensor_dequantize_impl( - struct ggml_tensor * tensor, std::vector> & output, std::vector & workers, + ggml_tensor * tensor, std::vector> & output, std::vector & workers, const size_t nelements, const int nthread ) { if (output.size() < nelements) { @@ -536,7 +543,7 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std:: model.load_hparams(ml); model.load_stats (ml); - struct quantize_state_impl qs(model, params); + quantize_state_impl qs(model, params); if (params->only_copy) { ftype = ml.ftype; @@ -661,7 +668,7 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std:: // populate the original tensors so we get an initial meta data for (const auto * it : tensors) { uint16_t i_split = params->keep_split ? it->idx : 0; - struct ggml_tensor * tensor = it->tensor; + ggml_tensor * tensor = it->tensor; if (!ctx_outs[i_split]) { ctx_outs[i_split].reset(gguf_init_empty()); } @@ -710,7 +717,7 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std:: new_ofstream(0); for (const auto * it : tensors) { const auto & weight = *it; - struct ggml_tensor * tensor = weight.tensor; + ggml_tensor * tensor = weight.tensor; if (weight.idx != cur_split && params->keep_split) { close_ofstream(); new_ofstream(weight.idx); @@ -776,7 +783,7 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std:: // do not quantize relative position bias (T5) quantize &= name.find("attn_rel_b.weight") == std::string::npos; - enum ggml_type new_type; + ggml_type new_type; void * new_data; size_t new_size; @@ -786,6 +793,19 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std:: // get more optimal quantization type based on the tensor shape, layer, etc. if (!params->pure && ggml_is_quantized(default_type)) { new_type = llama_tensor_get_type(qs, new_type, tensor, ftype); + // unless the user specifies a type + if (params->tensor_types) { + const std::vector & tensor_types = *static_cast *>(params->tensor_types); + for (const auto & [tname, qtype] : tensor_types) { + if (std::regex pattern(tname); std::regex_search(tensor->name, pattern)) { + if (qtype != new_type) { + LLAMA_LOG_DEBUG("(overriding %s -> %s), ", ggml_type_name(new_type), ggml_type_name(qtype)); + } + new_type = qtype; + break; + } + } + } } if (params->token_embedding_type < GGML_TYPE_COUNT && strcmp(tensor->name, "token_embd.weight") == 0) { new_type = params->token_embedding_type; @@ -910,8 +930,8 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std:: // interface implementation // -struct llama_model_quantize_params llama_model_quantize_default_params() { - struct llama_model_quantize_params result = { +llama_model_quantize_params llama_model_quantize_default_params() { + llama_model_quantize_params result = { /*.nthread =*/ 0, /*.ftype =*/ LLAMA_FTYPE_MOSTLY_Q5_1, /*.output_tensor_type =*/ GGML_TYPE_COUNT, @@ -923,6 +943,7 @@ struct llama_model_quantize_params llama_model_quantize_default_params() { /*.keep_split =*/ false, /*.imatrix =*/ nullptr, /*.kv_overrides =*/ nullptr, + /*.tensor_type =*/ nullptr, }; return result; From 307bfa253dea07c9270e78fa53b133504e9c3c9d Mon Sep 17 00:00:00 2001 From: Alan Gray Date: Sun, 13 Apr 2025 22:12:21 +0100 Subject: [PATCH 03/20] ggml: disable CUDA graphs for unsupported DUP and CONT node types (#12891) Fixes #12798 --- ggml/src/ggml-cuda/ggml-cuda.cu | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index fafe9633e..4af189701 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -2488,10 +2488,10 @@ static bool check_node_graph_compatibility_and_refresh_copy_ops(ggml_backend_cud #endif } - if (node->op == GGML_OP_MUL_MAT_ID) { + if (node->op == GGML_OP_MUL_MAT_ID || node->op == GGML_OP_CONT || node->op == GGML_OP_DUP) { use_cuda_graph = false; // This node type is not supported by CUDA graph capture #ifndef NDEBUG - GGML_LOG_DEBUG("%s: disabling CUDA graphs due to mul_mat_id\n", __func__); + GGML_LOG_DEBUG("%s: disabling CUDA graphs due to unsupported node type\n", __func__); #endif } From e959d32b1c16c02c0bc0bcdc7bc3f7077bc32f99 Mon Sep 17 00:00:00 2001 From: SXX Date: Mon, 14 Apr 2025 13:47:55 +0800 Subject: [PATCH 04/20] ggml: use _mm[512/256]_dpbusd[_avx]_epi32 to directly accumulate into the result register (#12773) * ggml: use _mm[512/256]_dpbusd[_avx]_epi32 to directly accumulate into the result register * simplifies the codebase by removing redundant functions --- ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp | 162 ++++++++++--------------- 1 file changed, 65 insertions(+), 97 deletions(-) diff --git a/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp b/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp index 74a31abb2..ced09c05c 100644 --- a/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp +++ b/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp @@ -183,67 +183,63 @@ static inline __m256 __avx_rearranged_f32cx8_load(ggml_fp16_t *x, __m128i arrang #if defined(__AVX2__) || defined(__AVX512F__) #if defined(__AVX512F__) -// add int16_t pairwise and return as 512 bit int vector -static inline __m512i sum_i16_pairs_int_32x16(const __m512i x) { +// add int16_t pairwise and return as 512 bit int vector, then add the accumulator +static inline __m512i sum_i16_pairs_acc_int32x16(const __m512i acc, const __m512i x) { const __m512i ones = _mm512_set1_epi16(1); - return _mm512_madd_epi16(ones, x); + return _mm512_add_epi32(acc, _mm512_madd_epi16(ones, x)); } -static inline __m512i mul_sum_us8_pairs_int32x16(const __m512i ax, const __m512i sy) { +static inline __m512i mul_sum_us8_pairs_acc_int32x16(const __m512i acc, const __m512i ax, const __m512i sy) { #if defined(__AVX512VNNI__) - const __m512i zero = _mm512_setzero_si512(); - return _mm512_dpbusd_epi32(zero, ax, sy); + return _mm512_dpbusd_epi32(acc, ax, sy); #else // Perform multiplication and create 16-bit values const __m512i dot = _mm512_maddubs_epi16(ax, sy); - return sum_i16_pairs_int_32x16(dot); + return sum_i16_pairs_acc_int32x16(acc, dot); #endif } -// multiply int8_t, add results pairwise twice and return as 512 bit int vector -static inline __m512i mul_sum_i8_pairs_int32x16(const __m512i x, const __m512i y) { +// multiply int8_t, add results pairwise twice and return as 512 bit int vector,then add the accumulator +static inline __m512i mul_sum_i8_pairs_acc_int32x16(const __m512i acc, const __m512i x, const __m512i y) { const __m512i zero = _mm512_setzero_si512(); // Get absolute values of x vectors const __m512i ax = _mm512_abs_epi8(x); // Sign the values of the y vectors __mmask64 blt0 = _mm512_movepi8_mask(x); const __m512i sy = _mm512_mask_sub_epi8(y, blt0, zero, y); - return mul_sum_us8_pairs_int32x16(ax, sy); + return mul_sum_us8_pairs_acc_int32x16(acc, ax, sy); } #endif -// add int16_t pairwise and return as 256 bit int vector -static inline __m256i sum_i16_pairs_int32x8(const __m256i x) { +// add int16_t pairwise and return as 256 bit int vector, then add the accumulator +static inline __m256i sum_i16_pairs_acc_int32x8(const __m256i acc, const __m256i x) { const __m256i ones = _mm256_set1_epi16(1); - return _mm256_madd_epi16(ones, x); + return _mm256_add_epi32(acc, _mm256_madd_epi16(ones, x)); } -static inline __m256i mul_sum_us8_pairs_int32x8(const __m256i ax, const __m256i sy) { +static inline __m256i mul_sum_us8_pairs_acc_int32x8(const __m256i acc, const __m256i ax, const __m256i sy) { #if defined(__AVX512VNNI__) && defined(__AVX512VL__) - const __m256i zero = _mm256_setzero_si256(); - return _mm256_dpbusd_epi32(zero, ax, sy); + return _mm256_dpbusd_epi32(acc, ax, sy); #elif defined(__AVXVNNI__) - const __m256i zero = _mm256_setzero_si256(); - return _mm256_dpbusd_avx_epi32(zero, ax, sy); + return _mm256_dpbusd_avx_epi32(acc, ax, sy); #else // Perform multiplication and create 16-bit values const __m256i dot = _mm256_maddubs_epi16(ax, sy); - return sum_i16_pairs_int32x8(dot); + return sum_i16_pairs_acc_int32x8(acc, dot); #endif } // Integer variant of the function defined in ggml-quants.c -// multiply int8_t, add results pairwise twice and return as 256 bit int vector -static inline __m256i mul_sum_i8_pairs_int32x8(const __m256i x, const __m256i y) { -#if __AVXVNNIINT8__ - const __m256i zero = _mm256_setzero_si256(); - return _mm256_dpbssd_epi32(zero, x, y); +// multiply int8_t, add results pairwise twice and return as 256 bit int vector, then add the accumulator +static inline __m256i mul_sum_i8_pairs_acc_int32x8(const __m256i acc, const __m256i x, const __m256i y) { +#if defined(__AVXVNNIINT8__) + return _mm256_dpbssd_epi32(acc, x, y); #else // Get absolute values of x vectors const __m256i ax = _mm256_sign_epi8(x, x); // Sign the values of the y vectors const __m256i sy = _mm256_sign_epi8(y, x); - return mul_sum_us8_pairs_int32x8(ax, sy); + return mul_sum_us8_pairs_acc_int32x8(acc, ax, sy); #endif } #endif @@ -1175,17 +1171,17 @@ static void ggml_gemv_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, c // ........................................................................... // B0(28-31) B4(28-31) B1(28-31) B5(28-31) B2(28-31) B6(28-31) B3(28-31) B7(28-31) with A0(28-31) - iacc = _mm256_add_epi32(iacc, mul_sum_i8_pairs_int32x8(_mm256_blend_epi32(rhs_vec_0123_0 ,_mm256_shuffle_epi32(rhs_vec_4567_0, 177), 170), _mm256_shuffle_epi32(lhs_vec_0, 0))); - iacc = _mm256_add_epi32(iacc, mul_sum_i8_pairs_int32x8(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_0, 177) ,rhs_vec_4567_0, 170), _mm256_shuffle_epi32(lhs_vec_0, 85))); + iacc = mul_sum_i8_pairs_acc_int32x8(iacc, _mm256_blend_epi32(rhs_vec_0123_0 ,_mm256_shuffle_epi32(rhs_vec_4567_0, 177), 170), _mm256_shuffle_epi32(lhs_vec_0, 0)); + iacc = mul_sum_i8_pairs_acc_int32x8(iacc, _mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_0, 177) ,rhs_vec_4567_0, 170), _mm256_shuffle_epi32(lhs_vec_0, 85)); - iacc = _mm256_add_epi32(iacc, mul_sum_i8_pairs_int32x8(_mm256_blend_epi32(rhs_vec_0123_1 ,_mm256_shuffle_epi32(rhs_vec_4567_1, 177), 170), _mm256_shuffle_epi32(lhs_vec_0, 170))); - iacc = _mm256_add_epi32(iacc, mul_sum_i8_pairs_int32x8(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_1, 177) ,rhs_vec_4567_1, 170), _mm256_shuffle_epi32(lhs_vec_0, 255))); + iacc = mul_sum_i8_pairs_acc_int32x8(iacc, _mm256_blend_epi32(rhs_vec_0123_1 ,_mm256_shuffle_epi32(rhs_vec_4567_1, 177), 170), _mm256_shuffle_epi32(lhs_vec_0, 170)); + iacc = mul_sum_i8_pairs_acc_int32x8(iacc, _mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_1, 177) ,rhs_vec_4567_1, 170), _mm256_shuffle_epi32(lhs_vec_0, 255)); - iacc = _mm256_add_epi32(iacc, mul_sum_i8_pairs_int32x8(_mm256_blend_epi32(rhs_vec_0123_2 ,_mm256_shuffle_epi32(rhs_vec_4567_2, 177), 170), _mm256_shuffle_epi32(lhs_vec_1, 0))); - iacc = _mm256_add_epi32(iacc, mul_sum_i8_pairs_int32x8(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_2, 177) ,rhs_vec_4567_2, 170), _mm256_shuffle_epi32(lhs_vec_1, 85))); + iacc = mul_sum_i8_pairs_acc_int32x8(iacc, _mm256_blend_epi32(rhs_vec_0123_2 ,_mm256_shuffle_epi32(rhs_vec_4567_2, 177), 170), _mm256_shuffle_epi32(lhs_vec_1, 0)); + iacc = mul_sum_i8_pairs_acc_int32x8(iacc, _mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_2, 177) ,rhs_vec_4567_2, 170), _mm256_shuffle_epi32(lhs_vec_1, 85)); - iacc = _mm256_add_epi32(iacc, mul_sum_i8_pairs_int32x8(_mm256_blend_epi32(rhs_vec_0123_3 ,_mm256_shuffle_epi32(rhs_vec_4567_3, 177), 170), _mm256_shuffle_epi32(lhs_vec_1, 170))); - iacc = _mm256_add_epi32(iacc, mul_sum_i8_pairs_int32x8(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_3, 177) ,rhs_vec_4567_3, 170), _mm256_shuffle_epi32(lhs_vec_1, 255))); + iacc = mul_sum_i8_pairs_acc_int32x8(iacc, _mm256_blend_epi32(rhs_vec_0123_3 ,_mm256_shuffle_epi32(rhs_vec_4567_3, 177), 170), _mm256_shuffle_epi32(lhs_vec_1, 170)); + iacc = mul_sum_i8_pairs_acc_int32x8(iacc, _mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_3, 177) ,rhs_vec_4567_3, 170), _mm256_shuffle_epi32(lhs_vec_1, 255)); // Accumulated values multipled with appropriate scales acc_row = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc), _mm256_mul_ps(col_scale_f32, row_scale_f32), acc_row); @@ -3239,22 +3235,15 @@ static void ggml_gemm_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, c // The values arranged in shuffle patterns are operated with dot product operation within 32 bit lane i.e corresponding bytes and multiplied and added into 32 bit integers within 32 bit lane // Resembles MMLAs into 2x2 matrices in ARM Version - __m512i iacc_mat_00_sp1 = - _mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int32x16(lhs_mat_01_3_sp1, rhs_mat_014589CD_3_sp1), mul_sum_i8_pairs_int32x16(lhs_mat_01_2_sp1, rhs_mat_014589CD_2_sp1)), mul_sum_i8_pairs_int32x16(lhs_mat_01_1_sp1, rhs_mat_014589CD_1_sp1)), mul_sum_i8_pairs_int32x16(lhs_mat_01_0_sp1, rhs_mat_014589CD_0_sp1)); - __m512i iacc_mat_01_sp1 = - _mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int32x16(lhs_mat_01_3_sp1, rhs_mat_2367ABEF_3_sp1), mul_sum_i8_pairs_int32x16(lhs_mat_01_2_sp1, rhs_mat_2367ABEF_2_sp1)), mul_sum_i8_pairs_int32x16(lhs_mat_01_1_sp1, rhs_mat_2367ABEF_1_sp1)), mul_sum_i8_pairs_int32x16(lhs_mat_01_0_sp1, rhs_mat_2367ABEF_0_sp1)); - __m512i iacc_mat_10_sp1 = - _mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int32x16(lhs_mat_23_3_sp1, rhs_mat_014589CD_3_sp1), mul_sum_i8_pairs_int32x16(lhs_mat_23_2_sp1, rhs_mat_014589CD_2_sp1)), mul_sum_i8_pairs_int32x16(lhs_mat_23_1_sp1, rhs_mat_014589CD_1_sp1)), mul_sum_i8_pairs_int32x16(lhs_mat_23_0_sp1, rhs_mat_014589CD_0_sp1)); - __m512i iacc_mat_11_sp1 = - _mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int32x16(lhs_mat_23_3_sp1, rhs_mat_2367ABEF_3_sp1), mul_sum_i8_pairs_int32x16(lhs_mat_23_2_sp1, rhs_mat_2367ABEF_2_sp1)), mul_sum_i8_pairs_int32x16(lhs_mat_23_1_sp1, rhs_mat_2367ABEF_1_sp1)), mul_sum_i8_pairs_int32x16(lhs_mat_23_0_sp1, rhs_mat_2367ABEF_0_sp1)); - __m512i iacc_mat_00_sp2 = - _mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int32x16(lhs_mat_01_3_sp2, rhs_mat_014589CD_3_sp2), mul_sum_i8_pairs_int32x16(lhs_mat_01_2_sp2, rhs_mat_014589CD_2_sp2)), mul_sum_i8_pairs_int32x16(lhs_mat_01_1_sp2, rhs_mat_014589CD_1_sp2)), mul_sum_i8_pairs_int32x16(lhs_mat_01_0_sp2, rhs_mat_014589CD_0_sp2)); - __m512i iacc_mat_01_sp2 = - _mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int32x16(lhs_mat_01_3_sp2, rhs_mat_2367ABEF_3_sp2), mul_sum_i8_pairs_int32x16(lhs_mat_01_2_sp2, rhs_mat_2367ABEF_2_sp2)), mul_sum_i8_pairs_int32x16(lhs_mat_01_1_sp2, rhs_mat_2367ABEF_1_sp2)), mul_sum_i8_pairs_int32x16(lhs_mat_01_0_sp2, rhs_mat_2367ABEF_0_sp2)); - __m512i iacc_mat_10_sp2 = - _mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int32x16(lhs_mat_23_3_sp2, rhs_mat_014589CD_3_sp2), mul_sum_i8_pairs_int32x16(lhs_mat_23_2_sp2, rhs_mat_014589CD_2_sp2)), mul_sum_i8_pairs_int32x16(lhs_mat_23_1_sp2, rhs_mat_014589CD_1_sp2)), mul_sum_i8_pairs_int32x16(lhs_mat_23_0_sp2, rhs_mat_014589CD_0_sp2)); - __m512i iacc_mat_11_sp2 = - _mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int32x16(lhs_mat_23_3_sp2, rhs_mat_2367ABEF_3_sp2), mul_sum_i8_pairs_int32x16(lhs_mat_23_2_sp2, rhs_mat_2367ABEF_2_sp2)), mul_sum_i8_pairs_int32x16(lhs_mat_23_1_sp2, rhs_mat_2367ABEF_1_sp2)), mul_sum_i8_pairs_int32x16(lhs_mat_23_0_sp2, rhs_mat_2367ABEF_0_sp2)); + const __m512i zero = _mm512_setzero_epi32(); + __m512i iacc_mat_00_sp1 = mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(zero, lhs_mat_01_3_sp1, rhs_mat_014589CD_3_sp1), lhs_mat_01_2_sp1, rhs_mat_014589CD_2_sp1), lhs_mat_01_1_sp1, rhs_mat_014589CD_1_sp1), lhs_mat_01_0_sp1, rhs_mat_014589CD_0_sp1); + __m512i iacc_mat_01_sp1 = mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(zero, lhs_mat_01_3_sp1, rhs_mat_2367ABEF_3_sp1), lhs_mat_01_2_sp1, rhs_mat_2367ABEF_2_sp1), lhs_mat_01_1_sp1, rhs_mat_2367ABEF_1_sp1), lhs_mat_01_0_sp1, rhs_mat_2367ABEF_0_sp1); + __m512i iacc_mat_10_sp1 = mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(zero, lhs_mat_23_3_sp1, rhs_mat_014589CD_3_sp1), lhs_mat_23_2_sp1, rhs_mat_014589CD_2_sp1), lhs_mat_23_1_sp1, rhs_mat_014589CD_1_sp1), lhs_mat_23_0_sp1, rhs_mat_014589CD_0_sp1); + __m512i iacc_mat_11_sp1 = mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(zero, lhs_mat_23_3_sp1, rhs_mat_2367ABEF_3_sp1), lhs_mat_23_2_sp1, rhs_mat_2367ABEF_2_sp1), lhs_mat_23_1_sp1, rhs_mat_2367ABEF_1_sp1), lhs_mat_23_0_sp1, rhs_mat_2367ABEF_0_sp1); + __m512i iacc_mat_00_sp2 = mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(zero, lhs_mat_01_3_sp2, rhs_mat_014589CD_3_sp2), lhs_mat_01_2_sp2, rhs_mat_014589CD_2_sp2), lhs_mat_01_1_sp2, rhs_mat_014589CD_1_sp2), lhs_mat_01_0_sp2, rhs_mat_014589CD_0_sp2); + __m512i iacc_mat_01_sp2 = mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(zero, lhs_mat_01_3_sp2, rhs_mat_2367ABEF_3_sp2), lhs_mat_01_2_sp2, rhs_mat_2367ABEF_2_sp2), lhs_mat_01_1_sp2, rhs_mat_2367ABEF_1_sp2), lhs_mat_01_0_sp2, rhs_mat_2367ABEF_0_sp2); + __m512i iacc_mat_10_sp2 = mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(zero, lhs_mat_23_3_sp2, rhs_mat_014589CD_3_sp2), lhs_mat_23_2_sp2, rhs_mat_014589CD_2_sp2), lhs_mat_23_1_sp2, rhs_mat_014589CD_1_sp2), lhs_mat_23_0_sp2, rhs_mat_014589CD_0_sp2); + __m512i iacc_mat_11_sp2 = mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(zero, lhs_mat_23_3_sp2, rhs_mat_2367ABEF_3_sp2), lhs_mat_23_2_sp2, rhs_mat_2367ABEF_2_sp2), lhs_mat_23_1_sp2, rhs_mat_2367ABEF_1_sp2), lhs_mat_23_0_sp2, rhs_mat_2367ABEF_0_sp2); // Output of both shuffle patterns are added in order to sum dot product outputs of all 32 values in block __m512i iacc_mat_00 = _mm512_add_epi32(iacc_mat_00_sp1, iacc_mat_00_sp2); @@ -3430,22 +3419,15 @@ static void ggml_gemm_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, c // The values arranged in shuffle patterns are operated with dot product operation within 32 bit lane i.e corresponding bytes and multiplied and added into 32 bit integers within 32 bit lane // Resembles MMLAs into 2x2 matrices in ARM Version - __m512i iacc_mat_00_sp1 = - _mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int32x16(lhs_mat_01_3_sp1, rhs_mat_014589CD_3_sp1), mul_sum_i8_pairs_int32x16(lhs_mat_01_2_sp1, rhs_mat_014589CD_2_sp1)), mul_sum_i8_pairs_int32x16(lhs_mat_01_1_sp1, rhs_mat_014589CD_1_sp1)), mul_sum_i8_pairs_int32x16(lhs_mat_01_0_sp1, rhs_mat_014589CD_0_sp1)); - __m512i iacc_mat_01_sp1 = - _mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int32x16(lhs_mat_01_3_sp1, rhs_mat_2367ABEF_3_sp1), mul_sum_i8_pairs_int32x16(lhs_mat_01_2_sp1, rhs_mat_2367ABEF_2_sp1)), mul_sum_i8_pairs_int32x16(lhs_mat_01_1_sp1, rhs_mat_2367ABEF_1_sp1)), mul_sum_i8_pairs_int32x16(lhs_mat_01_0_sp1, rhs_mat_2367ABEF_0_sp1)); - __m512i iacc_mat_10_sp1 = - _mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int32x16(lhs_mat_23_3_sp1, rhs_mat_014589CD_3_sp1), mul_sum_i8_pairs_int32x16(lhs_mat_23_2_sp1, rhs_mat_014589CD_2_sp1)), mul_sum_i8_pairs_int32x16(lhs_mat_23_1_sp1, rhs_mat_014589CD_1_sp1)), mul_sum_i8_pairs_int32x16(lhs_mat_23_0_sp1, rhs_mat_014589CD_0_sp1)); - __m512i iacc_mat_11_sp1 = - _mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int32x16(lhs_mat_23_3_sp1, rhs_mat_2367ABEF_3_sp1), mul_sum_i8_pairs_int32x16(lhs_mat_23_2_sp1, rhs_mat_2367ABEF_2_sp1)), mul_sum_i8_pairs_int32x16(lhs_mat_23_1_sp1, rhs_mat_2367ABEF_1_sp1)), mul_sum_i8_pairs_int32x16(lhs_mat_23_0_sp1, rhs_mat_2367ABEF_0_sp1)); - __m512i iacc_mat_00_sp2 = - _mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int32x16(lhs_mat_01_3_sp2, rhs_mat_014589CD_3_sp2), mul_sum_i8_pairs_int32x16(lhs_mat_01_2_sp2, rhs_mat_014589CD_2_sp2)), mul_sum_i8_pairs_int32x16(lhs_mat_01_1_sp2, rhs_mat_014589CD_1_sp2)), mul_sum_i8_pairs_int32x16(lhs_mat_01_0_sp2, rhs_mat_014589CD_0_sp2)); - __m512i iacc_mat_01_sp2 = - _mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int32x16(lhs_mat_01_3_sp2, rhs_mat_2367ABEF_3_sp2), mul_sum_i8_pairs_int32x16(lhs_mat_01_2_sp2, rhs_mat_2367ABEF_2_sp2)), mul_sum_i8_pairs_int32x16(lhs_mat_01_1_sp2, rhs_mat_2367ABEF_1_sp2)), mul_sum_i8_pairs_int32x16(lhs_mat_01_0_sp2, rhs_mat_2367ABEF_0_sp2)); - __m512i iacc_mat_10_sp2 = - _mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int32x16(lhs_mat_23_3_sp2, rhs_mat_014589CD_3_sp2), mul_sum_i8_pairs_int32x16(lhs_mat_23_2_sp2, rhs_mat_014589CD_2_sp2)), mul_sum_i8_pairs_int32x16(lhs_mat_23_1_sp2, rhs_mat_014589CD_1_sp2)), mul_sum_i8_pairs_int32x16(lhs_mat_23_0_sp2, rhs_mat_014589CD_0_sp2)); - __m512i iacc_mat_11_sp2 = - _mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int32x16(lhs_mat_23_3_sp2, rhs_mat_2367ABEF_3_sp2), mul_sum_i8_pairs_int32x16(lhs_mat_23_2_sp2, rhs_mat_2367ABEF_2_sp2)), mul_sum_i8_pairs_int32x16(lhs_mat_23_1_sp2, rhs_mat_2367ABEF_1_sp2)), mul_sum_i8_pairs_int32x16(lhs_mat_23_0_sp2, rhs_mat_2367ABEF_0_sp2)); + const __m512i zero = _mm512_setzero_epi32(); + __m512i iacc_mat_00_sp1 = mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(zero, lhs_mat_01_3_sp1, rhs_mat_014589CD_3_sp1), lhs_mat_01_2_sp1, rhs_mat_014589CD_2_sp1), lhs_mat_01_1_sp1, rhs_mat_014589CD_1_sp1), lhs_mat_01_0_sp1, rhs_mat_014589CD_0_sp1); + __m512i iacc_mat_01_sp1 = mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(zero, lhs_mat_01_3_sp1, rhs_mat_2367ABEF_3_sp1), lhs_mat_01_2_sp1, rhs_mat_2367ABEF_2_sp1), lhs_mat_01_1_sp1, rhs_mat_2367ABEF_1_sp1), lhs_mat_01_0_sp1, rhs_mat_2367ABEF_0_sp1); + __m512i iacc_mat_10_sp1 = mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(zero, lhs_mat_23_3_sp1, rhs_mat_014589CD_3_sp1), lhs_mat_23_2_sp1, rhs_mat_014589CD_2_sp1), lhs_mat_23_1_sp1, rhs_mat_014589CD_1_sp1), lhs_mat_23_0_sp1, rhs_mat_014589CD_0_sp1); + __m512i iacc_mat_11_sp1 = mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(zero, lhs_mat_23_3_sp1, rhs_mat_2367ABEF_3_sp1), lhs_mat_23_2_sp1, rhs_mat_2367ABEF_2_sp1), lhs_mat_23_1_sp1, rhs_mat_2367ABEF_1_sp1), lhs_mat_23_0_sp1, rhs_mat_2367ABEF_0_sp1); + __m512i iacc_mat_00_sp2 = mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(zero, lhs_mat_01_3_sp2, rhs_mat_014589CD_3_sp2), lhs_mat_01_2_sp2, rhs_mat_014589CD_2_sp2), lhs_mat_01_1_sp2, rhs_mat_014589CD_1_sp2), lhs_mat_01_0_sp2, rhs_mat_014589CD_0_sp2); + __m512i iacc_mat_01_sp2 = mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(zero, lhs_mat_01_3_sp2, rhs_mat_2367ABEF_3_sp2), lhs_mat_01_2_sp2, rhs_mat_2367ABEF_2_sp2), lhs_mat_01_1_sp2, rhs_mat_2367ABEF_1_sp2), lhs_mat_01_0_sp2, rhs_mat_2367ABEF_0_sp2); + __m512i iacc_mat_10_sp2 = mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(zero, lhs_mat_23_3_sp2, rhs_mat_014589CD_3_sp2), lhs_mat_23_2_sp2, rhs_mat_014589CD_2_sp2), lhs_mat_23_1_sp2, rhs_mat_014589CD_1_sp2), lhs_mat_23_0_sp2, rhs_mat_014589CD_0_sp2); + __m512i iacc_mat_11_sp2 = mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(zero, lhs_mat_23_3_sp2, rhs_mat_2367ABEF_3_sp2), lhs_mat_23_2_sp2, rhs_mat_2367ABEF_2_sp2), lhs_mat_23_1_sp2, rhs_mat_2367ABEF_1_sp2), lhs_mat_23_0_sp2, rhs_mat_2367ABEF_0_sp2); // Output of both shuffle patterns are added in order to sum dot product outputs of all 32 values in block __m512i iacc_mat_00 = _mm512_add_epi32(iacc_mat_00_sp1, iacc_mat_00_sp2); @@ -3605,22 +3587,15 @@ static void ggml_gemm_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, c // The values arranged in shuffle patterns are operated with dot product operation within 32 bit lane i.e corresponding bytes and multiplied and added into 32 bit integers within 32 bit lane // Resembles MMLAs into 2x2 matrices in ARM Version - __m256i iacc_mat_00_sp1 = - _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int32x8(lhs_mat_01_3_sp1, rhs_mat_0145_3_sp1), mul_sum_i8_pairs_int32x8(lhs_mat_01_2_sp1, rhs_mat_0145_2_sp1)), mul_sum_i8_pairs_int32x8(lhs_mat_01_1_sp1, rhs_mat_0145_1_sp1)), mul_sum_i8_pairs_int32x8(lhs_mat_01_0_sp1, rhs_mat_0145_0_sp1)); - __m256i iacc_mat_01_sp1 = - _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int32x8(lhs_mat_01_3_sp1, rhs_mat_2367_3_sp1), mul_sum_i8_pairs_int32x8(lhs_mat_01_2_sp1, rhs_mat_2367_2_sp1)), mul_sum_i8_pairs_int32x8(lhs_mat_01_1_sp1, rhs_mat_2367_1_sp1)), mul_sum_i8_pairs_int32x8(lhs_mat_01_0_sp1, rhs_mat_2367_0_sp1)); - __m256i iacc_mat_10_sp1 = - _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int32x8(lhs_mat_23_3_sp1, rhs_mat_0145_3_sp1), mul_sum_i8_pairs_int32x8(lhs_mat_23_2_sp1, rhs_mat_0145_2_sp1)), mul_sum_i8_pairs_int32x8(lhs_mat_23_1_sp1, rhs_mat_0145_1_sp1)), mul_sum_i8_pairs_int32x8(lhs_mat_23_0_sp1, rhs_mat_0145_0_sp1)); - __m256i iacc_mat_11_sp1 = - _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int32x8(lhs_mat_23_3_sp1, rhs_mat_2367_3_sp1), mul_sum_i8_pairs_int32x8(lhs_mat_23_2_sp1, rhs_mat_2367_2_sp1)), mul_sum_i8_pairs_int32x8(lhs_mat_23_1_sp1, rhs_mat_2367_1_sp1)), mul_sum_i8_pairs_int32x8(lhs_mat_23_0_sp1, rhs_mat_2367_0_sp1)); - __m256i iacc_mat_00_sp2 = - _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int32x8(lhs_mat_01_3_sp2, rhs_mat_0145_3_sp2), mul_sum_i8_pairs_int32x8(lhs_mat_01_2_sp2, rhs_mat_0145_2_sp2)), mul_sum_i8_pairs_int32x8(lhs_mat_01_1_sp2, rhs_mat_0145_1_sp2)), mul_sum_i8_pairs_int32x8(lhs_mat_01_0_sp2, rhs_mat_0145_0_sp2)); - __m256i iacc_mat_01_sp2 = - _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int32x8(lhs_mat_01_3_sp2, rhs_mat_2367_3_sp2), mul_sum_i8_pairs_int32x8(lhs_mat_01_2_sp2, rhs_mat_2367_2_sp2)), mul_sum_i8_pairs_int32x8(lhs_mat_01_1_sp2, rhs_mat_2367_1_sp2)), mul_sum_i8_pairs_int32x8(lhs_mat_01_0_sp2, rhs_mat_2367_0_sp2)); - __m256i iacc_mat_10_sp2 = - _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int32x8(lhs_mat_23_3_sp2, rhs_mat_0145_3_sp2), mul_sum_i8_pairs_int32x8(lhs_mat_23_2_sp2, rhs_mat_0145_2_sp2)), mul_sum_i8_pairs_int32x8(lhs_mat_23_1_sp2, rhs_mat_0145_1_sp2)), mul_sum_i8_pairs_int32x8(lhs_mat_23_0_sp2, rhs_mat_0145_0_sp2)); - __m256i iacc_mat_11_sp2 = - _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int32x8(lhs_mat_23_3_sp2, rhs_mat_2367_3_sp2), mul_sum_i8_pairs_int32x8(lhs_mat_23_2_sp2, rhs_mat_2367_2_sp2)), mul_sum_i8_pairs_int32x8(lhs_mat_23_1_sp2, rhs_mat_2367_1_sp2)), mul_sum_i8_pairs_int32x8(lhs_mat_23_0_sp2, rhs_mat_2367_0_sp2)); + const __m256i zero = _mm256_setzero_si256(); + __m256i iacc_mat_00_sp1 = mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(zero, lhs_mat_01_3_sp1, rhs_mat_0145_3_sp1), lhs_mat_01_2_sp1, rhs_mat_0145_2_sp1), lhs_mat_01_1_sp1, rhs_mat_0145_1_sp1), lhs_mat_01_0_sp1, rhs_mat_0145_0_sp1); + __m256i iacc_mat_01_sp1 = mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(zero, lhs_mat_01_3_sp1, rhs_mat_2367_3_sp1), lhs_mat_01_2_sp1, rhs_mat_2367_2_sp1), lhs_mat_01_1_sp1, rhs_mat_2367_1_sp1), lhs_mat_01_0_sp1, rhs_mat_2367_0_sp1); + __m256i iacc_mat_10_sp1 = mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(zero, lhs_mat_23_3_sp1, rhs_mat_0145_3_sp1), lhs_mat_23_2_sp1, rhs_mat_0145_2_sp1), lhs_mat_23_1_sp1, rhs_mat_0145_1_sp1), lhs_mat_23_0_sp1, rhs_mat_0145_0_sp1); + __m256i iacc_mat_11_sp1 = mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(zero, lhs_mat_23_3_sp1, rhs_mat_2367_3_sp1), lhs_mat_23_2_sp1, rhs_mat_2367_2_sp1), lhs_mat_23_1_sp1, rhs_mat_2367_1_sp1), lhs_mat_23_0_sp1, rhs_mat_2367_0_sp1); + __m256i iacc_mat_00_sp2 = mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(zero, lhs_mat_01_3_sp2, rhs_mat_0145_3_sp2), lhs_mat_01_2_sp2, rhs_mat_0145_2_sp2), lhs_mat_01_1_sp2, rhs_mat_0145_1_sp2), lhs_mat_01_0_sp2, rhs_mat_0145_0_sp2); + __m256i iacc_mat_01_sp2 = mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(zero, lhs_mat_01_3_sp2, rhs_mat_2367_3_sp2), lhs_mat_01_2_sp2, rhs_mat_2367_2_sp2), lhs_mat_01_1_sp2, rhs_mat_2367_1_sp2), lhs_mat_01_0_sp2, rhs_mat_2367_0_sp2); + __m256i iacc_mat_10_sp2 = mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(zero, lhs_mat_23_3_sp2, rhs_mat_0145_3_sp2), lhs_mat_23_2_sp2, rhs_mat_0145_2_sp2), lhs_mat_23_1_sp2, rhs_mat_0145_1_sp2), lhs_mat_23_0_sp2, rhs_mat_0145_0_sp2); + __m256i iacc_mat_11_sp2 = mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(zero, lhs_mat_23_3_sp2, rhs_mat_2367_3_sp2), lhs_mat_23_2_sp2, rhs_mat_2367_2_sp2), lhs_mat_23_1_sp2, rhs_mat_2367_1_sp2), lhs_mat_23_0_sp2, rhs_mat_2367_0_sp2); // Output of both shuffle patterns are added in order to sum dot product outputs of all 32 values in block __m256i iacc_mat_00 = _mm256_add_epi32(iacc_mat_00_sp1, iacc_mat_00_sp2); @@ -3769,22 +3744,15 @@ static void ggml_gemm_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, c // The values arranged in shuffle patterns are operated with dot product operation within 32 bit lane i.e corresponding bytes and multiplied and added into 32 bit integers within 32 bit lane // Resembles MMLAs into 2x2 matrices in ARM Version - __m256i iacc_mat_00_sp1 = - _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int32x8(lhs_mat_01_3_sp1, rhs_mat_0145_3_sp1), mul_sum_i8_pairs_int32x8(lhs_mat_01_2_sp1, rhs_mat_0145_2_sp1)), mul_sum_i8_pairs_int32x8(lhs_mat_01_1_sp1, rhs_mat_0145_1_sp1)), mul_sum_i8_pairs_int32x8(lhs_mat_01_0_sp1, rhs_mat_0145_0_sp1)); - __m256i iacc_mat_01_sp1 = - _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int32x8(lhs_mat_01_3_sp1, rhs_mat_2367_3_sp1), mul_sum_i8_pairs_int32x8(lhs_mat_01_2_sp1, rhs_mat_2367_2_sp1)), mul_sum_i8_pairs_int32x8(lhs_mat_01_1_sp1, rhs_mat_2367_1_sp1)), mul_sum_i8_pairs_int32x8(lhs_mat_01_0_sp1, rhs_mat_2367_0_sp1)); - __m256i iacc_mat_10_sp1 = - _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int32x8(lhs_mat_23_3_sp1, rhs_mat_0145_3_sp1), mul_sum_i8_pairs_int32x8(lhs_mat_23_2_sp1, rhs_mat_0145_2_sp1)), mul_sum_i8_pairs_int32x8(lhs_mat_23_1_sp1, rhs_mat_0145_1_sp1)), mul_sum_i8_pairs_int32x8(lhs_mat_23_0_sp1, rhs_mat_0145_0_sp1)); - __m256i iacc_mat_11_sp1 = - _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int32x8(lhs_mat_23_3_sp1, rhs_mat_2367_3_sp1), mul_sum_i8_pairs_int32x8(lhs_mat_23_2_sp1, rhs_mat_2367_2_sp1)), mul_sum_i8_pairs_int32x8(lhs_mat_23_1_sp1, rhs_mat_2367_1_sp1)), mul_sum_i8_pairs_int32x8(lhs_mat_23_0_sp1, rhs_mat_2367_0_sp1)); - __m256i iacc_mat_00_sp2 = - _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int32x8(lhs_mat_01_3_sp2, rhs_mat_0145_3_sp2), mul_sum_i8_pairs_int32x8(lhs_mat_01_2_sp2, rhs_mat_0145_2_sp2)), mul_sum_i8_pairs_int32x8(lhs_mat_01_1_sp2, rhs_mat_0145_1_sp2)), mul_sum_i8_pairs_int32x8(lhs_mat_01_0_sp2, rhs_mat_0145_0_sp2)); - __m256i iacc_mat_01_sp2 = - _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int32x8(lhs_mat_01_3_sp2, rhs_mat_2367_3_sp2), mul_sum_i8_pairs_int32x8(lhs_mat_01_2_sp2, rhs_mat_2367_2_sp2)), mul_sum_i8_pairs_int32x8(lhs_mat_01_1_sp2, rhs_mat_2367_1_sp2)), mul_sum_i8_pairs_int32x8(lhs_mat_01_0_sp2, rhs_mat_2367_0_sp2)); - __m256i iacc_mat_10_sp2 = - _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int32x8(lhs_mat_23_3_sp2, rhs_mat_0145_3_sp2), mul_sum_i8_pairs_int32x8(lhs_mat_23_2_sp2, rhs_mat_0145_2_sp2)), mul_sum_i8_pairs_int32x8(lhs_mat_23_1_sp2, rhs_mat_0145_1_sp2)), mul_sum_i8_pairs_int32x8(lhs_mat_23_0_sp2, rhs_mat_0145_0_sp2)); - __m256i iacc_mat_11_sp2 = - _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int32x8(lhs_mat_23_3_sp2, rhs_mat_2367_3_sp2), mul_sum_i8_pairs_int32x8(lhs_mat_23_2_sp2, rhs_mat_2367_2_sp2)), mul_sum_i8_pairs_int32x8(lhs_mat_23_1_sp2, rhs_mat_2367_1_sp2)), mul_sum_i8_pairs_int32x8(lhs_mat_23_0_sp2, rhs_mat_2367_0_sp2)); + const __m256i zero = _mm256_setzero_si256(); + __m256i iacc_mat_00_sp1 = mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(zero, lhs_mat_01_3_sp1, rhs_mat_0145_3_sp1), lhs_mat_01_2_sp1, rhs_mat_0145_2_sp1), lhs_mat_01_1_sp1, rhs_mat_0145_1_sp1), lhs_mat_01_0_sp1, rhs_mat_0145_0_sp1); + __m256i iacc_mat_01_sp1 = mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(zero, lhs_mat_01_3_sp1, rhs_mat_2367_3_sp1), lhs_mat_01_2_sp1, rhs_mat_2367_2_sp1), lhs_mat_01_1_sp1, rhs_mat_2367_1_sp1), lhs_mat_01_0_sp1, rhs_mat_2367_0_sp1); + __m256i iacc_mat_10_sp1 = mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(zero, lhs_mat_23_3_sp1, rhs_mat_0145_3_sp1), lhs_mat_23_2_sp1, rhs_mat_0145_2_sp1), lhs_mat_23_1_sp1, rhs_mat_0145_1_sp1), lhs_mat_23_0_sp1, rhs_mat_0145_0_sp1); + __m256i iacc_mat_11_sp1 = mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(zero, lhs_mat_23_3_sp1, rhs_mat_2367_3_sp1), lhs_mat_23_2_sp1, rhs_mat_2367_2_sp1), lhs_mat_23_1_sp1, rhs_mat_2367_1_sp1), lhs_mat_23_0_sp1, rhs_mat_2367_0_sp1); + __m256i iacc_mat_00_sp2 = mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(zero, lhs_mat_01_3_sp2, rhs_mat_0145_3_sp2), lhs_mat_01_2_sp2, rhs_mat_0145_2_sp2), lhs_mat_01_1_sp2, rhs_mat_0145_1_sp2), lhs_mat_01_0_sp2, rhs_mat_0145_0_sp2); + __m256i iacc_mat_01_sp2 = mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(zero, lhs_mat_01_3_sp2, rhs_mat_2367_3_sp2), lhs_mat_01_2_sp2, rhs_mat_2367_2_sp2), lhs_mat_01_1_sp2, rhs_mat_2367_1_sp2), lhs_mat_01_0_sp2, rhs_mat_2367_0_sp2); + __m256i iacc_mat_10_sp2 = mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(zero, lhs_mat_23_3_sp2, rhs_mat_0145_3_sp2), lhs_mat_23_2_sp2, rhs_mat_0145_2_sp2), lhs_mat_23_1_sp2, rhs_mat_0145_1_sp2), lhs_mat_23_0_sp2, rhs_mat_0145_0_sp2); + __m256i iacc_mat_11_sp2 = mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(zero, lhs_mat_23_3_sp2, rhs_mat_2367_3_sp2), lhs_mat_23_2_sp2, rhs_mat_2367_2_sp2), lhs_mat_23_1_sp2, rhs_mat_2367_1_sp2), lhs_mat_23_0_sp2, rhs_mat_2367_0_sp2); // Output of both shuffle patterns are added in order to sum dot product outputs of all 32 values in block __m256i iacc_mat_00 = _mm256_add_epi32(iacc_mat_00_sp1, iacc_mat_00_sp2); From a25355e2643b1feb2fcbbc4c00312aa86370d858 Mon Sep 17 00:00:00 2001 From: cmdr2 Date: Fri, 11 Apr 2025 12:14:19 +0530 Subject: [PATCH 05/20] cpu: fix cpu backend's supports-op for GET_ROWS_BACK. fixes a fatal when running test-backend-ops with only the CPU backend (ggml/1190) --- ggml/src/ggml-cpu/ggml-cpu.cpp | 2 ++ 1 file changed, 2 insertions(+) diff --git a/ggml/src/ggml-cpu/ggml-cpu.cpp b/ggml/src/ggml-cpu/ggml-cpu.cpp index 09f8382b9..4b688a67e 100644 --- a/ggml/src/ggml-cpu/ggml-cpu.cpp +++ b/ggml/src/ggml-cpu/ggml-cpu.cpp @@ -425,6 +425,8 @@ static bool ggml_backend_cpu_device_supports_op(ggml_backend_dev_t dev, const st } case GGML_OP_IM2COL_BACK: return src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32; + case GGML_OP_GET_ROWS_BACK: + return src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16; case GGML_OP_OUT_PROD: return (src0->type == GGML_TYPE_F32 || (ggml_is_quantized(src0->type) && src0->ne[2] == src1->ne[2] && src0->ne[3] == src1->ne[3])) && src1->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32; From 526739b879c9d6a702ed681138611197fa4d3f18 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Mon, 14 Apr 2025 08:52:10 +0300 Subject: [PATCH 06/20] sync : ggml ggml-ci --- scripts/sync-ggml.last | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/sync-ggml.last b/scripts/sync-ggml.last index 7111936ba..cad082a90 100644 --- a/scripts/sync-ggml.last +++ b/scripts/sync-ggml.last @@ -1 +1 @@ -2abf606f098844faebee578996cae9c6d63a40e2 +f71d538ece3fb32a04824dc6d1e73e360be9d22f From 81c7e64fc239288e91a58adad9145110e0353822 Mon Sep 17 00:00:00 2001 From: Neo Zhang Jianyu Date: Mon, 14 Apr 2025 18:19:07 +0800 Subject: [PATCH 07/20] dsiable curl lib check, this action is missed by commit bd3f59f81289b920bcc597a208c14f55e39ed37e (#12761) (#12937) --- examples/sycl/build.sh | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/sycl/build.sh b/examples/sycl/build.sh index 8fe0a6790..e72b2e261 100755 --- a/examples/sycl/build.sh +++ b/examples/sycl/build.sh @@ -8,10 +8,10 @@ cd build source /opt/intel/oneapi/setvars.sh #for FP16 -#cmake .. -DGGML_SYCL=ON -DCMAKE_C_COMPILER=icx -DCMAKE_CXX_COMPILER=icpx -DGGML_SYCL_F16=ON # faster for long-prompt inference +#cmake .. -DGGML_SYCL=ON -DCMAKE_C_COMPILER=icx -DCMAKE_CXX_COMPILER=icpx -DGGML_SYCL_F16=ON -DLLAMA_CURL=OFF # faster for long-prompt inference #for FP32 -cmake .. -DGGML_SYCL=ON -DCMAKE_C_COMPILER=icx -DCMAKE_CXX_COMPILER=icpx +cmake .. -DGGML_SYCL=ON -DCMAKE_C_COMPILER=icx -DCMAKE_CXX_COMPILER=icpx -DLLAMA_CURL=OFF #build example/main #cmake --build . --config Release --target main From c772d549264c1be058411312a54049e0dc86a037 Mon Sep 17 00:00:00 2001 From: Radoslav Gerganov Date: Mon, 14 Apr 2025 13:59:34 +0300 Subject: [PATCH 08/20] rpc : use ggml_context_ptr (#12938) --- ggml/src/ggml-rpc/ggml-rpc.cpp | 45 +++++++++++++++++----------------- 1 file changed, 23 insertions(+), 22 deletions(-) diff --git a/ggml/src/ggml-rpc/ggml-rpc.cpp b/ggml/src/ggml-rpc/ggml-rpc.cpp index 862b9b666..3189ae85d 100644 --- a/ggml/src/ggml-rpc/ggml-rpc.cpp +++ b/ggml/src/ggml-rpc/ggml-rpc.cpp @@ -1,6 +1,7 @@ #include "ggml-rpc.h" #include "ggml-impl.h" #include "ggml-backend-impl.h" +#include "ggml-cpp.h" #include #include @@ -853,12 +854,13 @@ bool rpc_server::get_alloc_size(const rpc_msg_get_alloc_size_req & request, rpc_ /*.no_alloc =*/ true, }; - struct ggml_context * ctx = ggml_init(params); + ggml_context_ptr ctx_ptr { ggml_init(params) }; + GGML_ASSERT(ctx_ptr != nullptr); + ggml_context * ctx = ctx_ptr.get(); ggml_tensor * tensor = deserialize_tensor(ctx, &request.tensor); if (tensor == nullptr) { GGML_LOG_ERROR("Null tensor pointer passed to server get_alloc_size function.\n"); - ggml_free(ctx); return false; } @@ -871,7 +873,6 @@ bool rpc_server::get_alloc_size(const rpc_msg_get_alloc_size_req & request, rpc_ response.alloc_size = ggml_backend_buft_get_alloc_size(buft,tensor); - ggml_free(ctx); return true; } @@ -985,11 +986,12 @@ bool rpc_server::set_tensor(const std::vector & input) { /*.mem_buffer =*/ NULL, /*.no_alloc =*/ true, }; - struct ggml_context * ctx = ggml_init(params); + ggml_context_ptr ctx_ptr { ggml_init(params) }; + GGML_ASSERT(ctx_ptr != nullptr); + ggml_context * ctx = ctx_ptr.get(); ggml_tensor * tensor = deserialize_tensor(ctx, in_tensor); if (tensor == nullptr) { GGML_LOG_ERROR("[%s] error deserializing tensor\n", __func__); - ggml_free(ctx); return false; } GGML_PRINT_DEBUG("[%s] buffer: %p, data: %p, offset: %" PRIu64 ", size: %zu\n", __func__, (void*)tensor->buffer, tensor->data, offset, size); @@ -1016,7 +1018,6 @@ bool rpc_server::set_tensor(const std::vector & input) { printf("[%s] saved to '%s'\n", __func__, cache_file.c_str()); } ggml_backend_tensor_set(tensor, data, offset, size); - ggml_free(ctx); return true; } @@ -1060,11 +1061,12 @@ bool rpc_server::set_tensor_hash(const std::vector & input, rpc_msg_set /*.mem_buffer =*/ NULL, /*.no_alloc =*/ true, }; - struct ggml_context * ctx = ggml_init(params); + ggml_context_ptr ctx_ptr { ggml_init(params) }; + GGML_ASSERT(ctx_ptr != nullptr); + ggml_context * ctx = ctx_ptr.get(); ggml_tensor * tensor = deserialize_tensor(ctx, in_tensor); if (tensor == nullptr) { GGML_LOG_ERROR("[%s] error deserializing tensor\n", __func__); - ggml_free(ctx); return false; } GGML_PRINT_DEBUG("[%s] buffer: %p, data: %p, offset: %" PRIu64 ", size: %zu, hash: %" PRIx64 "\n", __func__, (void*)tensor->buffer, tensor->data, offset, size, *hash); @@ -1080,7 +1082,6 @@ bool rpc_server::set_tensor_hash(const std::vector & input, rpc_msg_set } ggml_backend_tensor_set(tensor, cached_file.data(), offset, size); response.result = 1; - ggml_free(ctx); return true; } @@ -1090,11 +1091,12 @@ bool rpc_server::init_tensor(const rpc_msg_init_tensor_req & request) { /*.mem_buffer =*/ NULL, /*.no_alloc =*/ true, }; - struct ggml_context * ctx = ggml_init(params); + ggml_context_ptr ctx_ptr { ggml_init(params) }; + GGML_ASSERT(ctx_ptr != nullptr); + ggml_context * ctx = ctx_ptr.get(); ggml_tensor * tensor = deserialize_tensor(ctx, &request.tensor); if (tensor == nullptr) { GGML_LOG_ERROR("Null tensor pointer passed to server init_tensor function.\n"); - ggml_free(ctx); return false; } @@ -1110,11 +1112,9 @@ bool rpc_server::init_tensor(const rpc_msg_init_tensor_req & request) { // This pointer can either be passed around client/server, or probably better stored server-side and kept track of. // Currently unimplemented. GGML_LOG_ERROR("tensor->extra populated by the backend, this is currently unsupported.\n"); - ggml_free(ctx); return false; } - ggml_free(ctx); return true; } @@ -1124,11 +1124,12 @@ bool rpc_server::get_tensor(const rpc_msg_get_tensor_req & request, std::vector< /*.mem_buffer =*/ NULL, /*.no_alloc =*/ true, }; - struct ggml_context * ctx = ggml_init(params); + ggml_context_ptr ctx_ptr { ggml_init(params) }; + GGML_ASSERT(ctx_ptr != nullptr); + ggml_context * ctx = ctx_ptr.get(); ggml_tensor * tensor = deserialize_tensor(ctx, &request.tensor); if (tensor == nullptr) { GGML_LOG_ERROR("[%s] error deserializing tensor\n", __func__); - ggml_free(ctx); return false; } GGML_PRINT_DEBUG("[%s] buffer: %p, data: %p, offset: %" PRIu64 ", size: %" PRIu64 "\n", __func__, (void*)tensor->buffer, tensor->data, request.offset, request.size); @@ -1147,7 +1148,6 @@ bool rpc_server::get_tensor(const rpc_msg_get_tensor_req & request, std::vector< response.resize(request.size, 0); ggml_backend_tensor_get(tensor, response.data(), request.offset, request.size); - ggml_free(ctx); return true; } @@ -1157,12 +1157,14 @@ bool rpc_server::copy_tensor(const rpc_msg_copy_tensor_req & request, rpc_msg_co /*.mem_buffer =*/ NULL, /*.no_alloc =*/ true, }; - struct ggml_context * ctx = ggml_init(params); + ggml_context_ptr ctx_ptr { ggml_init(params) }; + GGML_ASSERT(ctx_ptr != nullptr); + ggml_context * ctx = ctx_ptr.get(); + ggml_tensor * src = deserialize_tensor(ctx, &request.src); ggml_tensor * dst = deserialize_tensor(ctx, &request.dst); if (src == nullptr || dst == nullptr) { GGML_LOG_ERROR("[%s] error deserializing tensors\n", __func__); - ggml_free(ctx); return false; } @@ -1180,7 +1182,6 @@ bool rpc_server::copy_tensor(const rpc_msg_copy_tensor_req & request, rpc_msg_co dst_data + src_size, dst_base, dst_base + dst_buf_sz); - ggml_free(ctx); return false; } @@ -1188,7 +1189,6 @@ bool rpc_server::copy_tensor(const rpc_msg_copy_tensor_req & request, rpc_msg_co __func__, (void*) src->buffer, (void*) dst->buffer); response.result = ggml_backend_buffer_copy_tensor(src, dst); - ggml_free(ctx); return true; } @@ -1242,7 +1242,9 @@ bool rpc_server::graph_compute(const std::vector & input, rpc_msg_graph /*.mem_buffer =*/ NULL, /*.no_alloc =*/ true, }; - struct ggml_context * ctx = ggml_init(params); + ggml_context_ptr ctx_ptr { ggml_init(params) }; + GGML_ASSERT(ctx_ptr != nullptr); + ggml_context * ctx = ctx_ptr.get(); struct ggml_cgraph * graph = ggml_new_graph_custom(ctx, n_nodes, false); graph->n_nodes = n_nodes; std::unordered_map tensor_ptrs; @@ -1257,7 +1259,6 @@ bool rpc_server::graph_compute(const std::vector & input, rpc_msg_graph } ggml_status status = ggml_backend_graph_compute(backend, graph); response.result = status; - ggml_free(ctx); return true; } From 75afa0ae31f0a51aaadcc5ff146eb7a32a7f9088 Mon Sep 17 00:00:00 2001 From: Akarshan Biswas Date: Mon, 14 Apr 2025 17:53:53 +0530 Subject: [PATCH 09/20] SYCL: Fix im2col (#12910) * SYCL: Fix im2col * restore local workgroup size adjustments for large inputs * restore format --- ggml/src/ggml-sycl/ggml-sycl.cpp | 3 +- ggml/src/ggml-sycl/im2col.cpp | 141 +++++++++++++++++-------------- 2 files changed, 79 insertions(+), 65 deletions(-) diff --git a/ggml/src/ggml-sycl/ggml-sycl.cpp b/ggml/src/ggml-sycl/ggml-sycl.cpp index 3e48a9244..09800d340 100644 --- a/ggml/src/ggml-sycl/ggml-sycl.cpp +++ b/ggml/src/ggml-sycl/ggml-sycl.cpp @@ -4018,8 +4018,7 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g return ggml_is_contiguous(op->src[0]); } case GGML_OP_IM2COL: - // TODO: add support for the new F32 operations - return op->src[0]->type == GGML_TYPE_F16; + return true; case GGML_OP_UPSCALE: return op->src[0]->type == GGML_TYPE_F32 && op->op_params[0] == GGML_SCALE_MODE_NEAREST; case GGML_OP_POOL_2D: diff --git a/ggml/src/ggml-sycl/im2col.cpp b/ggml/src/ggml-sycl/im2col.cpp index 009b42035..aa19c2527 100644 --- a/ggml/src/ggml-sycl/im2col.cpp +++ b/ggml/src/ggml-sycl/im2col.cpp @@ -12,110 +12,125 @@ #include "im2col.hpp" +#include +#include // For std::is_same_v + +#include "ggml.h" + template -static void im2col_kernel( - const float *x, T *dst, int64_t batch_offset, int64_t offset_delta, - int64_t IC, int64_t IW, int64_t IH, int64_t OH, int64_t OW, int64_t KW, int64_t KH, - int64_t pelements, int64_t CHW, int s0, int s1, int p0, int p1, int d0, int d1, - const sycl::nd_item<3> &item_ct1) { +static void im2col_kernel(const float * x, T * dst, int64_t batch_offset, int64_t offset_delta, int64_t IC, int64_t IW, + int64_t IH, int64_t OH, int64_t OW, int64_t KW, int64_t KH, int64_t pelements, int64_t CHW, + int s0, int s1, int p0, int p1, int d0, int d1, const sycl::nd_item<3> & item_ct1) { const int64_t work_group_size = item_ct1.get_local_range(2); - const int64_t global_id = item_ct1.get_local_id(2) + work_group_size * item_ct1.get_group(2); + const int64_t global_id = item_ct1.get_local_id(2) + (work_group_size * item_ct1.get_group(2)); // make each work-item deal with more elements since sycl global range can not exceed max int - for (int64_t i = global_id; i < pelements; i += work_group_size * item_ct1.get_group_range(2)) { - + for (int64_t i = global_id; i < pelements; i += (work_group_size * item_ct1.get_group_range(2))) { const int64_t ksize = OW * (KH > 1 ? KW : 1); - const int64_t kx = i / ksize; - const int64_t kd = kx * ksize; - const int64_t ky = (i - kd) / OW; - const int64_t ix = i % OW; + const int64_t kx = i / ksize; + const int64_t kd = kx * ksize; + const int64_t ky = (i - kd) / OW; + const int64_t ix = i % OW; - const int64_t oh = item_ct1.get_group(1); - const int64_t batch = item_ct1.get_group(0) / IC; - const int64_t ic = item_ct1.get_group(0) % IC; + const int64_t oh = item_ct1.get_group(1); + const int64_t batch = item_ct1.get_group(0) / IC; + const int64_t ic = item_ct1.get_group(0) % IC; - const int64_t iiw = ix * s0 + kx * d0 - p0; - const int64_t iih = oh * s1 + ky * d1 - p1; + const int64_t iiw = (ix * s0) + (kx * d0) - p0; + const int64_t iih = (oh * s1) + (ky * d1) - p1; - const int64_t offset_dst = - ((batch * OH + oh) * OW + ix) * CHW + - (ic * (KW * KH) + ky * KW + kx); + const int64_t offset_dst = (((batch * OH + oh) * OW + ix) * CHW) + (ic * (KW * KH) + ky * KW + kx); - if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) { - dst[offset_dst] = - sycl::vec(0.0f) - .convert()[0]; - } else { - const int64_t offset_src = ic * offset_delta + batch * batch_offset; - dst[offset_dst] = - sycl::vec(x[offset_src + iih * IW + iiw]) - .convert()[0]; + const int64_t offset_src_base = (ic * offset_delta) + (batch * batch_offset); + const int64_t offset_src = offset_src_base + (iih * IW) + iiw; + + const bool out_of_bounds = (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW); + const float src_val = out_of_bounds ? 0.0f : x[offset_src]; + + if constexpr (std::is_same_v) { + dst[offset_dst] = sycl::half(src_val); + } else if constexpr (std::is_same_v) { + dst[offset_dst] = src_val; } } } template -static void im2col_sycl( - const float *x, T *dst, int64_t IW, int64_t IH, int64_t OW, int64_t OH, int64_t KW, - int64_t KH, int64_t IC, int64_t batch, int64_t batch_offset, int64_t offset_delta, - int s0, int s1, int p0, int p1, int d0, int d1, - queue_ptr stream) { +static void im2col_sycl_internal(const float * x, T * dst, int64_t IW, int64_t IH, int64_t OW, int64_t OH, int64_t KW, + int64_t KH, int64_t IC, int64_t batch, int64_t batch_offset, int64_t offset_delta, + int s0, int s1, int p0, int p1, int d0, int d1, queue_ptr stream) { const int64_t parallel_elements = OW * KW * KH; - const int64_t num_blocks = (parallel_elements + SYCL_IM2COL_BLOCK_SIZE - 1) / SYCL_IM2COL_BLOCK_SIZE; + const int64_t num_blocks = (parallel_elements + SYCL_IM2COL_BLOCK_SIZE - 1) / SYCL_IM2COL_BLOCK_SIZE; // decrease global range when it exceeds the max int int64_t local_size = downsample_sycl_global_range(batch * IC * OH * num_blocks, SYCL_IM2COL_BLOCK_SIZE); + sycl::range<3> block_nums(batch * IC, OH, num_blocks); sycl::range<3> local_range(1, 1, local_size); - { - dpct::has_capability_or_fail(stream->get_device(), - {sycl::aspect::fp16}); + const int64_t CHW = IC * KH * KW; - stream->parallel_for( - sycl::nd_range<3>(block_nums * local_range, local_range), - [=](sycl::nd_item<3> item_ct1) { - im2col_kernel(x, dst, batch_offset, offset_delta, IC, IW, IH, OH, OW, KW, KH, - parallel_elements, (IC * KH * KW), s0, s1, p0, - p1, d0, d1, item_ct1); - }); - } + stream->parallel_for(sycl::nd_range<3>(block_nums * local_range, local_range), [=](sycl::nd_item<3> item_ct1) { + im2col_kernel(x, dst, batch_offset, offset_delta, IC, IW, IH, OH, OW, KW, KH, parallel_elements, CHW, s0, s1, + p0, p1, d0, d1, item_ct1); + }); } -void ggml_sycl_op_im2col(ggml_backend_sycl_context & ctx, ggml_tensor *dst) { +static void im2col_sycl_f16(const float * x, sycl::half * dst, int64_t IW, int64_t IH, int64_t OW, int64_t OH, + int64_t KW, int64_t KH, int64_t IC, int64_t batch, int64_t batch_offset, + int64_t offset_delta, int s0, int s1, int p0, int p1, int d0, int d1, queue_ptr stream) { + if (!stream->get_device().has(sycl::aspect::fp16)) { + throw sycl::exception(sycl::make_error_code(sycl::errc::kernel_not_supported), + "Device does not support half precision (fp16) operations!"); + } + im2col_sycl_internal(x, dst, IW, IH, OW, OH, KW, KH, IC, batch, batch_offset, offset_delta, s0, s1, p0, + p1, d0, d1, stream); +} + +static void im2col_sycl_f32(const float * x, float * dst, int64_t IW, int64_t IH, int64_t OW, int64_t OH, int64_t KW, + int64_t KH, int64_t IC, int64_t batch, int64_t batch_offset, int64_t offset_delta, int s0, + int s1, int p0, int p1, int d0, int d1, queue_ptr stream) { + im2col_sycl_internal(x, dst, IW, IH, OW, OH, KW, KH, IC, batch, batch_offset, offset_delta, s0, s1, p0, p1, + d0, d1, stream); +} + +void ggml_sycl_op_im2col(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { const ggml_tensor * src0 = dst->src[0]; const ggml_tensor * src1 = dst->src[1]; - GGML_ASSERT(src0->type == GGML_TYPE_F16); GGML_ASSERT(src1->type == GGML_TYPE_F32); GGML_ASSERT(dst->type == GGML_TYPE_F16 || dst->type == GGML_TYPE_F32); - const int32_t s0 = ((const int32_t*)(dst->op_params))[0]; - const int32_t s1 = ((const int32_t*)(dst->op_params))[1]; - const int32_t p0 = ((const int32_t*)(dst->op_params))[2]; - const int32_t p1 = ((const int32_t*)(dst->op_params))[3]; - const int32_t d0 = ((const int32_t*)(dst->op_params))[4]; - const int32_t d1 = ((const int32_t*)(dst->op_params))[5]; + const int32_t s0 = ((const int32_t *) (dst->op_params))[0]; + const int32_t s1 = ((const int32_t *) (dst->op_params))[1]; + const int32_t p0 = ((const int32_t *) (dst->op_params))[2]; + const int32_t p1 = ((const int32_t *) (dst->op_params))[3]; + const int32_t d0 = ((const int32_t *) (dst->op_params))[4]; + const int32_t d1 = ((const int32_t *) (dst->op_params))[5]; - const bool is_2D = ((const int32_t*)(dst->op_params))[6] == 1; + const bool is_2D = ((const int32_t *) (dst->op_params))[6] == 1; const int64_t IC = src1->ne[is_2D ? 2 : 1]; const int64_t IH = is_2D ? src1->ne[1] : 1; - const int64_t IW = src1->ne[0]; + const int64_t IW = src1->ne[0]; const int64_t KH = is_2D ? src0->ne[1] : 1; - const int64_t KW = src0->ne[0]; + const int64_t KW = src0->ne[0]; const int64_t OH = is_2D ? dst->ne[2] : 1; - const int64_t OW = dst->ne[1]; + const int64_t OW = dst->ne[1]; - const size_t delta_offset = src1->nb[is_2D ? 2 : 1] / 4; // nb is byte offset, src is type float32 - const int64_t batch = src1->ne[3]; - const size_t batch_offset = src1->nb[3] / 4; // nb is byte offset, src is type float32 + const size_t delta_offset = src1->nb[is_2D ? 2 : 1] / sizeof(float); + const int64_t batch = src1->ne[is_2D ? 3 : 2]; + const size_t batch_offset = src1->nb[is_2D ? 3 : 2] / sizeof(float); + + queue_ptr stream = ctx.stream(); if (dst->type == GGML_TYPE_F16) { - im2col_sycl((const float *) src1->data, (sycl::half *)dst->data, IW, IH, OW, OH, KW, KH, IC, batch, batch_offset, delta_offset, s0, s1, p0, p1, d0, d1, ctx.stream()); + im2col_sycl_f16((const float *) src1->data, (sycl::half *) dst->data, IW, IH, OW, OH, KW, KH, IC, batch, + batch_offset, delta_offset, s0, s1, p0, p1, d0, d1, stream); } else { - im2col_sycl((const float *) src1->data, (float *)dst->data, IW, IH, OW, OH, KW, KH, IC, batch, batch_offset, delta_offset, s0, s1, p0, p1, d0, d1, ctx.stream()); + im2col_sycl_f32((const float *) src1->data, (float *) dst->data, IW, IH, OW, OH, KW, KH, IC, batch, + batch_offset, delta_offset, s0, s1, p0, p1, d0, d1, stream); } } From d6d2c2ab8c8865784ba9fef37f2b2de3f2134d33 Mon Sep 17 00:00:00 2001 From: Russyyds <161207317+Russyyds@users.noreply.github.com> Date: Tue, 15 Apr 2025 01:18:20 +0800 Subject: [PATCH 10/20] Add performance print for gemma3 in example (#12929) --- examples/llava/gemma3-cli.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/llava/gemma3-cli.cpp b/examples/llava/gemma3-cli.cpp index 91a07e2a8..693b738a8 100644 --- a/examples/llava/gemma3-cli.cpp +++ b/examples/llava/gemma3-cli.cpp @@ -317,6 +317,6 @@ int main(int argc, char ** argv) { is_first_msg = false; } } - + llama_perf_context_print(ctx.lctx); return 0; } From b0c75ac9f93322b45e3766e149417334b4fd1ed9 Mon Sep 17 00:00:00 2001 From: Xinpeng Dou <15529241576@163.com> Date: Tue, 15 Apr 2025 10:04:24 +0800 Subject: [PATCH 11/20] CANN: Optimize CANN buffer pool memory management (#12875) Multiple optional memory pools are provided for CANN, including VMM, priority queue-based, and traditional memory pools. 1.When the memory pool is available and GGML_CANN_DISABLE_VMM_POOL is not defined, the VMM pool is selected by default. 2.Otherwise, if GGML_CANN_ENABLE_BUF_PRIO_POOL is defined, the priority queue-based memory pool is used. 3.If neither condition is met, the default memory pool is used. --- ggml/src/ggml-cann/aclnn_ops.cpp | 2 +- ggml/src/ggml-cann/ggml-cann.cpp | 397 ++++++++++++++++++++++++++----- 2 files changed, 335 insertions(+), 64 deletions(-) diff --git a/ggml/src/ggml-cann/aclnn_ops.cpp b/ggml/src/ggml-cann/aclnn_ops.cpp index 37d411797..f312a620c 100644 --- a/ggml/src/ggml-cann/aclnn_ops.cpp +++ b/ggml/src/ggml-cann/aclnn_ops.cpp @@ -1783,7 +1783,7 @@ void ggml_cann_get_rows(ggml_backend_cann_context& ctx, ggml_tensor* dst) { src0->data, ACL_INT8, sizeof(int8_t), weight_ne, weight_nb, GGML_MAX_DIMS + 1); aclTensor* acl_scale_tensor = ggml_cann_create_tensor( - src0->data, ACL_FLOAT16, sizeof(float16_t), scale_ne, scale_nb, + src0->data, ACL_FLOAT16, sizeof(uint16_t), scale_ne, scale_nb, GGML_MAX_DIMS + 1, ACL_FORMAT_ND, scale_offset); aclTensor* dequant_tensor = ggml_cann_create_tensor( dequant_buffer_allocator.get(), ACL_FLOAT, sizeof(float_t), diff --git a/ggml/src/ggml-cann/ggml-cann.cpp b/ggml/src/ggml-cann/ggml-cann.cpp index cec36b36e..8f8acaf99 100644 --- a/ggml/src/ggml-cann/ggml-cann.cpp +++ b/ggml/src/ggml-cann/ggml-cann.cpp @@ -29,6 +29,8 @@ #include #include #include +#include +#include #include "ggml-impl.h" #include "ggml-backend-impl.h" @@ -119,9 +121,10 @@ static ggml_cann_device_info ggml_cann_init() { prop.location.type = ACL_MEM_LOCATION_TYPE_DEVICE; prop.location.id = id; prop.reserve = 0; - ACL_CHECK(aclrtMemGetAllocationGranularity( + err = aclrtMemGetAllocationGranularity( &prop, ACL_RT_MEM_ALLOC_GRANULARITY_RECOMMENDED, - &info.devices[id].vmm_granularity)); + &info.devices[id].vmm_granularity); + info.devices[id].vmm = err == ACL_SUCCESS; size_t free, total; ggml_backend_cann_get_device_memory(id, &free, &total); @@ -148,11 +151,222 @@ const ggml_cann_device_info& ggml_cann_info() { //#define DEBUG_CANN_MALLOC /** - * @brief A pool of CANN buffers(legacy). + * @brief A pool of CANN buffers(priority segment buffer). * * This class manages a pool of CANN buffers for a specific device. */ -struct ggml_cann_pool_leg : public ggml_cann_pool { +struct ggml_cann_pool_buf_prio : public ggml_cann_pool { + /** + * @brief The maximum reuse margin for a buffer. + */ + static const size_t max_reuse_margin = 1ull << 22; // 4MB + + /** + * @brief The minimum free margin for a buffer. + */ + static const size_t min_free_margin = 1ull << 20; // 1MB + + /** + * @brief The alignment for buffer allocation. + */ + static const size_t alignment = 128; + + /** + * @brief The device ID associated with this buffer pool. + */ + int device; + + /** + * @brief Whether to disable clean during buffer allocation. + */ + bool disable_clean = false; + + /** + * @brief Structure representing a CANN buffer. + */ + struct ggml_cann_buffer { + void* ptr = nullptr; ///< Pointer to the buffer. + size_t size = 0; ///< Size of the buffer. + std::chrono::steady_clock::time_point last_used; ///< Last used time. + + bool operator>(const ggml_cann_buffer& other) const { + return size > other.size; + } + }; + + /** + * @brief Array of CANN buffers in the pool. + */ + std::unordered_map buffer_pool; + std::priority_queue, + std::greater<>> free_buffers ; + + /** + * @brief Total size of all buffers in the pool. + */ + size_t pool_size = 0; + + /** + * @brief Constructor to initialize the buffer pool for a specific device. + * + * @param device The device ID to associate with this buffer pool. + */ + explicit ggml_cann_pool_buf_prio(int device) : device(device) { + disable_clean = getenv("GGML_CANN_DISABLE_BUF_POOL_CLEAN") != nullptr; + } + + /** + * @brief Destructor to free all buffers in the pool. + */ + ~ggml_cann_pool_buf_prio() { + ggml_cann_set_device(device); + for (auto& [b_ptr, b_size] : buffer_pool) { + aclrtFree(b_ptr); + pool_size -= b_size; + } + buffer_pool.clear(); + GGML_ASSERT(pool_size == 0); + } + + /** + * @brief Allocate a buffer of the given size. + * + * @param size The size of the buffer to allocate. + * @param actual_size A pointer to a variable to receive the actual size of + * the allocated buffer. + * @return A pointer to the allocated buffer. + */ + void* alloc(size_t size, size_t* actual_size) override { + size = GGML_PAD(size, alignment); + if (size == 0) { + size = alignment; + } + + void* ptr = nullptr; + auto now = std::chrono::steady_clock::now(); + + std::vector free_buffers_rest; + free_buffers_rest.reserve(free_buffers.size()); + while (!free_buffers.empty()) { + auto b = free_buffers.top(); + free_buffers.pop(); + + if (b.size >= size) { + // reuse the buffer if the size is enough + const size_t margin = b.size - size; + if (margin <= max_reuse_margin) { + *actual_size = b.size; + ptr = b.ptr; + #ifdef DEBUG_CANN_MALLOC + GGML_LOG_INFO( + "cann pool[%d]: reused %p, " + "pool_size = %5u MB, " + "size = %5u MB, " + "margin = %5u MB\n", + device, b.ptr, + (uint32_t)(GGML_PAD(pool_size, 1048576) / 1048576), + (uint32_t)(GGML_PAD(size, 1048576) / 1048576), + (uint32_t)(GGML_PAD(margin, 1048576) / 1048576)); + #endif + break; + } + } + + bool should_clean = !disable_clean && + b.size > min_free_margin && + std::chrono::duration_cast(now - b.last_used).count() > 100; + if (should_clean) { + // free the buffer if the size is needed to be freed + ACL_CHECK(aclrtFree(b.ptr)); + pool_size -= b.size; + buffer_pool.erase(b.ptr); + #ifdef DEBUG_CANN_MALLOC + GGML_LOG_INFO( + "cann pool[%d]: clean %p, " + "pool_size = %5u MB, " + "size = %5u MB\n", + device, b.ptr, + (uint32_t)(GGML_PAD(pool_size, 1048576) / 1048576), + (uint32_t)(GGML_PAD(b.size, 1048576) / 1048576)); + #endif + continue; + } + free_buffers_rest.push_back(b); + } + for (ggml_cann_buffer &b : free_buffers_rest) { + free_buffers.push(std::move(b)); + } + + #ifdef DEBUG_CANN_MALLOC + GGML_LOG_INFO("cann pool[%d] free pool_size = %5u MB\n\n", device, (uint32_t)(GGML_PAD(pool_size, 1048576) / 1048576)); + #endif + if (ptr != nullptr) { + return ptr; + } + + // allocate a new buffer if no buffer can be reused + ggml_cann_set_device(device); + ACL_CHECK(aclrtMalloc(&ptr, size, ACL_MEM_MALLOC_HUGE_FIRST)); + *actual_size = size; + pool_size += size; + #ifdef DEBUG_CANN_MALLOC + GGML_LOG_INFO( + "cann pool[%d]: allocate %p, " + "pool_size = %5u MB, " + "size = %5u MB\n", + device, ptr, (uint32_t)(GGML_PAD(pool_size, 1048576) / 1048576), + (uint32_t)(GGML_PAD(size, 1048576) / 1048576)); + #endif + buffer_pool.emplace(ptr, size); + return ptr; + } + + /** + * @brief Free a buffer and return it to the pool. + * + * @param ptr Pointer to the buffer to free. + * @param size Size of the buffer to free. + */ + void free(void* ptr, size_t size) override { + auto it = buffer_pool.find(ptr); + if (it == buffer_pool.end()) { + GGML_ABORT("cann pool[%d]: buffer %p not found in pool\n", device, ptr); + } + + auto now = std::chrono::steady_clock::now(); + free_buffers.emplace(ggml_cann_buffer{ptr, it->second, now}); + #ifdef DEBUG_CANN_MALLOC + GGML_LOG_INFO( + "cann pool[%d]: return %p, " + "pool_size = %5u MB\n", + device, ptr, + (uint32_t)(GGML_PAD(pool_size, 1048576) / 1048576)); + #endif + } + }; + +/** + * @brief A pool of CANN buffers(segment buffer). + * + * This class manages a pool of CANN buffers for a specific device. + */ +struct ggml_cann_pool_buf : public ggml_cann_pool { + /** + * @brief The maximum reuse margin for a buffer. + */ + static const size_t max_reuse_margin = 1ull << 22; // 4MB + + /** + * @brief The minimum free margin for a buffer. + */ + static const size_t min_free_margin = 1ull << 20; // 1MB + + /** + * @brief The alignment for buffer allocation. + */ + static const size_t alignment = 128; + /** * @brief The maximum number of buffers in the pool. */ @@ -163,12 +377,19 @@ struct ggml_cann_pool_leg : public ggml_cann_pool { */ int device; + /** + * @brief Whether to disable clean during buffer allocation. + */ + bool disable_clean = false; + /** * @brief Structure representing a CANN buffer. */ struct ggml_cann_buffer { void* ptr = nullptr; ///< Pointer to the buffer memory. size_t size = 0; ///< Size of the buffer. + bool used = false; ///< Whether the buffer is currently in use. + std::chrono::steady_clock::time_point last_used; ///< Last used time. }; /** @@ -186,17 +407,19 @@ struct ggml_cann_pool_leg : public ggml_cann_pool { * * @param device The device ID to associate with this buffer pool. */ - explicit ggml_cann_pool_leg(int device) : device(device) {} + explicit ggml_cann_pool_buf(int device) : device(device) { + disable_clean = getenv("GGML_CANN_DISABLE_BUF_POOL_CLEAN") != nullptr; + } /** * @brief Destructor to free all buffers in the pool. */ - ~ggml_cann_pool_leg() { + ~ggml_cann_pool_buf() { ggml_cann_set_device(device); for (int i = 0; i < MAX_BUFFERS; ++i) { ggml_cann_buffer& b = buffer_pool[i]; if (b.ptr != nullptr) { - ACL_CHECK(aclrtFree(b.ptr)); + aclrtFree(b.ptr); pool_size -= b.size; } } @@ -212,63 +435,93 @@ struct ggml_cann_pool_leg : public ggml_cann_pool { * @return A pointer to the allocated buffer. */ void* alloc(size_t size, size_t* actual_size) override { - const size_t alignment = 128; size = GGML_PAD(size, alignment); if (size == 0) { size = alignment; } -#ifdef DEBUG_CANN_MALLOC - int nnz = 0; - size_t max_size = 0; -#endif - size_t best_diff = 1ull << 36; - int ibest = -1; - for (int i = 0; i < MAX_BUFFERS; ++i) { + + void* ptr = nullptr; + auto now = std::chrono::steady_clock::now(); + + int i = 0; + for (; i < MAX_BUFFERS; ++i) { ggml_cann_buffer& b = buffer_pool[i]; - if (b.ptr != nullptr) { + if (b.ptr == nullptr) { + break; + } + if (b.used) { + continue; + } + if (b.size >= size) { + // reuse the buffer if the size is enough + const size_t margin = b.size - size; + if (margin <= max_reuse_margin) { + *actual_size = b.size; + b.used = true; + ptr = b.ptr; #ifdef DEBUG_CANN_MALLOC - ++nnz; - if (b.size > max_size) max_size = b.size; + GGML_LOG_INFO( + "cann pool[%d]: reused %p, " + "pool_size = %5u MB, " + "size = %5u MB, " + "margin = %5u MB\n", + device, b.ptr, + (uint32_t)(GGML_PAD(pool_size, 1048576) / 1048576), + (uint32_t)(GGML_PAD(size, 1048576) / 1048576), + (uint32_t)(GGML_PAD(margin, 1048576) / 1048576)); #endif - if (b.size >= size) { - size_t diff = b.size - size; - if (diff < best_diff) { - best_diff = diff; - ibest = i; - if (!best_diff) { - void* ptr = b.ptr; - *actual_size = b.size; - b.ptr = nullptr; - b.size = 0; - return ptr; - } - } + break; } } + + bool should_clean = !disable_clean && + b.size > min_free_margin && + std::chrono::duration_cast(now - b.last_used).count() > 100; + if (should_clean) { + // free the buffer if the size is needed to be freed + ACL_CHECK(aclrtFree(b.ptr)); + pool_size -= b.size; +#ifdef DEBUG_CANN_MALLOC + GGML_LOG_INFO( + "cann pool[%d]: clean %p, " + "pool_size = %5u MB, " + "size = %5u MB\n", + device, b.ptr, + (uint32_t)(GGML_PAD(pool_size, 1048576) / 1048576), + (uint32_t)(GGML_PAD(b.size, 1048576) / 1048576)); +#endif + b.ptr = nullptr; + } } - if (ibest >= 0) { - ggml_cann_buffer& b = buffer_pool[ibest]; - void* ptr = b.ptr; - *actual_size = b.size; - b.ptr = nullptr; - b.size = 0; + if (ptr != nullptr) { return ptr; } - void* ptr; - ggml_cann_set_device(device); - ACL_CHECK( - aclrtMalloc(&ptr, size, ACL_MEM_MALLOC_HUGE_FIRST)); - *actual_size = size; - pool_size += size; + + if (i < MAX_BUFFERS) { + // allocate a new buffer if no buffer can be reused + ggml_cann_buffer& b = buffer_pool[i]; + ggml_cann_set_device(device); + ACL_CHECK(aclrtMalloc(&b.ptr, size, ACL_MEM_MALLOC_HUGE_FIRST)); + pool_size += size; + *actual_size = size; + b.size = size; + b.used = true; + if (i >= MAX_BUFFERS - 8) { + GGML_LOG_WARN("cann pool[%d]: slots almost full\n", device); + } #ifdef DEBUG_CANN_MALLOC - GGML_LOG_INFO( - "%s[%d]: %d buffers, max_size = %u MB, pool_size = %u MB, " - "requested %u MB\n", - __func__, device, nnz, (uint32_t)(max_size / 1024 / 1024), - (uint32_t)(pool_size / 1024 / 1024), - (uint32_t)(size / 1024 / 1024)); + GGML_LOG_INFO( + "cann pool[%d]: allocate %p, " + "pool_size = %5u MB, " + "size = %5u MB\n", + device, b.ptr, + (uint32_t)(GGML_PAD(pool_size, 1048576) / 1048576), + (uint32_t)(GGML_PAD(b.size, 1048576) / 1048576)); #endif - return ptr; + return b.ptr; + } + + GGML_ABORT("cann pool[%d]: slots full\n", device); } /** @@ -280,16 +533,21 @@ struct ggml_cann_pool_leg : public ggml_cann_pool { void free(void* ptr, size_t size) override { for (int i = 0; i < MAX_BUFFERS; ++i) { ggml_cann_buffer& b = buffer_pool[i]; - if (b.ptr == nullptr) { - b.ptr = ptr; - b.size = size; - return; + if (b.ptr != ptr) { + continue; } + b.used = false; + b.last_used = std::chrono::steady_clock::now(); +#ifdef DEBUG_CANN_MALLOC + GGML_LOG_INFO( + "cann pool[%d]: return %p, " + "pool_size = %5u MB\n", + device, b.ptr, + (uint32_t)(GGML_PAD(pool_size, 1048576) / 1048576)); +#endif + return; } - // memory should always buffered. these memory may still needed by - // tasks in stream. - // TODO, fix me. - GGML_ABORT("Cann buffer pool full, increase MAX_CANN_BUFFERS\n"); + GGML_ABORT("cann pool[%d]: slots full\n", device); } }; @@ -347,8 +605,7 @@ struct ggml_cann_pool_vmm : public ggml_cann_pool { * @param device The device ID to associate with this buffer pool. */ explicit ggml_cann_pool_vmm(int device) - : device(device), - granularity(ggml_cann_info().devices[device].vmm_granularity) { + : device(device) { auto dev = ggml_cann_info().devices[device]; granularity = dev.vmm_granularity; max_size = dev.total_vram; @@ -471,7 +728,18 @@ struct ggml_cann_pool_vmm : public ggml_cann_pool { */ std::unique_ptr ggml_backend_cann_context::new_pool_for_device( int device) { - return std::unique_ptr(new ggml_cann_pool_vmm(device)); + bool disable_vmm = (getenv("GGML_CANN_DISABLE_VMM_POOL") != nullptr); + if (!disable_vmm && ggml_cann_info().devices[device].vmm) { + GGML_LOG_INFO("%s: device %d use vmm pool\n", __func__, device); + return std::unique_ptr(new ggml_cann_pool_vmm(device)); + } + bool enable_buf_prio = (getenv("GGML_CANN_ENABLE_BUF_PRIO_POOL") != nullptr); + if (enable_buf_prio) { + GGML_LOG_INFO("%s: device %d use buffer pool with priority queue\n", __func__, device); + return std::unique_ptr(new ggml_cann_pool_buf_prio(device)); + } + GGML_LOG_INFO("%s: device %d use buffer pool\n", __func__, device); + return std::unique_ptr(new ggml_cann_pool_buf(device)); } // cann buffer @@ -1020,8 +1288,11 @@ ggml_backend_cann_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, ggml_cann_set_device(buft_ctx->device); - size = std::max(size, (size_t)1); - + const size_t alignment = 128; + size = GGML_PAD(size, alignment); + if (size == 0) { + size = alignment; + } void* dev_ptr; aclError err = aclrtMalloc(&dev_ptr, size, ACL_MEM_MALLOC_HUGE_FIRST); if (err != ACL_SUCCESS) { From 0019279bb56f028e4eee19b59b19750736c719f7 Mon Sep 17 00:00:00 2001 From: Chenguang Li <757486878@qq.com> Date: Tue, 15 Apr 2025 10:09:35 +0800 Subject: [PATCH 12/20] CANN: Opt ROPE optimization (#12865) * [CANN]Opt ROPE optimization * [CANN]Codestyle adjustment * [CANN]Fix the ROPE precision issue * [CANN]codestyle fix * [CANN]add rope unsupport case Signed-off-by: noemotiovon --- ggml/src/ggml-cann/aclnn_ops.cpp | 202 +++++++++++++------------------ ggml/src/ggml-cann/ggml-cann.cpp | 3 + 2 files changed, 84 insertions(+), 121 deletions(-) diff --git a/ggml/src/ggml-cann/aclnn_ops.cpp b/ggml/src/ggml-cann/aclnn_ops.cpp index f312a620c..2c5cdcae3 100644 --- a/ggml/src/ggml-cann/aclnn_ops.cpp +++ b/ggml/src/ggml-cann/aclnn_ops.cpp @@ -64,6 +64,7 @@ #include #include #include +#include #include #include @@ -144,23 +145,6 @@ static void aclnn_cast(ggml_backend_cann_context& ctx, aclTensor* acl_src, GGML_CANN_CALL_ACLNN_OP(Cast, acl_src, cast_data_type, acl_dst); } -/** - * @brief Casts the elements of a tensor to a specified data type using the CANN backend. - * - * @details This function performs a type conversion on the elements of the input tensor `acl_src` - * and stores the results in the destination tensor `acl_dst`. The conversion type is - * determined based on the `dst` tensor's data type. - * - * @param ctx The context for the CANN backend operations. - * @param acl_src The source tensor whose elements will be cast. - * @param acl_dst The destination tensor that will store the casted elements. - * @param dst The ggml tensor specifying the target data type. - */ -static void aclnn_cast(ggml_backend_cann_context& ctx, aclTensor* acl_src, - aclTensor* acl_dst, ggml_tensor* dst) { - aclnn_cast(ctx, acl_src, acl_dst, ggml_cann_type_mapping(dst->type)); -} - void ggml_cann_repeat(ggml_backend_cann_context& ctx, ggml_tensor* dst) { ggml_tensor* src = dst->src[0]; GGML_ASSERT(ggml_can_repeat(src, dst)); @@ -767,7 +751,7 @@ void ggml_cann_dup(ggml_backend_cann_context& ctx, ggml_tensor* dst) { if (dst->type == src0->type) { cann_copy(ctx, acl_src, acl_dst); } else { - aclnn_cast(ctx, acl_src, acl_dst, dst); + aclnn_cast(ctx, acl_src, acl_dst, ggml_cann_type_mapping(dst->type)); } } else { if (ggml_is_contiguous(src0) && ggml_is_contiguous(dst)) { @@ -792,7 +776,7 @@ void ggml_cann_dup(ggml_backend_cann_context& ctx, ggml_tensor* dst) { ggml_type_size(dst->type), src0->ne, src_trans_nb, GGML_MAX_DIMS); - aclnn_cast(ctx, acl_src, src_trans_tensor, dst); + aclnn_cast(ctx, acl_src, src_trans_tensor, ggml_cann_type_mapping(dst->type)); size_t cpy_size = ggml_nbytes(dst); ACL_CHECK(aclrtMemcpyAsync( dst->data, cpy_size, src_trans_buffer, cpy_size, @@ -814,7 +798,7 @@ void ggml_cann_dup(ggml_backend_cann_context& ctx, ggml_tensor* dst) { ggml_type_size(dst->type), src0->ne, src_trans_nb, GGML_MAX_DIMS); - aclnn_cast(ctx, acl_src, src_trans_tensor, dst); + aclnn_cast(ctx, acl_src, src_trans_tensor, ggml_cann_type_mapping(dst->type)); size_t cpy_size = ggml_nbytes(dst); ACL_CHECK(aclrtMemcpyAsync(dst->data, cpy_size, src_trans_buffer, @@ -1158,7 +1142,7 @@ void ggml_cann_im2col(ggml_backend_cann_context& ctx, ggml_tensor* dst) { tmp_cast_buffer, ggml_cann_type_mapping(dst->type), ggml_type_size(dst->type), tmp_im2col_ne, temp_cast_nb, GGML_MAX_DIMS - 1, ACL_FORMAT_ND); - aclnn_cast(ctx, tmp_im2col_tensor, tmp_cast_tensor, dst); + aclnn_cast(ctx, tmp_im2col_tensor, tmp_cast_tensor, ggml_cann_type_mapping(dst->type)); } // post-processing @@ -1733,7 +1717,7 @@ void ggml_cann_get_rows(ggml_backend_cann_context& ctx, ggml_tensor* dst) { aclTensor* src_trans_tensor = ggml_cann_create_tensor( src_trans_buffer, ACL_FLOAT, ggml_type_size(dst->type), src0->ne, src_trans_nb, GGML_MAX_DIMS); - aclnn_cast(ctx, acl_src0, src_trans_tensor, dst); + aclnn_cast(ctx, acl_src0, src_trans_tensor, ggml_cann_type_mapping(dst->type)); aclnn_embedding_4d(ctx, src_trans_buffer, src0->ne, src_trans_nb, src1, dst); ACL_CHECK(aclDestroyTensor(acl_src0)); @@ -2074,7 +2058,7 @@ static void ggml_cann_mul_mat_quant(ggml_backend_cann_context& ctx, output_buffer, ACL_FLOAT16, output_elem_size, output_cast_ne, output_cast_nb, GGML_MAX_DIMS); aclTensor* acl_dst_tensor = ggml_cann_create_tensor(dst); - aclnn_cast(ctx, acl_output_tensor, acl_dst_tensor, dst); + aclnn_cast(ctx, acl_output_tensor, acl_dst_tensor, ggml_cann_type_mapping(dst->type)); ACL_CHECK(aclDestroyTensor(acl_output_tensor)); ACL_CHECK(aclDestroyTensor(acl_dst_tensor)); @@ -2159,37 +2143,29 @@ static void aclnn_cache_init(ggml_backend_cann_context& ctx, ggml_tensor* dst, ggml_tensor* src1 = dst->src[1]; // position ggml_tensor* src2 = dst->src[2]; // freq_factors - // arange, [0,1,...,ne0/2] - int64_t arange_length = src0->ne[0] / 2; - ggml_cann_pool_alloc arange_allocator(ctx.pool(), - arange_length * sizeof(float_t)); - void* arange_buffer = arange_allocator.get(); - int64_t arange_ne[] = {arange_length, 1, 1, 1}; - size_t arange_nb[] = {sizeof(float_t), sizeof(float_t), sizeof(float_t), - arange_length * sizeof(float_t)}; + GGML_TENSOR_BINARY_OP_LOCALS - aclTensor* acl_arange_tensor = - ggml_cann_create_tensor(arange_buffer, ACL_FLOAT, sizeof(float_t), - arange_ne, arange_nb, GGML_MAX_DIMS); + // theta_scale arange, [0,1,...,ne00/2 - 1] + int64_t theta_scale_length = ne00 / 2; + ggml_cann_pool_alloc theta_scale_allocator(ctx.pool(), + theta_scale_length * sizeof(float_t)); + void* theta_scale_buffer = theta_scale_allocator.get(); + int64_t theta_scale_ne[] = {theta_scale_length, 1, 1, 1}; + size_t theta_scale_nb[] = {sizeof(float_t), sizeof(float_t), sizeof(float_t), + theta_scale_length * sizeof(float_t)}; + + aclTensor* acl_theta_scale_tensor = + ggml_cann_create_tensor(theta_scale_buffer, ACL_FLOAT, sizeof(float_t), + theta_scale_ne, theta_scale_nb, GGML_MAX_DIMS); float start = 0; float step = 1; - float stop = src0->ne[0] / 2; - float n_elements = src0->ne[0] / 2; - aclnn_arange(ctx, acl_arange_tensor, start, stop, step, n_elements); + float stop = ne00 / 2; + float n_elements = ne00 / 2; + aclnn_arange(ctx, acl_theta_scale_tensor, start, stop, step, n_elements); // power - // aclnnPowScalarTensor(): @param self is tensor which should be scalar, so - // use aclnn_pow_tensor_tensor() until fixed. aclScalar* acl_theta_scale = - // aclCreateScalar(&theta_scale, aclDataType::ACL_FLOAT); - // aclnn_power_scalar_tensor(ctx, acl_theta_scale, acl_arange_tensor, - // acl_power_tensor); - ggml_cann_pool_alloc theta_scale_allocator(ctx.pool(), - arange_length * sizeof(float_t)); - void* theta_scale_buffer = theta_scale_allocator.get(); - aclTensor* acl_theta_scale_tensor = aclnn_values( - ctx, theta_scale_buffer, arange_length * sizeof(float_t), arange_ne, - GGML_MAX_DIMS, ACL_FLOAT, sizeof(float_t), theta_scale); - aclnn_pow_tensor_tensor(ctx, acl_theta_scale_tensor, acl_arange_tensor); + aclScalar* acl_theta_scale = aclCreateScalar(&theta_scale, aclDataType::ACL_FLOAT); + GGML_CANN_CALL_ACLNN_OP(PowScalarTensor, acl_theta_scale, acl_theta_scale_tensor, acl_theta_scale_tensor); // freq_scale if (freq_scale != 1) { @@ -2200,7 +2176,7 @@ static void aclnn_cache_init(ggml_backend_cann_context& ctx, ggml_tensor* dst, if (src2) { aclTensor* acl_freq_factors_tensor = ggml_cann_create_tensor( src2->data, ggml_cann_type_mapping(src2->type), - ggml_type_size(src2->type), arange_ne, arange_nb, GGML_MAX_DIMS); + ggml_type_size(src2->type), theta_scale_ne, theta_scale_nb, GGML_MAX_DIMS); aclnn_div(ctx, acl_theta_scale_tensor, acl_freq_factors_tensor); ACL_CHECK(aclDestroyTensor(acl_freq_factors_tensor)); } @@ -2208,20 +2184,19 @@ static void aclnn_cache_init(ggml_backend_cann_context& ctx, ggml_tensor* dst, // position GGML_ASSERT(src1->type == GGML_TYPE_I32); int64_t position_length = src1->ne[0]; - int64_t position_ne[] = {1, position_length, 1, 1}; - size_t position_nb[] = {sizeof(int32_t), sizeof(int32_t), - sizeof(int32_t) * position_length, + int64_t position_ne[] = {1, 1, position_length, 1}; + size_t position_nb[] = {sizeof(int32_t), sizeof(int32_t), sizeof(int32_t), sizeof(int32_t) * position_length}; aclTensor* acl_position_tensor = ggml_cann_create_tensor( src1->data, ggml_cann_type_mapping(src1->type), ggml_type_size(src1->type), position_ne, position_nb, GGML_MAX_DIMS); // power * position - int64_t theta_length = arange_length * position_length; + int64_t theta_length = theta_scale_length * position_length; ggml_cann_pool_alloc theta_allocator(ctx.pool(), theta_length * sizeof(float_t)); void* theta_buffer = theta_allocator.get(); - int64_t theta_ne[] = {arange_length, position_length, 1, 1}; + int64_t theta_ne[] = {theta_scale_length, 1, position_length, 1}; size_t theta_nb[GGML_MAX_DIMS]; theta_nb[0] = sizeof(float_t); for (int i = 1; i < GGML_MAX_DIMS; i++) { @@ -2233,40 +2208,22 @@ static void aclnn_cache_init(ggml_backend_cann_context& ctx, ggml_tensor* dst, aclnn_mul(ctx, acl_position_tensor, acl_theta_scale_tensor, acl_theta_tensor); - // permute: [0,1,2,3]->[0,2,1,3] - int64_t permute_ne[] = {arange_length, 1, position_length, 1}; - size_t permute_nb[GGML_MAX_DIMS]; - permute_nb[0] = sizeof(float_t); - for (int i = 1; i < GGML_MAX_DIMS; i++) { - permute_nb[i] = permute_nb[i - 1] * permute_ne[i - 1]; - } - ggml_cann_pool_alloc permute_allocator(ctx.pool(), - theta_length * sizeof(float_t)); - void* permute_buffer = permute_allocator.get(); - aclTensor* acl_permute_tensor = ggml_cann_create_tensor( - permute_buffer, ACL_FLOAT, sizeof(float_t), permute_ne, permute_nb, - GGML_MAX_DIMS, ACL_FORMAT_ND); - int64_t permute_dim[] = {0, 2, 1, 3}; - int64_t num_dims = 4; - aclnn_permute(ctx, acl_theta_tensor, acl_permute_tensor, permute_dim, - num_dims); - // sin/cos ggml_cann_pool_alloc sin_allocator(ctx.pool(), theta_length * sizeof(float_t)); void* sin_buffer = sin_allocator.get(); aclTensor* acl_sin_tensor = ggml_cann_create_tensor( - sin_buffer, ACL_FLOAT, sizeof(float_t), permute_ne, permute_nb, + sin_buffer, ACL_FLOAT, sizeof(float_t), theta_ne, theta_nb, GGML_MAX_DIMS, ACL_FORMAT_ND); - aclnn_sin(ctx, acl_permute_tensor, acl_sin_tensor); + aclnn_sin(ctx, acl_theta_tensor, acl_sin_tensor); ggml_cann_pool_alloc cos_allocator(ctx.pool(), theta_length * sizeof(float_t)); void* cos_buffer = cos_allocator.get(); aclTensor* acl_cos_tensor = ggml_cann_create_tensor( - cos_buffer, ACL_FLOAT, sizeof(float_t), permute_ne, permute_nb, + cos_buffer, ACL_FLOAT, sizeof(float_t), theta_ne, theta_nb, GGML_MAX_DIMS, ACL_FORMAT_ND); - aclnn_cos(ctx, acl_permute_tensor, acl_cos_tensor); + aclnn_cos(ctx, acl_theta_tensor, acl_cos_tensor); // attn_factor if (attn_factor != 1) { @@ -2282,7 +2239,7 @@ static void aclnn_cache_init(ggml_backend_cann_context& ctx, ggml_tensor* dst, } else { int64_t num_repeats = 2; int64_t dim = 3; - int64_t output_size = arange_length * num_repeats; + int64_t output_size = theta_scale_length * num_repeats; aclnn_repeat_interleave(ctx, acl_sin_tensor, acl_sin_repeat_tensor, dim, num_repeats, output_size); aclnn_repeat_interleave(ctx, acl_cos_tensor, acl_cos_repeat_tensor, dim, @@ -2290,13 +2247,12 @@ static void aclnn_cache_init(ggml_backend_cann_context& ctx, ggml_tensor* dst, } // release - ACL_CHECK(aclDestroyTensor(acl_arange_tensor)); ACL_CHECK(aclDestroyTensor(acl_theta_scale_tensor)); ACL_CHECK(aclDestroyTensor(acl_position_tensor)); ACL_CHECK(aclDestroyTensor(acl_theta_tensor)); - ACL_CHECK(aclDestroyTensor(acl_permute_tensor)); ACL_CHECK(aclDestroyTensor(acl_sin_tensor)); ACL_CHECK(aclDestroyTensor(acl_cos_tensor)); + ACL_CHECK(aclDestroyScalar(acl_theta_scale)); } #ifdef __cplusplus @@ -2318,7 +2274,6 @@ void ggml_cann_rope(ggml_backend_cann_context& ctx, ggml_tensor* dst) { // TODO: use ascendc // Only test with LLAMA model. ggml_tensor* src0 = dst->src[0]; // input - // ggml_tensor* src2 = dst->src[2]; // freq_factors, not used now. // param float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow; @@ -2353,13 +2308,13 @@ void ggml_cann_rope(ggml_backend_cann_context& ctx, ggml_tensor* dst) { // init cos/sin cache ggml_cann_pool_alloc sin_allocator( - ctx.pool(), src0->ne[0] * src0->ne[2] * sizeof(float_t)); + ctx.pool(), ne00 * ne02 * sizeof(float_t)); ggml_cann_pool_alloc cos_allocator( - ctx.pool(), src0->ne[0] * src0->ne[2] * sizeof(float_t)); + ctx.pool(), ne00 * ne02 * sizeof(float_t)); void* sin_buffer = sin_allocator.get(); void* cos_buffer = cos_allocator.get(); - int64_t sin_reshape_ne[4] = {src0->ne[0], 1, src0->ne[2], 1}; + int64_t sin_reshape_ne[4] = {ne00, 1, ne02, 1}; size_t sin_reshape_nb[GGML_MAX_DIMS]; sin_reshape_nb[0] = sizeof(float_t); for (int i = 1; i < GGML_MAX_DIMS; i++) { @@ -2372,7 +2327,7 @@ void ggml_cann_rope(ggml_backend_cann_context& ctx, ggml_tensor* dst) { ggml_cann_create_tensor(cos_buffer, ACL_FLOAT, sizeof(float_t), sin_reshape_ne, sin_reshape_nb, GGML_MAX_DIMS); aclnn_cache_init(ctx, dst, acl_cos_reshape_tensor, acl_sin_reshape_tensor, - theta_scale, freq_scale, attn_factor, is_neox); + theta_scale, freq_scale, attn_factor, is_neox); aclTensor* acl_src = ggml_cann_create_tensor(src0); aclTensor* acl_dst = ggml_cann_create_tensor(dst); @@ -2549,46 +2504,51 @@ void ggml_cann_rope(ggml_backend_cann_context& ctx, ggml_tensor* dst) { return; #endif - // src0 == GGML_TYPE_F16 - // TODO: optimization this `if` code - if (src0->type == GGML_TYPE_F16) { - ggml_cann_pool_alloc sin_final_allocator( - ctx.pool(), src0->ne[0] * src0->ne[2] * ggml_type_size(src0->type)); - ggml_cann_pool_alloc cos_final_allocator( - ctx.pool(), src0->ne[0] * src0->ne[2] * ggml_type_size(src0->type)); - void* sin_final_buffer = sin_final_allocator.get(); - void* cos_final_buffer = cos_final_allocator.get(); + // ggml_mode = 0 --> aclnn_model = 1 + int64_t acl_mode = mode == 0 ? 1 : mode; - int64_t sin_final_ne[4] = {src0->ne[0], 1, src0->ne[2], 1}; - size_t sin_final_nb[GGML_MAX_DIMS]; - sin_final_nb[0] = ggml_type_size(src0->type); - for (int i = 1; i < GGML_MAX_DIMS; i++) { - sin_final_nb[i] = sin_final_nb[i - 1] * sin_final_ne[i - 1]; + switch (src0->type) { + case GGML_TYPE_F32: { + GGML_CANN_CALL_ACLNN_OP(RotaryPositionEmbedding, acl_src, acl_cos_reshape_tensor, + acl_sin_reshape_tensor, acl_mode, acl_dst); + break; } - aclTensor* acl_sin_final_tensor = ggml_cann_create_tensor( - sin_final_buffer, ggml_cann_type_mapping(src0->type), - ggml_type_size(src0->type), sin_final_ne, sin_final_nb, - GGML_MAX_DIMS); - aclTensor* acl_cos_final_tensor = ggml_cann_create_tensor( - cos_final_buffer, ggml_cann_type_mapping(src0->type), - ggml_type_size(src0->type), sin_final_ne, sin_final_nb, - GGML_MAX_DIMS); + case GGML_TYPE_F16: { + ggml_cann_pool_alloc src_trans_allocator( + ctx.pool(), ggml_nelements(src0) * sizeof(float)); + void* src_trans_buffer = src_trans_allocator.get(); + ggml_cann_pool_alloc dst_trans_allocator( + ctx.pool(), ggml_nelements(dst) * sizeof(float)); + void* dst_trans_buffer = dst_trans_allocator.get(); - aclnn_cast(ctx, acl_sin_reshape_tensor, acl_sin_final_tensor, dst); - aclnn_cast(ctx, acl_cos_reshape_tensor, acl_cos_final_tensor, dst); - ACL_CHECK(aclDestroyTensor(acl_cos_reshape_tensor)); - ACL_CHECK(aclDestroyTensor(acl_sin_reshape_tensor)); - acl_sin_reshape_tensor = acl_sin_final_tensor; - acl_cos_reshape_tensor = acl_cos_final_tensor; + size_t src_trans_nb[GGML_MAX_DIMS]; + src_trans_nb[0] = sizeof(float); + for (int i = 1; i < GGML_MAX_DIMS; i++) { + src_trans_nb[i] = src_trans_nb[i - 1] * src0->ne[i - 1]; + } + + aclTensor* acl_src_trans_tensor = ggml_cann_create_tensor( + src_trans_buffer, ACL_FLOAT, sizeof(float), src0->ne, src_trans_nb, + GGML_MAX_DIMS); + aclTensor* acl_dst_trans_tensor = ggml_cann_create_tensor( + dst_trans_buffer, ACL_FLOAT, sizeof(float), dst->ne, src_trans_nb, + GGML_MAX_DIMS); + + aclnn_cast(ctx, acl_src, acl_src_trans_tensor, ACL_FLOAT); + + GGML_CANN_CALL_ACLNN_OP(RotaryPositionEmbedding, acl_src_trans_tensor, acl_cos_reshape_tensor, + acl_sin_reshape_tensor, acl_mode, acl_dst_trans_tensor); + + aclnn_cast(ctx, acl_dst_trans_tensor, acl_dst, ACL_FLOAT16); + + ACL_CHECK(aclDestroyTensor(acl_src_trans_tensor)); + ACL_CHECK(aclDestroyTensor(acl_dst_trans_tensor)); + break; + } + default: + GGML_ABORT("Unsupported tensor type for GGML_OP_ROPE"); + break; } - - int acl_mode = mode; - if (mode == 0) { - acl_mode = 1; - } - - GGML_CANN_CALL_ACLNN_OP(RotaryPositionEmbedding, acl_src, acl_cos_reshape_tensor, - acl_sin_reshape_tensor, acl_mode, acl_dst); ACL_CHECK(aclDestroyTensor(acl_src)); ACL_CHECK(aclDestroyTensor(acl_cos_reshape_tensor)); ACL_CHECK(aclDestroyTensor(acl_sin_reshape_tensor)); diff --git a/ggml/src/ggml-cann/ggml-cann.cpp b/ggml/src/ggml-cann/ggml-cann.cpp index 8f8acaf99..db8ae260a 100644 --- a/ggml/src/ggml-cann/ggml-cann.cpp +++ b/ggml/src/ggml-cann/ggml-cann.cpp @@ -2087,6 +2087,9 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, return false; } + if(!ggml_is_contiguous(op->src[0])){ + return false; + } return true; } case GGML_OP_UPSCALE: { From eccc7a1602f0752507de4aaad1008b9618a282c8 Mon Sep 17 00:00:00 2001 From: Srihari-mcw <96763064+Srihari-mcw@users.noreply.github.com> Date: Tue, 15 Apr 2025 11:52:36 +0530 Subject: [PATCH 13/20] ggml : Add AVX512 implementation of GEMM - Q4_Kx8 (#12829) * Add AVX512 implementation of GEMM - q4kx8 * Update changes to remove unnecessary whitespaces --- ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp | 748 ++++++++++++++++++++++++- 1 file changed, 744 insertions(+), 4 deletions(-) diff --git a/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp b/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp index ced09c05c..175cba329 100644 --- a/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp +++ b/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp @@ -4044,7 +4044,7 @@ static void ggml_gemm_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, c UNUSED(ncols_interleaved); UNUSED(blocklen); -#if defined(__AVX2__) +#if defined(__AVX2__) || defined(__AVX512F__) const block_q4_Kx8 * b_ptr_start = (const block_q4_Kx8 * ) vx; const block_q8_Kx4 * a_ptr_start = (const block_q8_Kx4 * ) vy; int64_t b_nb = n / QK_K; @@ -4054,8 +4054,748 @@ static void ggml_gemm_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, c const __m256i m4b = _mm256_set1_epi8(0x0F); // Permute mask used for easier vector processing at later stages __m256i requiredOrder = _mm256_set_epi32(3, 2, 1, 0, 7, 6, 5, 4); - + int64_t xstart = 0; int anr = nr - nr % 16;; // Used to align nr with boundary of 16 +#ifdef __AVX512F__ + int anc = nc - nc % 16; // Used to align nc with boundary of 16 + // Mask to mask out nibbles from packed bytes expanded to 512 bit length + const __m512i m4bexpanded = _mm512_set1_epi8(0x0F); + //Take group of four block_q8_Kx4 structures at each pass of the loop and perform dot product operation + for (; y < anr / 4; y += 4) { + + const block_q8_Kx4 * a_ptrs[4]; + + a_ptrs[0] = a_ptr_start + (y * nb); + for (int i = 0; i < 3; ++i) { + a_ptrs[i + 1] = a_ptrs[i] + nb; + } + + // Take group of eight block_q4_kx8 structures at each pass of the loop and perform dot product operation + for (int64_t x = 0; x < anc / 8; x += 2) { + + const block_q4_Kx8 * b_ptr_0 = b_ptr_start + ((x) * b_nb); + const block_q4_Kx8 * b_ptr_1 = b_ptr_start + ((x + 1) * b_nb); + + // Master FP accumulators + __m512 acc_rows[16]; + for (int i = 0; i < 16; i++) { + acc_rows[i] = _mm512_setzero_ps(); + } + + __m512 acc_min_rows[16]; + for (int i = 0; i < 16; i++) { + acc_min_rows[i] = _mm512_setzero_ps(); + } + + // For super block + for (int64_t b = 0; b < nb; b++) { + // Scale values - Load the sixteen scale values from two block_q4_kx8 structures + const __m512 col_scale_f32 = GGML_F32Cx8x2_LOAD(b_ptr_0[b].d, b_ptr_1[b].d); + + // dmin values - Load the sixteen dmin values from two block_q4_kx8 structures + const __m512 col_dmin_f32 = GGML_F32Cx8x2_LOAD(b_ptr_0[b].dmin, b_ptr_1[b].dmin); + + // Loop to iterate over the eight sub blocks of a super block - two sub blocks are processed per iteration + for (int sb = 0; sb < QK_K / 64; sb++) { + + const __m256i rhs_raw_mat_0123_0 = _mm256_loadu_si256((const __m256i * )(b_ptr_0[b].qs + sb * 256)); + const __m256i rhs_raw_mat_4567_0 = _mm256_loadu_si256((const __m256i * )(b_ptr_0[b].qs + 32 + sb * 256)); + const __m256i rhs_raw_mat_0123_1 = _mm256_loadu_si256((const __m256i * )(b_ptr_0[b].qs + 64 + sb * 256)); + const __m256i rhs_raw_mat_4567_1 = _mm256_loadu_si256((const __m256i * )(b_ptr_0[b].qs + 96 + sb * 256)); + const __m256i rhs_raw_mat_0123_2 = _mm256_loadu_si256((const __m256i * )(b_ptr_0[b].qs + 128 + sb * 256)); + const __m256i rhs_raw_mat_4567_2 = _mm256_loadu_si256((const __m256i * )(b_ptr_0[b].qs + 160 + sb * 256)); + const __m256i rhs_raw_mat_0123_3 = _mm256_loadu_si256((const __m256i * )(b_ptr_0[b].qs + 192 + sb * 256)); + const __m256i rhs_raw_mat_4567_3 = _mm256_loadu_si256((const __m256i * )(b_ptr_0[b].qs + 224 + sb * 256)); + + const __m256i rhs_raw_mat_89AB_0 = _mm256_loadu_si256((const __m256i * )(b_ptr_1[b].qs + sb * 256)); + const __m256i rhs_raw_mat_CDEF_0 = _mm256_loadu_si256((const __m256i * )(b_ptr_1[b].qs + 32 + sb * 256)); + const __m256i rhs_raw_mat_89AB_1 = _mm256_loadu_si256((const __m256i * )(b_ptr_1[b].qs + 64 + sb * 256)); + const __m256i rhs_raw_mat_CDEF_1 = _mm256_loadu_si256((const __m256i * )(b_ptr_1[b].qs + 96 + sb * 256)); + const __m256i rhs_raw_mat_89AB_2 = _mm256_loadu_si256((const __m256i * )(b_ptr_1[b].qs + 128 + sb * 256)); + const __m256i rhs_raw_mat_CDEF_2 = _mm256_loadu_si256((const __m256i * )(b_ptr_1[b].qs + 160 + sb * 256)); + const __m256i rhs_raw_mat_89AB_3 = _mm256_loadu_si256((const __m256i * )(b_ptr_1[b].qs + 192 + sb * 256)); + const __m256i rhs_raw_mat_CDEF_3 = _mm256_loadu_si256((const __m256i * )(b_ptr_1[b].qs + 224 + sb * 256)); + + const __m256i rhs_raw_mat_0145_0 = _mm256_blend_epi32(rhs_raw_mat_0123_0, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_0, requiredOrder), 240); + const __m256i rhs_raw_mat_2367_0 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_0, requiredOrder), rhs_raw_mat_4567_0, 240); + const __m256i rhs_raw_mat_0145_1 = _mm256_blend_epi32(rhs_raw_mat_0123_1, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_1, requiredOrder), 240); + const __m256i rhs_raw_mat_2367_1 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_1, requiredOrder), rhs_raw_mat_4567_1, 240); + const __m256i rhs_raw_mat_0145_2 = _mm256_blend_epi32(rhs_raw_mat_0123_2, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_2, requiredOrder), 240); + const __m256i rhs_raw_mat_2367_2 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_2, requiredOrder), rhs_raw_mat_4567_2, 240); + const __m256i rhs_raw_mat_0145_3 = _mm256_blend_epi32(rhs_raw_mat_0123_3, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_3, requiredOrder), 240); + const __m256i rhs_raw_mat_2367_3 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_3, requiredOrder), rhs_raw_mat_4567_3, 240); + + const __m256i rhs_raw_mat_89CD_0 = _mm256_blend_epi32(rhs_raw_mat_89AB_0, _mm256_permutevar8x32_epi32(rhs_raw_mat_CDEF_0, requiredOrder), 240); + const __m256i rhs_raw_mat_ABEF_0 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_89AB_0, requiredOrder), rhs_raw_mat_CDEF_0, 240); + const __m256i rhs_raw_mat_89CD_1 = _mm256_blend_epi32(rhs_raw_mat_89AB_1, _mm256_permutevar8x32_epi32(rhs_raw_mat_CDEF_1, requiredOrder), 240); + const __m256i rhs_raw_mat_ABEF_1 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_89AB_1, requiredOrder), rhs_raw_mat_CDEF_1, 240); + const __m256i rhs_raw_mat_89CD_2 = _mm256_blend_epi32(rhs_raw_mat_89AB_2, _mm256_permutevar8x32_epi32(rhs_raw_mat_CDEF_2, requiredOrder), 240); + const __m256i rhs_raw_mat_ABEF_2 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_89AB_2, requiredOrder), rhs_raw_mat_CDEF_2, 240); + const __m256i rhs_raw_mat_89CD_3 = _mm256_blend_epi32(rhs_raw_mat_89AB_3, _mm256_permutevar8x32_epi32(rhs_raw_mat_CDEF_3, requiredOrder), 240); + const __m256i rhs_raw_mat_ABEF_3 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_89AB_3, requiredOrder), rhs_raw_mat_CDEF_3, 240); + + const __m512i rhs_raw_mat_014589CD_0 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_0145_0), rhs_raw_mat_89CD_0, 1); + const __m512i rhs_raw_mat_2367ABEF_0 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_2367_0), rhs_raw_mat_ABEF_0, 1); + const __m512i rhs_raw_mat_014589CD_1 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_0145_1), rhs_raw_mat_89CD_1, 1); + const __m512i rhs_raw_mat_2367ABEF_1 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_2367_1), rhs_raw_mat_ABEF_1, 1); + + const __m512i rhs_raw_mat_014589CD_2 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_0145_2), rhs_raw_mat_89CD_2, 1); + const __m512i rhs_raw_mat_2367ABEF_2 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_2367_2), rhs_raw_mat_ABEF_2, 1); + const __m512i rhs_raw_mat_014589CD_3 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_0145_3), rhs_raw_mat_89CD_3, 1); + const __m512i rhs_raw_mat_2367ABEF_3 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_2367_3), rhs_raw_mat_ABEF_3, 1); + + //4-bit -> 8-bit + const __m512i rhs_mat_014589CD_00 = _mm512_and_si512(rhs_raw_mat_014589CD_0, m4bexpanded); //B00(0-7) B01(0-7) B04(0-7) B05(0-7) B08(0-7) B09(0-7) B0C(0-7) B0D(0-7) + const __m512i rhs_mat_2367ABEF_00 = _mm512_and_si512(rhs_raw_mat_2367ABEF_0, m4bexpanded); //B02(0-7) B03(0-7) B06(0-7) B07(0-7) B0A(0-7) B0B(0-7) B0E(0-7) B0F(0-7) + const __m512i rhs_mat_014589CD_01 = _mm512_and_si512(rhs_raw_mat_014589CD_1, m4bexpanded); //B00(8-15) B01(8-15) B04(8-15) B05(8-15) B08(8-15) B09(8-15) B0C(8-15) B0D(8-15) + const __m512i rhs_mat_2367ABEF_01 = _mm512_and_si512(rhs_raw_mat_2367ABEF_1, m4bexpanded); //B02(8-15) B03(8-15) B06(8-15) B07(8-15) B0A(8-15) B0B(8-15) B0E(8-15) B0F(8-15) + + const __m512i rhs_mat_014589CD_02 = _mm512_and_si512(rhs_raw_mat_014589CD_2, m4bexpanded); //B00(16-23) B01(16-23) B04(16-23) B05(16-23) B08(16-23) B09(16-23) B0C(16-23) B0D(16-23) + const __m512i rhs_mat_2367ABEF_02 = _mm512_and_si512(rhs_raw_mat_2367ABEF_2, m4bexpanded); //B02(16-23) B03(16-23) B06(16-23) B07(16-23) B0A(16-23) B0B(16-23) B0E(16-23) B0F(16-23) + const __m512i rhs_mat_014589CD_03 = _mm512_and_si512(rhs_raw_mat_014589CD_3, m4bexpanded); //B00(24-31) B01(24-31) B04(24-31) B05(24-31) B08(24-31) B09(24-31) B0C(24-31) B0D(24-31) + const __m512i rhs_mat_2367ABEF_03 = _mm512_and_si512(rhs_raw_mat_2367ABEF_3, m4bexpanded); //B02(24-31) B03(24-31) B06(24-31) B07(24-31) B0A(24-31) B0B(24-31) B0E(24-31) B0F(24-31) + + const __m512i rhs_mat_014589CD_10 = _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_014589CD_0, 4), m4bexpanded); //B10(0-7) B11(0-7) B14(0-7) B15(0-7) B18(0-7) B19(0-7) B1C(0-7) B1D(0-7) + const __m512i rhs_mat_2367ABEF_10 = _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_2367ABEF_0, 4), m4bexpanded); //B12(0-7) B13(0-7) B16(0-7) B17(0-7) B1A(0-7) B1B(0-7) B1E(0-7) B1F(0-7) + const __m512i rhs_mat_014589CD_11 = _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_014589CD_1, 4), m4bexpanded); //B10(8-15) B11(8-15) B14(8-15) B15(8-15) B18(8-15) B19(8-15) B1C(8-15) B1D(8-15) + const __m512i rhs_mat_2367ABEF_11 = _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_2367ABEF_1, 4), m4bexpanded); //B12(8-15) B13(8-15) B16(8-15) B17(8-15) B1A(8-15) B1B(8-15) B1E(8-15) B1F(8-15) + + const __m512i rhs_mat_014589CD_12 = _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_014589CD_2, 4), m4bexpanded); //B10(16-23) B11(16-23) B14(16-23) B15(16-23) B18(16-23) B19(16-23) B1C(16-23) B1D(16-23) + const __m512i rhs_mat_2367ABEF_12 = _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_2367ABEF_2, 4), m4bexpanded); //B12(16-23) B13(16-23) B16(16-23) B17(16-23) B1A(16-23) B1B(16-23) B1E(16-23) B1F(16-23) + const __m512i rhs_mat_014589CD_13 = _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_014589CD_3, 4), m4bexpanded); //B10(24-31) B11(24-31) B14(24-31) B15(24-31) B18(24-31) B19(24-31) B1C(24-31) B1D(24-31) + const __m512i rhs_mat_2367ABEF_13 = _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_2367ABEF_3, 4), m4bexpanded); //B12(24-31) B13(24-31) B16(24-31) B17(24-31) B1A(24-31) B1B(24-31) B1E(24-31) B1F(24-31) + + // Shuffle pattern one - right side input + const __m512i rhs_mat_014589CD_00_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_00, (_MM_PERM_ENUM)136); //B00(0-3) B01(0-3) B00(0-3) B01(0-3) B04(0-3) B05(0-3) B04(0-3) B05(0-3) B08(0-3) B09(0-3) B08(0-3) B09(0-3) B0C(0-3) B0D(0-3) B0C(0-3) B0D(0-3) + const __m512i rhs_mat_2367ABEF_00_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_00, (_MM_PERM_ENUM)136); //B02(0-3) B03(0-3) B02(0-3) B03(0-3) B06(0-3) B07(0-3) B06(0-3) B07(0-3) B0A(0-3) B0B(0-3) B0A(0-3) B0B(0-3) B0E(0-3) B0F(0-3) B0E(0-3) B0F(0-3) + const __m512i rhs_mat_014589CD_01_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_01, (_MM_PERM_ENUM)136); //B00(8-11) B01(8-11) B00(8-11) B01(8-11) B04(8-11) B05(8-11) B04(8-11) B05(8-11) B08(8-11) B09(8-11) B08(8-11) B09(8-11) B0C(8-11) B0D(8-11) B0C(8-11) B0D(8-11) + const __m512i rhs_mat_2367ABEF_01_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_01, (_MM_PERM_ENUM)136); //B02(8-11) B03(8-11) B02(8-11) B03(8-11) B06(8-11) B07(8-11) B06(8-11) B07(8-11) B0A(8-11) B0B(8-11) B0A(8-11) B0B(8-11) B0E(8-11) B0F(8-11) B0E(8-11) B0F(8-11) + const __m512i rhs_mat_014589CD_02_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_02, (_MM_PERM_ENUM)136); //B00(16-19) B01(16-19) B00(16-19) B01(16-19) B04(16-19) B05(16-19) B04(16-19) B05(16-19) B08(16-19) B09(16-19) B08(16-19) B09(16-19) B0C(16-19) B0D(16-19) B0C(16-19) B0D(16-19) + const __m512i rhs_mat_2367ABEF_02_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_02, (_MM_PERM_ENUM)136); //B02(16-19) B03(16-19) B02(16-19) B03(16-19) B06(16-19) B07(16-19) B06(16-19) B07(16-19) B0A(16-19) B0B(16-19) B0A(16-19) B0B(16-19) B0E(16-19) B0F(16-19) B0E(16-19) B0F(16-19) + const __m512i rhs_mat_014589CD_03_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_03, (_MM_PERM_ENUM)136); //B00(24-27) B01(24-27) B00(24-27) B01(24-27) B04(24-27) B05(24-27) B04(24-27) B05(24-27) B08(24-27) B09(24-27) B08(24-27) B09(24-27) B0C(24-27) B0D(24-27) B0C(24-27) B0D(24-27) + const __m512i rhs_mat_2367ABEF_03_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_03, (_MM_PERM_ENUM)136); //B02(24-27) B03(24-27) B02(24-27) B03(24-27) B06(24-27) B07(24-27) B06(24-27) B07(24-27) B0A(24-27) B0B(24-27) B0A(24-27) B0B(24-27) B0E(24-27) B0F(24-27) B0E(24-27) B0F(24-27) + + const __m512i rhs_mat_014589CD_10_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_10, (_MM_PERM_ENUM)136); //B10(0-3) B11(0-3) B10(0-3) B11(0-3) B14(0-3) B15(0-3) B14(0-3) B15(0-3) B18(0-3) B19(0-3) B18(0-3) B19(0-3) B1C(0-3) B1D(0-3) B1C(0-3) B1D(0-3) + const __m512i rhs_mat_2367ABEF_10_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_10, (_MM_PERM_ENUM)136); //B12(0-3) B13(0-3) B12(0-3) B13(0-3) B16(0-3) B17(0-3) B16(0-3) B17(0-3) B1A(0-3) B1B(0-3) B1A(0-3) B1B(0-3) B1E(0-3) B1F(0-3) B1E(0-3) B1F(0-3) + const __m512i rhs_mat_014589CD_11_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_11, (_MM_PERM_ENUM)136); //B10(8-11) B11(8-11) B10(8-11) B11(8-11) B14(8-11) B15(8-11) B14(8-11) B15(8-11) B18(8-11) B19(8-11) B18(8-11) B19(8-11) B1C(8-11) B1D(8-11) B1C(8-11) B1D(8-11) + const __m512i rhs_mat_2367ABEF_11_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_11, (_MM_PERM_ENUM)136); //B12(8-11) B13(8-11) B12(8-11) B13(8-11) B16(8-11) B17(8-11) B16(8-11) B17(8-11) B1A(8-11) B1B(8-11) B1A(8-11) B1B(8-11) B1E(8-11) B1F(8-11) B1E(8-11) B1F(8-11) + const __m512i rhs_mat_014589CD_12_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_12, (_MM_PERM_ENUM)136); //B10(16-19) B11(16-19) B10(16-19) B11(16-19) B14(16-19) B15(16-19) B14(16-19) B15(16-19) B18(16-19) B19(16-19) B18(16-19) B19(16-19) B1C(16-19) B1D(16-19) B1C(16-19) B1D(16-19) + const __m512i rhs_mat_2367ABEF_12_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_12, (_MM_PERM_ENUM)136); //B12(16-19) B13(16-19) B12(16-19) B13(16-19) B16(16-19) B17(16-19) B16(16-19) B17(16-19) B1A(16-19) B1B(16-19) B1A(16-19) B1B(16-19) B1E(16-19) B1F(16-19) B1E(16-19) B1F(16-19) + const __m512i rhs_mat_014589CD_13_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_13, (_MM_PERM_ENUM)136); //B10(24-27) B11(24-27) B10(24-27) B11(24-27) B14(24-27) B15(24-27) B14(24-27) B15(24-27) B18(24-27) B19(24-27) B18(24-27) B19(24-27) B1C(24-27) B1D(24-27) B1C(24-27) B1D(24-27) + const __m512i rhs_mat_2367ABEF_13_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_13, (_MM_PERM_ENUM)136); //B12(24-27) B13(24-27) B12(24-27) B13(24-27) B16(24-27) B17(24-27) B16(24-27) B17(24-27) B1A(24-27) B1B(24-27) B1A(24-27) B1B(24-27) B1E(24-27) B1F(24-27) B1E(24-27) B1F(24-27) + + // Shuffle pattern two - right side input + const __m512i rhs_mat_014589CD_00_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_00, (_MM_PERM_ENUM)221); //B00(4-7) B01(4-7) B00(4-7) B01(4-7) B04(4-7) B05(4-7) B04(4-7) B05(4-7) B08(4-7) B09(4-7) B08(4-7) B09(4-7) B0C(4-7) B0D(4-7) B0C(4-7) B0D(4-7) + const __m512i rhs_mat_2367ABEF_00_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_00, (_MM_PERM_ENUM)221); //B02(4-7) B03(4-7) B02(4-7) B03(4-7) B06(4-7) B07(4-7) B06(4-7) B07(4-7) B0A(4-7) B0B(4-7) B0A(4-7) B0B(4-7) B0E(4-7) B0F(4-7) B0E(4-7) B0F(4-7) + const __m512i rhs_mat_014589CD_01_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_01, (_MM_PERM_ENUM)221); //B00(12-15) B01(12-15) B00(12-15) B01(12-15) B04(12-15) B05(12-15) B04(12-15) B05(12-15) B08(12-15) B09(12-15) B08(12-15) B09(12-15) B0C(12-15) B0D(12-15) B0C(12-15) B0D(12-15) + const __m512i rhs_mat_2367ABEF_01_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_01, (_MM_PERM_ENUM)221); //B02(12-15) B03(12-15) B02(12-15) B03(12-15) B06(12-15) B07(12-15) B06(12-15) B07(12-15) B0A(12-15) B0B(12-15) B0A(12-15) B0B(12-15) B0E(12-15) B0F(12-15) B0E(12-15) B0F(12-15) + const __m512i rhs_mat_014589CD_02_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_02, (_MM_PERM_ENUM)221); //B00(20-23) B01(20-23) B00(20-23) B01(20-23) B04(20-23) B05(20-23) B04(20-23) B05(20-23) B08(20-23) B09(20-23) B08(20-23) B09(20-23) B0C(20-23) B0D(20-23) B0C(20-23) B0D(20-23) + const __m512i rhs_mat_2367ABEF_02_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_02, (_MM_PERM_ENUM)221); //B02(20-23) B03(20-23) B02(20-23) B03(20-23) B06(20-23) B07(20-23) B06(20-23) B07(20-23) B0A(20-23) B0B(20-23) B0A(20-23) B0B(20-23) B0E(20-23) B0F(20-23) B0E(20-23) B0F(20-23) + const __m512i rhs_mat_014589CD_03_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_03, (_MM_PERM_ENUM)221); //B00(28-31) B01(28-31) B00(28-31) B01(28-31) B04(28-31) B05(28-31) B04(28-31) B05(28-31) B08(28-31) B09(28-31) B08(28-31) B09(28-31) B0C(28-31) B0D(28-31) B0C(28-31) 0BD(28-31) + const __m512i rhs_mat_2367ABEF_03_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_03, (_MM_PERM_ENUM)221); //B02(28-31) B03(28-31) B02(28-31) B03(28-31) B06(28-31) B07(28-31) B06(28-31) B07(28-31) B0A(28-31) B0B(28-31) B0A(28-31) B0B(28-31) B0E(28-31) B0F(28-31) B0E(28-31) B0F(28-31) + + const __m512i rhs_mat_014589CD_10_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_10, (_MM_PERM_ENUM)221); //B10(4-7) B11(4-7) B10(4-7) B11(4-7) B14(4-7) B15(4-7) B14(4-7) B15(4-7) B18(4-7) B19(4-7) B18(4-7) B19(4-7) B1C(4-7) B1D(4-7) B1C(4-7) B1D(4-7) + const __m512i rhs_mat_2367ABEF_10_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_10, (_MM_PERM_ENUM)221); //B12(4-7) B13(4-7) B12(4-7) B13(4-7) B16(4-7) B17(4-7) B16(4-7) B17(4-7) B1A(4-7) B1B(4-7) B1A(4-7) B1B(4-7) B1E(4-7) B1F(4-7) B1E(4-7) B1F(4-7) + const __m512i rhs_mat_014589CD_11_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_11, (_MM_PERM_ENUM)221); //B10(12-15) B11(12-15) B10(12-15) B11(12-15) B14(12-15) B15(12-15) B14(12-15) B15(12-15) B18(12-15) B19(12-15) B18(12-15) B19(12-15) B1C(12-15) B1D(12-15) B1C(12-15) B1D(12-15) + const __m512i rhs_mat_2367ABEF_11_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_11, (_MM_PERM_ENUM)221); //B12(12-15) B13(12-15) B12(12-15) B13(12-15) B16(12-15) B17(12-15) B16(12-15) B17(12-15) B1A(12-15) B1B(12-15) B1A(12-15) B1B(12-15) B1E(12-15) B1F(12-15) B1E(12-15) B1F(12-15) + const __m512i rhs_mat_014589CD_12_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_12, (_MM_PERM_ENUM)221); //B10(20-23) B11(20-23) B10(20-23) B11(20-23) B14(20-23) B15(20-23) B14(20-23) B15(20-23) B18(20-23) B19(20-23) B18(20-23) B19(20-23) B1C(20-23) B1D(20-23) B1C(20-23) B1D(20-23) + const __m512i rhs_mat_2367ABEF_12_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_12, (_MM_PERM_ENUM)221); //B12(20-23) B13(20-23) B12(20-23) B13(20-23) B16(20-23) B17(20-23) B16(20-23) B17(20-23) B1A(20-23) B1B(20-23) B1A(20-23) B1B(20-23) B1E(20-23) B1F(20-23) B1E(20-23) B1F(20-23) + const __m512i rhs_mat_014589CD_13_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_13, (_MM_PERM_ENUM)221); //B10(28-31) B11(28-31) B10(28-31) B11(28-31) B14(28-31) B15(28-31) B14(28-31) B15(28-31) B18(28-31) B19(28-31) B18(28-31) B19(28-31) B1C(28-31) B1D(28-31) B1C(28-31) B1D(28-31) + const __m512i rhs_mat_2367ABEF_13_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_13, (_MM_PERM_ENUM)221); //B12(28-31) B13(28-31) B12(28-31) B13(28-31) B16(28-31) B17(28-31) B16(28-31) B17(28-31) B1A(28-31) B1B(28-31) B1A(28-31) B1B(28-31) B1E(28-31) B1F(28-31) B1E(28-31) B1F(28-31) + + uint32_t utmp_00[4], utmp_01[4], utmp_10[4], utmp_11[4]; + + // Scales and Mins of corresponding sub blocks from different Q4_K structures are stored together + // The below block is for eg to extract first sub block's scales and mins from different Q4_K structures for the sb loop + memcpy(utmp_00, b_ptr_0[b].scales + 24 * sb, 12); + utmp_00[3] = ((utmp_00[2] >> 4) & kmask2) | (((utmp_00[1] >> 6) & kmask3) << 4); + const uint32_t uaux_00 = utmp_00[1] & kmask1; + utmp_00[1] = (utmp_00[2] & kmask2) | (((utmp_00[0] >> 6) & kmask3) << 4); + utmp_00[2] = uaux_00; + utmp_00[0] &= kmask1; + + // The below block is for eg to extract second sub block's scales and mins from different Q4_K structures for the sb loop + memcpy(utmp_01, b_ptr_0[b].scales + 12 + sb * 24, 12); + utmp_01[3] = ((utmp_01[2] >> 4) & kmask2) | (((utmp_01[1] >> 6) & kmask3) << 4); + const uint32_t uaux_01 = utmp_01[1] & kmask1; + utmp_01[1] = (utmp_01[2] & kmask2) | (((utmp_01[0] >> 6) & kmask3) << 4); + utmp_01[2] = uaux_01; + utmp_01[0] &= kmask1; + + memcpy(utmp_10, b_ptr_1[b].scales + sb * 24, 12); + utmp_10[3] = ((utmp_10[2] >> 4) & kmask2) | (((utmp_10[1] >> 6) & kmask3) << 4); + const uint32_t uaux_10 = utmp_10[1] & kmask1; + utmp_10[1] = (utmp_10[2] & kmask2) | (((utmp_10[0] >> 6) & kmask3) << 4); + utmp_10[2] = uaux_10; + utmp_10[0] &= kmask1; + + // The below block is for eg to extract second sub block's scales and mins from different Q4_K structures for the sb loop + memcpy(utmp_11, b_ptr_1[b].scales + 12 + sb * 24, 12); + utmp_11[3] = ((utmp_11[2] >> 4) & kmask2) | (((utmp_11[1] >> 6) & kmask3) << 4); + const uint32_t uaux_11 = utmp_11[1] & kmask1; + utmp_11[1] = (utmp_11[2] & kmask2) | (((utmp_11[0] >> 6) & kmask3) << 4); + utmp_11[2] = uaux_11; + utmp_11[0] &= kmask1; + + // Scales of first sub block in the sb loop + const __m256i mins_and_scales_0 = _mm256_set_epi32(utmp_10[3], utmp_10[2], utmp_10[1], utmp_10[0], utmp_00[3], utmp_00[2], utmp_00[1], utmp_00[0]); + const __m512i scales_0 = _mm512_cvtepu8_epi16(_mm256_unpacklo_epi8(mins_and_scales_0, mins_and_scales_0)); + + // Scales of second sub block in the sb loop + const __m256i mins_and_scales_1 = _mm256_set_epi32(utmp_11[3], utmp_11[2], utmp_11[1], utmp_11[0], utmp_01[3], utmp_01[2], utmp_01[1], utmp_01[0]); + const __m512i scales_1 = _mm512_cvtepu8_epi16(_mm256_unpacklo_epi8(mins_and_scales_1, mins_and_scales_1)); + + // Mins of first and second sub block of Q4_K block are arranged side by side + const __m512i mins_01 = _mm512_cvtepu8_epi16(_mm256_unpacklo_epi8(_mm256_shuffle_epi32(mins_and_scales_0, 78), _mm256_shuffle_epi32(mins_and_scales_1, 78))); + + const __m512i scale_014589CD_0 = _mm512_shuffle_epi32(scales_0, (_MM_PERM_ENUM)68); + const __m512i scale_2367ABEF_0 = _mm512_shuffle_epi32(scales_0, (_MM_PERM_ENUM)238); + + const __m512i scale_014589CD_1 = _mm512_shuffle_epi32(scales_1, (_MM_PERM_ENUM)68); + const __m512i scale_2367ABEF_1 = _mm512_shuffle_epi32(scales_1, (_MM_PERM_ENUM)238); + + for (int rp = 0; rp < 4; rp++) { + + // Load the four block_q8_k quantized values interleaved with each other in chunks of eight bytes - A0,A1,A2,A3 + // Loaded as set of 128 bit vectors and repeated and stored into a 256 bit vector before again repeating into 512 bit vector + __m256i lhs_mat_ymm_0123_00 = _mm256_loadu_si256((const __m256i * )((a_ptrs[rp][b].qs + 256 * sb))); + __m256i lhs_mat_ymm_01_00 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_00, lhs_mat_ymm_0123_00, 0); + __m256i lhs_mat_ymm_23_00 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_00, lhs_mat_ymm_0123_00, 17); + __m256i lhs_mat_ymm_0123_01 = _mm256_loadu_si256((const __m256i * )((a_ptrs[rp][b].qs + 32 + 256 * sb))); + __m256i lhs_mat_ymm_01_01 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_01, lhs_mat_ymm_0123_01, 0); + __m256i lhs_mat_ymm_23_01 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_01, lhs_mat_ymm_0123_01, 17); + __m256i lhs_mat_ymm_0123_02 = _mm256_loadu_si256((const __m256i * )((a_ptrs[rp][b].qs + 64 + 256 * sb))); + __m256i lhs_mat_ymm_01_02 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_02, lhs_mat_ymm_0123_02, 0); + __m256i lhs_mat_ymm_23_02 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_02, lhs_mat_ymm_0123_02, 17); + __m256i lhs_mat_ymm_0123_03 = _mm256_loadu_si256((const __m256i * )((a_ptrs[rp][b].qs + 96 + 256 * sb))); + __m256i lhs_mat_ymm_01_03 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_03, lhs_mat_ymm_0123_03, 0); + __m256i lhs_mat_ymm_23_03 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_03, lhs_mat_ymm_0123_03, 17); + __m256i lhs_mat_ymm_0123_10 = _mm256_loadu_si256((const __m256i * )((a_ptrs[rp][b].qs + 128 + 256 * sb))); + __m256i lhs_mat_ymm_01_10 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_10, lhs_mat_ymm_0123_10, 0); + __m256i lhs_mat_ymm_23_10 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_10, lhs_mat_ymm_0123_10, 17); + __m256i lhs_mat_ymm_0123_11 = _mm256_loadu_si256((const __m256i * )((a_ptrs[rp][b].qs + 160 + 256 * sb))); + __m256i lhs_mat_ymm_01_11 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_11, lhs_mat_ymm_0123_11, 0); + __m256i lhs_mat_ymm_23_11 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_11, lhs_mat_ymm_0123_11, 17); + __m256i lhs_mat_ymm_0123_12 = _mm256_loadu_si256((const __m256i * )((a_ptrs[rp][b].qs + 192 + 256 * sb))); + __m256i lhs_mat_ymm_01_12 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_12, lhs_mat_ymm_0123_12, 0); + __m256i lhs_mat_ymm_23_12 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_12, lhs_mat_ymm_0123_12, 17); + __m256i lhs_mat_ymm_0123_13 = _mm256_loadu_si256((const __m256i * )((a_ptrs[rp][b].qs + 224 + 256 * sb))); + __m256i lhs_mat_ymm_01_13 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_13, lhs_mat_ymm_0123_13, 0); + __m256i lhs_mat_ymm_23_13 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_13, lhs_mat_ymm_0123_13, 17); + + __m512i lhs_mat_01_00 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_00), lhs_mat_ymm_01_00, 1); + __m512i lhs_mat_23_00 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_00), lhs_mat_ymm_23_00, 1); + __m512i lhs_mat_01_01 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_01), lhs_mat_ymm_01_01, 1); + __m512i lhs_mat_23_01 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_01), lhs_mat_ymm_23_01, 1); + __m512i lhs_mat_01_02 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_02), lhs_mat_ymm_01_02, 1); + __m512i lhs_mat_23_02 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_02), lhs_mat_ymm_23_02, 1); + __m512i lhs_mat_01_03 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_03), lhs_mat_ymm_01_03, 1); + __m512i lhs_mat_23_03 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_03), lhs_mat_ymm_23_03, 1); + + __m512i lhs_mat_01_10 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_10), lhs_mat_ymm_01_10, 1); + __m512i lhs_mat_23_10 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_10), lhs_mat_ymm_23_10, 1); + __m512i lhs_mat_01_11 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_11), lhs_mat_ymm_01_11, 1); + __m512i lhs_mat_23_11 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_11), lhs_mat_ymm_23_11, 1); + __m512i lhs_mat_01_12 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_12), lhs_mat_ymm_01_12, 1); + __m512i lhs_mat_23_12 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_12), lhs_mat_ymm_23_12, 1); + __m512i lhs_mat_01_13 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_13), lhs_mat_ymm_01_13, 1); + __m512i lhs_mat_23_13 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_13), lhs_mat_ymm_23_13, 1); + + // Bsums are loaded - four bsums are loaded (for two sub blocks) for the different Q8_K blocks + __m256i lhs_bsums_0123_01 = _mm256_loadu_si256((const __m256i * )((a_ptrs[rp][b].bsums + 16 * sb))); + __m256i lhs_bsums_hsum_ymm_0123_01 = _mm256_castsi128_si256(_mm_hadd_epi16(_mm256_castsi256_si128(lhs_bsums_0123_01), _mm256_extractf128_si256(lhs_bsums_0123_01, 1))); + lhs_bsums_hsum_ymm_0123_01 = _mm256_permute2x128_si256(lhs_bsums_hsum_ymm_0123_01, lhs_bsums_hsum_ymm_0123_01, 0); + __m512i lhs_bsums_hsum_0123_01 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_bsums_hsum_ymm_0123_01), lhs_bsums_hsum_ymm_0123_01, 1); + + // Shuffle pattern one - left side input + const __m512i lhs_mat_01_00_sp1 = _mm512_shuffle_epi32(lhs_mat_01_00, (_MM_PERM_ENUM)160); //A00(0-3) A00(0-3) A01(0-3) A01(0-3) A00(0-3) A00(0-3) A01(0-3) A01(0-3) A00(0-3) A00(0-3) A01(0-3) A01(0-3) A00(0-3) A00(0-3) A01(0-3) A01(0-3) + const __m512i lhs_mat_23_00_sp1 = _mm512_shuffle_epi32(lhs_mat_23_00, (_MM_PERM_ENUM)160); //A02(0-3) A02(0-3) A03(0-3) A03(0-3) A02(0-3) A02(0-3) A03(0-3) A03(0-3) A02(0-3) A02(0-3) A03(0-3) A03(0-3) A02(0-3) A02(0-3) A03(0-3) A03(0-3) + const __m512i lhs_mat_01_01_sp1 = _mm512_shuffle_epi32(lhs_mat_01_01, (_MM_PERM_ENUM)160); //A00(8-11) A00(8-11) A01(8-11) A01(8-11) A00(8-11) A00(8-11) A01(8-11) A01(8-11) A00(8-11) A00(8-11) A01(8-11) A01(8-11) A00(8-11) A00(8-11) A01(8-11) A01(8-11) + const __m512i lhs_mat_23_01_sp1 = _mm512_shuffle_epi32(lhs_mat_23_01, (_MM_PERM_ENUM)160); //A02(8-11) A02(8-11) A03(8-11) A03(8-11) A02(8-11) A02(8-11) A03(8-11) A03(8-11) A02(8-11) A02(8-11) A03(8-11) A03(8-11) A02(8-11) A02(8-11) A03(8-11) A03(8-11) + const __m512i lhs_mat_01_02_sp1 = _mm512_shuffle_epi32(lhs_mat_01_02, (_MM_PERM_ENUM)160); //A00(16-19) A00(16-19) A01(16-19) A01(16-19) A00(16-19) A00(16-19) A01(16-19) A01(16-19) A00(16-19) A00(16-19) A01(16-19) A01(16-19) A00(16-19) A00(16-19) A01(16-19) A01(16-19) + const __m512i lhs_mat_23_02_sp1 = _mm512_shuffle_epi32(lhs_mat_23_02, (_MM_PERM_ENUM)160); //A02(16-19) A02(16-19) A03(16-19) A03(16-19) A02(16-19) A02(16-19) A03(16-19) A03(16-19) A02(16-19) A02(16-19) A03(16-19) A03(16-19) A02(16-19) A02(16-19) A03(16-19) A03(16-19) + const __m512i lhs_mat_01_03_sp1 = _mm512_shuffle_epi32(lhs_mat_01_03, (_MM_PERM_ENUM)160); //A00(24-27) A00(24-27) A01(24-27) A01(24-27) A00(24-27) A00(24-27) A01(24-27) A01(24-27) A00(24-27) A00(24-27) A01(24-27) A01(24-27) A00(24-27) A00(24-27) A01(24-27) A01(24-27) + const __m512i lhs_mat_23_03_sp1 = _mm512_shuffle_epi32(lhs_mat_23_03, (_MM_PERM_ENUM)160); //A02(24-27) A02(24-27) A03(24-27) A03(24-27) A02(24-27) A02(24-27) A03(24-27) A03(24-27) A02(24-27) A02(24-27) A03(24-27) A03(24-27) A02(24-27) A02(24-27) A03(24-27) A03(24-27) + + const __m512i lhs_mat_01_10_sp1 = _mm512_shuffle_epi32(lhs_mat_01_10, (_MM_PERM_ENUM)160); //A10(0-3) A10(0-3) A11(0-3) A11(0-3) A10(0-3) A10(0-3) A11(0-3) A11(0-3) A10(0-3) A10(0-3) A11(0-3) A11(0-3) A10(0-3) A10(0-3) A11(0-3) A11(0-3) + const __m512i lhs_mat_23_10_sp1 = _mm512_shuffle_epi32(lhs_mat_23_10, (_MM_PERM_ENUM)160); //A12(0-3) A12(0-3) A13(0-3) A13(0-3) A12(0-3) A12(0-3) A13(0-3) A13(0-3) A12(0-3) A12(0-3) A13(0-3) A13(0-3) A12(0-3) A12(0-3) A13(0-3) A13(0-3) + const __m512i lhs_mat_01_11_sp1 = _mm512_shuffle_epi32(lhs_mat_01_11, (_MM_PERM_ENUM)160); //A10(8-11) A10(8-11) A11(8-11) A11(8-11) A10(8-11) A10(8-11) A11(8-11) A11(8-11) A10(8-11) A10(8-11) A11(8-11) A11(8-11) A10(8-11) A10(8-11) A11(8-11) A11(8-11) + const __m512i lhs_mat_23_11_sp1 = _mm512_shuffle_epi32(lhs_mat_23_11, (_MM_PERM_ENUM)160); //A12(8-11) A12(8-11) A13(8-11) A13(8-11) A12(8-11) A12(8-11) A13(8-11) A13(8-11) A12(8-11) A12(8-11) A13(8-11) A13(8-11) A12(8-11) A12(8-11) A13(8-11) A13(8-11) + const __m512i lhs_mat_01_12_sp1 = _mm512_shuffle_epi32(lhs_mat_01_12, (_MM_PERM_ENUM)160); //A10(16-19) A10(16-19) A11(16-19) A11(16-19) A10(16-19) A10(16-19) A11(16-19) A11(16-19) A10(16-19) A10(16-19) A11(16-19) A11(16-19) A10(16-19) A10(16-19) A11(16-19) A11(16-19) + const __m512i lhs_mat_23_12_sp1 = _mm512_shuffle_epi32(lhs_mat_23_12, (_MM_PERM_ENUM)160); //A12(16-19) A12(16-19) A13(16-19) A13(16-19) A12(16-19) A12(16-19) A13(16-19) A13(16-19) A12(16-19) A12(16-19) A13(16-19) A13(16-19) A12(16-19) A12(16-19) A13(16-19) A13(16-19) + const __m512i lhs_mat_01_13_sp1 = _mm512_shuffle_epi32(lhs_mat_01_13, (_MM_PERM_ENUM)160); //A10(24-27) A10(24-27) A11(24-27) A11(24-27) A10(24-27) A10(24-27) A11(24-27) A11(24-27) A10(24-27) A10(24-27) A11(24-27) A11(24-27) A10(24-27) A10(24-27) A11(24-27) A11(24-27) + const __m512i lhs_mat_23_13_sp1 = _mm512_shuffle_epi32(lhs_mat_23_13, (_MM_PERM_ENUM)160); //A12(24-27) A12(24-27) A13(24-27) A13(24-27) A12(24-27) A12(24-27) A13(24-27) A13(24-27) A12(24-27) A12(24-27) A13(24-27) A13(24-27) A12(24-27) A12(24-27) A13(24-27) A13(24-27) + + const __m512i lhs_mat_01_00_sp2 = _mm512_shuffle_epi32(lhs_mat_01_00, (_MM_PERM_ENUM)245); //A00(4-7) A00(4-7) A01(4-7) A01(4-7) A00(4-7) A00(4-7) A01(4-7) A01(4-7) A00(4-7) A00(4-7) A01(4-7) A01(4-7) A00(4-7) A00(4-7) A01(4-7) A01(4-7) + const __m512i lhs_mat_23_00_sp2 = _mm512_shuffle_epi32(lhs_mat_23_00, (_MM_PERM_ENUM)245); //A02(4-7) A02(4-7) A03(4-7) A03(4-7) A02(4-7) A02(4-7) A03(4-7) A03(4-7) A02(4-7) A02(4-7) A03(4-7) A03(4-7) A02(4-7) A02(4-7) A03(4-7) A03(4-7) + const __m512i lhs_mat_01_01_sp2 = _mm512_shuffle_epi32(lhs_mat_01_01, (_MM_PERM_ENUM)245); //A00(12-15) A00(12-15) A01(12-15) A01(12-15) A00(12-15) A00(12-15) A01(12-15) A01(12-15) A00(12-15) A00(12-15) A01(12-15) A01(12-15) A00(12-15) A00(12-15) A01(12-15) A01(12-15) + const __m512i lhs_mat_23_01_sp2 = _mm512_shuffle_epi32(lhs_mat_23_01, (_MM_PERM_ENUM)245); //A02(12-15) A02(12-15) A03(12-15) A03(12-15) A02(12-15) A02(12-15) A03(12-15) A03(12-15) A02(12-15) A02(12-15) A03(12-15) A03(12-15) A02(12-15) A02(12-15) A03(12-15) A03(12-15) + const __m512i lhs_mat_01_02_sp2 = _mm512_shuffle_epi32(lhs_mat_01_02, (_MM_PERM_ENUM)245); //A00(20-23) A00(20-23) A01(20-23) A01(20-23) A00(20-23) A00(20-23) A01(20-23) A01(20-23) A00(20-23) A00(20-23) A01(20-23) A01(20-23) A00(20-23) A00(20-23) A01(20-23) A01(20-23) + const __m512i lhs_mat_23_02_sp2 = _mm512_shuffle_epi32(lhs_mat_23_02, (_MM_PERM_ENUM)245); //A02(20-23) A02(20-23) A03(20-23) A03(20-23) A02(20-23) A02(20-23) A03(20-23) A03(20-23) A02(20-23) A02(20-23) A03(20-23) A03(20-23) A02(20-23) A02(20-23) A03(20-23) A03(20-23) + const __m512i lhs_mat_01_03_sp2 = _mm512_shuffle_epi32(lhs_mat_01_03, (_MM_PERM_ENUM)245); //A00(28-31) A00(28-31) A01(28-31) A01(28-31) A00(28-31) A00(28-31) A01(28-31) A01(28-31) A00(28-31) A00(28-31) A01(28-31) A01(28-31) A00(28-31) A00(28-31) A01(28-31) A01(28-31) + const __m512i lhs_mat_23_03_sp2 = _mm512_shuffle_epi32(lhs_mat_23_03, (_MM_PERM_ENUM)245); //A02(28-31) A02(28-31) A03(28-31) A03(28-31) A02(28-31) A02(28-31) A03(28-31) A03(28-31) A02(28-31) A02(28-31) A03(28-31) A03(28-31) A02(28-31) A02(28-31) A03(28-31) A03(28-31) + + const __m512i lhs_mat_01_10_sp2 = _mm512_shuffle_epi32(lhs_mat_01_10, (_MM_PERM_ENUM)245); //A10(4-7) A10(4-7) A11(4-7) A11(4-7) A10(4-7) A10(4-7) A11(4-7) A11(4-7) A10(4-7) A10(4-7) A11(4-7) A11(4-7) A10(4-7) A10(4-7) A11(4-7) A11(4-7) + const __m512i lhs_mat_23_10_sp2 = _mm512_shuffle_epi32(lhs_mat_23_10, (_MM_PERM_ENUM)245); //A12(4-7) A12(4-7) A13(4-7) A13(4-7) A12(4-7) A12(4-7) A13(4-7) A13(4-7) A12(4-7) A12(4-7) A13(4-7) A13(4-7) A12(4-7) A12(4-7) A13(4-7) A13(4-7) + const __m512i lhs_mat_01_11_sp2 = _mm512_shuffle_epi32(lhs_mat_01_11, (_MM_PERM_ENUM)245); //A10(12-15) A10(12-15) A11(12-15) A11(12-15) A10(12-15) A10(12-15) A11(12-15) A11(12-15) A10(12-15) A10(12-15) A11(12-15) A11(12-15) A10(12-15) A10(12-15) A11(12-15) A11(12-15) + const __m512i lhs_mat_23_11_sp2 = _mm512_shuffle_epi32(lhs_mat_23_11, (_MM_PERM_ENUM)245); //A12(12-15) A12(12-15) A13(12-15) A13(12-15) A12(12-15) A12(12-15) A13(12-15) A13(12-15) A12(12-15) A12(12-15) A13(12-15) A13(12-15) A12(12-15) A12(12-15) A13(12-15) A13(12-15) + const __m512i lhs_mat_01_12_sp2 = _mm512_shuffle_epi32(lhs_mat_01_12, (_MM_PERM_ENUM)245); //A10(20-23) A10(20-23) A11(20-23) A11(20-23) A10(20-23) A10(20-23) A11(20-23) A11(20-23) A10(20-23) A10(20-23) A11(20-23) A11(20-23) A10(20-23) A10(20-23) A11(20-23) A11(20-23) + const __m512i lhs_mat_23_12_sp2 = _mm512_shuffle_epi32(lhs_mat_23_12, (_MM_PERM_ENUM)245); //A12(20-23) A12(20-23) A13(20-23) A13(20-23) A12(20-23) A12(20-23) A13(20-23) A13(20-23) A12(20-23) A12(20-23) A13(20-23) A13(20-23) A12(20-23) A12(20-23) A13(20-23) A13(20-23) + const __m512i lhs_mat_01_13_sp2 = _mm512_shuffle_epi32(lhs_mat_01_13, (_MM_PERM_ENUM)245); //A10(28-31) A10(28-31) A11(28-31) A11(28-31) A10(28-31) A10(28-31) A11(28-31) A11(28-31) A10(28-31) A10(28-31) A11(28-31) A11(28-31) A10(28-31) A10(28-31) A11(28-31) A11(28-31) + const __m512i lhs_mat_23_13_sp2 = _mm512_shuffle_epi32(lhs_mat_23_13, (_MM_PERM_ENUM)245); //A12(28-31) A12(28-31) A13(28-31) A13(28-31) A12(28-31) A12(28-31) A13(28-31) A13(28-31) A12(28-31) A12(28-31) A13(28-31) A13(28-31) A12(28-31) A12(28-31) A13(28-31) A13(28-31) + + // The values arranged in shuffle patterns are operated with dot product operation within 32 bit lane i.e corresponding bytes and multiplied and added into 32 bit integers within 32 bit lane + __m512i iacc_mat_00_0_sp1 = _mm512_add_epi16(_mm512_add_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_03_sp1, lhs_mat_01_03_sp1), _mm512_maddubs_epi16(rhs_mat_014589CD_02_sp1, lhs_mat_01_02_sp1)), _mm512_maddubs_epi16(rhs_mat_014589CD_01_sp1, lhs_mat_01_01_sp1)), _mm512_maddubs_epi16(rhs_mat_014589CD_00_sp1, lhs_mat_01_00_sp1)); + __m512i iacc_mat_01_0_sp1 = _mm512_add_epi16(_mm512_add_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_03_sp1, lhs_mat_01_03_sp1), _mm512_maddubs_epi16(rhs_mat_2367ABEF_02_sp1, lhs_mat_01_02_sp1)), _mm512_maddubs_epi16(rhs_mat_2367ABEF_01_sp1, lhs_mat_01_01_sp1)), _mm512_maddubs_epi16(rhs_mat_2367ABEF_00_sp1, lhs_mat_01_00_sp1)); + __m512i iacc_mat_10_0_sp1 = _mm512_add_epi16(_mm512_add_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_03_sp1, lhs_mat_23_03_sp1), _mm512_maddubs_epi16(rhs_mat_014589CD_02_sp1, lhs_mat_23_02_sp1)), _mm512_maddubs_epi16(rhs_mat_014589CD_01_sp1, lhs_mat_23_01_sp1)), _mm512_maddubs_epi16(rhs_mat_014589CD_00_sp1, lhs_mat_23_00_sp1)); + __m512i iacc_mat_11_0_sp1 = _mm512_add_epi16(_mm512_add_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_03_sp1, lhs_mat_23_03_sp1), _mm512_maddubs_epi16(rhs_mat_2367ABEF_02_sp1, lhs_mat_23_02_sp1)), _mm512_maddubs_epi16(rhs_mat_2367ABEF_01_sp1, lhs_mat_23_01_sp1)), _mm512_maddubs_epi16(rhs_mat_2367ABEF_00_sp1, lhs_mat_23_00_sp1)); + __m512i iacc_mat_00_1_sp1 = _mm512_add_epi16(_mm512_add_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_13_sp1, lhs_mat_01_13_sp1), _mm512_maddubs_epi16(rhs_mat_014589CD_12_sp1, lhs_mat_01_12_sp1)), _mm512_maddubs_epi16(rhs_mat_014589CD_11_sp1, lhs_mat_01_11_sp1)), _mm512_maddubs_epi16(rhs_mat_014589CD_10_sp1, lhs_mat_01_10_sp1)); + __m512i iacc_mat_01_1_sp1 = _mm512_add_epi16(_mm512_add_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_13_sp1, lhs_mat_01_13_sp1), _mm512_maddubs_epi16(rhs_mat_2367ABEF_12_sp1, lhs_mat_01_12_sp1)), _mm512_maddubs_epi16(rhs_mat_2367ABEF_11_sp1, lhs_mat_01_11_sp1)), _mm512_maddubs_epi16(rhs_mat_2367ABEF_10_sp1, lhs_mat_01_10_sp1)); + __m512i iacc_mat_10_1_sp1 = _mm512_add_epi16(_mm512_add_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_13_sp1, lhs_mat_23_13_sp1), _mm512_maddubs_epi16(rhs_mat_014589CD_12_sp1, lhs_mat_23_12_sp1)), _mm512_maddubs_epi16(rhs_mat_014589CD_11_sp1, lhs_mat_23_11_sp1)), _mm512_maddubs_epi16(rhs_mat_014589CD_10_sp1, lhs_mat_23_10_sp1)); + __m512i iacc_mat_11_1_sp1 = _mm512_add_epi16(_mm512_add_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_13_sp1, lhs_mat_23_13_sp1), _mm512_maddubs_epi16(rhs_mat_2367ABEF_12_sp1, lhs_mat_23_12_sp1)), _mm512_maddubs_epi16(rhs_mat_2367ABEF_11_sp1, lhs_mat_23_11_sp1)), _mm512_maddubs_epi16(rhs_mat_2367ABEF_10_sp1, lhs_mat_23_10_sp1)); + + __m512i iacc_mat_00_0_sp2 = _mm512_add_epi16(_mm512_add_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_03_sp2, lhs_mat_01_03_sp2), _mm512_maddubs_epi16(rhs_mat_014589CD_02_sp2, lhs_mat_01_02_sp2)), _mm512_maddubs_epi16(rhs_mat_014589CD_01_sp2, lhs_mat_01_01_sp2)), _mm512_maddubs_epi16(rhs_mat_014589CD_00_sp2, lhs_mat_01_00_sp2)); + __m512i iacc_mat_01_0_sp2 = _mm512_add_epi16(_mm512_add_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_03_sp2, lhs_mat_01_03_sp2), _mm512_maddubs_epi16(rhs_mat_2367ABEF_02_sp2, lhs_mat_01_02_sp2)), _mm512_maddubs_epi16(rhs_mat_2367ABEF_01_sp2, lhs_mat_01_01_sp2)), _mm512_maddubs_epi16(rhs_mat_2367ABEF_00_sp2, lhs_mat_01_00_sp2)); + __m512i iacc_mat_10_0_sp2 = _mm512_add_epi16(_mm512_add_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_03_sp2, lhs_mat_23_03_sp2), _mm512_maddubs_epi16(rhs_mat_014589CD_02_sp2, lhs_mat_23_02_sp2)), _mm512_maddubs_epi16(rhs_mat_014589CD_01_sp2, lhs_mat_23_01_sp2)), _mm512_maddubs_epi16(rhs_mat_014589CD_00_sp2, lhs_mat_23_00_sp2)); + __m512i iacc_mat_11_0_sp2 = _mm512_add_epi16(_mm512_add_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_03_sp2, lhs_mat_23_03_sp2), _mm512_maddubs_epi16(rhs_mat_2367ABEF_02_sp2, lhs_mat_23_02_sp2)), _mm512_maddubs_epi16(rhs_mat_2367ABEF_01_sp2, lhs_mat_23_01_sp2)), _mm512_maddubs_epi16(rhs_mat_2367ABEF_00_sp2, lhs_mat_23_00_sp2)); + __m512i iacc_mat_00_1_sp2 = _mm512_add_epi16(_mm512_add_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_13_sp2, lhs_mat_01_13_sp2), _mm512_maddubs_epi16(rhs_mat_014589CD_12_sp2, lhs_mat_01_12_sp2)), _mm512_maddubs_epi16(rhs_mat_014589CD_11_sp2, lhs_mat_01_11_sp2)), _mm512_maddubs_epi16(rhs_mat_014589CD_10_sp2, lhs_mat_01_10_sp2)); + __m512i iacc_mat_01_1_sp2 = _mm512_add_epi16(_mm512_add_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_13_sp2, lhs_mat_01_13_sp2), _mm512_maddubs_epi16(rhs_mat_2367ABEF_12_sp2, lhs_mat_01_12_sp2)), _mm512_maddubs_epi16(rhs_mat_2367ABEF_11_sp2, lhs_mat_01_11_sp2)), _mm512_maddubs_epi16(rhs_mat_2367ABEF_10_sp2, lhs_mat_01_10_sp2)); + __m512i iacc_mat_10_1_sp2 = _mm512_add_epi16(_mm512_add_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_13_sp2, lhs_mat_23_13_sp2), _mm512_maddubs_epi16(rhs_mat_014589CD_12_sp2, lhs_mat_23_12_sp2)), _mm512_maddubs_epi16(rhs_mat_014589CD_11_sp2, lhs_mat_23_11_sp2)), _mm512_maddubs_epi16(rhs_mat_014589CD_10_sp2, lhs_mat_23_10_sp2)); + __m512i iacc_mat_11_1_sp2 = _mm512_add_epi16(_mm512_add_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_13_sp2, lhs_mat_23_13_sp2), _mm512_maddubs_epi16(rhs_mat_2367ABEF_12_sp2, lhs_mat_23_12_sp2)), _mm512_maddubs_epi16(rhs_mat_2367ABEF_11_sp2, lhs_mat_23_11_sp2)), _mm512_maddubs_epi16(rhs_mat_2367ABEF_10_sp2, lhs_mat_23_10_sp2)); + + // Output of both shuffle patterns are added in order to sum dot product outputs of all 32 values in block + __m512i iacc_mat_00_0 = _mm512_add_epi16(iacc_mat_00_0_sp1, iacc_mat_00_0_sp2); + __m512i iacc_mat_01_0 = _mm512_add_epi16(iacc_mat_01_0_sp1, iacc_mat_01_0_sp2); + __m512i iacc_mat_10_0 = _mm512_add_epi16(iacc_mat_10_0_sp1, iacc_mat_10_0_sp2); + __m512i iacc_mat_11_0 = _mm512_add_epi16(iacc_mat_11_0_sp1, iacc_mat_11_0_sp2); + + __m512i iacc_mat_00_1 = _mm512_add_epi16(iacc_mat_00_1_sp1, iacc_mat_00_1_sp2); + __m512i iacc_mat_01_1 = _mm512_add_epi16(iacc_mat_01_1_sp1, iacc_mat_01_1_sp2); + __m512i iacc_mat_10_1 = _mm512_add_epi16(iacc_mat_10_1_sp1, iacc_mat_10_1_sp2); + __m512i iacc_mat_11_1 = _mm512_add_epi16(iacc_mat_11_1_sp1, iacc_mat_11_1_sp2); + + iacc_mat_00_0 = _mm512_madd_epi16(iacc_mat_00_0, scale_014589CD_0); + iacc_mat_01_0 = _mm512_madd_epi16(iacc_mat_01_0, scale_2367ABEF_0); + iacc_mat_10_0 = _mm512_madd_epi16(iacc_mat_10_0, scale_014589CD_0); + iacc_mat_11_0 = _mm512_madd_epi16(iacc_mat_11_0, scale_2367ABEF_0); + + iacc_mat_00_1 = _mm512_madd_epi16(iacc_mat_00_1, scale_014589CD_1); + iacc_mat_01_1 = _mm512_madd_epi16(iacc_mat_01_1, scale_2367ABEF_1); + iacc_mat_10_1 = _mm512_madd_epi16(iacc_mat_10_1, scale_014589CD_1); + iacc_mat_11_1 = _mm512_madd_epi16(iacc_mat_11_1, scale_2367ABEF_1); + + // Straighten out to make 4 row vectors (4 for each sub block which are accumulated together in the next step) + __m512i iacc_row_0_0 = _mm512_mask_blend_epi32(0xCCCC, iacc_mat_00_0, _mm512_shuffle_epi32(iacc_mat_01_0, (_MM_PERM_ENUM)78)); + __m512i iacc_row_1_0 = _mm512_mask_blend_epi32(0xCCCC, _mm512_shuffle_epi32(iacc_mat_00_0, (_MM_PERM_ENUM)78), iacc_mat_01_0); + __m512i iacc_row_2_0 = _mm512_mask_blend_epi32(0xCCCC, iacc_mat_10_0, _mm512_shuffle_epi32(iacc_mat_11_0, (_MM_PERM_ENUM)78)); + __m512i iacc_row_3_0 = _mm512_mask_blend_epi32(0xCCCC, _mm512_shuffle_epi32(iacc_mat_10_0, (_MM_PERM_ENUM)78), iacc_mat_11_0); + __m512i iacc_row_0_1 = _mm512_mask_blend_epi32(0xCCCC, iacc_mat_00_1, _mm512_shuffle_epi32(iacc_mat_01_1, (_MM_PERM_ENUM)78)); + __m512i iacc_row_1_1 = _mm512_mask_blend_epi32(0xCCCC, _mm512_shuffle_epi32(iacc_mat_00_1, (_MM_PERM_ENUM)78), iacc_mat_01_1); + __m512i iacc_row_2_1 = _mm512_mask_blend_epi32(0xCCCC, iacc_mat_10_1, _mm512_shuffle_epi32(iacc_mat_11_1, (_MM_PERM_ENUM)78)); + __m512i iacc_row_3_1 = _mm512_mask_blend_epi32(0xCCCC,_mm512_shuffle_epi32(iacc_mat_10_1, (_MM_PERM_ENUM)78), iacc_mat_11_1); + + __m512i iacc_row_0 = _mm512_add_epi32(iacc_row_0_0, iacc_row_0_1); + __m512i iacc_row_1 = _mm512_add_epi32(iacc_row_1_0, iacc_row_1_1); + __m512i iacc_row_2 = _mm512_add_epi32(iacc_row_2_0, iacc_row_2_1); + __m512i iacc_row_3 = _mm512_add_epi32(iacc_row_3_0, iacc_row_3_1); + + // Load the scale(d) values for all the 4 Q8_k blocks and repeat it across lanes + const __m128 row_scale_f32_sse = _mm_load_ps(a_ptrs[rp][b].d); + const __m256 row_scale_f32_ymm = _mm256_set_m128(row_scale_f32_sse, row_scale_f32_sse); + const __m512 row_scale_f32 = _mm512_insertf32x8(_mm512_castps256_ps512(row_scale_f32_ymm), row_scale_f32_ymm, 1); + + // Multiply with appropiate scales and accumulate (for both d and dmin) below + acc_rows[rp * 4] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_0), _mm512_mul_ps(col_scale_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 0)), acc_rows[rp * 4]); + acc_rows[rp * 4 + 1] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_1), _mm512_mul_ps(col_scale_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 85)), acc_rows[rp * 4 + 1]); + acc_rows[rp * 4 + 2] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_2), _mm512_mul_ps(col_scale_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 170)), acc_rows[rp * 4 + 2]); + acc_rows[rp * 4 + 3] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_3), _mm512_mul_ps(col_scale_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 255)), acc_rows[rp * 4 + 3]); + + __m512i iacc_row_min_0 = _mm512_madd_epi16(_mm512_shuffle_epi32(lhs_bsums_hsum_0123_01, (_MM_PERM_ENUM)0), mins_01); + __m512i iacc_row_min_1 = _mm512_madd_epi16(_mm512_shuffle_epi32(lhs_bsums_hsum_0123_01, (_MM_PERM_ENUM)85), mins_01); + __m512i iacc_row_min_2 = _mm512_madd_epi16(_mm512_shuffle_epi32(lhs_bsums_hsum_0123_01, (_MM_PERM_ENUM)170), mins_01); + __m512i iacc_row_min_3 = _mm512_madd_epi16(_mm512_shuffle_epi32(lhs_bsums_hsum_0123_01, (_MM_PERM_ENUM)255), mins_01); + + acc_min_rows[rp * 4] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_min_0), _mm512_mul_ps(col_dmin_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 0)), acc_min_rows[rp * 4]); + acc_min_rows[rp * 4 + 1] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_min_1), _mm512_mul_ps(col_dmin_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 85)), acc_min_rows[rp * 4 + 1]); + acc_min_rows[rp * 4 + 2] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_min_2), _mm512_mul_ps(col_dmin_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 170)), acc_min_rows[rp * 4 + 2]); + acc_min_rows[rp * 4 + 3] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_min_3), _mm512_mul_ps(col_dmin_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 255)), acc_min_rows[rp * 4 + 3]); + } + } + } + // Store the accumulated values + for (int i = 0; i < 16; i++) { + _mm512_storeu_ps((float * )(s + ((y * 4 + i) * bs + x * 8)), _mm512_sub_ps(acc_rows[i], acc_min_rows[i])); + } + } + } + + for (; y < nr / 4; y++) { + + const block_q8_Kx4 * a_ptr = a_ptr_start + (y * nb); + + // Take group of eight block_q4_kx8 structures at each pass of the loop and perform dot product operation + for (int64_t x = 0; x < anc / 8; x += 2) { + + const block_q4_Kx8 * b_ptr_0 = b_ptr_start + ((x) * b_nb); + const block_q4_Kx8 * b_ptr_1 = b_ptr_start + ((x + 1) * b_nb); + + // Master FP accumulators + __m512 acc_rows[4]; + for (int i = 0; i < 4; i++) { + acc_rows[i] = _mm512_setzero_ps(); + } + + __m512 acc_min_rows[4]; + for (int i = 0; i < 4; i++) { + acc_min_rows[i] = _mm512_setzero_ps(); + } + + // For super block + for (int64_t b = 0; b < nb; b++) { + // Scale values - Load the sixteen scale values from two block_q4_kx8 structures + const __m512 col_scale_f32 = GGML_F32Cx8x2_LOAD(b_ptr_0[b].d, b_ptr_1[b].d); + + // dmin values - Load the sixteen dmin values from two block_q4_kx8 structures + const __m512 col_dmin_f32 = GGML_F32Cx8x2_LOAD(b_ptr_0[b].dmin, b_ptr_1[b].dmin); + + // Loop to iterate over the eight sub blocks of a super block - two sub blocks are processed per iteration + for (int sb = 0; sb < QK_K / 64; sb++) { + + const __m256i rhs_raw_mat_0123_0 = _mm256_loadu_si256((const __m256i * )(b_ptr_0[b].qs + sb * 256)); + const __m256i rhs_raw_mat_4567_0 = _mm256_loadu_si256((const __m256i * )(b_ptr_0[b].qs + 32 + sb * 256)); + const __m256i rhs_raw_mat_0123_1 = _mm256_loadu_si256((const __m256i * )(b_ptr_0[b].qs + 64 + sb * 256)); + const __m256i rhs_raw_mat_4567_1 = _mm256_loadu_si256((const __m256i * )(b_ptr_0[b].qs + 96 + sb * 256)); + const __m256i rhs_raw_mat_0123_2 = _mm256_loadu_si256((const __m256i * )(b_ptr_0[b].qs + 128 + sb * 256)); + const __m256i rhs_raw_mat_4567_2 = _mm256_loadu_si256((const __m256i * )(b_ptr_0[b].qs + 160 + sb * 256)); + const __m256i rhs_raw_mat_0123_3 = _mm256_loadu_si256((const __m256i * )(b_ptr_0[b].qs + 192 + sb * 256)); + const __m256i rhs_raw_mat_4567_3 = _mm256_loadu_si256((const __m256i * )(b_ptr_0[b].qs + 224 + sb * 256)); + + const __m256i rhs_raw_mat_89AB_0 = _mm256_loadu_si256((const __m256i * )(b_ptr_1[b].qs + sb * 256)); + const __m256i rhs_raw_mat_CDEF_0 = _mm256_loadu_si256((const __m256i * )(b_ptr_1[b].qs + 32 + sb * 256)); + const __m256i rhs_raw_mat_89AB_1 = _mm256_loadu_si256((const __m256i * )(b_ptr_1[b].qs + 64 + sb * 256)); + const __m256i rhs_raw_mat_CDEF_1 = _mm256_loadu_si256((const __m256i * )(b_ptr_1[b].qs + 96 + sb * 256)); + const __m256i rhs_raw_mat_89AB_2 = _mm256_loadu_si256((const __m256i * )(b_ptr_1[b].qs + 128 + sb * 256)); + const __m256i rhs_raw_mat_CDEF_2 = _mm256_loadu_si256((const __m256i * )(b_ptr_1[b].qs + 160 + sb * 256)); + const __m256i rhs_raw_mat_89AB_3 = _mm256_loadu_si256((const __m256i * )(b_ptr_1[b].qs + 192 + sb * 256)); + const __m256i rhs_raw_mat_CDEF_3 = _mm256_loadu_si256((const __m256i * )(b_ptr_1[b].qs + 224 + sb * 256)); + + const __m256i rhs_raw_mat_0145_0 = _mm256_blend_epi32(rhs_raw_mat_0123_0, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_0, requiredOrder), 240); + const __m256i rhs_raw_mat_2367_0 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_0, requiredOrder), rhs_raw_mat_4567_0, 240); + const __m256i rhs_raw_mat_0145_1 = _mm256_blend_epi32(rhs_raw_mat_0123_1, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_1, requiredOrder), 240); + const __m256i rhs_raw_mat_2367_1 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_1, requiredOrder), rhs_raw_mat_4567_1, 240); + const __m256i rhs_raw_mat_0145_2 = _mm256_blend_epi32(rhs_raw_mat_0123_2, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_2, requiredOrder), 240); + const __m256i rhs_raw_mat_2367_2 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_2, requiredOrder), rhs_raw_mat_4567_2, 240); + const __m256i rhs_raw_mat_0145_3 = _mm256_blend_epi32(rhs_raw_mat_0123_3, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_3, requiredOrder), 240); + const __m256i rhs_raw_mat_2367_3 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_3, requiredOrder), rhs_raw_mat_4567_3, 240); + + const __m256i rhs_raw_mat_89CD_0 = _mm256_blend_epi32(rhs_raw_mat_89AB_0, _mm256_permutevar8x32_epi32(rhs_raw_mat_CDEF_0, requiredOrder), 240); + const __m256i rhs_raw_mat_ABEF_0 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_89AB_0, requiredOrder), rhs_raw_mat_CDEF_0, 240); + const __m256i rhs_raw_mat_89CD_1 = _mm256_blend_epi32(rhs_raw_mat_89AB_1, _mm256_permutevar8x32_epi32(rhs_raw_mat_CDEF_1, requiredOrder), 240); + const __m256i rhs_raw_mat_ABEF_1 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_89AB_1, requiredOrder), rhs_raw_mat_CDEF_1, 240); + const __m256i rhs_raw_mat_89CD_2 = _mm256_blend_epi32(rhs_raw_mat_89AB_2, _mm256_permutevar8x32_epi32(rhs_raw_mat_CDEF_2, requiredOrder), 240); + const __m256i rhs_raw_mat_ABEF_2 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_89AB_2, requiredOrder), rhs_raw_mat_CDEF_2, 240); + const __m256i rhs_raw_mat_89CD_3 = _mm256_blend_epi32(rhs_raw_mat_89AB_3, _mm256_permutevar8x32_epi32(rhs_raw_mat_CDEF_3, requiredOrder), 240); + const __m256i rhs_raw_mat_ABEF_3 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_89AB_3, requiredOrder), rhs_raw_mat_CDEF_3, 240); + + const __m512i rhs_raw_mat_014589CD_0 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_0145_0), rhs_raw_mat_89CD_0, 1); + const __m512i rhs_raw_mat_2367ABEF_0 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_2367_0), rhs_raw_mat_ABEF_0, 1); + const __m512i rhs_raw_mat_014589CD_1 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_0145_1), rhs_raw_mat_89CD_1, 1); + const __m512i rhs_raw_mat_2367ABEF_1 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_2367_1), rhs_raw_mat_ABEF_1, 1); + + const __m512i rhs_raw_mat_014589CD_2 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_0145_2), rhs_raw_mat_89CD_2, 1); + const __m512i rhs_raw_mat_2367ABEF_2 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_2367_2), rhs_raw_mat_ABEF_2, 1); + const __m512i rhs_raw_mat_014589CD_3 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_0145_3), rhs_raw_mat_89CD_3, 1); + const __m512i rhs_raw_mat_2367ABEF_3 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_2367_3), rhs_raw_mat_ABEF_3, 1); + + //4-bit -> 8-bit + const __m512i rhs_mat_014589CD_00 = _mm512_and_si512(rhs_raw_mat_014589CD_0, m4bexpanded); //B00(0-7) B01(0-7) B04(0-7) B05(0-7) B08(0-7) B09(0-7) B0C(0-7) B0D(0-7) + const __m512i rhs_mat_2367ABEF_00 = _mm512_and_si512(rhs_raw_mat_2367ABEF_0, m4bexpanded); //B02(0-7) B03(0-7) B06(0-7) B07(0-7) B0A(0-7) B0B(0-7) B0E(0-7) B0F(0-7) + const __m512i rhs_mat_014589CD_01 = _mm512_and_si512(rhs_raw_mat_014589CD_1, m4bexpanded); //B00(8-15) B01(8-15) B04(8-15) B05(8-15) B08(8-15) B09(8-15) B0C(8-15) B0D(8-15) + const __m512i rhs_mat_2367ABEF_01 = _mm512_and_si512(rhs_raw_mat_2367ABEF_1, m4bexpanded); //B02(8-15) B03(8-15) B06(8-15) B07(8-15) B0A(8-15) B0B(8-15) B0E(8-15) B0F(8-15) + + const __m512i rhs_mat_014589CD_02 = _mm512_and_si512(rhs_raw_mat_014589CD_2, m4bexpanded); //B00(16-23) B01(16-23) B04(16-23) B05(16-23) B08(16-23) B09(16-23) B0C(16-23) B0D(16-23) + const __m512i rhs_mat_2367ABEF_02 = _mm512_and_si512(rhs_raw_mat_2367ABEF_2, m4bexpanded); //B02(16-23) B03(16-23) B06(16-23) B07(16-23) B0A(16-23) B0B(16-23) B0E(16-23) B0F(16-23) + const __m512i rhs_mat_014589CD_03 = _mm512_and_si512(rhs_raw_mat_014589CD_3, m4bexpanded); //B00(24-31) B01(24-31) B04(24-31) B05(24-31) B08(24-31) B09(24-31) B0C(24-31) B0D(24-31) + const __m512i rhs_mat_2367ABEF_03 = _mm512_and_si512(rhs_raw_mat_2367ABEF_3, m4bexpanded); //B02(24-31) B03(24-31) B06(24-31) B07(24-31) B0A(24-31) B0B(24-31) B0E(24-31) B0F(24-31) + + const __m512i rhs_mat_014589CD_10 = _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_014589CD_0, 4), m4bexpanded); //B10(0-7) B11(0-7) B14(0-7) B15(0-7) B18(0-7) B19(0-7) B1C(0-7) B1D(0-7) + const __m512i rhs_mat_2367ABEF_10 = _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_2367ABEF_0, 4), m4bexpanded); //B12(0-7) B13(0-7) B16(0-7) B17(0-7) B1A(0-7) B1B(0-7) B1E(0-7) B1F(0-7) + const __m512i rhs_mat_014589CD_11 = _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_014589CD_1, 4), m4bexpanded); //B10(8-15) B11(8-15) B14(8-15) B15(8-15) B18(8-15) B19(8-15) B1C(8-15) B1D(8-15) + const __m512i rhs_mat_2367ABEF_11 = _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_2367ABEF_1, 4), m4bexpanded); //B12(8-15) B13(8-15) B16(8-15) B17(8-15) B1A(8-15) B1B(8-15) B1E(8-15) B1F(8-15) + + const __m512i rhs_mat_014589CD_12 = _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_014589CD_2, 4), m4bexpanded); //B10(16-23) B11(16-23) B14(16-23) B15(16-23) B18(16-23) B19(16-23) B1C(16-23) B1D(16-23) + const __m512i rhs_mat_2367ABEF_12 = _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_2367ABEF_2, 4), m4bexpanded); //B12(16-23) B13(16-23) B16(16-23) B17(16-23) B1A(16-23) B1B(16-23) B1E(16-23) B1F(16-23) + const __m512i rhs_mat_014589CD_13 = _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_014589CD_3, 4), m4bexpanded); //B10(24-31) B11(24-31) B14(24-31) B15(24-31) B18(24-31) B19(24-31) B1C(24-31) B1D(24-31) + const __m512i rhs_mat_2367ABEF_13 = _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_2367ABEF_3, 4), m4bexpanded); //B12(24-31) B13(24-31) B16(24-31) B17(24-31) B1A(24-31) B1B(24-31) B1E(24-31) B1F(24-31) + + // Shuffle pattern one - right side input + const __m512i rhs_mat_014589CD_00_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_00, (_MM_PERM_ENUM)136); //B00(0-3) B01(0-3) B00(0-3) B01(0-3) B04(0-3) B05(0-3) B04(0-3) B05(0-3) B08(0-3) B09(0-3) B08(0-3) B09(0-3) B0C(0-3) B0D(0-3) B0C(0-3) B0D(0-3) + const __m512i rhs_mat_2367ABEF_00_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_00, (_MM_PERM_ENUM)136); //B02(0-3) B03(0-3) B02(0-3) B03(0-3) B06(0-3) B07(0-3) B06(0-3) B07(0-3) B0A(0-3) B0B(0-3) B0A(0-3) B0B(0-3) B0E(0-3) B0F(0-3) B0E(0-3) B0F(0-3) + const __m512i rhs_mat_014589CD_01_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_01, (_MM_PERM_ENUM)136); //B00(8-11) B01(8-11) B00(8-11) B01(8-11) B04(8-11) B05(8-11) B04(8-11) B05(8-11) B08(8-11) B09(8-11) B08(8-11) B09(8-11) B0C(8-11) B0D(8-11) B0C(8-11) B0D(8-11) + const __m512i rhs_mat_2367ABEF_01_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_01, (_MM_PERM_ENUM)136); //B02(8-11) B03(8-11) B02(8-11) B03(8-11) B06(8-11) B07(8-11) B06(8-11) B07(8-11) B0A(8-11) B0B(8-11) B0A(8-11) B0B(8-11) B0E(8-11) B0F(8-11) B0E(8-11) B0F(8-11) + const __m512i rhs_mat_014589CD_02_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_02, (_MM_PERM_ENUM)136); //B00(16-19) B01(16-19) B00(16-19) B01(16-19) B04(16-19) B05(16-19) B04(16-19) B05(16-19) B08(16-19) B09(16-19) B08(16-19) B09(16-19) B0C(16-19) B0D(16-19) B0C(16-19) B0D(16-19) + const __m512i rhs_mat_2367ABEF_02_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_02, (_MM_PERM_ENUM)136); //B02(16-19) B03(16-19) B02(16-19) B03(16-19) B06(16-19) B07(16-19) B06(16-19) B07(16-19) B0A(16-19) B0B(16-19) B0A(16-19) B0B(16-19) B0E(16-19) B0F(16-19) B0E(16-19) B0F(16-19) + const __m512i rhs_mat_014589CD_03_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_03, (_MM_PERM_ENUM)136); //B00(24-27) B01(24-27) B00(24-27) B01(24-27) B04(24-27) B05(24-27) B04(24-27) B05(24-27) B08(24-27) B09(24-27) B08(24-27) B09(24-27) B0C(24-27) B0D(24-27) B0C(24-27) B0D(24-27) + const __m512i rhs_mat_2367ABEF_03_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_03, (_MM_PERM_ENUM)136); //B02(24-27) B03(24-27) B02(24-27) B03(24-27) B06(24-27) B07(24-27) B06(24-27) B07(24-27) B0A(24-27) B0B(24-27) B0A(24-27) B0B(24-27) B0E(24-27) B0F(24-27) B0E(24-27) B0F(24-27) + + const __m512i rhs_mat_014589CD_10_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_10, (_MM_PERM_ENUM)136); //B10(0-3) B11(0-3) B10(0-3) B11(0-3) B14(0-3) B15(0-3) B14(0-3) B15(0-3) B18(0-3) B19(0-3) B18(0-3) B19(0-3) B1C(0-3) B1D(0-3) B1C(0-3) B1D(0-3) + const __m512i rhs_mat_2367ABEF_10_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_10, (_MM_PERM_ENUM)136); //B12(0-3) B13(0-3) B12(0-3) B13(0-3) B16(0-3) B17(0-3) B16(0-3) B17(0-3) B1A(0-3) B1B(0-3) B1A(0-3) B1B(0-3) B1E(0-3) B1F(0-3) B1E(0-3) B1F(0-3) + const __m512i rhs_mat_014589CD_11_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_11, (_MM_PERM_ENUM)136); //B10(8-11) B11(8-11) B10(8-11) B11(8-11) B14(8-11) B15(8-11) B14(8-11) B15(8-11) B18(8-11) B19(8-11) B18(8-11) B19(8-11) B1C(8-11) B1D(8-11) B1C(8-11) B1D(8-11) + const __m512i rhs_mat_2367ABEF_11_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_11, (_MM_PERM_ENUM)136); //B12(8-11) B13(8-11) B12(8-11) B13(8-11) B16(8-11) B17(8-11) B16(8-11) B17(8-11) B1A(8-11) B1B(8-11) B1A(8-11) B1B(8-11) B1E(8-11) B1F(8-11) B1E(8-11) B1F(8-11) + const __m512i rhs_mat_014589CD_12_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_12, (_MM_PERM_ENUM)136); //B10(16-19) B11(16-19) B10(16-19) B11(16-19) B14(16-19) B15(16-19) B14(16-19) B15(16-19) B18(16-19) B19(16-19) B18(16-19) B19(16-19) B1C(16-19) B1D(16-19) B1C(16-19) B1D(16-19) + const __m512i rhs_mat_2367ABEF_12_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_12, (_MM_PERM_ENUM)136); //B12(16-19) B13(16-19) B12(16-19) B13(16-19) B16(16-19) B17(16-19) B16(16-19) B17(16-19) B1A(16-19) B1B(16-19) B1A(16-19) B1B(16-19) B1E(16-19) B1F(16-19) B1E(16-19) B1F(16-19) + const __m512i rhs_mat_014589CD_13_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_13, (_MM_PERM_ENUM)136); //B10(24-27) B11(24-27) B10(24-27) B11(24-27) B14(24-27) B15(24-27) B14(24-27) B15(24-27) B18(24-27) B19(24-27) B18(24-27) B19(24-27) B1C(24-27) B1D(24-27) B1C(24-27) B1D(24-27) + const __m512i rhs_mat_2367ABEF_13_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_13, (_MM_PERM_ENUM)136); //B12(24-27) B13(24-27) B12(24-27) B13(24-27) B16(24-27) B17(24-27) B16(24-27) B17(24-27) B1A(24-27) B1B(24-27) B1A(24-27) B1B(24-27) B1E(24-27) B1F(24-27) B1E(24-27) B1F(24-27) + + // Shuffle pattern two - right side input + const __m512i rhs_mat_014589CD_00_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_00, (_MM_PERM_ENUM)221); //B00(4-7) B01(4-7) B00(4-7) B01(4-7) B04(4-7) B05(4-7) B04(4-7) B05(4-7) B08(4-7) B09(4-7) B08(4-7) B09(4-7) B0C(4-7) B0D(4-7) B0C(4-7) B0D(4-7) + const __m512i rhs_mat_2367ABEF_00_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_00, (_MM_PERM_ENUM)221); //B02(4-7) B03(4-7) B02(4-7) B03(4-7) B06(4-7) B07(4-7) B06(4-7) B07(4-7) B0A(4-7) B0B(4-7) B0A(4-7) B0B(4-7) B0E(4-7) B0F(4-7) B0E(4-7) B0F(4-7) + const __m512i rhs_mat_014589CD_01_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_01, (_MM_PERM_ENUM)221); //B00(12-15) B01(12-15) B00(12-15) B01(12-15) B04(12-15) B05(12-15) B04(12-15) B05(12-15) B08(12-15) B09(12-15) B08(12-15) B09(12-15) B0C(12-15) B0D(12-15) B0C(12-15) B0D(12-15) + const __m512i rhs_mat_2367ABEF_01_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_01, (_MM_PERM_ENUM)221); //B02(12-15) B03(12-15) B02(12-15) B03(12-15) B06(12-15) B07(12-15) B06(12-15) B07(12-15) B0A(12-15) B0B(12-15) B0A(12-15) B0B(12-15) B0E(12-15) B0F(12-15) B0E(12-15) B0F(12-15) + const __m512i rhs_mat_014589CD_02_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_02, (_MM_PERM_ENUM)221); //B00(20-23) B01(20-23) B00(20-23) B01(20-23) B04(20-23) B05(20-23) B04(20-23) B05(20-23) B08(20-23) B09(20-23) B08(20-23) B09(20-23) B0C(20-23) B0D(20-23) B0C(20-23) B0D(20-23) + const __m512i rhs_mat_2367ABEF_02_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_02, (_MM_PERM_ENUM)221); //B02(20-23) B03(20-23) B02(20-23) B03(20-23) B06(20-23) B07(20-23) B06(20-23) B07(20-23) B0A(20-23) B0B(20-23) B0A(20-23) B0B(20-23) B0E(20-23) B0F(20-23) B0E(20-23) B0F(20-23) + const __m512i rhs_mat_014589CD_03_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_03, (_MM_PERM_ENUM)221); //B00(28-31) B01(28-31) B00(28-31) B01(28-31) B04(28-31) B05(28-31) B04(28-31) B05(28-31) B08(28-31) B09(28-31) B08(28-31) B09(28-31) B0C(28-31) B0D(28-31) B0C(28-31) 0BD(28-31) + const __m512i rhs_mat_2367ABEF_03_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_03, (_MM_PERM_ENUM)221); //B02(28-31) B03(28-31) B02(28-31) B03(28-31) B06(28-31) B07(28-31) B06(28-31) B07(28-31) B0A(28-31) B0B(28-31) B0A(28-31) B0B(28-31) B0E(28-31) B0F(28-31) B0E(28-31) B0F(28-31) + + const __m512i rhs_mat_014589CD_10_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_10, (_MM_PERM_ENUM)221); //B10(4-7) B11(4-7) B10(4-7) B11(4-7) B14(4-7) B15(4-7) B14(4-7) B15(4-7) B18(4-7) B19(4-7) B18(4-7) B19(4-7) B1C(4-7) B1D(4-7) B1C(4-7) B1D(4-7) + const __m512i rhs_mat_2367ABEF_10_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_10, (_MM_PERM_ENUM)221); //B12(4-7) B13(4-7) B12(4-7) B13(4-7) B16(4-7) B17(4-7) B16(4-7) B17(4-7) B1A(4-7) B1B(4-7) B1A(4-7) B1B(4-7) B1E(4-7) B1F(4-7) B1E(4-7) B1F(4-7) + const __m512i rhs_mat_014589CD_11_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_11, (_MM_PERM_ENUM)221); //B10(12-15) B11(12-15) B10(12-15) B11(12-15) B14(12-15) B15(12-15) B14(12-15) B15(12-15) B18(12-15) B19(12-15) B18(12-15) B19(12-15) B1C(12-15) B1D(12-15) B1C(12-15) B1D(12-15) + const __m512i rhs_mat_2367ABEF_11_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_11, (_MM_PERM_ENUM)221); //B12(12-15) B13(12-15) B12(12-15) B13(12-15) B16(12-15) B17(12-15) B16(12-15) B17(12-15) B1A(12-15) B1B(12-15) B1A(12-15) B1B(12-15) B1E(12-15) B1F(12-15) B1E(12-15) B1F(12-15) + const __m512i rhs_mat_014589CD_12_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_12, (_MM_PERM_ENUM)221); //B10(20-23) B11(20-23) B10(20-23) B11(20-23) B14(20-23) B15(20-23) B14(20-23) B15(20-23) B18(20-23) B19(20-23) B18(20-23) B19(20-23) B1C(20-23) B1D(20-23) B1C(20-23) B1D(20-23) + const __m512i rhs_mat_2367ABEF_12_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_12, (_MM_PERM_ENUM)221); //B12(20-23) B13(20-23) B12(20-23) B13(20-23) B16(20-23) B17(20-23) B16(20-23) B17(20-23) B1A(20-23) B1B(20-23) B1A(20-23) B1B(20-23) B1E(20-23) B1F(20-23) B1E(20-23) B1F(20-23) + const __m512i rhs_mat_014589CD_13_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_13, (_MM_PERM_ENUM)221); //B10(28-31) B11(28-31) B10(28-31) B11(28-31) B14(28-31) B15(28-31) B14(28-31) B15(28-31) B18(28-31) B19(28-31) B18(28-31) B19(28-31) B1C(28-31) B1D(28-31) B1C(28-31) B1D(28-31) + const __m512i rhs_mat_2367ABEF_13_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_13, (_MM_PERM_ENUM)221); //B12(28-31) B13(28-31) B12(28-31) B13(28-31) B16(28-31) B17(28-31) B16(28-31) B17(28-31) B1A(28-31) B1B(28-31) B1A(28-31) B1B(28-31) B1E(28-31) B1F(28-31) B1E(28-31) B1F(28-31) + + uint32_t utmp_00[4], utmp_01[4], utmp_10[4], utmp_11[4]; + + // Scales and Mins of corresponding sub blocks from different Q4_K structures are stored together + // The below block is for eg to extract first sub block's scales and mins from different Q4_K structures for the sb loop + memcpy(utmp_00, b_ptr_0[b].scales + 24 * sb, 12); + utmp_00[3] = ((utmp_00[2] >> 4) & kmask2) | (((utmp_00[1] >> 6) & kmask3) << 4); + const uint32_t uaux_00 = utmp_00[1] & kmask1; + utmp_00[1] = (utmp_00[2] & kmask2) | (((utmp_00[0] >> 6) & kmask3) << 4); + utmp_00[2] = uaux_00; + utmp_00[0] &= kmask1; + + // The below block is for eg to extract second sub block's scales and mins from different Q4_K structures for the sb loop + memcpy(utmp_01, b_ptr_0[b].scales + 12 + sb * 24, 12); + utmp_01[3] = ((utmp_01[2] >> 4) & kmask2) | (((utmp_01[1] >> 6) & kmask3) << 4); + const uint32_t uaux_01 = utmp_01[1] & kmask1; + utmp_01[1] = (utmp_01[2] & kmask2) | (((utmp_01[0] >> 6) & kmask3) << 4); + utmp_01[2] = uaux_01; + utmp_01[0] &= kmask1; + + // The below block is for eg to extract first sub block's scales and mins from different Q4_K structures for the sb loop + memcpy(utmp_10, b_ptr_1[b].scales + sb * 24, 12); + utmp_10[3] = ((utmp_10[2] >> 4) & kmask2) | (((utmp_10[1] >> 6) & kmask3) << 4); + const uint32_t uaux_10 = utmp_10[1] & kmask1; + utmp_10[1] = (utmp_10[2] & kmask2) | (((utmp_10[0] >> 6) & kmask3) << 4); + utmp_10[2] = uaux_10; + utmp_10[0] &= kmask1; + + // The below block is for eg to extract second sub block's scales and mins from different Q4_K structures for the sb loop + memcpy(utmp_11, b_ptr_1[b].scales + 12 + sb * 24, 12); + utmp_11[3] = ((utmp_11[2] >> 4) & kmask2) | (((utmp_11[1] >> 6) & kmask3) << 4); + const uint32_t uaux_11 = utmp_11[1] & kmask1; + utmp_11[1] = (utmp_11[2] & kmask2) | (((utmp_11[0] >> 6) & kmask3) << 4); + utmp_11[2] = uaux_11; + utmp_11[0] &= kmask1; + + // Scales of first sub block in the sb loop + const __m256i mins_and_scales_0 = _mm256_set_epi32(utmp_10[3], utmp_10[2], utmp_10[1], utmp_10[0], utmp_00[3], utmp_00[2], utmp_00[1], utmp_00[0]); + const __m512i scales_0 = _mm512_cvtepu8_epi16(_mm256_unpacklo_epi8(mins_and_scales_0, mins_and_scales_0)); + + // Scales of second sub block in the sb loop + const __m256i mins_and_scales_1 = _mm256_set_epi32(utmp_11[3], utmp_11[2], utmp_11[1], utmp_11[0], utmp_01[3], utmp_01[2], utmp_01[1], utmp_01[0]); + const __m512i scales_1 = _mm512_cvtepu8_epi16(_mm256_unpacklo_epi8(mins_and_scales_1, mins_and_scales_1)); + + // Mins of first and second sub block of Q4_K block are arranged side by side + const __m512i mins_01 = _mm512_cvtepu8_epi16(_mm256_unpacklo_epi8(_mm256_shuffle_epi32(mins_and_scales_0, 78), _mm256_shuffle_epi32(mins_and_scales_1, 78))); + + const __m512i scale_014589CD_0 = _mm512_shuffle_epi32(scales_0, (_MM_PERM_ENUM)68); + const __m512i scale_2367ABEF_0 = _mm512_shuffle_epi32(scales_0, (_MM_PERM_ENUM)238); + + const __m512i scale_014589CD_1 = _mm512_shuffle_epi32(scales_1, (_MM_PERM_ENUM)68); + const __m512i scale_2367ABEF_1 = _mm512_shuffle_epi32(scales_1, (_MM_PERM_ENUM)238); + + // Load the four block_q8_k quantized values interleaved with each other in chunks of eight bytes - A0,A1,A2,A3 + // Loaded as set of 128 bit vectors and repeated into a 256 bit vector + __m256i lhs_mat_ymm_0123_00 = _mm256_loadu_si256((const __m256i * )((a_ptr[b].qs + 256 * sb))); + __m256i lhs_mat_ymm_01_00 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_00, lhs_mat_ymm_0123_00, 0); + __m256i lhs_mat_ymm_23_00 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_00, lhs_mat_ymm_0123_00, 17); + __m256i lhs_mat_ymm_0123_01 = _mm256_loadu_si256((const __m256i * )((a_ptr[b].qs + 32 + 256 * sb))); + __m256i lhs_mat_ymm_01_01 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_01, lhs_mat_ymm_0123_01, 0); + __m256i lhs_mat_ymm_23_01 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_01, lhs_mat_ymm_0123_01, 17); + __m256i lhs_mat_ymm_0123_02 = _mm256_loadu_si256((const __m256i * )((a_ptr[b].qs + 64 + 256 * sb))); + __m256i lhs_mat_ymm_01_02 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_02, lhs_mat_ymm_0123_02, 0); + __m256i lhs_mat_ymm_23_02 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_02, lhs_mat_ymm_0123_02, 17); + __m256i lhs_mat_ymm_0123_03 = _mm256_loadu_si256((const __m256i * )((a_ptr[b].qs + 96 + 256 * sb))); + __m256i lhs_mat_ymm_01_03 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_03, lhs_mat_ymm_0123_03, 0); + __m256i lhs_mat_ymm_23_03 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_03, lhs_mat_ymm_0123_03, 17); + __m256i lhs_mat_ymm_0123_10 = _mm256_loadu_si256((const __m256i * )((a_ptr[b].qs + 128 + 256 * sb))); + __m256i lhs_mat_ymm_01_10 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_10, lhs_mat_ymm_0123_10, 0); + __m256i lhs_mat_ymm_23_10 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_10, lhs_mat_ymm_0123_10, 17); + __m256i lhs_mat_ymm_0123_11 = _mm256_loadu_si256((const __m256i * )((a_ptr[b].qs + 160 + 256 * sb))); + __m256i lhs_mat_ymm_01_11 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_11, lhs_mat_ymm_0123_11, 0); + __m256i lhs_mat_ymm_23_11 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_11, lhs_mat_ymm_0123_11, 17); + __m256i lhs_mat_ymm_0123_12 = _mm256_loadu_si256((const __m256i * )((a_ptr[b].qs + 192 + 256 * sb))); + __m256i lhs_mat_ymm_01_12 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_12, lhs_mat_ymm_0123_12, 0); + __m256i lhs_mat_ymm_23_12 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_12, lhs_mat_ymm_0123_12, 17); + __m256i lhs_mat_ymm_0123_13 = _mm256_loadu_si256((const __m256i * )((a_ptr[b].qs + 224 + 256 * sb))); + __m256i lhs_mat_ymm_01_13 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_13, lhs_mat_ymm_0123_13, 0); + __m256i lhs_mat_ymm_23_13 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_13, lhs_mat_ymm_0123_13, 17); + + //Loaded as set of 128 bit vectors and repeated and stored into a 256 bit vector before again repeating into a 512 bit vector + __m512i lhs_mat_01_00 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_00), lhs_mat_ymm_01_00, 1); + __m512i lhs_mat_23_00 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_00), lhs_mat_ymm_23_00, 1); + __m512i lhs_mat_01_01 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_01), lhs_mat_ymm_01_01, 1); + __m512i lhs_mat_23_01 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_01), lhs_mat_ymm_23_01, 1); + __m512i lhs_mat_01_02 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_02), lhs_mat_ymm_01_02, 1); + __m512i lhs_mat_23_02 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_02), lhs_mat_ymm_23_02, 1); + __m512i lhs_mat_01_03 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_03), lhs_mat_ymm_01_03, 1); + __m512i lhs_mat_23_03 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_03), lhs_mat_ymm_23_03, 1); + + __m512i lhs_mat_01_10 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_10), lhs_mat_ymm_01_10, 1); + __m512i lhs_mat_23_10 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_10), lhs_mat_ymm_23_10, 1); + __m512i lhs_mat_01_11 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_11), lhs_mat_ymm_01_11, 1); + __m512i lhs_mat_23_11 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_11), lhs_mat_ymm_23_11, 1); + __m512i lhs_mat_01_12 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_12), lhs_mat_ymm_01_12, 1); + __m512i lhs_mat_23_12 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_12), lhs_mat_ymm_23_12, 1); + __m512i lhs_mat_01_13 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_13), lhs_mat_ymm_01_13, 1); + __m512i lhs_mat_23_13 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_13), lhs_mat_ymm_23_13, 1); + + // Bsums are loaded - four bsums are loaded (for two sub blocks) for the different Q8_K blocks + __m256i lhs_bsums_0123_01 = _mm256_loadu_si256((const __m256i * )((a_ptr[b].bsums + 16 * sb))); + __m256i lhs_bsums_hsum_ymm_0123_01 = _mm256_castsi128_si256(_mm_hadd_epi16(_mm256_castsi256_si128(lhs_bsums_0123_01), _mm256_extractf128_si256(lhs_bsums_0123_01, 1))); + lhs_bsums_hsum_ymm_0123_01 = _mm256_permute2x128_si256(lhs_bsums_hsum_ymm_0123_01, lhs_bsums_hsum_ymm_0123_01, 0); + __m512i lhs_bsums_hsum_0123_01 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_bsums_hsum_ymm_0123_01), lhs_bsums_hsum_ymm_0123_01, 1); + + // Shuffle pattern one - left side input + const __m512i lhs_mat_01_00_sp1 = _mm512_shuffle_epi32(lhs_mat_01_00, (_MM_PERM_ENUM)160); //A00(0-3) A00(0-3) A01(0-3) A01(0-3) A00(0-3) A00(0-3) A01(0-3) A01(0-3) A00(0-3) A00(0-3) A01(0-3) A01(0-3) A00(0-3) A00(0-3) A01(0-3) A01(0-3) + const __m512i lhs_mat_23_00_sp1 = _mm512_shuffle_epi32(lhs_mat_23_00, (_MM_PERM_ENUM)160); //A02(0-3) A02(0-3) A03(0-3) A03(0-3) A02(0-3) A02(0-3) A03(0-3) A03(0-3) A02(0-3) A02(0-3) A03(0-3) A03(0-3) A02(0-3) A02(0-3) A03(0-3) A03(0-3) + const __m512i lhs_mat_01_01_sp1 = _mm512_shuffle_epi32(lhs_mat_01_01, (_MM_PERM_ENUM)160); //A00(8-11) A00(8-11) A01(8-11) A01(8-11) A00(8-11) A00(8-11) A01(8-11) A01(8-11) A00(8-11) A00(8-11) A01(8-11) A01(8-11) A00(8-11) A00(8-11) A01(8-11) A01(8-11) + const __m512i lhs_mat_23_01_sp1 = _mm512_shuffle_epi32(lhs_mat_23_01, (_MM_PERM_ENUM)160); //A02(8-11) A02(8-11) A03(8-11) A03(8-11) A02(8-11) A02(8-11) A03(8-11) A03(8-11) A02(8-11) A02(8-11) A03(8-11) A03(8-11) A02(8-11) A02(8-11) A03(8-11) A03(8-11) + const __m512i lhs_mat_01_02_sp1 = _mm512_shuffle_epi32(lhs_mat_01_02, (_MM_PERM_ENUM)160); //A00(16-19) A00(16-19) A01(16-19) A01(16-19) A00(16-19) A00(16-19) A01(16-19) A01(16-19) A00(16-19) A00(16-19) A01(16-19) A01(16-19) A00(16-19) A00(16-19) A01(16-19) A01(16-19) + const __m512i lhs_mat_23_02_sp1 = _mm512_shuffle_epi32(lhs_mat_23_02, (_MM_PERM_ENUM)160); //A02(16-19) A02(16-19) A03(16-19) A03(16-19) A02(16-19) A02(16-19) A03(16-19) A03(16-19) A02(16-19) A02(16-19) A03(16-19) A03(16-19) A02(16-19) A02(16-19) A03(16-19) A03(16-19) + const __m512i lhs_mat_01_03_sp1 = _mm512_shuffle_epi32(lhs_mat_01_03, (_MM_PERM_ENUM)160); //A00(24-27) A00(24-27) A01(24-27) A01(24-27) A00(24-27) A00(24-27) A01(24-27) A01(24-27) A00(24-27) A00(24-27) A01(24-27) A01(24-27) A00(24-27) A00(24-27) A01(24-27) A01(24-27) + const __m512i lhs_mat_23_03_sp1 = _mm512_shuffle_epi32(lhs_mat_23_03, (_MM_PERM_ENUM)160); //A02(24-27) A02(24-27) A03(24-27) A03(24-27) A02(24-27) A02(24-27) A03(24-27) A03(24-27) A02(24-27) A02(24-27) A03(24-27) A03(24-27) A02(24-27) A02(24-27) A03(24-27) A03(24-27) + + const __m512i lhs_mat_01_10_sp1 = _mm512_shuffle_epi32(lhs_mat_01_10, (_MM_PERM_ENUM)160); //A10(0-3) A10(0-3) A11(0-3) A11(0-3) A10(0-3) A10(0-3) A11(0-3) A11(0-3) A10(0-3) A10(0-3) A11(0-3) A11(0-3) A10(0-3) A10(0-3) A11(0-3) A11(0-3) + const __m512i lhs_mat_23_10_sp1 = _mm512_shuffle_epi32(lhs_mat_23_10, (_MM_PERM_ENUM)160); //A12(0-3) A12(0-3) A13(0-3) A13(0-3) A12(0-3) A12(0-3) A13(0-3) A13(0-3) A12(0-3) A12(0-3) A13(0-3) A13(0-3) A12(0-3) A12(0-3) A13(0-3) A13(0-3) + const __m512i lhs_mat_01_11_sp1 = _mm512_shuffle_epi32(lhs_mat_01_11, (_MM_PERM_ENUM)160); //A10(8-11) A10(8-11) A11(8-11) A11(8-11) A10(8-11) A10(8-11) A11(8-11) A11(8-11) A10(8-11) A10(8-11) A11(8-11) A11(8-11) A10(8-11) A10(8-11) A11(8-11) A11(8-11) + const __m512i lhs_mat_23_11_sp1 = _mm512_shuffle_epi32(lhs_mat_23_11, (_MM_PERM_ENUM)160); //A12(8-11) A12(8-11) A13(8-11) A13(8-11) A12(8-11) A12(8-11) A13(8-11) A13(8-11) A12(8-11) A12(8-11) A13(8-11) A13(8-11) A12(8-11) A12(8-11) A13(8-11) A13(8-11) + const __m512i lhs_mat_01_12_sp1 = _mm512_shuffle_epi32(lhs_mat_01_12, (_MM_PERM_ENUM)160); //A10(16-19) A10(16-19) A11(16-19) A11(16-19) A10(16-19) A10(16-19) A11(16-19) A11(16-19) A10(16-19) A10(16-19) A11(16-19) A11(16-19) A10(16-19) A10(16-19) A11(16-19) A11(16-19) + const __m512i lhs_mat_23_12_sp1 = _mm512_shuffle_epi32(lhs_mat_23_12, (_MM_PERM_ENUM)160); //A12(16-19) A12(16-19) A13(16-19) A13(16-19) A12(16-19) A12(16-19) A13(16-19) A13(16-19) A12(16-19) A12(16-19) A13(16-19) A13(16-19) A12(16-19) A12(16-19) A13(16-19) A13(16-19) + const __m512i lhs_mat_01_13_sp1 = _mm512_shuffle_epi32(lhs_mat_01_13, (_MM_PERM_ENUM)160); //A10(24-27) A10(24-27) A11(24-27) A11(24-27) A10(24-27) A10(24-27) A11(24-27) A11(24-27) A10(24-27) A10(24-27) A11(24-27) A11(24-27) A10(24-27) A10(24-27) A11(24-27) A11(24-27) + const __m512i lhs_mat_23_13_sp1 = _mm512_shuffle_epi32(lhs_mat_23_13, (_MM_PERM_ENUM)160); //A12(24-27) A12(24-27) A13(24-27) A13(24-27) A12(24-27) A12(24-27) A13(24-27) A13(24-27) A12(24-27) A12(24-27) A13(24-27) A13(24-27) A12(24-27) A12(24-27) A13(24-27) A13(24-27) + + const __m512i lhs_mat_01_00_sp2 = _mm512_shuffle_epi32(lhs_mat_01_00, (_MM_PERM_ENUM)245); //A00(4-7) A00(4-7) A01(4-7) A01(4-7) A00(4-7) A00(4-7) A01(4-7) A01(4-7) A00(4-7) A00(4-7) A01(4-7) A01(4-7) A00(4-7) A00(4-7) A01(4-7) A01(4-7) + const __m512i lhs_mat_23_00_sp2 = _mm512_shuffle_epi32(lhs_mat_23_00, (_MM_PERM_ENUM)245); //A02(4-7) A02(4-7) A03(4-7) A03(4-7) A02(4-7) A02(4-7) A03(4-7) A03(4-7) A02(4-7) A02(4-7) A03(4-7) A03(4-7) A02(4-7) A02(4-7) A03(4-7) A03(4-7) + const __m512i lhs_mat_01_01_sp2 = _mm512_shuffle_epi32(lhs_mat_01_01, (_MM_PERM_ENUM)245); //A00(12-15) A00(12-15) A01(12-15) A01(12-15) A00(12-15) A00(12-15) A01(12-15) A01(12-15) A00(12-15) A00(12-15) A01(12-15) A01(12-15) A00(12-15) A00(12-15) A01(12-15) A01(12-15) + const __m512i lhs_mat_23_01_sp2 = _mm512_shuffle_epi32(lhs_mat_23_01, (_MM_PERM_ENUM)245); //A02(12-15) A02(12-15) A03(12-15) A03(12-15) A02(12-15) A02(12-15) A03(12-15) A03(12-15) A02(12-15) A02(12-15) A03(12-15) A03(12-15) A02(12-15) A02(12-15) A03(12-15) A03(12-15) + const __m512i lhs_mat_01_02_sp2 = _mm512_shuffle_epi32(lhs_mat_01_02, (_MM_PERM_ENUM)245); //A00(20-23) A00(20-23) A01(20-23) A01(20-23) A00(20-23) A00(20-23) A01(20-23) A01(20-23) A00(20-23) A00(20-23) A01(20-23) A01(20-23) A00(20-23) A00(20-23) A01(20-23) A01(20-23) + const __m512i lhs_mat_23_02_sp2 = _mm512_shuffle_epi32(lhs_mat_23_02, (_MM_PERM_ENUM)245); //A02(20-23) A02(20-23) A03(20-23) A03(20-23) A02(20-23) A02(20-23) A03(20-23) A03(20-23) A02(20-23) A02(20-23) A03(20-23) A03(20-23) A02(20-23) A02(20-23) A03(20-23) A03(20-23) + const __m512i lhs_mat_01_03_sp2 = _mm512_shuffle_epi32(lhs_mat_01_03, (_MM_PERM_ENUM)245); //A00(28-31) A00(28-31) A01(28-31) A01(28-31) A00(28-31) A00(28-31) A01(28-31) A01(28-31) A00(28-31) A00(28-31) A01(28-31) A01(28-31) A00(28-31) A00(28-31) A01(28-31) A01(28-31) + const __m512i lhs_mat_23_03_sp2 = _mm512_shuffle_epi32(lhs_mat_23_03, (_MM_PERM_ENUM)245); //A02(28-31) A02(28-31) A03(28-31) A03(28-31) A02(28-31) A02(28-31) A03(28-31) A03(28-31) A02(28-31) A02(28-31) A03(28-31) A03(28-31) A02(28-31) A02(28-31) A03(28-31) A03(28-31) + + const __m512i lhs_mat_01_10_sp2 = _mm512_shuffle_epi32(lhs_mat_01_10, (_MM_PERM_ENUM)245); //A10(4-7) A10(4-7) A11(4-7) A11(4-7) A10(4-7) A10(4-7) A11(4-7) A11(4-7) A10(4-7) A10(4-7) A11(4-7) A11(4-7) A10(4-7) A10(4-7) A11(4-7) A11(4-7) + const __m512i lhs_mat_23_10_sp2 = _mm512_shuffle_epi32(lhs_mat_23_10, (_MM_PERM_ENUM)245); //A12(4-7) A12(4-7) A13(4-7) A13(4-7) A12(4-7) A12(4-7) A13(4-7) A13(4-7) A12(4-7) A12(4-7) A13(4-7) A13(4-7) A12(4-7) A12(4-7) A13(4-7) A13(4-7) + const __m512i lhs_mat_01_11_sp2 = _mm512_shuffle_epi32(lhs_mat_01_11, (_MM_PERM_ENUM)245); //A10(12-15) A10(12-15) A11(12-15) A11(12-15) A10(12-15) A10(12-15) A11(12-15) A11(12-15) A10(12-15) A10(12-15) A11(12-15) A11(12-15) A10(12-15) A10(12-15) A11(12-15) A11(12-15) + const __m512i lhs_mat_23_11_sp2 = _mm512_shuffle_epi32(lhs_mat_23_11, (_MM_PERM_ENUM)245); //A12(12-15) A12(12-15) A13(12-15) A13(12-15) A12(12-15) A12(12-15) A13(12-15) A13(12-15) A12(12-15) A12(12-15) A13(12-15) A13(12-15) A12(12-15) A12(12-15) A13(12-15) A13(12-15) + const __m512i lhs_mat_01_12_sp2 = _mm512_shuffle_epi32(lhs_mat_01_12, (_MM_PERM_ENUM)245); //A10(20-23) A10(20-23) A11(20-23) A11(20-23) A10(20-23) A10(20-23) A11(20-23) A11(20-23) A10(20-23) A10(20-23) A11(20-23) A11(20-23) A10(20-23) A10(20-23) A11(20-23) A11(20-23) + const __m512i lhs_mat_23_12_sp2 = _mm512_shuffle_epi32(lhs_mat_23_12, (_MM_PERM_ENUM)245); //A12(20-23) A12(20-23) A13(20-23) A13(20-23) A12(20-23) A12(20-23) A13(20-23) A13(20-23) A12(20-23) A12(20-23) A13(20-23) A13(20-23) A12(20-23) A12(20-23) A13(20-23) A13(20-23) + const __m512i lhs_mat_01_13_sp2 = _mm512_shuffle_epi32(lhs_mat_01_13, (_MM_PERM_ENUM)245); //A10(28-31) A10(28-31) A11(28-31) A11(28-31) A10(28-31) A10(28-31) A11(28-31) A11(28-31) A10(28-31) A10(28-31) A11(28-31) A11(28-31) A10(28-31) A10(28-31) A11(28-31) A11(28-31) + const __m512i lhs_mat_23_13_sp2 = _mm512_shuffle_epi32(lhs_mat_23_13, (_MM_PERM_ENUM)245); //A12(28-31) A12(28-31) A13(28-31) A13(28-31) A12(28-31) A12(28-31) A13(28-31) A13(28-31) A12(28-31) A12(28-31) A13(28-31) A13(28-31) A12(28-31) A12(28-31) A13(28-31) A13(28-31) + + // The values arranged in shuffle patterns are operated with dot product operation within 32 bit lane i.e corresponding bytes and multiplied and added into 32 bit integers within 32 bit lane + __m512i iacc_mat_00_0_sp1 = _mm512_add_epi16(_mm512_add_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_03_sp1, lhs_mat_01_03_sp1), _mm512_maddubs_epi16(rhs_mat_014589CD_02_sp1, lhs_mat_01_02_sp1)), _mm512_maddubs_epi16(rhs_mat_014589CD_01_sp1, lhs_mat_01_01_sp1)), _mm512_maddubs_epi16(rhs_mat_014589CD_00_sp1, lhs_mat_01_00_sp1)); + __m512i iacc_mat_01_0_sp1 = _mm512_add_epi16(_mm512_add_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_03_sp1, lhs_mat_01_03_sp1), _mm512_maddubs_epi16(rhs_mat_2367ABEF_02_sp1, lhs_mat_01_02_sp1)), _mm512_maddubs_epi16(rhs_mat_2367ABEF_01_sp1, lhs_mat_01_01_sp1)), _mm512_maddubs_epi16(rhs_mat_2367ABEF_00_sp1, lhs_mat_01_00_sp1)); + __m512i iacc_mat_10_0_sp1 = _mm512_add_epi16(_mm512_add_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_03_sp1, lhs_mat_23_03_sp1), _mm512_maddubs_epi16(rhs_mat_014589CD_02_sp1, lhs_mat_23_02_sp1)), _mm512_maddubs_epi16(rhs_mat_014589CD_01_sp1, lhs_mat_23_01_sp1)), _mm512_maddubs_epi16(rhs_mat_014589CD_00_sp1, lhs_mat_23_00_sp1)); + __m512i iacc_mat_11_0_sp1 = _mm512_add_epi16(_mm512_add_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_03_sp1, lhs_mat_23_03_sp1), _mm512_maddubs_epi16(rhs_mat_2367ABEF_02_sp1, lhs_mat_23_02_sp1)), _mm512_maddubs_epi16(rhs_mat_2367ABEF_01_sp1, lhs_mat_23_01_sp1)), _mm512_maddubs_epi16(rhs_mat_2367ABEF_00_sp1, lhs_mat_23_00_sp1)); + __m512i iacc_mat_00_1_sp1 = _mm512_add_epi16(_mm512_add_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_13_sp1, lhs_mat_01_13_sp1), _mm512_maddubs_epi16(rhs_mat_014589CD_12_sp1, lhs_mat_01_12_sp1)), _mm512_maddubs_epi16(rhs_mat_014589CD_11_sp1, lhs_mat_01_11_sp1)), _mm512_maddubs_epi16(rhs_mat_014589CD_10_sp1, lhs_mat_01_10_sp1)); + __m512i iacc_mat_01_1_sp1 = _mm512_add_epi16(_mm512_add_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_13_sp1, lhs_mat_01_13_sp1), _mm512_maddubs_epi16(rhs_mat_2367ABEF_12_sp1, lhs_mat_01_12_sp1)), _mm512_maddubs_epi16(rhs_mat_2367ABEF_11_sp1, lhs_mat_01_11_sp1)), _mm512_maddubs_epi16(rhs_mat_2367ABEF_10_sp1, lhs_mat_01_10_sp1)); + __m512i iacc_mat_10_1_sp1 = _mm512_add_epi16(_mm512_add_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_13_sp1, lhs_mat_23_13_sp1), _mm512_maddubs_epi16(rhs_mat_014589CD_12_sp1, lhs_mat_23_12_sp1)), _mm512_maddubs_epi16(rhs_mat_014589CD_11_sp1, lhs_mat_23_11_sp1)), _mm512_maddubs_epi16(rhs_mat_014589CD_10_sp1, lhs_mat_23_10_sp1)); + __m512i iacc_mat_11_1_sp1 = _mm512_add_epi16(_mm512_add_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_13_sp1, lhs_mat_23_13_sp1), _mm512_maddubs_epi16(rhs_mat_2367ABEF_12_sp1, lhs_mat_23_12_sp1)), _mm512_maddubs_epi16(rhs_mat_2367ABEF_11_sp1, lhs_mat_23_11_sp1)), _mm512_maddubs_epi16(rhs_mat_2367ABEF_10_sp1, lhs_mat_23_10_sp1)); + + __m512i iacc_mat_00_0_sp2 = _mm512_add_epi16(_mm512_add_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_03_sp2, lhs_mat_01_03_sp2), _mm512_maddubs_epi16(rhs_mat_014589CD_02_sp2, lhs_mat_01_02_sp2)), _mm512_maddubs_epi16(rhs_mat_014589CD_01_sp2, lhs_mat_01_01_sp2)), _mm512_maddubs_epi16(rhs_mat_014589CD_00_sp2, lhs_mat_01_00_sp2)); + __m512i iacc_mat_01_0_sp2 = _mm512_add_epi16(_mm512_add_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_03_sp2, lhs_mat_01_03_sp2), _mm512_maddubs_epi16(rhs_mat_2367ABEF_02_sp2, lhs_mat_01_02_sp2)), _mm512_maddubs_epi16(rhs_mat_2367ABEF_01_sp2, lhs_mat_01_01_sp2)), _mm512_maddubs_epi16(rhs_mat_2367ABEF_00_sp2, lhs_mat_01_00_sp2)); + __m512i iacc_mat_10_0_sp2 = _mm512_add_epi16(_mm512_add_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_03_sp2, lhs_mat_23_03_sp2), _mm512_maddubs_epi16(rhs_mat_014589CD_02_sp2, lhs_mat_23_02_sp2)), _mm512_maddubs_epi16(rhs_mat_014589CD_01_sp2, lhs_mat_23_01_sp2)), _mm512_maddubs_epi16(rhs_mat_014589CD_00_sp2, lhs_mat_23_00_sp2)); + __m512i iacc_mat_11_0_sp2 = _mm512_add_epi16(_mm512_add_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_03_sp2, lhs_mat_23_03_sp2), _mm512_maddubs_epi16(rhs_mat_2367ABEF_02_sp2, lhs_mat_23_02_sp2)), _mm512_maddubs_epi16(rhs_mat_2367ABEF_01_sp2, lhs_mat_23_01_sp2)), _mm512_maddubs_epi16(rhs_mat_2367ABEF_00_sp2, lhs_mat_23_00_sp2)); + __m512i iacc_mat_00_1_sp2 = _mm512_add_epi16(_mm512_add_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_13_sp2, lhs_mat_01_13_sp2), _mm512_maddubs_epi16(rhs_mat_014589CD_12_sp2, lhs_mat_01_12_sp2)), _mm512_maddubs_epi16(rhs_mat_014589CD_11_sp2, lhs_mat_01_11_sp2)), _mm512_maddubs_epi16(rhs_mat_014589CD_10_sp2, lhs_mat_01_10_sp2)); + __m512i iacc_mat_01_1_sp2 = _mm512_add_epi16(_mm512_add_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_13_sp2, lhs_mat_01_13_sp2), _mm512_maddubs_epi16(rhs_mat_2367ABEF_12_sp2, lhs_mat_01_12_sp2)), _mm512_maddubs_epi16(rhs_mat_2367ABEF_11_sp2, lhs_mat_01_11_sp2)), _mm512_maddubs_epi16(rhs_mat_2367ABEF_10_sp2, lhs_mat_01_10_sp2)); + __m512i iacc_mat_10_1_sp2 = _mm512_add_epi16(_mm512_add_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_13_sp2, lhs_mat_23_13_sp2), _mm512_maddubs_epi16(rhs_mat_014589CD_12_sp2, lhs_mat_23_12_sp2)), _mm512_maddubs_epi16(rhs_mat_014589CD_11_sp2, lhs_mat_23_11_sp2)), _mm512_maddubs_epi16(rhs_mat_014589CD_10_sp2, lhs_mat_23_10_sp2)); + __m512i iacc_mat_11_1_sp2 = _mm512_add_epi16(_mm512_add_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_13_sp2, lhs_mat_23_13_sp2), _mm512_maddubs_epi16(rhs_mat_2367ABEF_12_sp2, lhs_mat_23_12_sp2)), _mm512_maddubs_epi16(rhs_mat_2367ABEF_11_sp2, lhs_mat_23_11_sp2)), _mm512_maddubs_epi16(rhs_mat_2367ABEF_10_sp2, lhs_mat_23_10_sp2)); + + // Output of both shuffle patterns are added in order to sum dot product outputs of all 32 values in block + __m512i iacc_mat_00_0 = _mm512_add_epi16(iacc_mat_00_0_sp1, iacc_mat_00_0_sp2); + __m512i iacc_mat_01_0 = _mm512_add_epi16(iacc_mat_01_0_sp1, iacc_mat_01_0_sp2); + __m512i iacc_mat_10_0 = _mm512_add_epi16(iacc_mat_10_0_sp1, iacc_mat_10_0_sp2); + __m512i iacc_mat_11_0 = _mm512_add_epi16(iacc_mat_11_0_sp1, iacc_mat_11_0_sp2); + + __m512i iacc_mat_00_1 = _mm512_add_epi16(iacc_mat_00_1_sp1, iacc_mat_00_1_sp2); + __m512i iacc_mat_01_1 = _mm512_add_epi16(iacc_mat_01_1_sp1, iacc_mat_01_1_sp2); + __m512i iacc_mat_10_1 = _mm512_add_epi16(iacc_mat_10_1_sp1, iacc_mat_10_1_sp2); + __m512i iacc_mat_11_1 = _mm512_add_epi16(iacc_mat_11_1_sp1, iacc_mat_11_1_sp2); + + iacc_mat_00_0 = _mm512_madd_epi16(iacc_mat_00_0, scale_014589CD_0); + iacc_mat_01_0 = _mm512_madd_epi16(iacc_mat_01_0, scale_2367ABEF_0); + iacc_mat_10_0 = _mm512_madd_epi16(iacc_mat_10_0, scale_014589CD_0); + iacc_mat_11_0 = _mm512_madd_epi16(iacc_mat_11_0, scale_2367ABEF_0); + + iacc_mat_00_1 = _mm512_madd_epi16(iacc_mat_00_1, scale_014589CD_1); + iacc_mat_01_1 = _mm512_madd_epi16(iacc_mat_01_1, scale_2367ABEF_1); + iacc_mat_10_1 = _mm512_madd_epi16(iacc_mat_10_1, scale_014589CD_1); + iacc_mat_11_1 = _mm512_madd_epi16(iacc_mat_11_1, scale_2367ABEF_1); + + // Straighten out to make 4 row vectors (4 for each sub block which are accumulated together in the next step) + __m512i iacc_row_0_0 = _mm512_mask_blend_epi32(0xCCCC, iacc_mat_00_0, _mm512_shuffle_epi32(iacc_mat_01_0, (_MM_PERM_ENUM)78)); + __m512i iacc_row_1_0 = _mm512_mask_blend_epi32(0xCCCC, _mm512_shuffle_epi32(iacc_mat_00_0, (_MM_PERM_ENUM)78), iacc_mat_01_0); + __m512i iacc_row_2_0 = _mm512_mask_blend_epi32(0xCCCC, iacc_mat_10_0, _mm512_shuffle_epi32(iacc_mat_11_0, (_MM_PERM_ENUM)78)); + __m512i iacc_row_3_0 = _mm512_mask_blend_epi32(0xCCCC, _mm512_shuffle_epi32(iacc_mat_10_0, (_MM_PERM_ENUM)78), iacc_mat_11_0); + __m512i iacc_row_0_1 = _mm512_mask_blend_epi32(0xCCCC, iacc_mat_00_1, _mm512_shuffle_epi32(iacc_mat_01_1, (_MM_PERM_ENUM)78)); + __m512i iacc_row_1_1 = _mm512_mask_blend_epi32(0xCCCC, _mm512_shuffle_epi32(iacc_mat_00_1, (_MM_PERM_ENUM)78), iacc_mat_01_1); + __m512i iacc_row_2_1 = _mm512_mask_blend_epi32(0xCCCC, iacc_mat_10_1, _mm512_shuffle_epi32(iacc_mat_11_1, (_MM_PERM_ENUM)78)); + __m512i iacc_row_3_1 = _mm512_mask_blend_epi32(0xCCCC,_mm512_shuffle_epi32(iacc_mat_10_1, (_MM_PERM_ENUM)78), iacc_mat_11_1); + + __m512i iacc_row_0 = _mm512_add_epi32(iacc_row_0_0, iacc_row_0_1); + __m512i iacc_row_1 = _mm512_add_epi32(iacc_row_1_0, iacc_row_1_1); + __m512i iacc_row_2 = _mm512_add_epi32(iacc_row_2_0, iacc_row_2_1); + __m512i iacc_row_3 = _mm512_add_epi32(iacc_row_3_0, iacc_row_3_1); + + // Load the scale(d) values for all the 4 Q8_k blocks and repeat it across lanes + const __m128 row_scale_f32_sse = _mm_load_ps(a_ptr[b].d); + const __m256 row_scale_f32_ymm = _mm256_set_m128(row_scale_f32_sse, row_scale_f32_sse); + const __m512 row_scale_f32 = _mm512_insertf32x8(_mm512_castps256_ps512(row_scale_f32_ymm), row_scale_f32_ymm, 1); + + // Multiply with appropiate scales and accumulate (for both d and dmin) below + acc_rows[0] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_0), _mm512_mul_ps(col_scale_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 0)), acc_rows[0]); + acc_rows[1] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_1), _mm512_mul_ps(col_scale_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 85)), acc_rows[1]); + acc_rows[2] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_2), _mm512_mul_ps(col_scale_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 170)), acc_rows[2]); + acc_rows[3] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_3), _mm512_mul_ps(col_scale_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 255)), acc_rows[3]); + + __m512i iacc_row_min_0 = _mm512_madd_epi16(_mm512_shuffle_epi32(lhs_bsums_hsum_0123_01, (_MM_PERM_ENUM)0), mins_01); + __m512i iacc_row_min_1 = _mm512_madd_epi16(_mm512_shuffle_epi32(lhs_bsums_hsum_0123_01, (_MM_PERM_ENUM)85), mins_01); + __m512i iacc_row_min_2 = _mm512_madd_epi16(_mm512_shuffle_epi32(lhs_bsums_hsum_0123_01, (_MM_PERM_ENUM)170), mins_01); + __m512i iacc_row_min_3 = _mm512_madd_epi16(_mm512_shuffle_epi32(lhs_bsums_hsum_0123_01, (_MM_PERM_ENUM)255), mins_01); + + acc_min_rows[0] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_min_0), _mm512_mul_ps(col_dmin_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 0)), acc_min_rows[0]); + acc_min_rows[1] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_min_1), _mm512_mul_ps(col_dmin_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 85)), acc_min_rows[1]); + acc_min_rows[2] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_min_2), _mm512_mul_ps(col_dmin_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 170)), acc_min_rows[2]); + acc_min_rows[3] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_min_3), _mm512_mul_ps(col_dmin_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 255)), acc_min_rows[3]); + } + } + // Store accumlated values + for (int i = 0; i < 4; i++) { + _mm512_storeu_ps((float * )(s + ((y * 4 + i) * bs + x * 8)), _mm512_sub_ps(acc_rows[i], acc_min_rows[i])); + } + } + } + if (anc != nc) { + xstart = anc/8; + y = 0; + } +#endif //AVX512F + // Take group of four block_q8_Kx4 structures at each pass of the loop and perform dot product operation for (; y < anr / 4; y += 4) { @@ -4067,7 +4807,7 @@ static void ggml_gemm_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, c } // Take group of eight block_q4_kx8 structures at each pass of the loop and perform dot product operation - for (int64_t x = 0; x < nc / 8; x++) { + for (int64_t x = xstart; x < nc / 8; x++) { const block_q4_Kx8 * b_ptr = b_ptr_start + (x * b_nb); @@ -4401,7 +5141,7 @@ static void ggml_gemm_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, c const block_q8_Kx4 * a_ptr = a_ptr_start + (y * nb); - for (int64_t x = 0; x < nc / 8; x++) { + for (int64_t x = xstart; x < nc / 8; x++) { const block_q4_Kx8 * b_ptr = b_ptr_start + (x * b_nb); From daa422881a0ec7944771bcc8ff8de34d11f5bd3b Mon Sep 17 00:00:00 2001 From: Juk Armstrong <69222624+jukofyork@users.noreply.github.com> Date: Tue, 15 Apr 2025 07:49:57 +0100 Subject: [PATCH 14/20] llama : DeepSeek V2/V3 MLA implementation (#12801) * Merged using squash to remove all noise commit messages * Force flash attention off for `LLM_ARCH_DEEPSEEK2` - embedding too large * Removed 3 conts (2x RoPE and 1x RMS-norm) * Changed to use `` instead of `` * Reverted removal of the 3 conts * Used `reshape` in `llm_graph_context::build_attn_mha()` * Use `k_pe = ggml_reshape` * Removed the 3 conts again * Removed the 3D views of `wk_b` and `wv_b`, and just save and 3D in GGUF * Removed MQA optimisation from `build_attn_mha()` as no gains now * Simplified `is_mla` branch in `llm_build_deepseek2()` * Removed `build_attn_mla` and added `nullptr` to all `build_atnn` calls * Fixed call to `build_attn` in `llm_build_t5_enc` --- convert_hf_to_gguf.py | 33 +++- gguf-py/gguf/constants.py | 8 + gguf-py/gguf/gguf_writer.py | 6 + gguf-py/gguf/tensor_mapping.py | 8 + src/llama-arch.cpp | 23 +-- src/llama-arch.h | 4 + src/llama-context.cpp | 15 +- src/llama-graph.cpp | 19 +- src/llama-graph.h | 10 +- src/llama-hparams.h | 4 + src/llama-kv-cache.cpp | 2 +- src/llama-model.cpp | 320 +++++++++++++++++++-------------- src/llama-model.h | 2 + 13 files changed, 289 insertions(+), 165 deletions(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 2bf97475f..89522dee8 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -4422,6 +4422,10 @@ class DeepseekV2Model(Model): self._set_vocab_gpt2() def set_gguf_parameters(self): + + # note: deepseek2 using MLA converts into MQA (ie: GQA with 1 group) + self.hparams["num_key_value_heads"] = 1 + super().set_gguf_parameters() hparams = self.hparams @@ -4430,8 +4434,13 @@ class DeepseekV2Model(Model): if "q_lora_rank" in hparams and hparams["q_lora_rank"] is not None: self.gguf_writer.add_q_lora_rank(hparams["q_lora_rank"]) self.gguf_writer.add_kv_lora_rank(hparams["kv_lora_rank"]) - self.gguf_writer.add_key_length(hparams["qk_nope_head_dim"] + hparams["qk_rope_head_dim"]) - self.gguf_writer.add_value_length(hparams["v_head_dim"]) + + # note: deepseek2 using MLA converts into MQA with larger heads, then decompresses to MHA + self.gguf_writer.add_key_length(hparams["kv_lora_rank"] + hparams["qk_rope_head_dim"]) + self.gguf_writer.add_value_length(hparams["kv_lora_rank"]) + self.gguf_writer.add_key_length_mla(hparams["qk_nope_head_dim"] + hparams["qk_rope_head_dim"]) + self.gguf_writer.add_value_length_mla(hparams["v_head_dim"]) + self.gguf_writer.add_expert_feed_forward_length(hparams["moe_intermediate_size"]) self.gguf_writer.add_expert_count(hparams["n_routed_experts"]) self.gguf_writer.add_expert_shared_count(hparams["n_shared_experts"]) @@ -4500,6 +4509,26 @@ class DeepseekV2Model(Model): else: return [] + # note: MLA with the absorption optimization, needs these two split and k_b_proj transposed + if name.endswith("kv_b_proj.weight"): + name_kb = name.replace("kv_b_proj", "k_b_proj") + name_vb = name.replace("kv_b_proj", "v_b_proj") + + n_head_kv = self.hparams["num_key_value_heads"] + v_head_dim = self.hparams["v_head_dim"] + qk_nope_head_dim = self.hparams["qk_nope_head_dim"] + + assert data_torch.shape[0] == n_head_kv * (v_head_dim + qk_nope_head_dim) + + kv_b = data_torch.view(n_head_kv, v_head_dim + qk_nope_head_dim, data_torch.shape[-1]) + k_b, v_b = torch.split(kv_b, [qk_nope_head_dim, v_head_dim], dim=1) + k_b = k_b.transpose(1, 2) + + return [ + (self.map_tensor_name(name_kb), k_b), + (self.map_tensor_name(name_vb), v_b) + ] + return [(self.map_tensor_name(name), data_torch)] def prepare_tensors(self): diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 162070e6e..8fcde2626 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -139,6 +139,8 @@ class Keys: REL_BUCKETS_COUNT = "{arch}.attention.relative_buckets_count" SLIDING_WINDOW = "{arch}.attention.sliding_window" SCALE = "{arch}.attention.scale" + KEY_LENGTH_MLA = "{arch}.attention.key_length_mla" + VALUE_LENGTH_MLA = "{arch}.attention.value_length_mla" class Rope: DIMENSION_COUNT = "{arch}.rope.dimension_count" @@ -382,6 +384,8 @@ class MODEL_TENSOR(IntEnum): ATTN_Q_B = auto() ATTN_KV_A_MQA = auto() ATTN_KV_B = auto() + ATTN_K_B = auto() + ATTN_V_B = auto() ATTN_Q_A_NORM = auto() ATTN_KV_A_NORM = auto() FFN_SUB_NORM = auto() @@ -590,6 +594,8 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = { MODEL_TENSOR.ATTN_Q_B: "blk.{bid}.attn_q_b", MODEL_TENSOR.ATTN_KV_A_MQA: "blk.{bid}.attn_kv_a_mqa", MODEL_TENSOR.ATTN_KV_B: "blk.{bid}.attn_kv_b", + MODEL_TENSOR.ATTN_K_B: "blk.{bid}.attn_k_b", + MODEL_TENSOR.ATTN_V_B: "blk.{bid}.attn_v_b", MODEL_TENSOR.ATTN_Q_A_NORM: "blk.{bid}.attn_q_a_norm", MODEL_TENSOR.ATTN_KV_A_NORM: "blk.{bid}.attn_kv_a_norm", MODEL_TENSOR.ATTN_SUB_NORM: "blk.{bid}.attn_sub_norm", @@ -1517,6 +1523,8 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = { MODEL_TENSOR.ATTN_Q_B, MODEL_TENSOR.ATTN_KV_A_MQA, MODEL_TENSOR.ATTN_KV_B, + MODEL_TENSOR.ATTN_K_B, + MODEL_TENSOR.ATTN_V_B, MODEL_TENSOR.ATTN_Q_A_NORM, MODEL_TENSOR.ATTN_KV_A_NORM, MODEL_TENSOR.ATTN_OUT, diff --git a/gguf-py/gguf/gguf_writer.py b/gguf-py/gguf/gguf_writer.py index 485550aad..aef03db15 100644 --- a/gguf-py/gguf/gguf_writer.py +++ b/gguf-py/gguf/gguf_writer.py @@ -689,6 +689,12 @@ class GGUFWriter: def add_value_length(self, length: int) -> None: self.add_uint32(Keys.Attention.VALUE_LENGTH.format(arch=self.arch), length) + def add_key_length_mla(self, length: int) -> None: + self.add_uint32(Keys.Attention.KEY_LENGTH_MLA.format(arch=self.arch), length) + + def add_value_length_mla(self, length: int) -> None: + self.add_uint32(Keys.Attention.VALUE_LENGTH_MLA.format(arch=self.arch), length) + def add_max_alibi_bias(self, bias: float) -> None: self.add_float32(Keys.Attention.MAX_ALIBI_BIAS.format(arch=self.arch), bias) diff --git a/gguf-py/gguf/tensor_mapping.py b/gguf-py/gguf/tensor_mapping.py index 35154e9b5..0bc75cf51 100644 --- a/gguf-py/gguf/tensor_mapping.py +++ b/gguf-py/gguf/tensor_mapping.py @@ -677,6 +677,14 @@ class TensorNameMap: "model.layers.{bid}.self_attn.kv_b_proj", # deepseek2 ), + MODEL_TENSOR.ATTN_K_B: ( + "model.layers.{bid}.self_attn.k_b_proj", # deepseek2 + ), + + MODEL_TENSOR.ATTN_V_B: ( + "model.layers.{bid}.self_attn.v_b_proj", # deepseek2 + ), + MODEL_TENSOR.ATTN_Q_A_NORM: ( "model.layers.{bid}.self_attn.q_a_layernorm", # deepseek2 ), diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp index a6fddc7fd..62e1480bb 100644 --- a/src/llama-arch.cpp +++ b/src/llama-arch.cpp @@ -140,6 +140,8 @@ static const std::map LLM_KV_NAMES = { { LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT, "%s.attention.relative_buckets_count" }, { LLM_KV_ATTENTION_SLIDING_WINDOW, "%s.attention.sliding_window" }, { LLM_KV_ATTENTION_SCALE, "%s.attention.scale" }, + { LLM_KV_ATTENTION_KEY_LENGTH_MLA, "%s.attention.key_length_mla" }, + { LLM_KV_ATTENTION_VALUE_LENGTH_MLA, "%s.attention.value_length_mla" }, { LLM_KV_ROPE_DIMENSION_COUNT, "%s.rope.dimension_count" }, { LLM_KV_ROPE_DIMENSION_SECTIONS, "%s.rope.dimension_sections" }, @@ -1103,6 +1105,8 @@ static const std::map> LLM_TENSOR_N { LLM_TENSOR_ATTN_Q_B, "blk.%d.attn_q_b" }, { LLM_TENSOR_ATTN_KV_A_MQA, "blk.%d.attn_kv_a_mqa" }, { LLM_TENSOR_ATTN_KV_B, "blk.%d.attn_kv_b" }, + { LLM_TENSOR_ATTN_K_B, "blk.%d.attn_k_b" }, + { LLM_TENSOR_ATTN_V_B, "blk.%d.attn_v_b" }, { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, @@ -1563,23 +1567,8 @@ static const std::map LLM_TENSOR_INFOS = { {LLM_TENSOR_ATTN_Q_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_ATTN_KV_A_MQA, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_ATTN_KV_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, - {LLM_TENSOR_DEC_ATTN_Q, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, - {LLM_TENSOR_DEC_ATTN_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, - {LLM_TENSOR_ATTN_Q, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, - {LLM_TENSOR_ATTN_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, - {LLM_TENSOR_ATTN_V, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, - {LLM_TENSOR_ATTN_QKV, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, - {LLM_TENSOR_ATTN_OUT, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, - {LLM_TENSOR_FFN_GATE, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, - {LLM_TENSOR_FFN_DOWN, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, - {LLM_TENSOR_FFN_UP, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, - {LLM_TENSOR_FFN_DOWN_SHEXP, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, - {LLM_TENSOR_FFN_GATE_SHEXP, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, - {LLM_TENSOR_FFN_UP_SHEXP, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, - {LLM_TENSOR_ATTN_Q_A, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, - {LLM_TENSOR_ATTN_Q_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, - {LLM_TENSOR_ATTN_KV_A_MQA, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, - {LLM_TENSOR_ATTN_KV_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_ATTN_K_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_ATTN_V_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_DEC_ATTN_Q, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_DEC_ATTN_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_DEC_ATTN_V, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, diff --git a/src/llama-arch.h b/src/llama-arch.h index 2c2099b3c..98ca00a1b 100644 --- a/src/llama-arch.h +++ b/src/llama-arch.h @@ -144,6 +144,8 @@ enum llm_kv { LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT, LLM_KV_ATTENTION_SLIDING_WINDOW, LLM_KV_ATTENTION_SCALE, + LLM_KV_ATTENTION_KEY_LENGTH_MLA, + LLM_KV_ATTENTION_VALUE_LENGTH_MLA, LLM_KV_ROPE_DIMENSION_COUNT, LLM_KV_ROPE_DIMENSION_SECTIONS, @@ -306,6 +308,8 @@ enum llm_tensor { LLM_TENSOR_ATTN_Q_B, LLM_TENSOR_ATTN_KV_A_MQA, LLM_TENSOR_ATTN_KV_B, + LLM_TENSOR_ATTN_K_B, + LLM_TENSOR_ATTN_V_B, LLM_TENSOR_ATTN_Q_A_NORM, LLM_TENSOR_ATTN_KV_A_NORM, LLM_TENSOR_ATTN_SUB_NORM, diff --git a/src/llama-context.cpp b/src/llama-context.cpp index 4735e98ea..d3ef1cbde 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -10,6 +10,7 @@ #include #include #include +#include // // llama_context @@ -473,7 +474,6 @@ ggml_tensor * llama_context::build_rope_shift( const auto & n_ctx_orig = cparams.n_ctx_orig_yarn; const auto & yarn_ext_factor = cparams.yarn_ext_factor; - const auto & yarn_attn_factor = cparams.yarn_attn_factor; const auto & yarn_beta_fast = cparams.yarn_beta_fast; const auto & yarn_beta_slow = cparams.yarn_beta_slow; @@ -482,6 +482,10 @@ ggml_tensor * llama_context::build_rope_shift( const auto & n_rot = hparams.n_rot; const auto & rope_type = hparams.rope_type; + // See llm_build_deepseek2() for why attn_factor has to be scaled for YaRN RoPE to work correctly. + // See https://github.com/ggerganov/llama.cpp/discussions/7416 for detailed explanation. + const float yarn_attn_factor_scaled = model.arch == LLM_ARCH_DEEPSEEK2 ? 1.0f / (1.0f + 0.1f * logf(1.0f / freq_scale)) : cparams.yarn_attn_factor; + ggml_tensor * tmp; if (ggml_is_quantized(cur->type)) { @@ -500,14 +504,14 @@ ggml_tensor * llama_context::build_rope_shift( tmp = ggml_rope_ext_inplace(ctx0, tmp, shift, factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, - yarn_ext_factor, yarn_attn_factor, yarn_beta_fast, yarn_beta_slow); + yarn_ext_factor, yarn_attn_factor_scaled, yarn_beta_fast, yarn_beta_slow); tmp = ggml_cpy(ctx0, tmp, cur); } else { // we rotate only the first n_rot dimensions tmp = ggml_rope_ext_inplace(ctx0, cur, shift, factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, - yarn_ext_factor, yarn_attn_factor, yarn_beta_fast, yarn_beta_slow); + yarn_ext_factor, yarn_attn_factor_scaled, yarn_beta_fast, yarn_beta_slow); } return tmp; @@ -2274,6 +2278,11 @@ llama_context * llama_init_from_model( params.flash_attn = false; } + if (params.flash_attn && model->arch == LLM_ARCH_DEEPSEEK2) { + LLAMA_LOG_WARN("%s: flash_attn is not compatible with Deepseek2 - forcing off\n", __func__); + params.flash_attn = false; + } + if (ggml_is_quantized(params.type_v) && !params.flash_attn) { LLAMA_LOG_ERROR("%s: V cache quantization requires flash_attn\n", __func__); return nullptr; diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index cd955d63b..5d0222b98 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -1188,6 +1188,7 @@ ggml_tensor * llm_graph_context::build_attn_mha( ggml_tensor * v, ggml_tensor * kq_b, ggml_tensor * kq_mask, + ggml_tensor * v_mla, bool v_trans, float kq_scale) const { //const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il); @@ -1199,7 +1200,8 @@ ggml_tensor * llm_graph_context::build_attn_mha( //const auto & n_embd_head_k = hparams.n_embd_head_k; //const auto & n_embd_head_v = hparams.n_embd_head_v; - const auto n_embd_head_v = v_trans ? v->ne[1] : v->ne[0]; + // note: for MLA with the absorption optimization, the final embedding size will be changed via v_mla + const auto n_embd_head_v = v_mla == nullptr ? v_trans ? v->ne[1] : v->ne[0] : v_mla->ne[1]; const auto n_tokens = q->ne[1]; const auto n_head = q->ne[2]; @@ -1267,6 +1269,11 @@ ggml_tensor * llm_graph_context::build_attn_mha( ggml_tensor * kqv = ggml_mul_mat(ctx0, v, kq); + // for MLA with the absorption optimization, we need to "decompress" from MQA back to MHA + if (v_mla) { + kqv = ggml_mul_mat(ctx0, v_mla, kqv); + } + ggml_tensor * kqv_merged = ggml_permute(ctx0, kqv, 0, 2, 1, 3); cur = ggml_cont_2d(ctx0, kqv_merged, n_embd_head_v*n_head, n_tokens); @@ -1304,6 +1311,7 @@ ggml_tensor * llm_graph_context::build_attn( ggml_tensor * k_cur, ggml_tensor * v_cur, ggml_tensor * kq_b, + ggml_tensor * v_mla, float kq_scale, int il) const { GGML_UNUSED(n_tokens); @@ -1325,7 +1333,7 @@ ggml_tensor * llm_graph_context::build_attn( ggml_tensor * v = ggml_permute(ctx0, v_cur, 0, 2, 1, 3); //cb(k, "v", il); - ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, false, kq_scale); + ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, false, kq_scale); cb(cur, "kqv_out", il); @@ -1379,6 +1387,7 @@ ggml_tensor * llm_graph_context::build_attn( ggml_tensor * k_cur, ggml_tensor * v_cur, ggml_tensor * kq_b, + ggml_tensor * v_mla, float kq_scale, int il) const { // these nodes are added to the graph together so that they are not reordered @@ -1464,7 +1473,7 @@ ggml_tensor * llm_graph_context::build_attn( ggml_element_size(kv_self->v_l[il])*n_ctx*n_embd_head_v, 0); - ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_trans, kq_scale); + ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, v_trans, kq_scale); cb(cur, "kqv_out", il); if (wo) { @@ -1504,6 +1513,7 @@ ggml_tensor * llm_graph_context::build_attn( ggml_tensor * k_cur, ggml_tensor * v_cur, ggml_tensor * kq_b, + ggml_tensor * v_mla, float kq_scale, int il) const { // these nodes are added to the graph together so that they are not reordered @@ -1523,7 +1533,7 @@ ggml_tensor * llm_graph_context::build_attn( ggml_tensor * v = ggml_permute(ctx0, v_cur, 0, 2, 1, 3); //cb(k, "v", il); - ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, false, kq_scale); + ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, false, kq_scale); cb(cur, "kqv_out", il); @@ -1692,4 +1702,3 @@ void llm_graph_context::build_pooling( ggml_build_forward_expand(gf, cur); } - diff --git a/src/llama-graph.h b/src/llama-graph.h index 5b6618f9e..d192dc149 100644 --- a/src/llama-graph.h +++ b/src/llama-graph.h @@ -505,11 +505,12 @@ struct llm_graph_context { ggml_tensor * build_attn_mha( ggml_cgraph * gf, - ggml_tensor * q, // [n_embd_head_q, n_tokens, n_head_q] - ggml_tensor * k, // [n_embd_head_k, n_tokens, n_head_k] - ggml_tensor * v, // [n_embd_head_v, n_tokens, n_head_v] (v_trans == false) + ggml_tensor * q, // [n_embd_head_q, n_tokens, n_head_q] + ggml_tensor * k, // [n_embd_head_k, n_tokens, n_head_k] + ggml_tensor * v, // [n_embd_head_v, n_tokens, n_head_v] (v_trans == false) ggml_tensor * kq_b, ggml_tensor * kq_mask, + ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v] bool v_trans, float kq_scale) const; @@ -524,6 +525,7 @@ struct llm_graph_context { ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens] ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens] ggml_tensor * kq_b, + ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v] float kq_scale, int il) const; @@ -538,6 +540,7 @@ struct llm_graph_context { ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens] ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens] ggml_tensor * kq_b, + ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v] float kq_scale, int il) const; @@ -552,6 +555,7 @@ struct llm_graph_context { ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens] ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens] ggml_tensor * kq_b, + ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v] float kq_scale, int il) const; diff --git a/src/llama-hparams.h b/src/llama-hparams.h index 4e0b57190..80fcd65df 100644 --- a/src/llama-hparams.h +++ b/src/llama-hparams.h @@ -43,6 +43,10 @@ struct llama_hparams { uint32_t n_expert_used = 0; uint32_t n_rel_attn_bkts = 0; + // note: deepseek2 using MLA converts into MQA with larger heads, then decompresses to MHA + uint32_t n_embd_head_k_mla = 0; + uint32_t n_embd_head_v_mla = 0; + // for WavTokenizer struct llama_hparams_posnet posnet; struct llama_hparams_convnext convnext; diff --git a/src/llama-kv-cache.cpp b/src/llama-kv-cache.cpp index dbf5f1187..7c9d46d81 100644 --- a/src/llama-kv-cache.cpp +++ b/src/llama-kv-cache.cpp @@ -27,7 +27,7 @@ bool llama_kv_cache_unified::init( recurrent = llama_model_is_recurrent(&model); v_trans = !recurrent && !cparams.flash_attn; - can_shift = !recurrent && model.arch != LLM_ARCH_DEEPSEEK2; // not supported due to MLA + can_shift = !recurrent; LLAMA_LOG_INFO("%s: kv_size = %d, offload = %d, type_k = '%s', type_v = '%s', n_layer = %d, can_shift = %d\n", __func__, kv_size, offload, ggml_type_name(type_k), ggml_type_name(type_v), n_layer, can_shift); diff --git a/src/llama-model.cpp b/src/llama-model.cpp index b74dd72cf..248c61748 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -1156,6 +1156,8 @@ void llama_model::load_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_ATTENTION_Q_LORA_RANK, hparams.n_lora_q); } ml.get_key(LLM_KV_ATTENTION_KV_LORA_RANK, hparams.n_lora_kv); + ml.get_key(LLM_KV_ATTENTION_KEY_LENGTH_MLA, hparams.n_embd_head_k_mla, false); + ml.get_key(LLM_KV_ATTENTION_VALUE_LENGTH_MLA, hparams.n_embd_head_v_mla, false); ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); ml.get_key(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared); ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale); @@ -3205,8 +3207,14 @@ bool llama_model::load_tensors(llama_model_loader & ml) { { const bool is_lite = (hparams.n_layer == 27); + const bool is_mla = (hparams.n_embd_head_k_mla != 0 && hparams.n_embd_head_v_mla != 0); + + // note: these are the actual head sizes you get when treating as MHA or after "decompression" using wv_b for MLA + const int64_t n_embd_head_k_mla = is_mla ? hparams.n_embd_head_k_mla : hparams.n_embd_head_k; + const int64_t n_embd_head_v_mla = is_mla ? hparams.n_embd_head_v_mla : hparams.n_embd_head_v; + const int64_t n_embd_head_qk_rope = hparams.n_rot; - const int64_t n_embd_head_qk_nope = hparams.n_embd_head_k - hparams.n_rot; + const int64_t n_embd_head_qk_nope = n_embd_head_k_mla - n_embd_head_qk_rope; const int64_t q_lora_rank = hparams.n_lora_q; const int64_t kv_lora_rank = hparams.n_lora_kv; @@ -3232,14 +3240,22 @@ bool llama_model::load_tensors(llama_model_loader & ml) { if (!is_lite) { layer.wq_a = create_tensor(tn(LLM_TENSOR_ATTN_Q_A, "weight", i), {n_embd, q_lora_rank}, 0); - layer.wq_b = create_tensor(tn(LLM_TENSOR_ATTN_Q_B, "weight", i), {q_lora_rank, n_head * n_embd_head_k}, 0); + layer.wq_b = create_tensor(tn(LLM_TENSOR_ATTN_Q_B, "weight", i), {q_lora_rank, n_head * n_embd_head_k_mla}, 0); } else { - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_k_gqa}, 0); + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_head * n_embd_head_k_mla}, 0); } - layer.wkv_a_mqa = create_tensor(tn(LLM_TENSOR_ATTN_KV_A_MQA, "weight", i), {n_embd, kv_lora_rank + (n_embd_head_qk_rope)}, 0); - layer.wkv_b = create_tensor(tn(LLM_TENSOR_ATTN_KV_B, "weight", i), {kv_lora_rank, n_head * (n_embd_head_qk_nope + n_embd_head_v)}, 0); - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_head * ( n_embd_head_v), n_embd}, 0); + layer.wkv_a_mqa = create_tensor(tn(LLM_TENSOR_ATTN_KV_A_MQA, "weight", i), {n_embd, kv_lora_rank + n_embd_head_qk_rope}, 0); + + // note: only old legacy GGUF files will have the unsplit wkv_b tensor in + if (is_mla) { + layer.wk_b = create_tensor(tn(LLM_TENSOR_ATTN_K_B, "weight", i), {n_embd_head_qk_nope, kv_lora_rank, n_head}, 0); + layer.wv_b = create_tensor(tn(LLM_TENSOR_ATTN_V_B, "weight", i), {kv_lora_rank, n_embd_head_v_mla, n_head}, 0); + } else { + layer.wkv_b = create_tensor(tn(LLM_TENSOR_ATTN_KV_B, "weight", i), {kv_lora_rank, n_head * (n_embd_head_qk_nope + n_embd_head_v_mla)}, 0); + } + + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_head * n_embd_head_v_mla, n_embd}, 0); layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); @@ -4290,6 +4306,8 @@ void llama_model::print_info() const { LLAMA_LOG_INFO("%s: n_layer_dense_lead = %d\n", __func__, hparams.n_layer_dense_lead); LLAMA_LOG_INFO("%s: n_lora_q = %d\n", __func__, hparams.n_lora_q); LLAMA_LOG_INFO("%s: n_lora_kv = %d\n", __func__, hparams.n_lora_kv); + LLAMA_LOG_INFO("%s: n_embd_head_k_mla = %d\n", __func__, hparams.n_embd_head_k_mla); + LLAMA_LOG_INFO("%s: n_embd_head_v_mla = %d\n", __func__, hparams.n_embd_head_v_mla); LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp); LLAMA_LOG_INFO("%s: n_expert_shared = %d\n", __func__, hparams.n_expert_shared); LLAMA_LOG_INFO("%s: expert_weights_scale = %.1f\n", __func__, hparams.expert_weights_scale); @@ -4496,7 +4514,7 @@ struct llm_build_llama : public llm_graph_context { cur = build_attn(inp_attn, gf, model.layers[il].wo, model.layers[il].bo, - Qcur, Kcur, Vcur, nullptr, kq_scale, il); + Qcur, Kcur, Vcur, nullptr, nullptr, kq_scale, il); cb(cur, "attn_out", il); } @@ -4709,7 +4727,7 @@ struct llm_build_deci : public llm_graph_context { cur = build_attn(inp_attn, gf, model.layers[il].wo, model.layers[il].bo, - Qcur, Kcur, Vcur, nullptr, kq_scale, il); + Qcur, Kcur, Vcur, nullptr, nullptr, kq_scale, il); } if (il == n_layer - 1) { @@ -4851,7 +4869,7 @@ struct llm_build_baichuan : public llm_graph_context { cur = build_attn(inp_attn, gf, model.layers[il].wo, NULL, - Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1) { @@ -4966,7 +4984,7 @@ struct llm_build_xverse : public llm_graph_context { cur = build_attn(inp_attn, gf, model.layers[il].wo, NULL, - Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1) { @@ -5091,7 +5109,7 @@ struct llm_build_falcon : public llm_graph_context { cur = build_attn(inp_attn, gf, model.layers[il].wo, NULL, - Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1) { @@ -5221,7 +5239,7 @@ struct llm_build_grok : public llm_graph_context { cur = build_attn(inp_attn, gf, model.layers[il].wo, model.layers[il].bo, - Qcur, Kcur, Vcur, nullptr, 1.0f, il); + Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f, il); } if (il == n_layer - 1) { @@ -5372,7 +5390,7 @@ struct llm_build_dbrx : public llm_graph_context { cur = build_attn(inp_attn, gf, model.layers[il].wo, NULL, - Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1) { @@ -5486,7 +5504,7 @@ struct llm_build_starcoder : public llm_graph_context { cur = build_attn(inp_attn, gf, model.layers[il].wo, model.layers[il].bo, - Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1) { @@ -5585,7 +5603,7 @@ struct llm_build_refact : public llm_graph_context { cur = build_attn(inp_attn, gf, model.layers[il].wo, NULL, - Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1) { @@ -5739,7 +5757,7 @@ struct llm_build_bert : public llm_graph_context { cur = build_attn(inp_attn, gf, model.layers[il].wo, model.layers[il].bo, - Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); cb(cur, "kqv_out", il); if (il == n_layer - 1 && pooling_type == LLAMA_POOLING_TYPE_NONE) { @@ -5856,7 +5874,7 @@ struct llm_build_bloom : public llm_graph_context { cur = build_attn(inp_attn, gf, model.layers[il].wo, model.layers[il].bo, - Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1) { @@ -5997,7 +6015,7 @@ struct llm_build_mpt : public llm_graph_context { cur = build_attn(inp_attn, gf, model.layers[il].wo, model.layers[il].bo, - Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1) { @@ -6143,7 +6161,7 @@ struct llm_build_stablelm : public llm_graph_context { cur = build_attn(inp_attn, gf, model.layers[il].wo, NULL, - Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1) { @@ -6266,7 +6284,7 @@ struct llm_build_qwen : public llm_graph_context { cur = build_attn(inp_attn, gf, model.layers[il].wo, NULL, - Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1) { @@ -6386,7 +6404,7 @@ struct llm_build_qwen2 : public llm_graph_context { cur = build_attn(inp_attn, gf, model.layers[il].wo, model.layers[il].bo, - Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1) { @@ -6507,7 +6525,7 @@ struct llm_build_qwen2vl : public llm_graph_context { cur = build_attn(inp_attn, gf, model.layers[il].wo, model.layers[il].bo, - Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1) { @@ -6634,7 +6652,7 @@ struct llm_build_qwen2moe : public llm_graph_context { cur = build_attn(inp_attn, gf, model.layers[il].wo, model.layers[il].bo, - Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1) { @@ -6787,7 +6805,7 @@ struct llm_build_qwen3 : public llm_graph_context { cur = build_attn(inp_attn, gf, model.layers[il].wo, model.layers[il].bo, - Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1) { @@ -6908,7 +6926,7 @@ struct llm_build_qwen3moe : public llm_graph_context { cur = build_attn(inp_attn, gf, model.layers[il].wo, model.layers[il].bo, - Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1) { @@ -7048,7 +7066,7 @@ struct llm_build_phi2 : public llm_graph_context { cur = build_attn(inp_attn, gf, model.layers[il].wo, model.layers[il].bo, - Qcur, Kcur, Vcur, nullptr, 1.0f, il); + Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f, il); } if (il == n_layer - 1) { @@ -7177,7 +7195,7 @@ struct llm_build_phi3 : public llm_graph_context { cur = build_attn(inp_attn, gf, model.layers[il].wo, model.layers[il].bo, - Qcur, Kcur, Vcur, nullptr, 1.0f, il); + Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f, il); } if (il == n_layer - 1) { @@ -7312,7 +7330,7 @@ struct llm_build_plamo : public llm_graph_context { cur = build_attn(inp_attn, gf, model.layers[il].wo, NULL, - Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } ggml_tensor * sa_out = cur; @@ -7419,7 +7437,7 @@ struct llm_build_gpt2 : public llm_graph_context { cur = build_attn(inp_attn, gf, model.layers[il].wo, model.layers[il].bo, - Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1) { @@ -7535,7 +7553,7 @@ struct llm_build_codeshell : public llm_graph_context { cur = build_attn(inp_attn, gf, model.layers[il].wo, model.layers[il].bo, - Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1) { @@ -7664,7 +7682,7 @@ struct llm_build_orion : public llm_graph_context { cur = build_attn(inp_attn, gf, model.layers[il].wo, NULL, - Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1) { @@ -7791,7 +7809,7 @@ struct llm_build_internlm2 : public llm_graph_context { cur = build_attn(inp_attn, gf, model.layers[il].wo, model.layers[il].bo, - Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1) { @@ -7988,7 +8006,7 @@ struct llm_build_minicpm3 : public llm_graph_context { cur = build_attn(inp_attn, gf, model.layers[il].wo, NULL, - q_states, k_states, v_states, nullptr, kq_scale, il); + q_states, k_states, v_states, nullptr, nullptr, kq_scale, il); } if (il == n_layer - 1) { @@ -8118,7 +8136,7 @@ struct llm_build_gemma : public llm_graph_context { cur = build_attn(inp_attn, gf, model.layers[il].wo, NULL, - Qcur, Kcur, Vcur, nullptr, 1.0f, il); + Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f, il); } if (il == n_layer - 1) { @@ -8240,7 +8258,7 @@ struct llm_build_gemma2 : public llm_graph_context { cur = build_attn(inp_attn, gf, model.layers[il].wo, NULL, - Qcur, Kcur, Vcur, nullptr, 1.0f, il); + Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f, il); } cur = build_norm(cur, @@ -8381,7 +8399,7 @@ struct llm_build_gemma3 : public llm_graph_context { cur = build_attn(inp_attn, gf, model.layers[il].wo, NULL, - Qcur, Kcur, Vcur, nullptr, hparams.f_attention_scale, il); + Qcur, Kcur, Vcur, nullptr, nullptr, hparams.f_attention_scale, il); } cur = build_norm(cur, @@ -8521,7 +8539,7 @@ struct llm_build_starcoder2 : public llm_graph_context { cur = build_attn(inp_attn, gf, model.layers[il].wo, model.layers[il].bo, - Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1) { @@ -8856,7 +8874,7 @@ struct llm_build_command_r : public llm_graph_context { cur = build_attn(inp_attn, gf, model.layers[il].wo, model.layers[il].bo, - Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1) { @@ -8991,7 +9009,7 @@ struct llm_build_cohere2 : public llm_graph_context { cur = build_attn(inp_attn, gf, model.layers[il].wo, model.layers[il].bo, - Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1) { @@ -9122,7 +9140,7 @@ struct llm_build_olmo : public llm_graph_context { cur = build_attn(inp_attn, gf, model.layers[il].wo, nullptr, - Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1) { @@ -9242,7 +9260,7 @@ struct llm_build_olmo2 : public llm_graph_context { cur = build_attn(inp_attn, gf, model.layers[il].wo, NULL, - Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } cur = build_norm(cur, @@ -9375,7 +9393,7 @@ struct llm_build_olmoe : public llm_graph_context { cur = build_attn(inp_attn, gf, model.layers[il].wo, NULL, - Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1) { @@ -9508,7 +9526,7 @@ struct llm_build_openelm : public llm_graph_context { cur = build_attn(inp_attn, gf, model.layers[il].wo, NULL, - Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1) { @@ -9622,7 +9640,7 @@ struct llm_build_gptneox : public llm_graph_context { cur = build_attn(inp_attn, gf, model.layers[il].wo, model.layers[il].bo, - Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1) { @@ -9772,7 +9790,7 @@ struct llm_build_arctic : public llm_graph_context { cur = build_attn(inp_attn, gf, model.layers[il].wo, NULL, - Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1) { @@ -9927,7 +9945,7 @@ struct llm_build_deepseek : public llm_graph_context { cur = build_attn(inp_attn, gf, model.layers[il].wo, model.layers[il].bo, - Qcur, Kcur, Vcur, nullptr, kq_scale, il); + Qcur, Kcur, Vcur, nullptr, nullptr, kq_scale, il); } if (il == n_layer - 1) { @@ -10017,16 +10035,23 @@ struct llm_build_deepseek2 : public llm_graph_context { llm_build_deepseek2(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { bool is_lite = (hparams.n_layer == 27); + const bool is_mla = (hparams.n_embd_head_k_mla != 0 && hparams.n_embd_head_v_mla != 0); + + // note: these are the actual head sizes you get when treating as MHA or after "decompression" using wv_b for MLA + const int64_t n_embd_head_k = is_mla ? hparams.n_embd_head_k_mla : hparams.n_embd_head_k; + const int64_t n_embd_head_v = is_mla ? hparams.n_embd_head_v_mla : hparams.n_embd_head_v; + + const int64_t n_embd_head_qk_rope = hparams.n_rot; + const int64_t n_embd_head_qk_nope = n_embd_head_k - n_embd_head_qk_rope; + + const uint32_t kv_lora_rank = hparams.n_lora_kv; + // We have to pre-scale kq_scale and attn_factor to make the YaRN RoPE work correctly. // See https://github.com/ggerganov/llama.cpp/discussions/7416 for detailed explanation. const float mscale = attn_factor * (1.0f + hparams.rope_yarn_log_mul * logf(1.0f / freq_scale)); - const float kq_scale = 1.0f*mscale*mscale/sqrtf(float(hparams.n_embd_head_k)); + const float kq_scale = 1.0f*mscale*mscale/sqrtf(float(n_embd_head_k)); const float attn_factor_scaled = 1.0f / (1.0f + 0.1f * logf(1.0f / freq_scale)); - const uint32_t n_embd_head_qk_rope = hparams.n_rot; - const uint32_t n_embd_head_qk_nope = hparams.n_embd_head_k - hparams.n_rot; - const uint32_t kv_lora_rank = hparams.n_lora_kv; - ggml_tensor * cur; ggml_tensor * inpL; @@ -10051,16 +10076,14 @@ struct llm_build_deepseek2 : public llm_graph_context { { ggml_tensor * q = NULL; if (!is_lite) { - // {n_embd, q_lora_rank} * {n_embd, n_tokens} -> {q_lora_rank, n_tokens} q = ggml_mul_mat(ctx0, model.layers[il].wq_a, cur); cb(q, "q", il); q = build_norm(q, - model.layers[il].attn_q_a_norm, NULL, + model.layers[il].attn_q_a_norm, nullptr, LLM_NORM_RMS, il); cb(q, "q", il); - // {q_lora_rank, n_head * hparams.n_embd_head_k} * {q_lora_rank, n_tokens} -> {n_head * hparams.n_embd_head_k, n_tokens} q = ggml_mul_mat(ctx0, model.layers[il].wq_b, q); cb(q, "q", il); } else { @@ -10068,96 +10091,125 @@ struct llm_build_deepseek2 : public llm_graph_context { cb(q, "q", il); } - // split into {n_head * n_embd_head_qk_nope, n_tokens} - ggml_tensor * q_nope = ggml_view_3d(ctx0, q, n_embd_head_qk_nope, n_head, n_tokens, - ggml_row_size(q->type, hparams.n_embd_head_k), - ggml_row_size(q->type, hparams.n_embd_head_k * n_head), + // split into {n_embd_head_qk_nope, n_head, n_tokens} + ggml_tensor * q_nope = ggml_view_3d(ctx0, q, + n_embd_head_qk_nope, n_head, n_tokens, + ggml_row_size(q->type, n_embd_head_k), + ggml_row_size(q->type, n_embd_head_k) * n_head, 0); cb(q_nope, "q_nope", il); - // and {n_head * n_embd_head_qk_rope, n_tokens} - ggml_tensor * q_pe = ggml_view_3d(ctx0, q, n_embd_head_qk_rope, n_head, n_tokens, - ggml_row_size(q->type, hparams.n_embd_head_k), - ggml_row_size(q->type, hparams.n_embd_head_k * n_head), + // and {n_embd_head_qk_rope, n_head, n_tokens} + ggml_tensor * q_pe = ggml_view_3d(ctx0, q, + n_embd_head_qk_rope, n_head, n_tokens, + ggml_row_size(q->type, n_embd_head_k), + ggml_row_size(q->type, n_embd_head_k) * n_head, ggml_row_size(q->type, n_embd_head_qk_nope)); cb(q_pe, "q_pe", il); - // {n_embd, kv_lora_rank + n_embd_head_qk_rope} * {n_embd, n_tokens} -> {kv_lora_rank + n_embd_head_qk_rope, n_tokens} - ggml_tensor * kv_pe_compresseed = ggml_mul_mat(ctx0, model.layers[il].wkv_a_mqa, cur); - cb(kv_pe_compresseed, "kv_pe_compresseed", il); + ggml_tensor * kv_cmpr_pe = ggml_mul_mat(ctx0, model.layers[il].wkv_a_mqa, cur); + cb(kv_cmpr_pe, "kv_cmpr_pe", il); // split into {kv_lora_rank, n_tokens} - ggml_tensor * kv_compressed = ggml_view_2d(ctx0, kv_pe_compresseed, kv_lora_rank, n_tokens, - kv_pe_compresseed->nb[1], + ggml_tensor * kv_cmpr = ggml_view_2d(ctx0, kv_cmpr_pe, + kv_lora_rank, n_tokens, + ggml_row_size(kv_cmpr_pe->type, kv_lora_rank + n_embd_head_qk_rope), 0); - cb(kv_compressed, "kv_compressed", il); + cb(kv_cmpr, "kv_cmpr", il); - // and {n_embd_head_qk_rope, n_tokens} - ggml_tensor * k_pe = ggml_view_3d(ctx0, kv_pe_compresseed, n_embd_head_qk_rope, 1, n_tokens, - kv_pe_compresseed->nb[1], - kv_pe_compresseed->nb[1], - ggml_row_size(kv_pe_compresseed->type, kv_lora_rank)); + // and {n_embd_head_qk_rope, 1, n_tokens} + ggml_tensor * k_pe = ggml_view_3d(ctx0, kv_cmpr_pe, + n_embd_head_qk_rope, 1, n_tokens, + ggml_row_size(kv_cmpr_pe->type, kv_lora_rank + n_embd_head_qk_rope), + ggml_row_size(kv_cmpr_pe->type, kv_lora_rank + n_embd_head_qk_rope), + ggml_row_size(kv_cmpr_pe->type, kv_lora_rank)); cb(k_pe, "k_pe", il); - // TODO: the CUDA backend used to not support non-cont. (RMS) norm, investigate removing ggml_cont - kv_compressed = ggml_cont(ctx0, kv_compressed); - kv_compressed = build_norm(kv_compressed, - model.layers[il].attn_kv_a_norm, NULL, - LLM_NORM_RMS, il); - cb(kv_compressed, "kv_compressed", il); - - // {kv_lora_rank, n_head * (n_embd_head_qk_nope + n_embd_head_v)} * {kv_lora_rank, n_tokens} -> {n_head * (n_embd_head_qk_nope + n_embd_head_v), n_tokens} - ggml_tensor * kv = ggml_mul_mat(ctx0, model.layers[il].wkv_b, kv_compressed); - cb(kv, "kv", il); - - // split into {n_head * n_embd_head_qk_nope, n_tokens} - ggml_tensor * k_nope = ggml_view_3d(ctx0, kv, n_embd_head_qk_nope, n_head, n_tokens, - ggml_row_size(kv->type, n_embd_head_qk_nope + hparams.n_embd_head_v), - ggml_row_size(kv->type, n_head * (n_embd_head_qk_nope + hparams.n_embd_head_v)), - 0); - cb(k_nope, "k_nope", il); - - // and {n_head * n_embd_head_v, n_tokens} - ggml_tensor * v_states = ggml_view_3d(ctx0, kv, hparams.n_embd_head_v, n_head, n_tokens, - ggml_row_size(kv->type, (n_embd_head_qk_nope + hparams.n_embd_head_v)), - ggml_row_size(kv->type, (n_embd_head_qk_nope + hparams.n_embd_head_v)*n_head), - ggml_row_size(kv->type, (n_embd_head_qk_nope))); - cb(v_states, "v_states", il); - - v_states = ggml_cont(ctx0, v_states); - cb(v_states, "v_states", il); - - v_states = ggml_view_2d(ctx0, v_states, hparams.n_embd_head_v * n_head, n_tokens, - ggml_row_size(kv->type, hparams.n_embd_head_v * n_head), - 0); - cb(v_states, "v_states", il); - - q_pe = ggml_cont(ctx0, q_pe); // TODO: the CUDA backend used to not support non-cont. RoPE, investigate removing this - q_pe = ggml_rope_ext( - ctx0, q_pe, inp_pos, nullptr, + q_pe = ggml_rope_ext(ctx0, q_pe, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor_scaled, beta_fast, beta_slow - ); + ); cb(q_pe, "q_pe", il); - // shared RoPE key - k_pe = ggml_cont(ctx0, k_pe); // TODO: the CUDA backend used to not support non-cont. RoPE, investigate removing this - k_pe = ggml_rope_ext( - ctx0, k_pe, inp_pos, nullptr, + k_pe = ggml_rope_ext(ctx0, k_pe, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor_scaled, beta_fast, beta_slow - ); + ); cb(k_pe, "k_pe", il); - ggml_tensor * q_states = ggml_concat(ctx0, q_nope, q_pe, 0); - cb(q_states, "q_states", il); + kv_cmpr = build_norm(kv_cmpr, + model.layers[il].attn_kv_a_norm, nullptr, + LLM_NORM_RMS, il); + cb(kv_cmpr, "kv_cmpr", il); - ggml_tensor * k_states = ggml_concat(ctx0, k_nope, ggml_repeat(ctx0, k_pe, q_pe), 0); - cb(k_states, "k_states", il); + if (is_mla) { + // {n_embd_head_qk_nope, n_tokens, n_head} + q_nope = ggml_permute(ctx0, q_nope, 0, 2, 1, 3); + cb(q_nope, "q_nope_perm", il); - cur = build_attn(inp_attn, gf, - model.layers[il].wo, NULL, - q_states, k_states, v_states, nullptr, kq_scale, il); + // {n_embd_head_qk_nope, kv_lora_rank, n_head} x {n_embd_head_qk_nope, n_tokens, n_head} + ggml_tensor * q_nope_absorbed = ggml_mul_mat(ctx0, model.layers[il].wk_b, q_nope); + cb(q_nope_absorbed, "q_nope_absorbed", il); + + // {kv_lora_rank, n_head, n_tokens} + q_nope_absorbed = ggml_permute(ctx0, q_nope_absorbed, 0, 2, 1, 3); + cb(q_nope_absorbed, "q_nope_absorbed_perm", il); + + // {n_embd_head_qk_rope + kv_lora_rank, n_head, n_tokens} + // note: rope must go first for in-place context shifting in build_rope_shift() + ggml_tensor * Qcur = ggml_concat(ctx0, q_pe, q_nope_absorbed, 0); + cb(Qcur, "Qcur", il); + + kv_cmpr = ggml_reshape_3d(ctx0, kv_cmpr, kv_lora_rank, 1, n_tokens); + cb(kv_cmpr, "kv_cmpr_reshape", il); + + // {n_embd_head_qk_rope + kv_lora_rank, 1, n_tokens} + ggml_tensor * Kcur = ggml_concat(ctx0, k_pe, kv_cmpr, 0); + cb(Kcur, "Kcur", il); + + // {kv_lora_rank, 1, n_tokens} + ggml_tensor * Vcur = kv_cmpr; + cb(Vcur, "Vcur", il); + + // note: MLA with the absorption optimzation converts into MQA (ie: GQA with 1 group) + cur = build_attn(inp_attn, gf, + model.layers[il].wo, NULL, + Qcur, Kcur, Vcur, nullptr, model.layers[il].wv_b, kq_scale, il); + } else { + ggml_tensor * kv = ggml_mul_mat(ctx0, model.layers[il].wkv_b, kv_cmpr); + cb(kv, "kv", il); + + // split into {n_embd_head_qk_nope, n_head, n_tokens} + ggml_tensor * k_nope = ggml_view_3d(ctx0, kv, + n_embd_head_qk_nope, n_head, n_tokens, + ggml_row_size(kv->type, n_embd_head_qk_nope + n_embd_head_v), + ggml_row_size(kv->type, n_embd_head_qk_nope + n_embd_head_v) * n_head, + 0); + cb(k_nope, "k_nope_view", il); + + // and {n_embd_head_v, n_head, n_tokens} + ggml_tensor * Vcur = ggml_view_3d(ctx0, kv, + n_embd_head_v, n_head, n_tokens, + ggml_row_size(kv->type, n_embd_head_qk_nope + n_embd_head_v), + ggml_row_size(kv->type, n_embd_head_qk_nope + n_embd_head_v) * n_head, + ggml_row_size(kv->type, n_embd_head_qk_nope)); + cb(Vcur, "Vcur_view", il); + + Vcur = ggml_cont(ctx0, Vcur); + cb(Vcur, "Vcur_cont", il); + + // note: rope must go first for in-place context shifting in build_rope_shift() + ggml_tensor * Qcur = ggml_concat(ctx0, q_pe, q_nope, 0); + cb(Qcur, "Qcur", il); + + ggml_tensor * Kcur = ggml_concat(ctx0, ggml_repeat(ctx0, k_pe, q_pe), k_nope, 0); + cb(Kcur, "Kcur", il); + + // note: MLA without the absorption optimization converts into MHA (ie: GQA with full n_head groups) + cur = build_attn(inp_attn, gf, + model.layers[il].wo, NULL, + Qcur, Kcur, Vcur, nullptr, nullptr, kq_scale, il); + } } if (il == n_layer - 1) { @@ -10323,7 +10375,7 @@ struct llm_build_bitnet : public llm_graph_context { cur = build_attn(inp_attn, gf, NULL, NULL, - Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); cur = build_norm(cur, model.layers[il].attn_sub_norm, NULL, @@ -10446,7 +10498,7 @@ struct llm_build_t5_enc : public llm_graph_context { cur = build_attn(inp_attn, gf, model.layers[il].wo_enc, nullptr, - Qcur, Kcur, Vcur, kq_b, 1.0f, il); + Qcur, Kcur, Vcur, kq_b, nullptr, 1.0f, il); cb(cur, "kqv_out", il); } @@ -10552,7 +10604,7 @@ struct llm_build_t5_dec : public llm_graph_context { cur = build_attn(inp_attn_self, gf, model.layers[il].wo, model.layers[il].bo, - Qcur, Kcur, Vcur, kq_b, 1.0f, il); + Qcur, Kcur, Vcur, kq_b, nullptr, 1.0f, il); cb(cur, "kqv_out", il); } @@ -10584,7 +10636,7 @@ struct llm_build_t5_dec : public llm_graph_context { cur = build_attn(inp_attn_cross, gf, model.layers[il].wo_cross, nullptr, - Qcur, Kcur, Vcur, nullptr, 1.0f, il); + Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f, il); cb(cur, "kqv_out", il); //ggml_tensor * q = ggml_permute(ctx0, Qcur, 0, 2, 1, 3); @@ -10717,7 +10769,7 @@ struct llm_build_jais : public llm_graph_context { cur = build_attn(inp_attn, gf, model.layers[il].wo, model.layers[il].bo, - Qcur, Kcur, Vcur, nullptr, 1.0f/float(n_embd_head), il); + Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/float(n_embd_head), il); } if (il == n_layer - 1) { @@ -10849,7 +10901,7 @@ struct llm_build_chatglm : public llm_graph_context { cur = build_attn(inp_attn, gf, model.layers[il].wo, NULL, - Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1) { @@ -10982,7 +11034,7 @@ struct llm_build_glm4 : public llm_graph_context { cur = build_attn(inp_attn, gf, model.layers[il].wo, NULL, - Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1) { @@ -11126,7 +11178,7 @@ struct llm_build_nemotron : public llm_graph_context { cur = build_attn(inp_attn, gf, model.layers[il].wo, model.layers[il].bo, - Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1) { @@ -11257,7 +11309,7 @@ struct llm_build_exaone : public llm_graph_context { cur = build_attn(inp_attn, gf, model.layers[il].wo, model.layers[il].bo, - Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1) { @@ -12159,7 +12211,7 @@ struct llm_build_chameleon : public llm_graph_context { cur = build_attn(inp_attn, gf, model.layers[il].wo, nullptr, - Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); if (hparams.swin_norm) { cur = build_norm(cur, @@ -12515,7 +12567,7 @@ struct llm_build_plm : public llm_graph_context { cur = build_attn(inp_attn, gf, model.layers[il].wo, NULL, - q_states, k_states, v_states, nullptr, kq_scale, il); + q_states, k_states, v_states, nullptr, nullptr, kq_scale, il); } if (il == n_layer - 1) { @@ -12638,7 +12690,7 @@ struct llm_build_bailingmoe : public llm_graph_context { cur = build_attn(inp_attn, gf, model.layers[il].wo, model.layers[il].bo, - Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_rot)), il); + Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_rot)), il); } if (il == n_layer - 1) { diff --git a/src/llama-model.h b/src/llama-model.h index 0f18dac16..fd82d106c 100644 --- a/src/llama-model.h +++ b/src/llama-model.h @@ -171,6 +171,8 @@ struct llama_layer { struct ggml_tensor * wq_b = nullptr; struct ggml_tensor * wkv_a_mqa = nullptr; struct ggml_tensor * wkv_b = nullptr; + struct ggml_tensor * wk_b = nullptr; + struct ggml_tensor * wv_b = nullptr; struct ggml_tensor * wq_cross = nullptr; struct ggml_tensor * wk_cross = nullptr; struct ggml_tensor * wv_cross = nullptr; From 510676475f885ec064ff147af9f20ee7a9b12a50 Mon Sep 17 00:00:00 2001 From: Akarshan Biswas Date: Tue, 15 Apr 2025 14:07:42 +0530 Subject: [PATCH 15/20] SYCL: Add ROPE vision kernel (#12887) * SYCL: Add ROPE vision kernel * Add comment about rope mode --- ggml/src/ggml-sycl/ggml-sycl.cpp | 6 +- ggml/src/ggml-sycl/rope.cpp | 106 ++++++++++++++++++++++++++++++- 2 files changed, 105 insertions(+), 7 deletions(-) diff --git a/ggml/src/ggml-sycl/ggml-sycl.cpp b/ggml/src/ggml-sycl/ggml-sycl.cpp index 09800d340..4d2fda0bf 100644 --- a/ggml/src/ggml-sycl/ggml-sycl.cpp +++ b/ggml/src/ggml-sycl/ggml-sycl.cpp @@ -4009,10 +4009,8 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g case GGML_OP_ROPE: { const int mode = ((const int32_t *) op->op_params)[2]; - if (mode & GGML_ROPE_TYPE_MROPE) { - return false; - } - if (mode & GGML_ROPE_TYPE_VISION) { + // mode is not used as a bitmask in practice, the various rope type modes are independent implementations + if (mode == GGML_ROPE_TYPE_MROPE) { return false; } return ggml_is_contiguous(op->src[0]); diff --git a/ggml/src/ggml-sycl/rope.cpp b/ggml/src/ggml-sycl/rope.cpp index bbcb356e9..80e050f24 100644 --- a/ggml/src/ggml-sycl/rope.cpp +++ b/ggml/src/ggml-sycl/rope.cpp @@ -1,9 +1,15 @@ #include "rope.hpp" +#include "ggml-sycl/common.hpp" +#include "ggml.h" struct rope_corr_dims { float v[2]; }; +struct mrope_sections { + int v[4]; +}; + static float rope_yarn_ramp(const float low, const float high, const int i0) { const float y = (i0 / 2 - low) / sycl::max(0.001f, high - low); return 1.0f - sycl::min(1.0f, sycl::max(0.0f, y)); @@ -114,6 +120,48 @@ static void rope_neox( dst[i + n_dims/2] = x0*sin_theta + x1*cos_theta; } +template +static void rope_vision(const T * x, T * dst, const int ne0, const int ne1, const int ne2, const size_t s1, + const size_t s2, const int n_dims, const int32_t * pos, const float freq_scale, + const float ext_factor, const float attn_factor, const rope_corr_dims corr_dims, + const float theta_scale, const float * freq_factors, const mrope_sections sections, + const sycl::nd_item<3> & item_ct1) { + // get index pos + const int i0 = 2 * (item_ct1.get_group(1) * item_ct1.get_local_range(1) + item_ct1.get_local_id(1)); + if (i0 >= ne0) { + return; + } + const int row_dst = (item_ct1.get_group(2) * item_ct1.get_local_range(2)) + item_ct1.get_local_id(2); + const int row_x = row_dst % ne1; + const int channel_x = row_dst / ne1; + const int idst = (row_dst * ne0) + (i0 / 2); + const size_t ix = ((size_t) channel_x * s2) + ((size_t) row_x * s1) + (i0 / 2); + + const int sect_dims = sections.v[0] + sections.v[1]; + const int sector = (i0 / 2) % sect_dims; + + float theta_base = 0.0f; + if (sector < sections.v[0]) { + const int p = sector; + theta_base = pos[channel_x] * sycl::pow(theta_scale, (float) p); + } else { + // Simplified from CUDA backend code: if (sector >= sections.v[0] && sector < sec_w) which is just sector >= sections.v[0] + const int p = sector - sections.v[0]; + theta_base = pos[channel_x + ne2] * sycl::pow(theta_scale, (float) p); + } + + const float freq_factor = has_ff ? freq_factors[i0 / 2] : 1.0f; + float cos_theta; + float sin_theta; + rope_yarn(theta_base / freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta); + const float x0 = x[ix + 0]; + const float x1 = x[ix + n_dims]; + + // store results in dst + dst[idst + 0] = x0 * cos_theta - x1 * sin_theta; + dst[idst + n_dims] = x0 * sin_theta + x1 * cos_theta; +} + template static void rope_norm_sycl( const T *x, T *dst, int ne0, int n_dims, int nr, const int32_t *pos, float freq_scale, int p_delta_rows, @@ -192,21 +240,58 @@ static void rope_neox_sycl( } } +// rope vision +template +static void rope_vision_sycl(const T * x, T * dst, const int ne0, const int ne1, const int ne2, const size_t s1, + const size_t s2, const int n_dims, const int nr, const int32_t * pos, + const float freq_scale, const float freq_base, const float ext_factor, + const float attn_factor, const rope_corr_dims corr_dims, const float * freq_factors, + const mrope_sections sections, queue_ptr stream) { + GGML_ASSERT(ne0 % 2 == 0); + const sycl::range<3> block_dims(1, SYCL_ROPE_BLOCK_SIZE, 1); + const int n_blocks_y = (ne0 + 2 * SYCL_ROPE_BLOCK_SIZE - 1) / (2 * SYCL_ROPE_BLOCK_SIZE); + const sycl::range<3> grid_dims(1, n_blocks_y, nr); + const sycl::nd_range<3> nd_range(grid_dims * block_dims, block_dims); + + const float theta_scale = std::pow(freq_base, -2.0f / n_dims); + // Add FP16 capability check if T could be sycl::half + if constexpr (std::is_same_v) { + dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 }); + } + // launch kernel + if (freq_factors == nullptr) { + stream->parallel_for(nd_range, [=](sycl::nd_item<3> item_ct1) { + rope_vision(x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, + corr_dims, theta_scale, freq_factors, sections, item_ct1); + }); + } else { + stream->parallel_for(nd_range, [=](sycl::nd_item<3> item_ct1) { + rope_vision(x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, + corr_dims, theta_scale, freq_factors, sections, item_ct1); + }); + } +} + void ggml_sycl_op_rope(ggml_backend_sycl_context & ctx, ggml_tensor *dst) { GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16); GGML_ASSERT( dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); GGML_ASSERT(dst->src[0]->type == dst->type); - - const int64_t ne00 = dst->src[0]->ne[0]; - const int64_t ne01 = dst->src[0]->ne[1]; + const int64_t ne00 = dst->src[0]->ne[0]; // head dims + const int64_t ne01 = dst->src[0]->ne[1]; // num heads + const int64_t ne02 = dst->src[0]->ne[2]; // num heads const int64_t nr = ggml_nrows(dst->src[0]); + const size_t s01 = dst->src[0]->nb[1] / ggml_type_size(dst->src[0]->type); + const size_t s02 = dst->src[0]->nb[2] / ggml_type_size(dst->src[0]->type); + + //const int n_past = ((int32_t *) dst->op_params)[0]; const int n_dims = ((int32_t *) dst->op_params)[1]; const int mode = ((int32_t *) dst->op_params)[2]; //const int n_ctx = ((int32_t *) dst->op_params)[3]; const int n_ctx_orig = ((int32_t *) dst->op_params)[4]; + mrope_sections sections; // RoPE alteration for extended context float freq_base; @@ -222,8 +307,10 @@ void ggml_sycl_op_rope(ggml_backend_sycl_context & ctx, ggml_tensor *dst) { memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float)); memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float)); memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float)); + memcpy(§ions.v, (int32_t *) dst->op_params + 11, sizeof(int)*4); const bool is_neox = mode & GGML_ROPE_TYPE_NEOX; + const bool is_vision = mode == GGML_ROPE_TYPE_VISION; const int32_t * pos = (const int32_t *) dst->src[1]->data; @@ -240,6 +327,7 @@ void ggml_sycl_op_rope(ggml_backend_sycl_context & ctx, ggml_tensor *dst) { // compute if (is_neox) { + GGML_SYCL_DEBUG("%s: neox path\n", __func__); if (dst->src[0]->type == GGML_TYPE_F32) { rope_neox_sycl( (const float *)dst->src[0]->data, (float *)dst->data, ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor, @@ -253,7 +341,19 @@ void ggml_sycl_op_rope(ggml_backend_sycl_context & ctx, ggml_tensor *dst) { } else { GGML_ABORT("fatal error"); } + } else if (is_vision) { + GGML_SYCL_DEBUG("%s: vision path\n", __func__); + if (dst->src[0]->type == GGML_TYPE_F16) { + rope_vision_sycl((const sycl::half *)dst->src[0]->data, (sycl::half *)dst->data, ne00, ne01, ne02, s01, s02, n_dims, nr, pos, freq_scale, + freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, main_stream); + } else if (dst->src[0]->type == GGML_TYPE_F32) { + rope_vision_sycl((const float *) dst->src[0]->data, (float *)dst->data, ne00, ne01, ne02, s01, s02, n_dims, nr, pos, freq_scale, + freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, main_stream); + } else { + GGML_ABORT("Fatal error: Tensor type unsupported!"); + } } else { + GGML_SYCL_DEBUG("%s: norm path\n", __func__); if (dst->src[0]->type == GGML_TYPE_F32) { rope_norm_sycl( (const float *)dst->src[0]->data, (float *)dst->data, ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor, From 84778e97703740d8ac5fb64e14d83b80eafa0f3c Mon Sep 17 00:00:00 2001 From: David Huang <1969802+hjc4869@users.noreply.github.com> Date: Tue, 15 Apr 2025 17:20:38 +0800 Subject: [PATCH 16/20] CUDA/HIP: Share the same unified memory allocation logic. (#12934) Replace compile-time `GGML_HIP_UMA` with environment variable `GGML_CUDA_ENABLE_UNIFIED_MEMORY`. This unifies the usage on NVIDIA and AMD GPUs, and allows a single binary to be shared between integrated and dedicated GPUs. --- Makefile | 4 ---- docs/build.md | 6 ++++-- ggml/CMakeLists.txt | 1 - ggml/src/ggml-cuda/ggml-cuda.cu | 31 ++++++++++++++++--------------- ggml/src/ggml-cuda/vendors/hip.h | 2 ++ ggml/src/ggml-hip/CMakeLists.txt | 4 ---- 6 files changed, 22 insertions(+), 26 deletions(-) diff --git a/Makefile b/Makefile index 1f9455eff..772993ada 100644 --- a/Makefile +++ b/Makefile @@ -780,10 +780,6 @@ ifdef GGML_HIP MK_CPPFLAGS += -DGGML_USE_HIP -DGGML_USE_CUDA -ifdef GGML_HIP_UMA - MK_CPPFLAGS += -DGGML_HIP_UMA -endif # GGML_HIP_UMA - MK_LDFLAGS += -L$(ROCM_PATH)/lib -Wl,-rpath=$(ROCM_PATH)/lib MK_LDFLAGS += -L$(ROCM_PATH)/lib64 -Wl,-rpath=$(ROCM_PATH)/lib64 MK_LDFLAGS += -lhipblas -lamdhip64 -lrocblas diff --git a/docs/build.md b/docs/build.md index 3f1b04399..c9027c0b5 100644 --- a/docs/build.md +++ b/docs/build.md @@ -259,8 +259,6 @@ You can download it from your Linux distro's package manager or from here: [ROCm cmake -S . -B build -DGGML_HIP=ON -DAMDGPU_TARGETS=gfx1030 -DCMAKE_BUILD_TYPE=Release \ && cmake --build build --config Release -- -j 16 ``` - On Linux it is also possible to use unified memory architecture (UMA) to share main memory between the CPU and integrated GPU by setting `-DGGML_HIP_UMA=ON`. - However, this hurts performance for non-integrated GPUs (but enables working with integrated GPUs). To enhance flash attention performance on RDNA3+ or CDNA architectures, you can utilize the rocWMMA library by enabling the `-DGGML_HIP_ROCWMMA_FATTN=ON` option. This requires rocWMMA headers to be installed on the build system. @@ -296,6 +294,10 @@ You can download it from your Linux distro's package manager or from here: [ROCm The environment variable [`HIP_VISIBLE_DEVICES`](https://rocm.docs.amd.com/en/latest/understand/gpu_isolation.html#hip-visible-devices) can be used to specify which GPU(s) will be used. If your GPU is not officially supported you can use the environment variable [`HSA_OVERRIDE_GFX_VERSION`] set to a similar GPU, for example 10.3.0 on RDNA2 (e.g. gfx1030, gfx1031, or gfx1035) or 11.0.0 on RDNA3. +### Unified Memory + +On Linux it is possible to use unified memory architecture (UMA) to share main memory between the CPU and integrated GPU by setting environment variable `GGML_CUDA_ENABLE_UNIFIED_MEMORY=1`. However, this hurts performance for non-integrated GPUs (but enables working with integrated GPUs). + ## Vulkan **Windows** diff --git a/ggml/CMakeLists.txt b/ggml/CMakeLists.txt index d33f843b4..438c2a730 100644 --- a/ggml/CMakeLists.txt +++ b/ggml/CMakeLists.txt @@ -170,7 +170,6 @@ option(GGML_HIP "ggml: use HIP" option(GGML_HIP_GRAPHS "ggml: use HIP graph, experimental, slow" OFF) option(GGML_HIP_NO_VMM "ggml: do not try to use HIP VMM" ON) option(GGML_HIP_ROCWMMA_FATTN "ggml: enable rocWMMA for FlashAttention" OFF) -option(GGML_HIP_UMA "ggml: use HIP unified memory architecture" OFF) option(GGML_VULKAN "ggml: use Vulkan" OFF) option(GGML_VULKAN_CHECK_RESULTS "ggml: run Vulkan op checks" OFF) option(GGML_VULKAN_DEBUG "ggml: enable Vulkan debug output" OFF) diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 4af189701..9ced46651 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -96,31 +96,32 @@ int ggml_cuda_get_device() { static cudaError_t ggml_cuda_device_malloc(void ** ptr, size_t size, int device) { ggml_cuda_set_device(device); -#if defined(GGML_USE_HIP) && defined(GGML_HIP_UMA) - auto res = hipMallocManaged(ptr, size); - if (res == hipSuccess) { - // if error we "need" to know why... - CUDA_CHECK(hipMemAdvise(*ptr, size, hipMemAdviseSetCoarseGrain, device)); - } - return res; -#else - -#if !defined(GGML_USE_HIP) cudaError_t err; if (getenv("GGML_CUDA_ENABLE_UNIFIED_MEMORY") != nullptr) { err = cudaMallocManaged(ptr, size); +#if defined(GGML_USE_HIP) + if (err == hipSuccess) { + CUDA_CHECK(cudaMemAdvise(*ptr, size, hipMemAdviseSetCoarseGrain, device)); + } + + // fall back to cudaMalloc if not supported (e.g. on Windows) + if (err == hipErrorNotSupported) { + static bool warned_unsupported = false; + if (!warned_unsupported) { + GGML_LOG_WARN("hipMallocManaged unsupported, falling back to hipMalloc.\n"); + warned_unsupported = true; + } + + err = cudaMalloc(ptr, size); + } +#endif // defined(GGML_USE_HIP) } else { err = cudaMalloc(ptr, size); } return err; -#else - return cudaMalloc(ptr, size); -#endif // !defined(GGML_USE_HIP) - -#endif } #if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) diff --git a/ggml/src/ggml-cuda/vendors/hip.h b/ggml/src/ggml-cuda/vendors/hip.h index 420b41b8d..1a28831b7 100644 --- a/ggml/src/ggml-cuda/vendors/hip.h +++ b/ggml/src/ggml-cuda/vendors/hip.h @@ -71,6 +71,8 @@ #define cudaLaunchHostFunc hipLaunchHostFunc #define cudaMalloc hipMalloc #define cudaMallocHost(ptr, size) hipHostMalloc(ptr, size, hipHostMallocDefault) +#define cudaMallocManaged hipMallocManaged +#define cudaMemAdvise hipMemAdvise #define cudaMemcpy hipMemcpy #define cudaMemcpyAsync hipMemcpyAsync #define cudaMemcpyPeerAsync hipMemcpyPeerAsync diff --git a/ggml/src/ggml-hip/CMakeLists.txt b/ggml/src/ggml-hip/CMakeLists.txt index e3762649f..1fe8fe3b8 100644 --- a/ggml/src/ggml-hip/CMakeLists.txt +++ b/ggml/src/ggml-hip/CMakeLists.txt @@ -89,10 +89,6 @@ endif() add_compile_definitions(GGML_USE_HIP) -if (GGML_HIP_UMA) - add_compile_definitions(GGML_HIP_UMA) -endif() - if (GGML_CUDA_FORCE_MMQ) add_compile_definitions(GGML_CUDA_FORCE_MMQ) endif() From 54a72720432810c89c2693f34908b60b88da1e46 Mon Sep 17 00:00:00 2001 From: hipudding Date: Tue, 15 Apr 2025 19:08:55 +0800 Subject: [PATCH 17/20] CANN: Add x86 build ci (#12950) * CANN: Add x86 build ci * CANN: fix code format --- .github/workflows/build.yml | 5 +- ggml/src/ggml-cann/ggml-cann.cpp | 326 ++++++++++++++++--------------- 2 files changed, 167 insertions(+), 164 deletions(-) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index bcfcf08ac..34417985d 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -1766,16 +1766,17 @@ jobs: if: ${{ github.event_name != 'pull_request' || contains(github.event.pull_request.labels.*.name, 'Ascend NPU') }} defaults: run: - shell: bash -el {0} - runs-on: ubuntu-24.04-arm + shell: bash -el {0} strategy: matrix: + arch: [x86, aarch64] cann: - '8.1.RC1.alpha001-910b-openeuler22.03-py3.10' device: - 'ascend910b3' build: - 'Release' + runs-on: ${{ matrix.arch == 'aarch64' && 'ubuntu-24.04-arm' || 'ubuntu-24.04' }} container: ascendai/cann:${{ matrix.cann }} steps: - name: Checkout diff --git a/ggml/src/ggml-cann/ggml-cann.cpp b/ggml/src/ggml-cann/ggml-cann.cpp index db8ae260a..08b9ca301 100644 --- a/ggml/src/ggml-cann/ggml-cann.cpp +++ b/ggml/src/ggml-cann/ggml-cann.cpp @@ -156,195 +156,196 @@ const ggml_cann_device_info& ggml_cann_info() { * This class manages a pool of CANN buffers for a specific device. */ struct ggml_cann_pool_buf_prio : public ggml_cann_pool { - /** - * @brief The maximum reuse margin for a buffer. - */ - static const size_t max_reuse_margin = 1ull << 22; // 4MB + /** + * @brief The maximum reuse margin for a buffer. + */ + static const size_t max_reuse_margin = 1ull << 22; // 4MB - /** - * @brief The minimum free margin for a buffer. - */ - static const size_t min_free_margin = 1ull << 20; // 1MB + /** + * @brief The minimum free margin for a buffer. + */ + static const size_t min_free_margin = 1ull << 20; // 1MB - /** - * @brief The alignment for buffer allocation. - */ - static const size_t alignment = 128; + /** + * @brief The alignment for buffer allocation. + */ + static const size_t alignment = 128; - /** - * @brief The device ID associated with this buffer pool. - */ - int device; + /** + * @brief The device ID associated with this buffer pool. + */ + int device; - /** - * @brief Whether to disable clean during buffer allocation. - */ - bool disable_clean = false; + /** + * @brief Whether to disable clean during buffer allocation. + */ + bool disable_clean = false; - /** - * @brief Structure representing a CANN buffer. - */ - struct ggml_cann_buffer { - void* ptr = nullptr; ///< Pointer to the buffer. - size_t size = 0; ///< Size of the buffer. - std::chrono::steady_clock::time_point last_used; ///< Last used time. + /** + * @brief Structure representing a CANN buffer. + */ + struct ggml_cann_buffer { + void* ptr = nullptr; ///< Pointer to the buffer. + size_t size = 0; ///< Size of the buffer. + std::chrono::steady_clock::time_point last_used; ///< Last used time. - bool operator>(const ggml_cann_buffer& other) const { - return size > other.size; - } - }; + bool operator>(const ggml_cann_buffer& other) const { + return size > other.size; + } + }; - /** - * @brief Array of CANN buffers in the pool. - */ - std::unordered_map buffer_pool; - std::priority_queue, - std::greater<>> free_buffers ; + /** + * @brief Array of CANN buffers in the pool. + */ + std::unordered_map buffer_pool; + std::priority_queue, + std::greater<>> free_buffers ; - /** - * @brief Total size of all buffers in the pool. - */ - size_t pool_size = 0; + /** + * @brief Total size of all buffers in the pool. + */ + size_t pool_size = 0; - /** - * @brief Constructor to initialize the buffer pool for a specific device. - * - * @param device The device ID to associate with this buffer pool. - */ - explicit ggml_cann_pool_buf_prio(int device) : device(device) { - disable_clean = getenv("GGML_CANN_DISABLE_BUF_POOL_CLEAN") != nullptr; + /** + * @brief Constructor to initialize the buffer pool for a specific device. + * + * @param device The device ID to associate with this buffer pool. + */ + explicit ggml_cann_pool_buf_prio(int device) : device(device) { + disable_clean = getenv("GGML_CANN_DISABLE_BUF_POOL_CLEAN") != nullptr; + } + + /** + * @brief Destructor to free all buffers in the pool. + */ + ~ggml_cann_pool_buf_prio() { + ggml_cann_set_device(device); + for (auto& [b_ptr, b_size] : buffer_pool) { + aclrtFree(b_ptr); + pool_size -= b_size; + } + buffer_pool.clear(); + GGML_ASSERT(pool_size == 0); + } + + /** + * @brief Allocate a buffer of the given size. + * + * @param size The size of the buffer to allocate. + * @param actual_size A pointer to a variable to receive the actual size of + * the allocated buffer. + * @return A pointer to the allocated buffer. + */ + void* alloc(size_t size, size_t* actual_size) override { + size = GGML_PAD(size, alignment); + if (size == 0) { + size = alignment; } - /** - * @brief Destructor to free all buffers in the pool. - */ - ~ggml_cann_pool_buf_prio() { - ggml_cann_set_device(device); - for (auto& [b_ptr, b_size] : buffer_pool) { - aclrtFree(b_ptr); - pool_size -= b_size; - } - buffer_pool.clear(); - GGML_ASSERT(pool_size == 0); - } + void* ptr = nullptr; + auto now = std::chrono::steady_clock::now(); - /** - * @brief Allocate a buffer of the given size. - * - * @param size The size of the buffer to allocate. - * @param actual_size A pointer to a variable to receive the actual size of - * the allocated buffer. - * @return A pointer to the allocated buffer. - */ - void* alloc(size_t size, size_t* actual_size) override { - size = GGML_PAD(size, alignment); - if (size == 0) { - size = alignment; - } + std::vector free_buffers_rest; + free_buffers_rest.reserve(free_buffers.size()); + while (!free_buffers.empty()) { + auto b = free_buffers.top(); + free_buffers.pop(); - void* ptr = nullptr; - auto now = std::chrono::steady_clock::now(); - - std::vector free_buffers_rest; - free_buffers_rest.reserve(free_buffers.size()); - while (!free_buffers.empty()) { - auto b = free_buffers.top(); - free_buffers.pop(); - - if (b.size >= size) { - // reuse the buffer if the size is enough - const size_t margin = b.size - size; - if (margin <= max_reuse_margin) { - *actual_size = b.size; - ptr = b.ptr; - #ifdef DEBUG_CANN_MALLOC - GGML_LOG_INFO( - "cann pool[%d]: reused %p, " - "pool_size = %5u MB, " - "size = %5u MB, " - "margin = %5u MB\n", - device, b.ptr, - (uint32_t)(GGML_PAD(pool_size, 1048576) / 1048576), - (uint32_t)(GGML_PAD(size, 1048576) / 1048576), - (uint32_t)(GGML_PAD(margin, 1048576) / 1048576)); - #endif - break; - } - } - - bool should_clean = !disable_clean && - b.size > min_free_margin && - std::chrono::duration_cast(now - b.last_used).count() > 100; - if (should_clean) { - // free the buffer if the size is needed to be freed - ACL_CHECK(aclrtFree(b.ptr)); - pool_size -= b.size; - buffer_pool.erase(b.ptr); - #ifdef DEBUG_CANN_MALLOC + if (b.size >= size) { + // reuse the buffer if the size is enough + const size_t margin = b.size - size; + if (margin <= max_reuse_margin) { + *actual_size = b.size; + ptr = b.ptr; +#ifdef DEBUG_CANN_MALLOC GGML_LOG_INFO( - "cann pool[%d]: clean %p, " + "cann pool[%d]: reused %p, " "pool_size = %5u MB, " - "size = %5u MB\n", + "size = %5u MB, " + "margin = %5u MB\n", device, b.ptr, (uint32_t)(GGML_PAD(pool_size, 1048576) / 1048576), - (uint32_t)(GGML_PAD(b.size, 1048576) / 1048576)); - #endif - continue; + (uint32_t)(GGML_PAD(size, 1048576) / 1048576), + (uint32_t)(GGML_PAD(margin, 1048576) / 1048576)); +#endif + break; } - free_buffers_rest.push_back(b); - } - for (ggml_cann_buffer &b : free_buffers_rest) { - free_buffers.push(std::move(b)); } - #ifdef DEBUG_CANN_MALLOC - GGML_LOG_INFO("cann pool[%d] free pool_size = %5u MB\n\n", device, (uint32_t)(GGML_PAD(pool_size, 1048576) / 1048576)); - #endif - if (ptr != nullptr) { - return ptr; + bool should_clean = !disable_clean && + b.size > min_free_margin && + std::chrono::duration_cast(now - b.last_used).count() > 100; + if (should_clean) { + // free the buffer if the size is needed to be freed + ACL_CHECK(aclrtFree(b.ptr)); + pool_size -= b.size; + buffer_pool.erase(b.ptr); +#ifdef DEBUG_CANN_MALLOC + GGML_LOG_INFO( + "cann pool[%d]: clean %p, " + "pool_size = %5u MB, " + "size = %5u MB\n", + device, b.ptr, + (uint32_t)(GGML_PAD(pool_size, 1048576) / 1048576), + (uint32_t)(GGML_PAD(b.size, 1048576) / 1048576)); +#endif + continue; } + free_buffers_rest.push_back(b); + } + for (ggml_cann_buffer &b : free_buffers_rest) { + free_buffers.push(std::move(b)); + } - // allocate a new buffer if no buffer can be reused - ggml_cann_set_device(device); - ACL_CHECK(aclrtMalloc(&ptr, size, ACL_MEM_MALLOC_HUGE_FIRST)); - *actual_size = size; - pool_size += size; - #ifdef DEBUG_CANN_MALLOC - GGML_LOG_INFO( - "cann pool[%d]: allocate %p, " - "pool_size = %5u MB, " - "size = %5u MB\n", - device, ptr, (uint32_t)(GGML_PAD(pool_size, 1048576) / 1048576), - (uint32_t)(GGML_PAD(size, 1048576) / 1048576)); - #endif - buffer_pool.emplace(ptr, size); +#ifdef DEBUG_CANN_MALLOC + GGML_LOG_INFO("cann pool[%d] free pool_size = %5u MB\n\n", device, (uint32_t)(GGML_PAD(pool_size, 1048576) / 1048576)); +#endif + if (ptr != nullptr) { return ptr; } - /** - * @brief Free a buffer and return it to the pool. - * - * @param ptr Pointer to the buffer to free. - * @param size Size of the buffer to free. - */ - void free(void* ptr, size_t size) override { - auto it = buffer_pool.find(ptr); - if (it == buffer_pool.end()) { - GGML_ABORT("cann pool[%d]: buffer %p not found in pool\n", device, ptr); - } + // allocate a new buffer if no buffer can be reused + ggml_cann_set_device(device); + ACL_CHECK(aclrtMalloc(&ptr, size, ACL_MEM_MALLOC_HUGE_FIRST)); + *actual_size = size; + pool_size += size; +#ifdef DEBUG_CANN_MALLOC + GGML_LOG_INFO( + "cann pool[%d]: allocate %p, " + "pool_size = %5u MB, " + "size = %5u MB\n", + device, ptr, (uint32_t)(GGML_PAD(pool_size, 1048576) / 1048576), + (uint32_t)(GGML_PAD(size, 1048576) / 1048576)); +#endif + buffer_pool.emplace(ptr, size); + return ptr; + } - auto now = std::chrono::steady_clock::now(); - free_buffers.emplace(ggml_cann_buffer{ptr, it->second, now}); - #ifdef DEBUG_CANN_MALLOC - GGML_LOG_INFO( - "cann pool[%d]: return %p, " - "pool_size = %5u MB\n", - device, ptr, - (uint32_t)(GGML_PAD(pool_size, 1048576) / 1048576)); - #endif + /** + * @brief Free a buffer and return it to the pool. + * + * @param ptr Pointer to the buffer to free. + * @param size Size of the buffer to free. + */ + void free(void* ptr, size_t size) override { + GGML_UNUSED(size); + auto it = buffer_pool.find(ptr); + if (it == buffer_pool.end()) { + GGML_ABORT("cann pool[%d]: buffer %p not found in pool\n", device, ptr); } - }; + + auto now = std::chrono::steady_clock::now(); + free_buffers.emplace(ggml_cann_buffer{ptr, it->second, now}); +#ifdef DEBUG_CANN_MALLOC + GGML_LOG_INFO( + "cann pool[%d]: return %p, " + "pool_size = %5u MB\n", + device, ptr, + (uint32_t)(GGML_PAD(pool_size, 1048576) / 1048576)); +#endif + } +}; /** * @brief A pool of CANN buffers(segment buffer). @@ -531,6 +532,7 @@ struct ggml_cann_pool_buf : public ggml_cann_pool { * @param size Size of the buffer to free. */ void free(void* ptr, size_t size) override { + GGML_UNUSED(size); for (int i = 0; i < MAX_BUFFERS; ++i) { ggml_cann_buffer& b = buffer_pool[i]; if (b.ptr != ptr) { From f8f820cc4dc37032d5375972ba904ce53043445d Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Tue, 15 Apr 2025 14:45:05 +0300 Subject: [PATCH 18/20] metal : add FA-vec kernels for head size 96 (#12952) ggml-ci --- ggml/src/ggml-metal/ggml-metal.m | 34 +++++++++++++++++++++++++++- ggml/src/ggml-metal/ggml-metal.metal | 10 ++++++++ 2 files changed, 43 insertions(+), 1 deletion(-) diff --git a/ggml/src/ggml-metal/ggml-metal.m b/ggml/src/ggml-metal/ggml-metal.m index 9f1c6c6cc..16ca08383 100644 --- a/ggml/src/ggml-metal/ggml-metal.m +++ b/ggml/src/ggml-metal/ggml-metal.m @@ -402,6 +402,13 @@ enum ggml_metal_kernel_type { GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H192, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK192_HV128, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H256, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H96, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H96, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H96, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H96, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H96, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H96, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H96, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H128, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H128, @@ -1059,6 +1066,13 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H192, flash_attn_ext_q8_0_h192, has_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK192_HV128, flash_attn_ext_q8_0_hk192_hv128, has_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H256, flash_attn_ext_q8_0_h256, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H96, flash_attn_ext_vec_f16_h96, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H96, flash_attn_ext_vec_bf16_h96, has_simdgroup_reduction && use_bfloat); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H96, flash_attn_ext_vec_q4_0_h96, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H96, flash_attn_ext_vec_q4_1_h96, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H96, flash_attn_ext_vec_q5_0_h96, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H96, flash_attn_ext_vec_q5_1_h96, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H96, flash_attn_ext_vec_q8_0_h96, has_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128, flash_attn_ext_vec_f16_h128, has_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H128, flash_attn_ext_vec_bf16_h128, has_simdgroup_reduction && use_bfloat); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H128, flash_attn_ext_vec_q4_0_h128, has_simdgroup_reduction); @@ -3843,7 +3857,7 @@ static void ggml_metal_encode_node( // TODO: add vec kernels for (ne00%64 == 0) and maybe also for (ne00%32 == 0) // for now avoiding mainly to keep the number of templates/kernels a bit lower // these are now trivial to add after: https://github.com/ggml-org/llama.cpp/pull/12612 - if (ne01 >= 4 || (ne00%128 != 0 && ne00 != 192)) { + if (ne01 >= 4 || (ne00%128 != 0 && ne00 != 96 && ne00 != 192)) { switch (src1->type) { case GGML_TYPE_F16: { @@ -4010,6 +4024,24 @@ static void ggml_metal_encode_node( use_vec_kernel = true; switch (ne00) { + case 96: + { + switch (src1->type) { + case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H96].pipeline; break; + case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H96].pipeline; break; + case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H96].pipeline; break; + case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H96].pipeline; break; + case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H96].pipeline; break; + case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H96].pipeline; break; + case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H96].pipeline; break; + default: + { + GGML_LOG_ERROR("unsupported type: %d\n", src1->type); + GGML_LOG_ERROR("add template specialization for this type\n"); + GGML_ABORT("add template specialization for this type"); + } + } + } break; case 128: { switch (src1->type) { diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index b08666e27..0fb28238a 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -3959,6 +3959,16 @@ kernel void kernel_flash_attn_ext_vec( typedef decltype(kernel_flash_attn_ext_vec) flash_attn_ext_vec_t; +template [[host_name("kernel_flash_attn_ext_vec_f16_h96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +#if defined(GGML_METAL_USE_BF16) +template [[host_name("kernel_flash_attn_ext_vec_bf16_h96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +#endif +template [[host_name("kernel_flash_attn_ext_vec_q4_0_h96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q4_1_h96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q5_0_h96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q5_1_h96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q8_0_h96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; + template [[host_name("kernel_flash_attn_ext_vec_f16_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; #if defined(GGML_METAL_USE_BF16) template [[host_name("kernel_flash_attn_ext_vec_bf16_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; From 80f19b41869728eeb6a26569957b92a773a2b2c6 Mon Sep 17 00:00:00 2001 From: lhez Date: Tue, 15 Apr 2025 12:26:00 -0700 Subject: [PATCH 19/20] opencl: split `ggml-opencl.cl` into multiple files and cleanup (#12886) * opencl: refactor - split the kernel files --------- Co-authored-by: Shangqing Gu * opencl: split more kernels into separate files * opencl: specify subgroup size instead of querying it * opencl: refine Adreno cl compiler version parsing * opencl: skip some kernels not used by Adreno on old compilers * opencl: refine logic for selecting Adreno kernels * opencl: refine Adreno cl compiler version * opencl: cleanup preprocessor for kernels * opencl: consider Adreno CL compiler on Windows * opencl: add final newline for `mul_mv_f16_f16.cl` --------- Co-authored-by: Shangqing Gu --- ggml/src/ggml-opencl/CMakeLists.txt | 45 +- ggml/src/ggml-opencl/ggml-opencl.cpp | 1053 ++++-- ggml/src/ggml-opencl/kernels/add.cl | 83 + ggml/src/ggml-opencl/kernels/clamp.cl | 20 + ggml/src/ggml-opencl/kernels/cpy.cl | 184 + .../kernels/{ggml-opencl_cvt.cl => cvt.cl} | 66 +- ggml/src/ggml-opencl/kernels/diag_mask_inf.cl | 58 + ggml/src/ggml-opencl/kernels/gelu.cl | 62 + ...cl_gemv_noshuffle.cl => gemv_noshuffle.cl} | 0 ...e_general.cl => gemv_noshuffle_general.cl} | 0 ggml/src/ggml-opencl/kernels/get_rows.cl | 163 + ggml/src/ggml-opencl/kernels/ggml-opencl.cl | 3231 ----------------- .../ggml-opencl/kernels/ggml-opencl_im2col.cl | 146 - .../src/ggml-opencl/kernels/ggml-opencl_mm.cl | 1225 ------- .../kernels/ggml-opencl_transpose_16.cl | 26 - .../kernels/ggml-opencl_transpose_32.cl | 25 - .../kernels/ggml-opencl_transpose_32_16.cl | 35 - ggml/src/ggml-opencl/kernels/im2col_f16.cl | 57 + ggml/src/ggml-opencl/kernels/im2col_f32.cl | 57 + ggml/src/ggml-opencl/kernels/mul.cl | 79 + ..._mat_Ab_Bi_8x4.cl => mul_mat_Ab_Bi_8x4.cl} | 0 .../src/ggml-opencl/kernels/mul_mv_f16_f16.cl | 118 + .../src/ggml-opencl/kernels/mul_mv_f16_f32.cl | 118 + .../kernels/mul_mv_f16_f32_1row.cl | 94 + .../ggml-opencl/kernels/mul_mv_f16_f32_l4.cl | 84 + .../src/ggml-opencl/kernels/mul_mv_f32_f32.cl | 118 + .../ggml-opencl/kernels/mul_mv_q4_0_f32.cl | 192 + .../kernels/mul_mv_q4_0_f32_1d_16x_flat.cl | 307 ++ .../kernels/mul_mv_q4_0_f32_1d_8x_flat.cl | 265 ++ .../kernels/mul_mv_q4_0_f32_8x_flat.cl | 272 ++ .../ggml-opencl/kernels/mul_mv_q4_0_f32_v.cl | 254 ++ ggml/src/ggml-opencl/kernels/mul_mv_q6_k.cl | 190 + ggml/src/ggml-opencl/kernels/norm.cl | 81 + ggml/src/ggml-opencl/kernels/relu.cl | 16 + ggml/src/ggml-opencl/kernels/rms_norm.cl | 96 + ggml/src/ggml-opencl/kernels/rope.cl | 721 ++++ ggml/src/ggml-opencl/kernels/scale.cl | 16 + ggml/src/ggml-opencl/kernels/silu.cl | 30 + ggml/src/ggml-opencl/kernels/softmax_4_f16.cl | 87 + ggml/src/ggml-opencl/kernels/softmax_4_f32.cl | 87 + ggml/src/ggml-opencl/kernels/softmax_f16.cl | 86 + ggml/src/ggml-opencl/kernels/softmax_f32.cl | 86 + ggml/src/ggml-opencl/kernels/transpose.cl | 84 + 43 files changed, 5021 insertions(+), 4996 deletions(-) create mode 100644 ggml/src/ggml-opencl/kernels/add.cl create mode 100644 ggml/src/ggml-opencl/kernels/clamp.cl create mode 100644 ggml/src/ggml-opencl/kernels/cpy.cl rename ggml/src/ggml-opencl/kernels/{ggml-opencl_cvt.cl => cvt.cl} (69%) create mode 100644 ggml/src/ggml-opencl/kernels/diag_mask_inf.cl create mode 100644 ggml/src/ggml-opencl/kernels/gelu.cl rename ggml/src/ggml-opencl/kernels/{ggml-opencl_gemv_noshuffle.cl => gemv_noshuffle.cl} (100%) rename ggml/src/ggml-opencl/kernels/{ggml-opencl_gemv_noshuffle_general.cl => gemv_noshuffle_general.cl} (100%) create mode 100644 ggml/src/ggml-opencl/kernels/get_rows.cl delete mode 100644 ggml/src/ggml-opencl/kernels/ggml-opencl.cl delete mode 100644 ggml/src/ggml-opencl/kernels/ggml-opencl_im2col.cl delete mode 100644 ggml/src/ggml-opencl/kernels/ggml-opencl_mm.cl delete mode 100644 ggml/src/ggml-opencl/kernels/ggml-opencl_transpose_16.cl delete mode 100644 ggml/src/ggml-opencl/kernels/ggml-opencl_transpose_32.cl delete mode 100644 ggml/src/ggml-opencl/kernels/ggml-opencl_transpose_32_16.cl create mode 100644 ggml/src/ggml-opencl/kernels/im2col_f16.cl create mode 100644 ggml/src/ggml-opencl/kernels/im2col_f32.cl create mode 100644 ggml/src/ggml-opencl/kernels/mul.cl rename ggml/src/ggml-opencl/kernels/{ggml-opencl_mul_mat_Ab_Bi_8x4.cl => mul_mat_Ab_Bi_8x4.cl} (100%) create mode 100644 ggml/src/ggml-opencl/kernels/mul_mv_f16_f16.cl create mode 100644 ggml/src/ggml-opencl/kernels/mul_mv_f16_f32.cl create mode 100644 ggml/src/ggml-opencl/kernels/mul_mv_f16_f32_1row.cl create mode 100644 ggml/src/ggml-opencl/kernels/mul_mv_f16_f32_l4.cl create mode 100644 ggml/src/ggml-opencl/kernels/mul_mv_f32_f32.cl create mode 100644 ggml/src/ggml-opencl/kernels/mul_mv_q4_0_f32.cl create mode 100644 ggml/src/ggml-opencl/kernels/mul_mv_q4_0_f32_1d_16x_flat.cl create mode 100644 ggml/src/ggml-opencl/kernels/mul_mv_q4_0_f32_1d_8x_flat.cl create mode 100644 ggml/src/ggml-opencl/kernels/mul_mv_q4_0_f32_8x_flat.cl create mode 100644 ggml/src/ggml-opencl/kernels/mul_mv_q4_0_f32_v.cl create mode 100644 ggml/src/ggml-opencl/kernels/mul_mv_q6_k.cl create mode 100644 ggml/src/ggml-opencl/kernels/norm.cl create mode 100644 ggml/src/ggml-opencl/kernels/relu.cl create mode 100644 ggml/src/ggml-opencl/kernels/rms_norm.cl create mode 100644 ggml/src/ggml-opencl/kernels/rope.cl create mode 100644 ggml/src/ggml-opencl/kernels/scale.cl create mode 100644 ggml/src/ggml-opencl/kernels/silu.cl create mode 100644 ggml/src/ggml-opencl/kernels/softmax_4_f16.cl create mode 100644 ggml/src/ggml-opencl/kernels/softmax_4_f32.cl create mode 100644 ggml/src/ggml-opencl/kernels/softmax_f16.cl create mode 100644 ggml/src/ggml-opencl/kernels/softmax_f32.cl create mode 100644 ggml/src/ggml-opencl/kernels/transpose.cl diff --git a/ggml/src/ggml-opencl/CMakeLists.txt b/ggml/src/ggml-opencl/CMakeLists.txt index 624cb1b9d..352deb321 100644 --- a/ggml/src/ggml-opencl/CMakeLists.txt +++ b/ggml/src/ggml-opencl/CMakeLists.txt @@ -54,16 +54,41 @@ function(ggml_opencl_add_kernel KNAME) endfunction() set(GGML_OPENCL_KERNELS - ggml-opencl - ggml-opencl_mm - ggml-opencl_cvt - ggml-opencl_gemv_noshuffle - ggml-opencl_gemv_noshuffle_general - ggml-opencl_mul_mat_Ab_Bi_8x4 - ggml-opencl_transpose_16 - ggml-opencl_transpose_32 - ggml-opencl_transpose_32_16 - ggml-opencl_im2col + add + clamp + cpy + cvt + diag_mask_inf + gelu + gemv_noshuffle_general + gemv_noshuffle + get_rows + im2col_f32 + im2col_f16 + mul_mat_Ab_Bi_8x4 + mul_mv_f16_f16 + mul_mv_f16_f32_1row + mul_mv_f16_f32_l4 + mul_mv_f16_f32 + mul_mv_f32_f32 + mul_mv_q4_0_f32 + mul_mv_q4_0_f32_v + mul_mv_q4_0_f32_8x_flat + mul_mv_q4_0_f32_1d_8x_flat + mul_mv_q4_0_f32_1d_16x_flat + mul_mv_q6_k + mul + norm + relu + rms_norm + rope + scale + silu + softmax_4_f32 + softmax_4_f16 + softmax_f32 + softmax_f16 + transpose ) foreach (K ${GGML_OPENCL_KERNELS}) diff --git a/ggml/src/ggml-opencl/ggml-opencl.cpp b/ggml/src/ggml-opencl/ggml-opencl.cpp index b8b5cbd3e..fa40abc33 100644 --- a/ggml/src/ggml-opencl/ggml-opencl.cpp +++ b/ggml/src/ggml-opencl/ggml-opencl.cpp @@ -64,11 +64,33 @@ enum ADRENO_GPU_GEN { X1E, }; +enum ADRENO_CL_COMPILER_TYPE { + E031, + DX, +}; + struct ggml_cl_version { cl_uint major = 0; cl_uint minor = 0; }; +struct ggml_cl_compiler_version { + ADRENO_CL_COMPILER_TYPE type; + int major = -1; + int minor = -1; + int patch = -1; + + bool same(ADRENO_CL_COMPILER_TYPE t, int x, int y, int z) const { + return major == x && minor == y && patch == z && type == t; + } + bool newer_than(ADRENO_CL_COMPILER_TYPE t, int x, int y, int z) const { + return major*10000 + minor*100 + patch > x*10000 + y*100 + z && type == t; + } + bool newer_than_or_same(ADRENO_CL_COMPILER_TYPE t, int x, int y, int z) const { + return same(t, x, y, z) || newer_than(t, x, y, z); + } +}; + // Parses a version string of form "XX.YY ". On an error returns ggml_cl_version with all zeroes. static ggml_cl_version parse_cl_version(std::string_view str) { size_t major_str_begin = 0; @@ -173,24 +195,30 @@ static ADRENO_GPU_GEN get_adreno_gpu_gen(const char *device_name) { return ADRENO_GPU_GEN::ADRENO_UNKNOWN; } -static int get_adreno_cl_compiler_version(const char *driver_version) { +static ggml_cl_compiler_version get_adreno_cl_compiler_version(const char *driver_version) { std::string driver_ver_str(driver_version); + ADRENO_CL_COMPILER_TYPE type = ADRENO_CL_COMPILER_TYPE::E031; size_t compiler_ver_pos = driver_ver_str.find("E031"); size_t compiler_ver_len = 13; - size_t compiler_ver_offset = 5; + size_t compiler_major_offset = 5; + size_t compiler_minor_offset = 8; + size_t compiler_patch_offset = 11; if (compiler_ver_pos == std::string::npos) { compiler_ver_pos = driver_ver_str.find("DX"); if (compiler_ver_pos == std::string::npos) { - return -1; + return {}; } + type = ADRENO_CL_COMPILER_TYPE::DX; compiler_ver_len = 11; - compiler_ver_offset = 3; + compiler_major_offset = 3; } std::string compiler_ver_str = driver_ver_str.substr(compiler_ver_pos, compiler_ver_len); - std::string major_ver_str = compiler_ver_str.substr(compiler_ver_offset, 2); - return std::atoi(major_ver_str.c_str()); + int major = std::atoi(compiler_ver_str.substr(compiler_major_offset, 2).c_str()); + int minor = std::atoi(compiler_ver_str.substr(compiler_minor_offset, 2).c_str()); + int patch = std::atoi(compiler_ver_str.substr(compiler_patch_offset, 2).c_str()); + return { type, major, minor, patch }; } // backend device context @@ -215,16 +243,48 @@ struct ggml_backend_opencl_context { cl_int alignment; size_t max_alloc_size; bool fp16_support; + bool has_vector_subgroup_broadcast; + ggml_cl_compiler_version adreno_cl_compiler_version; int adreno_wave_size; cl_context context; cl_command_queue queue; - cl_program program; - cl_program program_1; - cl_program program_2; - cl_program program_im2col; + cl_program program_add; + cl_program program_clamp; + cl_program program_cpy; + cl_program program_cvt; + cl_program program_diag_mask_inf; + cl_program program_gelu; + cl_program program_gemv_noshuffle_general; + cl_program program_gemv_noshuffle; + cl_program program_get_rows; + cl_program program_im2col_f16; + cl_program program_im2col_f32; + cl_program program_mul_mat_Ab_Bi_8x4; + cl_program program_mul_mv_q4_0_f32; + cl_program program_mul_mv_q4_0_f32_v; + cl_program program_mul_mv_q4_0_f32_8x_flat; + cl_program program_mul_mv_q4_0_f32_1d_8x_flat; + cl_program program_mul_mv_q4_0_f32_1d_16x_flat; + cl_program program_mul_mv_q6_K; + cl_program program_mul_mv_f16_f16; + cl_program program_mul_mv_f16_f32_1row; + cl_program program_mul_mv_f16_f32_l4; + cl_program program_mul_mv_f16_f32; + cl_program program_mul_mv_f32_f32; + cl_program program_mul; + cl_program program_norm; + cl_program program_relu; + cl_program program_rms_norm; + cl_program program_rope; + cl_program program_scale; + cl_program program_silu; + cl_program program_softmax_f32; + cl_program program_softmax_f16; + cl_program program_softmax_4_f32; + cl_program program_softmax_4_f16; cl_kernel kernel_add, kernel_add_row; cl_kernel kernel_mul, kernel_mul_row; @@ -249,19 +309,17 @@ struct ggml_backend_opencl_context { cl_kernel kernel_mul_mat_f16_f32; cl_kernel kernel_mul_mat_f16_f32_l4; cl_kernel kernel_mul_mat_q4_0_f32, kernel_mul_mat_q4_0_f32_v; - cl_kernel kernel_convert_block_q4_0, kernel_restore_block_q4_0, kernel_mul_mat_q4_0_f32_flat; + cl_kernel kernel_convert_block_q4_0, kernel_restore_block_q4_0; cl_kernel kernel_mul_mat_q4_0_f32_8x_flat; - cl_kernel kernel_convert_block_q4_0_noshuffle, kernel_mul_mat_q4_0_f32_flat_v0, - kernel_mul_mat_q4_0_f32_flat_img_v0; + cl_kernel kernel_convert_block_q4_0_noshuffle; cl_kernel kernel_mul_mat_q4_0_f32_1d_8x_flat, kernel_mul_mat_q4_0_f32_1d_16x_flat; cl_kernel kernel_mul_mv_q6_K_f32; cl_kernel kernel_im2col_f32, kernel_im2col_f16; #ifdef GGML_OPENCL_USE_ADRENO_KERNELS // Transpose kernels - cl_program program_transpose_32; - cl_program program_transpose_32_16; - cl_program program_transpose_16; + cl_program program_transpose; + cl_kernel kernel_transpose_32; cl_kernel kernel_transpose_32_16; cl_kernel kernel_transpose_16; @@ -374,6 +432,681 @@ static cl_program build_program_from_source(cl_context ctx, cl_device_id dev, co return p; } +static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_version opencl_c_version) { + cl_int err; + + // compiler options for general kernels + auto opencl_c_std = + std::string("CL") + std::to_string(opencl_c_version.major) + "." + std::to_string(opencl_c_version.minor); + std::string compile_opts = std::string("-cl-std=") + opencl_c_std + + " -cl-mad-enable -cl-unsafe-math-optimizations" + " -cl-finite-math-only -cl-fast-relaxed-math"; + + GGML_LOG_INFO("ggml_opencl: loading OpenCL kernels"); + + // add + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "add.cl.h" + }; +#else + const std::string kernel_src = read_file("add.cl"); +#endif + backend_ctx->program_add = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_add = clCreateKernel(backend_ctx->program_add, "kernel_add", &err), err)); + CL_CHECK((backend_ctx->kernel_add_row = clCreateKernel(backend_ctx->program_add, "kernel_add_row", &err), err)); + GGML_LOG_CONT("."); + } + + // clamp + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "clamp.cl.h" + }; +#else + const std::string kernel_src = read_file("clamp.cl"); +#endif + backend_ctx->program_clamp = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_clamp = clCreateKernel(backend_ctx->program_clamp, "kernel_clamp", &err), err)); + GGML_LOG_CONT("."); + } + + // cpy + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "cpy.cl.h" + }; +#else + const std::string kernel_src = read_file("cpy.cl"); +#endif + backend_ctx->program_cpy = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_cpy_f16_f16 = clCreateKernel(backend_ctx->program_cpy, "kernel_cpy_f16_f16", &err), err)); + CL_CHECK((backend_ctx->kernel_cpy_f16_f32 = clCreateKernel(backend_ctx->program_cpy, "kernel_cpy_f16_f32", &err), err)); + CL_CHECK((backend_ctx->kernel_cpy_f32_f16 = clCreateKernel(backend_ctx->program_cpy, "kernel_cpy_f32_f16", &err), err)); + CL_CHECK((backend_ctx->kernel_cpy_f32_f32 = clCreateKernel(backend_ctx->program_cpy, "kernel_cpy_f32_f32", &err), err)); + GGML_LOG_CONT("."); + } + + // cvt + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "cvt.cl.h" + }; +#else + const std::string kernel_src = read_file("cvt.cl"); +#endif + backend_ctx->program_cvt = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_convert_block_q4_0_noshuffle = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_q4_0_noshuffle", &err), err)); + CL_CHECK((backend_ctx->kernel_convert_block_q4_0 = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_q4_0", &err), err)); + CL_CHECK((backend_ctx->kernel_restore_block_q4_0 = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_q4_0", &err), err)); + GGML_LOG_CONT("."); + } + + // diag_mask_inf + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "diag_mask_inf.cl.h" + }; +#else + const std::string kernel_src = read_file("diag_mask_inf.cl"); +#endif + backend_ctx->program_diag_mask_inf = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_diag_mask_inf_8 = clCreateKernel(backend_ctx->program_diag_mask_inf, "kernel_diag_mask_inf_8", &err), err)); + CL_CHECK((backend_ctx->kernel_diag_mask_inf = clCreateKernel(backend_ctx->program_diag_mask_inf, "kernel_diag_mask_inf", &err), err)); + GGML_LOG_CONT("."); + } + + // gelu + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "gelu.cl.h" + }; +#else + const std::string kernel_src = read_file("gelu.cl"); +#endif + backend_ctx->program_gelu = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_gelu = clCreateKernel(backend_ctx->program_gelu, "kernel_gelu", &err), err)); + CL_CHECK((backend_ctx->kernel_gelu_4 = clCreateKernel(backend_ctx->program_gelu, "kernel_gelu_4", &err), err)); + CL_CHECK((backend_ctx->kernel_gelu_quick = clCreateKernel(backend_ctx->program_gelu, "kernel_gelu_quick", &err), err)); + CL_CHECK((backend_ctx->kernel_gelu_quick_4 = clCreateKernel(backend_ctx->program_gelu, "kernel_gelu_quick_4", &err), err)); + GGML_LOG_CONT("."); + } + + // get_rows + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "get_rows.cl.h" + }; +#else + const std::string kernel_src = read_file("get_rows.cl"); +#endif + backend_ctx->program_get_rows = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_get_rows_f32 = clCreateKernel(backend_ctx->program_get_rows, "kernel_get_rows_f32", &err), err)); + CL_CHECK((backend_ctx->kernel_get_rows_f16 = clCreateKernel(backend_ctx->program_get_rows, "kernel_get_rows_f16", &err), err)); + CL_CHECK((backend_ctx->kernel_get_rows_q4_0 = clCreateKernel(backend_ctx->program_get_rows, "kernel_get_rows_q4_0", &err), err)); + GGML_LOG_CONT("."); + } + + // im2col_f32 + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "im2col_f32.cl.h" + }; +#else + const std::string kernel_src = read_file("im2col_f32.cl"); +#endif + backend_ctx->program_im2col_f32 = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_im2col_f32 = clCreateKernel(backend_ctx->program_im2col_f32, "kernel_im2col_f32", &err), err)); + GGML_LOG_CONT("."); + } + + // im2col_f16 + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "im2col_f16.cl.h" + }; +#else + const std::string kernel_src = read_file("im2col_f16.cl"); +#endif + backend_ctx->program_im2col_f16 = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_im2col_f16 = clCreateKernel(backend_ctx->program_im2col_f16, "kernel_im2col_f16", &err), err)); + GGML_LOG_CONT("."); + } + + // mul_mv_q4_0_f32 + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "mul_mv_q4_0_f32.cl.h" + }; +#else + const std::string kernel_src = read_file("mul_mv_q4_0_f32.cl"); +#endif + backend_ctx->program_mul_mv_q4_0_f32 = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_mul_mat_q4_0_f32 = clCreateKernel(backend_ctx->program_mul_mv_q4_0_f32, "kernel_mul_mat_q4_0_f32", &err), err)); + GGML_LOG_CONT("."); + } + + // mul_mv_q4_0_f32_v + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "mul_mv_q4_0_f32_v.cl.h" + }; +#else + const std::string kernel_src = read_file("mul_mv_q4_0_f32_v.cl"); +#endif + backend_ctx->program_mul_mv_q4_0_f32_v = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_mul_mat_q4_0_f32_v = clCreateKernel(backend_ctx->program_mul_mv_q4_0_f32_v, "kernel_mul_mat_q4_0_f32_v", &err), err)); + GGML_LOG_CONT("."); + } + + // mul_mv_q4_0_f32_8x_flat + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "mul_mv_q4_0_f32_8x_flat.cl.h" + }; +#else + const std::string kernel_src = read_file("mul_mv_q4_0_f32_8x_flat.cl"); +#endif + backend_ctx->program_mul_mv_q4_0_f32_8x_flat = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_mul_mat_q4_0_f32_8x_flat = clCreateKernel(backend_ctx->program_mul_mv_q4_0_f32_8x_flat, "kernel_mul_mat_q4_0_f32_8x_flat", &err), err)); + GGML_LOG_CONT("."); + } + + // mul_mv_q4_0_f32_1d_8x_flat + // This kernel does not compiler on Adreno cl compiler 38.01. Skip it for + // those compiler versions since it is anyway not used for Adreno. + if (backend_ctx->gpu_family != ADRENO || + backend_ctx->adreno_cl_compiler_version.newer_than_or_same(E031, 38, 11, 0) || + backend_ctx->adreno_cl_compiler_version.type == DX) { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "mul_mv_q4_0_f32_1d_8x_flat.cl.h" + }; +#else + const std::string kernel_src = read_file("mul_mv_q4_0_f32_1d_8x_flat.cl"); +#endif + backend_ctx->program_mul_mv_q4_0_f32_1d_8x_flat = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_mul_mat_q4_0_f32_1d_8x_flat = clCreateKernel(backend_ctx->program_mul_mv_q4_0_f32_1d_8x_flat, "kernel_mul_mat_q4_0_f32_1d_8x_flat", &err), err)); + GGML_LOG_CONT("."); + } + + // mul_mv_q4_0_f32_1d_16x_flat + // This kernel does not compiler on Adreno cl compiler 38.01. Skip it for + // those compiler versions since it is anyway not used for Adreno. + if (backend_ctx->gpu_family != ADRENO || + backend_ctx->adreno_cl_compiler_version.newer_than_or_same(E031, 38, 11, 0) || + backend_ctx->adreno_cl_compiler_version.type == DX) { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "mul_mv_q4_0_f32_1d_16x_flat.cl.h" + }; +#else + const std::string kernel_src = read_file("mul_mv_q4_0_f32_1d_16x_flat.cl"); +#endif + backend_ctx->program_mul_mv_q4_0_f32_1d_16x_flat = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_mul_mat_q4_0_f32_1d_16x_flat = clCreateKernel(backend_ctx->program_mul_mv_q4_0_f32_1d_16x_flat, "kernel_mul_mat_q4_0_f32_1d_16x_flat", &err), err)); + GGML_LOG_CONT("."); + } + + // mul_mv_q6_k + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "mul_mv_q6_k.cl.h" + }; +#else + const std::string kernel_src = read_file("mul_mv_q6_k.cl"); +#endif + backend_ctx->program_mul_mv_q6_K = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_mul_mv_q6_K_f32 = clCreateKernel(backend_ctx->program_mul_mv_q6_K, "kernel_mul_mv_q6_K_f32", &err), err)); + GGML_LOG_CONT("."); + } + + // mul_mv_f16_f16 + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "mul_mv_f16_f16.cl.h" + }; +#else + const std::string kernel_src = read_file("mul_mv_f16_f16.cl"); +#endif + backend_ctx->program_mul_mv_f16_f16 = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_mul_mat_f16_f16 = clCreateKernel(backend_ctx->program_mul_mv_f16_f16, "kernel_mul_mat_f16_f16", &err), err)); + GGML_LOG_CONT("."); + } + + // mul_mv_f16_f32_1row + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "mul_mv_f16_f32_1row.cl.h" + }; +#else + const std::string kernel_src = read_file("mul_mv_f16_f32_1row.cl"); +#endif + backend_ctx->program_mul_mv_f16_f32_1row = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_mul_mat_f16_f32_1row = clCreateKernel(backend_ctx->program_mul_mv_f16_f32_1row, "kernel_mul_mat_f16_f32_1row", &err), err)); + GGML_LOG_CONT("."); + } + + // mul_mv_f16_f32_l4 + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "mul_mv_f16_f32_l4.cl.h" + }; +#else + const std::string kernel_src = read_file("mul_mv_f16_f32_l4.cl"); +#endif + backend_ctx->program_mul_mv_f16_f32_l4 = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_mul_mat_f16_f32_l4 = clCreateKernel(backend_ctx->program_mul_mv_f16_f32_l4, "kernel_mul_mat_f16_f32_l4", &err), err)); + GGML_LOG_CONT("."); + } + + // mul_mv_f16_f32 + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "mul_mv_f16_f32.cl.h" + }; +#else + const std::string kernel_src = read_file("mul_mv_f16_f32.cl"); +#endif + backend_ctx->program_mul_mv_f16_f32 = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_mul_mat_f16_f32 = clCreateKernel(backend_ctx->program_mul_mv_f16_f32, "kernel_mul_mat_f16_f32", &err), err)); + GGML_LOG_CONT("."); + } + + // mul_mv_f32_f32 + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "mul_mv_f32_f32.cl.h" + }; +#else + const std::string kernel_src = read_file("mul_mv_f32_f32.cl"); +#endif + backend_ctx->program_mul_mv_f32_f32 = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_mul_mat_f32_f32 = clCreateKernel(backend_ctx->program_mul_mv_f32_f32, "kernel_mul_mat_f32_f32", &err), err)); + GGML_LOG_CONT("."); + } + + // mul + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "mul.cl.h" + }; +#else + const std::string kernel_src = read_file("mul.cl"); +#endif + backend_ctx->program_mul = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_mul = clCreateKernel(backend_ctx->program_mul, "kernel_mul", &err), err)); + CL_CHECK((backend_ctx->kernel_mul_row = clCreateKernel(backend_ctx->program_mul, "kernel_mul_row", &err), err)); + GGML_LOG_CONT("."); + } + + // norm + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "norm.cl.h" + }; +#else + const std::string kernel_src = read_file("norm.cl"); +#endif + backend_ctx->program_norm = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_norm = clCreateKernel(backend_ctx->program_norm, "kernel_norm", &err), err)); + GGML_LOG_CONT("."); + } + + // relu + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "relu.cl.h" + }; +#else + const std::string kernel_src = read_file("relu.cl"); +#endif + backend_ctx->program_relu = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_relu = clCreateKernel(backend_ctx->program_relu, "kernel_relu", &err), err)); + GGML_LOG_CONT("."); + } + + // rms_norm + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "rms_norm.cl.h" + }; +#else + const std::string kernel_src = read_file("rms_norm.cl"); +#endif + backend_ctx->program_rms_norm = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_rms_norm = clCreateKernel(backend_ctx->program_rms_norm, "kernel_rms_norm", &err), err)); + GGML_LOG_CONT("."); + } + + // rope + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "rope.cl.h" + }; +#else + const std::string kernel_src = read_file("rope.cl"); +#endif + backend_ctx->program_rope = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_rope_norm_f32 = clCreateKernel(backend_ctx->program_rope, "kernel_rope_norm_f32", &err), err)); + CL_CHECK((backend_ctx->kernel_rope_norm_f16 = clCreateKernel(backend_ctx->program_rope, "kernel_rope_norm_f16", &err), err)); + CL_CHECK((backend_ctx->kernel_rope_neox_f32 = clCreateKernel(backend_ctx->program_rope, "kernel_rope_neox_f32", &err), err)); + CL_CHECK((backend_ctx->kernel_rope_neox_f16 = clCreateKernel(backend_ctx->program_rope, "kernel_rope_neox_f16", &err), err)); + CL_CHECK((backend_ctx->kernel_rope_multi_f32 = clCreateKernel(backend_ctx->program_rope, "kernel_rope_multi_f32", &err), err)); + CL_CHECK((backend_ctx->kernel_rope_multi_f16 = clCreateKernel(backend_ctx->program_rope, "kernel_rope_multi_f16", &err), err)); + CL_CHECK((backend_ctx->kernel_rope_vision_f32 = clCreateKernel(backend_ctx->program_rope, "kernel_rope_vision_f32", &err), err)); + CL_CHECK((backend_ctx->kernel_rope_vision_f16 = clCreateKernel(backend_ctx->program_rope, "kernel_rope_vision_f16", &err), err)); + GGML_LOG_CONT("."); + } + + // scale + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "scale.cl.h" + }; +#else + const std::string kernel_src = read_file("scale.cl"); +#endif + backend_ctx->program_scale = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_scale = clCreateKernel(backend_ctx->program_scale, "kernel_scale", &err), err)); + GGML_LOG_CONT("."); + } + + // silu + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "silu.cl.h" + }; +#else + const std::string kernel_src = read_file("silu.cl"); +#endif + backend_ctx->program_silu = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_silu = clCreateKernel(backend_ctx->program_silu, "kernel_silu", &err), err)); + CL_CHECK((backend_ctx->kernel_silu_4 = clCreateKernel(backend_ctx->program_silu, "kernel_silu_4", &err), err)); + GGML_LOG_CONT("."); + } + + // softmax_f32 + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "softmax_f32.cl.h" + }; +#else + const std::string kernel_src = read_file("softmax_f32.cl"); +#endif + backend_ctx->program_softmax_f32 = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_soft_max = clCreateKernel(backend_ctx->program_softmax_f32, "kernel_soft_max", &err), err)); + GGML_LOG_CONT("."); + } + + // softmax_f16 + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "softmax_f16.cl.h" + }; +#else + const std::string kernel_src = read_file("softmax_f16.cl"); +#endif + backend_ctx->program_softmax_f16 = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_soft_max_f16 = clCreateKernel(backend_ctx->program_softmax_f16, "kernel_soft_max_f16", &err), err)); + GGML_LOG_CONT("."); + } + + // softmax_4_f32 + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "softmax_4_f32.cl.h" + }; +#else + const std::string kernel_src = read_file("softmax_4_f32.cl"); +#endif + backend_ctx->program_softmax_4_f32 = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_soft_max_4 = clCreateKernel(backend_ctx->program_softmax_4_f32, "kernel_soft_max_4", &err), err)); + GGML_LOG_CONT("."); + } + + // softmax_4_f16 + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "softmax_4_f16.cl.h" + }; +#else + const std::string kernel_src = read_file("softmax_4_f16.cl"); +#endif + backend_ctx->program_softmax_4_f16 = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_soft_max_4_f16 = clCreateKernel(backend_ctx->program_softmax_4_f16, "kernel_soft_max_4_f16", &err), err)); + GGML_LOG_CONT("."); + } + + // Adreno kernels +#ifdef GGML_OPENCL_USE_ADRENO_KERNELS + // transpose + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "transpose.cl.h" + }; +#else + const std::string kernel_src = read_file("transpose.cl"); +#endif + backend_ctx->program_transpose = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_transpose_32_16 = clCreateKernel(backend_ctx->program_transpose, "kernel_transpose_32_16", &err), err)); + CL_CHECK((backend_ctx->kernel_transpose_32 = clCreateKernel(backend_ctx->program_transpose, "kernel_transpose_32", &err), err)); + CL_CHECK((backend_ctx->kernel_transpose_16 = clCreateKernel(backend_ctx->program_transpose, "kernel_transpose_16", &err), err)); + GGML_LOG_CONT("."); + } + + // gemv_noshuffle_general + { + std::string CL_gemv_compile_opts = std::string("-cl-std=") + opencl_c_std + + " -cl-mad-enable " + " -DSIMDGROUP_WIDTH=" + + std::to_string(backend_ctx->adreno_wave_size); + if (backend_ctx->has_vector_subgroup_broadcast) { + CL_gemv_compile_opts += " -DVECTOR_SUB_GROUP_BROADCAT "; + } + +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src_CL_gemv_general { + #include "gemv_noshuffle_general.cl.h" + }; +#else + const std::string kernel_src_CL_gemv_general = read_file("gemv_noshuffle_general.cl"); +#endif + + backend_ctx->program_CL_gemv_general = build_program_from_source( + backend_ctx->context, backend_ctx->device, kernel_src_CL_gemv_general.c_str(), CL_gemv_compile_opts); + + CL_CHECK((backend_ctx->CL_mul_mat_vec_q4_0_f32_1d_4x_flat_general = clCreateKernel(backend_ctx->program_CL_gemv_general, "kernel_gemv_noshuffle", &err), err)); + GGML_LOG_CONT("."); + } + + // gemv_noshuffle + { + // Gemv 2048, 16384 + std::string CL_gemv_compile_opts = std::string("-cl-std=") + opencl_c_std + + " -cl-mad-enable " + " -DLINE_STRIDE_A=2048 " + " -DBLOCK_STRIDE_A=16384 " + " -DSIMDGROUP_WIDTH=" + + std::to_string(backend_ctx->adreno_wave_size); + if (backend_ctx->has_vector_subgroup_broadcast) { + CL_gemv_compile_opts += " -DVECTOR_SUB_GROUP_BROADCAT "; + } + +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src_CL_gemv { + #include "gemv_noshuffle.cl.h" + }; +#else + const std::string kernel_src_CL_gemv = read_file("gemv_noshuffle.cl"); +#endif + + backend_ctx->program_CL_gemv_4096_1_4096 = build_program_from_source( + backend_ctx->context, backend_ctx->device, kernel_src_CL_gemv.c_str(), CL_gemv_compile_opts); + CL_CHECK((backend_ctx->CL_mul_mat_vec_q4_0_f32_1d_4x_flat_4096_1_4096 = clCreateKernel(backend_ctx->program_CL_gemv_4096_1_4096, "kernel_gemv_noshuffle", &err), err)); + GGML_LOG_CONT("."); + + // Gemv 2048, 16384 + CL_gemv_compile_opts = std::string("-cl-std=") + opencl_c_std + + " -cl-mad-enable " + " -DLINE_STRIDE_A=2048 " + " -DBLOCK_STRIDE_A=16384 " + " -DSIMDGROUP_WIDTH=" + + std::to_string(backend_ctx->adreno_wave_size); + if (backend_ctx->has_vector_subgroup_broadcast) { + CL_gemv_compile_opts += " -DVECTOR_SUB_GROUP_BROADCAT "; + } + + backend_ctx->program_CL_gemv_4096_1_11008 = build_program_from_source( + backend_ctx->context, backend_ctx->device, kernel_src_CL_gemv.c_str(), CL_gemv_compile_opts); + CL_CHECK((backend_ctx->CL_mul_mat_vec_q4_0_f32_1d_4x_flat_4096_1_11008 = clCreateKernel(backend_ctx->program_CL_gemv_4096_1_11008, "kernel_gemv_noshuffle", &err), err)); + GGML_LOG_CONT("."); + + // Gemv 5504, 44032 + CL_gemv_compile_opts = std::string("-cl-std=") + opencl_c_std + + " -cl-mad-enable " + " -DLINE_STRIDE_A=5504 " + " -DBLOCK_STRIDE_A=44032 " + " -DSIMDGROUP_WIDTH=" + + std::to_string(backend_ctx->adreno_wave_size); + if (backend_ctx->has_vector_subgroup_broadcast) { + CL_gemv_compile_opts += " -DVECTOR_SUB_GROUP_BROADCAT "; + } + + backend_ctx->program_CL_gemv_11008_1_4096 = build_program_from_source( + backend_ctx->context, backend_ctx->device, kernel_src_CL_gemv.c_str(), CL_gemv_compile_opts); + CL_CHECK((backend_ctx->CL_mul_mat_vec_q4_0_f32_1d_4x_flat_11008_1_4096 = clCreateKernel(backend_ctx->program_CL_gemv_11008_1_4096, "kernel_gemv_noshuffle", &err), err)); + GGML_LOG_CONT("."); + + // Gemv 16000, 128000 + CL_gemv_compile_opts = std::string("-cl-std=") + opencl_c_std + + " -cl-mad-enable " + " -DLINE_STRIDE_A=16000 " + " -DBLOCK_STRIDE_A=128000 " + " -DSIMDGROUP_WIDTH=" + + std::to_string(backend_ctx->adreno_wave_size); + + if (backend_ctx->has_vector_subgroup_broadcast) { + CL_gemv_compile_opts += " -DVECTOR_SUB_GROUP_BROADCAT "; + } + + backend_ctx->program_CL_gemv_32000_1_4096 = build_program_from_source( + backend_ctx->context, backend_ctx->device, kernel_src_CL_gemv.c_str(), CL_gemv_compile_opts); + CL_CHECK((backend_ctx->CL_mul_mat_vec_q4_0_f32_1d_4x_flat_32000_1_4096 = clCreateKernel(backend_ctx->program_CL_gemv_32000_1_4096, "kernel_gemv_noshuffle", &err), err)); + GGML_LOG_CONT("."); + } + + // mul_mat_Ab_Bi_8x4 + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src_CL_gemm { + #include "mul_mat_Ab_Bi_8x4.cl.h" + }; +#else + const std::string kernel_src_CL_gemm = read_file("mul_mat_Ab_Bi_8x4.cl"); +#endif + backend_ctx->program_CL_gemm = build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src_CL_gemm.c_str(), compile_opts); + CL_CHECK((backend_ctx->CL_mul_mat_Ab_Bi_8x4 = clCreateKernel(backend_ctx->program_CL_gemm, "kernel_mul_mat_Ab_Bi_8x4", &err), err)); + GGML_LOG_CONT("."); + } +#endif // GGML_OPENCL_USE_ADRENO_KERNELS + GGML_LOG_CONT("\n"); +} + static ggml_backend_opencl_context * ggml_cl2_init(ggml_backend_dev_t dev) { static bool initialized = false; static ggml_backend_opencl_context *backend_ctx = nullptr; @@ -612,11 +1345,12 @@ static ggml_backend_opencl_context * ggml_cl2_init(ggml_backend_dev_t dev) { GGML_LOG_INFO("ggml_opencl: OpenCL driver: %s\n", driver_version); backend_ctx->driver_version = driver_version; - int adreno_cl_compiler_version = get_adreno_cl_compiler_version(driver_version); - bool has_vector_subgroup_broadcast = - adreno_cl_compiler_version >= 47 || adreno_cl_compiler_version == 17; + backend_ctx->adreno_cl_compiler_version = get_adreno_cl_compiler_version(driver_version); + backend_ctx->has_vector_subgroup_broadcast = + backend_ctx->adreno_cl_compiler_version.major >= 47 || + backend_ctx->adreno_cl_compiler_version.major == 17; GGML_LOG_INFO("ggml_opencl: vector subgroup broadcast support: %s\n", - has_vector_subgroup_broadcast ? "true" : "false"); + backend_ctx->has_vector_subgroup_broadcast ? "true" : "false"); size_t ext_str_size; clGetDeviceInfo(device, CL_DEVICE_EXTENSIONS, 0, NULL, &ext_str_size); @@ -691,247 +1425,10 @@ static ggml_backend_opencl_context * ggml_cl2_init(ggml_backend_dev_t dev) { #endif CL_CHECK((backend_ctx->queue = clCreateCommandQueue(context, device, command_queue_props, &err), err)); -#ifdef GGML_OPENCL_EMBED_KERNELS - const std::string kernel_src { - #include "ggml-opencl.cl.h" - }; -#else - const std::string kernel_src = read_file("ggml-opencl.cl"); -#endif + // Load kernels + load_cl_kernels(backend_ctx, opencl_c_version); - auto opencl_c_std = - std::string("CL") + std::to_string(opencl_c_version.major) + "." + std::to_string(opencl_c_version.minor); - - std::string compile_opts = std::string("-cl-std=") + opencl_c_std + - " -cl-mad-enable -cl-unsafe-math-optimizations" - " -cl-finite-math-only -cl-fast-relaxed-math"; - backend_ctx->program = build_program_from_source(context, device, kernel_src.c_str(), compile_opts); - - // Non matmul kernels. - CL_CHECK((backend_ctx->kernel_get_rows_f32 = clCreateKernel(backend_ctx->program, "kernel_get_rows_f32", &err), err)); - CL_CHECK((backend_ctx->kernel_get_rows_f16 = clCreateKernel(backend_ctx->program, "kernel_get_rows_f16", &err), err)); - CL_CHECK((backend_ctx->kernel_get_rows_q4_0 = clCreateKernel(backend_ctx->program, "kernel_get_rows_q4_0", &err), err)); - CL_CHECK((backend_ctx->kernel_add = clCreateKernel(backend_ctx->program, "kernel_add", &err), err)); - CL_CHECK((backend_ctx->kernel_add_row = clCreateKernel(backend_ctx->program, "kernel_add_row", &err), err)); - CL_CHECK((backend_ctx->kernel_mul = clCreateKernel(backend_ctx->program, "kernel_mul", &err), err)); - CL_CHECK((backend_ctx->kernel_mul_row = clCreateKernel(backend_ctx->program, "kernel_mul_row", &err), err)); - CL_CHECK((backend_ctx->kernel_scale = clCreateKernel(backend_ctx->program, "kernel_scale", &err), err)); - CL_CHECK((backend_ctx->kernel_silu = clCreateKernel(backend_ctx->program, "kernel_silu", &err), err)); - CL_CHECK((backend_ctx->kernel_silu_4 = clCreateKernel(backend_ctx->program, "kernel_silu_4", &err), err)); - CL_CHECK((backend_ctx->kernel_gelu = clCreateKernel(backend_ctx->program, "kernel_gelu", &err), err)); - CL_CHECK((backend_ctx->kernel_gelu_4 = clCreateKernel(backend_ctx->program, "kernel_gelu_4", &err), err)); - CL_CHECK((backend_ctx->kernel_gelu_quick = clCreateKernel(backend_ctx->program, "kernel_gelu_quick", &err), err)); - CL_CHECK((backend_ctx->kernel_gelu_quick_4 = clCreateKernel(backend_ctx->program, "kernel_gelu_quick_4", &err), err)); - CL_CHECK((backend_ctx->kernel_relu = clCreateKernel(backend_ctx->program, "kernel_relu", &err), err)); - CL_CHECK((backend_ctx->kernel_clamp = clCreateKernel(backend_ctx->program, "kernel_clamp", &err), err)); - CL_CHECK((backend_ctx->kernel_norm = clCreateKernel(backend_ctx->program, "kernel_norm", &err), err)); - CL_CHECK((backend_ctx->kernel_rms_norm = clCreateKernel(backend_ctx->program, "kernel_rms_norm", &err), err)); - CL_CHECK((backend_ctx->kernel_diag_mask_inf = clCreateKernel(backend_ctx->program, "kernel_diag_mask_inf", &err), err)); - CL_CHECK((backend_ctx->kernel_diag_mask_inf_8 = clCreateKernel(backend_ctx->program, "kernel_diag_mask_inf_8", &err), err)); - CL_CHECK((backend_ctx->kernel_soft_max = clCreateKernel(backend_ctx->program, "kernel_soft_max", &err), err)); - CL_CHECK((backend_ctx->kernel_soft_max_4 = clCreateKernel(backend_ctx->program, "kernel_soft_max_4", &err), err)); - CL_CHECK((backend_ctx->kernel_soft_max_f16 = clCreateKernel(backend_ctx->program, "kernel_soft_max_f16", &err), err)); - CL_CHECK((backend_ctx->kernel_soft_max_4_f16 = clCreateKernel(backend_ctx->program, "kernel_soft_max_4_f16", &err), err)); - CL_CHECK((backend_ctx->kernel_rope_norm_f32 = clCreateKernel(backend_ctx->program, "kernel_rope_norm_f32", &err), err)); - CL_CHECK((backend_ctx->kernel_rope_norm_f16 = clCreateKernel(backend_ctx->program, "kernel_rope_norm_f16", &err), err)); - CL_CHECK((backend_ctx->kernel_rope_neox_f32 = clCreateKernel(backend_ctx->program, "kernel_rope_neox_f32", &err), err)); - CL_CHECK((backend_ctx->kernel_rope_neox_f16 = clCreateKernel(backend_ctx->program, "kernel_rope_neox_f16", &err), err)); - CL_CHECK((backend_ctx->kernel_rope_multi_f32 = clCreateKernel(backend_ctx->program, "kernel_rope_multi_f32", &err), err)); - CL_CHECK((backend_ctx->kernel_rope_multi_f16 = clCreateKernel(backend_ctx->program, "kernel_rope_multi_f16", &err), err)); - CL_CHECK((backend_ctx->kernel_rope_vision_f32 = clCreateKernel(backend_ctx->program, "kernel_rope_vision_f32", &err), err)); - CL_CHECK((backend_ctx->kernel_rope_vision_f16 = clCreateKernel(backend_ctx->program, "kernel_rope_vision_f16", &err), err)); - CL_CHECK((backend_ctx->kernel_cpy_f16_f16 = clCreateKernel(backend_ctx->program, "kernel_cpy_f16_f16", &err), err)); - CL_CHECK((backend_ctx->kernel_cpy_f16_f32 = clCreateKernel(backend_ctx->program, "kernel_cpy_f16_f32", &err), err)); - CL_CHECK((backend_ctx->kernel_cpy_f32_f16 = clCreateKernel(backend_ctx->program, "kernel_cpy_f32_f16", &err), err)); - CL_CHECK((backend_ctx->kernel_cpy_f32_f32 = clCreateKernel(backend_ctx->program, "kernel_cpy_f32_f32", &err), err)); - - // Matmul kernels. - CL_CHECK((backend_ctx->kernel_mul_mat_f32_f32 = clCreateKernel(backend_ctx->program, "kernel_mul_mat_f32_f32", &err), err)); - CL_CHECK((backend_ctx->kernel_mul_mat_f16_f16 = clCreateKernel(backend_ctx->program, "kernel_mul_mat_f16_f16", &err), err)); - CL_CHECK((backend_ctx->kernel_mul_mat_f16_f32_1row = clCreateKernel(backend_ctx->program, "kernel_mul_mat_f16_f32_1row", &err), err)); - CL_CHECK((backend_ctx->kernel_mul_mat_f16_f32 = clCreateKernel(backend_ctx->program, "kernel_mul_mat_f16_f32", &err), err)); - CL_CHECK((backend_ctx->kernel_mul_mat_f16_f32_l4 = clCreateKernel(backend_ctx->program, "kernel_mul_mat_f16_f32_l4", &err), err)); - CL_CHECK((backend_ctx->kernel_mul_mat_q4_0_f32 = clCreateKernel(backend_ctx->program, "kernel_mul_mat_q4_0_f32", &err), err)); - CL_CHECK((backend_ctx->kernel_mul_mat_q4_0_f32_v = clCreateKernel(backend_ctx->program, "kernel_mul_mat_q4_0_f32_v", &err), err)); - - CL_CHECK((backend_ctx->kernel_mul_mat_q4_0_f32_flat = clCreateKernel(backend_ctx->program, "kernel_mul_mat_q4_0_f32_flat", &err), err)); - CL_CHECK((backend_ctx->kernel_convert_block_q4_0 = clCreateKernel(backend_ctx->program, "kernel_convert_block_q4_0", &err), err)); - CL_CHECK((backend_ctx->kernel_restore_block_q4_0 = clCreateKernel(backend_ctx->program, "kernel_restore_block_q4_0", &err), err)); - CL_CHECK((backend_ctx->kernel_mul_mat_q4_0_f32_8x_flat = clCreateKernel(backend_ctx->program, "kernel_mul_mat_q4_0_f32_8x_flat", &err), err)); - - // Load additional mulmat kernels. -#ifdef GGML_OPENCL_EMBED_KERNELS - const std::string kernel_src_1 { - #include "ggml-opencl_mm.cl.h" - }; -#else - const std::string kernel_src_1 = read_file("ggml-opencl_mm.cl"); -#endif - backend_ctx->program_1 = build_program_from_source(context, device, kernel_src_1.c_str(), compile_opts); - - CL_CHECK((backend_ctx->kernel_mul_mat_q4_0_f32_1d_8x_flat = clCreateKernel(backend_ctx->program_1, "kernel_mul_mat_q4_0_f32_1d_8x_flat", &err), err)); - CL_CHECK((backend_ctx->kernel_mul_mat_q4_0_f32_1d_16x_flat = clCreateKernel(backend_ctx->program_1, "kernel_mul_mat_q4_0_f32_1d_16x_flat", &err), err)); - CL_CHECK((backend_ctx->kernel_mul_mv_q6_K_f32 = clCreateKernel(backend_ctx->program_1, "kernel_mul_mv_q6_K_f32", &err), err)); - CL_CHECK((backend_ctx->kernel_mul_mat_q4_0_f32_flat_v0 = clCreateKernel(backend_ctx->program_1, "kernel_mul_mat_q4_0_f32_flat_v0", &err), err)); - CL_CHECK((backend_ctx->kernel_mul_mat_q4_0_f32_flat_img_v0 = clCreateKernel(backend_ctx->program_1, "kernel_mul_mat_q4_0_f32_flat_img_v0", &err), err)); - - // Load additional data conversion kernels. -#ifdef GGML_OPENCL_EMBED_KERNELS - const std::string kernel_src_2 { - #include "ggml-opencl_cvt.cl.h" - }; -#else - const std::string kernel_src_2 = read_file("ggml-opencl_cvt.cl"); -#endif - backend_ctx->program_2 = build_program_from_source(context, device, kernel_src_2.c_str(), compile_opts); - - CL_CHECK((backend_ctx->kernel_convert_block_q4_0_noshuffle = clCreateKernel(backend_ctx->program_2, "kernel_convert_block_q4_0_noshuffle", &err), err)); - - // im2col kernels -#ifdef GGML_OPENCL_EMBED_KERNELS - const std::string kernel_src_im2col { - #include "ggml-opencl_im2col.cl.h" - }; -#else - const std::string kernel_src_im2col = read_file("ggml-opencl_im2col.cl"); -#endif - backend_ctx->program_im2col = build_program_from_source(context, device, kernel_src_im2col.c_str(), compile_opts); - - CL_CHECK((backend_ctx->kernel_im2col_f32 = clCreateKernel(backend_ctx->program_im2col, "kernel_im2col_f32", &err), err)); - CL_CHECK((backend_ctx->kernel_im2col_f16 = clCreateKernel(backend_ctx->program_im2col, "kernel_im2col_f16", &err), err)); - - // Kernels for Adreno #ifdef GGML_OPENCL_USE_ADRENO_KERNELS -#ifdef GGML_OPENCL_EMBED_KERNELS - const std::string transpose_32_src { - #include "ggml-opencl_transpose_32.cl.h" - }; -#else - const std::string transpose_32_src = read_file("ggml-opencl_transpose_32.cl"); -#endif - backend_ctx->program_transpose_32 = build_program_from_source(context, device, transpose_32_src.c_str(), compile_opts); - CL_CHECK((backend_ctx->kernel_transpose_32 = clCreateKernel(backend_ctx->program_transpose_32, "kernel_transpose_32", &err), err)); - -#ifdef GGML_OPENCL_EMBED_KERNELS - const std::string transpose_32_16_src { - #include "ggml-opencl_transpose_32_16.cl.h" - }; -#else - const std::string transpose_32_16_src = read_file("ggml-opencl_transpose_32_16.cl"); -#endif - backend_ctx->program_transpose_32_16 = build_program_from_source(context, device, transpose_32_16_src.c_str(), compile_opts); - CL_CHECK((backend_ctx->kernel_transpose_32_16 = clCreateKernel(backend_ctx->program_transpose_32_16, "kernel_transpose_32_16", &err), err)); - -#ifdef GGML_OPENCL_EMBED_KERNELS - const std::string transpose_16_src { - #include "ggml-opencl_transpose_16.cl.h" - }; -#else - const std::string transpose_16_src = read_file("ggml-opencl_transpose_16.cl"); -#endif - backend_ctx->program_transpose_16 = build_program_from_source(context, device, transpose_16_src.c_str(), compile_opts); - CL_CHECK((backend_ctx->kernel_transpose_16 = clCreateKernel(backend_ctx->program_transpose_16, "kernel_transpose_16", &err), err)); - - // Gemv general - std::string CL_gemv_compile_opts = std::string("-cl-std=") + opencl_c_std + - " -cl-mad-enable " - " -DSIMDGROUP_WIDTH=" + - std::to_string(backend_ctx->adreno_wave_size); - if (has_vector_subgroup_broadcast) { - CL_gemv_compile_opts += " -DVECTOR_SUB_GROUP_BROADCAT "; - } -#ifdef GGML_OPENCL_EMBED_KERNELS - const std::string kernel_src_CL_gemv_general { - #include "ggml-opencl_gemv_noshuffle_general.cl.h" - }; -#else - const std::string kernel_src_CL_gemv_general = read_file("ggml-opencl_gemv_noshuffle_general.cl"); -#endif - - backend_ctx->program_CL_gemv_general = build_program_from_source( - context, device, kernel_src_CL_gemv_general.c_str(), CL_gemv_compile_opts); - CL_CHECK((backend_ctx->CL_mul_mat_vec_q4_0_f32_1d_4x_flat_general = clCreateKernel(backend_ctx->program_CL_gemv_general, "kernel_gemv_noshuffle", &err), err)); - - // Gemv 2048, 16384 - CL_gemv_compile_opts = std::string("-cl-std=") + opencl_c_std + - " -cl-mad-enable " - " -DLINE_STRIDE_A=2048 " - " -DBLOCK_STRIDE_A=16384 " - " -DSIMDGROUP_WIDTH=" + - std::to_string(backend_ctx->adreno_wave_size); - if (has_vector_subgroup_broadcast) { - CL_gemv_compile_opts += " -DVECTOR_SUB_GROUP_BROADCAT "; - } -#ifdef GGML_OPENCL_EMBED_KERNELS - const std::string kernel_src_CL_gemv { - #include "ggml-opencl_gemv_noshuffle.cl.h" - }; -#else - const std::string kernel_src_CL_gemv = read_file("ggml-opencl_gemv_noshuffle.cl"); -#endif - - backend_ctx->program_CL_gemv_4096_1_4096 = build_program_from_source( - context, device, kernel_src_CL_gemv.c_str(), CL_gemv_compile_opts); - CL_CHECK((backend_ctx->CL_mul_mat_vec_q4_0_f32_1d_4x_flat_4096_1_4096 = clCreateKernel(backend_ctx->program_CL_gemv_4096_1_4096, "kernel_gemv_noshuffle", &err), err)); - - // Gemv 2048, 16384 - CL_gemv_compile_opts = std::string("-cl-std=") + opencl_c_std + - " -cl-mad-enable " - " -DLINE_STRIDE_A=2048 " - " -DBLOCK_STRIDE_A=16384 " - " -DSIMDGROUP_WIDTH=" + - std::to_string(backend_ctx->adreno_wave_size); - if (has_vector_subgroup_broadcast) { - CL_gemv_compile_opts += " -DVECTOR_SUB_GROUP_BROADCAT "; - } - - backend_ctx->program_CL_gemv_4096_1_11008 = build_program_from_source( - context, device, kernel_src_CL_gemv.c_str(), CL_gemv_compile_opts); - CL_CHECK((backend_ctx->CL_mul_mat_vec_q4_0_f32_1d_4x_flat_4096_1_11008 = clCreateKernel(backend_ctx->program_CL_gemv_4096_1_11008, "kernel_gemv_noshuffle", &err), err)); - - // Gemv 5504, 44032 - CL_gemv_compile_opts = std::string("-cl-std=") + opencl_c_std + - " -cl-mad-enable " - " -DLINE_STRIDE_A=5504 " - " -DBLOCK_STRIDE_A=44032 " - " -DSIMDGROUP_WIDTH=" + - std::to_string(backend_ctx->adreno_wave_size); - if (has_vector_subgroup_broadcast) { - CL_gemv_compile_opts += " -DVECTOR_SUB_GROUP_BROADCAT "; - } - - backend_ctx->program_CL_gemv_11008_1_4096 = build_program_from_source( - context, device, kernel_src_CL_gemv.c_str(), CL_gemv_compile_opts); - CL_CHECK((backend_ctx->CL_mul_mat_vec_q4_0_f32_1d_4x_flat_11008_1_4096 = clCreateKernel(backend_ctx->program_CL_gemv_11008_1_4096, "kernel_gemv_noshuffle", &err), err)); - - // Gemv 16000, 128000 - CL_gemv_compile_opts = std::string("-cl-std=") + opencl_c_std + - " -cl-mad-enable " - " -DLINE_STRIDE_A=16000 " - " -DBLOCK_STRIDE_A=128000 " - " -DSIMDGROUP_WIDTH=" + - std::to_string(backend_ctx->adreno_wave_size); - if (has_vector_subgroup_broadcast) { - CL_gemv_compile_opts += " -DVECTOR_SUB_GROUP_BROADCAT "; - } - - backend_ctx->program_CL_gemv_32000_1_4096 = build_program_from_source(context, device, kernel_src_CL_gemv.c_str(), CL_gemv_compile_opts); - CL_CHECK((backend_ctx->CL_mul_mat_vec_q4_0_f32_1d_4x_flat_32000_1_4096 = clCreateKernel(backend_ctx->program_CL_gemv_32000_1_4096, "kernel_gemv_noshuffle", &err), err)); - - // Gemm -#ifdef GGML_OPENCL_EMBED_KERNELS - const std::string kernel_src_CL_gemm { - #include "ggml-opencl_mul_mat_Ab_Bi_8x4.cl.h" - }; -#else - const std::string kernel_src_CL_gemm = read_file("ggml-opencl_mul_mat_Ab_Bi_8x4.cl"); -#endif - backend_ctx->program_CL_gemm = build_program_from_source(context, device, kernel_src_CL_gemm.c_str(), compile_opts); - CL_CHECK((backend_ctx->CL_mul_mat_Ab_Bi_8x4 = clCreateKernel(backend_ctx->program_CL_gemm, "kernel_mul_mat_Ab_Bi_8x4", &err), err)); - - // TODO: fixme: these sizes are hardcoded for now. - // they should be allocated based on the model's size - // and the device's max alloc size // Allocate intermediate buffers and images size_t required_A_q_d_bytes = 311164928; size_t required_A_s_d_bytes = 38895616; @@ -1495,8 +1992,15 @@ static enum ggml_status ggml_backend_opencl_buffer_init_tensor(ggml_backend_buff // The optimized gemm and gemv kernels are used for large matrices without batch. // tensor is the quantized weights matrix. -inline bool use_adreno_kernels(const ggml_tensor *tensor) { - return tensor->ne[0] >= 512 && tensor->ne[1] >= 512 && +inline bool use_adreno_kernels(const ggml_backend_opencl_context *backend_ctx, const ggml_tensor *tensor) { + int64_t threshold_ne0 = 512; + int64_t threshold_ne1 = 512; + if (!backend_ctx->adreno_cl_compiler_version.newer_than_or_same(E031, 38, 11, 0) && + backend_ctx->adreno_cl_compiler_version.type != DX) { + threshold_ne0 = 128; + threshold_ne1 = 128; + } + return tensor->ne[0] >= threshold_ne0 && tensor->ne[1] >= threshold_ne1 && tensor->ne[2] == 1 && tensor->ne[3] == 1; } @@ -1574,7 +2078,7 @@ static void ggml_backend_opencl_buffer_set_tensor(ggml_backend_buffer_t buffer, cl_kernel kernel = backend_ctx->kernel_convert_block_q4_0; // The optimized kernels need weights in natural order, so unshuffle. - if (use_adreno_kernels(tensor)) { + if (use_adreno_kernels(backend_ctx, tensor)) { kernel = backend_ctx->kernel_convert_block_q4_0_noshuffle; } #else @@ -1598,7 +2102,7 @@ static void ggml_backend_opencl_buffer_set_tensor(ggml_backend_buffer_t buffer, #ifdef GGML_OPENCL_USE_ADRENO_KERNELS // Only do transpose for large, non batched matrix // TODO: use preallocated images instead of sub-buffer then image - if (use_adreno_kernels(tensor)) { + if (use_adreno_kernels(backend_ctx, tensor)) { // <----------------------------------------------------------------------------------> // // start transpose // <----------------------------------------------------------------------------------> // @@ -2899,8 +3403,8 @@ static void ggml_cl_rms_norm(ggml_backend_t backend, const ggml_tensor * src0, c ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; cl_command_queue queue = backend_ctx->queue; - ggml_backend_opencl_device_context * dev_ctx = - (ggml_backend_opencl_device_context *)backend->device->context; + //ggml_backend_opencl_device_context * dev_ctx = + // (ggml_backend_opencl_device_context *)backend->device->context; ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra; ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra; @@ -2931,13 +3435,20 @@ static void ggml_cl_rms_norm(ggml_backend_t backend, const ggml_tensor * src0, c // Note, this kernel declares local memory in kernel args and the size // depends on subgroup size. - // Retrieve subgroup size. // Note, this requires OpenCL 2.1 and above + // For now we use fixed subgroup size to simplify support for OpenCL 2.0. size_t sgs; - CL_CHECK(clGetKernelSubGroupInfo(kernel, dev_ctx->device, - CL_KERNEL_MAX_SUB_GROUP_SIZE_FOR_NDRANGE, - sizeof(local_work_size), local_work_size, - sizeof(size_t), &sgs, NULL)); + //CL_CHECK(clGetKernelSubGroupInfo(kernel, dev_ctx->device, + // CL_KERNEL_MAX_SUB_GROUP_SIZE_FOR_NDRANGE, + // sizeof(local_work_size), local_work_size, + // sizeof(size_t), &sgs, NULL)); + if (backend_ctx->gpu_family == ADRENO) { + sgs = 64; + } else if (backend_ctx->gpu_family == INTEL) { + sgs = 32; + } else { + GGML_ASSERT(false && "Unsupported GPU"); + } CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); @@ -3030,7 +3541,7 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co #ifdef GGML_OPENCL_USE_ADRENO_KERNELS cl_context context = backend_ctx->context; - if (ne01 && ne1 && use_adreno_kernels(src0)) { + if (ne01 && ne1 && use_adreno_kernels(backend_ctx, src0)) { // init CL objects // <--------------------------------------------> // diff --git a/ggml/src/ggml-opencl/kernels/add.cl b/ggml/src/ggml-opencl/kernels/add.cl new file mode 100644 index 000000000..f73f3c013 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/add.cl @@ -0,0 +1,83 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +//------------------------------------------------------------------------------ +// add +//------------------------------------------------------------------------------ + +// general-purpose kernel for addition of two tensors +// pros: works for non-contiguous tensors, supports broadcast across dims 1, 2 and 3 +// cons: not very efficient +kernel void kernel_add( + global char * src0, + ulong offset0, + global char * src1, + ulong offset1, + global char * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + int ne03, + ulong nb00, + ulong nb01, + ulong nb02, + ulong nb03, + int ne10, + int ne11, + int ne12, + int ne13, + ulong nb10, + ulong nb11, + ulong nb12, + ulong nb13, + int ne0, + int ne1, + int ne2, + int ne3, + ulong nb0, + ulong nb1, + ulong nb2, + ulong nb3 +) { + src0 = src0 + offset0; + src1 = src1 + offset1; + dst = dst + offsetd; + + int i03 = get_group_id(2); + int i02 = get_group_id(1); + int i01 = get_group_id(0); + + int i13 = i03 % ne13; + int i12 = i02 % ne12; + int i11 = i01 % ne11; + + global char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01; + global char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11; + global char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1; + + for (int i0 = get_local_id(0); i0 < ne0; i0 += get_local_size(0)) { + const int i10 = i0 % ne10; + *((global float *)(dst_ptr + i0*nb0)) = *((global float *)(src0_ptr + i0*nb00)) + *((global float *)(src1_ptr + i10*nb10)); + } +} + +// assumption: src1 is a row +// broadcast src1 into src0 +kernel void kernel_add_row( + global float4 * src0, + ulong offset0, + global float4 * src1, + ulong offset1, + global float4 * dst, + ulong offsetd, + int ne +) { + src0 = (global float4*)((global char*)src0 + offset0); + src1 = (global float4*)((global char*)src1 + offset1); + dst = (global float4*)((global char*)dst + offsetd); + + // This performs better than using %. + uint gid = get_global_id(0); + uint idx1 = gid - (gid/ne)*ne; // get_global_id(0) % ne + dst[gid] = src0[gid] + src1[idx1]; +} diff --git a/ggml/src/ggml-opencl/kernels/clamp.cl b/ggml/src/ggml-opencl/kernels/clamp.cl new file mode 100644 index 000000000..ae6032444 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/clamp.cl @@ -0,0 +1,20 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +//------------------------------------------------------------------------------ +// clamp +//------------------------------------------------------------------------------ +kernel void kernel_clamp( + global float * src0, + ulong offset0, + global float * dst, + ulong offsetd, + float min, + float max +) { + src0 = (global float*)((global char*)src0 + offset0); + dst = (global float*)((global char*)dst + offsetd); + + dst[get_global_id(0)] = src0[get_global_id(0)] < min ? + min : + (src0[get_global_id(0)] > max ? max : src0[get_global_id(0)]); +} diff --git a/ggml/src/ggml-opencl/kernels/cpy.cl b/ggml/src/ggml-opencl/kernels/cpy.cl new file mode 100644 index 000000000..9369351a6 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/cpy.cl @@ -0,0 +1,184 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +//------------------------------------------------------------------------------ +// cpy +//------------------------------------------------------------------------------ + +kernel void kernel_cpy_f16_f16( + global half * src0, + ulong offset0, + global half * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + int ne03, + ulong nb00, + ulong nb01, + ulong nb02, + ulong nb03, + int ne0, + int ne1, + int ne2, + int ne3, + ulong nb0, + ulong nb1, + ulong nb2, + ulong nb3 +) { + src0 = (global half*)((global char*)src0 + offset0); + dst = (global half*)((global char*)dst + offsetd); + + int i03 = get_group_id(2); + int i02 = get_group_id(1); + int i01 = get_group_id(0); + + int n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; + + int i3 = n / (ne2*ne1*ne0); + int i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0); + int i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0; + int i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0); + + global half * dst_data = (global half *) ((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + + for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) { + global const half * src = (global half *)((global char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00); + dst_data[i00] = src[0]; + } +} + +kernel void kernel_cpy_f16_f32( + global half * src0, + ulong offset0, + global float * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + int ne03, + ulong nb00, + ulong nb01, + ulong nb02, + ulong nb03, + int ne0, + int ne1, + int ne2, + int ne3, + ulong nb0, + ulong nb1, + ulong nb2, + ulong nb3 +) { + + src0 = (global half*)((global char*)src0 + offset0); + dst = (global float*)((global char*)dst + offsetd); + + int i03 = get_group_id(2); + int i02 = get_group_id(1); + int i01 = get_group_id(0); + + int n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; + + int i3 = n / (ne2*ne1*ne0); + int i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0); + int i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0; + int i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0); + + global float * dst_data = (global float *) ((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + + for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) { + global half * src = (global half *)((global char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00); + dst_data[i00] = src[0]; + } +} + +kernel void kernel_cpy_f32_f16( + global float * src0, + ulong offset0, + global half * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + int ne03, + ulong nb00, + ulong nb01, + ulong nb02, + ulong nb03, + int ne0, + int ne1, + int ne2, + int ne3, + ulong nb0, + ulong nb1, + ulong nb2, + ulong nb3 +) { + src0 = (global float*)((global char*)src0 + offset0); + dst = (global half*)((global char*)dst + offsetd); + + int i03 = get_group_id(2); + int i02 = get_group_id(1); + int i01 = get_group_id(0); + + int n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; + + int i3 = n / (ne2*ne1*ne0); + int i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0); + int i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0; + int i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0); + + global half * dst_data = (global half *) ((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + + for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) { + global const float * src = (global float *)((global char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00); + + dst_data[i00] = src[0]; + } +} + +kernel void kernel_cpy_f32_f32( + global float * src0, + ulong offset0, + global float * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + int ne03, + ulong nb00, + ulong nb01, + ulong nb02, + ulong nb03, + int ne0, + int ne1, + int ne2, + int ne3, + ulong nb0, + ulong nb1, + ulong nb2, + ulong nb3 +) { + src0 = (global float*)((global char*)src0 + offset0); + dst = (global float*)((global char*)dst + offsetd); + + int i03 = get_group_id(2); + int i02 = get_group_id(1); + int i01 = get_group_id(0); + + int n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; + + int i3 = n / (ne2*ne1*ne0); + int i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0); + int i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0; + int i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0); + + global float * dst_data = (global float *) ((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + + for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) { + global const float * src = (global float *)((global char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00); + + dst_data[i00] = src[0]; + } +} diff --git a/ggml/src/ggml-opencl/kernels/ggml-opencl_cvt.cl b/ggml/src/ggml-opencl/kernels/cvt.cl similarity index 69% rename from ggml/src/ggml-opencl/kernels/ggml-opencl_cvt.cl rename to ggml/src/ggml-opencl/kernels/cvt.cl index e2024332f..fe7975e3d 100644 --- a/ggml/src/ggml-opencl/kernels/ggml-opencl_cvt.cl +++ b/ggml/src/ggml-opencl/kernels/cvt.cl @@ -1,39 +1,20 @@ //------------------------------------------------------------------------------ -// This file is contains additional kernels for data conversion. +// This file is contains kernels for data conversion. // These kernels are used when loading the model, so its performance is less // important. //------------------------------------------------------------------------------ -#ifdef cl_khr_fp16 #pragma OPENCL EXTENSION cl_khr_fp16 : enable -#elif defined(cl_amd_fp16) -#pragma OPENCL EXTENSION cl_amd_fp16 : enable -#else -#error "Half precision floating point not supportedby OpenCL implementation on your device." -#endif - -#ifdef cl_khr_subgroups -#pragma OPENCL EXTENSION cl_khr_subgroups : enable -#elif defined(cl_intel_subgroups) -#pragma OPENCL EXTENSION cl_intel_subgroups : enable -#else -#error "Subgroup not supported on your device." -#endif #ifdef cl_intel_required_subgroup_size -// Always use subgroup size of 32 on Intel. #pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable #define INTEL_GPU 1 #define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16))) #define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32))) #elif defined(cl_qcom_reqd_sub_group_size) -// Always use subgroups size of 64 on Adreno. #pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable #define ADRENO_GPU 1 #define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half"))) #define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full"))) -#else -// TODO: do not know how to choose subgroup size on other GPUs. -#error "Selecting subgroup size is not supported on your device." #endif #define QK4_0 32 @@ -66,13 +47,44 @@ struct block_q4_0 }; //------------------------------------------------------------------------------ -// mul_vec_q_n_f32_flat_noshuffle -// -// This variation uses flat arrays (struct of arrays, SOA) representation for -// quant tensors. It also uses non shuffled bit order for weights. -// -// The shuffled version is kept in the original file because moving it here -// seems to result in worse performance for adreno. +// kernel_convert_block_q4_0 +// Convert the block_q4_0 format to 2 separate arrays (AOS -> SOA). +// This kernel does not deshuffle the bits. +//------------------------------------------------------------------------------ +kernel void kernel_convert_block_q4_0( + global struct block_q4_0 * src0, + global uchar * dst_q, + global half * dst_d +) { + global struct block_q4_0 * b = (global struct block_q4_0 *) src0 + get_global_id(0); + global uchar * q = (global uchar *) dst_q + QK4_0/2*get_global_id(0); + global half * d = (global half *) dst_d + get_global_id(0); + + *d = b->d; + + for (int i = 0; i < QK4_0/2; ++i) { + q[i] = b->qs[i]; + } +} + +kernel void kernel_restore_block_q4_0( + global uchar * src_q, + global half * src_d, + global struct block_q4_0 * dst +) { + global struct block_q4_0 * b = (global struct block_q4_0 *) dst + get_global_id(0); + global uchar * q = (global uchar *) src_q + QK4_0/2*get_global_id(0); + global half * d = (global half *) src_d + get_global_id(0); + + b->d = *d; + for (int i = 0; i < QK4_0/2; ++i) { + b->qs[i] = q[i]; + } +} + +//------------------------------------------------------------------------------ +// kernel_convert_block_q4_0_noshuffle +// Flatten q4_0 weights and unshuffle the bits //------------------------------------------------------------------------------ kernel void kernel_convert_block_q4_0_noshuffle( diff --git a/ggml/src/ggml-opencl/kernels/diag_mask_inf.cl b/ggml/src/ggml-opencl/kernels/diag_mask_inf.cl new file mode 100644 index 000000000..36eff0439 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/diag_mask_inf.cl @@ -0,0 +1,58 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +//------------------------------------------------------------------------------ +// diag_mask_inf kernels +//------------------------------------------------------------------------------ +kernel void kernel_diag_mask_inf( + global float * src0, + ulong offset0, + global float * dst, + ulong offsetd, + int ne00, + int ne01, + int n_past +) { + src0 = (global float*)((global char*)src0 + offset0); + dst = (global float*)((global char*)dst + offsetd); + + int i02 = get_global_id(2); + int i01 = get_global_id(1); + int i00 = get_global_id(0); + + if (i00 > n_past + i01) { + dst[i02*ne01*ne00 + i01*ne00 + i00] = -INFINITY; + } else { + dst[i02*ne01*ne00 + i01*ne00 + i00] = src0[i02*ne01*ne00 + i01*ne00 + i00]; + } +} + +kernel void kernel_diag_mask_inf_8( + global float4 * src0, + ulong offset0, + global float4 * dst, + ulong offsetd, + int ne00, + int ne01, + int n_past +) { + src0 = (global float4*)((global char*)src0 + offset0); + dst = (global float4*)((global char*)dst + offsetd); + + int i = 2*get_global_id(0); + + dst[i+0] = src0[i+0]; + dst[i+1] = src0[i+1]; + int i4 = 4*i; + int i02 = i4/(ne00*ne01); i4 -= i02*ne00*ne01; + int i01 = i4/(ne00); i4 -= i01*ne00; + int i00 = i4; + for (int k = 3; k >= 0; --k) { + if (i00 + 4 + k <= n_past + i01) { + break; + } + (&dst[i+1])[k] = -INFINITY; + if (i00 + k > n_past + i01) { + (&dst[i])[k] = -INFINITY; + } + } +} diff --git a/ggml/src/ggml-opencl/kernels/gelu.cl b/ggml/src/ggml-opencl/kernels/gelu.cl new file mode 100644 index 000000000..71c310cc9 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/gelu.cl @@ -0,0 +1,62 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +//------------------------------------------------------------------------------ +// gelu +//------------------------------------------------------------------------------ +#define GELU_COEF_A 0.044715f +#define GELU_QUICK_COEF -1.702f +#define SQRT_2_OVER_PI 0.79788456080286535587989211986876f + +kernel void kernel_gelu( + global float * src0, + ulong offset0, + global float * dst, + ulong offsetd +) { + src0 = (global float*)((global char*)src0 + offset0); + dst = (global float*)((global char*)dst + offsetd); + + float x = src0[get_global_id(0)]; + + dst[get_global_id(0)] = 0.5f*x*(1.0f + tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x))); +} + +kernel void kernel_gelu_4( + global float4 * src0, + ulong offset0, + global float4 * dst, + ulong offsetd +) { + src0 = (global float4*)((global char*)src0 + offset0); + dst = (global float4*)((global char*)dst + offsetd); + + float4 x = src0[get_global_id(0)]; + + dst[get_global_id(0)] = 0.5f*x*(1.0f + tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x))); +} + +kernel void kernel_gelu_quick( + global float * src0, + ulong offset0, + global float * dst, + ulong offsetd +) { + src0 = (global float*)((global char*)src0 + offset0); + dst = (global float*)((global char*)dst + offsetd); + + float x = src0[get_global_id(0)]; + dst[get_global_id(0)] = x*(1.0f/(1.0f+exp(GELU_QUICK_COEF*x))); +} + +kernel void kernel_gelu_quick_4( + global float4 * src0, + ulong offset0, + global float4 * dst, + ulong offsetd +) { + src0 = (global float4*)((global char*)src0 + offset0); + dst = (global float4*)((global char*)dst + offsetd); + + float4 x = src0[get_global_id(0)]; + dst[get_global_id(0)] = x*(1.0f/(1.0f+exp(GELU_QUICK_COEF*x))); +} diff --git a/ggml/src/ggml-opencl/kernels/ggml-opencl_gemv_noshuffle.cl b/ggml/src/ggml-opencl/kernels/gemv_noshuffle.cl similarity index 100% rename from ggml/src/ggml-opencl/kernels/ggml-opencl_gemv_noshuffle.cl rename to ggml/src/ggml-opencl/kernels/gemv_noshuffle.cl diff --git a/ggml/src/ggml-opencl/kernels/ggml-opencl_gemv_noshuffle_general.cl b/ggml/src/ggml-opencl/kernels/gemv_noshuffle_general.cl similarity index 100% rename from ggml/src/ggml-opencl/kernels/ggml-opencl_gemv_noshuffle_general.cl rename to ggml/src/ggml-opencl/kernels/gemv_noshuffle_general.cl diff --git a/ggml/src/ggml-opencl/kernels/get_rows.cl b/ggml/src/ggml-opencl/kernels/get_rows.cl new file mode 100644 index 000000000..b3fea2923 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/get_rows.cl @@ -0,0 +1,163 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +typedef char int8_t; +typedef uchar uint8_t; +typedef short int16_t; +typedef ushort uint16_t; +typedef int int32_t; +typedef uint uint32_t; + +#define QK4_0 32 + +//------------------------------------------------------------------------------ +// block_q4_0 +//------------------------------------------------------------------------------ +struct block_q4_0 +{ + half d; + uint8_t qs[QK4_0 / 2]; +}; + + +//------------------------------------------------------------------------------ +// dequantize_q4_0_f32, dequantize_q4_0_f16 +//------------------------------------------------------------------------------ +void dequantize_q4_0_f32(global struct block_q4_0 * xb, short il, float16 * reg) { + global ushort * qs = ((global ushort *)xb + 1); + float d1 = il ? (xb->d / 16.h) : xb->d; + float d2 = d1 / 256.f; + float md = -8.h * xb->d; + ushort mask0 = il ? 0x00F0 : 0x000F; + ushort mask1 = mask0 << 8; + + reg->s0 = d1 * (qs[0] & mask0) + md; + reg->s1 = d2 * (qs[0] & mask1) + md; + + reg->s2 = d1 * (qs[1] & mask0) + md; + reg->s3 = d2 * (qs[1] & mask1) + md; + + reg->s4 = d1 * (qs[2] & mask0) + md; + reg->s5 = d2 * (qs[2] & mask1) + md; + + reg->s6 = d1 * (qs[3] & mask0) + md; + reg->s7 = d2 * (qs[3] & mask1) + md; + + reg->s8 = d1 * (qs[4] & mask0) + md; + reg->s9 = d2 * (qs[4] & mask1) + md; + + reg->sa = d1 * (qs[5] & mask0) + md; + reg->sb = d2 * (qs[5] & mask1) + md; + + reg->sc = d1 * (qs[6] & mask0) + md; + reg->sd = d2 * (qs[6] & mask1) + md; + + reg->se = d1 * (qs[7] & mask0) + md; + reg->sf = d2 * (qs[7] & mask1) + md; +} + + +//------------------------------------------------------------------------------ +// get_rows +//------------------------------------------------------------------------------ +kernel void kernel_get_rows_f32( + global void * src0, + ulong offset0, + global int * src1, + ulong offset1, + global float * dst, + ulong offsetd, + int ne00, + ulong nb01, + ulong nb02, + int ne10, + ulong nb10, + ulong nb11, + ulong nb1, + ulong nb2 +) { + src0 = (global void*)((global char*)src0 + offset0); + src1 = (global int*)((global char*)src1 + offset1); + dst = (global float*)((global char*)dst + offsetd); + + int i10 = get_group_id(0); + int i11 = get_group_id(1); + + int r = ((global int *) ((global char *) src1 + i11*nb11 + i10*nb10))[0]; + + int i02 = i11; + + for (int ind = get_local_id(0); ind < ne00; ind += get_local_size(0)) { + ((global float *) ((global char *) dst + i11*nb2 + i10*nb1))[ind] = + ((global float *) ((global char *) src0 + r*nb01 + i02*nb02))[ind]; + } +} + +kernel void kernel_get_rows_f16( + global void * src0, + ulong offset0, + global int * src1, + ulong offset1, + global float * dst, + ulong offsetd, + int ne00, + ulong nb01, + ulong nb02, + int ne10, + ulong nb10, + ulong nb11, + ulong nb1, + ulong nb2 +) { + src0 = (global void*)((global char*)src0 + offset0); + src1 = (global int*)((global char*)src1 + offset1); + dst = (global float*)((global char*)dst + offsetd); + + int i10 = get_group_id(0); + int i11 = get_group_id(1); + + int r = ((global int32_t *) ((global char *) src1 + i11*nb11 + i10*nb10))[0]; + + int i02 = i11; + + for (int ind = get_local_id(0); ind < ne00; ind += get_local_size(0)) { + ((global float *) ((global char *) dst + i11*nb2 + i10*nb1))[ind] = + ((global half *) ((global char *) src0 + r*nb01 + i02*nb02))[ind]; + } +} + +kernel void kernel_get_rows_q4_0( + global void * src0, + ulong offset0, + global int * src1, + ulong offset1, + global float * dst, + ulong offsetd, + int ne00, + ulong nb01, + ulong nb02, + int ne10, + ulong nb10, + ulong nb11, + ulong nb1, + ulong nb2 +) { + src0 = (global void*)((global char*)src0 + offset0); + src1 = (global int*)((global char*)src1 + offset1); + dst = (global float*)((global char*)dst + offsetd); + + const int NL = 2; + + int i10 = get_group_id(0); + int i11 = get_group_id(1); + + int r = ((global int32_t *) ((global char *) src1 + i11*nb11 + i10*nb10))[0]; + + int i02 = i11; + + for (int ind = get_local_id(0); ind < ne00/16; ind += get_local_size(0)) { + float16 temp; + dequantize_q4_0_f32( + ((global struct block_q4_0 *) ((global char *) src0 + r*nb01 + i02*nb02)) + ind/NL, ind%NL, &temp); + *(((global float16 *) ((global char *) dst + i11*nb2 + i10*nb1)) + ind) = temp; + } +} diff --git a/ggml/src/ggml-opencl/kernels/ggml-opencl.cl b/ggml/src/ggml-opencl/kernels/ggml-opencl.cl deleted file mode 100644 index b88792887..000000000 --- a/ggml/src/ggml-opencl/kernels/ggml-opencl.cl +++ /dev/null @@ -1,3231 +0,0 @@ -#ifdef cl_khr_fp16 -#pragma OPENCL EXTENSION cl_khr_fp16 : enable -#elif defined(cl_amd_fp16) -#pragma OPENCL EXTENSION cl_amd_fp16 : enable -#else -#error "Half precision floating point not supportedby OpenCL implementation on your device." -#endif - -#ifdef cl_khr_subgroups -#pragma OPENCL EXTENSION cl_khr_subgroups : enable -#elif defined(cl_intel_subgroups) -#pragma OPENCL EXTENSION cl_intel_subgroups : enable -#else -#error "Subgroup not supported on your device." -#endif - -#ifdef cl_intel_required_subgroup_size -// Always use subgroup size of 32 on Intel. -#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable -#define INTEL_GPU 1 -#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16))) -#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32))) -#elif defined(cl_qcom_reqd_sub_group_size) -// Always use subgroups size of 64 on Adreno. -#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable -#define ADRENO_GPU 1 -#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half"))) -#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full"))) -#else -// TODO: do not know how to choose subgroup size on other GPUs. -#error "Selecting subgroup size is not supported on your device." -#endif - -#define QK4_0 32 -#define QR4_0 2 -#define QK4_1 32 -#define QR4_1 2 -#define QK5_0 32 -#define QR5_0 2 -#define QK5_1 32 -#define QR5_1 2 -#define QK8_0 32 -#define QR8_0 1 -#define QK_K 256 -#define K_QUANTS_PER_ITERATION 2 - -typedef char int8_t; -typedef uchar uint8_t; -typedef short int16_t; -typedef ushort uint16_t; -typedef int int32_t; -typedef uint uint32_t; - -//------------------------------------------------------------------------------ -// block_q4_0 -//------------------------------------------------------------------------------ -struct block_q4_0 -{ - half d; - uint8_t qs[QK4_0 / 2]; -}; - -//------------------------------------------------------------------------------ -// block_q4_1 -//------------------------------------------------------------------------------ -struct block_q4_1 -{ - half d; - half m; - uint8_t qs[QK4_1 / 2]; -}; - -//------------------------------------------------------------------------------ -// block_q5_0 -//------------------------------------------------------------------------------ -struct block_q5_0 -{ - half d; - uint32_t qh; - uint8_t qs[QK5_0 / 2]; -}; - -//------------------------------------------------------------------------------ -// block_q5_1 -//------------------------------------------------------------------------------ -struct block_q5_1 -{ - half d; - half m; - uint32_t qh; - uint8_t qs[QK5_1 / 2]; -}; - -//------------------------------------------------------------------------------ -// block_q8_0 -//------------------------------------------------------------------------------ -struct block_q8_0 -{ - half d; - int8_t qs[QK8_0]; -}; - -//------------------------------------------------------------------------------ -// block_q2_K -//------------------------------------------------------------------------------ -struct block_q2_K -{ - uint8_t scales[16]; - uint8_t qs[64]; - half d; - half dmin; -}; - -//------------------------------------------------------------------------------ -// block_q3_K -//------------------------------------------------------------------------------ -struct block_q3_K -{ - uint8_t hmask[32]; - uint8_t qs[64]; - uint8_t scales[12]; - half d; -}; - -//------------------------------------------------------------------------------ -// block_q4_K -//------------------------------------------------------------------------------ -struct block_q4_K -{ - half d; - half dmin; - uint8_t scales[12]; - uint8_t qs[128]; -}; - -//------------------------------------------------------------------------------ -// block_q5_K -//------------------------------------------------------------------------------ -struct block_q5_K -{ - half d; - half dmin; - uint8_t scales[12]; - uint8_t qh[32]; - uint8_t qs[128]; -}; - -//------------------------------------------------------------------------------ -// block_q6_K -//------------------------------------------------------------------------------ -struct block_q6_K -{ - uint8_t ql[128]; - uint8_t qh[64]; - int8_t scales[16]; - half d; -}; - -//------------------------------------------------------------------------------ -// dequantize_q4_0_f32, dequantize_q4_0_f16 -//------------------------------------------------------------------------------ -void dequantize_q4_0_f32(global struct block_q4_0 * xb, short il, float16 * reg) { - global ushort * qs = ((global ushort *)xb + 1); - float d1 = il ? (xb->d / 16.h) : xb->d; - float d2 = d1 / 256.f; - float md = -8.h * xb->d; - ushort mask0 = il ? 0x00F0 : 0x000F; - ushort mask1 = mask0 << 8; - - reg->s0 = d1 * (qs[0] & mask0) + md; - reg->s1 = d2 * (qs[0] & mask1) + md; - - reg->s2 = d1 * (qs[1] & mask0) + md; - reg->s3 = d2 * (qs[1] & mask1) + md; - - reg->s4 = d1 * (qs[2] & mask0) + md; - reg->s5 = d2 * (qs[2] & mask1) + md; - - reg->s6 = d1 * (qs[3] & mask0) + md; - reg->s7 = d2 * (qs[3] & mask1) + md; - - reg->s8 = d1 * (qs[4] & mask0) + md; - reg->s9 = d2 * (qs[4] & mask1) + md; - - reg->sa = d1 * (qs[5] & mask0) + md; - reg->sb = d2 * (qs[5] & mask1) + md; - - reg->sc = d1 * (qs[6] & mask0) + md; - reg->sd = d2 * (qs[6] & mask1) + md; - - reg->se = d1 * (qs[7] & mask0) + md; - reg->sf = d2 * (qs[7] & mask1) + md; -} - -void dequantize_q4_0_f16(global struct block_q4_0 * xb, short il, half16 * reg) { - global ushort * qs = ((global ushort *)xb + 1); - half d1 = il ? (xb->d / 16.h) : xb->d; - half d2 = d1 / 256.h; - half md = -8.h * xb->d; - ushort mask0 = il ? 0x00F0 : 0x000F; - ushort mask1 = mask0 << 8; - - reg->s0 = d1 * (qs[0] & mask0) + md; - reg->s1 = d2 * (qs[0] & mask1) + md; - - reg->s2 = d1 * (qs[1] & mask0) + md; - reg->s3 = d2 * (qs[1] & mask1) + md; - - reg->s4 = d1 * (qs[2] & mask0) + md; - reg->s5 = d2 * (qs[2] & mask1) + md; - - reg->s6 = d1 * (qs[3] & mask0) + md; - reg->s7 = d2 * (qs[3] & mask1) + md; - - reg->s8 = d1 * (qs[4] & mask0) + md; - reg->s9 = d2 * (qs[4] & mask1) + md; - - reg->sa = d1 * (qs[5] & mask0) + md; - reg->sb = d2 * (qs[5] & mask1) + md; - - reg->sc = d1 * (qs[6] & mask0) + md; - reg->sd = d2 * (qs[6] & mask1) + md; - - reg->se = d1 * (qs[7] & mask0) + md; - reg->sf = d2 * (qs[7] & mask1) + md; -} - -//------------------------------------------------------------------------------ -// add -//------------------------------------------------------------------------------ - -// general-purpose kernel for addition of two tensors -// pros: works for non-contiguous tensors, supports broadcast across dims 1, 2 and 3 -// cons: not very efficient -kernel void kernel_add( - global char * src0, - ulong offset0, - global char * src1, - ulong offset1, - global char * dst, - ulong offsetd, - int ne00, - int ne01, - int ne02, - int ne03, - ulong nb00, - ulong nb01, - ulong nb02, - ulong nb03, - int ne10, - int ne11, - int ne12, - int ne13, - ulong nb10, - ulong nb11, - ulong nb12, - ulong nb13, - int ne0, - int ne1, - int ne2, - int ne3, - ulong nb0, - ulong nb1, - ulong nb2, - ulong nb3 -) { - src0 = src0 + offset0; - src1 = src1 + offset1; - dst = dst + offsetd; - - int i03 = get_group_id(2); - int i02 = get_group_id(1); - int i01 = get_group_id(0); - - int i13 = i03 % ne13; - int i12 = i02 % ne12; - int i11 = i01 % ne11; - - global char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01; - global char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11; - global char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1; - - for (int i0 = get_local_id(0); i0 < ne0; i0 += get_local_size(0)) { - const int i10 = i0 % ne10; - *((global float *)(dst_ptr + i0*nb0)) = *((global float *)(src0_ptr + i0*nb00)) + *((global float *)(src1_ptr + i10*nb10)); - } -} - -// assumption: src1 is a row -// broadcast src1 into src0 -kernel void kernel_add_row( - global float4 * src0, - ulong offset0, - global float4 * src1, - ulong offset1, - global float4 * dst, - ulong offsetd, - int ne -) { - src0 = (global float4*)((global char*)src0 + offset0); - src1 = (global float4*)((global char*)src1 + offset1); - dst = (global float4*)((global char*)dst + offsetd); - - // This performs better than using %. - uint gid = get_global_id(0); - uint idx1 = gid - (gid/ne)*ne; // get_global_id(0) % ne - dst[gid] = src0[gid] + src1[idx1]; -} - -//------------------------------------------------------------------------------ -// mul -//------------------------------------------------------------------------------ -kernel void kernel_mul( - global char * src0, - ulong offset0, - global char * src1, - ulong offset1, - global char * dst, - ulong offsetd, - int ne00, - int ne01, - int ne02, - int ne03, - ulong nb00, - ulong nb01, - ulong nb02, - ulong nb03, - int ne10, - int ne11, - int ne12, - int ne13, - ulong nb10, - ulong nb11, - ulong nb12, - ulong nb13, - int ne0, - int ne1, - int ne2, - int ne3, - ulong nb0, - ulong nb1, - ulong nb2, - ulong nb3 -) { - src0 = src0 + offset0; - src1 = src1 + offset1; - dst = dst + offsetd; - - int i03 = get_group_id(2); - int i02 = get_group_id(1); - int i01 = get_group_id(0); - - int i13 = i03 % ne13; - int i12 = i02 % ne12; - int i11 = i01 % ne11; - - global char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01; - global char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11; - global char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1; - - for (int i0 = get_local_id(0); i0 < ne0; i0 += get_local_size(0)) { - const int i10 = i0 % ne10; - *((global float *)(dst_ptr + i0*nb0)) = *((global float *)(src0_ptr + i0*nb00)) * *((global float *)(src1_ptr + i10*nb10)); - } -} - -// assumption: src1 is a row -// broadcast src1 into src0 -kernel void kernel_mul_row( - global float4 * src0, - ulong offset0, - global float4 * src1, - ulong offset1, - global float4 * dst, - ulong offsetd, - int ne -) { - src0 = (global float4*)((global char*)src0 + offset0); - src1 = (global float4*)((global char*)src1 + offset1); - dst = (global float4*)((global char*)dst + offsetd); - - // This performs better than using %. - uint gid = get_global_id(0); - uint idx1 = gid - (gid/ne)*ne; // get_global_id(0) % ne - dst[gid] = src0[gid] * src1[idx1]; -} - -//------------------------------------------------------------------------------ -// scale -//------------------------------------------------------------------------------ -kernel void kernel_scale( - global float4 * src0, - ulong offset0, - global float4 * dst, - ulong offsetd, - float scale -) { - src0 = (global float4*)((global char*)src0 + offset0); - dst = (global float4*)((global char*)dst + offsetd); - dst[get_global_id(0)] = src0[get_global_id(0)] * scale; -} - -//------------------------------------------------------------------------------ -// gelu -//------------------------------------------------------------------------------ -#define GELU_COEF_A 0.044715f -#define GELU_QUICK_COEF -1.702f -#define SQRT_2_OVER_PI 0.79788456080286535587989211986876f - -kernel void kernel_gelu( - global float * src0, - ulong offset0, - global float * dst, - ulong offsetd -) { - src0 = (global float*)((global char*)src0 + offset0); - dst = (global float*)((global char*)dst + offsetd); - - float x = src0[get_global_id(0)]; - - dst[get_global_id(0)] = 0.5f*x*(1.0f + tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x))); -} - -kernel void kernel_gelu_4( - global float4 * src0, - ulong offset0, - global float4 * dst, - ulong offsetd -) { - src0 = (global float4*)((global char*)src0 + offset0); - dst = (global float4*)((global char*)dst + offsetd); - - float4 x = src0[get_global_id(0)]; - - dst[get_global_id(0)] = 0.5f*x*(1.0f + tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x))); -} - -kernel void kernel_gelu_quick( - global float * src0, - ulong offset0, - global float * dst, - ulong offsetd -) { - src0 = (global float*)((global char*)src0 + offset0); - dst = (global float*)((global char*)dst + offsetd); - - float x = src0[get_global_id(0)]; - dst[get_global_id(0)] = x*(1.0f/(1.0f+exp(GELU_QUICK_COEF*x))); -} - -kernel void kernel_gelu_quick_4( - global float4 * src0, - ulong offset0, - global float4 * dst, - ulong offsetd -) { - src0 = (global float4*)((global char*)src0 + offset0); - dst = (global float4*)((global char*)dst + offsetd); - - float4 x = src0[get_global_id(0)]; - dst[get_global_id(0)] = x*(1.0f/(1.0f+exp(GELU_QUICK_COEF*x))); -} - -//------------------------------------------------------------------------------ -// silu -//------------------------------------------------------------------------------ -kernel void kernel_silu( - global float * src0, - ulong offset0, - global float * dst, - ulong offsetd -) { - src0 = (global float*)((global char*)src0 + offset0); - dst = (global float*)((global char*)dst + offsetd); - - float x = src0[get_global_id(0)]; - dst[get_global_id(0)] = x / (1.0f + exp(-x)); -} - -kernel void kernel_silu_4( - global float4 * src0, - ulong offset0, - global float4 * dst, - ulong offsetd -) { - src0 = (global float4*)((global char*)src0 + offset0); - dst = (global float4*)((global char*)dst + offsetd); - - float4 x = src0[get_global_id(0)]; - dst[get_global_id(0)] = x / (1.0f + exp(-x)); -} - -//------------------------------------------------------------------------------ -// relu -//------------------------------------------------------------------------------ -kernel void kernel_relu( - global float * src0, - ulong offset0, - global float * dst, - ulong offsetd -) { - src0 = (global float*)((global char*)src0 + offset0); - dst = (global float*)((global char*)dst + offsetd); - - dst[get_global_id(0)] = fmax(0.0f, src0[get_global_id(0)]); -} - -//------------------------------------------------------------------------------ -// clamp -//------------------------------------------------------------------------------ -kernel void kernel_clamp( - global float * src0, - ulong offset0, - global float * dst, - ulong offsetd, - float min, - float max -) { - src0 = (global float*)((global char*)src0 + offset0); - dst = (global float*)((global char*)dst + offsetd); - - dst[get_global_id(0)] = src0[get_global_id(0)] < min ? - min : - (src0[get_global_id(0)] > max ? max : src0[get_global_id(0)]); -} - -//------------------------------------------------------------------------------ -// norm -//------------------------------------------------------------------------------ -kernel void kernel_norm( - global void * src0, - ulong offset0, - global float * dst, - ulong offsetd, - int ne00, - int ne01, - int ne02, - int ne03, - ulong nb01, - ulong nb02, - ulong nb03, - float eps, - local float * sum -) { - src0 = (global void*)((global char*)src0 + offset0); - dst = (global void*)((global char*)dst + offsetd); - - int i03 = get_group_id(2); - int i02 = get_group_id(1); - int i01 = get_group_id(0); - - global float * x = (global float *) ((global char *) src0 + i03*nb03 + i02*nb02 + i01*nb01); - - // MEAN - // parallel sum - sum[get_local_id(0)] = 0.0f; - for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) { - sum[get_local_id(0)] += x[i00]; - } - // reduce - barrier(CLK_LOCAL_MEM_FENCE); - for (uint i = get_local_size(0)/2; i > 0; i /= 2) { - if (get_local_id(0) < i) { - sum[get_local_id(0)] += sum[get_local_id(0) + i]; - } - barrier(CLK_LOCAL_MEM_FENCE); - } - float mean = sum[0] / ne00; - - // recenter and VARIANCE - barrier(CLK_LOCAL_MEM_FENCE); - global float * y = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; - sum[get_local_id(0)] = 0.0f; - for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) { - y[i00] = x[i00] - mean; - sum[get_local_id(0)] += y[i00] * y[i00]; - } - - // reduce - barrier(CLK_LOCAL_MEM_FENCE); - for (uint i = get_local_size(0)/2; i > 0; i /= 2) { - if (get_local_id(0) < i) { - sum[get_local_id(0)] += sum[get_local_id(0) + i]; - } - barrier(CLK_LOCAL_MEM_FENCE); - } - float variance = sum[0] / ne00; - - float scale = 1.0f/sqrt(variance + eps); - for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) { - y[i00] = y[i00] * scale; - } -} - -//------------------------------------------------------------------------------ -// rms_norm -//------------------------------------------------------------------------------ -// This kernel depends on subgroup size. -kernel void kernel_rms_norm( - global void * src0, - ulong offset0, - global float * dst, - ulong offsetd, - int ne00, - int ne01, - int ne02, - int ne03, - ulong nb01, - ulong nb02, - ulong nb03, - float eps, - local float * sum // Note, the size depends on number of subgroups -) { - src0 = (global void*)((global char*)src0 + offset0); - dst = (global float*)((global char*)dst + offsetd); - - int i03 = get_group_id(2); - int i02 = get_group_id(1); - int i01 = get_group_id(0); - - global float4 * x = (global float4 *) ((global char *) src0 + i03*nb03 + i02*nb02 + i01*nb01); - global float * x_scalar = (global float *) x; - float4 sumf = 0; - float all_sum = 0; - - // parallel sum - for (int i00 = get_local_id(0); i00 < ne00/4; i00 += get_local_size(0)) { - sumf += x[i00] * x[i00]; - } - all_sum = sumf.s0 + sumf.s1 + sumf.s2 + sumf.s3; - all_sum = sub_group_reduce_add(all_sum); - if (get_sub_group_local_id() == 0) { - sum[get_sub_group_id()] = all_sum; - } - - barrier(CLK_LOCAL_MEM_FENCE); - // broadcast - for (uint i = get_local_size(0) / get_max_sub_group_size() / 2; i > 0; i /= 2) { - if (get_local_id(0) < i) { - sum[get_local_id(0)] += sum[get_local_id(0) + i]; - } - } - if (get_local_id(0) == 0) { - for (int i = 4 * (ne00 / 4); i < ne00; i++) { - sum[0] += x_scalar[i]; - } - sum[0] /= ne00; - } - - barrier(CLK_LOCAL_MEM_FENCE); - - const float mean = sum[0]; - const float scale = 1.0f/sqrt(mean + eps); - - global float4 * y = (global float4 *) (dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00); - global float * y_scalar = (global float *) y; - for (int i00 = get_local_id(0); i00 < ne00/4; i00 += get_local_size(0)) { - y[i00] = x[i00] * scale; - } - if (get_local_id(0) == 0) { - for (int i00 = 4 * (ne00 / 4); i00 < ne00; i00++) { - y_scalar[i00] = x_scalar[i00] * scale; - } - } -} - -//------------------------------------------------------------------------------ -// diag_mask_inf kernels -//------------------------------------------------------------------------------ -kernel void kernel_diag_mask_inf( - global float * src0, - ulong offset0, - global float * dst, - ulong offsetd, - int ne00, - int ne01, - int n_past -) { - src0 = (global float*)((global char*)src0 + offset0); - dst = (global float*)((global char*)dst + offsetd); - - int i02 = get_global_id(2); - int i01 = get_global_id(1); - int i00 = get_global_id(0); - - if (i00 > n_past + i01) { - dst[i02*ne01*ne00 + i01*ne00 + i00] = -INFINITY; - } else { - dst[i02*ne01*ne00 + i01*ne00 + i00] = src0[i02*ne01*ne00 + i01*ne00 + i00]; - } -} - -kernel void kernel_diag_mask_inf_8( - global float4 * src0, - ulong offset0, - global float4 * dst, - ulong offsetd, - int ne00, - int ne01, - int n_past -) { - src0 = (global float4*)((global char*)src0 + offset0); - dst = (global float4*)((global char*)dst + offsetd); - - int i = 2*get_global_id(0); - - dst[i+0] = src0[i+0]; - dst[i+1] = src0[i+1]; - int i4 = 4*i; - int i02 = i4/(ne00*ne01); i4 -= i02*ne00*ne01; - int i01 = i4/(ne00); i4 -= i01*ne00; - int i00 = i4; - for (int k = 3; k >= 0; --k) { - if (i00 + 4 + k <= n_past + i01) { - break; - } - (&dst[i+1])[k] = -INFINITY; - if (i00 + k > n_past + i01) { - (&dst[i])[k] = -INFINITY; - } - } -} - -//------------------------------------------------------------------------------ -// softmax -//------------------------------------------------------------------------------ -#ifdef ADRENO_GPU -REQD_SUBGROUP_SIZE_64 -#endif -kernel void kernel_soft_max( - global float * src0, - ulong offset0, - global float * src1, - ulong offset1, - global float * dst, - ulong offsetd, - int ne00, - int ne01, - int ne02, - float scale, - float max_bias, - float m0, - float m1, - int n_head_log2 -) { - src0 = (global float*)((global char*)src0 + offset0); - src1 = (global float*)((global char*)src1 + offset1); - dst = (global float*)((global char*)dst + offsetd); - - int i03 = get_group_id(2); - int i02 = get_group_id(1); - int i01 = get_group_id(0); - - global float * psrc0 = src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; - global float * pmask = src1 != src0 ? src1 + i01*ne00 : 0; - global float * pdst = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; - - float slope = 1.0f; - - // ALiBi - if (max_bias > 0.0f) { - int h = i02; - - float base = h < n_head_log2 ? m0 : m1; - int exp = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1; - - slope = pow(base, exp); - } - - // parallel max - float lmax = -INFINITY; - for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) { - lmax = fmax(lmax, psrc0[i00]*scale + (pmask ? slope*pmask[i00] : 0.0f)); - } - float max = sub_group_reduce_max(lmax); - - // parallel sum - float lsum = 0.0f; - for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) { - float exp_psrc0 = exp((psrc0[i00]*scale + (pmask ? slope*pmask[i00] : 0.0f)) - max); - lsum += exp_psrc0; - // Remember the result of exp here. exp is expensive, so we really do not - // wish to compute it twice. - pdst[i00] = exp_psrc0; - } - - const float sum = sub_group_reduce_add(lsum); - - for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) { - pdst[i00] /= sum; - } -} - -#ifdef ADRENO_GPU -REQD_SUBGROUP_SIZE_64 -#endif -kernel void kernel_soft_max_4( - global float * src0, - ulong offset0, - global float * src1, - ulong offset1, - global float * dst, - ulong offsetd, - int ne00, - int ne01, - int ne02, - float scale, - float max_bias, - float m0, - float m1, - int n_head_log2 -) { - src0 = (global float*)((global char*)src0 + offset0); - src1 = (global float*)((global char*)src1 + offset1); - dst = (global float*)((global char*)dst + offsetd); - - int i03 = get_group_id(2); - int i02 = get_group_id(1); - int i01 = get_group_id(0); - - global float4 * psrc4 = (global float4 *)(src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00); - global float4 * pmask = src1 != src0 ? (global float4 *)(src1 + i01*ne00) : 0; - global float4 * pdst4 = (global float4 *)(dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00); - - float slope = 1.0f; - - // ALiBi - if (max_bias > 0.0f) { - int h = i02; - - float base = h < n_head_log2 ? m0 : m1; - int exp = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1; - - slope = pow(base, exp); - } - - // parallel max - float4 lmax4 = -INFINITY; - for (int i00 = get_local_id(0); i00 < ne00/4; i00 += get_local_size(0)) { - lmax4 = fmax(lmax4, psrc4[i00]*scale + (pmask ? slope*pmask[i00] : 0.0f)); - } - float lmax = fmax(fmax(lmax4.s0, lmax4.s1), fmax(lmax4.s2, lmax4.s3)); - - const float max = sub_group_reduce_max(lmax); - - // parallel sum - float4 lsum4 = 0.0f; - for (int i00 = get_local_id(0); i00 < ne00/4; i00 += get_local_size(0)) { - const float4 exp_psrc4 = exp((psrc4[i00]*scale + (pmask ? slope*pmask[i00] : 0.0f)) - max); - lsum4 += exp_psrc4; - pdst4[i00] = exp_psrc4; - } - float lsum = lsum4.s0 + lsum4.s1 + lsum4.s2 + lsum4.s3; - - const float sum = sub_group_reduce_add(lsum); - - for (int i00 = get_local_id(0); i00 < ne00/4; i00 += get_local_size(0)) { - pdst4[i00] /= sum; - } -} - -#ifdef ADRENO_GPU -REQD_SUBGROUP_SIZE_64 -#endif -kernel void kernel_soft_max_f16( - global float * src0, - ulong offset0, - global half * src1, - ulong offset1, - global float * dst, - ulong offsetd, - int ne00, - int ne01, - int ne02, - float scale, - float max_bias, - float m0, - float m1, - int n_head_log2 -) { - src0 = (global float *)((global char *)src0 + offset0); - src1 = (global half *)((global char *)src1 + offset1); - dst = (global float *)((global char *)dst + offsetd); - - int i03 = get_group_id(2); - int i02 = get_group_id(1); - int i01 = get_group_id(0); - - global float * psrc0 = src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; - global half * pmask = (global char *)src1 != (global char *)src0 ? src1 + i01*ne00 : 0; - global float * pdst = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; - - float slope = 1.0f; - - // ALiBi - if (max_bias > 0.0f) { - int h = i02; - - float base = h < n_head_log2 ? m0 : m1; - int exp = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1; - - slope = pow(base, exp); - } - - // parallel max - float lmax = -INFINITY; - for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) { - lmax = fmax(lmax, psrc0[i00]*scale + (pmask ? slope*pmask[i00] : 0.0f)); - } - float max = sub_group_reduce_max(lmax); - - // parallel sum - float lsum = 0.0f; - for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) { - float exp_psrc0 = exp((psrc0[i00]*scale + (pmask ? slope*pmask[i00] : 0.0f)) - max); - lsum += exp_psrc0; - // Remember the result of exp here. exp is expensive, so we really do not - // wish to compute it twice. - pdst[i00] = exp_psrc0; - } - - const float sum = sub_group_reduce_add(lsum); - - for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) { - pdst[i00] /= sum; - } -} - -#ifdef ADRENO_GPU -REQD_SUBGROUP_SIZE_64 -#endif -kernel void kernel_soft_max_4_f16( - global float * src0, - ulong offset0, - global half * src1, - ulong offset1, - global float * dst, - ulong offsetd, - int ne00, - int ne01, - int ne02, - float scale, - float max_bias, - float m0, - float m1, - int n_head_log2 -) { - src0 = (global float *)((global char *)src0 + offset0); - src1 = (global half *)((global char *)src1 + offset1); - dst = (global float *)((global char *)dst + offsetd); - - int i03 = get_group_id(2); - int i02 = get_group_id(1); - int i01 = get_group_id(0); - - global float4 * psrc4 = (global float4 *)(src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00); - global half4 * pmask = (global char *)src1 != (global char *)src0 ? (global half4 *)(src1 + i01*ne00) : 0; - global float4 * pdst4 = (global float4 *)(dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00); - - float slope = 1.0f; - - // ALiBi - if (max_bias > 0.0f) { - int h = i02; - - float base = h < n_head_log2 ? m0 : m1; - int exp = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1; - - slope = pow(base, exp); - } - - // parallel max - float4 lmax4 = -INFINITY; - for (int i00 = get_local_id(0); i00 < ne00/4; i00 += get_local_size(0)) { - lmax4 = fmax(lmax4, psrc4[i00]*scale + slope*(pmask ? convert_float4(pmask[i00]) : 0.0f)); - } - float lmax = fmax(fmax(lmax4.s0, lmax4.s1), fmax(lmax4.s2, lmax4.s3)); - - const float max = sub_group_reduce_max(lmax); - - // parallel sum - float4 lsum4 = 0.0f; - for (int i00 = get_local_id(0); i00 < ne00/4; i00 += get_local_size(0)) { - const float4 exp_psrc4 = exp((psrc4[i00]*scale + slope*(pmask ? convert_float4(pmask[i00]) : 0.0f)) - max); - lsum4 += exp_psrc4; - pdst4[i00] = exp_psrc4; - } - float lsum = lsum4.s0 + lsum4.s1 + lsum4.s2 + lsum4.s3; - - const float sum = sub_group_reduce_add(lsum); - - for (int i00 = get_local_id(0); i00 < ne00/4; i00 += get_local_size(0)) { - pdst4[i00] /= sum; - } -} - -//------------------------------------------------------------------------------ -// kernel_rope -//------------------------------------------------------------------------------ -float rope_yarn_ramp(float low, float high, int i0) { - const float y = (i0 / 2 - low) / max(0.001f, high - low); - return 1.0f - min(1.0f, max(0.0f, y)); -} - -// YaRN algorithm based on LlamaYaRNScaledRotaryEmbedding.py from https://github.com/jquesnelle/yarn -// MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng. -float2 rope_yarn( - float theta_extrap, float freq_scale, float2 corr_dims, int i0, float ext_factor, float mscale -) { - // Get n-d rotational scaling corrected for extrapolation - float theta_interp = freq_scale * theta_extrap; - float theta = theta_interp; - if (ext_factor != 0.0f) { - float ramp_mix = rope_yarn_ramp(corr_dims.s0, corr_dims.s1, i0) * ext_factor; - theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix; - - // Get n-d magnitude scaling corrected for interpolation - mscale *= 1.0f + 0.1f * log(1.0f / freq_scale); - } - return (float2)(cos(theta) * mscale, sin(theta) * mscale); -} - -// Apparently solving `n_rot = 2pi * x * base^((2 * max_pos_emb) / n_dims)` for x, we get -// `corr_fac(n_rot) = n_dims * log(max_pos_emb / (n_rot * 2pi)) / (2 * log(base))` -float rope_yarn_corr_factor(int n_dims, int n_ctx_orig, float n_rot, float base) { - return n_dims * log(n_ctx_orig / (n_rot * 2 * M_PI_F)) / (2 * log(base)); -} - -float2 rope_yarn_corr_dims( - int n_dims, int n_ctx_orig, float freq_base, float beta_fast, float beta_slow -) { - // start and end correction dims - return (float2)( - max(0.0f, floor(rope_yarn_corr_factor(n_dims, n_ctx_orig, beta_fast, freq_base))), - min(n_dims - 1.0f, ceil(rope_yarn_corr_factor(n_dims, n_ctx_orig, beta_slow, freq_base))) - ); -} - -kernel void kernel_rope_norm_f32( - global void * src0, - ulong offset0, - global int * src1, - ulong offset1, - global float * src2, - ulong offset2, - global float * dst, - ulong offsetd, - int ne00, - int ne01, - int ne02, - int ne03, - ulong nb00, - ulong nb01, - ulong nb02, - ulong nb03, - int ne0, - int ne1, - int ne2, - int ne3, - ulong nb0, - ulong nb1, - ulong nb2, - ulong nb3, - int n_past, - int n_dims, - int n_ctx_orig, - float freq_base, - float freq_scale, - float ext_factor, - float attn_factor, - float beta_fast, - float beta_slow -) { - src0 = (global void*)((global char*)src0 + offset0); - src1 = (global int*)((global char*)src1 + offset1); - src2 = (global float*)((global char*)src2 + offset2); - dst = (global float*)((global char*)dst + offsetd); - - int i3 = get_group_id(2); - int i2 = get_group_id(1); - int i1 = get_group_id(0); - - float2 corr_dims = rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow); - - global int * pos = src1; - - float theta_base = (float) pos[i2]; - float inv_ndims = -1.f/n_dims; - - for (int i0 = 2*get_local_id(0); i0 < ne0; i0 += 2*get_local_size(0)) { - if (i0 < n_dims) { - int ic = i0/2; - - float theta = theta_base * pow(freq_base, inv_ndims*i0); - - float freq_factor = src2 != src0 ? src2[ic] : 1.0f; - - float2 cos_sin_theta = rope_yarn(theta/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor); - - global float * src = (global float *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); - global float * dst_data = (global float *)((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); - - float x0 = src[0]; - float x1 = src[1]; - - dst_data[0] = x0*cos_sin_theta.s0 - x1*cos_sin_theta.s1; - dst_data[1] = x0*cos_sin_theta.s1 + x1*cos_sin_theta.s0; - } else { - global float * src = (global float *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); - global float * dst_data = (global float *)((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); - - dst_data[0] = src[0]; - dst_data[1] = src[1]; - } - } -} - -kernel void kernel_rope_norm_f16( - global void * src0, - ulong offset0, - global int * src1, - ulong offset1, - global float * src2, - ulong offset2, - global float * dst, - ulong offsetd, - int ne00, - int ne01, - int ne02, - int ne03, - ulong nb00, - ulong nb01, - ulong nb02, - ulong nb03, - int ne0, - int ne1, - int ne2, - int ne3, - ulong nb0, - ulong nb1, - ulong nb2, - ulong nb3, - int n_past, - int n_dims, - int n_ctx_orig, - float freq_base, - float freq_scale, - float ext_factor, - float attn_factor, - float beta_fast, - float beta_slow -) { - src0 = (global void*)((global char*)src0 + offset0); - src1 = (global int*)((global char*)src1 + offset1); - src2 = (global float*)((global char*)src2 + offset2); - dst = (global float*)((global char*)dst + offsetd); - - int i3 = get_group_id(2); - int i2 = get_group_id(1); - int i1 = get_group_id(0); - - float2 corr_dims = rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow); - - global int * pos = src1; - - float theta_base = (float) pos[i2]; - float inv_ndims = -1.f/n_dims; - - for (int i0 = 2*get_local_id(0); i0 < ne0; i0 += 2*get_local_size(0)) { - if (i0 < n_dims) { - int ic = i0/2; - - float theta = theta_base * pow(freq_base, inv_ndims*i0); - - float freq_factor = src2 != src0 ? src2[ic] : 1.0f; - - float2 cos_sin_theta = rope_yarn(theta/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor); - - global half * src = (global half *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); - global half * dst_data = (global half *)((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); - - float x0 = src[0]; - float x1 = src[1]; - - dst_data[0] = x0*cos_sin_theta.s0 - x1*cos_sin_theta.s1; - dst_data[1] = x0*cos_sin_theta.s1 + x1*cos_sin_theta.s0; - } else { - global half * src = (global half *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); - global half * dst_data = (global half *)((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); - - dst_data[0] = src[0]; - dst_data[1] = src[1]; - } - } -} - -kernel void kernel_rope_neox_f32( - global void * src0, - ulong offset0, - global int * src1, - ulong offset1, - global float * src2, - ulong offset2, - global float * dst, - ulong offsetd, - int ne00, - int ne01, - int ne02, - int ne03, - ulong nb00, - ulong nb01, - ulong nb02, - ulong nb03, - int ne0, - int ne1, - int ne2, - int ne3, - ulong nb0, - ulong nb1, - ulong nb2, - ulong nb3, - int n_past, - int n_dims, - int n_ctx_orig, - float freq_base, - float freq_scale, - float ext_factor, - float attn_factor, - float beta_fast, - float beta_slow -) { - src0 = (global void*)((global char*)src0 + offset0); - src1 = (global int*)((global char*)src1 + offset1); - src2 = (global float*)((global char*)src2 + offset2); - dst = (global float*)((global char*)dst + offsetd); - - int i3 = get_group_id(2); - int i2 = get_group_id(1); - int i1 = get_group_id(0); - - float2 corr_dims = rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow); - - global int * pos = src1; - - float theta_base = (float) pos[i2]; - float inv_ndims = -1.f/n_dims; - - for (int i0 = 2*get_local_id(0); i0 < ne0; i0 += 2*get_local_size(0)) { - if (i0 < n_dims) { - int ic = i0/2; - - const float theta = theta_base * pow(freq_base, inv_ndims*i0); - - const float freq_factor = src2 != src0 ? src2[ic] : 1.0f; - - float2 cos_sin_theta = rope_yarn(theta/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor); - - global float * src = (global float *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00); - global float * dst_data = (global float *)((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0); - - const float x0 = src[0]; - const float x1 = src[n_dims/2]; - - dst_data[0] = x0*cos_sin_theta.s0 - x1*cos_sin_theta.s1; - dst_data[n_dims/2] = x0*cos_sin_theta.s1 + x1*cos_sin_theta.s0; - } else { - global float * const src = (global float *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); - global float * dst_data = (global float *)((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); - - dst_data[0] = src[0]; - dst_data[1] = src[1]; - } - } -} - -kernel void kernel_rope_neox_f16( - global void * src0, - ulong offset0, - global int * src1, - ulong offset1, - global float * src2, - ulong offset2, - global float * dst, - ulong offsetd, - int ne00, - int ne01, - int ne02, - int ne03, - ulong nb00, - ulong nb01, - ulong nb02, - ulong nb03, - int ne0, - int ne1, - int ne2, - int ne3, - ulong nb0, - ulong nb1, - ulong nb2, - ulong nb3, - int n_past, - int n_dims, - int n_ctx_orig, - float freq_base, - float freq_scale, - float ext_factor, - float attn_factor, - float beta_fast, - float beta_slow -) { - src0 = (global void*)((global char*)src0 + offset0); - src1 = (global int*)((global char*)src1 + offset1); - src2 = (global float*)((global char*)src2 + offset2); - dst = (global float*)((global char*)dst + offsetd); - - int i3 = get_group_id(2); - int i2 = get_group_id(1); - int i1 = get_group_id(0); - - float2 corr_dims = rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow); - - global int * pos = src1; - - float theta_base = (float) pos[i2]; - float inv_ndims = -1.f/n_dims; - - for (int i0 = 2*get_local_id(0); i0 < ne0; i0 += 2*get_local_size(0)) { - if (i0 < n_dims) { - int ic = i0/2; - - const float theta = theta_base * pow(freq_base, inv_ndims*i0); - - const float freq_factor = src2 != src0 ? src2[ic] : 1.0f; - - float2 cos_sin_theta = rope_yarn(theta/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor); - - global half * src = (global half *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00); - global half * dst_data = (global half *)((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0); - - const float x0 = src[0]; - const float x1 = src[n_dims/2]; - - dst_data[0] = x0*cos_sin_theta.s0 - x1*cos_sin_theta.s1; - dst_data[n_dims/2] = x0*cos_sin_theta.s1 + x1*cos_sin_theta.s0; - } else { - global half * const src = (global half *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); - global half * dst_data = (global half *)((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); - - dst_data[0] = src[0]; - dst_data[1] = src[1]; - } - } -} - -kernel void kernel_rope_multi_f32( - global void * src0, - ulong offset0, - global int * src1, - ulong offset1, - global float * src2, - ulong offset2, - global float * dst, - ulong offsetd, - int ne00, - int ne01, - int ne02, - int ne03, - ulong nb00, - ulong nb01, - ulong nb02, - ulong nb03, - int ne0, - int ne1, - int ne2, - int ne3, - ulong nb0, - ulong nb1, - ulong nb2, - ulong nb3, - int n_past, - int n_dims, - int n_ctx_orig, - float freq_base, - float freq_scale, - float ext_factor, - float attn_factor, - float beta_fast, - float beta_slow, - int4 sections -) { - src0 = (global void*)((global char*)src0 + offset0); - src1 = (global int*)((global char*)src1 + offset1); - src2 = (global float*)((global char*)src2 + offset2); - dst = (global float*)((global char*)dst + offsetd); - - int i3 = get_group_id(2); - int i2 = get_group_id(1); - int i1 = get_group_id(0); - - float2 corr_dims = rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow); - - global int * pos = src1; - - const int sect_dims = sections.s0 + sections.s1 + sections.s2 + sections.s3; - const int sec_w = sections.s1 + sections.s0; - - float inv_ndims = -1.f/n_dims; - - for (int i0 = 2*get_local_id(0); i0 < ne0; i0 += 2*get_local_size(0)) { - if (i0 < n_dims) { - int ic = i0/2; - - const int sector = (i0 / 2) % sect_dims; - float theta_base = 0.0f; - - if (sector < sections.s0) { - theta_base = pos[i2]; - } - else if (sector >= sections.s0 && sector < sec_w) { - theta_base = pos[i2 + ne2 * 1]; - } - else if (sector >= sec_w && sector < sec_w + sections.s2) { - theta_base = pos[i2 + ne2 * 2]; - } - else if (sector >= sec_w + sections.s2) { - theta_base = pos[i2 + ne2 * 3]; - } - - const float theta = theta_base * pow(freq_base, inv_ndims*i0); - - const float freq_factor = src2 != src0 ? src2[ic] : 1.0f; - - float2 cos_sin_theta = rope_yarn(theta/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor); - - global float * src = (global float *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00); - global float * dst_data = (global float *)((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0); - - const float x0 = src[0]; - const float x1 = src[n_dims/2]; - - dst_data[0] = x0*cos_sin_theta.s0 - x1*cos_sin_theta.s1; - dst_data[n_dims/2] = x0*cos_sin_theta.s1 + x1*cos_sin_theta.s0; - } else { - global float * const src = (global float *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); - global float * dst_data = (global float *)((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); - - dst_data[0] = src[0]; - dst_data[1] = src[1]; - } - } -} - -kernel void kernel_rope_multi_f16( - global void * src0, - ulong offset0, - global int * src1, - ulong offset1, - global float * src2, - ulong offset2, - global half * dst, - ulong offsetd, - int ne00, - int ne01, - int ne02, - int ne03, - ulong nb00, - ulong nb01, - ulong nb02, - ulong nb03, - int ne0, - int ne1, - int ne2, - int ne3, - ulong nb0, - ulong nb1, - ulong nb2, - ulong nb3, - int n_past, - int n_dims, - int n_ctx_orig, - float freq_base, - float freq_scale, - float ext_factor, - float attn_factor, - float beta_fast, - float beta_slow, - int4 sections -) { - src0 = (global void*)((global char*)src0 + offset0); - src1 = (global int*)((global char*)src1 + offset1); - src2 = (global float*)((global char*)src2 + offset2); - dst = (global float*)((global char*)dst + offsetd); - - int i3 = get_group_id(2); - int i2 = get_group_id(1); - int i1 = get_group_id(0); - - float2 corr_dims = rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow); - - global int * pos = src1; - - const int sect_dims = sections.s0 + sections.s1 + sections.s2 + sections.s3; - const int sec_w = sections.s1 + sections.s0; - - float inv_ndims = -1.f/n_dims; - - for (int i0 = 2*get_local_id(0); i0 < ne0; i0 += 2*get_local_size(0)) { - if (i0 < n_dims) { - int ic = i0/2; - - const int sector = (i0 / 2) % sect_dims; - float theta_base = 0.0f; - - if (sector < sections.s0) { - theta_base = pos[i2]; - } - else if (sector >= sections.s0 && sector < sec_w) { - theta_base = pos[i2 + ne2 * 1]; - } - else if (sector >= sec_w && sector < sec_w + sections.s2) { - theta_base = pos[i2 + ne2 * 2]; - } - else if (sector >= sec_w + sections.s2) { - theta_base = pos[i2 + ne2 * 3]; - } - - const float theta = theta_base * pow(freq_base, inv_ndims*i0); - - const float freq_factor = src2 != src0 ? src2[ic] : 1.0f; - - float2 cos_sin_theta = rope_yarn(theta/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor); - - global half * src = (global half *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00); - global half * dst_data = (global half *)((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0); - - const float x0 = src[0]; - const float x1 = src[n_dims/2]; - - dst_data[0] = x0*cos_sin_theta.s0 - x1*cos_sin_theta.s1; - dst_data[n_dims/2] = x0*cos_sin_theta.s1 + x1*cos_sin_theta.s0; - } else { - global half * const src = (global half *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); - global half * dst_data = (global half *)((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); - - dst_data[0] = src[0]; - dst_data[1] = src[1]; - } - } -} - -kernel void kernel_rope_vision_f32( - global void * src0, - ulong offset0, - global int * src1, - ulong offset1, - global float * src2, - ulong offset2, - global float * dst, - ulong offsetd, - int ne00, - int ne01, - int ne02, - int ne03, - ulong nb00, - ulong nb01, - ulong nb02, - ulong nb03, - int ne0, - int ne1, - int ne2, - int ne3, - ulong nb0, - ulong nb1, - ulong nb2, - ulong nb3, - int n_past, - int n_dims, - int n_ctx_orig, - float freq_base, - float freq_scale, - float ext_factor, - float attn_factor, - float beta_fast, - float beta_slow, - int4 sections -) { - src0 = (global void*)((global char*)src0 + offset0); - src1 = (global int*)((global char*)src1 + offset1); - src2 = (global float*)((global char*)src2 + offset2); - dst = (global float*)((global char*)dst + offsetd); - - int i3 = get_group_id(2); - int i2 = get_group_id(1); - int i1 = get_group_id(0); - - float2 corr_dims = rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow); - - global int * pos = src1; - - const int sect_dims = sections.s0 + sections.s1; - const int sec_w = sections.s1 + sections.s0; - - float inv_ndims = -1.f/n_dims; - - for (int i0 = 2*get_local_id(0); i0 < ne0; i0 += 2*get_local_size(0)) { - int ic = i0/2; - - const int sector = (i0/2) % sect_dims; - float theta_base = 0.0f; - - if (sector < sections.s0) { - const int p = sector; - theta_base = pos[i2] * pow(freq_base, inv_ndims*2.0f*p); - } else if (sector >= sections.s0 && sector < sec_w) { - const int p = sector - sections.s0; - theta_base = pos[i2 + ne2] * pow(freq_base, inv_ndims*2.0f*p); - } - - const float freq_factor = src2 != src0 ? src2[ic] : 1.0f; - - float2 cos_sin_theta = rope_yarn(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor); - - global float * src = (global float *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00); - global float * dst_data = (global float *)((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0); - - const float x0 = src[0]; - const float x1 = src[n_dims]; - - dst_data[0] = x0*cos_sin_theta.s0 - x1*cos_sin_theta.s1; - dst_data[n_dims] = x0*cos_sin_theta.s1 + x1*cos_sin_theta.s0; - } -} - -kernel void kernel_rope_vision_f16( - global void * src0, - ulong offset0, - global int * src1, - ulong offset1, - global float * src2, - ulong offset2, - global half * dst, - ulong offsetd, - int ne00, - int ne01, - int ne02, - int ne03, - ulong nb00, - ulong nb01, - ulong nb02, - ulong nb03, - int ne0, - int ne1, - int ne2, - int ne3, - ulong nb0, - ulong nb1, - ulong nb2, - ulong nb3, - int n_past, - int n_dims, - int n_ctx_orig, - float freq_base, - float freq_scale, - float ext_factor, - float attn_factor, - float beta_fast, - float beta_slow, - int4 sections -) { - src0 = (global void*)((global char*)src0 + offset0); - src1 = (global int*)((global char*)src1 + offset1); - src2 = (global float*)((global char*)src2 + offset2); - dst = (global float*)((global char*)dst + offsetd); - - int i3 = get_group_id(2); - int i2 = get_group_id(1); - int i1 = get_group_id(0); - - float2 corr_dims = rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow); - - global int * pos = src1; - - const int sect_dims = sections.s0 + sections.s1; - const int sec_w = sections.s1 + sections.s0; - - float inv_ndims = -1.f/n_dims; - - for (int i0 = 2*get_local_id(0); i0 < ne0; i0 += 2*get_local_size(0)) { - int ic = i0/2; - - const int sector = (i0/2) % sect_dims; - float theta_base = 0.0f; - - if (sector < sections.s0) { - const int p = sector; - theta_base = pos[i2] * pow(freq_base, inv_ndims*2.0f*p); - } else if (sector >= sections.s0 && sector < sec_w) { - const int p = sector - sections.s0; - theta_base = pos[i2 + ne2] * pow(freq_base, inv_ndims*2.0f*p); - } - - const float freq_factor = src2 != src0 ? src2[ic] : 1.0f; - - float2 cos_sin_theta = rope_yarn(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor); - - global half * src = (global half *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00); - global half * dst_data = (global half *)((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0); - - const float x0 = src[0]; - const float x1 = src[n_dims]; - - dst_data[0] = x0*cos_sin_theta.s0 - x1*cos_sin_theta.s1; - dst_data[n_dims] = x0*cos_sin_theta.s1 + x1*cos_sin_theta.s0; - } -} - -//------------------------------------------------------------------------------ -// cpy -//------------------------------------------------------------------------------ - -kernel void kernel_cpy_f16_f16( - global half * src0, - ulong offset0, - global half * dst, - ulong offsetd, - int ne00, - int ne01, - int ne02, - int ne03, - ulong nb00, - ulong nb01, - ulong nb02, - ulong nb03, - int ne0, - int ne1, - int ne2, - int ne3, - ulong nb0, - ulong nb1, - ulong nb2, - ulong nb3 -) { - src0 = (global half*)((global char*)src0 + offset0); - dst = (global half*)((global char*)dst + offsetd); - - int i03 = get_group_id(2); - int i02 = get_group_id(1); - int i01 = get_group_id(0); - - int n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; - - int i3 = n / (ne2*ne1*ne0); - int i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0); - int i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0; - int i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0); - - global half * dst_data = (global half *) ((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); - - for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) { - global const half * src = (global half *)((global char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00); - dst_data[i00] = src[0]; - } -} - -kernel void kernel_cpy_f16_f32( - global half * src0, - ulong offset0, - global float * dst, - ulong offsetd, - int ne00, - int ne01, - int ne02, - int ne03, - ulong nb00, - ulong nb01, - ulong nb02, - ulong nb03, - int ne0, - int ne1, - int ne2, - int ne3, - ulong nb0, - ulong nb1, - ulong nb2, - ulong nb3 -) { - - src0 = (global half*)((global char*)src0 + offset0); - dst = (global float*)((global char*)dst + offsetd); - - int i03 = get_group_id(2); - int i02 = get_group_id(1); - int i01 = get_group_id(0); - - int n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; - - int i3 = n / (ne2*ne1*ne0); - int i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0); - int i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0; - int i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0); - - global float * dst_data = (global float *) ((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); - - for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) { - global half * src = (global half *)((global char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00); - dst_data[i00] = src[0]; - } -} - -kernel void kernel_cpy_f32_f16( - global float * src0, - ulong offset0, - global half * dst, - ulong offsetd, - int ne00, - int ne01, - int ne02, - int ne03, - ulong nb00, - ulong nb01, - ulong nb02, - ulong nb03, - int ne0, - int ne1, - int ne2, - int ne3, - ulong nb0, - ulong nb1, - ulong nb2, - ulong nb3 -) { - src0 = (global float*)((global char*)src0 + offset0); - dst = (global half*)((global char*)dst + offsetd); - - int i03 = get_group_id(2); - int i02 = get_group_id(1); - int i01 = get_group_id(0); - - int n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; - - int i3 = n / (ne2*ne1*ne0); - int i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0); - int i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0; - int i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0); - - global half * dst_data = (global half *) ((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); - - for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) { - global const float * src = (global float *)((global char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00); - - dst_data[i00] = src[0]; - } -} - -kernel void kernel_cpy_f32_f32( - global float * src0, - ulong offset0, - global float * dst, - ulong offsetd, - int ne00, - int ne01, - int ne02, - int ne03, - ulong nb00, - ulong nb01, - ulong nb02, - ulong nb03, - int ne0, - int ne1, - int ne2, - int ne3, - ulong nb0, - ulong nb1, - ulong nb2, - ulong nb3 -) { - src0 = (global float*)((global char*)src0 + offset0); - dst = (global float*)((global char*)dst + offsetd); - - int i03 = get_group_id(2); - int i02 = get_group_id(1); - int i01 = get_group_id(0); - - int n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; - - int i3 = n / (ne2*ne1*ne0); - int i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0); - int i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0; - int i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0); - - global float * dst_data = (global float *) ((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); - - for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) { - global const float * src = (global float *)((global char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00); - - dst_data[i00] = src[0]; - } -} - -//------------------------------------------------------------------------------ -// get_rows -//------------------------------------------------------------------------------ -kernel void kernel_get_rows_f32( - global void * src0, - ulong offset0, - global int * src1, - ulong offset1, - global float * dst, - ulong offsetd, - int ne00, - ulong nb01, - ulong nb02, - int ne10, - ulong nb10, - ulong nb11, - ulong nb1, - ulong nb2 -) { - src0 = (global void*)((global char*)src0 + offset0); - src1 = (global int*)((global char*)src1 + offset1); - dst = (global float*)((global char*)dst + offsetd); - - int i10 = get_group_id(0); - int i11 = get_group_id(1); - - int r = ((global int *) ((global char *) src1 + i11*nb11 + i10*nb10))[0]; - - int i02 = i11; - - for (int ind = get_local_id(0); ind < ne00; ind += get_local_size(0)) { - ((global float *) ((global char *) dst + i11*nb2 + i10*nb1))[ind] = - ((global float *) ((global char *) src0 + r*nb01 + i02*nb02))[ind]; - } -} - -kernel void kernel_get_rows_f16( - global void * src0, - ulong offset0, - global int * src1, - ulong offset1, - global float * dst, - ulong offsetd, - int ne00, - ulong nb01, - ulong nb02, - int ne10, - ulong nb10, - ulong nb11, - ulong nb1, - ulong nb2 -) { - src0 = (global void*)((global char*)src0 + offset0); - src1 = (global int*)((global char*)src1 + offset1); - dst = (global float*)((global char*)dst + offsetd); - - int i10 = get_group_id(0); - int i11 = get_group_id(1); - - int r = ((global int32_t *) ((global char *) src1 + i11*nb11 + i10*nb10))[0]; - - int i02 = i11; - - for (int ind = get_local_id(0); ind < ne00; ind += get_local_size(0)) { - ((global float *) ((global char *) dst + i11*nb2 + i10*nb1))[ind] = - ((global half *) ((global char *) src0 + r*nb01 + i02*nb02))[ind]; - } -} - -kernel void kernel_get_rows_q4_0( - global void * src0, - ulong offset0, - global int * src1, - ulong offset1, - global float * dst, - ulong offsetd, - int ne00, - ulong nb01, - ulong nb02, - int ne10, - ulong nb10, - ulong nb11, - ulong nb1, - ulong nb2 -) { - src0 = (global void*)((global char*)src0 + offset0); - src1 = (global int*)((global char*)src1 + offset1); - dst = (global float*)((global char*)dst + offsetd); - - const int NL = 2; - - int i10 = get_group_id(0); - int i11 = get_group_id(1); - - int r = ((global int32_t *) ((global char *) src1 + i11*nb11 + i10*nb10))[0]; - - int i02 = i11; - - for (int ind = get_local_id(0); ind < ne00/16; ind += get_local_size(0)) { - float16 temp; - dequantize_q4_0_f32( - ((global struct block_q4_0 *) ((global char *) src0 + r*nb01 + i02*nb02)) + ind/NL, ind%NL, &temp); - *(((global float16 *) ((global char *) dst + i11*nb2 + i10*nb1)) + ind) = temp; - } -} - -//------------------------------------------------------------------------------ -// mul_mat_f32_f32 -//------------------------------------------------------------------------------ -#define N_F32_F32 4 - -kernel void kernel_mul_mat_f32_f32( - global char * src0, - ulong offset0, - global char * src1, - ulong offset1, - global float * dst, - ulong offsetd, - int ne00, - int ne01, - int ne02, - ulong nb00, - ulong nb01, - ulong nb02, - ulong nb03, - int ne10, - int ne11, - int ne12, - ulong nb10, - ulong nb11, - ulong nb12, - ulong nb13, - int ne0, - int ne1, - int r2, - int r3 -) { - src0 = (global char*)((global char*)src0 + offset0); - src1 = (global char*)((global char*)src1 + offset1); - dst = (global float*)((global char*)dst + offsetd); - - int r0 = get_group_id(0); - int rb = get_group_id(1)*N_F32_F32; - int im = get_group_id(2); - - int i12 = im%ne12; - int i13 = im/ne12; - - ulong offset_src0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03; - - global float * x = (global float *) (src0 + offset_src0); - - if (ne00 < 128) { - for (int row = 0; row < N_F32_F32; ++row) { - int r1 = rb + row; - if (r1 >= ne11) { - break; - } - - ulong offset_src1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13; - - global float * y = (global float *) (src1 + offset_src1); - - float sumf = 0; - for (int i = get_sub_group_local_id(); i < ne00; i += get_max_sub_group_size()) { - sumf += (float) x[i] * (float) y[i]; - } - - float all_sum = sub_group_reduce_add(sumf); - if (get_sub_group_local_id() == 0) { - dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; - } - } - } else { - global float4 * x4 = (global float4 *)x; - for (int row = 0; row < N_F32_F32; ++row) { - int r1 = rb + row; - if (r1 >= ne11) { - break; - } - - ulong offset_src1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13; - - global float * y = (global float *) (src1 + offset_src1); - global float4 * y4 = (global float4 *) y; - - float sumf = 0; - for (int i = get_sub_group_local_id(); i < ne00/4; i += get_max_sub_group_size()) { - sumf += (float) x4[i].s0 * y4[i].s0; - sumf += (float) x4[i].s1 * y4[i].s1; - sumf += (float) x4[i].s2 * y4[i].s2; - sumf += (float) x4[i].s3 * y4[i].s3; - } - - float all_sum = sub_group_reduce_add(sumf); - if (get_sub_group_local_id() == 0) { - for (int i = 4*(ne00/4); i < ne00; ++i) { - all_sum += (float) x[i] * y[i]; - } - dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; - } - } - } -} - -//------------------------------------------------------------------------------ -// mul_mat_f16_f16 -//------------------------------------------------------------------------------ -#define N_F16_F16 4 - -kernel void kernel_mul_mat_f16_f16( - global char * src0, - ulong offset0, - global char * src1, - ulong offset1, - global float * dst, - ulong offsetd, - int ne00, - int ne01, - int ne02, - ulong nb00, - ulong nb01, - ulong nb02, - ulong nb03, - int ne10, - int ne11, - int ne12, - ulong nb10, - ulong nb11, - ulong nb12, - ulong nb13, - int ne0, - int ne1, - int r2, - int r3) -{ - src0 = (global char*)((global char*)src0 + offset0); - src1 = (global char*)((global char*)src1 + offset1); - dst = (global float*)((global char*)dst + offsetd); - - int r0 = get_group_id(0); - int rb = get_group_id(1)*N_F16_F16; - int im = get_group_id(2); - - int i12 = im%ne12; - int i13 = im/ne12; - - ulong offset_src0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03; - - global half * x = (global half *) (src0 + offset_src0); - - if (ne00 < 128) { - for (int row = 0; row < N_F16_F16; ++row) { - int r1 = rb + row; - if (r1 >= ne11) { - break; - } - - ulong offset_src1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13; - - global half * y = (global half *) (src1 + offset_src1); - - float sumf = 0; - for (int i = get_sub_group_local_id(); i < ne00; i += get_max_sub_group_size()) { - sumf += (half) x[i] * (half) y[i]; - } - - float all_sum = sub_group_reduce_add(sumf); - if (get_sub_group_local_id() == 0) { - dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; - } - } - } else { - global half4 * x4 = (global half4 *)x; - for (int row = 0; row < N_F16_F16; ++row) { - int r1 = rb + row; - if (r1 >= ne11) { - break; - } - - ulong offset_src1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13; - - global half * y = (global half *) (src1 + offset_src1); - global half4 * y4 = (global half4 *) y; - - float sumf = 0; - for (int i = get_sub_group_local_id(); i < ne00/4; i += get_max_sub_group_size()) { - sumf += (half) x4[i].s0 * y4[i].s0; - sumf += (half) x4[i].s1 * y4[i].s1; - sumf += (half) x4[i].s2 * y4[i].s2; - sumf += (half) x4[i].s3 * y4[i].s3; - } - - float all_sum = sub_group_reduce_add(sumf); - if (get_sub_group_local_id() == 0) { - for (int i = 4*(ne00/4); i < ne00; ++i) { - all_sum += (half) x[i] * y[i]; - } - dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; - } - } - } -} - -//------------------------------------------------------------------------------ -// mul_mat_f16_f32_1row -//------------------------------------------------------------------------------ -#ifdef ADRENO_GPU -REQD_SUBGROUP_SIZE_64 -#endif -kernel void kernel_mul_mat_f16_f32_1row( - global char * src0, - ulong offset0, - global char * src1, - ulong offset1, - global float * dst, - ulong offsetd, - int ne00, - int ne01, - int ne02, - ulong nb00, - ulong nb01, - ulong nb02, - ulong nb03, - int ne10, - int ne11, - int ne12, - ulong nb10, - ulong nb11, - ulong nb12, - ulong nb13, - int ne0, - int ne1, - int r2, - int r3 -) { - src0 = (global char*)((global char*)src0 + offset0); - src1 = (global char*)((global char*)src1 + offset1); - dst = (global float*)((global char*)dst + offsetd); - - int r0 = get_group_id(0); - int r1 = get_group_id(1); - int im = get_group_id(2); - - int i12 = im%ne12; - int i13 = im/ne12; - - ulong offset_src0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03; - ulong offset_src1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13; - - global half * x = (global half *) (src0 + offset_src0); - global float * y = (global float *) (src1 + offset_src1); - - float sumf = 0; - if (ne00 < 128) { - for (int i = get_sub_group_local_id(); i < ne00; i += get_max_sub_group_size()) { - sumf += (float) x[i] * (float) y[i]; - } - float all_sum = sub_group_reduce_add(sumf); - if (get_sub_group_local_id() == 0) { - dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; - } - } else { - global half4 * x4 = (global half4 *) x; - global float4 * y4 = (global float4 *) y; - for (int i = get_sub_group_local_id(); i < ne00/4; i += get_max_sub_group_size()) { - sumf += (float) x4[i].s0 * y4[i].s0; - sumf += (float) x4[i].s1 * y4[i].s1; - sumf += (float) x4[i].s2 * y4[i].s2; - sumf += (float) x4[i].s3 * y4[i].s3; - } - float all_sum = sub_group_reduce_add(sumf); - if (get_sub_group_local_id() == 0) { - for (int i = 4*(ne00/4); i < ne00; ++i) { - all_sum += (float) x[i] * y[i]; - } - dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; - } - } - -} - -//------------------------------------------------------------------------------ -// mul_mat_f16_f32 -//------------------------------------------------------------------------------ -#define N_F16_F32 4 - -#ifdef ADRENO_GPU -REQD_SUBGROUP_SIZE_64 -#endif -kernel void kernel_mul_mat_f16_f32( - global char * src0, - ulong offset0, - global char * src1, - ulong offset1, - global float * dst, - ulong offsetd, - int ne00, - int ne01, - int ne02, - ulong nb00, - ulong nb01, - ulong nb02, - ulong nb03, - int ne10, - int ne11, - int ne12, - ulong nb10, - ulong nb11, - ulong nb12, - ulong nb13, - int ne0, - int ne1, - int r2, - int r3 -) { - src0 = (global char*)((global char*)src0 + offset0); - src1 = (global char*)((global char*)src1 + offset1); - dst = (global float*)((global char*)dst + offsetd); - - int r0 = get_group_id(0); - int rb = get_group_id(1)*N_F16_F32; - int im = get_group_id(2); - - int i12 = im%ne12; - int i13 = im/ne12; - - ulong offset_src0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03; - - global half * x = (global half *) (src0 + offset_src0); - - if (ne00 < 128) { - for (int row = 0; row < N_F16_F32; ++row) { - int r1 = rb + row; - if (r1 >= ne11) { - break; - } - - ulong offset_src1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13; - - global float * y = (global float *) (src1 + offset_src1); - - float sumf = 0; - for (int i = get_sub_group_local_id(); i < ne00; i += get_max_sub_group_size()) { - sumf += convert_float(x[i]) * y[i]; - } - - float all_sum = sub_group_reduce_add(sumf); - if (get_sub_group_local_id() == 0) { - dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; - } - } - } else { - global half4 * x4 = (global half4 *)x; - for (int row = 0; row < N_F16_F32; ++row) { - int r1 = rb + row; - if (r1 >= ne11) { - break; - } - - ulong offset_src1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13; - - global float * y = (global float *) (src1 + offset_src1); - global float4 * y4 = (global float4 *) y; - - float sumf = 0; - for (int i = get_sub_group_local_id(); i < ne00/4; i += get_max_sub_group_size()) { - sumf += convert_float(x4[i].s0) * y4[i].s0; - sumf += convert_float(x4[i].s1) * y4[i].s1; - sumf += convert_float(x4[i].s2) * y4[i].s2; - sumf += convert_float(x4[i].s3) * y4[i].s3; - } - - float all_sum = sub_group_reduce_add(sumf); - if (get_sub_group_local_id() == 0) { - for (int i = 4*(ne00/4); i < ne00; ++i) { - all_sum += (float) x[i] * y[i]; - } - dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; - } - } - } -} - -//------------------------------------------------------------------------------ -// mul_mat_f16_f32_l4 -//------------------------------------------------------------------------------ -// Assumes row size (ne00) is a multiple of 4 -#ifdef ADRENO_GPU -REQD_SUBGROUP_SIZE_64 -#endif -kernel void kernel_mul_mat_f16_f32_l4( - global char * src0, - ulong offset0, - global char * src1, - ulong offset1, - global float * dst, - ulong offsetd, - int ne00, - int ne01, - int ne02, - ulong nb00, - ulong nb01, - ulong nb02, - ulong nb03, - int ne10, - int ne11, - int ne12, - ulong nb10, - ulong nb11, - ulong nb12, - ulong nb13, - int ne0, - int ne1, - int r2, - int r3 -) { - src0 = (global char*)((global char*)src0 + offset0); - src1 = (global char*)((global char*)src1 + offset1); - dst = (global float*)((global char*)dst + offsetd); - - int nrows = ne11; - int r0 = get_group_id(0); - int im = get_group_id(2); - - int i12 = im%ne12; - int i13 = im/ne12; - - ulong offset_src0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03; - - global half4 * x4 = (global half4 *) (src0 + offset_src0); - - for (int r1 = 0; r1 < nrows; ++r1) { - ulong offset_src1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13; - - global float4 * y4 = (global float4 *) (src1 + offset_src1); - - float sumf = 0; - for (int i = get_sub_group_local_id(); i < ne00/4; i += get_max_sub_group_size()) { - sumf += convert_float(x4[i].s0) * y4[i].s0; - sumf += convert_float(x4[i].s1) * y4[i].s1; - sumf += convert_float(x4[i].s2) * y4[i].s2; - sumf += convert_float(x4[i].s3) * y4[i].s3; - } - - float all_sum = sub_group_reduce_add(sumf); - if (get_sub_group_local_id() == 0) { - dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; - } - } -} - -//------------------------------------------------------------------------------ -// mul_vec_q_n_f32 -//------------------------------------------------------------------------------ -// function for calculate inner product between half a q4_0 block and 16 floats (yl), sumy is SUM(yl[i]) -// il indicates where the q4 quants begin (0 or QK4_0/4) -// we assume that the yl's have been multiplied with the appropriate scale factor -// that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096) -inline float block_q_4_0_dot_y( - global struct block_q4_0 * qb_curr, - float sumy, - private float * yl, - int il -) { - float d = qb_curr->d; - float2 acc = 0.f; - global ushort * qs = ((global ushort *)qb_curr + 1 + il/2); - for (int i = 0; i < 8; i+=2) { - acc.s0 += yl[i + 0] * (qs[i / 2] & 0x000F) - + yl[i + 1] * (qs[i / 2] & 0x0F00); - acc.s1 += yl[i + 8] * (qs[i / 2] & 0x00F0) - + yl[i + 9] * (qs[i / 2] & 0xF000); - } - return d * (sumy * -8.f + acc.s0 + acc.s1); -} - -#ifdef INTEL_GPU -#define N_DST 4 // each SIMD group works on 4 rows -#define N_SIMDGROUP 1 // number of SIMD groups in a thread group -#define N_SIMDWIDTH 16 // assuming SIMD group size is 16 -#elif defined (ADRENO_GPU) -#define N_DST 4 -#define N_SIMDGROUP 1 -#define N_SIMDWIDTH 64 -#endif - -inline void mul_vec_q_n_f32( - global void * src0, - global float * src1, - global float * dst, - int ne00, - int ne01, - int ne02, - int ne10, - int ne12, - int ne0, - int ne1, - int r2, - int r3 -) { - - const ulong nb = ne00/QK4_0; - - int r0 = get_group_id(0); - int r1 = get_group_id(1); - int im = get_group_id(2); - - // (r0 * N_SIMDGROUP + get_sub_group_id()) is essenatially the linear global - // id of a SIMD group in the grid. - int first_row = (r0 * N_SIMDGROUP + get_sub_group_id()) * N_DST; - - int i12 = im%ne12; - int i13 = im/ne12; - - ulong offset0 = first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); - - global struct block_q4_0 * x = (global struct block_q4_0 *) src0 + offset0; - global float * y = (global float *) src1 + r1*ne10 + im*ne00*ne1; - - float yl[16]; // src1 vector cache - float sumf[N_DST]={0.f}; - - int ix = get_sub_group_local_id()/2; - int il = 8*(get_sub_group_local_id()%2); - - global float * yb = y + ix * QK4_0 + il; - - // each thread in a SIMD group deals with half a block. - for (int ib = ix; ib < nb; ib += N_SIMDWIDTH/2) { - float sumy = 0; - for (int i = 0; i < 8; i += 2) { - sumy += yb[i] + yb[i+1]; - yl[i+0] = yb[i+ 0]; - yl[i+1] = yb[i+ 1]/256.f; - sumy += yb[i+16] + yb[i+17]; - yl[i+8] = yb[i+16]/16.f; - yl[i+9] = yb[i+17]/4096.f; - } - - for (int row = 0; row < N_DST; row++) { - sumf[row] += block_q_4_0_dot_y(x+ib+row*nb, sumy, yl, il); - } - - // One thread in a SIMD group (i.e., subgroup) handles a half block, - // hence then entire SIMD group handles SIMDWIDTH/2 blocks. - // y points to the activation matrix (of type float). Therefore for - // one thread, the # of blocks y should advance is SIMDWIDTH/2 (because - // SIMDWIDTH/2 blocks are processed by a SIMD group) - in terms of - // floats, it is QK4_0 * (SIMDWIDTH/2), where QK4_0 is the block size. - yb += QK4_0 * (N_SIMDWIDTH/2); - } - - // The above does not work for Adreno - it produces incorrect results for - // row = 1, 2, 3 and only row = 0 gives the correct result. - // If N_DST is changed, the below array must be initialized accordingly. - // This also seems to perform better on Intel. - float tot[N_DST] = { - sub_group_reduce_add(sumf[0]), sub_group_reduce_add(sumf[1]), - sub_group_reduce_add(sumf[2]), sub_group_reduce_add(sumf[3])}; - for (int row = 0; row < N_DST; ++row) { - if (get_sub_group_local_id() == 0 && first_row + row < ne01) { - dst[r1*ne0 + im*ne0*ne1 + first_row + row] = tot[row]; - } - } -} - -#ifdef INTEL_GPU -REQD_SUBGROUP_SIZE_16 -#elif defined (ADRENO_GPU) -REQD_SUBGROUP_SIZE_64 -#endif -kernel void kernel_mul_mat_q4_0_f32( - global void * src0, - ulong offset0, - global float * src1, - ulong offset1, - global float * dst, - ulong offsetd, - int ne00, - int ne01, - int ne02, - int ne10, - int ne12, - int ne0, - int ne1, - int r2, - int r3 -) { - src0 = (global void*)((global char*)src0 + offset0); - src1 = (global float*)((global char*)src1 + offset1); - dst = (global float*)((global char*)dst + offsetd); - - mul_vec_q_n_f32(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3); -} - -// -// This variant unrolls the loops and uses vector types instead of pointers. -// It improves performance on Adreno but not so much on Intel. -// -inline float block_q_4_0_dot_y_v( - global struct block_q4_0 * qb_curr, - float sumy, - float16 yl, - int il -) { - float d = qb_curr->d; - float acc = 0.f; - global ushort * qs = ((global ushort *)qb_curr + 1 + il/2); - - acc += yl.s0 * (qs[0] & 0x000F); - acc += yl.s1 * (qs[0] & 0x0F00); - acc += yl.s8 * (qs[0] & 0x00F0); - acc += yl.s9 * (qs[0] & 0xF000); - - acc += yl.s2 * (qs[1] & 0x000F); - acc += yl.s3 * (qs[1] & 0x0F00); - acc += yl.sa * (qs[1] & 0x00F0); - acc += yl.sb * (qs[1] & 0xF000); - - acc += yl.s4 * (qs[2] & 0x000F); - acc += yl.s5 * (qs[2] & 0x0F00); - acc += yl.sc * (qs[2] & 0x00F0); - acc += yl.sd * (qs[2] & 0xF000); - - acc += yl.s6 * (qs[3] & 0x000F); - acc += yl.s7 * (qs[3] & 0x0F00); - acc += yl.se * (qs[3] & 0x00F0); - acc += yl.sf * (qs[3] & 0xF000); - - return d * (sumy * -8.f + acc); -} - -#undef N_DST -#undef N_SIMDGROUP -#undef N_SIMDWIDTH - -#ifdef INTEL_GPU -#define N_DST 4 // each SIMD group works on 4 rows -#define N_SIMDGROUP 1 // number of SIMD groups in a thread group -#define N_SIMDWIDTH 16 // assuming SIMD group size is 16 -#elif defined (ADRENO_GPU) -#define N_DST 4 -#define N_SIMDGROUP 1 -#define N_SIMDWIDTH 64 -#endif - -inline void mul_vec_q_n_f32_v( - global void * src0, - global float * src1, - global float * dst, - int ne00, - int ne01, - int ne02, - int ne10, - int ne12, - int ne0, - int ne1, - int r2, - int r3 -) { - const ulong nb = ne00/QK4_0; - - int r0 = get_group_id(0); - int r1 = get_group_id(1); - int im = get_group_id(2); - - // (r0 * N_SIMDGROUP + get_sub_group_id()) is essenatially the linear global - // id of a SIMD group in the grid. - int first_row = (r0 * N_SIMDGROUP + get_sub_group_id()) * N_DST; - - int i12 = im%ne12; - int i13 = im/ne12; - - ulong offset0 = first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); - - global struct block_q4_0 * x = (global struct block_q4_0 *) src0 + offset0; - global float * y = (global float *) src1 + r1*ne10 + im*ne00*ne1; - - float16 yl; // src1 vector cache - float4 sumf = (float4)(0.f, 0.f, 0.f, 0.f); - - int ix = get_sub_group_local_id()/2; - int il = 8*(get_sub_group_local_id()%2); - - global float * yb = y + ix * QK4_0 + il; - - // each thread in a SIMD group deals with half a block. - for (int ib = ix; ib < nb; ib += N_SIMDWIDTH/2) { - float sumy = 0; - - sumy += yb[0]; - sumy += yb[1]; - sumy += yb[2]; - sumy += yb[3]; - sumy += yb[4]; - sumy += yb[5]; - sumy += yb[6]; - sumy += yb[7]; - - sumy += yb[16]; - sumy += yb[17]; - sumy += yb[18]; - sumy += yb[19]; - sumy += yb[20]; - sumy += yb[21]; - sumy += yb[22]; - sumy += yb[23]; - - - yl.s0 = yb[0]; - yl.s1 = yb[1]/256.f; - - yl.s2 = yb[2]; - yl.s3 = yb[3]/256.f; - - yl.s4 = yb[4]; - yl.s5 = yb[5]/256.f; - - yl.s6 = yb[6]; - yl.s7 = yb[7]/256.f; - - yl.s8 = yb[16]/16.f; - yl.s9 = yb[17]/4096.f; - - yl.sa = yb[18]/16.f; - yl.sb = yb[19]/4096.f; - - yl.sc = yb[20]/16.f; - yl.sd = yb[21]/4096.f; - - yl.se = yb[22]/16.f; - yl.sf = yb[23]/4096.f; - - sumf.s0 += block_q_4_0_dot_y_v(x+ib+0*nb, sumy, yl, il); - sumf.s1 += block_q_4_0_dot_y_v(x+ib+1*nb, sumy, yl, il); - sumf.s2 += block_q_4_0_dot_y_v(x+ib+2*nb, sumy, yl, il); - sumf.s3 += block_q_4_0_dot_y_v(x+ib+3*nb, sumy, yl, il); - - // One thread in a SIMD group (i.e., subgroup) handles a half block, - // hence then entire SIMD group handles SIMDWIDTH/2 blocks. - // y points to the activation matrix (of type float). Therefore for - // one thread, the # of blocks y should advance is SIMDWIDTH/2 (because - // SIMDWIDTH/2 blocks are processed by a SIMD group) - in terms of - // floats, it is QK4_0 * (SIMDWIDTH/2), where QK4_0 is the block size. - yb += QK4_0 * (N_SIMDWIDTH/2); - } - - // The above does not work for Adreno - it produces incorrect results for - // row = 1, 2, 3 and only row = 0 gives the correct result. - // If N_DST is changed, the below array must be initialized accordingly. - // This also seems to perform better on Intel. - float4 tot = (float4)( - sub_group_reduce_add(sumf.s0), sub_group_reduce_add(sumf.s1), - sub_group_reduce_add(sumf.s2), sub_group_reduce_add(sumf.s3) - ); - - if (get_sub_group_local_id() == 0) { - if (first_row + 0 < ne01) { - dst[r1*ne0 + im*ne0*ne1 + first_row + 0] = tot.s0; - } - if (first_row + 1 < ne01) { - dst[r1*ne0 + im*ne0*ne1 + first_row + 1] = tot.s1; - } - if (first_row + 2 < ne01) { - dst[r1*ne0 + im*ne0*ne1 + first_row + 2] = tot.s2; - } - if (first_row + 3 < ne01) { - dst[r1*ne0 + im*ne0*ne1 + first_row + 3] = tot.s3; - } - } -} - -#ifdef INTEL_GPU -REQD_SUBGROUP_SIZE_16 -#elif defined (ADRENO_GPU) -REQD_SUBGROUP_SIZE_64 -#endif -kernel void kernel_mul_mat_q4_0_f32_v( - global void * src0, - ulong offset0, - global float * src1, - ulong offset1, - global float * dst, - ulong offsetd, - int ne00, - int ne01, - int ne02, - int ne10, - int ne12, - int ne0, - int ne1, - int r2, - int r3 -) { - src0 = (global void*)((global char*)src0 + offset0); - src1 = (global float*)((global char*)src1 + offset1); - dst = (global float*)((global char*)dst + offsetd); - - mul_vec_q_n_f32_v(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3); -} - -//------------------------------------------------------------------------------ -// kernel_convert_block_q4_0 -// Convert the block_q4_0 format to 2 separate arrays (AOS -> SOA). -// This kernel does not deshuffle the bits. -//------------------------------------------------------------------------------ -kernel void kernel_convert_block_q4_0( - global struct block_q4_0 * src0, - global uchar * dst_q, - global half * dst_d -) { - global struct block_q4_0 * b = (global struct block_q4_0 *) src0 + get_global_id(0); - global uchar * q = (global uchar *) dst_q + QK4_0/2*get_global_id(0); - global half * d = (global half *) dst_d + get_global_id(0); - - *d = b->d; - - for (int i = 0; i < QK4_0/2; ++i) { - q[i] = b->qs[i]; - } -} - -kernel void kernel_restore_block_q4_0( - global uchar * src_q, - global half * src_d, - global struct block_q4_0 * dst -) { - global struct block_q4_0 * b = (global struct block_q4_0 *) dst + get_global_id(0); - global uchar * q = (global uchar *) src_q + QK4_0/2*get_global_id(0); - global half * d = (global half *) src_d + get_global_id(0); - - b->d = *d; - for (int i = 0; i < QK4_0/2; ++i) { - b->qs[i] = q[i]; - } -} - -//------------------------------------------------------------------------------ -// mul_vec_q_n_f32_flat -// -// This variation uses flat arrays (struct of arrays, SOA) representation for -// quant tensors. -//------------------------------------------------------------------------------ - -// This function requires the original shuffled weights. -// As a reminder, the original weights are shuffled so that (q[0], q[16]) are -// packed together in a byte, so are (q[1], q[17]) and so on. -inline float block_q_4_0_dot_y_flat( - global uchar * x, - global half * dh, - float sumy, - float16 yl, - int il -) { - float d = *dh; - global ushort * qs = ((global ushort *)x + il/2); - float acc = 0.f; - - acc += yl.s0 * (qs[0] & 0x000F); - acc += yl.s1 * (qs[0] & 0x0F00); - acc += yl.s8 * (qs[0] & 0x00F0); - acc += yl.s9 * (qs[0] & 0xF000); - - acc += yl.s2 * (qs[1] & 0x000F); - acc += yl.s3 * (qs[1] & 0x0F00); - acc += yl.sa * (qs[1] & 0x00F0); - acc += yl.sb * (qs[1] & 0xF000); - - acc += yl.s4 * (qs[2] & 0x000F); - acc += yl.s5 * (qs[2] & 0x0F00); - acc += yl.sc * (qs[2] & 0x00F0); - acc += yl.sd * (qs[2] & 0xF000); - - acc += yl.s6 * (qs[3] & 0x000F); - acc += yl.s7 * (qs[3] & 0x0F00); - acc += yl.se * (qs[3] & 0x00F0); - acc += yl.sf * (qs[3] & 0xF000); - - return d * (sumy * -8.f + acc); -} - -#undef N_DST -#undef N_SIMDGROUP -#undef N_SIMDWIDTH - -#ifdef INTEL_GPU -#define N_DST 4 // each SIMD group works on 4 rows -#define N_SIMDGROUP 1 // number of SIMD groups in a thread group -#define N_SIMDWIDTH 16 // assuming SIMD group size is 32 -#elif defined (ADRENO_GPU) -#define N_DST 4 -#define N_SIMDGROUP 1 -#define N_SIMDWIDTH 64 -#endif - -inline void mul_vec_q_n_f32_flat( - global uchar * src0_q, - global half * src0_d, - global float * src1, - global float * dst, - int ne00, - int ne01, - int ne02, - int ne10, - int ne12, - int ne0, - int ne1, - int r2, - int r3 -) { - const ulong nb = ne00/QK4_0; - - int r0 = get_group_id(0); - int r1 = get_group_id(1); - int im = get_group_id(2); - - // (r0 * N_SIMDGROUP + get_sub_group_id()) is the linear global id of - // a SIMD group in the grid. Each SIMD group produces N_DST values in the - // result, hence uses nb blocks, i.e., the offset becomes first_row*nb. - // Currently with llama2 7B, im is always 0. - // TODO: how to handle im/gqa*(nb*ne0)? - int first_row = (r0 * N_SIMDGROUP + get_sub_group_id()) * N_DST; - - int i12 = im%ne12; - int i13 = im/ne12; - - // The number of scales is the same as the number of blocks. - ulong offset0_d = first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); - // Each block contains QK4_0/2 uchars, hence offset for qs is as follows. - ulong offset0_q = (first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02)) * QK4_0/2; - - global uchar * x = (global uchar *) src0_q + offset0_q; - global half * d = (global half *) src0_d + offset0_d; - global float * y = (global float *) src1 + r1*ne10 + im*ne00*ne1; - - float16 yl; - float4 sumf = (float4)(0.f, 0.f, 0.f, 0.f); - - int ix = get_sub_group_local_id()/2; - int il = 8*(get_sub_group_local_id()%2); - - global float * yb = y + ix*QK4_0 + il; - - for (int ib = ix; ib < nb; ib += N_SIMDWIDTH/2) { - float sumy = 0.f; - - sumy += yb[0]; - sumy += yb[1]; - sumy += yb[2]; - sumy += yb[3]; - sumy += yb[4]; - sumy += yb[5]; - sumy += yb[6]; - sumy += yb[7]; - - sumy += yb[16]; - sumy += yb[17]; - sumy += yb[18]; - sumy += yb[19]; - sumy += yb[20]; - sumy += yb[21]; - sumy += yb[22]; - sumy += yb[23]; - - yl.s0 = yb[0]; - yl.s1 = yb[1]/256.f; - - yl.s2 = yb[2]; - yl.s3 = yb[3]/256.f; - - yl.s4 = yb[4]; - yl.s5 = yb[5]/256.f; - - yl.s6 = yb[6]; - yl.s7 = yb[7]/256.f; - - yl.s8 = yb[16]/16.f; - yl.s9 = yb[17]/4096.f; - - yl.sa = yb[18]/16.f; - yl.sb = yb[19]/4096.f; - - yl.sc = yb[20]/16.f; - yl.sd = yb[21]/4096.f; - - yl.se = yb[22]/16.f; - yl.sf = yb[23]/4096.f; - - sumf.s0 += block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 0*nb*QK4_0/2, d + ib + 0*nb, sumy, yl, il); - sumf.s1 += block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 1*nb*QK4_0/2, d + ib + 1*nb, sumy, yl, il); - sumf.s2 += block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 2*nb*QK4_0/2, d + ib + 2*nb, sumy, yl, il); - sumf.s3 += block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 3*nb*QK4_0/2, d + ib + 3*nb, sumy, yl, il); - - yb += QK4_0 * (N_SIMDWIDTH/2); - } - - float4 tot = (float4)( - sub_group_reduce_add(sumf.s0), sub_group_reduce_add(sumf.s1), - sub_group_reduce_add(sumf.s2), sub_group_reduce_add(sumf.s3) - ); - - if (get_sub_group_local_id() == 0) { - if (first_row + 0 < ne01) { - dst[r1*ne0 + im*ne0*ne1 + first_row + 0] = tot.s0; - } - if (first_row + 1 < ne01) { - dst[r1*ne0 + im*ne0*ne1 + first_row + 1] = tot.s1; - } - if (first_row + 2 < ne01) { - dst[r1*ne0 + im*ne0*ne1 + first_row + 2] = tot.s2; - } - if (first_row + 3 < ne01) { - dst[r1*ne0 + im*ne0*ne1 + first_row + 3] = tot.s3; - } - } -} - -#ifdef INTEL_GPU -REQD_SUBGROUP_SIZE_16 -#elif defined (ADRENO_GPU) -REQD_SUBGROUP_SIZE_64 -#endif -kernel void kernel_mul_mat_q4_0_f32_flat( - global uchar * src0_q, - global half * src0_d, - global float * src1, - ulong offset1, - global float * dst, - ulong offsetd, - int ne00, - int ne01, - int ne02, - int ne10, - int ne12, - int ne0, - int ne1, - int r2, - int r3 -) { - src1 = (global float*)((global char*)src1 + offset1); - dst = (global float*)((global char*)dst + offsetd); - - mul_vec_q_n_f32_flat(src0_q, src0_d, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3); -} - -// -// This variant outputs 8 values. -// -#undef N_DST -#undef N_SIMDGROUP -#undef N_SIMDWIDTH - -#ifdef INTEL_GPU -#define N_DST 8 // each SIMD group works on 8 rows -#define N_SIMDGROUP 1 // number of SIMD groups in a thread group -#define N_SIMDWIDTH 16 // assuming SIMD group size is 32 -#elif defined (ADRENO_GPU) -#define N_DST 8 -#define N_SIMDGROUP 1 -#define N_SIMDWIDTH 64 -#endif - -inline void mul_vec_q_n_f32_8x_flat( - global uchar * src0_q, - global half * src0_d, - global float * src1, - global float * dst, - int ne00, - int ne01, - int ne02, - int ne10, - int ne12, - int ne0, - int ne1, - int r2, - int r3 -) { - const ulong nb = ne00/QK4_0; - - int r0 = get_group_id(0); - int r1 = get_group_id(1); - int im = get_group_id(2); - - // (r0 * N_SIMDGROUP + get_sub_group_id()) is the linear global id of - // a SIMD group in the grid. Each SIMD group produces N_DST values in the - // result, hence uses nb blocks, i.e., the offset becomes first_row*nb. - // Currently with llama2 7B, im is always 0. - // TODO: how to handle im/gqa*(nb*ne0)? - int first_row = (r0 * N_SIMDGROUP + get_sub_group_id()) * N_DST; - - int i12 = im%ne12; - int i13 = im/ne12; - - // The number of scales is the same as the number of blocks. - ulong offset0_d = first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); - // Each block contains QK4_0/2 uchars, hence offset for qs is as follows. - ulong offset0_q = (first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02)) * QK4_0/2; - - global uchar * x = (global uchar *) src0_q + offset0_q; - global half * d = (global half *) src0_d + offset0_d; - global float * y = (global float *) src1 + r1*ne10 + im*ne00*ne1; - - float16 yl; - float8 sumf = 0.f; - - int ix = get_sub_group_local_id()/2; - int il = 8*(get_sub_group_local_id()%2); - - global float * yb = y + ix*QK4_0 + il; - - for (int ib = ix; ib < nb; ib += N_SIMDWIDTH/2) { - float sumy = 0.f; - - sumy += yb[0]; - sumy += yb[1]; - sumy += yb[2]; - sumy += yb[3]; - sumy += yb[4]; - sumy += yb[5]; - sumy += yb[6]; - sumy += yb[7]; - - sumy += yb[16]; - sumy += yb[17]; - sumy += yb[18]; - sumy += yb[19]; - sumy += yb[20]; - sumy += yb[21]; - sumy += yb[22]; - sumy += yb[23]; - - yl.s0 = yb[0]; - yl.s1 = yb[1]/256.f; - - yl.s2 = yb[2]; - yl.s3 = yb[3]/256.f; - - yl.s4 = yb[4]; - yl.s5 = yb[5]/256.f; - - yl.s6 = yb[6]; - yl.s7 = yb[7]/256.f; - - yl.s8 = yb[16]/16.f; - yl.s9 = yb[17]/4096.f; - - yl.sa = yb[18]/16.f; - yl.sb = yb[19]/4096.f; - - yl.sc = yb[20]/16.f; - yl.sd = yb[21]/4096.f; - - yl.se = yb[22]/16.f; - yl.sf = yb[23]/4096.f; - - sumf.s0 += block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 0*nb*QK4_0/2, d + ib + 0*nb, sumy, yl, il); - sumf.s1 += block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 1*nb*QK4_0/2, d + ib + 1*nb, sumy, yl, il); - sumf.s2 += block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 2*nb*QK4_0/2, d + ib + 2*nb, sumy, yl, il); - sumf.s3 += block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 3*nb*QK4_0/2, d + ib + 3*nb, sumy, yl, il); - - sumf.s4 += block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 4*nb*QK4_0/2, d + ib + 4*nb, sumy, yl, il); - sumf.s5 += block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 5*nb*QK4_0/2, d + ib + 5*nb, sumy, yl, il); - sumf.s6 += block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 6*nb*QK4_0/2, d + ib + 6*nb, sumy, yl, il); - sumf.s7 += block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 7*nb*QK4_0/2, d + ib + 7*nb, sumy, yl, il); - - yb += QK4_0 * (N_SIMDWIDTH/2); - } - - float8 tot = (float8)( - sub_group_reduce_add(sumf.s0), sub_group_reduce_add(sumf.s1), - sub_group_reduce_add(sumf.s2), sub_group_reduce_add(sumf.s3), - sub_group_reduce_add(sumf.s4), sub_group_reduce_add(sumf.s5), - sub_group_reduce_add(sumf.s6), sub_group_reduce_add(sumf.s7) - ); - - if (get_sub_group_local_id() == 0) { - if (first_row + 0 < ne01) { - dst[r1*ne0 + im*ne0*ne1 + first_row + 0] = tot.s0; - } - if (first_row + 1 < ne01) { - dst[r1*ne0 + im*ne0*ne1 + first_row + 1] = tot.s1; - } - if (first_row + 2 < ne01) { - dst[r1*ne0 + im*ne0*ne1 + first_row + 2] = tot.s2; - } - if (first_row + 3 < ne01) { - dst[r1*ne0 + im*ne0*ne1 + first_row + 3] = tot.s3; - } - - if (first_row + 4 < ne01) { - dst[r1*ne0 + im*ne0*ne1 + first_row + 4] = tot.s4; - } - if (first_row + 5 < ne01) { - dst[r1*ne0 + im*ne0*ne1 + first_row + 5] = tot.s5; - } - if (first_row + 6 < ne01) { - dst[r1*ne0 + im*ne0*ne1 + first_row + 6] = tot.s6; - } - if (first_row + 7 < ne01) { - dst[r1*ne0 + im*ne0*ne1 + first_row + 7] = tot.s7; - } - } -} - -#ifdef INTEL_GPU -REQD_SUBGROUP_SIZE_16 -#elif defined (ADRENO_GPU) -REQD_SUBGROUP_SIZE_64 -#endif -kernel void kernel_mul_mat_q4_0_f32_8x_flat( - global uchar * src0_q, - global half * src0_d, - global float * src1, - ulong offset1, - global float * dst, - ulong offsetd, - int ne00, - int ne01, - int ne02, - int ne10, - int ne12, - int ne0, - int ne1, - int r2, - int r3 -) { - src1 = (global float*)((global char*)src1 + offset1); - dst = (global float*)((global char*)dst + offsetd); - - mul_vec_q_n_f32_8x_flat(src0_q, src0_d, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3); -} diff --git a/ggml/src/ggml-opencl/kernels/ggml-opencl_im2col.cl b/ggml/src/ggml-opencl/kernels/ggml-opencl_im2col.cl deleted file mode 100644 index 9b41dfb25..000000000 --- a/ggml/src/ggml-opencl/kernels/ggml-opencl_im2col.cl +++ /dev/null @@ -1,146 +0,0 @@ -#ifdef cl_khr_fp16 -#pragma OPENCL EXTENSION cl_khr_fp16 : enable -#elif defined(cl_amd_fp16) -#pragma OPENCL EXTENSION cl_amd_fp16 : enable -#else -#error "Half precision floating point not supportedby OpenCL implementation on your device." -#endif - -#ifdef cl_khr_subgroups -#pragma OPENCL EXTENSION cl_khr_subgroups : enable -#elif defined(cl_intel_subgroups) -#pragma OPENCL EXTENSION cl_intel_subgroups : enable -#else -#error "Subgroup not supported on your device." -#endif - -#ifdef cl_intel_required_subgroup_size -// Always use subgroup size of 32 on Intel. -#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable -#define INTEL_GPU 1 -#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16))) -#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32))) -#elif defined(cl_qcom_reqd_sub_group_size) -// Always use subgroups size of 64 on Adreno. -#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable -#define ADRENO_GPU 1 -#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half"))) -#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full"))) -#else -// TODO: do not know how to choose subgroup size on other GPUs. -#error "Selecting subgroup size is not supported on your device." -#endif - -kernel void kernel_im2col_f32( - global float * src1, - ulong offset1, - global float * dst, - ulong offsetd, - ulong batch_offset, - ulong delta_offset, - long IW, - long IH, - long IC, - long OW, - long OH, - long KW, - long KH, - long pelements, - long CHW, - int s0, - int s1, - int p0, - int p1, - int d0, - int d1 -) { - // threadIdx.x + blockIdx.x * blockDim.x - long i = get_global_id(0); - if (i >= pelements) { - return; - } - - src1 = (global float*)((global char*)src1 + offset1); - dst = (global float*)((global char*)dst + offsetd); - - long ksize = OW * (KH > 1 ? KW : 1); - long kx = i / ksize; - long kd = kx * ksize; - long ky = (i - kd) / OW; - long ix = i % OW; - - long oh = get_group_id(1); - long batch = get_group_id(2) / IC; - long ic = get_group_id(2) % IC; - - long iiw = ix * s0 + kx * d0 - p0; - long iih = oh * s1 + ky * d1 - p1; - - long offset_dst = - ((batch * OH + oh) * OW + ix) * CHW + - (ic * (KW * KH) + ky * KW + kx); - - if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) { - dst[offset_dst] = 0.0f; - } else { - long offset_src = ic * delta_offset + batch * batch_offset; - dst[offset_dst] = src1[offset_src + iih * IW + iiw]; - } -} - -kernel void kernel_im2col_f16( - global float * src1, - ulong offset1, - global half * dst, - ulong offsetd, - ulong batch_offset, - ulong delta_offset, - long IW, - long IH, - long IC, - long OW, - long OH, - long KW, - long KH, - long pelements, - long CHW, - int s0, - int s1, - int p0, - int p1, - int d0, - int d1 -) { - long i = get_global_id(0); - - if (i >= pelements) { - return; - } - - src1 = (global float*)((global char*)src1 + offset1); - dst = (global half*)((global char*)dst + offsetd); - - long ksize = OW * (KH > 1 ? KW : 1); - long kx = i / ksize; - long kd = kx * ksize; - long ky = (i - kd) / OW; - long ix = i % OW; - - long oh = get_group_id(1); - long batch = get_group_id(2) / IC; - long ic = get_group_id(2) % IC; - - long iiw = ix * s0 + kx * d0 - p0; - long iih = oh * s1 + ky * d1 - p1; - - long offset_dst = - ((batch * OH + oh) * OW + ix) * CHW + - (ic * (KW * KH) + ky * KW + kx); - - if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) { - dst[offset_dst] = 0.0f; - } else { - long offset_src = ic * delta_offset + batch * batch_offset; - dst[offset_dst] = src1[offset_src + iih * IW + iiw]; - } -} diff --git a/ggml/src/ggml-opencl/kernels/ggml-opencl_mm.cl b/ggml/src/ggml-opencl/kernels/ggml-opencl_mm.cl deleted file mode 100644 index e19e9a2f4..000000000 --- a/ggml/src/ggml-opencl/kernels/ggml-opencl_mm.cl +++ /dev/null @@ -1,1225 +0,0 @@ -//------------------------------------------------------------------------------ -// This file is contains additional mulmat kernels -// (and potentially other kernels). -//------------------------------------------------------------------------------ -#ifdef cl_khr_fp16 -#pragma OPENCL EXTENSION cl_khr_fp16 : enable -#elif defined(cl_amd_fp16) -#pragma OPENCL EXTENSION cl_amd_fp16 : enable -#else -#error "Half precision floating point not supportedby OpenCL implementation on your device." -#endif - -#ifdef cl_khr_subgroups -#pragma OPENCL EXTENSION cl_khr_subgroups : enable -#elif defined(cl_intel_subgroups) -#pragma OPENCL EXTENSION cl_intel_subgroups : enable -#else -#error "Subgroup not supported on your device." -#endif - -#ifdef cl_intel_required_subgroup_size -// Always use subgroup size of 32 on Intel. -#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable -#define INTEL_GPU 1 -#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16))) -#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32))) -#elif defined(cl_qcom_reqd_sub_group_size) -// Always use subgroups size of 64 on Adreno. -#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable -#define ADRENO_GPU 1 -#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half"))) -#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full"))) -#else -// TODO: do not know how to choose subgroup size on other GPUs. -#error "Selecting subgroup size is not supported on your device." -#endif - -#define QK4_0 32 -#define QR4_0 2 -#define QK4_1 32 -#define QR4_1 2 -#define QK5_0 32 -#define QR5_0 2 -#define QK5_1 32 -#define QR5_1 2 -#define QK8_0 32 -#define QR8_0 1 -#define QK_K 256 -#define K_QUANTS_PER_ITERATION 2 - -typedef char int8_t; -typedef uchar uint8_t; -typedef short int16_t; -typedef ushort uint16_t; -typedef int int32_t; -typedef uint uint32_t; - -//------------------------------------------------------------------------------ -// block_q4_0 -//------------------------------------------------------------------------------ -struct block_q4_0 -{ - half d; - uint8_t qs[QK4_0 / 2]; -}; - -//------------------------------------------------------------------------------ -// block_q6_K -//------------------------------------------------------------------------------ -// 6-bit quantization -// weight is represented as x = a * q -// 16 blocks of 16 elements each -// Effectively 6.5625 bits per weight -typedef struct { - uint8_t ql[QK_K/2]; // quants, lower 4 bits - uint8_t qh[QK_K/4]; // quants, upper 2 bits - int8_t scales[QK_K/16]; // scales, quantized with 8 bits - half d; // super-block scale -} block_q6_K; - -//------------------------------------------------------------------------------ -// These are the variant for matmatmul, based on the matvecmul kernel with -// flattened block_q4_0. -//------------------------------------------------------------------------------ - -// Common dot prod. -inline float mm_block_q_4_0_dot_y_flat( - global uchar * x, - global half * dh, - float sumy, - float16 yl, - int il -) { - float d = *dh; - global ushort * qs = ((global ushort *)x + il/2); - float acc = 0.f; - - acc += yl.s0 * (qs[0] & 0x000F); - acc += yl.s1 * (qs[0] & 0x0F00); - acc += yl.s8 * (qs[0] & 0x00F0); - acc += yl.s9 * (qs[0] & 0xF000); - - acc += yl.s2 * (qs[1] & 0x000F); - acc += yl.s3 * (qs[1] & 0x0F00); - acc += yl.sa * (qs[1] & 0x00F0); - acc += yl.sb * (qs[1] & 0xF000); - - acc += yl.s4 * (qs[2] & 0x000F); - acc += yl.s5 * (qs[2] & 0x0F00); - acc += yl.sc * (qs[2] & 0x00F0); - acc += yl.sd * (qs[2] & 0xF000); - - acc += yl.s6 * (qs[3] & 0x000F); - acc += yl.s7 * (qs[3] & 0x0F00); - acc += yl.se * (qs[3] & 0x00F0); - acc += yl.sf * (qs[3] & 0xF000); - - return d * (sumy * -8.f + acc); -} - -#undef N_DST -#undef N_SIMDGROUP -#undef N_SIMDWIDTH - -#ifdef INTEL_GPU -#define N_DST 8 // each SIMD group works on 8 rows (in weights matrix) -#define N_SIMDGROUP 1 // number of SIMD groups in a thread group -#define N_SIMDWIDTH 16 // assuming SIMD group size is 16 -#elif defined (ADRENO_GPU) -#define N_DST 8 -#define N_SIMDGROUP 1 -#define N_SIMDWIDTH 64 -#endif -// -// This variant performs 1d blocking with 8x output. -// Eeach simdgroup outputs 8 values on `n0` dim (row in the output matrix). -// -inline void mul_mat_q_n_f32_1d_8x_flat( - global uchar * src0_q, - global half * src0_d, - global float * src1, - global float * dst, - int ne00, - int ne01, - int ne02, - int ne10, - int ne12, - int ne0, - int ne1, - int r2, - int r3 -) { - const int nb = ne00/QK4_0; - - int r0 = get_group_id(0); - int r1 = get_group_id(1); - int im = get_group_id(2); - - // (r0 * N_SIMDGROUP + get_sub_group_id()) is the linear global id of - // a SIMD group in the grid. Each SIMD group produces N_DST values in the - // result, hence uses nb blocks, i.e., the offset becomes first_row*nb. - // Currently with llama2 7B, im is always 0. - // TODO: how to handle im/gqa*(nb*ne0)? - int first_row = (r0 * N_SIMDGROUP + get_sub_group_id()) * N_DST; - - int i12 = im%ne12; - int i13 = im/ne12; - - // The number of scales is the same as the number of blocks. - ulong offset0_d = first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); - // Each block contains QK4_0/2 uchars, hence offset for qs is as follows. - ulong offset0_q = (first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02)) * QK4_0/2; - - global uchar * x = (global uchar *) src0_q + offset0_q; - global half * d = (global half *) src0_d + offset0_d; - global float * y = (global float *) src1 + r1*ne10 + im*ne00*ne1; - - float16 yl; - float8 sumf = (float8)(0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f); - - int ix = get_sub_group_local_id()/2; - int il = 8*(get_sub_group_local_id()%2); - - global float * yb = y + ix*QK4_0 + il; - - for (int ib = ix; ib < nb; ib += N_SIMDWIDTH/2) { - float sumy = 0.f; - - sumy += yb[0]; - sumy += yb[1]; - sumy += yb[2]; - sumy += yb[3]; - sumy += yb[4]; - sumy += yb[5]; - sumy += yb[6]; - sumy += yb[7]; - - sumy += yb[16]; - sumy += yb[17]; - sumy += yb[18]; - sumy += yb[19]; - sumy += yb[20]; - sumy += yb[21]; - sumy += yb[22]; - sumy += yb[23]; - - yl.s0 = yb[0]; - yl.s1 = yb[1]/256.f; - - yl.s2 = yb[2]; - yl.s3 = yb[3]/256.f; - - yl.s4 = yb[4]; - yl.s5 = yb[5]/256.f; - - yl.s6 = yb[6]; - yl.s7 = yb[7]/256.f; - - yl.s8 = yb[16]/16.f; - yl.s9 = yb[17]/4096.f; - - yl.sa = yb[18]/16.f; - yl.sb = yb[19]/4096.f; - - yl.sc = yb[20]/16.f; - yl.sd = yb[21]/4096.f; - - yl.se = yb[22]/16.f; - yl.sf = yb[23]/4096.f; - - sumf.s0 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 0*nb*QK4_0/2, d + ib + 0*nb, sumy, yl, il); - sumf.s1 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 1*nb*QK4_0/2, d + ib + 1*nb, sumy, yl, il); - sumf.s2 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 2*nb*QK4_0/2, d + ib + 2*nb, sumy, yl, il); - sumf.s3 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 3*nb*QK4_0/2, d + ib + 3*nb, sumy, yl, il); - - sumf.s4 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 4*nb*QK4_0/2, d + ib + 4*nb, sumy, yl, il); - sumf.s5 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 5*nb*QK4_0/2, d + ib + 5*nb, sumy, yl, il); - sumf.s6 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 6*nb*QK4_0/2, d + ib + 6*nb, sumy, yl, il); - sumf.s7 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 7*nb*QK4_0/2, d + ib + 7*nb, sumy, yl, il); - - yb += QK4_0 * (N_SIMDWIDTH/2); - } - - float8 tot = (float8)( - sub_group_reduce_add(sumf.s0), sub_group_reduce_add(sumf.s1), - sub_group_reduce_add(sumf.s2), sub_group_reduce_add(sumf.s3), - sub_group_reduce_add(sumf.s4), sub_group_reduce_add(sumf.s5), - sub_group_reduce_add(sumf.s6), sub_group_reduce_add(sumf.s7) - ); - - if (get_sub_group_local_id() == 0) { - if (first_row + 0 < ne01) { - dst[r1*ne0 + im*ne0*ne1 + first_row + 0] = tot.s0; - } - if (first_row + 1 < ne01) { - dst[r1*ne0 + im*ne0*ne1 + first_row + 1] = tot.s1; - } - if (first_row + 2 < ne01) { - dst[r1*ne0 + im*ne0*ne1 + first_row + 2] = tot.s2; - } - if (first_row + 3 < ne01) { - dst[r1*ne0 + im*ne0*ne1 + first_row + 3] = tot.s3; - } - - if (first_row + 4 < ne01) { - dst[r1*ne0 + im*ne0*ne1 + first_row + 4] = tot.s4; - } - if (first_row + 5 < ne01) { - dst[r1*ne0 + im*ne0*ne1 + first_row + 5] = tot.s5; - } - if (first_row + 6 < ne01) { - dst[r1*ne0 + im*ne0*ne1 + first_row + 6] = tot.s6; - } - if (first_row + 7 < ne01) { - dst[r1*ne0 + im*ne0*ne1 + first_row + 7] = tot.s7; - } - } -} - -#ifdef INTEL_GPU -REQD_SUBGROUP_SIZE_16 -#elif defined (ADRENO_GPU) -REQD_SUBGROUP_SIZE_64 -#endif -kernel void kernel_mul_mat_q4_0_f32_1d_8x_flat( - global uchar * src0_q, - global half * src0_d, - global float * src1, - ulong offset1, - global float * dst, - ulong offsetd, - int ne00, - int ne01, - int ne02, - int ne10, - int ne12, - int ne0, - int ne1, - int r2, - int r3 -) { - src1 = (global float*)((global char*)src1 + offset1); - dst = (global float*)((global char*)dst + offsetd); - - mul_mat_q_n_f32_1d_8x_flat(src0_q, src0_d, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3); -} - -#undef N_DST -#undef N_SIMDGROUP -#undef N_SIMDWIDTH - -#ifdef INTEL_GPU -#define N_DST 16 // each SIMD group works on 8 rows (in weights matrix) -#define N_SIMDGROUP 1 // number of SIMD groups in a thread group -#define N_SIMDWIDTH 16 // assuming SIMD group size is 16 -#elif defined (ADRENO_GPU) -#define N_DST 16 -#define N_SIMDGROUP 1 -#define N_SIMDWIDTH 64 -#endif -// -// This variant performs 1d blocking with 16x output. -// Eeach simdgroup outputs 16 values on `n0` dim (row in the output matrix). -// -inline void mul_mat_q_n_f32_1d_16x_flat( - global uchar * src0_q, - global half * src0_d, - global float * src1, - global float * dst, - int ne00, - int ne01, - int ne02, - int ne10, - int ne12, - int ne0, - int ne1, - int r2, - int r3 -) { - const int nb = ne00/QK4_0; - - int r0 = get_group_id(0); - int r1 = get_group_id(1); - int im = get_group_id(2); - - // (r0 * N_SIMDGROUP + get_sub_group_id()) is the linear global id of - // a SIMD group in the grid. Each SIMD group produces N_DST values in the - // result, hence uses nb blocks, i.e., the offset becomes first_row*nb. - // Currently with llama2 7B, im is always 0. - // TODO: how to handle im/gqa*(nb*ne0)? - int first_row = (r0 * N_SIMDGROUP + get_sub_group_id()) * N_DST; - - int i12 = im%ne12; - int i13 = im/ne12; - - // The number of scales is the same as the number of blocks. - ulong offset0_d = first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); - // Each block contains QK4_0/2 uchars, hence offset for qs is as follows. - ulong offset0_q = (first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02)) * QK4_0/2; - - global uchar * x = (global uchar *) src0_q + offset0_q; - global half * d = (global half *) src0_d + offset0_d; - global float * y = (global float *) src1 + r1*ne10 + im*ne00*ne1; - - float16 yl; - float16 sumf = (float16)(0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, - 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f); - - int ix = get_sub_group_local_id()/2; - int il = 8*(get_sub_group_local_id()%2); - - global float * yb = y + ix*QK4_0 + il; - - for (int ib = ix; ib < nb; ib += N_SIMDWIDTH/2) { - float sumy = 0.f; - - sumy += yb[0]; - sumy += yb[1]; - sumy += yb[2]; - sumy += yb[3]; - sumy += yb[4]; - sumy += yb[5]; - sumy += yb[6]; - sumy += yb[7]; - - sumy += yb[16]; - sumy += yb[17]; - sumy += yb[18]; - sumy += yb[19]; - sumy += yb[20]; - sumy += yb[21]; - sumy += yb[22]; - sumy += yb[23]; - - yl.s0 = yb[0]; - yl.s1 = yb[1]/256.f; - - yl.s2 = yb[2]; - yl.s3 = yb[3]/256.f; - - yl.s4 = yb[4]; - yl.s5 = yb[5]/256.f; - - yl.s6 = yb[6]; - yl.s7 = yb[7]/256.f; - - yl.s8 = yb[16]/16.f; - yl.s9 = yb[17]/4096.f; - - yl.sa = yb[18]/16.f; - yl.sb = yb[19]/4096.f; - - yl.sc = yb[20]/16.f; - yl.sd = yb[21]/4096.f; - - yl.se = yb[22]/16.f; - yl.sf = yb[23]/4096.f; - - sumf.s0 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 0*nb*QK4_0/2, d + ib + 0*nb, sumy, yl, il); - sumf.s1 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 1*nb*QK4_0/2, d + ib + 1*nb, sumy, yl, il); - sumf.s2 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 2*nb*QK4_0/2, d + ib + 2*nb, sumy, yl, il); - sumf.s3 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 3*nb*QK4_0/2, d + ib + 3*nb, sumy, yl, il); - - sumf.s4 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 4*nb*QK4_0/2, d + ib + 4*nb, sumy, yl, il); - sumf.s5 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 5*nb*QK4_0/2, d + ib + 5*nb, sumy, yl, il); - sumf.s6 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 6*nb*QK4_0/2, d + ib + 6*nb, sumy, yl, il); - sumf.s7 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 7*nb*QK4_0/2, d + ib + 7*nb, sumy, yl, il); - - sumf.s8 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 8*nb*QK4_0/2, d + ib + 8*nb, sumy, yl, il); - sumf.s9 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 9*nb*QK4_0/2, d + ib + 9*nb, sumy, yl, il); - sumf.sa += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 10*nb*QK4_0/2, d + ib + 10*nb, sumy, yl, il); - sumf.sb += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 11*nb*QK4_0/2, d + ib + 11*nb, sumy, yl, il); - - sumf.sc += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 12*nb*QK4_0/2, d + ib + 12*nb, sumy, yl, il); - sumf.sd += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 13*nb*QK4_0/2, d + ib + 13*nb, sumy, yl, il); - sumf.se += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 14*nb*QK4_0/2, d + ib + 14*nb, sumy, yl, il); - sumf.sf += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 15*nb*QK4_0/2, d + ib + 15*nb, sumy, yl, il); - - yb += QK4_0 * (N_SIMDWIDTH/2); - } - - float16 tot = (float16)( - sub_group_reduce_add(sumf.s0), sub_group_reduce_add(sumf.s1), - sub_group_reduce_add(sumf.s2), sub_group_reduce_add(sumf.s3), - sub_group_reduce_add(sumf.s4), sub_group_reduce_add(sumf.s5), - sub_group_reduce_add(sumf.s6), sub_group_reduce_add(sumf.s7), - - sub_group_reduce_add(sumf.s8), sub_group_reduce_add(sumf.s9), - sub_group_reduce_add(sumf.sa), sub_group_reduce_add(sumf.sb), - sub_group_reduce_add(sumf.sc), sub_group_reduce_add(sumf.sd), - sub_group_reduce_add(sumf.se), sub_group_reduce_add(sumf.sf) - ); - - if (get_sub_group_local_id() == 0) { - if (first_row + 0 < ne01) { - dst[r1*ne0 + im*ne0*ne1 + first_row + 0] = tot.s0; - } - if (first_row + 1 < ne01) { - dst[r1*ne0 + im*ne0*ne1 + first_row + 1] = tot.s1; - } - if (first_row + 2 < ne01) { - dst[r1*ne0 + im*ne0*ne1 + first_row + 2] = tot.s2; - } - if (first_row + 3 < ne01) { - dst[r1*ne0 + im*ne0*ne1 + first_row + 3] = tot.s3; - } - - if (first_row + 4 < ne01) { - dst[r1*ne0 + im*ne0*ne1 + first_row + 4] = tot.s4; - } - if (first_row + 5 < ne01) { - dst[r1*ne0 + im*ne0*ne1 + first_row + 5] = tot.s5; - } - if (first_row + 6 < ne01) { - dst[r1*ne0 + im*ne0*ne1 + first_row + 6] = tot.s6; - } - if (first_row + 7 < ne01) { - dst[r1*ne0 + im*ne0*ne1 + first_row + 7] = tot.s7; - } - - if (first_row + 8 < ne01) { - dst[r1*ne0 + im*ne0*ne1 + first_row + 8] = tot.s8; - } - if (first_row + 9 < ne01) { - dst[r1*ne0 + im*ne0*ne1 + first_row + 9] = tot.s9; - } - if (first_row + 10 < ne01) { - dst[r1*ne0 + im*ne0*ne1 + first_row + 10] = tot.sa; - } - if (first_row + 11 < ne01) { - dst[r1*ne0 + im*ne0*ne1 + first_row + 11] = tot.sb; - } - - if (first_row + 12 < ne01) { - dst[r1*ne0 + im*ne0*ne1 + first_row + 12] = tot.sc; - } - if (first_row + 13 < ne01) { - dst[r1*ne0 + im*ne0*ne1 + first_row + 13] = tot.sd; - } - if (first_row + 14 < ne01) { - dst[r1*ne0 + im*ne0*ne1 + first_row + 14] = tot.se; - } - if (first_row + 15 < ne01) { - dst[r1*ne0 + im*ne0*ne1 + first_row + 15] = tot.sf; - } - } -} - -#ifdef INTEL_GPU -REQD_SUBGROUP_SIZE_16 -#elif defined (ADRENO_GPU) -REQD_SUBGROUP_SIZE_64 -#endif -kernel void kernel_mul_mat_q4_0_f32_1d_16x_flat( - global uchar * src0_q, - global half * src0_d, - global float * src1, - ulong offset1, - global float * dst, - ulong offsetd, - int ne00, - int ne01, - int ne02, - int ne10, - int ne12, - int ne0, - int ne1, - int r2, - int r3 -) { - src1 = (global float*)((global char*)src1 + offset1); - dst = (global float*)((global char*)dst + offsetd); - - mul_mat_q_n_f32_1d_16x_flat(src0_q, src0_d, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3); -} - -//------------------------------------------------------------------------------ -// kernel_mul_mat_q4_0_f32_flat_v0 -//------------------------------------------------------------------------------ -inline float block_q_4_0_dot_y_flat_v2( - half x, - half d, - float sumy, - float4 yl -) { - uchar2 q = as_uchar2(x); - float acc = 0.0f; - - acc += (q.s0 & 0x0F) * yl.s0; - acc += (q.s1 & 0x0F) * yl.s1; - - acc += (q.s0 & 0xF0) * yl.s2; - acc += (q.s1 & 0xF0) * yl.s3; - - return d * (sumy * -8.f + acc);; -} - -inline float block_q_4_0_dot_y_flat_v4( - float x, - half d, - float sumy, - float8 yl -) { - uchar4 q = as_uchar4(x); - float acc = 0.0f; - - acc += (q.s0 & 0x0F) * yl.s0; - acc += (q.s1 & 0x0F) * yl.s1; - acc += (q.s2 & 0x0F) * yl.s2; - acc += (q.s3 & 0x0F) * yl.s3; - - acc += (q.s0 & 0xF0) * yl.s4; - acc += (q.s1 & 0xF0) * yl.s5; - acc += (q.s2 & 0xF0) * yl.s6; - acc += (q.s3 & 0xF0) * yl.s7; - - return d * (sumy * -8.f + acc);; -} - -inline float block_q_4_0_dot_y_flat_v8( - float2 x, - half d, - float sumy, - float16 yl -) { - uchar8 q = as_uchar8(x); - float acc = 0.0f; - - acc += (q.s0 & 0x0F) * yl.s0; - acc += (q.s1 & 0x0F) * yl.s1; - acc += (q.s2 & 0x0F) * yl.s2; - acc += (q.s3 & 0x0F) * yl.s3; - acc += (q.s4 & 0x0F) * yl.s4; - acc += (q.s5 & 0x0F) * yl.s5; - acc += (q.s6 & 0x0F) * yl.s6; - acc += (q.s7 & 0x0F) * yl.s7; - - acc += (q.s0 & 0xF0) * yl.s8; - acc += (q.s1 & 0xF0) * yl.s9; - acc += (q.s2 & 0xF0) * yl.sa; - acc += (q.s3 & 0xF0) * yl.sb; - acc += (q.s4 & 0xF0) * yl.sc; - acc += (q.s5 & 0xF0) * yl.sd; - acc += (q.s6 & 0xF0) * yl.se; - acc += (q.s7 & 0xF0) * yl.sf; - - return d * (sumy * -8.f + acc);; -} - -#undef N_DST -#undef N_SIMDGROUP -#undef N_SIMDWIDTH - -#ifdef INTEL_GPU -#define THREADS_PER_BLK 4 // Number of threads per block, or each thread process 1/THREADS_PER_BLK of a block -#define N_DST 4 -#define N_SIMDGROUP 1 -#define N_SIMDWIDTH 16 -#elif defined (ADRENO_GPU) -#define THREADS_PER_BLK 4 -#define N_DST 4 -#define N_SIMDGROUP 1 -#define N_SIMDWIDTH 64 -#endif - -#if THREADS_PER_BLK == 2 // Each thread processes 1/2 block -# define ACT_TY float16 -# define Q_BLK_LD_TY float2 -# define block_q_4_0_dot_y_flat block_q_4_0_dot_y_flat_v8 -#elif THREADS_PER_BLK == 4 // Each thread processes 1/4 block -# define ACT_TY float8 -# define Q_BLK_LD_TY float -# define block_q_4_0_dot_y_flat block_q_4_0_dot_y_flat_v4 -#elif THREADS_PER_BLK == 8 // Each thread processes 1/8 block -# define ACT_TY float4 -# define Q_BLK_LD_TY half -# define block_q_4_0_dot_y_flat block_q_4_0_dot_y_flat_v2 -#endif - -#define BTYES_PER_THREAD_IN_BLK (QK4_0/2/THREADS_PER_BLK) - -#if N_DST == 2 -# define SUM_TY float2 -#elif N_DST == 4 -# define SUM_TY float4 -#elif N_DST == 8 -# define SUM_TY float8 -#elif N_DST == 16 -# define SUM_TY float16 -#endif - -#ifdef INTEL_GPU -REQD_SUBGROUP_SIZE_16 -#elif defined (ADRENO_GPU) -REQD_SUBGROUP_SIZE_64 -#endif -kernel void kernel_mul_mat_q4_0_f32_flat_v0( - global uchar * src0_q, - global half * src0_d, - global float * src1, - ulong offset1, - global float * dst, - ulong offsetd, - int ne00, - int ne01, - int ne02, - int ne10, - int ne12, - int ne0, - int ne1, - int r2, - int r3 -) { - src1 = (global float*)((global char*)src1 + offset1); - dst = (global float*)((global char*)dst + offsetd); - - const int nb = ne00/QK4_0; - - int r0 = get_group_id(0); - int r1 = get_group_id(1); - int im = get_group_id(2); - - int first_row = (r0 * N_SIMDGROUP + get_sub_group_id()) * N_DST; - - int i12 = im%ne12; - int i13 = im/ne12; - - // The number of scales is the same as the number of blocks. - ulong offset0_d = first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); - // Each block contains QK4_0/2 uchars, hence offset for qs is as follows. - ulong offset0_q = (first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02)) * QK4_0/2; - - global uchar * x = (global uchar *) src0_q + offset0_q; - global half * d = (global half *) src0_d + offset0_d; - global float * y = (global float *) src1 + r1*ne10 + im*ne00*ne1; - - int ix = get_sub_group_local_id()/THREADS_PER_BLK; - int il = get_sub_group_local_id()%THREADS_PER_BLK; - - global float * yb = y + ix*QK4_0 + BTYES_PER_THREAD_IN_BLK*il; - - // Registers for caching activation - ACT_TY yl = 0.f; - - // Registers for caching quants - Q_BLK_LD_TY q_blk_0 = 0, q_blk_1 = 0; -#if N_DST == 4 || N_DST == 8 || N_DST == 16 - Q_BLK_LD_TY q_blk_2 = 0, q_blk_3 = 0; -#endif -#if N_DST == 8 || N_DST == 16 - Q_BLK_LD_TY q_blk_4 = 0, q_blk_5 = 0, q_blk_6 = 0, q_blk_7 = 0; -#endif - - // Partial sum - SUM_TY sumf = 0.f; - - for (int ib = ix; ib < nb; ib += N_SIMDWIDTH/THREADS_PER_BLK) { - float sumy = 0.f; - - q_blk_0 = *(global Q_BLK_LD_TY*)(x + ib*QK4_0/2 + BTYES_PER_THREAD_IN_BLK*il + 0*nb*QK4_0/2); - q_blk_1 = *(global Q_BLK_LD_TY*)(x + ib*QK4_0/2 + BTYES_PER_THREAD_IN_BLK*il + 1*nb*QK4_0/2); -#if N_DST == 4 || N_DST == 8 || N_DST == 16 - q_blk_2 = *(global Q_BLK_LD_TY*)(x + ib*QK4_0/2 + BTYES_PER_THREAD_IN_BLK*il + 2*nb*QK4_0/2); - q_blk_3 = *(global Q_BLK_LD_TY*)(x + ib*QK4_0/2 + BTYES_PER_THREAD_IN_BLK*il + 3*nb*QK4_0/2); -#endif -#if N_DST == 8 || N_DST == 16 - q_blk_4 = (*(global Q_BLK_LD_TY*)(x + ib*QK4_0/2 + BTYES_PER_THREAD_IN_BLK*il + 4*nb*QK4_0/2)); - q_blk_5 = (*(global Q_BLK_LD_TY*)(x + ib*QK4_0/2 + BTYES_PER_THREAD_IN_BLK*il + 5*nb*QK4_0/2)); - q_blk_6 = (*(global Q_BLK_LD_TY*)(x + ib*QK4_0/2 + BTYES_PER_THREAD_IN_BLK*il + 6*nb*QK4_0/2)); - q_blk_7 = (*(global Q_BLK_LD_TY*)(x + ib*QK4_0/2 + BTYES_PER_THREAD_IN_BLK*il + 7*nb*QK4_0/2)); -#endif - - // Load activation -#if THREADS_PER_BLK == 2 // Each thread processes 1/2 block - yl.s01234567 = *(global float8 *)(yb); - yl.s89abcdef = *(global float8 *)(yb + 16); - - sumy += yl.s0; - sumy += yl.s1; - sumy += yl.s2; - sumy += yl.s3; - sumy += yl.s4; - sumy += yl.s5; - sumy += yl.s6; - sumy += yl.s7; - sumy += yl.s8; yl.s8 /= 16.f; - sumy += yl.s9; yl.s9 /= 16.f; - sumy += yl.sa; yl.sa /= 16.f; - sumy += yl.sb; yl.sb /= 16.f; - sumy += yl.sc; yl.sc /= 16.f; - sumy += yl.sd; yl.sd /= 16.f; - sumy += yl.se; yl.se /= 16.f; - sumy += yl.sf; yl.sf /= 16.f; -#elif THREADS_PER_BLK == 4 // Each thread processes 1/4 block - yl.s0123 = *(global float4 *)(yb); - yl.s4567 = *(global float4 *)(yb + 16); - - sumy += yl.s0; - sumy += yl.s1; - sumy += yl.s2; - sumy += yl.s3; - sumy += yl.s4; yl.s4 /= 16.f; - sumy += yl.s5; yl.s5 /= 16.f; - sumy += yl.s6; yl.s6 /= 16.f; - sumy += yl.s7; yl.s7 /= 16.f; -#elif THREADS_PER_BLK == 8 // Each thread processes 1/8 block - yl.s01 = *(global float2 *)(yb); - yl.s23 = *(global float2 *)(yb + 16); - - sumy += yl.s0; - sumy += yl.s1; - sumy += yl.s2; yl.s2 /= 16.f; - sumy += yl.s3; yl.s3 /= 16.f; -#endif - - sumf.s0 += block_q_4_0_dot_y_flat(q_blk_0, *(d + ib + 0*nb), sumy, yl); - sumf.s1 += block_q_4_0_dot_y_flat(q_blk_1, *(d + ib + 1*nb), sumy, yl); -#if N_DST == 4 || N_DST == 8 || N_DST == 16 - sumf.s2 += block_q_4_0_dot_y_flat(q_blk_2, *(d + ib + 2*nb), sumy, yl); - sumf.s3 += block_q_4_0_dot_y_flat(q_blk_3, *(d + ib + 3*nb), sumy, yl); -#endif -#if N_DST == 8 || N_DST == 16 - sumf.s4 += block_q_4_0_dot_y_flat(q_blk_4, *(d + ib + 4*nb), sumy, yl); - sumf.s5 += block_q_4_0_dot_y_flat(q_blk_5, *(d + ib + 5*nb), sumy, yl); - sumf.s6 += block_q_4_0_dot_y_flat(q_blk_6, *(d + ib + 6*nb), sumy, yl); - sumf.s7 += block_q_4_0_dot_y_flat(q_blk_7, *(d + ib + 7*nb), sumy, yl); -#endif - - yb += QK4_0 * (N_SIMDWIDTH/THREADS_PER_BLK); - } - - SUM_TY tot = (SUM_TY)( - sub_group_reduce_add(sumf.s0), sub_group_reduce_add(sumf.s1) -#if N_DST == 4 || N_DST == 8 || N_DST == 16 - , sub_group_reduce_add(sumf.s2), sub_group_reduce_add(sumf.s3) -#endif -#if N_DST == 8 || N_DST == 16 - , sub_group_reduce_add(sumf.s4), sub_group_reduce_add(sumf.s5) - , sub_group_reduce_add(sumf.s6), sub_group_reduce_add(sumf.s7) -#endif - ); - - if (get_sub_group_local_id() == 0) { - if (first_row + 0 < ne01) { - dst[r1*ne0 + im*ne0*ne1 + first_row + 0] = tot.s0; - } - if (first_row + 1 < ne01) { - dst[r1*ne0 + im*ne0*ne1 + first_row + 1] = tot.s1; - } -#if N_DST == 4 || N_DST == 8 || N_DST == 16 - if (first_row + 2 < ne01) { - dst[r1*ne0 + im*ne0*ne1 + first_row + 2] = tot.s2; - } - if (first_row + 3 < ne01) { - dst[r1*ne0 + im*ne0*ne1 + first_row + 3] = tot.s3; - } -#endif -#if N_DST == 8 || N_DST == 16 - if (first_row + 4 < ne01) { - dst[r1*ne0 + im*ne0*ne1 + first_row + 4] = tot.s4; - } - if (first_row + 5 < ne01) { - dst[r1*ne0 + im*ne0*ne1 + first_row + 5] = tot.s5; - } - if (first_row + 6 < ne01) { - dst[r1*ne0 + im*ne0*ne1 + first_row + 6] = tot.s6; - } - if (first_row + 7 < ne01) { - dst[r1*ne0 + im*ne0*ne1 + first_row + 7] = tot.s7; - } -#endif - } -} - -//------------------------------------------------------------------------------ -// Using image1d_buffer_t - -#if defined(cl_qcom_subgroup_shuffle) -#pragma OPENCL EXTENSION cl_qcom_subgroup_shuffle : enable -float qcom_sub_group_reduce_add(float sum) { - sum += qcom_sub_group_shuffle_down(sum, 32, CLK_SUB_GROUP_SHUFFLE_WIDTH_WAVE_SIZE_QCOM, 0.f); - sum += qcom_sub_group_shuffle_down(sum, 16, CLK_SUB_GROUP_SHUFFLE_WIDTH_WAVE_SIZE_QCOM, 0.f); - sum += qcom_sub_group_shuffle_down(sum, 8, CLK_SUB_GROUP_SHUFFLE_WIDTH_WAVE_SIZE_QCOM, 0.f); - sum += qcom_sub_group_shuffle_down(sum, 4, CLK_SUB_GROUP_SHUFFLE_WIDTH_WAVE_SIZE_QCOM, 0.f); - sum += qcom_sub_group_shuffle_down(sum, 2, CLK_SUB_GROUP_SHUFFLE_WIDTH_WAVE_SIZE_QCOM, 0.f); - sum += qcom_sub_group_shuffle_down(sum, 1, CLK_SUB_GROUP_SHUFFLE_WIDTH_WAVE_SIZE_QCOM, 0.f); - return sum; -} -#define sub_group_reduce_add qcom_sub_group_reduce_add -#else -#define sub_group_reduce_add sub_group_reduce_add -#endif - -#undef THREADS_PER_BLK -#undef N_DST -#undef N_SIMDGROUP -#undef N_SIMDWIDTH - -#ifdef INTEL_GPU -#define THREADS_PER_BLK 4 // Number of threads per block, or each thread process 1/THREADS_PER_BLK of a block -#define N_DST 4 -#define N_SIMDGROUP 1 -#define N_SIMDWIDTH 16 -#elif defined (ADRENO_GPU) -#define THREADS_PER_BLK 4 -#define N_DST 4 -#define N_SIMDGROUP 1 -#define N_SIMDWIDTH 64 -#endif - -#if THREADS_PER_BLK == 2 // Each thread processes 1/2 block -# define ACT_TY float16 -# define Q_BLK_LD_TY float2 -# define EXTRACT_BLK_DATA(tmp, part) *((float2*)&tmp + part) -# define block_q_4_0_dot_y_flat block_q_4_0_dot_y_flat_v8 -#elif THREADS_PER_BLK == 4 // Each thread processes 1/4 block -# define ACT_TY float8 -# define Q_BLK_LD_TY float -# define EXTRACT_BLK_DATA(tmp, part) *((float*)&tmp + part) -# define block_q_4_0_dot_y_flat block_q_4_0_dot_y_flat_v4 -#elif THREADS_PER_BLK == 8 // Each thread processes 1/8 block -# define ACT_TY float4 -# define Q_BLK_LD_TY half -# define EXTRACT_BLK_DATA(tmp, part) *((half*)&tmp + part) -# define block_q_4_0_dot_y_flat block_q_4_0_dot_y_flat_v2 -#endif - -#define BTYES_PER_THREAD_IN_BLK (QK4_0/2/THREADS_PER_BLK) - -#if N_DST == 2 -# define SUM_TY float2 -#elif N_DST == 4 -# define SUM_TY float4 -#elif N_DST == 8 -# define SUM_TY float8 -#elif N_DST == 16 -# define SUM_TY float16 -#endif - -#ifdef INTEL_GPU -REQD_SUBGROUP_SIZE_16 -#elif defined (ADRENO_GPU) -REQD_SUBGROUP_SIZE_64 -#endif -kernel void kernel_mul_mat_q4_0_f32_flat_img_v0( - read_only image1d_buffer_t src0_q, - read_only image1d_buffer_t src0_d, - global float * src1, - ulong offset1, - global float * dst, - ulong offsetd, - int ne00, - int ne01, - int ne02, - int ne10, - int ne12, - int ne0, - int ne1, - int r2, - int r3 -) { - src1 = (global float*)((global char*)src1 + offset1); - dst = (global float*)((global char*)dst + offsetd); - - const int nb = ne00/QK4_0; - - int r0 = get_group_id(0); - int r1 = get_group_id(1); - int im = get_group_id(2); - - int first_row = (r0 * N_SIMDGROUP + get_sub_group_id()) * N_DST; - - int i12 = im%ne12; - int i13 = im/ne12; - - // The number of scales is the same as the number of blocks. - ulong offset0_d = first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); - // Each block contains QK4_0/2 uchars, hence offset for qs is as follows. - ulong offset0_q = first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); - - global float * y = (global float *) src1 + r1*ne10 + im*ne00*ne1; - - int ix = get_sub_group_local_id()/THREADS_PER_BLK; - int il = get_sub_group_local_id()%THREADS_PER_BLK; - - global float * yb = y + ix*QK4_0 + BTYES_PER_THREAD_IN_BLK*il; - - // Registers for caching activation - ACT_TY yl = 0.f; - - // Registers for caching quants - Q_BLK_LD_TY q_blk_0 = 0, q_blk_1 = 0; -#if N_DST == 4 || N_DST == 8 || N_DST == 16 - Q_BLK_LD_TY q_blk_2 = 0, q_blk_3 = 0; -#endif -#if N_DST == 8 || N_DST == 16 - Q_BLK_LD_TY q_blk_4 = 0, q_blk_5 = 0, q_blk_6 = 0, q_blk_7 = 0; -#endif - - // Partial sum - SUM_TY sumf = 0.f; - - for (int ib = ix; ib < nb; ib += N_SIMDWIDTH/THREADS_PER_BLK) { - float sumy = 0.f;; - - float4 tmp; - tmp = read_imagef(src0_q, offset0_q + ib + 0*nb); - q_blk_0 = EXTRACT_BLK_DATA(tmp, il); - tmp = read_imagef(src0_q, offset0_q + ib + 1*nb); - q_blk_1 = EXTRACT_BLK_DATA(tmp, il); -#if N_DST == 4 || N_DST == 8 || N_DST == 16 - tmp = read_imagef(src0_q, offset0_q + ib + 2*nb); - q_blk_2 = EXTRACT_BLK_DATA(tmp, il); - tmp = read_imagef(src0_q, offset0_q + ib + 3*nb); - q_blk_3 = EXTRACT_BLK_DATA(tmp, il); -#endif -#if N_DST == 8 || N_DST == 16 - tmp = read_imagef(src0_q, offset0_q + ib + 4*nb); - q_blk_4 = EXTRACT_BLK_DATA(tmp, il); - tmp = read_imagef(src0_q, offset0_q + ib + 5*nb); - q_blk_5 = EXTRACT_BLK_DATA(tmp, il); - tmp = read_imagef(src0_q, offset0_q + ib + 6*nb); - q_blk_6 = EXTRACT_BLK_DATA(tmp, il); - tmp = read_imagef(src0_q, offset0_q + ib + 7*nb); - q_blk_7 = EXTRACT_BLK_DATA(tmp, il); -#endif - - // Load activation -#if THREADS_PER_BLK == 2 // Each thread processes 1/2 block - yl.s01234567 = *(global float8 *)(yb); - yl.s89abcdef = *(global float8 *)(yb + 16); - - sumy += yl.s0; - sumy += yl.s1; - sumy += yl.s2; - sumy += yl.s3; - sumy += yl.s4; - sumy += yl.s5; - sumy += yl.s6; - sumy += yl.s7; - sumy += yl.s8; yl.s8 /= 16.f; - sumy += yl.s9; yl.s9 /= 16.f; - sumy += yl.sa; yl.sa /= 16.f; - sumy += yl.sb; yl.sb /= 16.f; - sumy += yl.sc; yl.sc /= 16.f; - sumy += yl.sd; yl.sd /= 16.f; - sumy += yl.se; yl.se /= 16.f; - sumy += yl.sf; yl.sf /= 16.f; -#elif THREADS_PER_BLK == 4 // Each thread processes 1/4 block - yl.s0123 = *(global float4 *)(yb); - yl.s4567 = *(global float4 *)(yb + 16); - - sumy += yl.s0; - sumy += yl.s1; - sumy += yl.s2; - sumy += yl.s3; - sumy += yl.s4; yl.s4 /= 16.f; - sumy += yl.s5; yl.s5 /= 16.f; - sumy += yl.s6; yl.s6 /= 16.f; - sumy += yl.s7; yl.s7 /= 16.f; -#elif THREADS_PER_BLK == 8 // Each thread processes 1/8 block - yl.s01 = *(global float2 *)(yb); - yl.s23 = *(global float2 *)(yb + 16); - - sumy += yl.s0; - sumy += yl.s1; - sumy += yl.s2; yl.s2 /= 16.f; - sumy += yl.s3; yl.s3 /= 16.f; -#endif - - sumf.s0 += block_q_4_0_dot_y_flat(q_blk_0, read_imageh(src0_d, offset0_d + ib + 0*nb).s0, sumy, yl); - sumf.s1 += block_q_4_0_dot_y_flat(q_blk_1, read_imageh(src0_d, offset0_d + ib + 1*nb).s0, sumy, yl); -#if N_DST == 4 || N_DST == 8 || N_DST == 16 - sumf.s2 += block_q_4_0_dot_y_flat(q_blk_2, read_imageh(src0_d, offset0_d + ib + 2*nb).s0, sumy, yl); - sumf.s3 += block_q_4_0_dot_y_flat(q_blk_3, read_imageh(src0_d, offset0_d + ib + 3*nb).s0, sumy, yl); -#endif -#if N_DST == 8 || N_DST == 16 - sumf.s4 += block_q_4_0_dot_y_flat(q_blk_4, read_imageh(src0_d, offset0_d + ib + 4*nb).s0, sumy, yl); - sumf.s5 += block_q_4_0_dot_y_flat(q_blk_5, read_imageh(src0_d, offset0_d + ib + 5*nb).s0, sumy, yl); - sumf.s6 += block_q_4_0_dot_y_flat(q_blk_6, read_imageh(src0_d, offset0_d + ib + 6*nb).s0, sumy, yl); - sumf.s7 += block_q_4_0_dot_y_flat(q_blk_7, read_imageh(src0_d, offset0_d + ib + 7*nb).s0, sumy, yl); -#endif - - yb += QK4_0 * (N_SIMDWIDTH/THREADS_PER_BLK); - } - - SUM_TY tot = (SUM_TY)( - sub_group_reduce_add(sumf.s0), sub_group_reduce_add(sumf.s1) -#if N_DST == 4 || N_DST == 8 || N_DST == 16 - , sub_group_reduce_add(sumf.s2), sub_group_reduce_add(sumf.s3) -#endif -#if N_DST == 8 || N_DST == 16 - , sub_group_reduce_add(sumf.s4), sub_group_reduce_add(sumf.s5) - , sub_group_reduce_add(sumf.s6), sub_group_reduce_add(sumf.s7) -#endif - ); - - if (get_sub_group_local_id() == 0) { - if (first_row + 0 < ne01) { - dst[r1*ne0 + im*ne0*ne1 + first_row + 0] = tot.s0; - } - if (first_row + 1 < ne01) { - dst[r1*ne0 + im*ne0*ne1 + first_row + 1] = tot.s1; - } -#if N_DST == 4 || N_DST == 8 || N_DST == 16 - if (first_row + 2 < ne01) { - dst[r1*ne0 + im*ne0*ne1 + first_row + 2] = tot.s2; - } - if (first_row + 3 < ne01) { - dst[r1*ne0 + im*ne0*ne1 + first_row + 3] = tot.s3; - } -#endif -#if N_DST == 8 || N_DST == 16 - if (first_row + 4 < ne01) { - dst[r1*ne0 + im*ne0*ne1 + first_row + 4] = tot.s4; - } - if (first_row + 5 < ne01) { - dst[r1*ne0 + im*ne0*ne1 + first_row + 5] = tot.s5; - } - if (first_row + 6 < ne01) { - dst[r1*ne0 + im*ne0*ne1 + first_row + 6] = tot.s6; - } - if (first_row + 7 < ne01) { - dst[r1*ne0 + im*ne0*ne1 + first_row + 7] = tot.s7; - } -#endif - } -} - -//------------------------------------------------------------------------------ -// kernel_mul_mv_q6_K_f32 -//------------------------------------------------------------------------------ - -#undef N_DST -#undef N_SIMDGROUP -#undef N_SIMDWIDTH - -#ifdef INTEL_GPU -#define N_DST 1 // number of rows each SIMD group works on -#define N_SIMDGROUP 2 // number of SIMD groups in a thread group -#define N_SIMDWIDTH 16 // SIMD group size -#elif defined (ADRENO_GPU) -#define N_DST 1 -#define N_SIMDGROUP 2 -#define N_SIMDWIDTH 64 -#endif - -#define BLOCK_STRIDE (N_SIMDWIDTH/16) // number of blocks each subgroup processes - -#ifdef INTEL_GPU -REQD_SUBGROUP_SIZE_16 -#elif defined (ADRENO_GPU) -REQD_SUBGROUP_SIZE_64 -#endif -kernel void kernel_mul_mv_q6_K_f32( - global void * src0, - ulong offset0, - global float * src1, - ulong offset1, - global float * dst, - ulong offsetd, - int ne00, - int ne01, - int ne02, - int ne10, - int ne12, - int ne0, - int ne1, - int r2, - int r3 -) { - src0 = (global void*)((global char*)src0 + offset0); - src1 = (global float*)((global char*)src1 + offset1); - dst = (global float*)((global char*)dst + offsetd); - - uchar kmask1 = 0x03; - uchar kmask2 = 0x0C; - uchar kmask3 = 0x30; - uchar kmask4 = 0xC0; - - int nb = ne00/QK_K; - - int r0 = get_group_id(0); - int r1 = get_group_id(1); - int im = get_group_id(2); - - int row = N_SIMDGROUP * r0 + get_sub_group_id(); - - int i12 = im%ne12; - int i13 = im/ne12; - - ulong offset_src0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); - - global block_q6_K * x = (global block_q6_K *) src0 + row*nb + offset_src0; - global float * yy = (global float *) src1 + r1*ne10 + im*ne00*ne1; - - float sumf = 0; - - // For Q6_K quantization, 16 values forms a subblock, 16 subblock forms a - // block. Values in a subblock shares a scale that is quantized with 8 bits; - // the entire block shares a single floating point scale. - // For work distribution, each thread processes a subblock (16 weights), hence - // 16 threads process a (super) block -- a subgroup thus handles SIMDWIDTH/16 - // (super) blocks -- this is the block stride. - // The 16 threads that process a (super) block are split into 2 portions, each has - // 8 threads; each portion works on 8 subblocks. - // For subgroup of 16 threads, the entire subgroup works on a single (super) block - // before moving to the next (super) block. Thread0 - thread7 work on the - // first 8 subblocks; thread8 - thread15 works on the last 8 subblocks. - // Thread0 - thread3 work on subblocks 0, 2, 4, 6; thread4 - thread7 work on - // subblocks 1, 3, 5, 7. Each thread does not work on an entire subblock, but - // works on a total of 16 weight values. - int tid = get_sub_group_local_id()/BLOCK_STRIDE; // first block_stride groups have tid=0 - int ix = get_sub_group_local_id()%BLOCK_STRIDE; // first block is 0..block_stride-1 - int ip = tid/8; // first or second half of (super) block (0 or 1) - int il = tid%8; // each half has 8 parts, one per scale - int n = 4; // 4 scales at a time (and 4 sums) - int l0 = n*il; // offset into half-block, 0..28 - int is = 8*ip + l0/16; // 0, 1, 8, 9 - - int y_offset = 128*ip + l0; - int q_offset_l = 64*ip + l0; - int q_offset_h = 32*ip + l0; - - for (int i = ix; i < nb; i += BLOCK_STRIDE) { - - global uint8_t * q1 = x[i].ql + q_offset_l; - global uint8_t * q2 = q1 + QK_K/8; - global uint8_t * qh = x[i].qh + q_offset_h; - global int8_t * sc = x[i].scales + is; - - global float * y = yy + i * QK_K + y_offset; - - float dall = x[i].d; - - float4 sums = {0.f, 0.f, 0.f, 0.f}; - - sums.s0 += y[0+ 0] * ((float)((q1[0] & 0xF) | ((qh[0] & kmask1) << 4)) - 32.f); - sums.s1 += y[0+32] * ((float)((q2[0] & 0xF) | ((qh[0] & kmask2) << 2)) - 32.f); - sums.s2 += y[0+64] * ((float)((q1[0] >> 4) | ((qh[0] & kmask3) << 0)) - 32.f); - sums.s3 += y[0+96] * ((float)((q2[0] >> 4) | ((qh[0] & kmask4) >> 2)) - 32.f); - - sums.s0 += y[1+ 0] * ((float)((q1[1] & 0xF) | ((qh[1] & kmask1) << 4)) - 32.f); - sums.s1 += y[1+32] * ((float)((q2[1] & 0xF) | ((qh[1] & kmask2) << 2)) - 32.f); - sums.s2 += y[1+64] * ((float)((q1[1] >> 4) | ((qh[1] & kmask3) << 0)) - 32.f); - sums.s3 += y[1+96] * ((float)((q2[1] >> 4) | ((qh[1] & kmask4) >> 2)) - 32.f); - - sums.s0 += y[2+ 0] * ((float)((q1[2] & 0xF) | ((qh[2] & kmask1) << 4)) - 32.f); - sums.s1 += y[2+32] * ((float)((q2[2] & 0xF) | ((qh[2] & kmask2) << 2)) - 32.f); - sums.s2 += y[2+64] * ((float)((q1[2] >> 4) | ((qh[2] & kmask3) << 0)) - 32.f); - sums.s3 += y[2+96] * ((float)((q2[2] >> 4) | ((qh[2] & kmask4) >> 2)) - 32.f); - - sums.s0 += y[3+ 0] * ((float)((q1[3] & 0xF) | ((qh[3] & kmask1) << 4)) - 32.f); - sums.s1 += y[3+32] * ((float)((q2[3] & 0xF) | ((qh[3] & kmask2) << 2)) - 32.f); - sums.s2 += y[3+64] * ((float)((q1[3] >> 4) | ((qh[3] & kmask3) << 0)) - 32.f); - sums.s3 += y[3+96] * ((float)((q2[3] >> 4) | ((qh[3] & kmask4) >> 2)) - 32.f); - - sumf += dall * (sums.s0 * sc[0] + sums.s1 * sc[2] + sums.s2 * sc[4] + sums.s3 * sc[6]); - } - - float tot = sub_group_reduce_add(sumf); - if (get_sub_group_local_id() == 0) { - dst[r1*ne0 + im*ne0*ne1 + row] = tot; - } -} diff --git a/ggml/src/ggml-opencl/kernels/ggml-opencl_transpose_16.cl b/ggml/src/ggml-opencl/kernels/ggml-opencl_transpose_16.cl deleted file mode 100644 index cd4e0afba..000000000 --- a/ggml/src/ggml-opencl/kernels/ggml-opencl_transpose_16.cl +++ /dev/null @@ -1,26 +0,0 @@ -// 16-bit transpose, loading/storing a 4x4 tile of elements - -#pragma OPENCL EXTENSION cl_khr_fp16 : enable - -kernel void kernel_transpose_16( - __read_only image1d_buffer_t input, - __write_only image1d_buffer_t output, - const uint rows, - const uint cols -) { - - const int i = get_global_id(0); - const int j = get_global_id(1); - const int i_2 = i<<2; - const int j_2 = j<<2; - - half4 temp0 = read_imageh(input, (j_2+0)*cols+i); - half4 temp1 = read_imageh(input, (j_2+1)*cols+i); - half4 temp2 = read_imageh(input, (j_2+2)*cols+i); - half4 temp3 = read_imageh(input, (j_2+3)*cols+i); - - write_imageh(output, (i_2+0)*rows+j, (half4)(temp0.s0, temp1.s0, temp2.s0, temp3.s0)); - write_imageh(output, (i_2+1)*rows+j, (half4)(temp0.s1, temp1.s1, temp2.s1, temp3.s1)); - write_imageh(output, (i_2+2)*rows+j, (half4)(temp0.s2, temp1.s2, temp2.s2, temp3.s2)); - write_imageh(output, (i_2+3)*rows+j, (half4)(temp0.s3, temp1.s3, temp2.s3, temp3.s3)); -} diff --git a/ggml/src/ggml-opencl/kernels/ggml-opencl_transpose_32.cl b/ggml/src/ggml-opencl/kernels/ggml-opencl_transpose_32.cl deleted file mode 100644 index 914ec0193..000000000 --- a/ggml/src/ggml-opencl/kernels/ggml-opencl_transpose_32.cl +++ /dev/null @@ -1,25 +0,0 @@ -// 32-bit transpose, loading/storing a 4x4 tile of elements - -kernel void kernel_transpose_32( - __read_only image1d_buffer_t input, - __write_only image1d_buffer_t output, - const uint rows, - const uint cols -) { - - const int i = get_global_id(0); - const int j = get_global_id(1); - const int i_2 = i<<2; - const int j_2 = j<<2; - - float4 temp0 = read_imagef(input, (j_2+0)*cols+i); - float4 temp1 = read_imagef(input, (j_2+1)*cols+i); - float4 temp2 = read_imagef(input, (j_2+2)*cols+i); - float4 temp3 = read_imagef(input, (j_2+3)*cols+i); - - write_imagef(output, (i_2+0)*rows+j, (float4)(temp0.s0, temp1.s0, temp2.s0, temp3.s0)); - write_imagef(output, (i_2+1)*rows+j, (float4)(temp0.s1, temp1.s1, temp2.s1, temp3.s1)); - write_imagef(output, (i_2+2)*rows+j, (float4)(temp0.s2, temp1.s2, temp2.s2, temp3.s2)); - write_imagef(output, (i_2+3)*rows+j, (float4)(temp0.s3, temp1.s3, temp2.s3, temp3.s3)); - -} diff --git a/ggml/src/ggml-opencl/kernels/ggml-opencl_transpose_32_16.cl b/ggml/src/ggml-opencl/kernels/ggml-opencl_transpose_32_16.cl deleted file mode 100644 index d3bd1fabb..000000000 --- a/ggml/src/ggml-opencl/kernels/ggml-opencl_transpose_32_16.cl +++ /dev/null @@ -1,35 +0,0 @@ -// 32-bit transpose, loading/storing a 4x4 tile of elements -// Only used for activations -// converts to FP16 -// also adds zero padding for non multiple of 8 prompt lengths -#pragma OPENCL EXTENSION cl_khr_fp16 : enable - -kernel void kernel_transpose_32_16(__read_only image1d_buffer_t input, __write_only image1d_buffer_t output, const uint rows, const uint cols, const uint padded_rows) { - - const int i = get_global_id(0); - const int j = get_global_id(1); - const int i_2 = i<<2; - const int j_2 = j<<2; - half4 temp0 = {0,0,0,0}; // initialize outputs to 0 - half4 temp1 = {0,0,0,0}; - half4 temp2 = {0,0,0,0}; - half4 temp3 = {0,0,0,0}; - - if((j_2+0)*cols+i*4+3 < rows*cols*16){ // only load from a valid location. Otherwise keep register data as 0 - temp0 = read_imageh(input, (j_2+0)*cols+i); - } - if((j_2+1)*cols+i*4+3 < rows*cols*16){ - temp1 = read_imageh(input, (j_2+1)*cols+i); - } - if((j_2+2)*cols+i*4+3 < rows*cols*16){ - temp2 = read_imageh(input, (j_2+2)*cols+i); - } - if((j_2+3)*cols+i*4+3 < rows*cols*16){ - temp3 = read_imageh(input, (j_2+3)*cols+i); - } - - write_imageh(output, (i_2+0)*padded_rows+j, (half4)(temp0.s0, temp1.s0, temp2.s0, temp3.s0)); // no conditionals for output, includes zero padding - write_imageh(output, (i_2+1)*padded_rows+j, (half4)(temp0.s1, temp1.s1, temp2.s1, temp3.s1)); - write_imageh(output, (i_2+2)*padded_rows+j, (half4)(temp0.s2, temp1.s2, temp2.s2, temp3.s2)); - write_imageh(output, (i_2+3)*padded_rows+j, (half4)(temp0.s3, temp1.s3, temp2.s3, temp3.s3)); -} diff --git a/ggml/src/ggml-opencl/kernels/im2col_f16.cl b/ggml/src/ggml-opencl/kernels/im2col_f16.cl new file mode 100644 index 000000000..b84c89846 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/im2col_f16.cl @@ -0,0 +1,57 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +kernel void kernel_im2col_f16( + global float * src1, + ulong offset1, + global half * dst, + ulong offsetd, + ulong batch_offset, + ulong delta_offset, + long IW, + long IH, + long IC, + long OW, + long OH, + long KW, + long KH, + long pelements, + long CHW, + int s0, + int s1, + int p0, + int p1, + int d0, + int d1 +) { + long i = get_global_id(0); + if (i >= pelements) { + return; + } + + src1 = (global float*)((global char*)src1 + offset1); + dst = (global half*)((global char*)dst + offsetd); + + long ksize = OW * (KH > 1 ? KW : 1); + long kx = i / ksize; + long kd = kx * ksize; + long ky = (i - kd) / OW; + long ix = i % OW; + + long oh = get_group_id(1); + long batch = get_group_id(2) / IC; + long ic = get_group_id(2) % IC; + + long iiw = ix * s0 + kx * d0 - p0; + long iih = oh * s1 + ky * d1 - p1; + + long offset_dst = + ((batch * OH + oh) * OW + ix) * CHW + + (ic * (KW * KH) + ky * KW + kx); + + if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) { + dst[offset_dst] = 0.0f; + } else { + long offset_src = ic * delta_offset + batch * batch_offset; + dst[offset_dst] = src1[offset_src + iih * IW + iiw]; + } +} diff --git a/ggml/src/ggml-opencl/kernels/im2col_f32.cl b/ggml/src/ggml-opencl/kernels/im2col_f32.cl new file mode 100644 index 000000000..4bf65e4ea --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/im2col_f32.cl @@ -0,0 +1,57 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +kernel void kernel_im2col_f32( + global float * src1, + ulong offset1, + global float * dst, + ulong offsetd, + ulong batch_offset, + ulong delta_offset, + long IW, + long IH, + long IC, + long OW, + long OH, + long KW, + long KH, + long pelements, + long CHW, + int s0, + int s1, + int p0, + int p1, + int d0, + int d1 +) { + long i = get_global_id(0); + if (i >= pelements) { + return; + } + + src1 = (global float*)((global char*)src1 + offset1); + dst = (global float*)((global char*)dst + offsetd); + + long ksize = OW * (KH > 1 ? KW : 1); + long kx = i / ksize; + long kd = kx * ksize; + long ky = (i - kd) / OW; + long ix = i % OW; + + long oh = get_group_id(1); + long batch = get_group_id(2) / IC; + long ic = get_group_id(2) % IC; + + long iiw = ix * s0 + kx * d0 - p0; + long iih = oh * s1 + ky * d1 - p1; + + long offset_dst = + ((batch * OH + oh) * OW + ix) * CHW + + (ic * (KW * KH) + ky * KW + kx); + + if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) { + dst[offset_dst] = 0.0f; + } else { + long offset_src = ic * delta_offset + batch * batch_offset; + dst[offset_dst] = src1[offset_src + iih * IW + iiw]; + } +} diff --git a/ggml/src/ggml-opencl/kernels/mul.cl b/ggml/src/ggml-opencl/kernels/mul.cl new file mode 100644 index 000000000..2a2b4eb70 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/mul.cl @@ -0,0 +1,79 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +//------------------------------------------------------------------------------ +// mul +//------------------------------------------------------------------------------ +kernel void kernel_mul( + global char * src0, + ulong offset0, + global char * src1, + ulong offset1, + global char * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + int ne03, + ulong nb00, + ulong nb01, + ulong nb02, + ulong nb03, + int ne10, + int ne11, + int ne12, + int ne13, + ulong nb10, + ulong nb11, + ulong nb12, + ulong nb13, + int ne0, + int ne1, + int ne2, + int ne3, + ulong nb0, + ulong nb1, + ulong nb2, + ulong nb3 +) { + src0 = src0 + offset0; + src1 = src1 + offset1; + dst = dst + offsetd; + + int i03 = get_group_id(2); + int i02 = get_group_id(1); + int i01 = get_group_id(0); + + int i13 = i03 % ne13; + int i12 = i02 % ne12; + int i11 = i01 % ne11; + + global char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01; + global char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11; + global char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1; + + for (int i0 = get_local_id(0); i0 < ne0; i0 += get_local_size(0)) { + const int i10 = i0 % ne10; + *((global float *)(dst_ptr + i0*nb0)) = *((global float *)(src0_ptr + i0*nb00)) * *((global float *)(src1_ptr + i10*nb10)); + } +} + +// assumption: src1 is a row +// broadcast src1 into src0 +kernel void kernel_mul_row( + global float4 * src0, + ulong offset0, + global float4 * src1, + ulong offset1, + global float4 * dst, + ulong offsetd, + int ne +) { + src0 = (global float4*)((global char*)src0 + offset0); + src1 = (global float4*)((global char*)src1 + offset1); + dst = (global float4*)((global char*)dst + offsetd); + + // This performs better than using %. + uint gid = get_global_id(0); + uint idx1 = gid - (gid/ne)*ne; // get_global_id(0) % ne + dst[gid] = src0[gid] * src1[idx1]; +} diff --git a/ggml/src/ggml-opencl/kernels/ggml-opencl_mul_mat_Ab_Bi_8x4.cl b/ggml/src/ggml-opencl/kernels/mul_mat_Ab_Bi_8x4.cl similarity index 100% rename from ggml/src/ggml-opencl/kernels/ggml-opencl_mul_mat_Ab_Bi_8x4.cl rename to ggml/src/ggml-opencl/kernels/mul_mat_Ab_Bi_8x4.cl diff --git a/ggml/src/ggml-opencl/kernels/mul_mv_f16_f16.cl b/ggml/src/ggml-opencl/kernels/mul_mv_f16_f16.cl new file mode 100644 index 000000000..9393b5494 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/mul_mv_f16_f16.cl @@ -0,0 +1,118 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +#ifdef cl_intel_subgroups +#pragma OPENCL EXTENSION cl_intel_subgroups : enable +#else +#pragma OPENCL EXTENSION cl_khr_subgroups : enable +#endif + +#ifdef cl_intel_required_subgroup_size +#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable +#define INTEL_GPU 1 +#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16))) +#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32))) +#elif defined(cl_qcom_reqd_sub_group_size) +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable +#define ADRENO_GPU 1 +#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half"))) +#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full"))) +#endif + +#define N_F16_F16 4 + +#ifdef ADRENO_GPU +REQD_SUBGROUP_SIZE_64 +#endif +kernel void kernel_mul_mat_f16_f16( + global char * src0, + ulong offset0, + global char * src1, + ulong offset1, + global float * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + ulong nb00, + ulong nb01, + ulong nb02, + ulong nb03, + int ne10, + int ne11, + int ne12, + ulong nb10, + ulong nb11, + ulong nb12, + ulong nb13, + int ne0, + int ne1, + int r2, + int r3) +{ + src0 = (global char*)((global char*)src0 + offset0); + src1 = (global char*)((global char*)src1 + offset1); + dst = (global float*)((global char*)dst + offsetd); + + int r0 = get_group_id(0); + int rb = get_group_id(1)*N_F16_F16; + int im = get_group_id(2); + + int i12 = im%ne12; + int i13 = im/ne12; + + ulong offset_src0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03; + + global half * x = (global half *) (src0 + offset_src0); + + if (ne00 < 128) { + for (int row = 0; row < N_F16_F16; ++row) { + int r1 = rb + row; + if (r1 >= ne11) { + break; + } + + ulong offset_src1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13; + + global half * y = (global half *) (src1 + offset_src1); + + float sumf = 0; + for (int i = get_sub_group_local_id(); i < ne00; i += get_max_sub_group_size()) { + sumf += (half) x[i] * (half) y[i]; + } + + float all_sum = sub_group_reduce_add(sumf); + if (get_sub_group_local_id() == 0) { + dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; + } + } + } else { + global half4 * x4 = (global half4 *)x; + for (int row = 0; row < N_F16_F16; ++row) { + int r1 = rb + row; + if (r1 >= ne11) { + break; + } + + ulong offset_src1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13; + + global half * y = (global half *) (src1 + offset_src1); + global half4 * y4 = (global half4 *) y; + + float sumf = 0; + for (int i = get_sub_group_local_id(); i < ne00/4; i += get_max_sub_group_size()) { + sumf += (half) x4[i].s0 * y4[i].s0; + sumf += (half) x4[i].s1 * y4[i].s1; + sumf += (half) x4[i].s2 * y4[i].s2; + sumf += (half) x4[i].s3 * y4[i].s3; + } + + float all_sum = sub_group_reduce_add(sumf); + if (get_sub_group_local_id() == 0) { + for (int i = 4*(ne00/4); i < ne00; ++i) { + all_sum += (half) x[i] * y[i]; + } + dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; + } + } + } +} diff --git a/ggml/src/ggml-opencl/kernels/mul_mv_f16_f32.cl b/ggml/src/ggml-opencl/kernels/mul_mv_f16_f32.cl new file mode 100644 index 000000000..e52d3c6d4 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/mul_mv_f16_f32.cl @@ -0,0 +1,118 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +#ifdef cl_intel_subgroups +#pragma OPENCL EXTENSION cl_intel_subgroups : enable +#else +#pragma OPENCL EXTENSION cl_khr_subgroups : enable +#endif + +#ifdef cl_intel_required_subgroup_size +#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable +#define INTEL_GPU 1 +#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16))) +#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32))) +#elif defined(cl_qcom_reqd_sub_group_size) +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable +#define ADRENO_GPU 1 +#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half"))) +#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full"))) +#endif + +#define N_F16_F32 4 + +#ifdef ADRENO_GPU +REQD_SUBGROUP_SIZE_64 +#endif +kernel void kernel_mul_mat_f16_f32( + global char * src0, + ulong offset0, + global char * src1, + ulong offset1, + global float * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + ulong nb00, + ulong nb01, + ulong nb02, + ulong nb03, + int ne10, + int ne11, + int ne12, + ulong nb10, + ulong nb11, + ulong nb12, + ulong nb13, + int ne0, + int ne1, + int r2, + int r3 +) { + src0 = (global char*)((global char*)src0 + offset0); + src1 = (global char*)((global char*)src1 + offset1); + dst = (global float*)((global char*)dst + offsetd); + + int r0 = get_group_id(0); + int rb = get_group_id(1)*N_F16_F32; + int im = get_group_id(2); + + int i12 = im%ne12; + int i13 = im/ne12; + + ulong offset_src0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03; + + global half * x = (global half *) (src0 + offset_src0); + + if (ne00 < 128) { + for (int row = 0; row < N_F16_F32; ++row) { + int r1 = rb + row; + if (r1 >= ne11) { + break; + } + + ulong offset_src1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13; + + global float * y = (global float *) (src1 + offset_src1); + + float sumf = 0; + for (int i = get_sub_group_local_id(); i < ne00; i += get_max_sub_group_size()) { + sumf += convert_float(x[i]) * y[i]; + } + + float all_sum = sub_group_reduce_add(sumf); + if (get_sub_group_local_id() == 0) { + dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; + } + } + } else { + global half4 * x4 = (global half4 *)x; + for (int row = 0; row < N_F16_F32; ++row) { + int r1 = rb + row; + if (r1 >= ne11) { + break; + } + + ulong offset_src1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13; + + global float * y = (global float *) (src1 + offset_src1); + global float4 * y4 = (global float4 *) y; + + float sumf = 0; + for (int i = get_sub_group_local_id(); i < ne00/4; i += get_max_sub_group_size()) { + sumf += convert_float(x4[i].s0) * y4[i].s0; + sumf += convert_float(x4[i].s1) * y4[i].s1; + sumf += convert_float(x4[i].s2) * y4[i].s2; + sumf += convert_float(x4[i].s3) * y4[i].s3; + } + + float all_sum = sub_group_reduce_add(sumf); + if (get_sub_group_local_id() == 0) { + for (int i = 4*(ne00/4); i < ne00; ++i) { + all_sum += (float) x[i] * y[i]; + } + dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; + } + } + } +} diff --git a/ggml/src/ggml-opencl/kernels/mul_mv_f16_f32_1row.cl b/ggml/src/ggml-opencl/kernels/mul_mv_f16_f32_1row.cl new file mode 100644 index 000000000..28d30212c --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/mul_mv_f16_f32_1row.cl @@ -0,0 +1,94 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +#ifdef cl_intel_subgroups +#pragma OPENCL EXTENSION cl_intel_subgroups : enable +#else +#pragma OPENCL EXTENSION cl_khr_subgroups : enable +#endif + +#ifdef cl_intel_required_subgroup_size +#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable +#define INTEL_GPU 1 +#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16))) +#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32))) +#elif defined(cl_qcom_reqd_sub_group_size) +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable +#define ADRENO_GPU 1 +#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half"))) +#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full"))) +#endif + +#ifdef ADRENO_GPU +REQD_SUBGROUP_SIZE_64 +#endif +kernel void kernel_mul_mat_f16_f32_1row( + global char * src0, + ulong offset0, + global char * src1, + ulong offset1, + global float * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + ulong nb00, + ulong nb01, + ulong nb02, + ulong nb03, + int ne10, + int ne11, + int ne12, + ulong nb10, + ulong nb11, + ulong nb12, + ulong nb13, + int ne0, + int ne1, + int r2, + int r3 +) { + src0 = (global char*)((global char*)src0 + offset0); + src1 = (global char*)((global char*)src1 + offset1); + dst = (global float*)((global char*)dst + offsetd); + + int r0 = get_group_id(0); + int r1 = get_group_id(1); + int im = get_group_id(2); + + int i12 = im%ne12; + int i13 = im/ne12; + + ulong offset_src0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03; + ulong offset_src1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13; + + global half * x = (global half *) (src0 + offset_src0); + global float * y = (global float *) (src1 + offset_src1); + + float sumf = 0; + if (ne00 < 128) { + for (int i = get_sub_group_local_id(); i < ne00; i += get_max_sub_group_size()) { + sumf += (float) x[i] * (float) y[i]; + } + float all_sum = sub_group_reduce_add(sumf); + if (get_sub_group_local_id() == 0) { + dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; + } + } else { + global half4 * x4 = (global half4 *) x; + global float4 * y4 = (global float4 *) y; + for (int i = get_sub_group_local_id(); i < ne00/4; i += get_max_sub_group_size()) { + sumf += (float) x4[i].s0 * y4[i].s0; + sumf += (float) x4[i].s1 * y4[i].s1; + sumf += (float) x4[i].s2 * y4[i].s2; + sumf += (float) x4[i].s3 * y4[i].s3; + } + float all_sum = sub_group_reduce_add(sumf); + if (get_sub_group_local_id() == 0) { + for (int i = 4*(ne00/4); i < ne00; ++i) { + all_sum += (float) x[i] * y[i]; + } + dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; + } + } + +} diff --git a/ggml/src/ggml-opencl/kernels/mul_mv_f16_f32_l4.cl b/ggml/src/ggml-opencl/kernels/mul_mv_f16_f32_l4.cl new file mode 100644 index 000000000..cdf8197c4 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/mul_mv_f16_f32_l4.cl @@ -0,0 +1,84 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +#ifdef cl_intel_subgroups +#pragma OPENCL EXTENSION cl_intel_subgroups : enable +#else +#pragma OPENCL EXTENSION cl_khr_subgroups : enable +#endif + +#ifdef cl_intel_required_subgroup_size +#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable +#define INTEL_GPU 1 +#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16))) +#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32))) +#elif defined(cl_qcom_reqd_sub_group_size) +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable +#define ADRENO_GPU 1 +#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half"))) +#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full"))) +#endif + +// Assumes row size (ne00) is a multiple of 4 +#ifdef ADRENO_GPU +REQD_SUBGROUP_SIZE_64 +#endif +kernel void kernel_mul_mat_f16_f32_l4( + global char * src0, + ulong offset0, + global char * src1, + ulong offset1, + global float * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + ulong nb00, + ulong nb01, + ulong nb02, + ulong nb03, + int ne10, + int ne11, + int ne12, + ulong nb10, + ulong nb11, + ulong nb12, + ulong nb13, + int ne0, + int ne1, + int r2, + int r3 +) { + src0 = (global char*)((global char*)src0 + offset0); + src1 = (global char*)((global char*)src1 + offset1); + dst = (global float*)((global char*)dst + offsetd); + + int nrows = ne11; + int r0 = get_group_id(0); + int im = get_group_id(2); + + int i12 = im%ne12; + int i13 = im/ne12; + + ulong offset_src0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03; + + global half4 * x4 = (global half4 *) (src0 + offset_src0); + + for (int r1 = 0; r1 < nrows; ++r1) { + ulong offset_src1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13; + + global float4 * y4 = (global float4 *) (src1 + offset_src1); + + float sumf = 0; + for (int i = get_sub_group_local_id(); i < ne00/4; i += get_max_sub_group_size()) { + sumf += convert_float(x4[i].s0) * y4[i].s0; + sumf += convert_float(x4[i].s1) * y4[i].s1; + sumf += convert_float(x4[i].s2) * y4[i].s2; + sumf += convert_float(x4[i].s3) * y4[i].s3; + } + + float all_sum = sub_group_reduce_add(sumf); + if (get_sub_group_local_id() == 0) { + dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; + } + } +} diff --git a/ggml/src/ggml-opencl/kernels/mul_mv_f32_f32.cl b/ggml/src/ggml-opencl/kernels/mul_mv_f32_f32.cl new file mode 100644 index 000000000..ec71b8756 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/mul_mv_f32_f32.cl @@ -0,0 +1,118 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +#ifdef cl_intel_subgroups +#pragma OPENCL EXTENSION cl_intel_subgroups : enable +#else +#pragma OPENCL EXTENSION cl_khr_subgroups : enable +#endif + +#ifdef cl_intel_required_subgroup_size +#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable +#define INTEL_GPU 1 +#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16))) +#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32))) +#elif defined(cl_qcom_reqd_sub_group_size) +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable +#define ADRENO_GPU 1 +#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half"))) +#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full"))) +#endif + +#define N_F32_F32 4 + +#ifdef ADRENO_GPU +REQD_SUBGROUP_SIZE_64 +#endif +kernel void kernel_mul_mat_f32_f32( + global char * src0, + ulong offset0, + global char * src1, + ulong offset1, + global float * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + ulong nb00, + ulong nb01, + ulong nb02, + ulong nb03, + int ne10, + int ne11, + int ne12, + ulong nb10, + ulong nb11, + ulong nb12, + ulong nb13, + int ne0, + int ne1, + int r2, + int r3 +) { + src0 = (global char*)((global char*)src0 + offset0); + src1 = (global char*)((global char*)src1 + offset1); + dst = (global float*)((global char*)dst + offsetd); + + int r0 = get_group_id(0); + int rb = get_group_id(1)*N_F32_F32; + int im = get_group_id(2); + + int i12 = im%ne12; + int i13 = im/ne12; + + ulong offset_src0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03; + + global float * x = (global float *) (src0 + offset_src0); + + if (ne00 < 128) { + for (int row = 0; row < N_F32_F32; ++row) { + int r1 = rb + row; + if (r1 >= ne11) { + break; + } + + ulong offset_src1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13; + + global float * y = (global float *) (src1 + offset_src1); + + float sumf = 0; + for (int i = get_sub_group_local_id(); i < ne00; i += get_max_sub_group_size()) { + sumf += (float) x[i] * (float) y[i]; + } + + float all_sum = sub_group_reduce_add(sumf); + if (get_sub_group_local_id() == 0) { + dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; + } + } + } else { + global float4 * x4 = (global float4 *)x; + for (int row = 0; row < N_F32_F32; ++row) { + int r1 = rb + row; + if (r1 >= ne11) { + break; + } + + ulong offset_src1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13; + + global float * y = (global float *) (src1 + offset_src1); + global float4 * y4 = (global float4 *) y; + + float sumf = 0; + for (int i = get_sub_group_local_id(); i < ne00/4; i += get_max_sub_group_size()) { + sumf += (float) x4[i].s0 * y4[i].s0; + sumf += (float) x4[i].s1 * y4[i].s1; + sumf += (float) x4[i].s2 * y4[i].s2; + sumf += (float) x4[i].s3 * y4[i].s3; + } + + float all_sum = sub_group_reduce_add(sumf); + if (get_sub_group_local_id() == 0) { + for (int i = 4*(ne00/4); i < ne00; ++i) { + all_sum += (float) x[i] * y[i]; + } + dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; + } + } + } +} diff --git a/ggml/src/ggml-opencl/kernels/mul_mv_q4_0_f32.cl b/ggml/src/ggml-opencl/kernels/mul_mv_q4_0_f32.cl new file mode 100644 index 000000000..52141e0ed --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/mul_mv_q4_0_f32.cl @@ -0,0 +1,192 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +#ifdef cl_intel_subgroups +#pragma OPENCL EXTENSION cl_intel_subgroups : enable +#else +#pragma OPENCL EXTENSION cl_khr_subgroups : enable +#endif + +#ifdef cl_intel_required_subgroup_size +#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable +#define INTEL_GPU 1 +#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16))) +#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32))) +#elif defined(cl_qcom_reqd_sub_group_size) +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable +#define ADRENO_GPU 1 +#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half"))) +#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full"))) +#endif + +#define QK4_0 32 +#define QR4_0 2 +#define QK4_1 32 +#define QR4_1 2 +#define QK5_0 32 +#define QR5_0 2 +#define QK5_1 32 +#define QR5_1 2 +#define QK8_0 32 +#define QR8_0 1 +#define QK_K 256 +#define K_QUANTS_PER_ITERATION 2 + +typedef char int8_t; +typedef uchar uint8_t; +typedef short int16_t; +typedef ushort uint16_t; +typedef int int32_t; +typedef uint uint32_t; + +//------------------------------------------------------------------------------ +// block_q4_0 +//------------------------------------------------------------------------------ +struct block_q4_0 +{ + half d; + uint8_t qs[QK4_0 / 2]; +}; + +//------------------------------------------------------------------------------ +// mul_vec_q_n_f32 +//------------------------------------------------------------------------------ +// function for calculate inner product between half a q4_0 block and 16 floats (yl), sumy is SUM(yl[i]) +// il indicates where the q4 quants begin (0 or QK4_0/4) +// we assume that the yl's have been multiplied with the appropriate scale factor +// that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096) +inline float block_q_4_0_dot_y( + global struct block_q4_0 * qb_curr, + float sumy, + private float * yl, + int il +) { + float d = qb_curr->d; + float2 acc = 0.f; + global ushort * qs = ((global ushort *)qb_curr + 1 + il/2); + for (int i = 0; i < 8; i+=2) { + acc.s0 += yl[i + 0] * (qs[i / 2] & 0x000F) + + yl[i + 1] * (qs[i / 2] & 0x0F00); + acc.s1 += yl[i + 8] * (qs[i / 2] & 0x00F0) + + yl[i + 9] * (qs[i / 2] & 0xF000); + } + return d * (sumy * -8.f + acc.s0 + acc.s1); +} + +#ifdef INTEL_GPU +#define N_DST 4 // each SIMD group works on 4 rows +#define N_SIMDGROUP 1 // number of SIMD groups in a thread group +#define N_SIMDWIDTH 16 // assuming SIMD group size is 16 +#elif defined (ADRENO_GPU) +#define N_DST 4 +#define N_SIMDGROUP 1 +#define N_SIMDWIDTH 64 +#endif + +inline void mul_vec_q_n_f32( + global void * src0, + global float * src1, + global float * dst, + int ne00, + int ne01, + int ne02, + int ne10, + int ne12, + int ne0, + int ne1, + int r2, + int r3 +) { + + const ulong nb = ne00/QK4_0; + + int r0 = get_group_id(0); + int r1 = get_group_id(1); + int im = get_group_id(2); + + // (r0 * N_SIMDGROUP + get_sub_group_id()) is essenatially the linear global + // id of a SIMD group in the grid. + int first_row = (r0 * N_SIMDGROUP + get_sub_group_id()) * N_DST; + + int i12 = im%ne12; + int i13 = im/ne12; + + ulong offset0 = first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); + + global struct block_q4_0 * x = (global struct block_q4_0 *) src0 + offset0; + global float * y = (global float *) src1 + r1*ne10 + im*ne00*ne1; + + float yl[16]; // src1 vector cache + float sumf[N_DST]={0.f}; + + int ix = get_sub_group_local_id()/2; + int il = 8*(get_sub_group_local_id()%2); + + global float * yb = y + ix * QK4_0 + il; + + // each thread in a SIMD group deals with half a block. + for (int ib = ix; ib < nb; ib += N_SIMDWIDTH/2) { + float sumy = 0; + for (int i = 0; i < 8; i += 2) { + sumy += yb[i] + yb[i+1]; + yl[i+0] = yb[i+ 0]; + yl[i+1] = yb[i+ 1]/256.f; + sumy += yb[i+16] + yb[i+17]; + yl[i+8] = yb[i+16]/16.f; + yl[i+9] = yb[i+17]/4096.f; + } + + for (int row = 0; row < N_DST; row++) { + sumf[row] += block_q_4_0_dot_y(x+ib+row*nb, sumy, yl, il); + } + + // One thread in a SIMD group (i.e., subgroup) handles a half block, + // hence then entire SIMD group handles SIMDWIDTH/2 blocks. + // y points to the activation matrix (of type float). Therefore for + // one thread, the # of blocks y should advance is SIMDWIDTH/2 (because + // SIMDWIDTH/2 blocks are processed by a SIMD group) - in terms of + // floats, it is QK4_0 * (SIMDWIDTH/2), where QK4_0 is the block size. + yb += QK4_0 * (N_SIMDWIDTH/2); + } + + // The above does not work for Adreno - it produces incorrect results for + // row = 1, 2, 3 and only row = 0 gives the correct result. + // If N_DST is changed, the below array must be initialized accordingly. + // This also seems to perform better on Intel. + float tot[N_DST] = { + sub_group_reduce_add(sumf[0]), sub_group_reduce_add(sumf[1]), + sub_group_reduce_add(sumf[2]), sub_group_reduce_add(sumf[3])}; + for (int row = 0; row < N_DST; ++row) { + if (get_sub_group_local_id() == 0 && first_row + row < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + row] = tot[row]; + } + } +} + +#ifdef INTEL_GPU +REQD_SUBGROUP_SIZE_16 +#elif defined (ADRENO_GPU) +REQD_SUBGROUP_SIZE_64 +#endif +kernel void kernel_mul_mat_q4_0_f32( + global void * src0, + ulong offset0, + global float * src1, + ulong offset1, + global float * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + int ne10, + int ne12, + int ne0, + int ne1, + int r2, + int r3 +) { + src0 = (global void*)((global char*)src0 + offset0); + src1 = (global float*)((global char*)src1 + offset1); + dst = (global float*)((global char*)dst + offsetd); + + mul_vec_q_n_f32(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3); +} diff --git a/ggml/src/ggml-opencl/kernels/mul_mv_q4_0_f32_1d_16x_flat.cl b/ggml/src/ggml-opencl/kernels/mul_mv_q4_0_f32_1d_16x_flat.cl new file mode 100644 index 000000000..3eebab8f0 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/mul_mv_q4_0_f32_1d_16x_flat.cl @@ -0,0 +1,307 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +#ifdef cl_intel_subgroups +#pragma OPENCL EXTENSION cl_intel_subgroups : enable +#else +#pragma OPENCL EXTENSION cl_khr_subgroups : enable +#endif + +#ifdef cl_intel_required_subgroup_size +#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable +#define INTEL_GPU 1 +#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16))) +#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32))) +#elif defined(cl_qcom_reqd_sub_group_size) +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable +#define ADRENO_GPU 1 +#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half"))) +#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full"))) +#endif + +#define QK4_0 32 +#define QR4_0 2 +#define QK4_1 32 +#define QR4_1 2 +#define QK5_0 32 +#define QR5_0 2 +#define QK5_1 32 +#define QR5_1 2 +#define QK8_0 32 +#define QR8_0 1 +#define QK_K 256 +#define K_QUANTS_PER_ITERATION 2 + +typedef char int8_t; +typedef uchar uint8_t; +typedef short int16_t; +typedef ushort uint16_t; +typedef int int32_t; +typedef uint uint32_t; + +//------------------------------------------------------------------------------ +// block_q4_0 +//------------------------------------------------------------------------------ +struct block_q4_0 +{ + half d; + uint8_t qs[QK4_0 / 2]; +}; + +inline float mm_block_q_4_0_dot_y_flat( + global uchar * x, + global half * dh, + float sumy, + float16 yl, + int il +) { + float d = *dh; + global ushort * qs = ((global ushort *)x + il/2); + float acc = 0.f; + + acc += yl.s0 * (qs[0] & 0x000F); + acc += yl.s1 * (qs[0] & 0x0F00); + acc += yl.s8 * (qs[0] & 0x00F0); + acc += yl.s9 * (qs[0] & 0xF000); + + acc += yl.s2 * (qs[1] & 0x000F); + acc += yl.s3 * (qs[1] & 0x0F00); + acc += yl.sa * (qs[1] & 0x00F0); + acc += yl.sb * (qs[1] & 0xF000); + + acc += yl.s4 * (qs[2] & 0x000F); + acc += yl.s5 * (qs[2] & 0x0F00); + acc += yl.sc * (qs[2] & 0x00F0); + acc += yl.sd * (qs[2] & 0xF000); + + acc += yl.s6 * (qs[3] & 0x000F); + acc += yl.s7 * (qs[3] & 0x0F00); + acc += yl.se * (qs[3] & 0x00F0); + acc += yl.sf * (qs[3] & 0xF000); + + return d * (sumy * -8.f + acc); +} + +#ifdef INTEL_GPU +#define N_DST 16 // each SIMD group works on 8 rows (in weights matrix) +#define N_SIMDGROUP 1 // number of SIMD groups in a thread group +#define N_SIMDWIDTH 16 // assuming SIMD group size is 16 +#elif defined (ADRENO_GPU) +#define N_DST 16 +#define N_SIMDGROUP 1 +#define N_SIMDWIDTH 64 +#endif +// +// This variant performs 1d blocking with 16x output. +// Eeach simdgroup outputs 16 values on `n0` dim (row in the output matrix). +// +inline void mul_mat_q_n_f32_1d_16x_flat( + global uchar * src0_q, + global half * src0_d, + global float * src1, + global float * dst, + int ne00, + int ne01, + int ne02, + int ne10, + int ne12, + int ne0, + int ne1, + int r2, + int r3 +) { + const int nb = ne00/QK4_0; + + int r0 = get_group_id(0); + int r1 = get_group_id(1); + int im = get_group_id(2); + + // (r0 * N_SIMDGROUP + get_sub_group_id()) is the linear global id of + // a SIMD group in the grid. Each SIMD group produces N_DST values in the + // result, hence uses nb blocks, i.e., the offset becomes first_row*nb. + // Currently with llama2 7B, im is always 0. + // TODO: how to handle im/gqa*(nb*ne0)? + int first_row = (r0 * N_SIMDGROUP + get_sub_group_id()) * N_DST; + + int i12 = im%ne12; + int i13 = im/ne12; + + // The number of scales is the same as the number of blocks. + ulong offset0_d = first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); + // Each block contains QK4_0/2 uchars, hence offset for qs is as follows. + ulong offset0_q = (first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02)) * QK4_0/2; + + global uchar * x = (global uchar *) src0_q + offset0_q; + global half * d = (global half *) src0_d + offset0_d; + global float * y = (global float *) src1 + r1*ne10 + im*ne00*ne1; + + float16 yl; + float16 sumf = (float16)(0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f); + + int ix = get_sub_group_local_id()/2; + int il = 8*(get_sub_group_local_id()%2); + + global float * yb = y + ix*QK4_0 + il; + + for (int ib = ix; ib < nb; ib += N_SIMDWIDTH/2) { + float sumy = 0.f; + + sumy += yb[0]; + sumy += yb[1]; + sumy += yb[2]; + sumy += yb[3]; + sumy += yb[4]; + sumy += yb[5]; + sumy += yb[6]; + sumy += yb[7]; + + sumy += yb[16]; + sumy += yb[17]; + sumy += yb[18]; + sumy += yb[19]; + sumy += yb[20]; + sumy += yb[21]; + sumy += yb[22]; + sumy += yb[23]; + + yl.s0 = yb[0]; + yl.s1 = yb[1]/256.f; + + yl.s2 = yb[2]; + yl.s3 = yb[3]/256.f; + + yl.s4 = yb[4]; + yl.s5 = yb[5]/256.f; + + yl.s6 = yb[6]; + yl.s7 = yb[7]/256.f; + + yl.s8 = yb[16]/16.f; + yl.s9 = yb[17]/4096.f; + + yl.sa = yb[18]/16.f; + yl.sb = yb[19]/4096.f; + + yl.sc = yb[20]/16.f; + yl.sd = yb[21]/4096.f; + + yl.se = yb[22]/16.f; + yl.sf = yb[23]/4096.f; + + sumf.s0 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 0*nb*QK4_0/2, d + ib + 0*nb, sumy, yl, il); + sumf.s1 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 1*nb*QK4_0/2, d + ib + 1*nb, sumy, yl, il); + sumf.s2 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 2*nb*QK4_0/2, d + ib + 2*nb, sumy, yl, il); + sumf.s3 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 3*nb*QK4_0/2, d + ib + 3*nb, sumy, yl, il); + + sumf.s4 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 4*nb*QK4_0/2, d + ib + 4*nb, sumy, yl, il); + sumf.s5 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 5*nb*QK4_0/2, d + ib + 5*nb, sumy, yl, il); + sumf.s6 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 6*nb*QK4_0/2, d + ib + 6*nb, sumy, yl, il); + sumf.s7 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 7*nb*QK4_0/2, d + ib + 7*nb, sumy, yl, il); + + sumf.s8 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 8*nb*QK4_0/2, d + ib + 8*nb, sumy, yl, il); + sumf.s9 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 9*nb*QK4_0/2, d + ib + 9*nb, sumy, yl, il); + sumf.sa += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 10*nb*QK4_0/2, d + ib + 10*nb, sumy, yl, il); + sumf.sb += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 11*nb*QK4_0/2, d + ib + 11*nb, sumy, yl, il); + + sumf.sc += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 12*nb*QK4_0/2, d + ib + 12*nb, sumy, yl, il); + sumf.sd += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 13*nb*QK4_0/2, d + ib + 13*nb, sumy, yl, il); + sumf.se += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 14*nb*QK4_0/2, d + ib + 14*nb, sumy, yl, il); + sumf.sf += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 15*nb*QK4_0/2, d + ib + 15*nb, sumy, yl, il); + + yb += QK4_0 * (N_SIMDWIDTH/2); + } + + float16 tot = (float16)( + sub_group_reduce_add(sumf.s0), sub_group_reduce_add(sumf.s1), + sub_group_reduce_add(sumf.s2), sub_group_reduce_add(sumf.s3), + sub_group_reduce_add(sumf.s4), sub_group_reduce_add(sumf.s5), + sub_group_reduce_add(sumf.s6), sub_group_reduce_add(sumf.s7), + + sub_group_reduce_add(sumf.s8), sub_group_reduce_add(sumf.s9), + sub_group_reduce_add(sumf.sa), sub_group_reduce_add(sumf.sb), + sub_group_reduce_add(sumf.sc), sub_group_reduce_add(sumf.sd), + sub_group_reduce_add(sumf.se), sub_group_reduce_add(sumf.sf) + ); + + if (get_sub_group_local_id() == 0) { + if (first_row + 0 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 0] = tot.s0; + } + if (first_row + 1 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 1] = tot.s1; + } + if (first_row + 2 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 2] = tot.s2; + } + if (first_row + 3 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 3] = tot.s3; + } + + if (first_row + 4 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 4] = tot.s4; + } + if (first_row + 5 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 5] = tot.s5; + } + if (first_row + 6 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 6] = tot.s6; + } + if (first_row + 7 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 7] = tot.s7; + } + + if (first_row + 8 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 8] = tot.s8; + } + if (first_row + 9 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 9] = tot.s9; + } + if (first_row + 10 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 10] = tot.sa; + } + if (first_row + 11 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 11] = tot.sb; + } + + if (first_row + 12 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 12] = tot.sc; + } + if (first_row + 13 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 13] = tot.sd; + } + if (first_row + 14 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 14] = tot.se; + } + if (first_row + 15 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 15] = tot.sf; + } + } +} + +#ifdef INTEL_GPU +REQD_SUBGROUP_SIZE_16 +#elif defined (ADRENO_GPU) +REQD_SUBGROUP_SIZE_64 +#endif +kernel void kernel_mul_mat_q4_0_f32_1d_16x_flat( + global uchar * src0_q, + global half * src0_d, + global float * src1, + ulong offset1, + global float * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + int ne10, + int ne12, + int ne0, + int ne1, + int r2, + int r3 +) { + src1 = (global float*)((global char*)src1 + offset1); + dst = (global float*)((global char*)dst + offsetd); + + mul_mat_q_n_f32_1d_16x_flat(src0_q, src0_d, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3); +} diff --git a/ggml/src/ggml-opencl/kernels/mul_mv_q4_0_f32_1d_8x_flat.cl b/ggml/src/ggml-opencl/kernels/mul_mv_q4_0_f32_1d_8x_flat.cl new file mode 100644 index 000000000..38024d00a --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/mul_mv_q4_0_f32_1d_8x_flat.cl @@ -0,0 +1,265 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +#ifdef cl_intel_subgroups +#pragma OPENCL EXTENSION cl_intel_subgroups : enable +#else +#pragma OPENCL EXTENSION cl_khr_subgroups : enable +#endif + +#ifdef cl_intel_required_subgroup_size +#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable +#define INTEL_GPU 1 +#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16))) +#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32))) +#elif defined(cl_qcom_reqd_sub_group_size) +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable +#define ADRENO_GPU 1 +#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half"))) +#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full"))) +#endif + +#define QK4_0 32 +#define QR4_0 2 +#define QK4_1 32 +#define QR4_1 2 +#define QK5_0 32 +#define QR5_0 2 +#define QK5_1 32 +#define QR5_1 2 +#define QK8_0 32 +#define QR8_0 1 +#define QK_K 256 +#define K_QUANTS_PER_ITERATION 2 + +typedef char int8_t; +typedef uchar uint8_t; +typedef short int16_t; +typedef ushort uint16_t; +typedef int int32_t; +typedef uint uint32_t; + +//------------------------------------------------------------------------------ +// block_q4_0 +//------------------------------------------------------------------------------ +struct block_q4_0 +{ + half d; + uint8_t qs[QK4_0 / 2]; +}; + +inline float mm_block_q_4_0_dot_y_flat( + global uchar * x, + global half * dh, + float sumy, + float16 yl, + int il +) { + float d = *dh; + global ushort * qs = ((global ushort *)x + il/2); + float acc = 0.f; + + acc += yl.s0 * (qs[0] & 0x000F); + acc += yl.s1 * (qs[0] & 0x0F00); + acc += yl.s8 * (qs[0] & 0x00F0); + acc += yl.s9 * (qs[0] & 0xF000); + + acc += yl.s2 * (qs[1] & 0x000F); + acc += yl.s3 * (qs[1] & 0x0F00); + acc += yl.sa * (qs[1] & 0x00F0); + acc += yl.sb * (qs[1] & 0xF000); + + acc += yl.s4 * (qs[2] & 0x000F); + acc += yl.s5 * (qs[2] & 0x0F00); + acc += yl.sc * (qs[2] & 0x00F0); + acc += yl.sd * (qs[2] & 0xF000); + + acc += yl.s6 * (qs[3] & 0x000F); + acc += yl.s7 * (qs[3] & 0x0F00); + acc += yl.se * (qs[3] & 0x00F0); + acc += yl.sf * (qs[3] & 0xF000); + + return d * (sumy * -8.f + acc); +} + +#ifdef INTEL_GPU +#define N_DST 8 // each SIMD group works on 8 rows (in weights matrix) +#define N_SIMDGROUP 1 // number of SIMD groups in a thread group +#define N_SIMDWIDTH 16 // assuming SIMD group size is 16 +#elif defined (ADRENO_GPU) +#define N_DST 8 +#define N_SIMDGROUP 1 +#define N_SIMDWIDTH 64 +#endif +// +// This variant performs 1d blocking with 8x output. +// Eeach simdgroup outputs 8 values on `n0` dim (row in the output matrix). +// +inline void mul_mat_q_n_f32_1d_8x_flat( + global uchar * src0_q, + global half * src0_d, + global float * src1, + global float * dst, + int ne00, + int ne01, + int ne02, + int ne10, + int ne12, + int ne0, + int ne1, + int r2, + int r3 +) { + const int nb = ne00/QK4_0; + + int r0 = get_group_id(0); + int r1 = get_group_id(1); + int im = get_group_id(2); + + // (r0 * N_SIMDGROUP + get_sub_group_id()) is the linear global id of + // a SIMD group in the grid. Each SIMD group produces N_DST values in the + // result, hence uses nb blocks, i.e., the offset becomes first_row*nb. + // Currently with llama2 7B, im is always 0. + // TODO: how to handle im/gqa*(nb*ne0)? + int first_row = (r0 * N_SIMDGROUP + get_sub_group_id()) * N_DST; + + int i12 = im%ne12; + int i13 = im/ne12; + + // The number of scales is the same as the number of blocks. + ulong offset0_d = first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); + // Each block contains QK4_0/2 uchars, hence offset for qs is as follows. + ulong offset0_q = (first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02)) * QK4_0/2; + + global uchar * x = (global uchar *) src0_q + offset0_q; + global half * d = (global half *) src0_d + offset0_d; + global float * y = (global float *) src1 + r1*ne10 + im*ne00*ne1; + + float16 yl; + float8 sumf = (float8)(0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f); + + int ix = get_sub_group_local_id()/2; + int il = 8*(get_sub_group_local_id()%2); + + global float * yb = y + ix*QK4_0 + il; + + for (int ib = ix; ib < nb; ib += N_SIMDWIDTH/2) { + float sumy = 0.f; + + sumy += yb[0]; + sumy += yb[1]; + sumy += yb[2]; + sumy += yb[3]; + sumy += yb[4]; + sumy += yb[5]; + sumy += yb[6]; + sumy += yb[7]; + + sumy += yb[16]; + sumy += yb[17]; + sumy += yb[18]; + sumy += yb[19]; + sumy += yb[20]; + sumy += yb[21]; + sumy += yb[22]; + sumy += yb[23]; + + yl.s0 = yb[0]; + yl.s1 = yb[1]/256.f; + + yl.s2 = yb[2]; + yl.s3 = yb[3]/256.f; + + yl.s4 = yb[4]; + yl.s5 = yb[5]/256.f; + + yl.s6 = yb[6]; + yl.s7 = yb[7]/256.f; + + yl.s8 = yb[16]/16.f; + yl.s9 = yb[17]/4096.f; + + yl.sa = yb[18]/16.f; + yl.sb = yb[19]/4096.f; + + yl.sc = yb[20]/16.f; + yl.sd = yb[21]/4096.f; + + yl.se = yb[22]/16.f; + yl.sf = yb[23]/4096.f; + + sumf.s0 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 0*nb*QK4_0/2, d + ib + 0*nb, sumy, yl, il); + sumf.s1 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 1*nb*QK4_0/2, d + ib + 1*nb, sumy, yl, il); + sumf.s2 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 2*nb*QK4_0/2, d + ib + 2*nb, sumy, yl, il); + sumf.s3 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 3*nb*QK4_0/2, d + ib + 3*nb, sumy, yl, il); + + sumf.s4 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 4*nb*QK4_0/2, d + ib + 4*nb, sumy, yl, il); + sumf.s5 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 5*nb*QK4_0/2, d + ib + 5*nb, sumy, yl, il); + sumf.s6 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 6*nb*QK4_0/2, d + ib + 6*nb, sumy, yl, il); + sumf.s7 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 7*nb*QK4_0/2, d + ib + 7*nb, sumy, yl, il); + + yb += QK4_0 * (N_SIMDWIDTH/2); + } + + float8 tot = (float8)( + sub_group_reduce_add(sumf.s0), sub_group_reduce_add(sumf.s1), + sub_group_reduce_add(sumf.s2), sub_group_reduce_add(sumf.s3), + sub_group_reduce_add(sumf.s4), sub_group_reduce_add(sumf.s5), + sub_group_reduce_add(sumf.s6), sub_group_reduce_add(sumf.s7) + ); + + if (get_sub_group_local_id() == 0) { + if (first_row + 0 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 0] = tot.s0; + } + if (first_row + 1 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 1] = tot.s1; + } + if (first_row + 2 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 2] = tot.s2; + } + if (first_row + 3 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 3] = tot.s3; + } + + if (first_row + 4 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 4] = tot.s4; + } + if (first_row + 5 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 5] = tot.s5; + } + if (first_row + 6 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 6] = tot.s6; + } + if (first_row + 7 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 7] = tot.s7; + } + } +} + +#ifdef INTEL_GPU +REQD_SUBGROUP_SIZE_16 +#elif defined (ADRENO_GPU) +REQD_SUBGROUP_SIZE_64 +#endif +kernel void kernel_mul_mat_q4_0_f32_1d_8x_flat( + global uchar * src0_q, + global half * src0_d, + global float * src1, + ulong offset1, + global float * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + int ne10, + int ne12, + int ne0, + int ne1, + int r2, + int r3 +) { + src1 = (global float*)((global char*)src1 + offset1); + dst = (global float*)((global char*)dst + offsetd); + + mul_mat_q_n_f32_1d_8x_flat(src0_q, src0_d, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3); +} diff --git a/ggml/src/ggml-opencl/kernels/mul_mv_q4_0_f32_8x_flat.cl b/ggml/src/ggml-opencl/kernels/mul_mv_q4_0_f32_8x_flat.cl new file mode 100644 index 000000000..aed1ce7b2 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/mul_mv_q4_0_f32_8x_flat.cl @@ -0,0 +1,272 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +#ifdef cl_intel_subgroups +#pragma OPENCL EXTENSION cl_intel_subgroups : enable +#else +#pragma OPENCL EXTENSION cl_khr_subgroups : enable +#endif + +#ifdef cl_intel_required_subgroup_size +#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable +#define INTEL_GPU 1 +#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16))) +#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32))) +#elif defined(cl_qcom_reqd_sub_group_size) +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable +#define ADRENO_GPU 1 +#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half"))) +#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full"))) +#endif + +#define QK4_0 32 +#define QR4_0 2 +#define QK4_1 32 +#define QR4_1 2 +#define QK5_0 32 +#define QR5_0 2 +#define QK5_1 32 +#define QR5_1 2 +#define QK8_0 32 +#define QR8_0 1 +#define QK_K 256 +#define K_QUANTS_PER_ITERATION 2 + +typedef char int8_t; +typedef uchar uint8_t; +typedef short int16_t; +typedef ushort uint16_t; +typedef int int32_t; +typedef uint uint32_t; + +//------------------------------------------------------------------------------ +// block_q4_0 +//------------------------------------------------------------------------------ +struct block_q4_0 +{ + half d; + uint8_t qs[QK4_0 / 2]; +}; + +// This function requires the original shuffled weights. +// As a reminder, the original weights are shuffled so that (q[0], q[16]) are +// packed together in a byte, so are (q[1], q[17]) and so on. +inline float block_q_4_0_dot_y_flat( + global uchar * x, + global half * dh, + float sumy, + float16 yl, + int il +) { + float d = *dh; + global ushort * qs = ((global ushort *)x + il/2); + float acc = 0.f; + + acc += yl.s0 * (qs[0] & 0x000F); + acc += yl.s1 * (qs[0] & 0x0F00); + acc += yl.s8 * (qs[0] & 0x00F0); + acc += yl.s9 * (qs[0] & 0xF000); + + acc += yl.s2 * (qs[1] & 0x000F); + acc += yl.s3 * (qs[1] & 0x0F00); + acc += yl.sa * (qs[1] & 0x00F0); + acc += yl.sb * (qs[1] & 0xF000); + + acc += yl.s4 * (qs[2] & 0x000F); + acc += yl.s5 * (qs[2] & 0x0F00); + acc += yl.sc * (qs[2] & 0x00F0); + acc += yl.sd * (qs[2] & 0xF000); + + acc += yl.s6 * (qs[3] & 0x000F); + acc += yl.s7 * (qs[3] & 0x0F00); + acc += yl.se * (qs[3] & 0x00F0); + acc += yl.sf * (qs[3] & 0xF000); + + return d * (sumy * -8.f + acc); +} + +// +// This variant outputs 8 values. +// +#undef N_DST +#undef N_SIMDGROUP +#undef N_SIMDWIDTH + +#ifdef INTEL_GPU +#define N_DST 8 // each SIMD group works on 8 rows +#define N_SIMDGROUP 1 // number of SIMD groups in a thread group +#define N_SIMDWIDTH 16 // assuming SIMD group size is 32 +#elif defined (ADRENO_GPU) +#define N_DST 8 +#define N_SIMDGROUP 1 +#define N_SIMDWIDTH 64 +#endif + +inline void mul_vec_q_n_f32_8x_flat( + global uchar * src0_q, + global half * src0_d, + global float * src1, + global float * dst, + int ne00, + int ne01, + int ne02, + int ne10, + int ne12, + int ne0, + int ne1, + int r2, + int r3 +) { + const ulong nb = ne00/QK4_0; + + int r0 = get_group_id(0); + int r1 = get_group_id(1); + int im = get_group_id(2); + + // (r0 * N_SIMDGROUP + get_sub_group_id()) is the linear global id of + // a SIMD group in the grid. Each SIMD group produces N_DST values in the + // result, hence uses nb blocks, i.e., the offset becomes first_row*nb. + // Currently with llama2 7B, im is always 0. + // TODO: how to handle im/gqa*(nb*ne0)? + int first_row = (r0 * N_SIMDGROUP + get_sub_group_id()) * N_DST; + + int i12 = im%ne12; + int i13 = im/ne12; + + // The number of scales is the same as the number of blocks. + ulong offset0_d = first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); + // Each block contains QK4_0/2 uchars, hence offset for qs is as follows. + ulong offset0_q = (first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02)) * QK4_0/2; + + global uchar * x = (global uchar *) src0_q + offset0_q; + global half * d = (global half *) src0_d + offset0_d; + global float * y = (global float *) src1 + r1*ne10 + im*ne00*ne1; + + float16 yl; + float8 sumf = 0.f; + + int ix = get_sub_group_local_id()/2; + int il = 8*(get_sub_group_local_id()%2); + + global float * yb = y + ix*QK4_0 + il; + + for (int ib = ix; ib < nb; ib += N_SIMDWIDTH/2) { + float sumy = 0.f; + + sumy += yb[0]; + sumy += yb[1]; + sumy += yb[2]; + sumy += yb[3]; + sumy += yb[4]; + sumy += yb[5]; + sumy += yb[6]; + sumy += yb[7]; + + sumy += yb[16]; + sumy += yb[17]; + sumy += yb[18]; + sumy += yb[19]; + sumy += yb[20]; + sumy += yb[21]; + sumy += yb[22]; + sumy += yb[23]; + + yl.s0 = yb[0]; + yl.s1 = yb[1]/256.f; + + yl.s2 = yb[2]; + yl.s3 = yb[3]/256.f; + + yl.s4 = yb[4]; + yl.s5 = yb[5]/256.f; + + yl.s6 = yb[6]; + yl.s7 = yb[7]/256.f; + + yl.s8 = yb[16]/16.f; + yl.s9 = yb[17]/4096.f; + + yl.sa = yb[18]/16.f; + yl.sb = yb[19]/4096.f; + + yl.sc = yb[20]/16.f; + yl.sd = yb[21]/4096.f; + + yl.se = yb[22]/16.f; + yl.sf = yb[23]/4096.f; + + sumf.s0 += block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 0*nb*QK4_0/2, d + ib + 0*nb, sumy, yl, il); + sumf.s1 += block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 1*nb*QK4_0/2, d + ib + 1*nb, sumy, yl, il); + sumf.s2 += block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 2*nb*QK4_0/2, d + ib + 2*nb, sumy, yl, il); + sumf.s3 += block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 3*nb*QK4_0/2, d + ib + 3*nb, sumy, yl, il); + + sumf.s4 += block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 4*nb*QK4_0/2, d + ib + 4*nb, sumy, yl, il); + sumf.s5 += block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 5*nb*QK4_0/2, d + ib + 5*nb, sumy, yl, il); + sumf.s6 += block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 6*nb*QK4_0/2, d + ib + 6*nb, sumy, yl, il); + sumf.s7 += block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 7*nb*QK4_0/2, d + ib + 7*nb, sumy, yl, il); + + yb += QK4_0 * (N_SIMDWIDTH/2); + } + + float8 tot = (float8)( + sub_group_reduce_add(sumf.s0), sub_group_reduce_add(sumf.s1), + sub_group_reduce_add(sumf.s2), sub_group_reduce_add(sumf.s3), + sub_group_reduce_add(sumf.s4), sub_group_reduce_add(sumf.s5), + sub_group_reduce_add(sumf.s6), sub_group_reduce_add(sumf.s7) + ); + + if (get_sub_group_local_id() == 0) { + if (first_row + 0 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 0] = tot.s0; + } + if (first_row + 1 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 1] = tot.s1; + } + if (first_row + 2 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 2] = tot.s2; + } + if (first_row + 3 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 3] = tot.s3; + } + + if (first_row + 4 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 4] = tot.s4; + } + if (first_row + 5 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 5] = tot.s5; + } + if (first_row + 6 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 6] = tot.s6; + } + if (first_row + 7 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 7] = tot.s7; + } + } +} + +#ifdef INTEL_GPU +REQD_SUBGROUP_SIZE_16 +#elif defined (ADRENO_GPU) +REQD_SUBGROUP_SIZE_64 +#endif +kernel void kernel_mul_mat_q4_0_f32_8x_flat( + global uchar * src0_q, + global half * src0_d, + global float * src1, + ulong offset1, + global float * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + int ne10, + int ne12, + int ne0, + int ne1, + int r2, + int r3 +) { + src1 = (global float*)((global char*)src1 + offset1); + dst = (global float*)((global char*)dst + offsetd); + + mul_vec_q_n_f32_8x_flat(src0_q, src0_d, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3); +} diff --git a/ggml/src/ggml-opencl/kernels/mul_mv_q4_0_f32_v.cl b/ggml/src/ggml-opencl/kernels/mul_mv_q4_0_f32_v.cl new file mode 100644 index 000000000..929552179 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/mul_mv_q4_0_f32_v.cl @@ -0,0 +1,254 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +#ifdef cl_intel_subgroups +#pragma OPENCL EXTENSION cl_intel_subgroups : enable +#else +#pragma OPENCL EXTENSION cl_khr_subgroups : enable +#endif + +#ifdef cl_intel_required_subgroup_size +#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable +#define INTEL_GPU 1 +#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16))) +#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32))) +#elif defined(cl_qcom_reqd_sub_group_size) +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable +#define ADRENO_GPU 1 +#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half"))) +#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full"))) +#endif + +#define QK4_0 32 +#define QR4_0 2 +#define QK4_1 32 +#define QR4_1 2 +#define QK5_0 32 +#define QR5_0 2 +#define QK5_1 32 +#define QR5_1 2 +#define QK8_0 32 +#define QR8_0 1 +#define QK_K 256 +#define K_QUANTS_PER_ITERATION 2 + +typedef char int8_t; +typedef uchar uint8_t; +typedef short int16_t; +typedef ushort uint16_t; +typedef int int32_t; +typedef uint uint32_t; + +//------------------------------------------------------------------------------ +// block_q4_0 +//------------------------------------------------------------------------------ +struct block_q4_0 +{ + half d; + uint8_t qs[QK4_0 / 2]; +}; + +// +// This variant unrolls the loops and uses vector types instead of pointers. +// It improves performance on Adreno but not so much on Intel. +// +inline float block_q_4_0_dot_y_v( + global struct block_q4_0 * qb_curr, + float sumy, + float16 yl, + int il +) { + float d = qb_curr->d; + float acc = 0.f; + global ushort * qs = ((global ushort *)qb_curr + 1 + il/2); + + acc += yl.s0 * (qs[0] & 0x000F); + acc += yl.s1 * (qs[0] & 0x0F00); + acc += yl.s8 * (qs[0] & 0x00F0); + acc += yl.s9 * (qs[0] & 0xF000); + + acc += yl.s2 * (qs[1] & 0x000F); + acc += yl.s3 * (qs[1] & 0x0F00); + acc += yl.sa * (qs[1] & 0x00F0); + acc += yl.sb * (qs[1] & 0xF000); + + acc += yl.s4 * (qs[2] & 0x000F); + acc += yl.s5 * (qs[2] & 0x0F00); + acc += yl.sc * (qs[2] & 0x00F0); + acc += yl.sd * (qs[2] & 0xF000); + + acc += yl.s6 * (qs[3] & 0x000F); + acc += yl.s7 * (qs[3] & 0x0F00); + acc += yl.se * (qs[3] & 0x00F0); + acc += yl.sf * (qs[3] & 0xF000); + + return d * (sumy * -8.f + acc); +} + +#undef N_DST +#undef N_SIMDGROUP +#undef N_SIMDWIDTH + +#ifdef INTEL_GPU +#define N_DST 4 // each SIMD group works on 4 rows +#define N_SIMDGROUP 1 // number of SIMD groups in a thread group +#define N_SIMDWIDTH 16 // assuming SIMD group size is 16 +#elif defined (ADRENO_GPU) +#define N_DST 4 +#define N_SIMDGROUP 1 +#define N_SIMDWIDTH 64 +#endif + +inline void mul_vec_q_n_f32_v( + global void * src0, + global float * src1, + global float * dst, + int ne00, + int ne01, + int ne02, + int ne10, + int ne12, + int ne0, + int ne1, + int r2, + int r3 +) { + const ulong nb = ne00/QK4_0; + + int r0 = get_group_id(0); + int r1 = get_group_id(1); + int im = get_group_id(2); + + // (r0 * N_SIMDGROUP + get_sub_group_id()) is essenatially the linear global + // id of a SIMD group in the grid. + int first_row = (r0 * N_SIMDGROUP + get_sub_group_id()) * N_DST; + + int i12 = im%ne12; + int i13 = im/ne12; + + ulong offset0 = first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); + + global struct block_q4_0 * x = (global struct block_q4_0 *) src0 + offset0; + global float * y = (global float *) src1 + r1*ne10 + im*ne00*ne1; + + float16 yl; // src1 vector cache + float4 sumf = (float4)(0.f, 0.f, 0.f, 0.f); + + int ix = get_sub_group_local_id()/2; + int il = 8*(get_sub_group_local_id()%2); + + global float * yb = y + ix * QK4_0 + il; + + // each thread in a SIMD group deals with half a block. + for (int ib = ix; ib < nb; ib += N_SIMDWIDTH/2) { + float sumy = 0; + + sumy += yb[0]; + sumy += yb[1]; + sumy += yb[2]; + sumy += yb[3]; + sumy += yb[4]; + sumy += yb[5]; + sumy += yb[6]; + sumy += yb[7]; + + sumy += yb[16]; + sumy += yb[17]; + sumy += yb[18]; + sumy += yb[19]; + sumy += yb[20]; + sumy += yb[21]; + sumy += yb[22]; + sumy += yb[23]; + + + yl.s0 = yb[0]; + yl.s1 = yb[1]/256.f; + + yl.s2 = yb[2]; + yl.s3 = yb[3]/256.f; + + yl.s4 = yb[4]; + yl.s5 = yb[5]/256.f; + + yl.s6 = yb[6]; + yl.s7 = yb[7]/256.f; + + yl.s8 = yb[16]/16.f; + yl.s9 = yb[17]/4096.f; + + yl.sa = yb[18]/16.f; + yl.sb = yb[19]/4096.f; + + yl.sc = yb[20]/16.f; + yl.sd = yb[21]/4096.f; + + yl.se = yb[22]/16.f; + yl.sf = yb[23]/4096.f; + + sumf.s0 += block_q_4_0_dot_y_v(x+ib+0*nb, sumy, yl, il); + sumf.s1 += block_q_4_0_dot_y_v(x+ib+1*nb, sumy, yl, il); + sumf.s2 += block_q_4_0_dot_y_v(x+ib+2*nb, sumy, yl, il); + sumf.s3 += block_q_4_0_dot_y_v(x+ib+3*nb, sumy, yl, il); + + // One thread in a SIMD group (i.e., subgroup) handles a half block, + // hence then entire SIMD group handles SIMDWIDTH/2 blocks. + // y points to the activation matrix (of type float). Therefore for + // one thread, the # of blocks y should advance is SIMDWIDTH/2 (because + // SIMDWIDTH/2 blocks are processed by a SIMD group) - in terms of + // floats, it is QK4_0 * (SIMDWIDTH/2), where QK4_0 is the block size. + yb += QK4_0 * (N_SIMDWIDTH/2); + } + + // The above does not work for Adreno - it produces incorrect results for + // row = 1, 2, 3 and only row = 0 gives the correct result. + // If N_DST is changed, the below array must be initialized accordingly. + // This also seems to perform better on Intel. + float4 tot = (float4)( + sub_group_reduce_add(sumf.s0), sub_group_reduce_add(sumf.s1), + sub_group_reduce_add(sumf.s2), sub_group_reduce_add(sumf.s3) + ); + + if (get_sub_group_local_id() == 0) { + if (first_row + 0 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 0] = tot.s0; + } + if (first_row + 1 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 1] = tot.s1; + } + if (first_row + 2 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 2] = tot.s2; + } + if (first_row + 3 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 3] = tot.s3; + } + } +} + +#ifdef INTEL_GPU +REQD_SUBGROUP_SIZE_16 +#elif defined (ADRENO_GPU) +REQD_SUBGROUP_SIZE_64 +#endif +kernel void kernel_mul_mat_q4_0_f32_v( + global void * src0, + ulong offset0, + global float * src1, + ulong offset1, + global float * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + int ne10, + int ne12, + int ne0, + int ne1, + int r2, + int r3 +) { + src0 = (global void*)((global char*)src0 + offset0); + src1 = (global float*)((global char*)src1 + offset1); + dst = (global float*)((global char*)dst + offsetd); + + mul_vec_q_n_f32_v(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3); +} diff --git a/ggml/src/ggml-opencl/kernels/mul_mv_q6_k.cl b/ggml/src/ggml-opencl/kernels/mul_mv_q6_k.cl new file mode 100644 index 000000000..8a17b9aae --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/mul_mv_q6_k.cl @@ -0,0 +1,190 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +#ifdef cl_intel_subgroups +#pragma OPENCL EXTENSION cl_intel_subgroups : enable +#else +#pragma OPENCL EXTENSION cl_khr_subgroups : enable +#endif + +#ifdef cl_intel_required_subgroup_size +#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable +#define INTEL_GPU 1 +#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16))) +#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32))) +#elif defined(cl_qcom_reqd_sub_group_size) +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable +#define ADRENO_GPU 1 +#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half"))) +#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full"))) +#endif + +#define QK4_0 32 +#define QR4_0 2 +#define QK4_1 32 +#define QR4_1 2 +#define QK5_0 32 +#define QR5_0 2 +#define QK5_1 32 +#define QR5_1 2 +#define QK8_0 32 +#define QR8_0 1 +#define QK_K 256 +#define K_QUANTS_PER_ITERATION 2 + +typedef char int8_t; +typedef uchar uint8_t; +typedef short int16_t; +typedef ushort uint16_t; +typedef int int32_t; +typedef uint uint32_t; + +//------------------------------------------------------------------------------ +// block_q6_K +//------------------------------------------------------------------------------ +// 6-bit quantization +// weight is represented as x = a * q +// 16 blocks of 16 elements each +// Effectively 6.5625 bits per weight +typedef struct { + uint8_t ql[QK_K/2]; // quants, lower 4 bits + uint8_t qh[QK_K/4]; // quants, upper 2 bits + int8_t scales[QK_K/16]; // scales, quantized with 8 bits + half d; // super-block scale +} block_q6_K; + +//------------------------------------------------------------------------------ +// kernel_mul_mv_q6_K_f32 +//------------------------------------------------------------------------------ + +#undef N_DST +#undef N_SIMDGROUP +#undef N_SIMDWIDTH + +#ifdef INTEL_GPU +#define N_DST 1 // number of rows each SIMD group works on +#define N_SIMDGROUP 2 // number of SIMD groups in a thread group +#define N_SIMDWIDTH 16 // SIMD group size +#elif defined (ADRENO_GPU) +#define N_DST 1 +#define N_SIMDGROUP 2 +#define N_SIMDWIDTH 64 +#endif + +#define BLOCK_STRIDE (N_SIMDWIDTH/16) // number of blocks each subgroup processes + +#ifdef INTEL_GPU +REQD_SUBGROUP_SIZE_16 +#elif defined (ADRENO_GPU) +REQD_SUBGROUP_SIZE_64 +#endif +kernel void kernel_mul_mv_q6_K_f32( + global void * src0, + ulong offset0, + global float * src1, + ulong offset1, + global float * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + int ne10, + int ne12, + int ne0, + int ne1, + int r2, + int r3 +) { + src0 = (global void*)((global char*)src0 + offset0); + src1 = (global float*)((global char*)src1 + offset1); + dst = (global float*)((global char*)dst + offsetd); + + uchar kmask1 = 0x03; + uchar kmask2 = 0x0C; + uchar kmask3 = 0x30; + uchar kmask4 = 0xC0; + + int nb = ne00/QK_K; + + int r0 = get_group_id(0); + int r1 = get_group_id(1); + int im = get_group_id(2); + + int row = N_SIMDGROUP * r0 + get_sub_group_id(); + + int i12 = im%ne12; + int i13 = im/ne12; + + ulong offset_src0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); + + global block_q6_K * x = (global block_q6_K *) src0 + row*nb + offset_src0; + global float * yy = (global float *) src1 + r1*ne10 + im*ne00*ne1; + + float sumf = 0; + + // For Q6_K quantization, 16 values forms a subblock, 16 subblock forms a + // block. Values in a subblock shares a scale that is quantized with 8 bits; + // the entire block shares a single floating point scale. + // For work distribution, each thread processes a subblock (16 weights), hence + // 16 threads process a (super) block -- a subgroup thus handles SIMDWIDTH/16 + // (super) blocks -- this is the block stride. + // The 16 threads that process a (super) block are split into 2 portions, each has + // 8 threads; each portion works on 8 subblocks. + // For subgroup of 16 threads, the entire subgroup works on a single (super) block + // before moving to the next (super) block. Thread0 - thread7 work on the + // first 8 subblocks; thread8 - thread15 works on the last 8 subblocks. + // Thread0 - thread3 work on subblocks 0, 2, 4, 6; thread4 - thread7 work on + // subblocks 1, 3, 5, 7. Each thread does not work on an entire subblock, but + // works on a total of 16 weight values. + int tid = get_sub_group_local_id()/BLOCK_STRIDE; // first block_stride groups have tid=0 + int ix = get_sub_group_local_id()%BLOCK_STRIDE; // first block is 0..block_stride-1 + int ip = tid/8; // first or second half of (super) block (0 or 1) + int il = tid%8; // each half has 8 parts, one per scale + int n = 4; // 4 scales at a time (and 4 sums) + int l0 = n*il; // offset into half-block, 0..28 + int is = 8*ip + l0/16; // 0, 1, 8, 9 + + int y_offset = 128*ip + l0; + int q_offset_l = 64*ip + l0; + int q_offset_h = 32*ip + l0; + + for (int i = ix; i < nb; i += BLOCK_STRIDE) { + + global uint8_t * q1 = x[i].ql + q_offset_l; + global uint8_t * q2 = q1 + QK_K/8; + global uint8_t * qh = x[i].qh + q_offset_h; + global int8_t * sc = x[i].scales + is; + + global float * y = yy + i * QK_K + y_offset; + + float dall = x[i].d; + + float4 sums = {0.f, 0.f, 0.f, 0.f}; + + sums.s0 += y[0+ 0] * ((float)((q1[0] & 0xF) | ((qh[0] & kmask1) << 4)) - 32.f); + sums.s1 += y[0+32] * ((float)((q2[0] & 0xF) | ((qh[0] & kmask2) << 2)) - 32.f); + sums.s2 += y[0+64] * ((float)((q1[0] >> 4) | ((qh[0] & kmask3) << 0)) - 32.f); + sums.s3 += y[0+96] * ((float)((q2[0] >> 4) | ((qh[0] & kmask4) >> 2)) - 32.f); + + sums.s0 += y[1+ 0] * ((float)((q1[1] & 0xF) | ((qh[1] & kmask1) << 4)) - 32.f); + sums.s1 += y[1+32] * ((float)((q2[1] & 0xF) | ((qh[1] & kmask2) << 2)) - 32.f); + sums.s2 += y[1+64] * ((float)((q1[1] >> 4) | ((qh[1] & kmask3) << 0)) - 32.f); + sums.s3 += y[1+96] * ((float)((q2[1] >> 4) | ((qh[1] & kmask4) >> 2)) - 32.f); + + sums.s0 += y[2+ 0] * ((float)((q1[2] & 0xF) | ((qh[2] & kmask1) << 4)) - 32.f); + sums.s1 += y[2+32] * ((float)((q2[2] & 0xF) | ((qh[2] & kmask2) << 2)) - 32.f); + sums.s2 += y[2+64] * ((float)((q1[2] >> 4) | ((qh[2] & kmask3) << 0)) - 32.f); + sums.s3 += y[2+96] * ((float)((q2[2] >> 4) | ((qh[2] & kmask4) >> 2)) - 32.f); + + sums.s0 += y[3+ 0] * ((float)((q1[3] & 0xF) | ((qh[3] & kmask1) << 4)) - 32.f); + sums.s1 += y[3+32] * ((float)((q2[3] & 0xF) | ((qh[3] & kmask2) << 2)) - 32.f); + sums.s2 += y[3+64] * ((float)((q1[3] >> 4) | ((qh[3] & kmask3) << 0)) - 32.f); + sums.s3 += y[3+96] * ((float)((q2[3] >> 4) | ((qh[3] & kmask4) >> 2)) - 32.f); + + sumf += dall * (sums.s0 * sc[0] + sums.s1 * sc[2] + sums.s2 * sc[4] + sums.s3 * sc[6]); + } + + float tot = sub_group_reduce_add(sumf); + if (get_sub_group_local_id() == 0) { + dst[r1*ne0 + im*ne0*ne1 + row] = tot; + } +} diff --git a/ggml/src/ggml-opencl/kernels/norm.cl b/ggml/src/ggml-opencl/kernels/norm.cl new file mode 100644 index 000000000..43167ba4d --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/norm.cl @@ -0,0 +1,81 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +#ifdef cl_intel_required_subgroup_size +#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable +#define INTEL_GPU 1 +#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16))) +#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32))) +#elif defined(cl_qcom_reqd_sub_group_size) +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable +#define ADRENO_GPU 1 +#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half"))) +#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full"))) +#endif + +//------------------------------------------------------------------------------ +// norm +//------------------------------------------------------------------------------ +kernel void kernel_norm( + global void * src0, + ulong offset0, + global float * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + int ne03, + ulong nb01, + ulong nb02, + ulong nb03, + float eps, + local float * sum +) { + src0 = (global void*)((global char*)src0 + offset0); + dst = (global void*)((global char*)dst + offsetd); + + int i03 = get_group_id(2); + int i02 = get_group_id(1); + int i01 = get_group_id(0); + + global float * x = (global float *) ((global char *) src0 + i03*nb03 + i02*nb02 + i01*nb01); + + // MEAN + // parallel sum + sum[get_local_id(0)] = 0.0f; + for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) { + sum[get_local_id(0)] += x[i00]; + } + // reduce + barrier(CLK_LOCAL_MEM_FENCE); + for (uint i = get_local_size(0)/2; i > 0; i /= 2) { + if (get_local_id(0) < i) { + sum[get_local_id(0)] += sum[get_local_id(0) + i]; + } + barrier(CLK_LOCAL_MEM_FENCE); + } + float mean = sum[0] / ne00; + + // recenter and VARIANCE + barrier(CLK_LOCAL_MEM_FENCE); + global float * y = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; + sum[get_local_id(0)] = 0.0f; + for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) { + y[i00] = x[i00] - mean; + sum[get_local_id(0)] += y[i00] * y[i00]; + } + + // reduce + barrier(CLK_LOCAL_MEM_FENCE); + for (uint i = get_local_size(0)/2; i > 0; i /= 2) { + if (get_local_id(0) < i) { + sum[get_local_id(0)] += sum[get_local_id(0) + i]; + } + barrier(CLK_LOCAL_MEM_FENCE); + } + float variance = sum[0] / ne00; + + float scale = 1.0f/sqrt(variance + eps); + for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) { + y[i00] = y[i00] * scale; + } +} diff --git a/ggml/src/ggml-opencl/kernels/relu.cl b/ggml/src/ggml-opencl/kernels/relu.cl new file mode 100644 index 000000000..60ff28a61 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/relu.cl @@ -0,0 +1,16 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +//------------------------------------------------------------------------------ +// relu +//------------------------------------------------------------------------------ +kernel void kernel_relu( + global float * src0, + ulong offset0, + global float * dst, + ulong offsetd +) { + src0 = (global float*)((global char*)src0 + offset0); + dst = (global float*)((global char*)dst + offsetd); + + dst[get_global_id(0)] = fmax(0.0f, src0[get_global_id(0)]); +} diff --git a/ggml/src/ggml-opencl/kernels/rms_norm.cl b/ggml/src/ggml-opencl/kernels/rms_norm.cl new file mode 100644 index 000000000..9d21f3398 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/rms_norm.cl @@ -0,0 +1,96 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +#ifdef cl_intel_subgroups +#pragma OPENCL EXTENSION cl_intel_subgroups : enable +#else +#pragma OPENCL EXTENSION cl_khr_subgroups : enable +#endif + +#ifdef cl_intel_required_subgroup_size +#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable +#define INTEL_GPU 1 +#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16))) +#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32))) +#elif defined(cl_qcom_reqd_sub_group_size) +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable +#define ADRENO_GPU 1 +#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half"))) +#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full"))) +#endif + +//------------------------------------------------------------------------------ +// rms_norm +//------------------------------------------------------------------------------ +// This kernel depends on subgroup size. +#ifdef INTEL_GPU +REQD_SUBGROUP_SIZE_32 +#elif defined (ADRENO_GPU) +REQD_SUBGROUP_SIZE_64 +#endif +kernel void kernel_rms_norm( + global void * src0, + ulong offset0, + global float * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + int ne03, + ulong nb01, + ulong nb02, + ulong nb03, + float eps, + local float * sum // Note, the size depends on number of subgroups +) { + src0 = (global void*)((global char*)src0 + offset0); + dst = (global float*)((global char*)dst + offsetd); + + int i03 = get_group_id(2); + int i02 = get_group_id(1); + int i01 = get_group_id(0); + + global float4 * x = (global float4 *) ((global char *) src0 + i03*nb03 + i02*nb02 + i01*nb01); + global float * x_scalar = (global float *) x; + float4 sumf = 0; + float all_sum = 0; + + // parallel sum + for (int i00 = get_local_id(0); i00 < ne00/4; i00 += get_local_size(0)) { + sumf += x[i00] * x[i00]; + } + all_sum = sumf.s0 + sumf.s1 + sumf.s2 + sumf.s3; + all_sum = sub_group_reduce_add(all_sum); + if (get_sub_group_local_id() == 0) { + sum[get_sub_group_id()] = all_sum; + } + + barrier(CLK_LOCAL_MEM_FENCE); + // broadcast + for (uint i = get_local_size(0) / get_max_sub_group_size() / 2; i > 0; i /= 2) { + if (get_local_id(0) < i) { + sum[get_local_id(0)] += sum[get_local_id(0) + i]; + } + } + if (get_local_id(0) == 0) { + for (int i = 4 * (ne00 / 4); i < ne00; i++) { + sum[0] += x_scalar[i]; + } + sum[0] /= ne00; + } + + barrier(CLK_LOCAL_MEM_FENCE); + + const float mean = sum[0]; + const float scale = 1.0f/sqrt(mean + eps); + + global float4 * y = (global float4 *) (dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00); + global float * y_scalar = (global float *) y; + for (int i00 = get_local_id(0); i00 < ne00/4; i00 += get_local_size(0)) { + y[i00] = x[i00] * scale; + } + if (get_local_id(0) == 0) { + for (int i00 = 4 * (ne00 / 4); i00 < ne00; i00++) { + y_scalar[i00] = x_scalar[i00] * scale; + } + } +} diff --git a/ggml/src/ggml-opencl/kernels/rope.cl b/ggml/src/ggml-opencl/kernels/rope.cl new file mode 100644 index 000000000..0247730c0 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/rope.cl @@ -0,0 +1,721 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +//------------------------------------------------------------------------------ +// kernel_rope +//------------------------------------------------------------------------------ +float rope_yarn_ramp(float low, float high, int i0) { + const float y = (i0 / 2 - low) / max(0.001f, high - low); + return 1.0f - min(1.0f, max(0.0f, y)); +} + +// YaRN algorithm based on LlamaYaRNScaledRotaryEmbedding.py from https://github.com/jquesnelle/yarn +// MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng. +float2 rope_yarn( + float theta_extrap, float freq_scale, float2 corr_dims, int i0, float ext_factor, float mscale +) { + // Get n-d rotational scaling corrected for extrapolation + float theta_interp = freq_scale * theta_extrap; + float theta = theta_interp; + if (ext_factor != 0.0f) { + float ramp_mix = rope_yarn_ramp(corr_dims.s0, corr_dims.s1, i0) * ext_factor; + theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix; + + // Get n-d magnitude scaling corrected for interpolation + mscale *= 1.0f + 0.1f * log(1.0f / freq_scale); + } + return (float2)(cos(theta) * mscale, sin(theta) * mscale); +} + +// Apparently solving `n_rot = 2pi * x * base^((2 * max_pos_emb) / n_dims)` for x, we get +// `corr_fac(n_rot) = n_dims * log(max_pos_emb / (n_rot * 2pi)) / (2 * log(base))` +float rope_yarn_corr_factor(int n_dims, int n_ctx_orig, float n_rot, float base) { + return n_dims * log(n_ctx_orig / (n_rot * 2 * M_PI_F)) / (2 * log(base)); +} + +float2 rope_yarn_corr_dims( + int n_dims, int n_ctx_orig, float freq_base, float beta_fast, float beta_slow +) { + // start and end correction dims + return (float2)( + max(0.0f, floor(rope_yarn_corr_factor(n_dims, n_ctx_orig, beta_fast, freq_base))), + min(n_dims - 1.0f, ceil(rope_yarn_corr_factor(n_dims, n_ctx_orig, beta_slow, freq_base))) + ); +} + +kernel void kernel_rope_norm_f32( + global void * src0, + ulong offset0, + global int * src1, + ulong offset1, + global float * src2, + ulong offset2, + global float * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + int ne03, + ulong nb00, + ulong nb01, + ulong nb02, + ulong nb03, + int ne0, + int ne1, + int ne2, + int ne3, + ulong nb0, + ulong nb1, + ulong nb2, + ulong nb3, + int n_past, + int n_dims, + int n_ctx_orig, + float freq_base, + float freq_scale, + float ext_factor, + float attn_factor, + float beta_fast, + float beta_slow +) { + src0 = (global void*)((global char*)src0 + offset0); + src1 = (global int*)((global char*)src1 + offset1); + src2 = (global float*)((global char*)src2 + offset2); + dst = (global float*)((global char*)dst + offsetd); + + int i3 = get_group_id(2); + int i2 = get_group_id(1); + int i1 = get_group_id(0); + + float2 corr_dims = rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow); + + global int * pos = src1; + + float theta_base = (float) pos[i2]; + float inv_ndims = -1.f/n_dims; + + for (int i0 = 2*get_local_id(0); i0 < ne0; i0 += 2*get_local_size(0)) { + if (i0 < n_dims) { + int ic = i0/2; + + float theta = theta_base * pow(freq_base, inv_ndims*i0); + + float freq_factor = src2 != src0 ? src2[ic] : 1.0f; + + float2 cos_sin_theta = rope_yarn(theta/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor); + + global float * src = (global float *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); + global float * dst_data = (global float *)((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + + float x0 = src[0]; + float x1 = src[1]; + + dst_data[0] = x0*cos_sin_theta.s0 - x1*cos_sin_theta.s1; + dst_data[1] = x0*cos_sin_theta.s1 + x1*cos_sin_theta.s0; + } else { + global float * src = (global float *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); + global float * dst_data = (global float *)((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + + dst_data[0] = src[0]; + dst_data[1] = src[1]; + } + } +} + +kernel void kernel_rope_norm_f16( + global void * src0, + ulong offset0, + global int * src1, + ulong offset1, + global float * src2, + ulong offset2, + global float * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + int ne03, + ulong nb00, + ulong nb01, + ulong nb02, + ulong nb03, + int ne0, + int ne1, + int ne2, + int ne3, + ulong nb0, + ulong nb1, + ulong nb2, + ulong nb3, + int n_past, + int n_dims, + int n_ctx_orig, + float freq_base, + float freq_scale, + float ext_factor, + float attn_factor, + float beta_fast, + float beta_slow +) { + src0 = (global void*)((global char*)src0 + offset0); + src1 = (global int*)((global char*)src1 + offset1); + src2 = (global float*)((global char*)src2 + offset2); + dst = (global float*)((global char*)dst + offsetd); + + int i3 = get_group_id(2); + int i2 = get_group_id(1); + int i1 = get_group_id(0); + + float2 corr_dims = rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow); + + global int * pos = src1; + + float theta_base = (float) pos[i2]; + float inv_ndims = -1.f/n_dims; + + for (int i0 = 2*get_local_id(0); i0 < ne0; i0 += 2*get_local_size(0)) { + if (i0 < n_dims) { + int ic = i0/2; + + float theta = theta_base * pow(freq_base, inv_ndims*i0); + + float freq_factor = src2 != src0 ? src2[ic] : 1.0f; + + float2 cos_sin_theta = rope_yarn(theta/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor); + + global half * src = (global half *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); + global half * dst_data = (global half *)((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + + float x0 = src[0]; + float x1 = src[1]; + + dst_data[0] = x0*cos_sin_theta.s0 - x1*cos_sin_theta.s1; + dst_data[1] = x0*cos_sin_theta.s1 + x1*cos_sin_theta.s0; + } else { + global half * src = (global half *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); + global half * dst_data = (global half *)((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + + dst_data[0] = src[0]; + dst_data[1] = src[1]; + } + } +} + +kernel void kernel_rope_neox_f32( + global void * src0, + ulong offset0, + global int * src1, + ulong offset1, + global float * src2, + ulong offset2, + global float * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + int ne03, + ulong nb00, + ulong nb01, + ulong nb02, + ulong nb03, + int ne0, + int ne1, + int ne2, + int ne3, + ulong nb0, + ulong nb1, + ulong nb2, + ulong nb3, + int n_past, + int n_dims, + int n_ctx_orig, + float freq_base, + float freq_scale, + float ext_factor, + float attn_factor, + float beta_fast, + float beta_slow +) { + src0 = (global void*)((global char*)src0 + offset0); + src1 = (global int*)((global char*)src1 + offset1); + src2 = (global float*)((global char*)src2 + offset2); + dst = (global float*)((global char*)dst + offsetd); + + int i3 = get_group_id(2); + int i2 = get_group_id(1); + int i1 = get_group_id(0); + + float2 corr_dims = rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow); + + global int * pos = src1; + + float theta_base = (float) pos[i2]; + float inv_ndims = -1.f/n_dims; + + for (int i0 = 2*get_local_id(0); i0 < ne0; i0 += 2*get_local_size(0)) { + if (i0 < n_dims) { + int ic = i0/2; + + const float theta = theta_base * pow(freq_base, inv_ndims*i0); + + const float freq_factor = src2 != src0 ? src2[ic] : 1.0f; + + float2 cos_sin_theta = rope_yarn(theta/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor); + + global float * src = (global float *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00); + global float * dst_data = (global float *)((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0); + + const float x0 = src[0]; + const float x1 = src[n_dims/2]; + + dst_data[0] = x0*cos_sin_theta.s0 - x1*cos_sin_theta.s1; + dst_data[n_dims/2] = x0*cos_sin_theta.s1 + x1*cos_sin_theta.s0; + } else { + global float * const src = (global float *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); + global float * dst_data = (global float *)((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + + dst_data[0] = src[0]; + dst_data[1] = src[1]; + } + } +} + +kernel void kernel_rope_neox_f16( + global void * src0, + ulong offset0, + global int * src1, + ulong offset1, + global float * src2, + ulong offset2, + global float * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + int ne03, + ulong nb00, + ulong nb01, + ulong nb02, + ulong nb03, + int ne0, + int ne1, + int ne2, + int ne3, + ulong nb0, + ulong nb1, + ulong nb2, + ulong nb3, + int n_past, + int n_dims, + int n_ctx_orig, + float freq_base, + float freq_scale, + float ext_factor, + float attn_factor, + float beta_fast, + float beta_slow +) { + src0 = (global void*)((global char*)src0 + offset0); + src1 = (global int*)((global char*)src1 + offset1); + src2 = (global float*)((global char*)src2 + offset2); + dst = (global float*)((global char*)dst + offsetd); + + int i3 = get_group_id(2); + int i2 = get_group_id(1); + int i1 = get_group_id(0); + + float2 corr_dims = rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow); + + global int * pos = src1; + + float theta_base = (float) pos[i2]; + float inv_ndims = -1.f/n_dims; + + for (int i0 = 2*get_local_id(0); i0 < ne0; i0 += 2*get_local_size(0)) { + if (i0 < n_dims) { + int ic = i0/2; + + const float theta = theta_base * pow(freq_base, inv_ndims*i0); + + const float freq_factor = src2 != src0 ? src2[ic] : 1.0f; + + float2 cos_sin_theta = rope_yarn(theta/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor); + + global half * src = (global half *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00); + global half * dst_data = (global half *)((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0); + + const float x0 = src[0]; + const float x1 = src[n_dims/2]; + + dst_data[0] = x0*cos_sin_theta.s0 - x1*cos_sin_theta.s1; + dst_data[n_dims/2] = x0*cos_sin_theta.s1 + x1*cos_sin_theta.s0; + } else { + global half * const src = (global half *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); + global half * dst_data = (global half *)((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + + dst_data[0] = src[0]; + dst_data[1] = src[1]; + } + } +} + +kernel void kernel_rope_multi_f32( + global void * src0, + ulong offset0, + global int * src1, + ulong offset1, + global float * src2, + ulong offset2, + global float * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + int ne03, + ulong nb00, + ulong nb01, + ulong nb02, + ulong nb03, + int ne0, + int ne1, + int ne2, + int ne3, + ulong nb0, + ulong nb1, + ulong nb2, + ulong nb3, + int n_past, + int n_dims, + int n_ctx_orig, + float freq_base, + float freq_scale, + float ext_factor, + float attn_factor, + float beta_fast, + float beta_slow, + int4 sections +) { + src0 = (global void*)((global char*)src0 + offset0); + src1 = (global int*)((global char*)src1 + offset1); + src2 = (global float*)((global char*)src2 + offset2); + dst = (global float*)((global char*)dst + offsetd); + + int i3 = get_group_id(2); + int i2 = get_group_id(1); + int i1 = get_group_id(0); + + float2 corr_dims = rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow); + + global int * pos = src1; + + const int sect_dims = sections.s0 + sections.s1 + sections.s2 + sections.s3; + const int sec_w = sections.s1 + sections.s0; + + float inv_ndims = -1.f/n_dims; + + for (int i0 = 2*get_local_id(0); i0 < ne0; i0 += 2*get_local_size(0)) { + if (i0 < n_dims) { + int ic = i0/2; + + const int sector = (i0 / 2) % sect_dims; + float theta_base = 0.0f; + + if (sector < sections.s0) { + theta_base = pos[i2]; + } + else if (sector >= sections.s0 && sector < sec_w) { + theta_base = pos[i2 + ne2 * 1]; + } + else if (sector >= sec_w && sector < sec_w + sections.s2) { + theta_base = pos[i2 + ne2 * 2]; + } + else if (sector >= sec_w + sections.s2) { + theta_base = pos[i2 + ne2 * 3]; + } + + const float theta = theta_base * pow(freq_base, inv_ndims*i0); + + const float freq_factor = src2 != src0 ? src2[ic] : 1.0f; + + float2 cos_sin_theta = rope_yarn(theta/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor); + + global float * src = (global float *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00); + global float * dst_data = (global float *)((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0); + + const float x0 = src[0]; + const float x1 = src[n_dims/2]; + + dst_data[0] = x0*cos_sin_theta.s0 - x1*cos_sin_theta.s1; + dst_data[n_dims/2] = x0*cos_sin_theta.s1 + x1*cos_sin_theta.s0; + } else { + global float * const src = (global float *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); + global float * dst_data = (global float *)((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + + dst_data[0] = src[0]; + dst_data[1] = src[1]; + } + } +} + +kernel void kernel_rope_multi_f16( + global void * src0, + ulong offset0, + global int * src1, + ulong offset1, + global float * src2, + ulong offset2, + global half * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + int ne03, + ulong nb00, + ulong nb01, + ulong nb02, + ulong nb03, + int ne0, + int ne1, + int ne2, + int ne3, + ulong nb0, + ulong nb1, + ulong nb2, + ulong nb3, + int n_past, + int n_dims, + int n_ctx_orig, + float freq_base, + float freq_scale, + float ext_factor, + float attn_factor, + float beta_fast, + float beta_slow, + int4 sections +) { + src0 = (global void*)((global char*)src0 + offset0); + src1 = (global int*)((global char*)src1 + offset1); + src2 = (global float*)((global char*)src2 + offset2); + dst = (global float*)((global char*)dst + offsetd); + + int i3 = get_group_id(2); + int i2 = get_group_id(1); + int i1 = get_group_id(0); + + float2 corr_dims = rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow); + + global int * pos = src1; + + const int sect_dims = sections.s0 + sections.s1 + sections.s2 + sections.s3; + const int sec_w = sections.s1 + sections.s0; + + float inv_ndims = -1.f/n_dims; + + for (int i0 = 2*get_local_id(0); i0 < ne0; i0 += 2*get_local_size(0)) { + if (i0 < n_dims) { + int ic = i0/2; + + const int sector = (i0 / 2) % sect_dims; + float theta_base = 0.0f; + + if (sector < sections.s0) { + theta_base = pos[i2]; + } + else if (sector >= sections.s0 && sector < sec_w) { + theta_base = pos[i2 + ne2 * 1]; + } + else if (sector >= sec_w && sector < sec_w + sections.s2) { + theta_base = pos[i2 + ne2 * 2]; + } + else if (sector >= sec_w + sections.s2) { + theta_base = pos[i2 + ne2 * 3]; + } + + const float theta = theta_base * pow(freq_base, inv_ndims*i0); + + const float freq_factor = src2 != src0 ? src2[ic] : 1.0f; + + float2 cos_sin_theta = rope_yarn(theta/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor); + + global half * src = (global half *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00); + global half * dst_data = (global half *)((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0); + + const float x0 = src[0]; + const float x1 = src[n_dims/2]; + + dst_data[0] = x0*cos_sin_theta.s0 - x1*cos_sin_theta.s1; + dst_data[n_dims/2] = x0*cos_sin_theta.s1 + x1*cos_sin_theta.s0; + } else { + global half * const src = (global half *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); + global half * dst_data = (global half *)((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + + dst_data[0] = src[0]; + dst_data[1] = src[1]; + } + } +} + +kernel void kernel_rope_vision_f32( + global void * src0, + ulong offset0, + global int * src1, + ulong offset1, + global float * src2, + ulong offset2, + global float * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + int ne03, + ulong nb00, + ulong nb01, + ulong nb02, + ulong nb03, + int ne0, + int ne1, + int ne2, + int ne3, + ulong nb0, + ulong nb1, + ulong nb2, + ulong nb3, + int n_past, + int n_dims, + int n_ctx_orig, + float freq_base, + float freq_scale, + float ext_factor, + float attn_factor, + float beta_fast, + float beta_slow, + int4 sections +) { + src0 = (global void*)((global char*)src0 + offset0); + src1 = (global int*)((global char*)src1 + offset1); + src2 = (global float*)((global char*)src2 + offset2); + dst = (global float*)((global char*)dst + offsetd); + + int i3 = get_group_id(2); + int i2 = get_group_id(1); + int i1 = get_group_id(0); + + float2 corr_dims = rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow); + + global int * pos = src1; + + const int sect_dims = sections.s0 + sections.s1; + const int sec_w = sections.s1 + sections.s0; + + float inv_ndims = -1.f/n_dims; + + for (int i0 = 2*get_local_id(0); i0 < ne0; i0 += 2*get_local_size(0)) { + int ic = i0/2; + + const int sector = (i0/2) % sect_dims; + float theta_base = 0.0f; + + if (sector < sections.s0) { + const int p = sector; + theta_base = pos[i2] * pow(freq_base, inv_ndims*2.0f*p); + } else if (sector >= sections.s0 && sector < sec_w) { + const int p = sector - sections.s0; + theta_base = pos[i2 + ne2] * pow(freq_base, inv_ndims*2.0f*p); + } + + const float freq_factor = src2 != src0 ? src2[ic] : 1.0f; + + float2 cos_sin_theta = rope_yarn(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor); + + global float * src = (global float *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00); + global float * dst_data = (global float *)((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0); + + const float x0 = src[0]; + const float x1 = src[n_dims]; + + dst_data[0] = x0*cos_sin_theta.s0 - x1*cos_sin_theta.s1; + dst_data[n_dims] = x0*cos_sin_theta.s1 + x1*cos_sin_theta.s0; + } +} + +kernel void kernel_rope_vision_f16( + global void * src0, + ulong offset0, + global int * src1, + ulong offset1, + global float * src2, + ulong offset2, + global half * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + int ne03, + ulong nb00, + ulong nb01, + ulong nb02, + ulong nb03, + int ne0, + int ne1, + int ne2, + int ne3, + ulong nb0, + ulong nb1, + ulong nb2, + ulong nb3, + int n_past, + int n_dims, + int n_ctx_orig, + float freq_base, + float freq_scale, + float ext_factor, + float attn_factor, + float beta_fast, + float beta_slow, + int4 sections +) { + src0 = (global void*)((global char*)src0 + offset0); + src1 = (global int*)((global char*)src1 + offset1); + src2 = (global float*)((global char*)src2 + offset2); + dst = (global float*)((global char*)dst + offsetd); + + int i3 = get_group_id(2); + int i2 = get_group_id(1); + int i1 = get_group_id(0); + + float2 corr_dims = rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow); + + global int * pos = src1; + + const int sect_dims = sections.s0 + sections.s1; + const int sec_w = sections.s1 + sections.s0; + + float inv_ndims = -1.f/n_dims; + + for (int i0 = 2*get_local_id(0); i0 < ne0; i0 += 2*get_local_size(0)) { + int ic = i0/2; + + const int sector = (i0/2) % sect_dims; + float theta_base = 0.0f; + + if (sector < sections.s0) { + const int p = sector; + theta_base = pos[i2] * pow(freq_base, inv_ndims*2.0f*p); + } else if (sector >= sections.s0 && sector < sec_w) { + const int p = sector - sections.s0; + theta_base = pos[i2 + ne2] * pow(freq_base, inv_ndims*2.0f*p); + } + + const float freq_factor = src2 != src0 ? src2[ic] : 1.0f; + + float2 cos_sin_theta = rope_yarn(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor); + + global half * src = (global half *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00); + global half * dst_data = (global half *)((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0); + + const float x0 = src[0]; + const float x1 = src[n_dims]; + + dst_data[0] = x0*cos_sin_theta.s0 - x1*cos_sin_theta.s1; + dst_data[n_dims] = x0*cos_sin_theta.s1 + x1*cos_sin_theta.s0; + } +} diff --git a/ggml/src/ggml-opencl/kernels/scale.cl b/ggml/src/ggml-opencl/kernels/scale.cl new file mode 100644 index 000000000..8cfd518fa --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/scale.cl @@ -0,0 +1,16 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +//------------------------------------------------------------------------------ +// scale +//------------------------------------------------------------------------------ +kernel void kernel_scale( + global float4 * src0, + ulong offset0, + global float4 * dst, + ulong offsetd, + float scale +) { + src0 = (global float4*)((global char*)src0 + offset0); + dst = (global float4*)((global char*)dst + offsetd); + dst[get_global_id(0)] = src0[get_global_id(0)] * scale; +} diff --git a/ggml/src/ggml-opencl/kernels/silu.cl b/ggml/src/ggml-opencl/kernels/silu.cl new file mode 100644 index 000000000..1d95e1b50 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/silu.cl @@ -0,0 +1,30 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +//------------------------------------------------------------------------------ +// silu +//------------------------------------------------------------------------------ +kernel void kernel_silu( + global float * src0, + ulong offset0, + global float * dst, + ulong offsetd +) { + src0 = (global float*)((global char*)src0 + offset0); + dst = (global float*)((global char*)dst + offsetd); + + float x = src0[get_global_id(0)]; + dst[get_global_id(0)] = x / (1.0f + exp(-x)); +} + +kernel void kernel_silu_4( + global float4 * src0, + ulong offset0, + global float4 * dst, + ulong offsetd +) { + src0 = (global float4*)((global char*)src0 + offset0); + dst = (global float4*)((global char*)dst + offsetd); + + float4 x = src0[get_global_id(0)]; + dst[get_global_id(0)] = x / (1.0f + exp(-x)); +} diff --git a/ggml/src/ggml-opencl/kernels/softmax_4_f16.cl b/ggml/src/ggml-opencl/kernels/softmax_4_f16.cl new file mode 100644 index 000000000..62c05369a --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/softmax_4_f16.cl @@ -0,0 +1,87 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +#ifdef cl_intel_subgroups +#pragma OPENCL EXTENSION cl_intel_subgroups : enable +#else +#pragma OPENCL EXTENSION cl_khr_subgroups : enable +#endif + +#ifdef cl_intel_required_subgroup_size +#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable +#define INTEL_GPU 1 +#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16))) +#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32))) +#elif defined(cl_qcom_reqd_sub_group_size) +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable +#define ADRENO_GPU 1 +#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half"))) +#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full"))) +#endif + +#ifdef ADRENO_GPU +REQD_SUBGROUP_SIZE_64 +#endif +kernel void kernel_soft_max_4_f16( + global float * src0, + ulong offset0, + global half * src1, + ulong offset1, + global float * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + float scale, + float max_bias, + float m0, + float m1, + int n_head_log2 +) { + src0 = (global float *)((global char *)src0 + offset0); + src1 = (global half *)((global char *)src1 + offset1); + dst = (global float *)((global char *)dst + offsetd); + + int i03 = get_group_id(2); + int i02 = get_group_id(1); + int i01 = get_group_id(0); + + global float4 * psrc4 = (global float4 *)(src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00); + global half4 * pmask = (global char *)src1 != (global char *)src0 ? (global half4 *)(src1 + i01*ne00) : 0; + global float4 * pdst4 = (global float4 *)(dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00); + + float slope = 1.0f; + + // ALiBi + if (max_bias > 0.0f) { + int h = i02; + + float base = h < n_head_log2 ? m0 : m1; + int exp = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1; + + slope = pow(base, exp); + } + + // parallel max + float4 lmax4 = -INFINITY; + for (int i00 = get_local_id(0); i00 < ne00/4; i00 += get_local_size(0)) { + lmax4 = fmax(lmax4, psrc4[i00]*scale + slope*(pmask ? convert_float4(pmask[i00]) : 0.0f)); + } + float lmax = fmax(fmax(lmax4.s0, lmax4.s1), fmax(lmax4.s2, lmax4.s3)); + + const float max = sub_group_reduce_max(lmax); + + // parallel sum + float4 lsum4 = 0.0f; + for (int i00 = get_local_id(0); i00 < ne00/4; i00 += get_local_size(0)) { + const float4 exp_psrc4 = exp((psrc4[i00]*scale + slope*(pmask ? convert_float4(pmask[i00]) : 0.0f)) - max); + lsum4 += exp_psrc4; + pdst4[i00] = exp_psrc4; + } + float lsum = lsum4.s0 + lsum4.s1 + lsum4.s2 + lsum4.s3; + + const float sum = sub_group_reduce_add(lsum); + + for (int i00 = get_local_id(0); i00 < ne00/4; i00 += get_local_size(0)) { + pdst4[i00] /= sum; + } +} diff --git a/ggml/src/ggml-opencl/kernels/softmax_4_f32.cl b/ggml/src/ggml-opencl/kernels/softmax_4_f32.cl new file mode 100644 index 000000000..d562774ea --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/softmax_4_f32.cl @@ -0,0 +1,87 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +#ifdef cl_intel_subgroups +#pragma OPENCL EXTENSION cl_intel_subgroups : enable +#else +#pragma OPENCL EXTENSION cl_khr_subgroups : enable +#endif + +#ifdef cl_intel_required_subgroup_size +#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable +#define INTEL_GPU 1 +#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16))) +#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32))) +#elif defined(cl_qcom_reqd_sub_group_size) +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable +#define ADRENO_GPU 1 +#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half"))) +#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full"))) +#endif + +#ifdef ADRENO_GPU +REQD_SUBGROUP_SIZE_64 +#endif +kernel void kernel_soft_max_4( + global float * src0, + ulong offset0, + global float * src1, + ulong offset1, + global float * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + float scale, + float max_bias, + float m0, + float m1, + int n_head_log2 +) { + src0 = (global float*)((global char*)src0 + offset0); + src1 = (global float*)((global char*)src1 + offset1); + dst = (global float*)((global char*)dst + offsetd); + + int i03 = get_group_id(2); + int i02 = get_group_id(1); + int i01 = get_group_id(0); + + global float4 * psrc4 = (global float4 *)(src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00); + global float4 * pmask = src1 != src0 ? (global float4 *)(src1 + i01*ne00) : 0; + global float4 * pdst4 = (global float4 *)(dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00); + + float slope = 1.0f; + + // ALiBi + if (max_bias > 0.0f) { + int h = i02; + + float base = h < n_head_log2 ? m0 : m1; + int exp = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1; + + slope = pow(base, exp); + } + + // parallel max + float4 lmax4 = -INFINITY; + for (int i00 = get_local_id(0); i00 < ne00/4; i00 += get_local_size(0)) { + lmax4 = fmax(lmax4, psrc4[i00]*scale + (pmask ? slope*pmask[i00] : 0.0f)); + } + float lmax = fmax(fmax(lmax4.s0, lmax4.s1), fmax(lmax4.s2, lmax4.s3)); + + const float max = sub_group_reduce_max(lmax); + + // parallel sum + float4 lsum4 = 0.0f; + for (int i00 = get_local_id(0); i00 < ne00/4; i00 += get_local_size(0)) { + const float4 exp_psrc4 = exp((psrc4[i00]*scale + (pmask ? slope*pmask[i00] : 0.0f)) - max); + lsum4 += exp_psrc4; + pdst4[i00] = exp_psrc4; + } + float lsum = lsum4.s0 + lsum4.s1 + lsum4.s2 + lsum4.s3; + + const float sum = sub_group_reduce_add(lsum); + + for (int i00 = get_local_id(0); i00 < ne00/4; i00 += get_local_size(0)) { + pdst4[i00] /= sum; + } +} diff --git a/ggml/src/ggml-opencl/kernels/softmax_f16.cl b/ggml/src/ggml-opencl/kernels/softmax_f16.cl new file mode 100644 index 000000000..d38d09967 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/softmax_f16.cl @@ -0,0 +1,86 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +#ifdef cl_intel_subgroups +#pragma OPENCL EXTENSION cl_intel_subgroups : enable +#else +#pragma OPENCL EXTENSION cl_khr_subgroups : enable +#endif + +#ifdef cl_intel_required_subgroup_size +#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable +#define INTEL_GPU 1 +#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16))) +#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32))) +#elif defined(cl_qcom_reqd_sub_group_size) +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable +#define ADRENO_GPU 1 +#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half"))) +#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full"))) +#endif + +#ifdef ADRENO_GPU +REQD_SUBGROUP_SIZE_64 +#endif +kernel void kernel_soft_max_f16( + global float * src0, + ulong offset0, + global half * src1, + ulong offset1, + global float * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + float scale, + float max_bias, + float m0, + float m1, + int n_head_log2 +) { + src0 = (global float *)((global char *)src0 + offset0); + src1 = (global half *)((global char *)src1 + offset1); + dst = (global float *)((global char *)dst + offsetd); + + int i03 = get_group_id(2); + int i02 = get_group_id(1); + int i01 = get_group_id(0); + + global float * psrc0 = src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; + global half * pmask = (global char *)src1 != (global char *)src0 ? src1 + i01*ne00 : 0; + global float * pdst = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; + + float slope = 1.0f; + + // ALiBi + if (max_bias > 0.0f) { + int h = i02; + + float base = h < n_head_log2 ? m0 : m1; + int exp = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1; + + slope = pow(base, exp); + } + + // parallel max + float lmax = -INFINITY; + for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) { + lmax = fmax(lmax, psrc0[i00]*scale + (pmask ? slope*pmask[i00] : 0.0f)); + } + float max = sub_group_reduce_max(lmax); + + // parallel sum + float lsum = 0.0f; + for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) { + float exp_psrc0 = exp((psrc0[i00]*scale + (pmask ? slope*pmask[i00] : 0.0f)) - max); + lsum += exp_psrc0; + // Remember the result of exp here. exp is expensive, so we really do not + // wish to compute it twice. + pdst[i00] = exp_psrc0; + } + + const float sum = sub_group_reduce_add(lsum); + + for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) { + pdst[i00] /= sum; + } +} diff --git a/ggml/src/ggml-opencl/kernels/softmax_f32.cl b/ggml/src/ggml-opencl/kernels/softmax_f32.cl new file mode 100644 index 000000000..001b587ab --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/softmax_f32.cl @@ -0,0 +1,86 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +#ifdef cl_intel_subgroups +#pragma OPENCL EXTENSION cl_intel_subgroups : enable +#else +#pragma OPENCL EXTENSION cl_khr_subgroups : enable +#endif + +#ifdef cl_intel_required_subgroup_size +#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable +#define INTEL_GPU 1 +#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16))) +#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32))) +#elif defined(cl_qcom_reqd_sub_group_size) +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable +#define ADRENO_GPU 1 +#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half"))) +#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full"))) +#endif + +#ifdef ADRENO_GPU +REQD_SUBGROUP_SIZE_64 +#endif +kernel void kernel_soft_max( + global float * src0, + ulong offset0, + global float * src1, + ulong offset1, + global float * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + float scale, + float max_bias, + float m0, + float m1, + int n_head_log2 +) { + src0 = (global float*)((global char*)src0 + offset0); + src1 = (global float*)((global char*)src1 + offset1); + dst = (global float*)((global char*)dst + offsetd); + + int i03 = get_group_id(2); + int i02 = get_group_id(1); + int i01 = get_group_id(0); + + global float * psrc0 = src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; + global float * pmask = src1 != src0 ? src1 + i01*ne00 : 0; + global float * pdst = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; + + float slope = 1.0f; + + // ALiBi + if (max_bias > 0.0f) { + int h = i02; + + float base = h < n_head_log2 ? m0 : m1; + int exp = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1; + + slope = pow(base, exp); + } + + // parallel max + float lmax = -INFINITY; + for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) { + lmax = fmax(lmax, psrc0[i00]*scale + (pmask ? slope*pmask[i00] : 0.0f)); + } + float max = sub_group_reduce_max(lmax); + + // parallel sum + float lsum = 0.0f; + for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) { + float exp_psrc0 = exp((psrc0[i00]*scale + (pmask ? slope*pmask[i00] : 0.0f)) - max); + lsum += exp_psrc0; + // Remember the result of exp here. exp is expensive, so we really do not + // wish to compute it twice. + pdst[i00] = exp_psrc0; + } + + const float sum = sub_group_reduce_add(lsum); + + for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) { + pdst[i00] /= sum; + } +} diff --git a/ggml/src/ggml-opencl/kernels/transpose.cl b/ggml/src/ggml-opencl/kernels/transpose.cl new file mode 100644 index 000000000..a11490b30 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/transpose.cl @@ -0,0 +1,84 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +// 16-bit transpose, loading/storing a 4x4 tile of elements +kernel void kernel_transpose_16( + __read_only image1d_buffer_t input, + __write_only image1d_buffer_t output, + const uint rows, + const uint cols +) { + + const int i = get_global_id(0); + const int j = get_global_id(1); + const int i_2 = i<<2; + const int j_2 = j<<2; + + half4 temp0 = read_imageh(input, (j_2+0)*cols+i); + half4 temp1 = read_imageh(input, (j_2+1)*cols+i); + half4 temp2 = read_imageh(input, (j_2+2)*cols+i); + half4 temp3 = read_imageh(input, (j_2+3)*cols+i); + + write_imageh(output, (i_2+0)*rows+j, (half4)(temp0.s0, temp1.s0, temp2.s0, temp3.s0)); + write_imageh(output, (i_2+1)*rows+j, (half4)(temp0.s1, temp1.s1, temp2.s1, temp3.s1)); + write_imageh(output, (i_2+2)*rows+j, (half4)(temp0.s2, temp1.s2, temp2.s2, temp3.s2)); + write_imageh(output, (i_2+3)*rows+j, (half4)(temp0.s3, temp1.s3, temp2.s3, temp3.s3)); +} + +// 32-bit transpose, loading/storing a 4x4 tile of elements +kernel void kernel_transpose_32( + __read_only image1d_buffer_t input, + __write_only image1d_buffer_t output, + const uint rows, + const uint cols +) { + + const int i = get_global_id(0); + const int j = get_global_id(1); + const int i_2 = i<<2; + const int j_2 = j<<2; + + float4 temp0 = read_imagef(input, (j_2+0)*cols+i); + float4 temp1 = read_imagef(input, (j_2+1)*cols+i); + float4 temp2 = read_imagef(input, (j_2+2)*cols+i); + float4 temp3 = read_imagef(input, (j_2+3)*cols+i); + + write_imagef(output, (i_2+0)*rows+j, (float4)(temp0.s0, temp1.s0, temp2.s0, temp3.s0)); + write_imagef(output, (i_2+1)*rows+j, (float4)(temp0.s1, temp1.s1, temp2.s1, temp3.s1)); + write_imagef(output, (i_2+2)*rows+j, (float4)(temp0.s2, temp1.s2, temp2.s2, temp3.s2)); + write_imagef(output, (i_2+3)*rows+j, (float4)(temp0.s3, temp1.s3, temp2.s3, temp3.s3)); + +} + +// 32-bit transpose, loading/storing a 4x4 tile of elements +// Only used for activations +// converts to FP16 +// also adds zero padding for non multiple of 8 prompt lengths +kernel void kernel_transpose_32_16(__read_only image1d_buffer_t input, __write_only image1d_buffer_t output, const uint rows, const uint cols, const uint padded_rows) { + + const int i = get_global_id(0); + const int j = get_global_id(1); + const int i_2 = i<<2; + const int j_2 = j<<2; + half4 temp0 = {0,0,0,0}; // initialize outputs to 0 + half4 temp1 = {0,0,0,0}; + half4 temp2 = {0,0,0,0}; + half4 temp3 = {0,0,0,0}; + + if((j_2+0)*cols+i*4+3 < rows*cols*16){ // only load from a valid location. Otherwise keep register data as 0 + temp0 = read_imageh(input, (j_2+0)*cols+i); + } + if((j_2+1)*cols+i*4+3 < rows*cols*16){ + temp1 = read_imageh(input, (j_2+1)*cols+i); + } + if((j_2+2)*cols+i*4+3 < rows*cols*16){ + temp2 = read_imageh(input, (j_2+2)*cols+i); + } + if((j_2+3)*cols+i*4+3 < rows*cols*16){ + temp3 = read_imageh(input, (j_2+3)*cols+i); + } + + write_imageh(output, (i_2+0)*padded_rows+j, (half4)(temp0.s0, temp1.s0, temp2.s0, temp3.s0)); // no conditionals for output, includes zero padding + write_imageh(output, (i_2+1)*padded_rows+j, (half4)(temp0.s1, temp1.s1, temp2.s1, temp3.s1)); + write_imageh(output, (i_2+2)*padded_rows+j, (half4)(temp0.s2, temp1.s2, temp2.s2, temp3.s2)); + write_imageh(output, (i_2+3)*padded_rows+j, (half4)(temp0.s3, temp1.s3, temp2.s3, temp3.s3)); +} From b43d89e311c5e7fbf62e5ec3c0401eb536677267 Mon Sep 17 00:00:00 2001 From: Chenguang Li <757486878@qq.com> Date: Wed, 16 Apr 2025 16:21:05 +0800 Subject: [PATCH 20/20] CANN: Add 310P operator support check (#12962) --- ggml/src/ggml-cann/aclnn_ops.cpp | 8 ++++++++ ggml/src/ggml-cann/ggml-cann.cpp | 10 ++++++++++ 2 files changed, 18 insertions(+) diff --git a/ggml/src/ggml-cann/aclnn_ops.cpp b/ggml/src/ggml-cann/aclnn_ops.cpp index 2c5cdcae3..2c6737ea8 100644 --- a/ggml/src/ggml-cann/aclnn_ops.cpp +++ b/ggml/src/ggml-cann/aclnn_ops.cpp @@ -625,6 +625,10 @@ static void ggml_cann_avg_pool2d(ggml_backend_cann_context& ctx, bool count_include_pad = true; int64_t divisor_override = 0; int8_t cube_math_type = 0; +#ifdef ASCEND_310P + cube_math_type = 1; +#endif + GGML_CANN_CALL_ACLNN_OP(AvgPool2d, acl_src, kernel_size, strides, paddings_avg, ceil_mode, count_include_pad, divisor_override, cube_math_type, acl_dst); @@ -2590,6 +2594,10 @@ void ggml_cann_conv_transpose_1d(ggml_backend_cann_context& ctx, ggml_tensor* ds int64_t groups = 1; int8_t cubeMathType = 0; +#ifdef ASCEND_310P + cubeMathType = 1; +#endif + GGML_CANN_CALL_ACLNN_OP(Convolution, acl_input, acl_weight, nullptr, stride, padding, dilation, transposed, padding, groups, acl_dst, cubeMathType); diff --git a/ggml/src/ggml-cann/ggml-cann.cpp b/ggml/src/ggml-cann/ggml-cann.cpp index 08b9ca301..ca41e0260 100644 --- a/ggml/src/ggml-cann/ggml-cann.cpp +++ b/ggml/src/ggml-cann/ggml-cann.cpp @@ -2022,6 +2022,10 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, return true; case GGML_TYPE_Q8_0: case GGML_TYPE_Q4_0: +#ifdef ASCEND_310P + // Q4 && Q8 per group is not suppor on 310p device + return false; +#endif // only support contiguous for quantized types. return ggml_is_contiguous(op->src[0]) && ggml_is_contiguous(op->src[1]); @@ -2107,6 +2111,12 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, } case GGML_OP_POOL_2D: { const int32_t * opts = (const int32_t *) op->op_params; +#ifdef ASCEND_310P + enum ggml_op_pool opt = static_cast(opts[0]); + if(opt == GGML_OP_POOL_MAX){ + return false; + } +#endif const int k0 = opts[1]; const int k1 = opts[2]; const int p0 = opts[5];