From 02082f1519565fc7b49de211b28bc5404a69209b Mon Sep 17 00:00:00 2001 From: Ivy233 <952254420@qq.com> Date: Wed, 26 Mar 2025 22:06:04 +0800 Subject: [PATCH 01/26] clip: Fix llama-llava-clip-quantize-cli quantization error under CUDA backend (#12566) * [Fix] Compiling clip-quantize-cli and running it in a CUDA environment will cause ggml_fp16_to_fp32 to report an error when trying to access video memory. You need to switch to the CPU backend to run quantize. After the fix, it will automatically run in the CPU backend and will no longer be bound to CUDA. * [Fix]Roll back the signature and implementation of clip_model_load, and change the call in clip_model_quantize to clip_init. --- examples/llava/clip.cpp | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/examples/llava/clip.cpp b/examples/llava/clip.cpp index a1f050e39..58ee5cf01 100644 --- a/examples/llava/clip.cpp +++ b/examples/llava/clip.cpp @@ -2989,7 +2989,10 @@ bool clip_model_quantize(const char * fname_inp, const char * fname_out, const i assert(itype < GGML_TYPE_COUNT); ggml_type type = static_cast(itype); - auto * ctx_clip = clip_model_load(fname_inp, 2); + auto * ctx_clip = clip_init(fname_inp, clip_context_params{ + /* use_gpu */ false, + /* verbosity */ 2, + }); const auto & ctx_src = ctx_clip->ctx_gguf; const auto & ctx_data = ctx_clip->ctx_data; From 2447ad8a981253a2b8e9f4b31cc8e7fdff83423e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Micha=C5=82=20Moskal?= Date: Wed, 26 Mar 2025 11:06:09 -0700 Subject: [PATCH 02/26] upgrade to llguidance 0.7.10 (#12576) --- common/CMakeLists.txt | 4 +- common/llguidance.cpp | 77 ++++++++++++------------------- tests/test-grammar-llguidance.cpp | 62 +++++++++++++++++++++++++ 3 files changed, 94 insertions(+), 49 deletions(-) diff --git a/common/CMakeLists.txt b/common/CMakeLists.txt index 17146fffc..829eb5b72 100644 --- a/common/CMakeLists.txt +++ b/common/CMakeLists.txt @@ -114,8 +114,8 @@ if (LLAMA_LLGUIDANCE) ExternalProject_Add(llguidance_ext GIT_REPOSITORY https://github.com/guidance-ai/llguidance - # v0.6.12: - GIT_TAG ced1c9023d47ec194fa977932d35ce65c2ebfc09 + # v0.7.10: + GIT_TAG 0309d2a6bf40abda35344a362edc71e06d5009f8 PREFIX ${CMAKE_BINARY_DIR}/llguidance SOURCE_DIR ${LLGUIDANCE_SRC} BUILD_IN_SOURCE TRUE diff --git a/common/llguidance.cpp b/common/llguidance.cpp index 2feeb93c8..8bff89ea4 100644 --- a/common/llguidance.cpp +++ b/common/llguidance.cpp @@ -11,25 +11,24 @@ struct llama_sampler_llg { std::string grammar_kind; std::string grammar_data; LlgTokenizer * tokenizer; - LlgConstraint * grammar; - LlgMaskResult llg_res; - bool has_llg_res; + LlgMatcher * grammar; }; -static LlgConstraint * llama_sampler_llg_new(LlgTokenizer * tokenizer, const char * grammar_kind, - const char * grammar_data) { +static LlgMatcher * llama_sampler_llg_new(LlgTokenizer * tokenizer, const char * grammar_kind, + const char * grammar_data) { LlgConstraintInit cinit; llg_constraint_init_set_defaults(&cinit, tokenizer); const char * log_level = getenv("LLGUIDANCE_LOG_LEVEL"); if (log_level && *log_level) { cinit.log_stderr_level = atoi(log_level); } - auto c = llg_new_constraint_any(&cinit, grammar_kind, grammar_data); - if (llg_get_error(c)) { - LOG_ERR("llg error: %s\n", llg_get_error(c)); - llg_free_constraint(c); + auto c = llg_new_matcher(&cinit, grammar_kind, grammar_data); + if (llg_matcher_get_error(c)) { + LOG_ERR("llg error: %s\n", llg_matcher_get_error(c)); + llg_free_matcher(c); return nullptr; } + return c; } @@ -40,39 +39,29 @@ static const char * llama_sampler_llg_name(const llama_sampler * /*smpl*/) { static void llama_sampler_llg_accept_impl(llama_sampler * smpl, llama_token token) { auto * ctx = (llama_sampler_llg *) smpl->ctx; if (ctx->grammar) { - LlgCommitResult res; - llg_commit_token(ctx->grammar, token, &res); - ctx->has_llg_res = false; + llg_matcher_consume_token(ctx->grammar, token); } } static void llama_sampler_llg_apply(llama_sampler * smpl, llama_token_data_array * cur_p) { auto * ctx = (llama_sampler_llg *) smpl->ctx; if (ctx->grammar) { - if (!ctx->has_llg_res) { - if (llg_compute_mask(ctx->grammar, &ctx->llg_res) == 0) { - ctx->has_llg_res = true; + const uint32_t * mask = llg_matcher_get_mask(ctx->grammar); + if (mask == nullptr) { + if (llg_matcher_compute_mask(ctx->grammar) == 0) { + mask = llg_matcher_get_mask(ctx->grammar); } else { - LOG_ERR("llg error: %s\n", llg_get_error(ctx->grammar)); - llg_free_constraint(ctx->grammar); + LOG_ERR("llg error: %s\n", llg_matcher_get_error(ctx->grammar)); + llg_free_matcher(ctx->grammar); ctx->grammar = nullptr; + return; } } - if (ctx->has_llg_res) { - if (ctx->llg_res.is_stop) { - for (size_t i = 0; i < cur_p->size; ++i) { - if (!llama_vocab_is_eog(ctx->vocab, cur_p->data[i].id)) { - cur_p->data[i].logit = -INFINITY; - } - } - } else { - const uint32_t * mask = ctx->llg_res.sample_mask; - for (size_t i = 0; i < cur_p->size; ++i) { - auto token = cur_p->data[i].id; - if ((mask[token / 32] & (1 << (token % 32))) == 0) { - cur_p->data[i].logit = -INFINITY; - } - } + + for (size_t i = 0; i < cur_p->size; ++i) { + auto token = cur_p->data[i].id; + if ((mask[token / 32] & (1 << (token % 32))) == 0) { + cur_p->data[i].logit = -INFINITY; } } } @@ -80,14 +69,9 @@ static void llama_sampler_llg_apply(llama_sampler * smpl, llama_token_data_array static void llama_sampler_llg_reset(llama_sampler * smpl) { auto * ctx = (llama_sampler_llg *) smpl->ctx; - if (!ctx->grammar) { - return; + if (ctx->grammar) { + llg_matcher_reset(ctx->grammar); } - - auto * grammar_new = llama_sampler_llg_new(ctx->tokenizer, ctx->grammar_kind.c_str(), ctx->grammar_data.c_str()); - llg_free_constraint(ctx->grammar); - ctx->grammar = grammar_new; - ctx->has_llg_res = false; } static llama_sampler * llama_sampler_llg_clone(const llama_sampler * smpl) { @@ -102,7 +86,7 @@ static llama_sampler * llama_sampler_llg_clone(const llama_sampler * smpl) { if (ctx->grammar) { result_ctx->grammar_kind = ctx->grammar_kind; result_ctx->grammar_data = ctx->grammar_data; - result_ctx->grammar = llg_clone_constraint(ctx->grammar); + result_ctx->grammar = llg_clone_matcher(ctx->grammar); result_ctx->tokenizer = llg_clone_tokenizer(ctx->tokenizer); } } @@ -114,7 +98,7 @@ static void llama_sampler_llg_free(llama_sampler * smpl) { const auto * ctx = (llama_sampler_llg *) smpl->ctx; if (ctx->grammar) { - llg_free_constraint(ctx->grammar); + llg_free_matcher(ctx->grammar); llg_free_tokenizer(ctx->tokenizer); } @@ -239,9 +223,11 @@ llama_sampler * llama_sampler_init_llg(const llama_vocab * vocab, const char * g /* .grammar_data = */ grammar_data, /* .tokenizer = */ tokenizer, /* .grammar = */ llama_sampler_llg_new(tokenizer, grammar_kind, grammar_data), - /* .llg_res = */ {}, - /* .has_llg_res = */ false, }; + if (ctx->grammar) { + GGML_ASSERT(((size_t) llama_vocab_n_tokens(vocab) + 31) / 32 * 4 == + llg_matcher_get_mask_byte_size(ctx->grammar)); + } } else { *ctx = { /* .vocab = */ vocab, @@ -249,15 +235,12 @@ llama_sampler * llama_sampler_init_llg(const llama_vocab * vocab, const char * g /* .grammar_data = */ {}, /* .tokenizer = */ nullptr, /* .grammar = */ nullptr, - /* .llg_res = */ {}, - /* .has_llg_res = */ false, }; } return llama_sampler_init( /* .iface = */ &llama_sampler_llg_i, - /* .ctx = */ ctx - ); + /* .ctx = */ ctx); } #else diff --git a/tests/test-grammar-llguidance.cpp b/tests/test-grammar-llguidance.cpp index 8b696006b..3c19220e1 100644 --- a/tests/test-grammar-llguidance.cpp +++ b/tests/test-grammar-llguidance.cpp @@ -1086,6 +1086,65 @@ static void test_json_schema() { }); } +static void one_hot(llama_token_data_array & tok_arr, llama_token selected) { + auto n_vocab = tok_arr.size; + + tok_arr.selected = -1; + tok_arr.sorted = false; + for (llama_token token_id = 0; token_id < (llama_token) n_vocab; token_id++) { + tok_arr.data[token_id].id = token_id; + tok_arr.data[token_id].logit = 0.0f; + } + + tok_arr.data[selected].logit = 100.0f; +} + +static void test_sampler_chain(void) { + auto sparams = llama_sampler_chain_default_params(); + sparams.no_perf = false; + llama_sampler * sampler = llama_sampler_chain_init(sparams); + + const auto grammar_data = R"(%llguidance {} +start: /[A-Z ]*/)"; + + llama_sampler_chain_add(sampler, llama_sampler_init_llg(vocab, "lark", grammar_data)); + llama_sampler_chain_add(sampler, llama_sampler_init_dist(42)); + + auto input = "ALL YOUR BASE ARE BELONG TO US"; + auto tokens = common_tokenize(vocab, input, false, false); + + auto n_vocab = llama_vocab_n_tokens(vocab); + + std::vector cur; + cur.reserve(n_vocab); + for (llama_token token_id = 0; token_id < (llama_token) n_vocab; token_id++) { + cur.emplace_back(llama_token_data{ token_id, 0.0f, 0.0f }); + } + auto tok_arr = llama_token_data_array{ cur.data(), cur.size(), -1, false }; + + for (const auto token : tokens) { + one_hot(tok_arr, token); + + fprintf(stderr, "applying token: %d\n", token); + llama_sampler_apply(sampler, &tok_arr); + + auto idx = tok_arr.selected; + fprintf(stderr, " -> %d %f\n", cur[idx].id, cur[idx].logit); + assert(cur[tok_arr.selected].id == token); + llama_sampler_accept(sampler, token); + } + + auto tok_eos = llama_vocab_eot(vocab); + if (tok_eos == LLAMA_TOKEN_NULL) { + tok_eos = llama_vocab_eos(vocab); + } + + one_hot(tok_arr, tok_eos); + + llama_sampler_apply(sampler, &tok_arr); + assert(cur[tok_arr.selected].id == tok_eos); +} + int main(int argc, const char ** argv) { fprintf(stdout, "Running llguidance integration tests...\n"); @@ -1135,6 +1194,9 @@ int main(int argc, const char ** argv) { test_special_chars(); test_quantifiers(); test_json_schema(); + + test_sampler_chain(); + fprintf(stdout, "All tests passed.\n"); return 0; } From b3298fa47a2d56ae892127ea038942ab1cada190 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Wed, 26 Mar 2025 21:38:38 +0200 Subject: [PATCH 03/26] metal : refactor mat-vec code (#12569) * metal : refactor mat-vec code ggml-ci * metal : rename all_sum -> sum_all ggml-ci * metal : fix comments [no ci] * metal : fix nr constant [no ci] * metal : mv q6_K support nr0 > 1 ggml-ci * metal : reduce register pressure ggml-ci * metal : fix typo [no ci] * metal : reduce register pressure ggml-ci --- ggml/src/ggml-metal/ggml-metal-impl.h | 64 +++ ggml/src/ggml-metal/ggml-metal.m | 298 +++++------- ggml/src/ggml-metal/ggml-metal.metal | 657 +++++++++++++------------- 3 files changed, 521 insertions(+), 498 deletions(-) diff --git a/ggml/src/ggml-metal/ggml-metal-impl.h b/ggml/src/ggml-metal/ggml-metal-impl.h index 1e954b4ce..ca5a00b03 100644 --- a/ggml/src/ggml-metal/ggml-metal-impl.h +++ b/ggml/src/ggml-metal/ggml-metal-impl.h @@ -1,6 +1,70 @@ #ifndef GGML_METAL_IMPL #define GGML_METAL_IMPL +// kernel parameters for mat-vec threadgroups +// +// N_R0: number of src0 rows to process per simdgroup +// N_SG: number of simdgroups per threadgroup +// +// TODO: for optimal performance, become function of the device and work size + +#define N_R0_Q4_0 4 +#define N_SG_Q4_0 2 + +#define N_R0_Q4_1 4 +#define N_SG_Q4_1 2 + +#define N_R0_Q5_0 4 +#define N_SG_Q5_0 2 + +#define N_R0_Q5_1 4 +#define N_SG_Q5_1 2 + +#define N_R0_Q8_0 4 +#define N_SG_Q8_0 2 + +#define N_R0_Q2_K 4 +#define N_SG_Q2_K 2 + +#define N_R0_Q3_K 2 +#define N_SG_Q3_K 2 + +#define N_R0_Q4_K 4 +#define N_SG_Q4_K 2 + +#define N_R0_Q5_K 2 +#define N_SG_Q5_K 2 + +#define N_R0_Q6_K 1 +#define N_SG_Q6_K 2 + +#define N_R0_IQ1_S 4 +#define N_SG_IQ1_S 2 + +#define N_R0_IQ1_M 4 +#define N_SG_IQ1_M 2 + +#define N_R0_IQ2_XXS 4 +#define N_SG_IQ2_XXS 2 + +#define N_R0_IQ2_XS 4 +#define N_SG_IQ2_XS 2 + +#define N_R0_IQ2_S 4 +#define N_SG_IQ2_S 2 + +#define N_R0_IQ3_XXS 4 +#define N_SG_IQ3_XXS 2 + +#define N_R0_IQ3_S 4 +#define N_SG_IQ3_S 2 + +#define N_R0_IQ4_NL 2 +#define N_SG_IQ4_NL 2 + +#define N_R0_IQ4_XS 2 +#define N_SG_IQ4_XS 2 + // kernel argument structs // // - element counters (e.g. ne00) typically use int32_t to reduce register usage diff --git a/ggml/src/ggml-metal/ggml-metal.m b/ggml/src/ggml-metal/ggml-metal.m index af65e7d9f..195d96782 100644 --- a/ggml/src/ggml-metal/ggml-metal.m +++ b/ggml/src/ggml-metal/ggml-metal.m @@ -2561,171 +2561,180 @@ static void ggml_metal_encode_node( [encoder setThreadgroupMemoryLength:8192 atIndex:0]; [encoder dispatchThreadgroups:MTLSizeMake( (ne11 + 31)/32, (ne01 + 63)/64, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)]; } else { - int nth0 = 32; - int nth1 = 1; - int nrows = 1; - //printf("vector: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12); - id pipeline = nil; + int nsg = 0; // number of simdgroups + int nr0 = 0; // number of src0 rows per simdgroup + int nr1 = 1; // number of src1 rows per threadgroup + + size_t smem = 0; // shared memory + // use custom matrix x vector kernel switch (src0t) { case GGML_TYPE_F32: { GGML_ASSERT(src1t == GGML_TYPE_F32); + nsg = 1; + nr0 = 1; + nr1 = 4; pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32].pipeline; - nrows = 4; } break; case GGML_TYPE_F16: { - nth0 = 32; - nth1 = 1; + nsg = 1; + nr0 = 1; if (src1t == GGML_TYPE_F32) { if (ne11 * ne12 < 4) { pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW].pipeline; } else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) { pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4].pipeline; - nrows = ne11; + nr1 = ne11; } else { pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32].pipeline; - nrows = 4; + nr1 = 4; } } else { pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16].pipeline; - nrows = 4; + nr1 = 4; } } break; case GGML_TYPE_BF16: { - nth0 = 32; - nth1 = 1; + nsg = 1; + nr0 = 1; if (src1t == GGML_TYPE_F32) { if (ne11 * ne12 < 4) { pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_1ROW].pipeline; } else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) { pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_L4].pipeline; - nrows = ne11; + nr1 = ne11; } else { pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32].pipeline; - nrows = 4; + nr1 = 4; } } else { pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_BF16].pipeline; - nrows = 4; + nr1 = 4; } } break; case GGML_TYPE_Q4_0: { - nth0 = 8; - nth1 = 8; + nsg = N_SG_Q4_0; + nr0 = N_R0_Q4_0; pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_0_F32].pipeline; } break; case GGML_TYPE_Q4_1: { - nth0 = 8; - nth1 = 8; + nsg = N_SG_Q4_1; + nr0 = N_R0_Q4_1; pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_1_F32].pipeline; } break; case GGML_TYPE_Q5_0: { - nth0 = 8; - nth1 = 8; + nsg = N_SG_Q5_0; + nr0 = N_R0_Q5_0; pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32].pipeline; } break; case GGML_TYPE_Q5_1: { - nth0 = 8; - nth1 = 8; + nsg = N_SG_Q5_1; + nr0 = N_R0_Q5_1; pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32].pipeline; } break; case GGML_TYPE_Q8_0: { - nth0 = 8; - nth1 = 8; + nsg = N_SG_Q8_0; + nr0 = N_R0_Q8_0; pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32].pipeline; } break; case GGML_TYPE_Q2_K: { - nth0 = 2; - nth1 = 32; + nsg = N_SG_Q2_K; + nr0 = N_R0_Q2_K; pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q2_K_F32].pipeline; } break; case GGML_TYPE_Q3_K: { - nth0 = 2; - nth1 = 32; + nsg = N_SG_Q3_K; + nr0 = N_R0_Q3_K; pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q3_K_F32].pipeline; } break; case GGML_TYPE_Q4_K: { - nth0 = 4; //1; - nth1 = 8; //32; + nsg = N_SG_Q4_K; + nr0 = N_R0_Q4_K; pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_K_F32].pipeline; } break; case GGML_TYPE_Q5_K: { - nth0 = 2; - nth1 = 32; + nsg = N_SG_Q5_K; + nr0 = N_R0_Q5_K; pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_K_F32].pipeline; } break; case GGML_TYPE_Q6_K: { - nth0 = 2; - nth1 = 32; + nsg = N_SG_Q6_K; + nr0 = N_R0_Q6_K; pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q6_K_F32].pipeline; } break; case GGML_TYPE_IQ2_XXS: { - nth0 = 4; - nth1 = 16; + nsg = N_SG_IQ2_XXS; + nr0 = N_R0_IQ2_XXS; + smem = 256*8+128; pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XXS_F32].pipeline; } break; case GGML_TYPE_IQ2_XS: { - nth0 = 4; - nth1 = 16; + nsg = N_SG_IQ2_XS; + nr0 = N_R0_IQ2_XS; + smem = 512*8+128; pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32].pipeline; } break; case GGML_TYPE_IQ3_XXS: { - nth0 = 4; - nth1 = 16; + nsg = N_SG_IQ3_XXS; + nr0 = N_R0_IQ3_XXS; + smem = 256*4+128; pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_XXS_F32].pipeline; } break; case GGML_TYPE_IQ3_S: { - nth0 = 4; - nth1 = 16; + nsg = N_SG_IQ3_S; + nr0 = N_R0_IQ3_S; + smem = 512*4; pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_S_F32].pipeline; } break; case GGML_TYPE_IQ2_S: { - nth0 = 4; - nth1 = 16; + nsg = N_SG_IQ2_S; + nr0 = N_R0_IQ2_S; pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_S_F32].pipeline; } break; case GGML_TYPE_IQ1_S: { - nth0 = 4; - nth1 = 16; + nsg = N_SG_IQ1_S; + nr0 = N_R0_IQ1_S; pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_S_F32].pipeline; } break; case GGML_TYPE_IQ1_M: { - nth0 = 4; - nth1 = 16; + nsg = N_SG_IQ1_M; + nr0 = N_R0_IQ1_M; pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_M_F32].pipeline; } break; case GGML_TYPE_IQ4_NL: { - nth0 = 4; - nth1 = 16; + nsg = N_SG_IQ4_NL; + nr0 = N_R0_IQ4_NL; + smem = 32*sizeof(float); pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32].pipeline; } break; case GGML_TYPE_IQ4_XS: { - nth0 = 4; - nth1 = 16; + nsg = N_SG_IQ4_XS; + nr0 = N_R0_IQ4_XS; + smem = 32*sizeof(float); pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_XS_F32].pipeline; } break; default: @@ -2762,41 +2771,10 @@ static void ggml_metal_encode_node( [encoder setBuffer:id_src1 offset:offs_src1 atIndex:2]; [encoder setBuffer:id_dst offset:offs_dst atIndex:3]; - if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 || src0t == GGML_TYPE_Q5_0 || - src0t == GGML_TYPE_Q5_1 || src0t == GGML_TYPE_Q8_0 || src0t == GGML_TYPE_Q2_K || - src0t == GGML_TYPE_IQ1_S || src0t == GGML_TYPE_IQ1_M || src0t == GGML_TYPE_IQ2_S) { - [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; - } - else if (src0t == GGML_TYPE_IQ2_XXS || src0t == GGML_TYPE_IQ2_XS) { - const int mem_size = src0t == GGML_TYPE_IQ2_XXS ? 256*8+128 : 512*8+128; - [encoder setThreadgroupMemoryLength:mem_size atIndex:0]; - [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; - } - else if (src0t == GGML_TYPE_IQ3_XXS || src0t == GGML_TYPE_IQ3_S) { - const int mem_size = src0t == GGML_TYPE_IQ3_XXS ? 256*4+128 : 512*4; - [encoder setThreadgroupMemoryLength:mem_size atIndex:0]; - [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; - } - else if (src0t == GGML_TYPE_IQ4_NL || src0t == GGML_TYPE_IQ4_XS) { - const int mem_size = 32*sizeof(float); - [encoder setThreadgroupMemoryLength:mem_size atIndex:0]; - [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; - } - else if (src0t == GGML_TYPE_Q4_K) { - [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; - } - else if (src0t == GGML_TYPE_Q3_K) { - [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; - } - else if (src0t == GGML_TYPE_Q5_K) { - [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; - } - else if (src0t == GGML_TYPE_Q6_K) { - [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; - } else { - const int64_t ny = (ne11 + nrows - 1)/nrows; - [encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; + if (smem > 0) { + [encoder setThreadgroupMemoryLength:smem atIndex:0]; } + [encoder dispatchThreadgroups:MTLSizeMake((ne01 + nr0*nsg - 1)/(nr0*nsg), (ne11 + nr1 - 1)/nr1, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)]; } } break; case GGML_OP_MUL_MAT_ID: @@ -2902,146 +2880,155 @@ static void ggml_metal_encode_node( [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 31)/32, (ne01 + 63)/64, n_as) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)]; } else { - int nth0 = 32; - int nth1 = 1; - int nrows = 1; - //printf("vector: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12); - id pipeline = nil; + int nsg = 0; // number of simdgroups + int nr0 = 0; // number of src0 rows per simdgroup + int nr1 = 1; // number of src1 rows per threadgroup + + size_t smem = 0; // shared memory + // use custom matrix x vector kernel switch (src0t) { case GGML_TYPE_F32: { GGML_ASSERT(src1t == GGML_TYPE_F32); + nsg = 1; + nr0 = 1; pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32].pipeline; } break; case GGML_TYPE_F16: { GGML_ASSERT(src1t == GGML_TYPE_F32); - nth0 = 32; - nth1 = 1; + nsg = 1; + nr0 = 1; pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32].pipeline; } break; case GGML_TYPE_BF16: { GGML_ASSERT(src1t == GGML_TYPE_F32); - nth0 = 32; - nth1 = 1; + nsg = 1; + nr0 = 1; pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_BF16_F32].pipeline; } break; case GGML_TYPE_Q4_0: { - nth0 = 8; - nth1 = 8; + nsg = N_SG_Q4_0; + nr0 = N_R0_Q4_0; pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_0_F32].pipeline; } break; case GGML_TYPE_Q4_1: { - nth0 = 8; - nth1 = 8; + nsg = N_SG_Q4_1; + nr0 = N_R0_Q4_1; pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_1_F32].pipeline; } break; case GGML_TYPE_Q5_0: { - nth0 = 8; - nth1 = 8; + nsg = N_SG_Q5_0; + nr0 = N_R0_Q5_0; pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_0_F32].pipeline; } break; case GGML_TYPE_Q5_1: { - nth0 = 8; - nth1 = 8; + nsg = N_SG_Q5_1; + nr0 = N_R0_Q5_1; pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_1_F32].pipeline; } break; case GGML_TYPE_Q8_0: { - nth0 = 8; - nth1 = 8; + nsg = N_SG_Q8_0; + nr0 = N_R0_Q8_0; pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q8_0_F32].pipeline; } break; case GGML_TYPE_Q2_K: { - nth0 = 2; - nth1 = 32; + nsg = N_SG_Q2_K; + nr0 = N_R0_Q2_K; pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q2_K_F32].pipeline; } break; case GGML_TYPE_Q3_K: { - nth0 = 2; - nth1 = 32; + nsg = N_SG_Q3_K; + nr0 = N_R0_Q3_K; pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q3_K_F32].pipeline; } break; case GGML_TYPE_Q4_K: { - nth0 = 4; //1; - nth1 = 8; //32; + nsg = N_SG_Q4_K; + nr0 = N_R0_Q4_K; pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_K_F32].pipeline; } break; case GGML_TYPE_Q5_K: { - nth0 = 2; - nth1 = 32; + nsg = N_SG_Q5_K; + nr0 = N_R0_Q5_K; pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_K_F32].pipeline; } break; case GGML_TYPE_Q6_K: { - nth0 = 2; - nth1 = 32; + nsg = N_SG_Q6_K; + nr0 = N_R0_Q6_K; pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q6_K_F32].pipeline; } break; case GGML_TYPE_IQ2_XXS: { - nth0 = 4; - nth1 = 16; + nsg = N_SG_IQ2_XXS; + nr0 = N_R0_IQ2_XXS; + smem = 256*8+128; pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XXS_F32].pipeline; } break; case GGML_TYPE_IQ2_XS: { - nth0 = 4; - nth1 = 16; + nsg = N_SG_IQ2_XS; + nr0 = N_R0_IQ2_XS; + smem = 512*8+128; pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32].pipeline; } break; case GGML_TYPE_IQ3_XXS: { - nth0 = 4; - nth1 = 16; + nsg = N_SG_IQ3_XXS; + nr0 = N_R0_IQ3_XXS; + smem = 256*4+128; pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_XXS_F32].pipeline; } break; case GGML_TYPE_IQ3_S: { - nth0 = 4; - nth1 = 16; + nsg = N_SG_IQ3_S; + nr0 = N_R0_IQ3_S; + smem = 512*4; pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_S_F32].pipeline; } break; case GGML_TYPE_IQ2_S: { - nth0 = 4; - nth1 = 16; + nsg = N_SG_IQ2_S; + nr0 = N_R0_IQ2_S; pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_S_F32].pipeline; } break; case GGML_TYPE_IQ1_S: { - nth0 = 4; - nth1 = 16; + nsg = N_SG_IQ1_S; + nr0 = N_R0_IQ1_S; pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_S_F32].pipeline; } break; case GGML_TYPE_IQ1_M: { - nth0 = 4; - nth1 = 16; + nsg = N_SG_IQ1_M; + nr0 = N_R0_IQ1_M; pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_M_F32].pipeline; } break; case GGML_TYPE_IQ4_NL: { - nth0 = 4; - nth1 = 16; + nsg = N_SG_IQ4_NL; + nr0 = N_R0_IQ4_NL; + smem = 32*sizeof(float); pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32].pipeline; } break; case GGML_TYPE_IQ4_XS: { - nth0 = 4; - nth1 = 16; + nsg = N_SG_IQ4_XS; + nr0 = N_R0_IQ4_XS; + smem = 32*sizeof(float); pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32].pipeline; } break; default: @@ -3052,7 +3039,7 @@ static void ggml_metal_encode_node( }; if (ggml_is_quantized(src0t)) { - GGML_ASSERT(ne00 >= nth0*nth1); + GGML_ASSERT(ne00 >= nsg*nr0); } ggml_metal_kargs_mul_mv_id args = { @@ -3085,43 +3072,12 @@ static void ggml_metal_encode_node( [encoder setBuffer:id_src2 offset:offs_src2 atIndex:4]; const int64_t _ne1 = 1; - const int tgz = dst_rows; + const int64_t ne123 = dst_rows; - if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 || src0t == GGML_TYPE_Q5_0 || - src0t == GGML_TYPE_Q5_1 || src0t == GGML_TYPE_Q8_0 || src0t == GGML_TYPE_Q2_K || - src0t == GGML_TYPE_IQ1_S || src0t == GGML_TYPE_IQ1_M || src0t == GGML_TYPE_IQ2_S) { - [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; - } - else if (src0t == GGML_TYPE_IQ2_XXS || src0t == GGML_TYPE_IQ2_XS) { - const int mem_size = src0t == GGML_TYPE_IQ2_XXS ? 256*8+128 : 512*8+128; - [encoder setThreadgroupMemoryLength:mem_size atIndex:0]; - [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; - } - else if (src0t == GGML_TYPE_IQ3_XXS || src0t == GGML_TYPE_IQ3_S) { - const int mem_size = src0t == GGML_TYPE_IQ3_XXS ? 256*4+128 : 512*4; - [encoder setThreadgroupMemoryLength:mem_size atIndex:0]; - [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; - } - else if (src0t == GGML_TYPE_IQ4_NL || src0t == GGML_TYPE_IQ4_XS) { - const int mem_size = 32*sizeof(float); - [encoder setThreadgroupMemoryLength:mem_size atIndex:0]; - [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; - } - else if (src0t == GGML_TYPE_Q4_K) { - [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; - } - else if (src0t == GGML_TYPE_Q3_K) { - [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; - } - else if (src0t == GGML_TYPE_Q5_K) { - [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; - } - else if (src0t == GGML_TYPE_Q6_K) { - [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; - } else { - const int64_t ny = (_ne1 + nrows - 1)/nrows; // = _ne1 - [encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; + if (smem > 0) { + [encoder setThreadgroupMemoryLength:smem atIndex:0]; } + [encoder dispatchThreadgroups:MTLSizeMake((ne01 + nr0*nsg - 1)/(nr0*nsg), (_ne1 + nr1 - 1)/nr1, ne123) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)]; } } break; case GGML_OP_GET_ROWS: diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index 3cef81b79..38f03efba 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -1439,7 +1439,7 @@ kernel void kernel_rwkv_wkv7_f32( float4 sa_vec(0.0); - for (int j = 0; j < head_size; j += 4) { + for (uint j = 0; j < head_size; j += 4) { float4 a_vec = float4(_a[j], _a[j+1], _a[j+2], _a[j+3]); float4 s_vec = float4(state[j], state[j+1], state[j+2], state[j+3]); sa_vec += a_vec * s_vec; @@ -1853,14 +1853,7 @@ inline float block_q_n_dot_y(device const block_q5_1 * qb_curr, float sumy, thre return d * (acc[0] + acc[1] + acc[2] + acc[3]) + sumy * m; } -// putting them in the kernel cause a significant performance penalty -#define N_DST 4 // each SIMD group works on 4 rows -#define N_SIMDGROUP 2 // number of SIMD groups in a thread group -//Note: This is a template, but strictly speaking it only applies to -// quantizations where the block size is 32. It also does not -// guard against the number of rows not being divisible by -// N_DST, so this is another explicit assumption of the implementation. -template +template void mul_vec_q_n_f32_impl( args_t args, device const char * src0, @@ -1876,7 +1869,7 @@ void mul_vec_q_n_f32_impl( const int r1 = tgpig.y; const int im = tgpig.z; - const int first_row = (r0 * nsg + sgitg) * nr; + const int first_row = (r0 * nsg + sgitg) * nr0; const uint i12 = im%args.ne12; const uint i13 = im/args.ne12; @@ -1888,15 +1881,15 @@ void mul_vec_q_n_f32_impl( device const float * y = (device const float *) (src1 + offset1); // pointers to src0 rows - device const block_q_type * ax[nr]; - for (int row = 0; row < nr; ++row) { + device const block_q_type * ax[nr0]; + for (int row = 0; row < nr0; ++row) { const uint64_t offset0 = (first_row + row)*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; ax[row] = (device const block_q_type *) ((device char *) src0 + offset0); } float yl[16]; // src1 vector cache - float sumf[nr] = {0.f}; + float sumf[nr0] = {0.f}; const short ix = (tiisg/2); const short il = (tiisg%2)*8; @@ -1908,7 +1901,7 @@ void mul_vec_q_n_f32_impl( float sumy[2] = { 0.f, 0.f }; #pragma unroll - for (int i = 0; i < 8; i += 2) { + for (short i = 0; i < 8; i += 2) { sumy[0] += yb[i + 0] + yb[i + 1]; yl[i + 0] = yb[i + 0]; yl[i + 1] = yb[i + 1]/256.f; @@ -1919,7 +1912,7 @@ void mul_vec_q_n_f32_impl( } #pragma unroll - for (int row = 0; row < nr; row++) { + for (short row = 0; row < nr0; row++) { sumf[row] += block_q_n_dot_y(ax[row] + ib, sumy[0] + sumy[1], yl, il); } @@ -1928,7 +1921,7 @@ void mul_vec_q_n_f32_impl( device float * dst_f32 = (device float *) dst + im*args.ne0*args.ne1 + r1*args.ne0; - for (int row = 0; row < nr; ++row) { + for (int row = 0; row < nr0; ++row) { const float tot = simd_sum(sumf[row]); if (tiisg == 0 && first_row + row < args.ne01) { @@ -1945,7 +1938,7 @@ kernel void kernel_mul_mv_q4_0_f32( uint3 tgpig[[threadgroup_position_in_grid]], ushort tiisg[[thread_index_in_simdgroup]], ushort sgitg[[simdgroup_index_in_threadgroup]]) { - mul_vec_q_n_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); + mul_vec_q_n_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); } kernel void kernel_mul_mv_q4_1_f32( @@ -1956,7 +1949,7 @@ kernel void kernel_mul_mv_q4_1_f32( uint3 tgpig[[threadgroup_position_in_grid]], ushort tiisg[[thread_index_in_simdgroup]], ushort sgitg[[simdgroup_index_in_threadgroup]]) { - mul_vec_q_n_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); + mul_vec_q_n_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); } kernel void kernel_mul_mv_q5_0_f32( @@ -1967,7 +1960,7 @@ kernel void kernel_mul_mv_q5_0_f32( uint3 tgpig[[threadgroup_position_in_grid]], ushort tiisg[[thread_index_in_simdgroup]], ushort sgitg[[simdgroup_index_in_threadgroup]]) { - mul_vec_q_n_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); + mul_vec_q_n_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); } kernel void kernel_mul_mv_q5_1_f32( @@ -1978,12 +1971,12 @@ kernel void kernel_mul_mv_q5_1_f32( uint3 tgpig[[threadgroup_position_in_grid]], ushort tiisg[[thread_index_in_simdgroup]], ushort sgitg[[simdgroup_index_in_threadgroup]]) { - mul_vec_q_n_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); + mul_vec_q_n_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); } #define NB_Q8_0 8 -template +template void kernel_mul_mv_q8_0_f32_impl( args_t args, device const char * src0, @@ -1993,16 +1986,13 @@ void kernel_mul_mv_q8_0_f32_impl( uint3 tgpig, ushort tiisg, ushort sgitg) { - const int nr = N_DST; - const int nsg = N_SIMDGROUP; - const int nw = N_SIMDWIDTH; - const int nb = args.ne00/QK8_0; + const int r0 = tgpig.x; const int r1 = tgpig.y; const int im = tgpig.z; - const int first_row = (r0*nsg + sgitg)*nr; + const int first_row = (r0 * nsg + sgitg) * nr0; const uint i12 = im%args.ne12; const uint i13 = im/args.ne12; @@ -2014,15 +2004,15 @@ void kernel_mul_mv_q8_0_f32_impl( device const float * y = (device const float *) (src1 + offset1); // pointers to src0 rows - device const block_q8_0 * ax[nr]; - for (int row = 0; row < nr; ++row) { + device const block_q8_0 * ax[nr0]; + for (int row = 0; row < nr0; ++row) { const uint64_t offset0 = (first_row + row)*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; ax[row] = (device const block_q8_0 *) ((device char *) src0 + offset0); } float yl[NB_Q8_0]; - float sumf[nr] = { 0.f }; + float sumf[nr0] = { 0.f }; const short ix = tiisg/4; const short il = tiisg%4; @@ -2035,7 +2025,7 @@ void kernel_mul_mv_q8_0_f32_impl( yl[i] = yb[i]; } - for (int row = 0; row < nr; row++) { + for (short row = 0; row < nr0; row++) { device const int8_t * qs = ax[row][ib].qs + il*NB_Q8_0; float sumq = 0.f; for (short iq = 0; iq < NB_Q8_0; ++iq) { @@ -2049,7 +2039,7 @@ void kernel_mul_mv_q8_0_f32_impl( device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; - for (int row = 0; row < nr; ++row) { + for (int row = 0; row < nr0; ++row) { const float tot = simd_sum(sumf[row]); if (tiisg == 0 && first_row + row < args.ne01) { @@ -2067,7 +2057,7 @@ kernel void kernel_mul_mv_q8_0_f32( uint3 tgpig[[threadgroup_position_in_grid]], ushort tiisg[[thread_index_in_simdgroup]], ushort sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_q8_0_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); + kernel_mul_mv_q8_0_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); } // mat-vec kernel processing in chunks of float4 @@ -2404,9 +2394,9 @@ void kernel_mul_mv_impl( sumf += (T0) x[i] * (T1) y[i]; } - float all_sum = simd_sum(sumf); + float sum_all = simd_sum(sumf); if (tiisg == 0) { - dst_f32[(uint64_t)r1*args.ne0 + r0] = all_sum; + dst_f32[(uint64_t)r1*args.ne0 + r0] = sum_all; } } } else { @@ -2427,10 +2417,10 @@ void kernel_mul_mv_impl( sumf += dot((float4) x4[i], (float4) y4[i]); } - float all_sum = simd_sum(sumf); + float sum_all = simd_sum(sumf); if (tiisg == 0) { - for (int i = 4*(args.ne00/4); i < args.ne00; ++i) all_sum += (float) (x[i] * y[i]); - dst_f32[(uint64_t)r1*args.ne0 + r0] = all_sum; + for (int i = 4*(args.ne00/4); i < args.ne00; ++i) sum_all += (float) (x[i] * y[i]); + dst_f32[(uint64_t)r1*args.ne0 + r0] = sum_all; } } } @@ -2492,9 +2482,9 @@ kernel void kernel_mul_mv_1row( for (int i = tiisg; i < args.ne00; i += 32) { sumf += (float) x[i] * (float) y[i]; } - float all_sum = simd_sum(sumf); + float sum_all = simd_sum(sumf); if (tiisg == 0) { - dst_f32[r0] = all_sum; + dst_f32[r0] = sum_all; } } else { device const T4 * x4 = (device const T4 *) x; @@ -2504,11 +2494,11 @@ kernel void kernel_mul_mv_1row( sumf += dot((float4) x4[i], y4[i]); } - float all_sum = simd_sum(sumf); + float sum_all = simd_sum(sumf); if (tiisg == 0) { - for (int i = 4*(args.ne00/4); i < args.ne00; ++i) all_sum += (float) (x[i] * y[i]); - dst_f32[r0] = all_sum; + for (int i = 4*(args.ne00/4); i < args.ne00; ++i) sum_all += (float) (x[i] * y[i]); + dst_f32[r0] = sum_all; } } } @@ -2553,9 +2543,9 @@ kernel void kernel_mul_mv_l4( sumf += dot((float4) x4[i], y4[i]); } - float all_sum = simd_sum(sumf); + float sum_all = simd_sum(sumf); if (tiisg == 0) { - dst_f32[(uint64_t)r1*args.ne0 + r0] = all_sum; + dst_f32[(uint64_t)r1*args.ne0 + r0] = sum_all; } } } @@ -4321,7 +4311,7 @@ kernel void kernel_cpy_f32_iq4_nl( float amax = 0.0f; // absolute max float max = 0.0f; - for (int j = 0; j < QK4_0; j++) { + for (int j = 0; j < QK4_NL; j++) { const float v = src[j]; if (amax < fabs(v)) { amax = fabs(v); @@ -4429,7 +4419,7 @@ kernel void kernel_concat( } } -template +template void kernel_mul_mv_q2_K_f32_impl( args_t args, device const char * src0, @@ -4445,7 +4435,7 @@ void kernel_mul_mv_q2_K_f32_impl( const int r1 = tgpig.y; const int im = tgpig.z; - const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST; + const int first_row = (r0 * nsg + sgitg) * nr0; const uint i12 = im%args.ne12; const uint i13 = im/args.ne12; @@ -4457,20 +4447,19 @@ void kernel_mul_mv_q2_K_f32_impl( device const float * y = (device const float *) (src1 + offset1); float yl[32]; - float sumf[N_DST]={0.f}, all_sum; + float sumf[nr0]={0.f}; - const int ix = tiisg/8; // 0...3 - const int it = tiisg%8; // 0...7 - const int iq = it/4; // 0 or 1 - const int ir = it%4; // 0...3 - const int is = (8*ir)/16;// 0 or 1 + const short ix = tiisg/8; // 0...3 + const short it = tiisg%8; // 0...7 + const short iq = it/4; // 0 or 1 + const short ir = it%4; // 0...3 + const short is = (8*ir)/16;// 0 or 1 device const float * y4 = y + ix * QK_K + 128 * iq + 8 * ir; for (int ib = ix; ib < nb; ib += 4) { - float4 sumy = {0.f, 0.f, 0.f, 0.f}; - for (int i = 0; i < 8; ++i) { + for (short i = 0; i < 8; ++i) { yl[i+ 0] = y4[i+ 0]; sumy[0] += yl[i+ 0]; yl[i+ 8] = y4[i+32]; sumy[1] += yl[i+ 8]; yl[i+16] = y4[i+64]; sumy[2] += yl[i+16]; @@ -4481,7 +4470,7 @@ void kernel_mul_mv_q2_K_f32_impl( device const uint16_t * qs = (device const uint16_t *)x[ib].qs + 16 * iq + 4 * ir; device const half * dh = &x[ib].d; - for (int row = 0; row < N_DST; row++) { + for (short row = 0; row < nr0; row++) { float4 acc1 = {0.f, 0.f, 0.f, 0.f}; float4 acc2 = {0.f, 0.f, 0.f, 0.f}; for (int i = 0; i < 8; i += 2) { @@ -4512,10 +4501,10 @@ void kernel_mul_mv_q2_K_f32_impl( device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; - for (int row = 0; row < N_DST && first_row + row < args.ne0; ++row) { - all_sum = simd_sum(sumf[row]); + for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) { + float sum_all = simd_sum(sumf[row]); if (tiisg == 0) { - dst_f32[first_row + row] = all_sum; + dst_f32[first_row + row] = sum_all; } } } @@ -4530,10 +4519,10 @@ kernel void kernel_mul_mv_q2_K_f32( ushort tiisg[[thread_index_in_simdgroup]], ushort sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_q2_K_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); + kernel_mul_mv_q2_K_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); } -template +template void kernel_mul_mv_q3_K_f32_impl( args_t args, device const char * src0, @@ -4550,7 +4539,7 @@ void kernel_mul_mv_q3_K_f32_impl( const int r1 = tgpig.y; const int im = tgpig.z; - const int first_row = (r0 * N_SIMDGROUP + sgitg) * 2; + const int first_row = (r0 * nsg + sgitg) * nr0; const uint i12 = im%args.ne12; const uint i13 = im/args.ne12; @@ -4566,13 +4555,12 @@ void kernel_mul_mv_q3_K_f32_impl( //const uint16_t kmask1 = 0x3030; //const uint16_t kmask2 = 0x0f0f; - const int tid = tiisg/4; - const int ix = tiisg%4; - const int ip = tid/4; // 0 or 1 - const int il = 2*((tid%4)/2); // 0 or 2 - const int ir = tid%2; - const int n = 8; - const int l0 = n*ir; + const short tid = tiisg/4; + const short ix = tiisg%4; + const short ip = tid/4; // 0 or 1 + const short il = 2*((tid%4)/2); // 0 or 2 + const short ir = tid%2; + const short l0 = 8*ir; // One would think that the Metal compiler would figure out that ip and il can only have // 4 possible states, and optimize accordingly. Well, no. It needs help, and we do it @@ -4597,8 +4585,8 @@ void kernel_mul_mv_q3_K_f32_impl( const uint16_t s_shift1 = 4*ip; const uint16_t s_shift2 = s_shift1 + il; - const int q_offset = 32*ip + l0; - const int y_offset = 128*ip + 32*il + l0; + const short q_offset = 32*ip + l0; + const short y_offset = 128*ip + 32*il + l0; device const float * y1 = yy + ix*QK_K + y_offset; @@ -4606,10 +4594,11 @@ void kernel_mul_mv_q3_K_f32_impl( thread uint16_t * scales16 = (thread uint16_t *)&scales32; thread const int8_t * scales = (thread const int8_t *)&scales32; - float sumf1[2] = {0.f}; - float sumf2[2] = {0.f}; + float sumf1[nr0] = {0.f}; + float sumf2[nr0] = {0.f}; + for (int i = ix; i < nb; i += 4) { - for (int l = 0; l < 8; ++l) { + for (short l = 0; l < 8; ++l) { yl[l+ 0] = y1[l+ 0]; yl[l+ 8] = y1[l+16]; yl[l+16] = y1[l+32]; @@ -4621,7 +4610,7 @@ void kernel_mul_mv_q3_K_f32_impl( device const uint16_t * a = (device const uint16_t *)(x[i].scales); device const half * dh = &x[i].d; - for (int row = 0; row < 2; ++row) { + for (short row = 0; row < nr0; ++row) { const float d_all = (float)dh[0]; scales16[0] = a[4]; @@ -4632,7 +4621,7 @@ void kernel_mul_mv_q3_K_f32_impl( scales32 = ((scales32 >> s_shift1) & 0x0f0f0f0f) | aux32; float s1 = 0, s2 = 0, s3 = 0, s4 = 0, s5 = 0, s6 = 0; - for (int l = 0; l < n; l += 2) { + for (short l = 0; l < 8; l += 2) { const int32_t qs = q[l/2]; s1 += yl[l+0] * (qs & qm[il/2][0]); s2 += yl[l+1] * (qs & qm[il/2][1]); @@ -4647,7 +4636,7 @@ void kernel_mul_mv_q3_K_f32_impl( sumf2[row] += d2 * (scales[2] - 32); s1 = s2 = s3 = s4 = s5 = s6 = 0; - for (int l = 0; l < n; l += 2) { + for (short l = 0; l < 8; l += 2) { const int32_t qs = q[l/2+8]; s1 += yl[l+8] * (qs & qm[il/2][0]); s2 += yl[l+9] * (qs & qm[il/2][1]); @@ -4670,7 +4659,7 @@ void kernel_mul_mv_q3_K_f32_impl( y1 += 4 * QK_K; } - for (int row = 0; row < 2; ++row) { + for (int row = 0; row < nr0; ++row) { const float sumf = (sumf1[row] + 0.25f * sumf2[row]) / (1 << shift); sumf1[row] = simd_sum(sumf); } @@ -4678,7 +4667,7 @@ void kernel_mul_mv_q3_K_f32_impl( device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; if (tiisg == 0) { - for (int row = 0; row < 2 && first_row + row < args.ne0; ++row) { + for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) { dst_f32[first_row + row] = sumf1[row]; } } @@ -4694,10 +4683,10 @@ kernel void kernel_mul_mv_q3_K_f32( ushort tiisg[[thread_index_in_simdgroup]], ushort sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_q3_K_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); + kernel_mul_mv_q3_K_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); } -template +template void kernel_mul_mv_q4_K_f32_impl( args_t args, device const char * src0, @@ -4707,22 +4696,22 @@ void kernel_mul_mv_q4_K_f32_impl( uint3 tgpig, ushort tiisg, ushort sgitg) { - const uint16_t kmask1 = 0x3f3f; const uint16_t kmask2 = 0x0f0f; const uint16_t kmask3 = 0xc0c0; - const int ix = tiisg/8; // 0...3 - const int it = tiisg%8; // 0...7 - const int iq = it/4; // 0 or 1 - const int ir = it%4; // 0...3 + const short ix = tiisg/8; // 0...3 + const short it = tiisg%8; // 0...7 + const short iq = it/4; // 0 or 1 + const short ir = it%4; // 0...3 const int nb = args.ne00/QK_K; + const int r0 = tgpig.x; const int r1 = tgpig.y; const int im = tgpig.z; - //const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST; - const int first_row = r0 * N_DST; + + const int first_row = (r0 * nsg + sgitg) * nr0; const uint i12 = im%args.ne12; const uint i13 = im/args.ne12; @@ -4735,7 +4724,8 @@ void kernel_mul_mv_q4_K_f32_impl( float yl[16]; float yh[16]; - float sumf[N_DST]={0.f}, all_sum; + + float sumf[nr0]={0.f}; device const float * y4 = y + ix * QK_K + 64 * iq + 8 * ir; @@ -4744,7 +4734,8 @@ void kernel_mul_mv_q4_K_f32_impl( for (int ib = ix; ib < nb; ib += 4) { float4 sumy = {0.f, 0.f, 0.f, 0.f}; - for (int i = 0; i < 8; ++i) { + + for (short i = 0; i < 8; ++i) { yl[i+0] = y4[i+ 0]; sumy[0] += yl[i+0]; yl[i+8] = y4[i+ 32]; sumy[1] += yl[i+8]; yh[i+0] = y4[i+128]; sumy[2] += yh[i+0]; @@ -4755,7 +4746,7 @@ void kernel_mul_mv_q4_K_f32_impl( device const uint16_t * q1 = (device const uint16_t *)x[ib].qs + 16 * iq + 4 * ir; device const half * dh = &x[ib].d; - for (int row = 0; row < N_DST; row++) { + for (short row = 0; row < nr0; row++) { sc16[0] = sc[0] & kmask1; sc16[1] = sc[2] & kmask1; sc16[2] = ((sc[4] >> 0) & kmask2) | ((sc[0] & kmask3) >> 2); @@ -4765,19 +4756,21 @@ void kernel_mul_mv_q4_K_f32_impl( float4 acc1 = {0.f, 0.f, 0.f, 0.f}; float4 acc2 = {0.f, 0.f, 0.f, 0.f}; - for (int i = 0; i < 8; i += 2) { - acc1[0] += yl[i+0] * (q1[i/2] & 0x000F); - acc1[1] += yl[i+1] * (q1[i/2] & 0x0F00); - acc1[2] += yl[i+8] * (q1[i/2] & 0x00F0); - acc1[3] += yl[i+9] * (q1[i/2] & 0xF000); - acc2[0] += yh[i+0] * (q2[i/2] & 0x000F); - acc2[1] += yh[i+1] * (q2[i/2] & 0x0F00); - acc2[2] += yh[i+8] * (q2[i/2] & 0x00F0); - acc2[3] += yh[i+9] * (q2[i/2] & 0xF000); + + for (short i = 0; i < 4; ++i) { + acc1[0] += yl[2*i + 0] * (q1[i] & 0x000F); + acc1[1] += yl[2*i + 1] * (q1[i] & 0x0F00); + acc1[2] += yl[2*i + 8] * (q1[i] & 0x00F0); + acc1[3] += yl[2*i + 9] * (q1[i] & 0xF000); + acc2[0] += yh[2*i + 0] * (q2[i] & 0x000F); + acc2[1] += yh[2*i + 1] * (q2[i] & 0x0F00); + acc2[2] += yh[2*i + 8] * (q2[i] & 0x00F0); + acc2[3] += yh[2*i + 9] * (q2[i] & 0xF000); } float dall = dh[0]; float dmin = dh[1]; + sumf[row] += dall * ((acc1[0] + 1.f/256.f * acc1[1]) * sc8[0] + (acc1[2] + 1.f/256.f * acc1[3]) * sc8[1] * 1.f/16.f + (acc2[0] + 1.f/256.f * acc2[1]) * sc8[4] + @@ -4794,10 +4787,10 @@ void kernel_mul_mv_q4_K_f32_impl( device float * dst_f32 = (device float *) dst + (int64_t)im*args.ne0*args.ne1 + (int64_t)r1*args.ne0; - for (int row = 0; row < N_DST && first_row + row < args.ne0; ++row) { - all_sum = simd_sum(sumf[row]); + for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) { + float sum_all = simd_sum(sumf[row]); if (tiisg == 0) { - dst_f32[first_row + row] = all_sum; + dst_f32[first_row + row] = sum_all; } } } @@ -4812,10 +4805,10 @@ kernel void kernel_mul_mv_q4_K_f32( ushort tiisg[[thread_index_in_simdgroup]], ushort sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_q4_K_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); + kernel_mul_mv_q4_K_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); } -template +template void kernel_mul_mv_q5_K_f32_impl( args_t args, device const char * src0, @@ -4832,7 +4825,7 @@ void kernel_mul_mv_q5_K_f32_impl( const int r1 = tgpig.y; const int im = tgpig.z; - const int first_row = (r0 * N_SIMDGROUP + sgitg) * 2; + const int first_row = (r0 * nsg + sgitg) * nr0; const uint i12 = im%args.ne12; const uint i13 = im/args.ne12; @@ -4843,7 +4836,7 @@ void kernel_mul_mv_q5_K_f32_impl( device const block_q5_K * x = (device const block_q5_K *) (src0 + offset0); device const float * yy = (device const float *) (src1 + offset1); - float sumf[2]={0.f}; + float sumf[nr0]={0.f}; float yl[16], yh[16]; @@ -4851,15 +4844,14 @@ void kernel_mul_mv_q5_K_f32_impl( const uint16_t kmask2 = 0x0f0f; const uint16_t kmask3 = 0xc0c0; - const int tid = tiisg/4; - const int ix = tiisg%4; - const int iq = tid/4; - const int ir = tid%4; - const int n = 8; + const short tid = tiisg/4; + const short ix = tiisg%4; + const short iq = tid/4; + const short ir = tid%4; - const int l0 = n*ir; - const int q_offset = 32*iq + l0; - const int y_offset = 64*iq + l0; + const short l0 = 8*ir; + const short q_offset = 32*iq + l0; + const short y_offset = 64*iq + l0; const uint8_t hm1 = 1u << (2*iq); const uint8_t hm2 = hm1 << 1; @@ -4879,14 +4871,14 @@ void kernel_mul_mv_q5_K_f32_impl( device const float * y2 = y1 + 128; float4 sumy = {0.f, 0.f, 0.f, 0.f}; - for (int l = 0; l < 8; ++l) { + for (short l = 0; l < 8; ++l) { yl[l+0] = y1[l+ 0]; sumy[0] += yl[l+0]; yl[l+8] = y1[l+32]; sumy[1] += yl[l+8]; yh[l+0] = y2[l+ 0]; sumy[2] += yh[l+0]; yh[l+8] = y2[l+32]; sumy[3] += yh[l+8]; } - for (int row = 0; row < 2; ++row) { + for (short row = 0; row < nr0; ++row) { device const uint8_t * q2 = q1 + 64; sc16[0] = a[0] & kmask1; @@ -4896,7 +4888,7 @@ void kernel_mul_mv_q5_K_f32_impl( float4 acc1 = {0.f}; float4 acc2 = {0.f}; - for (int l = 0; l < n; ++l) { + for (short l = 0; l < 8; ++l) { uint8_t h = qh[l]; acc1[0] += yl[l+0] * (q1[l] & 0x0F); acc1[1] += yl[l+8] * (q1[l] & 0xF0); @@ -4926,7 +4918,7 @@ void kernel_mul_mv_q5_K_f32_impl( device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; - for (int row = 0; row < 2 && first_row + row < args.ne0; ++row) { + for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) { const float tot = simd_sum(sumf[row]); if (tiisg == 0) { dst_f32[first_row + row] = tot; @@ -4944,10 +4936,10 @@ kernel void kernel_mul_mv_q5_K_f32( ushort tiisg[[thread_index_in_simdgroup]], ushort sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_q5_K_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); + kernel_mul_mv_q5_K_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); } -template +template void kernel_mul_mv_q6_K_f32_impl( args_t args, device const char * src0, @@ -4969,62 +4961,77 @@ void kernel_mul_mv_q6_K_f32_impl( const int r1 = tgpig.y; const int im = tgpig.z; - const int row = 2*r0 + sgitg; - - if (row >= args.ne0) { - return; - } + const int first_row = (r0 * nsg + sgitg) * nr0; const uint i12 = im%args.ne12; const uint i13 = im/args.ne12; - const uint64_t offset0 = row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; - const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; + const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; device const block_q6_K * x = (device const block_q6_K *) (src0 + offset0); device const float * yy = (device const float *) (src1 + offset1); - float sumf = 0; + float sumf[nr0] = { 0.f }; - const int tid = tiisg/2; - const int ix = tiisg%2; - const int ip = tid/8; // 0 or 1 - const int il = tid%8; - const int n = 4; - const int l0 = n*il; - const int is = 8*ip + l0/16; + float yl[16]; - const int y_offset = 128*ip + l0; - const int q_offset_l = 64*ip + l0; - const int q_offset_h = 32*ip + l0; + const short tid = tiisg/2; + const short ix = tiisg%2; + const short ip = tid/8; // 0 or 1 + const short il = tid%8; + const short l0 = 4*il; + const short is = 8*ip + l0/16; + + const short y_offset = 128*ip + l0; + const short q_offset_l = 64*ip + l0; + const short q_offset_h = 32*ip + l0; for (int i = ix; i < nb; i += 2) { device const uint8_t * q1 = x[i].ql + q_offset_l; device const uint8_t * q2 = q1 + 32; device const uint8_t * qh = x[i].qh + q_offset_h; device const int8_t * sc = x[i].scales + is; + device const half * dh = &x[i].d; device const float * y = yy + i * QK_K + y_offset; - const float dall = x[i].d; - - float4 sums = {0.f, 0.f, 0.f, 0.f}; - for (int l = 0; l < n; ++l) { - sums[0] += y[l+ 0] * ((int8_t)((q1[l] & 0xF) | ((qh[l] & kmask1) << 4)) - 32); - sums[1] += y[l+32] * ((int8_t)((q2[l] & 0xF) | ((qh[l] & kmask2) << 2)) - 32); - sums[2] += y[l+64] * ((int8_t)((q1[l] >> 4) | ((qh[l] & kmask3) << 0)) - 32); - sums[3] += y[l+96] * ((int8_t)((q2[l] >> 4) | ((qh[l] & kmask4) >> 2)) - 32); + for (short l = 0; l < 4; ++l) { + yl[4*l + 0] = y[l + 0]; + yl[4*l + 1] = y[l + 32]; + yl[4*l + 2] = y[l + 64]; + yl[4*l + 3] = y[l + 96]; } - sumf += dall * (sums[0] * sc[0] + sums[1] * sc[2] + sums[2] * sc[4] + sums[3] * sc[6]); + for (short row = 0; row < nr0; ++row) { + const float dall = dh[0]; + float4 sums = {0.f, 0.f, 0.f, 0.f}; + + for (short l = 0; l < 4; ++l) { + sums[0] += yl[4*l + 0] * ((int8_t)((q1[l] & 0xF) | ((qh[l] & kmask1) << 4)) - 32); + sums[1] += yl[4*l + 1] * ((int8_t)((q2[l] & 0xF) | ((qh[l] & kmask2) << 2)) - 32); + sums[2] += yl[4*l + 2] * ((int8_t)((q1[l] >> 4) | ((qh[l] & kmask3) << 0)) - 32); + sums[3] += yl[4*l + 3] * ((int8_t)((q2[l] >> 4) | ((qh[l] & kmask4) >> 2)) - 32); + } + + sumf[row] += dall * (sums[0] * sc[0] + sums[1] * sc[2] + sums[2] * sc[4] + sums[3] * sc[6]); + + q1 += args.nb01; + q2 += args.nb01; + qh += args.nb01; + sc += args.nb01; + dh += args.nb01/2; + } } device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; - const float tot = simd_sum(sumf); - if (tiisg == 0) { - dst_f32[row] = tot; + for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) { + float sum_all = simd_sum(sumf[row]); + if (tiisg == 0) { + dst_f32[first_row + row] = sum_all; + } } } @@ -5038,12 +5045,12 @@ kernel void kernel_mul_mv_q6_K_f32( ushort tiisg[[thread_index_in_simdgroup]], ushort sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_q6_K_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); + kernel_mul_mv_q6_K_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); } // ======================= "True" 2-bit -template +template void kernel_mul_mv_iq2_xxs_f32_impl( args_t args, device const char * src0, @@ -5059,7 +5066,7 @@ void kernel_mul_mv_iq2_xxs_f32_impl( const int r1 = tgpig.y; const int im = tgpig.z; - const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST; + const int first_row = (r0 * nsg + sgitg) * nr0; const uint i12 = im%args.ne12; const uint i13 = im/args.ne12; @@ -5071,7 +5078,7 @@ void kernel_mul_mv_iq2_xxs_f32_impl( device const float * y = (device const float *) (src1 + offset1); float yl[32]; - float sumf[N_DST]={0.f}, all_sum; + float sumf[nr0]={0.f}; const int nb32 = nb * (QK_K / 32); @@ -5092,8 +5099,7 @@ void kernel_mul_mv_iq2_xxs_f32_impl( device const float * y4 = y + 32 * ix; for (int ib32 = ix; ib32 < nb32; ib32 += 32) { - - for (int i = 0; i < 32; ++i) { + for (short i = 0; i < 32; ++i) { yl[i] = y4[i]; } @@ -5104,18 +5110,17 @@ void kernel_mul_mv_iq2_xxs_f32_impl( device const uint16_t * q2 = xr->qs + 4 * ib; device const half * dh = &xr->d; - for (int row = 0; row < N_DST; row++) { - + for (short row = 0; row < nr0; row++) { const float db = dh[0]; device const uint8_t * aux8 = (device const uint8_t *)q2; const uint32_t aux32 = q2[2] | (q2[3] << 16); const float d = db * (0.5f + (aux32 >> 28)); float sum = 0; - for (int l = 0; l < 4; ++l) { + for (short l = 0; l < 4; ++l) { const threadgroup uint8_t * grid = (const threadgroup uint8_t *)(svalues + aux8[l]); const uint8_t signs = ssigns[(aux32 >> 7*l) & 127]; - for (int j = 0; j < 8; ++j) { + for (short j = 0; j < 8; ++j) { sum += yl[8*l + j] * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f); } } @@ -5130,10 +5135,10 @@ void kernel_mul_mv_iq2_xxs_f32_impl( device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; - for (int row = 0; row < N_DST && first_row + row < args.ne0; ++row) { - all_sum = simd_sum(sumf[row]); + for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) { + float sum_all = simd_sum(sumf[row]); if (tiisg == 0) { - dst_f32[first_row + row] = all_sum * 0.25f; + dst_f32[first_row + row] = sum_all * 0.25f; } } } @@ -5148,10 +5153,10 @@ kernel void kernel_mul_mv_iq2_xxs_f32( uint3 tgpig[[threadgroup_position_in_grid]], ushort tiisg[[thread_index_in_simdgroup]], ushort sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_iq2_xxs_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); + kernel_mul_mv_iq2_xxs_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); } -template +template void kernel_mul_mv_iq2_xs_f32_impl( args_t args, device const char * src0, @@ -5167,7 +5172,7 @@ void kernel_mul_mv_iq2_xs_f32_impl( const int r1 = tgpig.y; const int im = tgpig.z; - const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST; + const int first_row = (r0 * nsg + sgitg) * nr0; const uint i12 = im%args.ne12; const uint i13 = im/args.ne12; @@ -5179,7 +5184,7 @@ void kernel_mul_mv_iq2_xs_f32_impl( device const float * y = (device const float *) (src1 + offset1); float yl[32]; - float sumf[N_DST]={0.f}, all_sum; + float sumf[nr0]={0.f}; const int nb32 = nb * (QK_K / 32); @@ -5200,8 +5205,7 @@ void kernel_mul_mv_iq2_xs_f32_impl( device const float * y4 = y + 32 * ix; for (int ib32 = ix; ib32 < nb32; ib32 += 32) { - - for (int i = 0; i < 32; ++i) { + for (short i = 0; i < 32; ++i) { yl[i] = y4[i]; } @@ -5213,8 +5217,7 @@ void kernel_mul_mv_iq2_xs_f32_impl( device const uint8_t * sc = xr->scales + ib; device const half * dh = &xr->d; - for (int row = 0; row < N_DST; row++) { - + for (short row = 0; row < nr0; row++) { const float db = dh[0]; const uint8_t ls1 = sc[0] & 0xf; const uint8_t ls2 = sc[0] >> 4; @@ -5222,17 +5225,17 @@ void kernel_mul_mv_iq2_xs_f32_impl( const float d2 = db * (0.5f + ls2); float sum1 = 0, sum2 = 0; - for (int l = 0; l < 2; ++l) { + for (short l = 0; l < 2; ++l) { const threadgroup uint8_t * grid = (const threadgroup uint8_t *)(svalues + (q2[l] & 511)); const uint8_t signs = ssigns[(q2[l] >> 9)]; - for (int j = 0; j < 8; ++j) { + for (short j = 0; j < 8; ++j) { sum1 += yl[8*l + j] * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f); } } - for (int l = 2; l < 4; ++l) { + for (short l = 2; l < 4; ++l) { const threadgroup uint8_t * grid = (const threadgroup uint8_t *)(svalues + (q2[l] & 511)); const uint8_t signs = ssigns[(q2[l] >> 9)]; - for (int j = 0; j < 8; ++j) { + for (short j = 0; j < 8; ++j) { sum2 += yl[8*l + j] * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f); } } @@ -5248,10 +5251,10 @@ void kernel_mul_mv_iq2_xs_f32_impl( device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; - for (int row = 0; row < N_DST && first_row + row < args.ne0; ++row) { - all_sum = simd_sum(sumf[row]); + for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) { + float sum_all = simd_sum(sumf[row]); if (tiisg == 0) { - dst_f32[first_row + row] = all_sum * 0.25f; + dst_f32[first_row + row] = sum_all * 0.25f; } } } @@ -5267,10 +5270,10 @@ kernel void kernel_mul_mv_iq2_xs_f32( ushort tiisg[[thread_index_in_simdgroup]], ushort sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_iq2_xs_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); + kernel_mul_mv_iq2_xs_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); } -template +template void kernel_mul_mv_iq3_xxs_f32_impl( args_t args, device const char * src0, @@ -5286,7 +5289,7 @@ void kernel_mul_mv_iq3_xxs_f32_impl( const int r1 = tgpig.y; const int im = tgpig.z; - const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST; + const int first_row = (r0 * nsg + sgitg) * nr0; const uint i12 = im%args.ne12; const uint i13 = im/args.ne12; @@ -5298,7 +5301,7 @@ void kernel_mul_mv_iq3_xxs_f32_impl( device const float * y = (device const float *) (src1 + offset1); float yl[32]; - float sumf[N_DST]={0.f}, all_sum; + float sumf[nr0]={0.f}; const int nb32 = nb * (QK_K / 32); @@ -5319,7 +5322,7 @@ void kernel_mul_mv_iq3_xxs_f32_impl( device const float * y4 = y + 32 * ix; for (int ib32 = ix; ib32 < nb32; ib32 += 32) { - for (int i = 0; i < 32; ++i) { + for (short i = 0; i < 32; ++i) { yl[i] = y4[i]; } @@ -5331,17 +5334,17 @@ void kernel_mul_mv_iq3_xxs_f32_impl( device const uint16_t * gas = (device const uint16_t *)(xr->qs + QK_K/4) + 2 * ib; device const half * dh = &xr->d; - for (int row = 0; row < N_DST; row++) { + for (short row = 0; row < nr0; row++) { const float db = dh[0]; const uint32_t aux32 = gas[0] | (gas[1] << 16); const float d = db * (0.5f + (aux32 >> 28)); float2 sum = {0}; - for (int l = 0; l < 4; ++l) { + for (short l = 0; l < 4; ++l) { const threadgroup uint8_t * grid1 = (const threadgroup uint8_t *)(svalues + q3[2*l+0]); const threadgroup uint8_t * grid2 = (const threadgroup uint8_t *)(svalues + q3[2*l+1]); const uint8_t signs = ssigns[(aux32 >> 7*l) & 127]; - for (int j = 0; j < 4; ++j) { + for (short j = 0; j < 4; ++j) { sum[0] += yl[8*l + j + 0] * grid1[j] * (signs & kmask_iq2xs[j+0] ? -1.f : 1.f); sum[1] += yl[8*l + j + 4] * grid2[j] * (signs & kmask_iq2xs[j+4] ? -1.f : 1.f); } @@ -5358,10 +5361,10 @@ void kernel_mul_mv_iq3_xxs_f32_impl( device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; - for (int row = 0; row < N_DST && first_row + row < args.ne0; ++row) { - all_sum = simd_sum(sumf[row]); + for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) { + float sum_all = simd_sum(sumf[row]); if (tiisg == 0) { - dst_f32[first_row + row] = all_sum * 0.5f; + dst_f32[first_row + row] = sum_all * 0.5f; } } } @@ -5377,10 +5380,10 @@ kernel void kernel_mul_mv_iq3_xxs_f32( ushort tiisg[[thread_index_in_simdgroup]], ushort sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_iq3_xxs_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); + kernel_mul_mv_iq3_xxs_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); } -template +template void kernel_mul_mv_iq3_s_f32_impl( args_t args, device const char * src0, @@ -5396,7 +5399,7 @@ void kernel_mul_mv_iq3_s_f32_impl( const int r1 = tgpig.y; const int im = tgpig.z; - const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST; + const int first_row = (r0 * nsg + sgitg) * nr0; const uint i12 = im%args.ne12; const uint i13 = im/args.ne12; @@ -5408,7 +5411,7 @@ void kernel_mul_mv_iq3_s_f32_impl( device const float * y = (device const float *) (src1 + offset1); float yl[32]; - float sumf[N_DST]={0.f}, all_sum; + float sumf[nr0]={0.f}; const int nb32 = nb * (QK_K / 32); @@ -5425,8 +5428,7 @@ void kernel_mul_mv_iq3_s_f32_impl( device const float * y4 = y + 32 * ix; for (int ib32 = ix; ib32 < nb32; ib32 += 32) { - - for (int i = 0; i < 32; ++i) { + for (short i = 0; i < 32; ++i) { yl[i] = y4[i]; } @@ -5440,18 +5442,17 @@ void kernel_mul_mv_iq3_s_f32_impl( device const uint8_t * signs = xr->signs + 4 * ib; device const half * dh = &xr->d; - for (int row = 0; row < N_DST; row++) { - + for (short row = 0; row < nr0; row++) { const float db = dh[0]; const float d = db * (1 + 2*((sc[0] >> 4*(ib%2)) & 0xf)); float2 sum = {0}; - for (int l = 0; l < 4; ++l) { + for (short l = 0; l < 4; ++l) { const threadgroup uint32_t * table1 = qh[0] & kmask_iq2xs[2*l+0] ? svalues + 256 : svalues; const threadgroup uint32_t * table2 = qh[0] & kmask_iq2xs[2*l+1] ? svalues + 256 : svalues; const threadgroup uint8_t * grid1 = (const threadgroup uint8_t *)(table1 + qs[2*l+0]); const threadgroup uint8_t * grid2 = (const threadgroup uint8_t *)(table2 + qs[2*l+1]); - for (int j = 0; j < 4; ++j) { + for (short j = 0; j < 4; ++j) { sum[0] += yl[8*l + j + 0] * grid1[j] * select(1, -1, signs[l] & kmask_iq2xs[j+0]); sum[1] += yl[8*l + j + 4] * grid2[j] * select(1, -1, signs[l] & kmask_iq2xs[j+4]); } @@ -5470,10 +5471,10 @@ void kernel_mul_mv_iq3_s_f32_impl( device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; - for (int row = 0; row < N_DST && first_row + row < args.ne0; ++row) { - all_sum = simd_sum(sumf[row]); + for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) { + float sum_all = simd_sum(sumf[row]); if (tiisg == 0) { - dst_f32[first_row + row] = all_sum; + dst_f32[first_row + row] = sum_all; } } } @@ -5489,10 +5490,10 @@ kernel void kernel_mul_mv_iq3_s_f32( ushort tiisg[[thread_index_in_simdgroup]], ushort sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_iq3_s_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); + kernel_mul_mv_iq3_s_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); } -template +template void kernel_mul_mv_iq2_s_f32_impl( args_t args, device const char * src0, @@ -5508,7 +5509,7 @@ void kernel_mul_mv_iq2_s_f32_impl( const int r1 = tgpig.y; const int im = tgpig.z; - const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST; + const int first_row = (r0 * nsg + sgitg) * nr0; const uint i12 = im%args.ne12; const uint i13 = im/args.ne12; @@ -5520,7 +5521,7 @@ void kernel_mul_mv_iq2_s_f32_impl( device const float * y = (device const float *) (src1 + offset1); float yl[32]; - float sumf[N_DST]={0.f}, all_sum; + float sumf[nr0]={0.f}; const int nb32 = nb * (QK_K / 32); @@ -5532,13 +5533,12 @@ void kernel_mul_mv_iq2_s_f32_impl( // threadgroup_barrier(mem_flags::mem_threadgroup); //} - const int ix = tiisg; + const short ix = tiisg; device const float * y4 = y + 32 * ix; for (int ib32 = ix; ib32 < nb32; ib32 += 32) { - - for (int i = 0; i < 32; ++i) { + for (short i = 0; i < 32; ++i) { yl[i] = y4[i]; } @@ -5552,19 +5552,18 @@ void kernel_mul_mv_iq2_s_f32_impl( device const uint8_t * signs = qs + QK_K/8; device const half * dh = &xr->d; - for (int row = 0; row < N_DST; row++) { - + for (short row = 0; row < nr0; row++) { const float db = dh[0]; const float d1 = db * (0.5f + (sc[0] & 0xf)); const float d2 = db * (0.5f + (sc[0] >> 4)); float2 sum = {0}; - for (int l = 0; l < 2; ++l) { + for (short l = 0; l < 2; ++l) { //const threadgroup uint8_t * grid1 = (const threadgroup uint8_t *)(svalues + (qs[l+0] | ((qh[0] << (8-2*l)) & 0x300))); //const threadgroup uint8_t * grid2 = (const threadgroup uint8_t *)(svalues + (qs[l+2] | ((qh[0] << (4-2*l)) & 0x300))); constant uint8_t * grid1 = (constant uint8_t *)(iq2s_grid + (qs[l+0] | ((qh[0] << (8-2*l)) & 0x300))); constant uint8_t * grid2 = (constant uint8_t *)(iq2s_grid + (qs[l+2] | ((qh[0] << (4-2*l)) & 0x300))); - for (int j = 0; j < 8; ++j) { + for (short j = 0; j < 8; ++j) { sum[0] += yl[8*l + j + 0] * grid1[j] * select(1, -1, signs[l+0] & kmask_iq2xs[j]); sum[1] += yl[8*l + j + 16] * grid2[j] * select(1, -1, signs[l+2] & kmask_iq2xs[j]); } @@ -5583,10 +5582,10 @@ void kernel_mul_mv_iq2_s_f32_impl( device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; - for (int row = 0; row < N_DST && first_row + row < args.ne0; ++row) { - all_sum = simd_sum(sumf[row]); + for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) { + float sum_all = simd_sum(sumf[row]); if (tiisg == 0) { - dst_f32[first_row + row] = all_sum * 0.25f; + dst_f32[first_row + row] = sum_all * 0.25f; } } } @@ -5602,10 +5601,10 @@ kernel void kernel_mul_mv_iq2_s_f32( ushort tiisg[[thread_index_in_simdgroup]], ushort sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_iq2_s_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); + kernel_mul_mv_iq2_s_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); } -template +template void kernel_mul_mv_iq1_s_f32_impl( args_t args, device const char * src0, @@ -5621,7 +5620,7 @@ void kernel_mul_mv_iq1_s_f32_impl( const int r1 = tgpig.y; const int im = tgpig.z; - const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST; + const int first_row = (r0 * nsg + sgitg) * nr0; const uint i12 = im%args.ne12; const uint i13 = im/args.ne12; @@ -5633,18 +5632,17 @@ void kernel_mul_mv_iq1_s_f32_impl( device const float * y = (device const float *) (src1 + offset1); float yl[32]; - float sumf[N_DST]={0.f}, all_sum; + float sumf[nr0]={0.f}; const int nb32 = nb * (QK_K / 32); - const int ix = tiisg; + const short ix = tiisg; device const float * y4 = y + 32 * ix; for (int ib32 = ix; ib32 < nb32; ib32 += 32) { - float sumy = 0; - for (int i = 0; i < 32; ++i) { + for (short i = 0; i < 32; ++i) { yl[i] = y4[i]; sumy += yl[i]; } @@ -5657,15 +5655,14 @@ void kernel_mul_mv_iq1_s_f32_impl( device const uint16_t * qh = xr->qh + ib; device const half * dh = &xr->d; - for (int row = 0; row < N_DST; row++) { - + for (short row = 0; row < nr0; row++) { constant uint8_t * grid1 = (constant uint8_t *)(iq1s_grid_gpu + (qs[0] | ((qh[0] << 8) & 0x700))); constant uint8_t * grid2 = (constant uint8_t *)(iq1s_grid_gpu + (qs[1] | ((qh[0] << 5) & 0x700))); constant uint8_t * grid3 = (constant uint8_t *)(iq1s_grid_gpu + (qs[2] | ((qh[0] << 2) & 0x700))); constant uint8_t * grid4 = (constant uint8_t *)(iq1s_grid_gpu + (qs[3] | ((qh[0] >> 1) & 0x700))); float sum = 0; - for (int j = 0; j < 4; ++j) { + for (short j = 0; j < 4; ++j) { sum += yl[j+ 0] * (grid1[j] & 0xf) + yl[j+ 4] * (grid1[j] >> 4) + yl[j+ 8] * (grid2[j] & 0xf) + yl[j+12] * (grid2[j] >> 4) + yl[j+16] * (grid3[j] & 0xf) + yl[j+20] * (grid3[j] >> 4) @@ -5683,15 +5680,28 @@ void kernel_mul_mv_iq1_s_f32_impl( device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; - for (int row = 0; row < N_DST && first_row + row < args.ne0; ++row) { - all_sum = simd_sum(sumf[row]); + for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) { + float sum_all = simd_sum(sumf[row]); if (tiisg == 0) { - dst_f32[first_row + row] = all_sum; + dst_f32[first_row + row] = sum_all; } } } -template +[[host_name("kernel_mul_mv_iq1_s_f32")]] +kernel void kernel_mul_mv_iq1_s_f32( + constant ggml_metal_kargs_mul_mv & args, + device const char * src0, + device const char * src1, + device char * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { + + kernel_mul_mv_iq1_s_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); +} + +template void kernel_mul_mv_iq1_m_f32_impl( args_t args, device const char * src0, @@ -5703,11 +5713,12 @@ void kernel_mul_mv_iq1_m_f32_impl( ushort sgitg) { const int nb = args.ne00/QK_K; + const int r0 = tgpig.x; const int r1 = tgpig.y; const int im = tgpig.z; - const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST; + const int first_row = (r0 * nsg + sgitg) * nr0; const uint i12 = im%args.ne12; const uint i13 = im/args.ne12; @@ -5719,20 +5730,19 @@ void kernel_mul_mv_iq1_m_f32_impl( device const float * y = (device const float *) (src1 + offset1); float yl[32]; - float sumf[N_DST]={0.f}, all_sum; + float sumf[nr0]={0.f}; const int nb32 = nb * (QK_K / 32); - const int ix = tiisg; + const short ix = tiisg; device const float * y4 = y + 32 * ix; iq1m_scale_t scale; for (int ib32 = ix; ib32 < nb32; ib32 += 32) { - float4 sumy = {0.f}; - for (int i = 0; i < 8; ++i) { + for (short i = 0; i < 8; ++i) { yl[i+ 0] = y4[i+ 0]; sumy[0] += yl[i+ 0]; yl[i+ 8] = y4[i+ 8]; sumy[1] += yl[i+ 8]; yl[i+16] = y4[i+16]; sumy[2] += yl[i+16]; @@ -5747,7 +5757,7 @@ void kernel_mul_mv_iq1_m_f32_impl( device const uint8_t * qh = xr->qh + 2 * ib; device const uint16_t * sc = (device const uint16_t *)xr->scales; - for (int row = 0; row < N_DST; row++) { + for (short row = 0; row < nr0; row++) { scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000); constant uint8_t * grid1 = (constant uint8_t *)(iq1s_grid_gpu + (qs[0] | ((qh[0] << 8) & 0x700))); @@ -5756,7 +5766,7 @@ void kernel_mul_mv_iq1_m_f32_impl( constant uint8_t * grid4 = (constant uint8_t *)(iq1s_grid_gpu + (qs[3] | ((qh[1] << 4) & 0x700))); float2 sum = {0.f}; - for (int j = 0; j < 4; ++j) { + for (short j = 0; j < 4; ++j) { sum[0] += yl[j+ 0] * (grid1[j] & 0xf) + yl[j+ 4] * (grid1[j] >> 4) + yl[j+ 8] * (grid2[j] & 0xf) + yl[j+12] * (grid2[j] >> 4); sum[1] += yl[j+16] * (grid3[j] & 0xf) + yl[j+20] * (grid3[j] >> 4) @@ -5778,15 +5788,28 @@ void kernel_mul_mv_iq1_m_f32_impl( device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; - for (int row = 0; row < N_DST && first_row + row < args.ne0; ++row) { - all_sum = simd_sum(sumf[row]); + for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) { + float sum_all = simd_sum(sumf[row]); if (tiisg == 0) { - dst_f32[first_row + row] = all_sum; + dst_f32[first_row + row] = sum_all; } } } -template +[[host_name("kernel_mul_mv_iq1_m_f32")]] +kernel void kernel_mul_mv_iq1_m_f32( + constant ggml_metal_kargs_mul_mv & args, + device const char * src0, + device const char * src1, + device char * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { + + kernel_mul_mv_iq1_m_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); +} + +template void kernel_mul_mv_iq4_nl_f32_impl( args_t args, device const char * src0, @@ -5799,10 +5822,12 @@ void kernel_mul_mv_iq4_nl_f32_impl( threadgroup float * shmem_f32 = (threadgroup float *) shmem; const int nb = args.ne00/QK4_NL; + const int r0 = tgpig.x; const int r1 = tgpig.y; const int im = tgpig.z; - const int first_row = (r0 * 2 + sgitg) * 2; + + const int first_row = (r0 * nsg + sgitg) * nr0; const uint i12 = im%args.ne12; const uint i13 = im/args.ne12; @@ -5813,14 +5838,14 @@ void kernel_mul_mv_iq4_nl_f32_impl( device const block_iq4_nl * x = (device const block_iq4_nl *) (src0 + offset0); device const float * y = (device const float *) (src1 + offset1); - const int ix = tiisg/2; // 0...15 - const int it = tiisg%2; // 0 or 1 + const short ix = tiisg/2; // 0...15 + const short it = tiisg%2; // 0 or 1 shmem_f32[tiisg] = kvalues_iq4nl_f[tiisg%16]; threadgroup_barrier(mem_flags::mem_threadgroup); float4 yl[4]; - float sumf[2]={0.f}, all_sum; + float sumf[nr0]={0.f}; device const float * yb = y + ix * QK4_NL + it * 8; @@ -5830,12 +5855,13 @@ void kernel_mul_mv_iq4_nl_f32_impl( float4 qf1, qf2; for (int ib = ix; ib < nb; ib += 16) { - device const float4 * y4 = (device const float4 *)yb; - yl[0] = y4[0]; yl[1] = y4[4]; yl[2] = y4[1]; yl[3] = y4[5]; - - for (int row = 0; row < 2 && first_row + row < args.ne01; ++row) { + yl[0] = y4[0]; + yl[1] = y4[4]; + yl[2] = y4[1]; + yl[3] = y4[5]; + for (short row = 0; row < nr0; row++) { device const block_iq4_nl & xb = x[row*nb + ib]; device const uint16_t * q4 = (device const uint16_t *)(xb.qs + 8*it); @@ -5860,7 +5886,6 @@ void kernel_mul_mv_iq4_nl_f32_impl( acc1 += acc2; sumf[row] += (float)xb.d * (acc1[0] + acc1[1] + acc1[2] + acc1[3]); - } yb += 16 * QK4_NL; @@ -5868,15 +5893,29 @@ void kernel_mul_mv_iq4_nl_f32_impl( device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; - for (int row = 0; row < 2 && first_row + row < args.ne0; ++row) { - all_sum = simd_sum(sumf[row]); + for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) { + float sum_all = simd_sum(sumf[row]); if (tiisg == 0) { - dst_f32[first_row + row] = all_sum; + dst_f32[first_row + row] = sum_all; } } } -template +[[host_name("kernel_mul_mv_iq4_nl_f32")]] +kernel void kernel_mul_mv_iq4_nl_f32( + constant ggml_metal_kargs_mul_mv & args, + device const char * src0, + device const char * src1, + device char * dst, + threadgroup char * shmem [[threadgroup(0)]], + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { + + kernel_mul_mv_iq4_nl_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); +} + +template void kernel_mul_mv_iq4_xs_f32_impl( args_t args, device const char * src0, @@ -5892,7 +5931,7 @@ void kernel_mul_mv_iq4_xs_f32_impl( const int r0 = tgpig.x; const int r1 = tgpig.y; const int im = tgpig.z; - const int first_row = (r0 * 2 + sgitg) * 2; + const int first_row = (r0 * nsg + sgitg) * nr0; const uint i12 = im%args.ne12; const uint i13 = im/args.ne12; @@ -5903,16 +5942,16 @@ void kernel_mul_mv_iq4_xs_f32_impl( device const block_iq4_xs * x = (device const block_iq4_xs *) (src0 + offset0); device const float * y = (device const float *) (src1 + offset1); - const int ix = tiisg/16; // 0 or 1 - const int it = tiisg%16; // 0...15 - const int ib = it/2; - const int il = it%2; + const short ix = tiisg/16; // 0 or 1 + const short it = tiisg%16; // 0...15 + const short ib = it/2; + const short il = it%2; shmem_f32[tiisg] = kvalues_iq4nl_f[tiisg%16]; threadgroup_barrier(mem_flags::mem_threadgroup); float4 yl[4]; - float sumf[2]={0.f}, all_sum; + float sumf[nr0]={0.f}; device const float * yb = y + ix * QK_K + ib * 32 + il * 8; @@ -5923,9 +5962,12 @@ void kernel_mul_mv_iq4_xs_f32_impl( for (int ibl = ix; ibl < nb; ibl += 2) { device const float4 * y4 = (device const float4 *)yb; - yl[0] = y4[0]; yl[1] = y4[4]; yl[2] = y4[1]; yl[3] = y4[5]; + yl[0] = y4[0]; + yl[1] = y4[4]; + yl[2] = y4[1]; + yl[3] = y4[5]; - for (int row = 0; row < 2; ++row) { + for (short row = 0; row < nr0; ++row) { device const block_iq4_xs & xb = x[row*nb + ibl]; device const uint32_t * q4 = (device const uint32_t *)(xb.qs + 16*ib + 8*il); @@ -5949,7 +5991,6 @@ void kernel_mul_mv_iq4_xs_f32_impl( const int ls = (((xb.scales_l[ib/2] >> 4*(ib%2)) & 0xf) | (((xb.scales_h >> 2*ib) & 3) << 4)) - 32; sumf[row] += (float)xb.d * ls * (acc1[0] + acc1[1] + acc1[2] + acc1[3]); - } yb += 2 * QK_K; @@ -5957,54 +5998,14 @@ void kernel_mul_mv_iq4_xs_f32_impl( device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; - for (int row = 0; row < 2 && first_row + row < args.ne0; ++row) { - all_sum = simd_sum(sumf[row]); + for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) { + float sum_all = simd_sum(sumf[row]); if (tiisg == 0) { - dst_f32[first_row + row] = all_sum; + dst_f32[first_row + row] = sum_all; } } } -[[host_name("kernel_mul_mv_iq1_s_f32")]] -kernel void kernel_mul_mv_iq1_s_f32( - constant ggml_metal_kargs_mul_mv & args, - device const char * src0, - device const char * src1, - device char * dst, - uint3 tgpig[[threadgroup_position_in_grid]], - ushort tiisg[[thread_index_in_simdgroup]], - ushort sgitg[[simdgroup_index_in_threadgroup]]) { - - kernel_mul_mv_iq1_s_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); -} - -[[host_name("kernel_mul_mv_iq1_m_f32")]] -kernel void kernel_mul_mv_iq1_m_f32( - constant ggml_metal_kargs_mul_mv & args, - device const char * src0, - device const char * src1, - device char * dst, - uint3 tgpig[[threadgroup_position_in_grid]], - ushort tiisg[[thread_index_in_simdgroup]], - ushort sgitg[[simdgroup_index_in_threadgroup]]) { - - kernel_mul_mv_iq1_m_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); -} - -[[host_name("kernel_mul_mv_iq4_nl_f32")]] -kernel void kernel_mul_mv_iq4_nl_f32( - constant ggml_metal_kargs_mul_mv & args, - device const char * src0, - device const char * src1, - device char * dst, - threadgroup char * shmem [[threadgroup(0)]], - uint3 tgpig[[threadgroup_position_in_grid]], - ushort tiisg[[thread_index_in_simdgroup]], - ushort sgitg[[simdgroup_index_in_threadgroup]]) { - - kernel_mul_mv_iq4_nl_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); -} - [[host_name("kernel_mul_mv_iq4_xs_f32")]] kernel void kernel_mul_mv_iq4_xs_f32( constant ggml_metal_kargs_mul_mv & args, @@ -6016,7 +6017,7 @@ kernel void kernel_mul_mv_iq4_xs_f32( ushort tiisg[[thread_index_in_simdgroup]], ushort sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_iq4_xs_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); + kernel_mul_mv_iq4_xs_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); } template @@ -6660,25 +6661,27 @@ template [[host_name("kernel_mul_mv_id_f16_f32")]] kernel kernel_mul_mv_id_t #if defined(GGML_METAL_USE_BF16) template [[host_name("kernel_mul_mv_id_bf16_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; #endif -template [[host_name("kernel_mul_mv_id_q8_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; -template [[host_name("kernel_mul_mv_id_q4_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; -template [[host_name("kernel_mul_mv_id_q4_1_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; -template [[host_name("kernel_mul_mv_id_q5_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; -template [[host_name("kernel_mul_mv_id_q5_1_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; -template [[host_name("kernel_mul_mv_id_q2_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; -template [[host_name("kernel_mul_mv_id_q3_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; -template [[host_name("kernel_mul_mv_id_q4_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; -template [[host_name("kernel_mul_mv_id_q5_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; -template [[host_name("kernel_mul_mv_id_q6_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; -template [[host_name("kernel_mul_mv_id_iq1_s_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; -template [[host_name("kernel_mul_mv_id_iq1_m_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; -template [[host_name("kernel_mul_mv_id_iq2_xxs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; -template [[host_name("kernel_mul_mv_id_iq2_xs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; -template [[host_name("kernel_mul_mv_id_iq3_xxs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; -template [[host_name("kernel_mul_mv_id_iq3_s_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; -template [[host_name("kernel_mul_mv_id_iq2_s_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; -template [[host_name("kernel_mul_mv_id_iq4_nl_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; -template [[host_name("kernel_mul_mv_id_iq4_xs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; +template [[host_name("kernel_mul_mv_id_q8_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; + +template [[host_name("kernel_mul_mv_id_q4_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; +template [[host_name("kernel_mul_mv_id_q4_1_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; +template [[host_name("kernel_mul_mv_id_q5_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; +template [[host_name("kernel_mul_mv_id_q5_1_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; + +template [[host_name("kernel_mul_mv_id_q2_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; +template [[host_name("kernel_mul_mv_id_q3_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; +template [[host_name("kernel_mul_mv_id_q4_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; +template [[host_name("kernel_mul_mv_id_q5_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; +template [[host_name("kernel_mul_mv_id_q6_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; +template [[host_name("kernel_mul_mv_id_iq1_s_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; +template [[host_name("kernel_mul_mv_id_iq1_m_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; +template [[host_name("kernel_mul_mv_id_iq2_xxs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; +template [[host_name("kernel_mul_mv_id_iq2_xs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; +template [[host_name("kernel_mul_mv_id_iq3_xxs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; +template [[host_name("kernel_mul_mv_id_iq3_s_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; +template [[host_name("kernel_mul_mv_id_iq2_s_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; +template [[host_name("kernel_mul_mv_id_iq4_nl_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; +template [[host_name("kernel_mul_mv_id_iq4_xs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; kernel void kernel_pool_2d_max_f32( device const float * src0, From bd40678df768ab189b26b8b03e3de4062e3b71a3 Mon Sep 17 00:00:00 2001 From: Slobodan Josic <127323561+slojosic-amd@users.noreply.github.com> Date: Wed, 26 Mar 2025 23:46:30 +0100 Subject: [PATCH 04/26] HIP: Add support for RDNA4 targets (#12372) --- docs/build.md | 2 +- ggml/src/ggml-cuda/common.cuh | 18 ++++++++++-------- ggml/src/ggml-cuda/ggml-cuda.cu | 8 +++++--- ggml/src/ggml-cuda/mmq.cu | 2 +- ggml/src/ggml-cuda/mmq.cuh | 4 ++-- ggml/src/ggml-cuda/mmvq.cu | 4 ++-- ggml/src/ggml-cuda/vendors/hip.h | 4 ++++ 7 files changed, 25 insertions(+), 17 deletions(-) diff --git a/docs/build.md b/docs/build.md index aa1db9a04..9c1314a29 100644 --- a/docs/build.md +++ b/docs/build.md @@ -191,7 +191,7 @@ The following compilation options are also available to tweak performance: | Option | Legal values | Default | Description | |-------------------------------|------------------------|---------|-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| -| GGML_CUDA_FORCE_MMQ | Boolean | false | Force the use of custom matrix multiplication kernels for quantized models instead of FP16 cuBLAS even if there is no int8 tensor core implementation available (affects V100, RDNA3). MMQ kernels are enabled by default on GPUs with int8 tensor core support. With MMQ force enabled, speed for large batch sizes will be worse but VRAM consumption will be lower. | +| GGML_CUDA_FORCE_MMQ | Boolean | false | Force the use of custom matrix multiplication kernels for quantized models instead of FP16 cuBLAS even if there is no int8 tensor core implementation available (affects V100, CDNA and RDNA3+). MMQ kernels are enabled by default on GPUs with int8 tensor core support. With MMQ force enabled, speed for large batch sizes will be worse but VRAM consumption will be lower. | | GGML_CUDA_FORCE_CUBLAS | Boolean | false | Force the use of FP16 cuBLAS instead of custom matrix multiplication kernels for quantized models | | GGML_CUDA_F16 | Boolean | false | If enabled, use half-precision floating point arithmetic for the CUDA dequantization + mul mat vec kernels and for the q4_1 and q5_1 matrix matrix multiplication kernels. Can improve performance on relatively recent GPUs. | | GGML_CUDA_PEER_MAX_BATCH_SIZE | Positive integer | 128 | Maximum batch size for which to enable peer access between multiple GPUs. Peer access requires either Linux or NVLink. When using NVLink enabling peer access for larger batch sizes is potentially beneficial. | diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index 954ff5f16..f8c55a2b8 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -52,7 +52,7 @@ #define GGML_CUDA_CC_IS_NVIDIA(cc) (cc < GGML_CUDA_CC_OFFSET_MTHREADS) // AMD -// GCN/CNDA, wave size is 64 +// GCN/CDNA, wave size is 64 #define GGML_CUDA_CC_GCN4 (GGML_CUDA_CC_OFFSET_AMD + 0x803) // Tonga, Fiji, Polaris, minimum for fast fp16 #define GGML_CUDA_CC_VEGA (GGML_CUDA_CC_OFFSET_AMD + 0x900) // Vega56/64, minimum for fp16 dual issue #define GGML_CUDA_CC_VEGA20 (GGML_CUDA_CC_OFFSET_AMD + 0x906) // MI50/Radeon VII, minimum for dp4a @@ -60,16 +60,18 @@ #define GGML_CUDA_CC_CDNA2 (GGML_CUDA_CC_OFFSET_AMD + 0x910) // MI210, minimum acc register renameing #define GGML_CUDA_CC_CDNA3 (GGML_CUDA_CC_OFFSET_AMD + 0x942) // MI300 -// RNDA removes MFMA, dp4a, xnack, acc registers, wave size is 32 +// RDNA removes MFMA, dp4a, xnack, acc registers, wave size is 32 #define GGML_CUDA_CC_RDNA1 (GGML_CUDA_CC_OFFSET_AMD + 0x1010) // RX 5000 #define GGML_CUDA_CC_RDNA2 (GGML_CUDA_CC_OFFSET_AMD + 0x1030) // RX 6000, minimum for dp4a #define GGML_CUDA_CC_RDNA3 (GGML_CUDA_CC_OFFSET_AMD + 0x1100) // RX 7000, minimum for WMMA +#define GGML_CUDA_CC_RDNA4 (GGML_CUDA_CC_OFFSET_AMD + 0x1200) // RX 9000 #define GGML_CUDA_CC_IS_AMD(cc) (cc >= GGML_CUDA_CC_OFFSET_AMD) #define GGML_CUDA_CC_IS_RDNA(cc) (cc >= GGML_CUDA_CC_RDNA1) #define GGML_CUDA_CC_IS_RDNA1(cc) (cc >= GGML_CUDA_CC_RDNA1 && cc < GGML_CUDA_CC_RDNA2) #define GGML_CUDA_CC_IS_RDNA2(cc) (cc >= GGML_CUDA_CC_RDNA2 && cc < GGML_CUDA_CC_RDNA3) -#define GGML_CUDA_CC_IS_RDNA3(cc) (cc >= GGML_CUDA_CC_RDNA3) +#define GGML_CUDA_CC_IS_RDNA3(cc) (cc >= GGML_CUDA_CC_RDNA3 && cc < GGML_CUDA_CC_RDNA4) +#define GGML_CUDA_CC_IS_RDNA4(cc) (cc >= GGML_CUDA_CC_RDNA4) #define GGML_CUDA_CC_IS_GCN(cc) (cc > GGML_CUDA_CC_OFFSET_AMD && cc < GGML_CUDA_CC_CDNA) #define GGML_CUDA_CC_IS_CDNA(cc) (cc >= GGML_CUDA_CC_CDNA && cc < GGML_CUDA_CC_RDNA1) @@ -209,9 +211,9 @@ typedef float2 dfloat2; #define FP16_MMA_AVAILABLE #endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA -#if defined(GGML_HIP_ROCWMMA_FATTN) && (defined(CDNA) || defined(RDNA3)) +#if defined(GGML_HIP_ROCWMMA_FATTN) && (defined(CDNA) || defined(RDNA3) || defined(RDNA4)) #define FP16_MMA_AVAILABLE -#endif // defined(GGML_HIP_ROCWMMA_FATTN) && (defined(CDNA) || defined(RDNA3)) +#endif // defined(GGML_HIP_ROCWMMA_FATTN) && (defined(CDNA) || defined(RDNA3) || defined(RDNA4)) #if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_TURING #define NEW_MMA_AVAILABLE @@ -244,14 +246,14 @@ static bool fp16_mma_available(const int cc) { return false; #else return (GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA) || - GGML_CUDA_CC_IS_CDNA(cc) || GGML_CUDA_CC_IS_RDNA3(cc); + GGML_CUDA_CC_IS_CDNA(cc) || GGML_CUDA_CC_IS_RDNA3(cc) || GGML_CUDA_CC_IS_RDNA4(cc); #endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && !defined(GGML_HIP_ROCWMMA_FATTN) } // To be used for feature selection of external libraries, e.g. cuBLAS. static bool fp16_mma_hardware_available(const int cc) { return (GGML_CUDA_CC_IS_NVIDIA(cc) && cc >= GGML_CUDA_CC_VOLTA) || - GGML_CUDA_CC_IS_CDNA(cc) || GGML_CUDA_CC_IS_RDNA3(cc); + GGML_CUDA_CC_IS_CDNA(cc) || GGML_CUDA_CC_IS_RDNA3(cc) || GGML_CUDA_CC_IS_RDNA4(cc); } // Volta technically had FP16 tensor cores but they work very differently compared to Turing and later. @@ -409,7 +411,7 @@ static __device__ __forceinline__ int ggml_cuda_dp4a(const int a, const int b, i #if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) #if defined(CDNA) || defined(RDNA2) || defined(__gfx906__) c = __builtin_amdgcn_sdot4(a, b, c, false); -#elif defined(RDNA3) +#elif defined(RDNA3) || defined(RDNA4) c = __builtin_amdgcn_sudot4( true, a, true, b, c, false); #elif defined(RDNA1) || defined(__gfx900__) int tmp1; diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 6dd5dcb85..3bb472ffb 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -1216,7 +1216,7 @@ static void ggml_cuda_op_mul_mat_cublas( CUBLAS_CHECK(cublasSetStream(ctx.cublas_handle(id), stream)); - if (GGML_CUDA_CC_IS_CDNA(cc)) { + if (GGML_CUDA_CC_IS_CDNA(cc) || GGML_CUDA_CC_IS_RDNA4(cc)) { const float alpha = 1.0f; const float beta = 0.0f; CUBLAS_CHECK( @@ -1759,7 +1759,9 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co beta = &beta_f32; } - if (GGML_CUDA_CC_IS_CDNA(ggml_cuda_info().devices[ctx.device].cc)) { + int id = ggml_cuda_get_device(); + const int cc = ggml_cuda_info().devices[id].cc; + if (GGML_CUDA_CC_IS_CDNA(cc) || GGML_CUDA_CC_IS_RDNA4(cc)) { cu_compute_type = CUBLAS_COMPUTE_32F; alpha = &alpha_f32; beta = &beta_f32; @@ -1836,7 +1838,7 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co } #endif - if (dst->op_params[0] == GGML_PREC_DEFAULT) { + if (dst->op_params[0] == GGML_PREC_DEFAULT && cu_data_type == CUDA_R_16F) { const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16); to_fp32_cuda(dst_f16.get(), dst_ddf, ne_dst, main_stream); } diff --git a/ggml/src/ggml-cuda/mmq.cu b/ggml/src/ggml-cuda/mmq.cu index 2c19485d5..b36b43d54 100644 --- a/ggml/src/ggml-cuda/mmq.cu +++ b/ggml/src/ggml-cuda/mmq.cu @@ -149,5 +149,5 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11) { return !fp16_mma_hardware_available(cc) || ne11 < MMQ_DP4A_MAX_BATCH_SIZE; } - return (!GGML_CUDA_CC_IS_RDNA3(cc) && !GGML_CUDA_CC_IS_CDNA(cc)) || ne11 < MMQ_DP4A_MAX_BATCH_SIZE; + return (!GGML_CUDA_CC_IS_RDNA4(cc) && !GGML_CUDA_CC_IS_RDNA3(cc) && !GGML_CUDA_CC_IS_CDNA(cc)) || ne11 < MMQ_DP4A_MAX_BATCH_SIZE; } diff --git a/ggml/src/ggml-cuda/mmq.cuh b/ggml/src/ggml-cuda/mmq.cuh index ee0115425..f136c4195 100644 --- a/ggml/src/ggml-cuda/mmq.cuh +++ b/ggml/src/ggml-cuda/mmq.cuh @@ -2577,9 +2577,9 @@ static __device__ void mul_mat_q_process_tile( template #if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) -#if defined(RDNA3) || defined(RDNA2) || defined(CDNA) || defined(GCN) +#if defined(RDNA4) || defined(RDNA3) || defined(RDNA2) || defined(CDNA) || defined(GCN) __launch_bounds__(WARP_SIZE*nwarps, 2) -#endif // defined(RDNA3) || defined(RDNA2) || defined(CDNA) || defined(GCN) +#endif // defined(RDNA4) || defined(RDNA3) || defined(RDNA2) || defined(CDNA) || defined(GCN) #else #if __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA __launch_bounds__(WARP_SIZE*nwarps, 1) diff --git a/ggml/src/ggml-cuda/mmvq.cu b/ggml/src/ggml-cuda/mmvq.cu index a7d518a57..45ea30f62 100644 --- a/ggml/src/ggml-cuda/mmvq.cu +++ b/ggml/src/ggml-cuda/mmvq.cu @@ -54,7 +54,7 @@ enum mmvq_parameter_table_id { }; static constexpr __device__ mmvq_parameter_table_id get_device_table_id() { -#if defined(RDNA2) || defined(RDNA3) +#if defined(RDNA2) || defined(RDNA3) || defined(RDNA4) return MMVQ_PARAMETERS_RDNA2; #elif defined(GCN) || defined(CDNA) return MMVQ_PARAMETERS_GCN; @@ -64,7 +64,7 @@ static constexpr __device__ mmvq_parameter_table_id get_device_table_id() { } static __host__ mmvq_parameter_table_id get_device_table_id(int cc) { - if (GGML_CUDA_CC_IS_RDNA2(cc) || GGML_CUDA_CC_IS_RDNA3(cc)) { + if (GGML_CUDA_CC_IS_RDNA2(cc) || GGML_CUDA_CC_IS_RDNA3(cc) || GGML_CUDA_CC_IS_RDNA4(cc)) { return MMVQ_PARAMETERS_RDNA2; } if (GGML_CUDA_CC_IS_GCN(cc) || GGML_CUDA_CC_IS_CDNA(cc)) { diff --git a/ggml/src/ggml-cuda/vendors/hip.h b/ggml/src/ggml-cuda/vendors/hip.h index a4c717a32..3983ce5b4 100644 --- a/ggml/src/ggml-cuda/vendors/hip.h +++ b/ggml/src/ggml-cuda/vendors/hip.h @@ -151,6 +151,10 @@ #define CDNA #endif +#if defined(__GFX12__) +#define RDNA4 +#endif + #if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) || defined(__gfx1103__) || \ defined(__gfx1150__) || defined(__gfx1151__) #define RDNA3 From f17a3bb4e8b0aa24c0f86636d234aca7dc2cfa01 Mon Sep 17 00:00:00 2001 From: Akarshan Biswas Date: Thu, 27 Mar 2025 07:16:00 +0530 Subject: [PATCH 05/26] SYCL: implement memset ggml backend buffer interface (#12580) * SYCL: implement memset ggml backend buffer interface * use GGML_ABORT macro * Do not wait for all queues to finish for memset operation --- ggml/src/ggml-sycl/ggml-sycl.cpp | 20 +++++++++++++++++++- 1 file changed, 19 insertions(+), 1 deletion(-) diff --git a/ggml/src/ggml-sycl/ggml-sycl.cpp b/ggml/src/ggml-sycl/ggml-sycl.cpp index 9fa24b980..39d53da33 100644 --- a/ggml/src/ggml-sycl/ggml-sycl.cpp +++ b/ggml/src/ggml-sycl/ggml-sycl.cpp @@ -37,6 +37,7 @@ #include "ggml-backend-impl.h" #include "ggml-sycl/backend.hpp" +#include "ggml-sycl/common.hpp" #include "ggml-sycl/presets.hpp" #include "ggml-sycl/gemm.hpp" #include "ggml-sycl/sycl_hw.hpp" @@ -490,6 +491,23 @@ catch (sycl::exception const &exc) { std::exit(1); } +static void ggml_backend_sycl_buffer_memset_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, uint8_t value, + size_t offset, size_t size) { + GGML_SYCL_DEBUG(" [SYCL] call %s\n", __func__); + ggml_backend_sycl_buffer_context * ctx = (ggml_backend_sycl_buffer_context *) buffer->context; + SYCL_CHECK(ggml_sycl_set_device(ctx->device)); + auto stream = &(dpct::dev_mgr::instance().get_device(ctx->device).default_queue()); + if (size == 0) { + return; // Nothing to do + } + if (tensor->data == nullptr) { + GGML_ABORT("Error: Tensor data pointer is null.\n"); + } + void * target_ptr = static_cast(tensor->data) + offset; + SYCL_CHECK(CHECK_TRY_ERROR((*stream).memset(target_ptr, value, size))); + SYCL_CHECK(CHECK_TRY_ERROR((*stream).wait())); +} + static void ggml_backend_sycl_buffer_reset(ggml_backend_buffer_t buffer) { GGML_SYCL_DEBUG("[SYCL] call %s\n", __func__); if (buffer == nullptr) { @@ -510,7 +528,7 @@ static const ggml_backend_buffer_i ggml_backend_sycl_buffer_interface = { /* .free_buffer = */ ggml_backend_sycl_buffer_free_buffer, /* .get_base = */ ggml_backend_sycl_buffer_get_base, /* .init_tensor = */ ggml_backend_sycl_buffer_init_tensor, - /* .memset_tensor = */ NULL, + /* .memset_tensor = */ ggml_backend_sycl_buffer_memset_tensor, /* .set_tensor = */ ggml_backend_sycl_buffer_set_tensor, /* .get_tensor = */ ggml_backend_sycl_buffer_get_tensor, /* .cpy_tensor = */ ggml_backend_sycl_buffer_cpy_tensor, From f28bc4c286c453cd8385388eea057a404fdb6402 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 27 Mar 2025 08:24:10 +0200 Subject: [PATCH 06/26] llama : make loras compatible with repacking (#12593) * llama : make loras compatible with repacking ggml-ci * cont : simplify ggml-ci * cont : add TODO [no ci] --- src/llama-adapter.cpp | 38 +++++++++++++++++++++++++++++++++++++- 1 file changed, 37 insertions(+), 1 deletion(-) diff --git a/src/llama-adapter.cpp b/src/llama-adapter.cpp index b448614e4..7ac54d239 100644 --- a/src/llama-adapter.cpp +++ b/src/llama-adapter.cpp @@ -247,6 +247,26 @@ static void llama_adapter_lora_init_impl(llama_model & model, const char * path_ } } + // get extra buffer types of the CPU + // TODO: a more general solution for non-CPU extra buft should be imlpemented in the future + // ref: https://github.com/ggml-org/llama.cpp/pull/12593#pullrequestreview-2718659948 + std::vector buft_extra; + { + auto * cpu_dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU); + auto * cpu_reg = ggml_backend_dev_backend_reg(cpu_dev); + + auto ggml_backend_dev_get_extra_bufts_fn = (ggml_backend_dev_get_extra_bufts_t) + ggml_backend_reg_get_proc_address(cpu_reg, "ggml_backend_dev_get_extra_bufts"); + + if (ggml_backend_dev_get_extra_bufts_fn) { + ggml_backend_buffer_type_t * extra_bufts = ggml_backend_dev_get_extra_bufts_fn(cpu_dev); + while (extra_bufts && *extra_bufts) { + buft_extra.emplace_back(*extra_bufts); + ++extra_bufts; + } + } + } + // add tensors for (auto & it : ab_map) { const std::string & name = it.first; @@ -263,7 +283,23 @@ static void llama_adapter_lora_init_impl(llama_model & model, const char * path_ throw std::runtime_error("LoRA tensor '" + name + "' does not exist in base model (hint: maybe wrong base model?)"); } - ggml_context * dev_ctx = ctx_for_buft(ggml_backend_buffer_get_type(model_tensor->buffer)); + auto * buft = ggml_backend_buffer_get_type(model_tensor->buffer); + + // do not load loras to extra buffer types (i.e. bufts for repacking) -> use the CPU in that case + for (auto & ex : buft_extra) { + if (ex == buft) { + LLAMA_LOG_WARN("%s: lora for '%s' cannot use buft '%s', fallback to CPU\n", __func__, model_tensor->name, ggml_backend_buft_name(buft)); + + auto * cpu_dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU); + buft = ggml_backend_dev_buffer_type(cpu_dev); + + break; + } + } + + LLAMA_LOG_DEBUG("%s: lora for '%s' -> '%s'\n", __func__, model_tensor->name, ggml_backend_buft_name(buft)); + + ggml_context * dev_ctx = ctx_for_buft(buft); // validate tensor shape if (is_token_embd) { // expect B to be non-transposed, A and B are flipped; see llm_build_inp_embd() From 24feaec05792b972d5ff3e2b12d9237ebd50d1ac Mon Sep 17 00:00:00 2001 From: xctan Date: Thu, 27 Mar 2025 14:38:34 +0800 Subject: [PATCH 07/26] ggml : riscv: add 128-bit RVV support (#12530) * ggml : add 128-bit RVV support * ggml : revert to old RVV 256+ q2_K, q3_K, q4_K, q6_K impl * remove trailing whitespaces * restructure vector length selection code --- ggml/CMakeLists.txt | 1 + ggml/src/ggml-cpu/CMakeLists.txt | 6 +- ggml/src/ggml-cpu/ggml-cpu-quants.c | 1142 +++++++++++++++++---------- ggml/src/ggml-impl.h | 29 + 4 files changed, 781 insertions(+), 397 deletions(-) diff --git a/ggml/CMakeLists.txt b/ggml/CMakeLists.txt index 740f9f69c..433628c4c 100644 --- a/ggml/CMakeLists.txt +++ b/ggml/CMakeLists.txt @@ -123,6 +123,7 @@ endif() option(GGML_LASX "ggml: enable lasx" ON) option(GGML_LSX "ggml: enable lsx" ON) option(GGML_RVV "ggml: enable rvv" ON) +option(GGML_RV_ZFH "ggml: enable riscv zfh" OFF) option(GGML_VXE "ggml: enable vxe" ON) option(GGML_CPU_ALL_VARIANTS "ggml: build all variants of the CPU backend (requires GGML_BACKEND_DL)" OFF) diff --git a/ggml/src/ggml-cpu/CMakeLists.txt b/ggml/src/ggml-cpu/CMakeLists.txt index cb71e9b39..b9076513a 100644 --- a/ggml/src/ggml-cpu/CMakeLists.txt +++ b/ggml/src/ggml-cpu/CMakeLists.txt @@ -320,7 +320,11 @@ function(ggml_add_cpu_backend_variant_impl tag_name) elseif (${CMAKE_SYSTEM_PROCESSOR} MATCHES "riscv64") message(STATUS "RISC-V detected") if (GGML_RVV) - list(APPEND ARCH_FLAGS -march=rv64gcv -mabi=lp64d) + if (GGML_RV_ZFH) + list(APPEND ARCH_FLAGS -march=rv64gcv_zfhmin -DGGML_RV_ZFH -mabi=lp64d) + else() + list(APPEND ARCH_FLAGS -march=rv64gcv -mabi=lp64d) + endif() endif() elseif (${CMAKE_SYSTEM_PROCESSOR} MATCHES "s390x") message(STATUS "s390x detected") diff --git a/ggml/src/ggml-cpu/ggml-cpu-quants.c b/ggml/src/ggml-cpu/ggml-cpu-quants.c index 4e0ae0572..91a81bdc3 100644 --- a/ggml/src/ggml-cpu/ggml-cpu-quants.c +++ b/ggml/src/ggml-cpu/ggml-cpu-quants.c @@ -891,15 +891,15 @@ void quantize_row_q8_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, i } #elif defined(__riscv_v_intrinsic) - size_t vl = __riscv_vsetvl_e32m4(QK8_0); + size_t vl = QK8_0; for (int i = 0; i < nb; i++) { // load elements - vfloat32m4_t v_x = __riscv_vle32_v_f32m4(x+i*QK8_0, vl); + vfloat32m8_t v_x = __riscv_vle32_v_f32m8(x+i*QK8_0, vl); - vfloat32m4_t vfabs = __riscv_vfabs_v_f32m4(v_x, vl); + vfloat32m8_t vfabs = __riscv_vfabs_v_f32m8(v_x, vl); vfloat32m1_t tmp = __riscv_vfmv_v_f_f32m1(0.0f, vl); - vfloat32m1_t vmax = __riscv_vfredmax_vs_f32m4_f32m1(vfabs, tmp, vl); + vfloat32m1_t vmax = __riscv_vfredmax_vs_f32m8_f32m1(vfabs, tmp, vl); float amax = __riscv_vfmv_f_s_f32m1_f32(vmax); const float d = amax / ((1 << 7) - 1); @@ -907,14 +907,14 @@ void quantize_row_q8_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, i y[i].d = GGML_FP32_TO_FP16(d); - vfloat32m4_t x0 = __riscv_vfmul_vf_f32m4(v_x, id, vl); + vfloat32m8_t x0 = __riscv_vfmul_vf_f32m8(v_x, id, vl); // convert to integer - vint16m2_t vi = __riscv_vfncvt_x_f_w_i16m2(x0, vl); - vint8m1_t vs = __riscv_vncvt_x_x_w_i8m1(vi, vl); + vint16m4_t vi = __riscv_vfncvt_x_f_w_i16m4(x0, vl); + vint8m2_t vs = __riscv_vncvt_x_x_w_i8m2(vi, vl); // store result - __riscv_vse8_v_i8m1(y[i].qs , vs, vl); + __riscv_vse8_v_i8m2(y[i].qs , vs, vl); } #elif defined(__POWER9_VECTOR__) @@ -1229,15 +1229,15 @@ void quantize_row_q8_1(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, i } #elif defined(__riscv_v_intrinsic) - size_t vl = __riscv_vsetvl_e32m4(QK8_1); + size_t vl = QK8_1; for (int i = 0; i < nb; i++) { // load elements - vfloat32m4_t v_x = __riscv_vle32_v_f32m4(x+i*QK8_1, vl); + vfloat32m8_t v_x = __riscv_vle32_v_f32m8(x+i*QK8_1, vl); - vfloat32m4_t vfabs = __riscv_vfabs_v_f32m4(v_x, vl); + vfloat32m8_t vfabs = __riscv_vfabs_v_f32m8(v_x, vl); vfloat32m1_t tmp = __riscv_vfmv_v_f_f32m1(0.0, vl); - vfloat32m1_t vmax = __riscv_vfredmax_vs_f32m4_f32m1(vfabs, tmp, vl); + vfloat32m1_t vmax = __riscv_vfredmax_vs_f32m8_f32m1(vfabs, tmp, vl); float amax = __riscv_vfmv_f_s_f32m1_f32(vmax); const float d = amax / ((1 << 7) - 1); @@ -1245,18 +1245,18 @@ void quantize_row_q8_1(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, i y[i].d = GGML_FP32_TO_FP16(d); - vfloat32m4_t x0 = __riscv_vfmul_vf_f32m4(v_x, id, vl); + vfloat32m8_t x0 = __riscv_vfmul_vf_f32m8(v_x, id, vl); // convert to integer - vint16m2_t vi = __riscv_vfncvt_x_f_w_i16m2(x0, vl); - vint8m1_t vs = __riscv_vncvt_x_x_w_i8m1(vi, vl); + vint16m4_t vi = __riscv_vfncvt_x_f_w_i16m4(x0, vl); + vint8m2_t vs = __riscv_vncvt_x_x_w_i8m2(vi, vl); // store result - __riscv_vse8_v_i8m1(y[i].qs , vs, vl); + __riscv_vse8_v_i8m2(y[i].qs , vs, vl); // compute sum for y[i].s vint16m1_t tmp2 = __riscv_vmv_v_x_i16m1(0, vl); - vint16m1_t vwrs = __riscv_vwredsum_vs_i8m1_i16m1(vs, tmp2, vl); + vint16m1_t vwrs = __riscv_vwredsum_vs_i8m2_i16m1(vs, tmp2, vl); // set y[i].s int sum = __riscv_vmv_x_s_i16m1_i16(vwrs); @@ -2391,33 +2391,31 @@ void ggml_vec_dot_q4_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const voi sumf = hsum_float_4x4(acc_0, acc_1, acc_2, acc_3); #elif defined(__riscv_v_intrinsic) - size_t vl = __riscv_vsetvl_e8m1(qk/2); + size_t vl = qk / 2; for (; ib < nb; ++ib) { // load elements - vuint8mf2_t tx = __riscv_vle8_v_u8mf2(x[ib].qs, vl); + vuint8m1_t tx = __riscv_vle8_v_u8m1(x[ib].qs, vl); - vint8mf2_t y0 = __riscv_vle8_v_i8mf2(y[ib].qs, vl); - vint8mf2_t y1 = __riscv_vle8_v_i8mf2(y[ib].qs+16, vl); + vint8m1_t y0 = __riscv_vle8_v_i8m1(y[ib].qs, vl); + vint8m1_t y1 = __riscv_vle8_v_i8m1(y[ib].qs+16, vl); // mask and store lower part of x, and then upper part - vuint8mf2_t x_a = __riscv_vand_vx_u8mf2(tx, 0x0F, vl); - vuint8mf2_t x_l = __riscv_vsrl_vx_u8mf2(tx, 0x04, vl); + vuint8m1_t x_a = __riscv_vand_vx_u8m1(tx, 0x0F, vl); + vuint8m1_t x_l = __riscv_vsrl_vx_u8m1(tx, 0x04, vl); - vint8mf2_t x_ai = __riscv_vreinterpret_v_u8mf2_i8mf2(x_a); - vint8mf2_t x_li = __riscv_vreinterpret_v_u8mf2_i8mf2(x_l); + vint8m1_t x_ai = __riscv_vreinterpret_v_u8m1_i8m1(x_a); + vint8m1_t x_li = __riscv_vreinterpret_v_u8m1_i8m1(x_l); // subtract offset - vint8mf2_t v0 = __riscv_vsub_vx_i8mf2(x_ai, 8, vl); - vint8mf2_t v1 = __riscv_vsub_vx_i8mf2(x_li, 8, vl); + vint8m1_t v0 = __riscv_vsub_vx_i8m1(x_ai, 8, vl); + vint8m1_t v1 = __riscv_vsub_vx_i8m1(x_li, 8, vl); - vint16m1_t vec_mul1 = __riscv_vwmul_vv_i16m1(v0, y0, vl); - vint16m1_t vec_mul2 = __riscv_vwmul_vv_i16m1(v1, y1, vl); + vint16m2_t vec_mul1 = __riscv_vwmul_vv_i16m2(v0, y0, vl); + vint16m2_t vec_mul2 = __riscv_vwmacc_vv_i16m2(vec_mul1, v1, y1, vl); vint32m1_t vec_zero = __riscv_vmv_v_x_i32m1(0, vl); - - vint32m1_t vs1 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul1, vec_zero, vl); - vint32m1_t vs2 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul2, vs1, vl); + vint32m1_t vs2 = __riscv_vwredsum_vs_i16m2_i32m1(vec_mul2, vec_zero, vl); int sumi = __riscv_vmv_x_s_i32m1_i32(vs2); @@ -2783,29 +2781,27 @@ void ggml_vec_dot_q4_1_q8_1(int n, float * GGML_RESTRICT s, size_t bs, const voi sumf = hsum_float_8(acc) + summs; #elif defined(__riscv_v_intrinsic) - size_t vl = __riscv_vsetvl_e8m1(qk/2); + size_t vl = qk / 2; for (; ib < nb; ++ib) { // load elements - vuint8mf2_t tx = __riscv_vle8_v_u8mf2(x[ib].qs, vl); + vuint8m1_t tx = __riscv_vle8_v_u8m1(x[ib].qs, vl); - vint8mf2_t y0 = __riscv_vle8_v_i8mf2(y[ib].qs, vl); - vint8mf2_t y1 = __riscv_vle8_v_i8mf2(y[ib].qs+16, vl); + vint8m1_t y0 = __riscv_vle8_v_i8m1(y[ib].qs, vl); + vint8m1_t y1 = __riscv_vle8_v_i8m1(y[ib].qs+16, vl); // mask and store lower part of x, and then upper part - vuint8mf2_t x_a = __riscv_vand_vx_u8mf2(tx, 0x0F, vl); - vuint8mf2_t x_l = __riscv_vsrl_vx_u8mf2(tx, 0x04, vl); + vuint8m1_t x_a = __riscv_vand_vx_u8m1(tx, 0x0F, vl); + vuint8m1_t x_l = __riscv_vsrl_vx_u8m1(tx, 0x04, vl); - vint8mf2_t v0 = __riscv_vreinterpret_v_u8mf2_i8mf2(x_a); - vint8mf2_t v1 = __riscv_vreinterpret_v_u8mf2_i8mf2(x_l); + vint8m1_t v0 = __riscv_vreinterpret_v_u8m1_i8m1(x_a); + vint8m1_t v1 = __riscv_vreinterpret_v_u8m1_i8m1(x_l); - vint16m1_t vec_mul1 = __riscv_vwmul_vv_i16m1(v0, y0, vl); - vint16m1_t vec_mul2 = __riscv_vwmul_vv_i16m1(v1, y1, vl); + vint16m2_t vec_mul1 = __riscv_vwmul_vv_i16m2(v0, y0, vl); + vint16m2_t vec_mul2 = __riscv_vwmacc_vv_i16m2(vec_mul1, v1, y1, vl); vint32m1_t vec_zero = __riscv_vmv_v_x_i32m1(0, vl); - - vint32m1_t vs1 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul1, vec_zero, vl); - vint32m1_t vs2 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul2, vs1, vl); + vint32m1_t vs2 = __riscv_vwredsum_vs_i16m2_i32m1(vec_mul2, vec_zero, vl); int sumi = __riscv_vmv_x_s_i32m1_i32(vs2); @@ -3132,65 +3128,33 @@ void ggml_vec_dot_q5_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const voi sumf = hsum_float_8(acc); #elif defined(__riscv_v_intrinsic) - uint32_t qh; - - size_t vl = __riscv_vsetvl_e8m1(qk/2); - - // These temporary registers are for masking and shift operations - vuint32m2_t vt_1 = __riscv_vid_v_u32m2(vl); - vuint32m2_t vt_2 = __riscv_vsll_vv_u32m2(__riscv_vmv_v_x_u32m2(1, vl), vt_1, vl); - - vuint32m2_t vt_3 = __riscv_vsll_vx_u32m2(vt_2, 16, vl); - vuint32m2_t vt_4 = __riscv_vadd_vx_u32m2(vt_1, 12, vl); + size_t vl; + size_t vlenb = __riscv_vlenb(); for (; ib < nb; ++ib) { - memcpy(&qh, x[ib].qh, sizeof(uint32_t)); + vl = qk / 2; + vuint8m1_t v0 = __riscv_vle8_v_u8m1(x[ib].qs, vl); + vint8m1_t v0l = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(v0, 0x0F, vl)); + vint8m1_t v0h = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vsrl_vx_u8m1(v0, 4, vl)); + vint8m2_t v0c; + if (vlenb == 16) { + v0c = __riscv_vcreate_v_i8m1_i8m2(v0l, v0h); + } else { + v0l = __riscv_vslideup_vx_i8m1(v0l, v0h, 16, 32); + v0c = __riscv_vlmul_ext_v_i8m1_i8m2(v0l); + } - // ((qh & (1u << (j + 0 ))) >> (j + 0 )) << 4; - vuint32m2_t xha_0 = __riscv_vand_vx_u32m2(vt_2, qh, vl); - vuint32m2_t xhr_0 = __riscv_vsrl_vv_u32m2(xha_0, vt_1, vl); - vuint32m2_t xhl_0 = __riscv_vsll_vx_u32m2(xhr_0, 4, vl); + vl = qk; + vbool4_t qh = __riscv_vlm_v_b4(x[ib].qh, vl); + qh = __riscv_vmnand_mm_b4(qh, qh, vl); + vint8m2_t v0f = __riscv_vsub_vx_i8m2_mu(qh, v0c, v0c, 0x10, vl); + vint8m2_t v1 = __riscv_vle8_v_i8m2(y[ib].qs, vl); + vint16m4_t mul = __riscv_vwmul_vv_i16m4(v0f, v1, vl); + vint32m1_t zero = __riscv_vmv_v_x_i32m1(0, vl); + vint32m1_t sum = __riscv_vwredsum_vs_i16m4_i32m1(mul, zero, vl); + int32_t sumi = __riscv_vmv_x_s_i32m1_i32(sum); - // ((qh & (1u << (j + 16))) >> (j + 12)); - vuint32m2_t xha_1 = __riscv_vand_vx_u32m2(vt_3, qh, vl); - vuint32m2_t xhl_1 = __riscv_vsrl_vv_u32m2(xha_1, vt_4, vl); - - // narrowing - vuint16m1_t xhc_0 = __riscv_vncvt_x_x_w_u16m1(xhl_0, vl); - vuint8mf2_t xh_0 = __riscv_vncvt_x_x_w_u8mf2(xhc_0, vl); - - vuint16m1_t xhc_1 = __riscv_vncvt_x_x_w_u16m1(xhl_1, vl); - vuint8mf2_t xh_1 = __riscv_vncvt_x_x_w_u8mf2(xhc_1, vl); - - // load - vuint8mf2_t tx = __riscv_vle8_v_u8mf2(x[ib].qs, vl); - - vint8mf2_t y0 = __riscv_vle8_v_i8mf2(y[ib].qs, vl); - vint8mf2_t y1 = __riscv_vle8_v_i8mf2(y[ib].qs+16, vl); - - vuint8mf2_t x_at = __riscv_vand_vx_u8mf2(tx, 0x0F, vl); - vuint8mf2_t x_lt = __riscv_vsrl_vx_u8mf2(tx, 0x04, vl); - - vuint8mf2_t x_a = __riscv_vor_vv_u8mf2(x_at, xh_0, vl); - vuint8mf2_t x_l = __riscv_vor_vv_u8mf2(x_lt, xh_1, vl); - - vint8mf2_t x_ai = __riscv_vreinterpret_v_u8mf2_i8mf2(x_a); - vint8mf2_t x_li = __riscv_vreinterpret_v_u8mf2_i8mf2(x_l); - - vint8mf2_t v0 = __riscv_vsub_vx_i8mf2(x_ai, 16, vl); - vint8mf2_t v1 = __riscv_vsub_vx_i8mf2(x_li, 16, vl); - - vint16m1_t vec_mul1 = __riscv_vwmul_vv_i16m1(v0, y0, vl); - vint16m1_t vec_mul2 = __riscv_vwmul_vv_i16m1(v1, y1, vl); - - vint32m1_t vec_zero = __riscv_vmv_v_x_i32m1(0, vl); - - vint32m1_t vs1 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul1, vec_zero, vl); - vint32m1_t vs2 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul2, vs1, vl); - - int sumi = __riscv_vmv_x_s_i32m1_i32(vs2); - - sumf += (GGML_FP16_TO_FP32(x[ib].d)*GGML_FP16_TO_FP32(y[ib].d)) * sumi; + sumf += (GGML_FP16_TO_FP32(x[ib].d) * GGML_FP16_TO_FP32(y[ib].d)) * sumi; } #elif defined(__POWER9_VECTOR__) @@ -3503,60 +3467,30 @@ void ggml_vec_dot_q5_1_q8_1(int n, float * GGML_RESTRICT s, size_t bs, const voi sumf = hsum_float_8(acc) + summs; #elif defined(__riscv_v_intrinsic) - uint32_t qh; - - size_t vl = __riscv_vsetvl_e8m1(qk/2); - - // temporary registers for shift operations - vuint32m2_t vt_1 = __riscv_vid_v_u32m2(vl); - vuint32m2_t vt_2 = __riscv_vadd_vx_u32m2(vt_1, 12, vl); + size_t vl; + size_t vlenb = __riscv_vlenb(); for (; ib < nb; ++ib) { - memcpy(&qh, x[ib].qh, sizeof(uint32_t)); + vl = qk / 2; + vuint8m1_t v0 = __riscv_vle8_v_u8m1(x[ib].qs, vl); + vint8m1_t v0l = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(v0, 0x0F, vl)); + vint8m1_t v0h = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vsrl_vx_u8m1(v0, 4, vl)); + vint8m2_t v0c; + if (vlenb == 16) { + v0c = __riscv_vcreate_v_i8m1_i8m2(v0l, v0h); + } else { + v0l = __riscv_vslideup_vx_i8m1(v0l, v0h, 16, 32); + v0c = __riscv_vlmul_ext_v_i8m1_i8m2(v0l); + } - // load qh - vuint32m2_t vqh = __riscv_vmv_v_x_u32m2(qh, vl); - - // ((qh >> (j + 0)) << 4) & 0x10; - vuint32m2_t xhr_0 = __riscv_vsrl_vv_u32m2(vqh, vt_1, vl); - vuint32m2_t xhl_0 = __riscv_vsll_vx_u32m2(xhr_0, 4, vl); - vuint32m2_t xha_0 = __riscv_vand_vx_u32m2(xhl_0, 0x10, vl); - - // ((qh >> (j + 12)) ) & 0x10; - vuint32m2_t xhr_1 = __riscv_vsrl_vv_u32m2(vqh, vt_2, vl); - vuint32m2_t xha_1 = __riscv_vand_vx_u32m2(xhr_1, 0x10, vl); - - // narrowing - vuint16m1_t xhc_0 = __riscv_vncvt_x_x_w_u16m1(xha_0, vl); - vuint8mf2_t xh_0 = __riscv_vncvt_x_x_w_u8mf2(xhc_0, vl); - - vuint16m1_t xhc_1 = __riscv_vncvt_x_x_w_u16m1(xha_1, vl); - vuint8mf2_t xh_1 = __riscv_vncvt_x_x_w_u8mf2(xhc_1, vl); - - // load - vuint8mf2_t tx = __riscv_vle8_v_u8mf2(x[ib].qs, vl); - - vint8mf2_t y0 = __riscv_vle8_v_i8mf2(y[ib].qs, vl); - vint8mf2_t y1 = __riscv_vle8_v_i8mf2(y[ib].qs+16, vl); - - vuint8mf2_t x_at = __riscv_vand_vx_u8mf2(tx, 0x0F, vl); - vuint8mf2_t x_lt = __riscv_vsrl_vx_u8mf2(tx, 0x04, vl); - - vuint8mf2_t x_a = __riscv_vor_vv_u8mf2(x_at, xh_0, vl); - vuint8mf2_t x_l = __riscv_vor_vv_u8mf2(x_lt, xh_1, vl); - - vint8mf2_t v0 = __riscv_vreinterpret_v_u8mf2_i8mf2(x_a); - vint8mf2_t v1 = __riscv_vreinterpret_v_u8mf2_i8mf2(x_l); - - vint16m1_t vec_mul1 = __riscv_vwmul_vv_i16m1(v0, y0, vl); - vint16m1_t vec_mul2 = __riscv_vwmul_vv_i16m1(v1, y1, vl); - - vint32m1_t vec_zero = __riscv_vmv_v_x_i32m1(0, vl); - - vint32m1_t vs1 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul1, vec_zero, vl); - vint32m1_t vs2 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul2, vs1, vl); - - int sumi = __riscv_vmv_x_s_i32m1_i32(vs2); + vl = qk; + vbool4_t qh = __riscv_vlm_v_b4(x[ib].qh, vl); + vint8m2_t v0f = __riscv_vor_vx_i8m2_mu(qh, v0c, v0c, 0x10, vl); + vint8m2_t v1 = __riscv_vle8_v_i8m2(y[ib].qs, vl); + vint16m4_t mul = __riscv_vwmul_vv_i16m4(v0f, v1, vl); + vint32m1_t zero = __riscv_vmv_v_x_i32m1(0, vl); + vint32m1_t sum = __riscv_vwredsum_vs_i16m4_i32m1(mul, zero, vl); + int32_t sumi = __riscv_vmv_x_s_i32m1_i32(sum); sumf += (GGML_FP16_TO_FP32(x[ib].d)*GGML_FP16_TO_FP32(y[ib].d))*sumi + GGML_FP16_TO_FP32(x[ib].m)*GGML_FP16_TO_FP32(y[ib].s); } @@ -3970,17 +3904,17 @@ void ggml_vec_dot_q8_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const voi sumf = hsum_float_8(accum); #elif defined(__riscv_v_intrinsic) - size_t vl = __riscv_vsetvl_e8m1(qk); + size_t vl = qk; for (; ib < nb; ++ib) { // load elements - vint8m1_t bx_0 = __riscv_vle8_v_i8m1(x[ib].qs, vl); - vint8m1_t by_0 = __riscv_vle8_v_i8m1(y[ib].qs, vl); + vint8m2_t bx_0 = __riscv_vle8_v_i8m2(x[ib].qs, vl); + vint8m2_t by_0 = __riscv_vle8_v_i8m2(y[ib].qs, vl); - vint16m2_t vw_mul = __riscv_vwmul_vv_i16m2(bx_0, by_0, vl); + vint16m4_t vw_mul = __riscv_vwmul_vv_i16m4(bx_0, by_0, vl); vint32m1_t v_zero = __riscv_vmv_v_x_i32m1(0, vl); - vint32m1_t v_sum = __riscv_vwredsum_vs_i16m2_i32m1(vw_mul, v_zero, vl); + vint32m1_t v_sum = __riscv_vwredsum_vs_i16m4_i32m1(vw_mul, v_zero, vl); int sumi = __riscv_vmv_x_s_i32m1_i32(v_sum); @@ -5174,84 +5108,182 @@ void ggml_vec_dot_q2_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi #elif defined __riscv_v_intrinsic + const int vector_length = __riscv_vlenb() * 8; float sumf = 0; - uint8_t temp_01[32] = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}; - for (int i = 0; i < nb; ++i) { + uint8_t temp_01[32] = { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1 }; + uint8_t atmp[16]; - const uint8_t * q2 = x[i].qs; - const int8_t * q8 = y[i].qs; - const uint8_t * sc = x[i].scales; + switch (vector_length) { + case 256: + for (int i = 0; i < nb; ++i) { + const uint8_t * q2 = x[i].qs; + const int8_t * q8 = y[i].qs; + const uint8_t * sc = x[i].scales; - const float dall = y[i].d * GGML_FP16_TO_FP32(x[i].d); - const float dmin = -y[i].d * GGML_FP16_TO_FP32(x[i].dmin); + const float dall = y[i].d * GGML_FP16_TO_FP32(x[i].d); + const float dmin = -y[i].d * GGML_FP16_TO_FP32(x[i].dmin); - size_t vl = 16; + size_t vl = 16; - vuint8m1_t scales = __riscv_vle8_v_u8m1(sc, vl); - vuint8m1_t aux = __riscv_vand_vx_u8m1(scales, 0x0F, vl); + vuint8m1_t scales = __riscv_vle8_v_u8m1(sc, vl); + vuint8m1_t aux = __riscv_vand_vx_u8m1(scales, 0x0F, vl); - vint16m1_t q8sums = __riscv_vle16_v_i16m1(y[i].bsums, vl); + vint16m1_t q8sums = __riscv_vle16_v_i16m1(y[i].bsums, vl); - vuint8mf2_t scales_2 = __riscv_vle8_v_u8mf2(sc, vl); - vuint8mf2_t mins8 = __riscv_vsrl_vx_u8mf2(scales_2, 0x4, vl); - vint16m1_t mins = __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vzext_vf2_u16m1(mins8, vl)); - vint32m2_t prod = __riscv_vwmul_vv_i32m2(q8sums, mins, vl); - vint32m1_t vsums = __riscv_vredsum_vs_i32m2_i32m1(prod, __riscv_vmv_v_x_i32m1(0, 1), vl); + vuint8mf2_t scales_2 = __riscv_vle8_v_u8mf2(sc, vl); + vuint8mf2_t mins8 = __riscv_vsrl_vx_u8mf2(scales_2, 0x4, vl); + vint16m1_t mins = __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vzext_vf2_u16m1(mins8, vl)); + vint32m2_t prod = __riscv_vwmul_vv_i32m2(q8sums, mins, vl); + vint32m1_t vsums = __riscv_vredsum_vs_i32m2_i32m1(prod, __riscv_vmv_v_x_i32m1(0, 1), vl); - sumf += dmin * __riscv_vmv_x_s_i32m1_i32(vsums); + sumf += dmin * __riscv_vmv_x_s_i32m1_i32(vsums); - vl = 32; + vl = 32; - vint32m1_t vzero = __riscv_vmv_v_x_i32m1(0, 1); - vuint8m1_t v_b = __riscv_vle8_v_u8m1(temp_01, vl); + vint32m1_t vzero = __riscv_vmv_v_x_i32m1(0, 1); + vuint8m1_t v_b = __riscv_vle8_v_u8m1(temp_01, vl); - uint8_t is=0; - int isum=0; + uint8_t is = 0; + int isum = 0; - for (int j = 0; j < QK_K/128; ++j) { - // load Q2 - vuint8m1_t q2_x = __riscv_vle8_v_u8m1(q2, vl); + for (int j = 0; j < QK_K / 128; ++j) { + // load Q2 + vuint8m1_t q2_x = __riscv_vle8_v_u8m1(q2, vl); - vuint8m1_t q2_0 = __riscv_vand_vx_u8m1(q2_x, 0x03, vl); - vuint8m1_t q2_1 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q2_x, 0x2, vl), 0x03 , vl); - vuint8m1_t q2_2 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q2_x, 0x4, vl), 0x03 , vl); - vuint8m1_t q2_3 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q2_x, 0x6, vl), 0x03 , vl); + vuint8m1_t q2_0 = __riscv_vand_vx_u8m1(q2_x, 0x03, vl); + vuint8m1_t q2_1 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q2_x, 0x2, vl), 0x03, vl); + vuint8m1_t q2_2 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q2_x, 0x4, vl), 0x03, vl); + vuint8m1_t q2_3 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q2_x, 0x6, vl), 0x03, vl); - // duplicate scale elements for product - vuint8m1_t sc0 = __riscv_vrgather_vv_u8m1(aux, __riscv_vadd_vx_u8m1(v_b, 0+is, vl), vl); - vuint8m1_t sc1 = __riscv_vrgather_vv_u8m1(aux, __riscv_vadd_vx_u8m1(v_b, 2+is, vl), vl); - vuint8m1_t sc2 = __riscv_vrgather_vv_u8m1(aux, __riscv_vadd_vx_u8m1(v_b, 4+is, vl), vl); - vuint8m1_t sc3 = __riscv_vrgather_vv_u8m1(aux, __riscv_vadd_vx_u8m1(v_b, 6+is, vl), vl); + // duplicate scale elements for product + vuint8m1_t sc0 = __riscv_vrgather_vv_u8m1(aux, __riscv_vadd_vx_u8m1(v_b, 0 + is, vl), vl); + vuint8m1_t sc1 = __riscv_vrgather_vv_u8m1(aux, __riscv_vadd_vx_u8m1(v_b, 2 + is, vl), vl); + vuint8m1_t sc2 = __riscv_vrgather_vv_u8m1(aux, __riscv_vadd_vx_u8m1(v_b, 4 + is, vl), vl); + vuint8m1_t sc3 = __riscv_vrgather_vv_u8m1(aux, __riscv_vadd_vx_u8m1(v_b, 6 + is, vl), vl); - vint16m2_t p0 = __riscv_vreinterpret_v_u16m2_i16m2(__riscv_vwmulu_vv_u16m2(q2_0, sc0, vl)); - vint16m2_t p1 = __riscv_vreinterpret_v_u16m2_i16m2(__riscv_vwmulu_vv_u16m2(q2_1, sc1, vl)); - vint16m2_t p2 = __riscv_vreinterpret_v_u16m2_i16m2(__riscv_vwmulu_vv_u16m2(q2_2, sc2, vl)); - vint16m2_t p3 = __riscv_vreinterpret_v_u16m2_i16m2(__riscv_vwmulu_vv_u16m2(q2_3, sc3, vl)); + vint16m2_t p0 = __riscv_vreinterpret_v_u16m2_i16m2(__riscv_vwmulu_vv_u16m2(q2_0, sc0, vl)); + vint16m2_t p1 = __riscv_vreinterpret_v_u16m2_i16m2(__riscv_vwmulu_vv_u16m2(q2_1, sc1, vl)); + vint16m2_t p2 = __riscv_vreinterpret_v_u16m2_i16m2(__riscv_vwmulu_vv_u16m2(q2_2, sc2, vl)); + vint16m2_t p3 = __riscv_vreinterpret_v_u16m2_i16m2(__riscv_vwmulu_vv_u16m2(q2_3, sc3, vl)); - // load Q8 - vint8m1_t q8_0 = __riscv_vle8_v_i8m1(q8, vl); - vint8m1_t q8_1 = __riscv_vle8_v_i8m1(q8+32, vl); - vint8m1_t q8_2 = __riscv_vle8_v_i8m1(q8+64, vl); - vint8m1_t q8_3 = __riscv_vle8_v_i8m1(q8+96, vl); + // load Q8 + vint8m1_t q8_0 = __riscv_vle8_v_i8m1(q8, vl); + vint8m1_t q8_1 = __riscv_vle8_v_i8m1(q8 + 32, vl); + vint8m1_t q8_2 = __riscv_vle8_v_i8m1(q8 + 64, vl); + vint8m1_t q8_3 = __riscv_vle8_v_i8m1(q8 + 96, vl); - vint32m4_t s0 = __riscv_vwmul_vv_i32m4(p0, __riscv_vwcvt_x_x_v_i16m2(q8_0, vl), vl); - vint32m4_t s1 = __riscv_vwmul_vv_i32m4(p1, __riscv_vwcvt_x_x_v_i16m2(q8_1, vl), vl); - vint32m4_t s2 = __riscv_vwmul_vv_i32m4(p2, __riscv_vwcvt_x_x_v_i16m2(q8_2, vl), vl); - vint32m4_t s3 = __riscv_vwmul_vv_i32m4(p3, __riscv_vwcvt_x_x_v_i16m2(q8_3, vl), vl); + vint32m4_t s0 = __riscv_vwmul_vv_i32m4(p0, __riscv_vwcvt_x_x_v_i16m2(q8_0, vl), vl); + vint32m4_t s1 = __riscv_vwmul_vv_i32m4(p1, __riscv_vwcvt_x_x_v_i16m2(q8_1, vl), vl); + vint32m4_t s2 = __riscv_vwmul_vv_i32m4(p2, __riscv_vwcvt_x_x_v_i16m2(q8_2, vl), vl); + vint32m4_t s3 = __riscv_vwmul_vv_i32m4(p3, __riscv_vwcvt_x_x_v_i16m2(q8_3, vl), vl); - vint32m1_t isum0 = __riscv_vredsum_vs_i32m4_i32m1(__riscv_vadd_vv_i32m4(s0, s1, vl), vzero, vl); - vint32m1_t isum1 = __riscv_vredsum_vs_i32m4_i32m1(__riscv_vadd_vv_i32m4(s2, s3, vl), isum0, vl); + vint32m1_t isum0 = __riscv_vredsum_vs_i32m4_i32m1(__riscv_vadd_vv_i32m4(s0, s1, vl), vzero, vl); + vint32m1_t isum1 = __riscv_vredsum_vs_i32m4_i32m1(__riscv_vadd_vv_i32m4(s2, s3, vl), isum0, vl); - isum += __riscv_vmv_x_s_i32m1_i32(isum1); + isum += __riscv_vmv_x_s_i32m1_i32(isum1); - q2+=32; q8+=128; is=8; + q2 += 32; + q8 += 128; + is = 8; + } + sumf += dall * isum; } + break; + case 128: + for (int i = 0; i < nb; ++i) { + const uint8_t * q2 = x[i].qs; + const int8_t * q8 = y[i].qs; + const uint8_t * sc = x[i].scales; + const float dall = y[i].d * GGML_FP16_TO_FP32(x[i].d); + const float dmin = -y[i].d * GGML_FP16_TO_FP32(x[i].dmin); + uint8_t *patmp = atmp; + int vsums; + int tmp; + __asm__ __volatile__( + "vsetivli zero, 16, e8, m1\n\t" + "vmv.v.x v8, zero\n\t" + "vle8.v v1, (%[sc])\n\t" + "vand.vi v0, v1, 0xF\n\t" + "vsrl.vi v1, v1, 4\n\t" + "vse8.v v0, (%[scale])\n\t" + "vsetivli zero, 16, e16, m2\n\t" + "vle16.v v2, (%[bsums])\n\t" + "vzext.vf2 v0, v1\n\t" + "vwmul.vv v4, v0, v2\n\t" + "vsetivli zero, 16, e32, m4\n\t" + "vredsum.vs v8, v4, v8\n\t" + "vmv.x.s %[vsums], v8" + : [tmp] "=&r" (tmp), [vsums] "=&r" (vsums) + : [sc] "r" (sc), [scale] "r" (atmp), [bsums] "r" (y[i].bsums) + : "memory" + , "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7" + , "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15" + , "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23" + , "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31" + ); + sumf += dmin * vsums; + int isum = 0; - sumf += dall * isum; + for (int j = 0; j < QK_K/128; ++j) { + __asm__ __volatile__( + "vsetvli zero, %[vl32], e8, m2\n\t" + "vle8.v v0, (%[q2])\n\t" + "vsrl.vi v2, v0, 2\n\t" + "vsrl.vi v4, v0, 4\n\t" + "vsrl.vi v6, v0, 6\n\t" + "vand.vi v0, v0, 0x3\n\t" + "vand.vi v2, v2, 0x3\n\t" + "vand.vi v4, v4, 0x3\n\t" + "vsetvli zero, %[vl128], e8, m8\n\t" + "vle8.v v8, (%[q8])\n\t" + "vsetvli zero, %[vl64], e8, m4\n\t" + "vwmul.vv v16, v0, v8\n\t" + "vwmul.vv v24, v4, v12\n\t" + "vsetivli zero, 16, e16, m2\n\t" + "vmv.v.x v0, zero\n\t" + "vwredsum.vs v10, v16, v0\n\t" + "vwredsum.vs v9, v18, v0\n\t" + "vwredsum.vs v8, v20, v0\n\t" + "vwredsum.vs v7, v22, v0\n\t" + "vwredsum.vs v11, v24, v0\n\t" + "vwredsum.vs v12, v26, v0\n\t" + "vwredsum.vs v13, v28, v0\n\t" + "vwredsum.vs v14, v30, v0\n\t" + "vsetivli zero, 4, e32, m1\n\t" + "vslideup.vi v10, v9, 1\n\t" + "vslideup.vi v8, v7, 1\n\t" + "vslideup.vi v11, v12, 1\n\t" + "vslideup.vi v13, v14, 1\n\t" + "vslideup.vi v10, v8, 2\n\t" + "vslideup.vi v11, v13, 2\n\t" + "vsetivli zero, 8, e32, m2\n\t" + "vle8.v v15, (%[scale])\n\t" + "vzext.vf4 v12, v15\n\t" + "vmul.vv v10, v10, v12\n\t" + "vredsum.vs v0, v10, v0\n\t" + "vmv.x.s %[tmp], v0\n\t" + "add %[isum], %[isum], %[tmp]" + : [tmp] "=&r" (tmp), [isum] "+&r" (isum) + : [q2] "r" (q2), [scale] "r" (patmp), [q8] "r" (q8) + , [vl32] "r" (32), [vl64] "r" (64), [vl128] "r" (128) + : "memory" + , "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7" + , "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15" + , "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23" + , "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31" + ); + q2 += 32; q8 += 128; patmp += 8; + } + sumf += dall * isum; + } + break; + default: + assert(false && "Unsupported vector length"); + break; } *s = sumf; @@ -6116,97 +6148,221 @@ void ggml_vec_dot_q3_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi uint32_t aux[3]; uint32_t utmp[4]; + const int vector_length = __riscv_vlenb() * 8; float sumf = 0; - for (int i = 0; i < nb; ++i) { - const uint8_t * GGML_RESTRICT q3 = x[i].qs; - const uint8_t * GGML_RESTRICT qh = x[i].hmask; - const int8_t * GGML_RESTRICT q8 = y[i].qs; + switch (vector_length) { + case 256: + for (int i = 0; i < nb; ++i) { - memcpy(aux, x[i].scales, 12); - utmp[3] = ((aux[1] >> 4) & kmask2) | (((aux[2] >> 6) & kmask1) << 4); - utmp[2] = ((aux[0] >> 4) & kmask2) | (((aux[2] >> 4) & kmask1) << 4); - utmp[1] = (aux[1] & kmask2) | (((aux[2] >> 2) & kmask1) << 4); - utmp[0] = (aux[0] & kmask2) | (((aux[2] >> 0) & kmask1) << 4); + const uint8_t * GGML_RESTRICT q3 = x[i].qs; + const uint8_t * GGML_RESTRICT qh = x[i].hmask; + const int8_t * GGML_RESTRICT q8 = y[i].qs; - int8_t * scale = (int8_t *)utmp; - for (int j = 0; j < 16; ++j) scale[j] -= 32; + memcpy(aux, x[i].scales, 12); + utmp[3] = ((aux[1] >> 4) & kmask2) | (((aux[2] >> 6) & kmask1) << 4); + utmp[2] = ((aux[0] >> 4) & kmask2) | (((aux[2] >> 4) & kmask1) << 4); + utmp[1] = (aux[1] & kmask2) | (((aux[2] >> 2) & kmask1) << 4); + utmp[0] = (aux[0] & kmask2) | (((aux[2] >> 0) & kmask1) << 4); + + int8_t * scale = (int8_t *)utmp; + for (int j = 0; j < 16; ++j) scale[j] -= 32; - size_t vl = 32; - uint8_t m = 1; + size_t vl = 32; + uint8_t m = 1; - vint32m1_t vzero = __riscv_vmv_v_x_i32m1(0, 1); - vuint8m1_t vqh = __riscv_vle8_v_u8m1(qh, vl); + vint32m1_t vzero = __riscv_vmv_v_x_i32m1(0, 1); + vuint8m1_t vqh = __riscv_vle8_v_u8m1(qh, vl); - int sum_t = 0; + int sum_t = 0; - for (int j = 0; j < QK_K; j += 128) { + for (int j = 0; j < QK_K; j += 128) { - vl = 32; + vl = 32; - // load Q3 - vuint8m1_t q3_x = __riscv_vle8_v_u8m1(q3, vl); + // load Q3 + vuint8m1_t q3_x = __riscv_vle8_v_u8m1(q3, vl); - vint8m1_t q3_0 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(q3_x, 0x03, vl)); - vint8m1_t q3_1 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q3_x, 0x2, vl), 0x03 , vl)); - vint8m1_t q3_2 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q3_x, 0x4, vl), 0x03 , vl)); - vint8m1_t q3_3 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q3_x, 0x6, vl), 0x03 , vl)); + vint8m1_t q3_0 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(q3_x, 0x03, vl)); + vint8m1_t q3_1 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q3_x, 0x2, vl), 0x03 , vl)); + vint8m1_t q3_2 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q3_x, 0x4, vl), 0x03 , vl)); + vint8m1_t q3_3 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q3_x, 0x6, vl), 0x03 , vl)); - // compute mask for subtraction - vuint8m1_t qh_m0 = __riscv_vand_vx_u8m1(vqh, m, vl); - vbool8_t vmask_0 = __riscv_vmseq_vx_u8m1_b8(qh_m0, 0, vl); - vint8m1_t q3_m0 = __riscv_vsub_vx_i8m1_mu(vmask_0, q3_0, q3_0, 0x4, vl); - m <<= 1; + // compute mask for subtraction + vuint8m1_t qh_m0 = __riscv_vand_vx_u8m1(vqh, m, vl); + vbool8_t vmask_0 = __riscv_vmseq_vx_u8m1_b8(qh_m0, 0, vl); + vint8m1_t q3_m0 = __riscv_vsub_vx_i8m1_mu(vmask_0, q3_0, q3_0, 0x4, vl); + m <<= 1; - vuint8m1_t qh_m1 = __riscv_vand_vx_u8m1(vqh, m, vl); - vbool8_t vmask_1 = __riscv_vmseq_vx_u8m1_b8(qh_m1, 0, vl); - vint8m1_t q3_m1 = __riscv_vsub_vx_i8m1_mu(vmask_1, q3_1, q3_1, 0x4, vl); - m <<= 1; + vuint8m1_t qh_m1 = __riscv_vand_vx_u8m1(vqh, m, vl); + vbool8_t vmask_1 = __riscv_vmseq_vx_u8m1_b8(qh_m1, 0, vl); + vint8m1_t q3_m1 = __riscv_vsub_vx_i8m1_mu(vmask_1, q3_1, q3_1, 0x4, vl); + m <<= 1; - vuint8m1_t qh_m2 = __riscv_vand_vx_u8m1(vqh, m, vl); - vbool8_t vmask_2 = __riscv_vmseq_vx_u8m1_b8(qh_m2, 0, vl); - vint8m1_t q3_m2 = __riscv_vsub_vx_i8m1_mu(vmask_2, q3_2, q3_2, 0x4, vl); - m <<= 1; + vuint8m1_t qh_m2 = __riscv_vand_vx_u8m1(vqh, m, vl); + vbool8_t vmask_2 = __riscv_vmseq_vx_u8m1_b8(qh_m2, 0, vl); + vint8m1_t q3_m2 = __riscv_vsub_vx_i8m1_mu(vmask_2, q3_2, q3_2, 0x4, vl); + m <<= 1; - vuint8m1_t qh_m3 = __riscv_vand_vx_u8m1(vqh, m, vl); - vbool8_t vmask_3 = __riscv_vmseq_vx_u8m1_b8(qh_m3, 0, vl); - vint8m1_t q3_m3 = __riscv_vsub_vx_i8m1_mu(vmask_3, q3_3, q3_3, 0x4, vl); - m <<= 1; + vuint8m1_t qh_m3 = __riscv_vand_vx_u8m1(vqh, m, vl); + vbool8_t vmask_3 = __riscv_vmseq_vx_u8m1_b8(qh_m3, 0, vl); + vint8m1_t q3_m3 = __riscv_vsub_vx_i8m1_mu(vmask_3, q3_3, q3_3, 0x4, vl); + m <<= 1; - // load Q8 and take product with Q3 - vint16m2_t a0 = __riscv_vwmul_vv_i16m2(q3_m0, __riscv_vle8_v_i8m1(q8, vl), vl); - vint16m2_t a1 = __riscv_vwmul_vv_i16m2(q3_m1, __riscv_vle8_v_i8m1(q8+32, vl), vl); - vint16m2_t a2 = __riscv_vwmul_vv_i16m2(q3_m2, __riscv_vle8_v_i8m1(q8+64, vl), vl); - vint16m2_t a3 = __riscv_vwmul_vv_i16m2(q3_m3, __riscv_vle8_v_i8m1(q8+96, vl), vl); + // load Q8 and take product with Q3 + vint16m2_t a0 = __riscv_vwmul_vv_i16m2(q3_m0, __riscv_vle8_v_i8m1(q8, vl), vl); + vint16m2_t a1 = __riscv_vwmul_vv_i16m2(q3_m1, __riscv_vle8_v_i8m1(q8+32, vl), vl); + vint16m2_t a2 = __riscv_vwmul_vv_i16m2(q3_m2, __riscv_vle8_v_i8m1(q8+64, vl), vl); + vint16m2_t a3 = __riscv_vwmul_vv_i16m2(q3_m3, __riscv_vle8_v_i8m1(q8+96, vl), vl); - vl = 16; + vl = 16; - // retrieve lane to multiply with scale - vint32m2_t aux0_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a0, 0), (scale[0]), vl); - vint32m2_t aux0_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a0, 1), (scale[1]), vl); - vint32m2_t aux1_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a1, 0), (scale[2]), vl); - vint32m2_t aux1_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a1, 1), (scale[3]), vl); - vint32m2_t aux2_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a2, 0), (scale[4]), vl); - vint32m2_t aux2_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a2, 1), (scale[5]), vl); - vint32m2_t aux3_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a3, 0), (scale[6]), vl); - vint32m2_t aux3_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a3, 1), (scale[7]), vl); + // retrieve lane to multiply with scale + vint32m2_t aux0_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a0, 0), (scale[0]), vl); + vint32m2_t aux0_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a0, 1), (scale[1]), vl); + vint32m2_t aux1_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a1, 0), (scale[2]), vl); + vint32m2_t aux1_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a1, 1), (scale[3]), vl); + vint32m2_t aux2_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a2, 0), (scale[4]), vl); + vint32m2_t aux2_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a2, 1), (scale[5]), vl); + vint32m2_t aux3_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a3, 0), (scale[6]), vl); + vint32m2_t aux3_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a3, 1), (scale[7]), vl); - vint32m1_t isum0 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(aux0_0, aux0_1, vl), vzero, vl); - vint32m1_t isum1 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(aux1_0, aux1_1, vl), isum0, vl); - vint32m1_t isum2 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(aux2_0, aux2_1, vl), isum1, vl); - vint32m1_t isum3 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(aux3_0, aux3_1, vl), isum2, vl); + vint32m1_t isum0 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(aux0_0, aux0_1, vl), vzero, vl); + vint32m1_t isum1 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(aux1_0, aux1_1, vl), isum0, vl); + vint32m1_t isum2 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(aux2_0, aux2_1, vl), isum1, vl); + vint32m1_t isum3 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(aux3_0, aux3_1, vl), isum2, vl); - sum_t += __riscv_vmv_x_s_i32m1_i32(isum3); + sum_t += __riscv_vmv_x_s_i32m1_i32(isum3); - q3 += 32; q8 += 128; scale += 8; + q3 += 32; q8 += 128; scale += 8; + + } + + const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; + + sumf += d*sum_t; } + break; + case 128: + for (int i = 0; i < nb; ++i) { + const uint8_t * restrict q3 = x[i].qs; + const uint8_t * restrict qh = x[i].hmask; + const int8_t * restrict q8 = y[i].qs; - const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; + int8_t * scale = (int8_t *)utmp; + int tmp; + __asm__ __volatile__( + "vsetivli zero, 12, e8, m1\n\t" + "vle8.v v0, (%[s6b])\n\t" + "vmv1r.v v2, v0\n\t" + "vsetivli zero, 2, e64, m1\n\t" + "vmv.v.x v9, %[sh]\n\t"\ + "vslidedown.vi v1, v0, 1\n\t" + "vslide1up.vx v8, v9, zero\n\t" // {0, 0, 4, 4} + "vslideup.vi v0, v2, 1\n\t" // {aux[0], aux[1], aux[0], aux[1]} + "vsetivli zero, 4, e32, m1\n\t" + "vid.v v9\n\t" + "vmv.x.s %[tmp], v1\n\t" + "vsll.vi v9, v9, 1\n\t" // {0, 2, 4, 6} + "vmv.v.x v1, %[tmp]\n\t" // {aux[2], aux[2], aux[2], aux[2]} + "vsrl.vv v4, v1, v9\n\t" + "vsrl.vv v2, v0, v8\n\t" + "vand.vx v5, v4, %[kmask1]\n\t" + "vand.vx v3, v2, %[kmask2]\n\t" + "vsll.vi v6, v5, 4\n\t" + "vor.vv v7, v6, v3\n\t" + "vsetivli zero, 16, e8, m1\n\t" + "vsub.vx v0, v7, %[c]\n\t" + "vse8.v v0, (%[scale])" + : [tmp] "=&r" (tmp) + : [sh] "r" (0x0000000400000004), [s6b] "r" (x[i].scales), [c] "r" (32) + , [scale] "r" (scale), [kmask1] "r" (kmask1), [kmask2] "r" (kmask2) + : "memory" + , "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7" + , "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15" + , "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23" + , "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31" + ); - sumf += d*sum_t; + uint8_t m = 1; + int isum = 0; + for (int j = 0; j < QK_K; j += 128) { + __asm__ __volatile__( + "vsetvli zero, %[vl32], e8, m2, ta, mu\n\t" + "vle8.v v8, (%[q3])\n\t" + "vsrl.vi v10, v8, 2\n\t" + "vsrl.vi v12, v8, 4\n\t" + "vsrl.vi v14, v8, 6\n\t" + "vand.vi v8, v8, 3\n\t" + "vand.vi v10, v10, 3\n\t" + "vand.vi v12, v12, 3\n\t" + "vle8.v v2, (%[qh])\n\t" + "vand.vx v4, v2, %[m]\n\t" + "slli %[m], %[m], 1\n\t" + "vmseq.vx v0, v4, zero\n\t" + "vadd.vi v8, v8, -4, v0.t\n\t" + "vand.vx v4, v2, %[m]\n\t" + "slli %[m], %[m], 1\n\t" + "vmseq.vx v0, v4, zero\n\t" + "vadd.vi v10, v10, -4, v0.t\n\t" + "vand.vx v4, v2, %[m]\n\t" + "slli %[m], %[m], 1\n\t" + "vmseq.vx v0, v4, zero\n\t" + "vadd.vi v12, v12, -4, v0.t\n\t" + "vand.vx v4, v2, %[m]\n\t" + "slli %[m], %[m], 1\n\t" + "vmseq.vx v0, v4, zero\n\t" + "vadd.vi v14, v14, -4, v0.t\n\t" + "vsetvli zero, %[vl128], e8, m8\n\t" + "vle8.v v0, (%[q8])\n\t" + "vsetvli zero, %[vl64], e8, m4\n\t" + "vwmul.vv v16, v0, v8\n\t" + "vwmul.vv v24, v4, v12\n\t" + "vsetivli zero, 16, e16, m2\n\t" + "vmv.v.x v0, zero\n\t" + "vwredsum.vs v10, v16, v0\n\t" + "vwredsum.vs v9, v18, v0\n\t" + "vwredsum.vs v8, v20, v0\n\t" + "vwredsum.vs v7, v22, v0\n\t" + "vwredsum.vs v11, v24, v0\n\t" + "vwredsum.vs v12, v26, v0\n\t" + "vwredsum.vs v13, v28, v0\n\t" + "vwredsum.vs v14, v30, v0\n\t" + "vsetivli zero, 4, e32, m1\n\t" + "vslideup.vi v10, v9, 1\n\t" + "vslideup.vi v8, v7, 1\n\t" + "vslideup.vi v11, v12, 1\n\t" + "vslideup.vi v13, v14, 1\n\t" + "vslideup.vi v10, v8, 2\n\t" + "vslideup.vi v11, v13, 2\n\t" + "vsetivli zero, 8, e32, m2\n\t"\ + "vle8.v v15, (%[scale])\n\t" + "vsext.vf4 v12, v15\n\t" + "vmul.vv v10, v10, v12\n\t" + "vredsum.vs v0, v10, v0\n\t" + "vmv.x.s %[tmp], v0\n\t" + "add %[isum], %[isum], %[tmp]" + : [tmp] "=&r" (tmp), [m] "+&r" (m), [isum] "+&r" (isum) + : [vl128] "r" (128), [vl64] "r" (64), [vl32] "r" (32) + , [q3] "r" (q3), [qh] "r" (qh), [scale] "r" (scale), [q8] "r" (q8) + : "memory" + , "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7" + , "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15" + , "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23" + , "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31" + ); + q3 += 32; q8 += 128; scale += 8; + } + const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; + sumf += d * isum; + } + break; + default: + assert(false && "Unsupported vector length"); + break; } *s = sumf; @@ -6924,69 +7080,181 @@ void ggml_vec_dot_q4_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi const uint8_t * scales = (const uint8_t*)&utmp[0]; const uint8_t * mins = (const uint8_t*)&utmp[2]; + const int vector_length = __riscv_vlenb() * 8; float sumf = 0; - for (int i = 0; i < nb; ++i) { + switch (vector_length) { + case 256: + for (int i = 0; i < nb; ++i) { - size_t vl = 8; + size_t vl = 8; - const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); - const float dmin = y[i].d * GGML_FP16_TO_FP32(x[i].dmin); + const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); + const float dmin = y[i].d * GGML_FP16_TO_FP32(x[i].dmin); - vint16mf2_t q8sums_0 = __riscv_vlse16_v_i16mf2(y[i].bsums, 4, vl); - vint16mf2_t q8sums_1 = __riscv_vlse16_v_i16mf2(y[i].bsums+1, 4, vl); - vint16mf2_t q8sums = __riscv_vadd_vv_i16mf2(q8sums_0, q8sums_1, vl); + vint16mf2_t q8sums_0 = __riscv_vlse16_v_i16mf2(y[i].bsums, 4, vl); + vint16mf2_t q8sums_1 = __riscv_vlse16_v_i16mf2(y[i].bsums+1, 4, vl); + vint16mf2_t q8sums = __riscv_vadd_vv_i16mf2(q8sums_0, q8sums_1, vl); - memcpy(utmp, x[i].scales, 12); - utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4); - const uint32_t uaux = utmp[1] & kmask1; - utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4); - utmp[2] = uaux; - utmp[0] &= kmask1; + memcpy(utmp, x[i].scales, 12); + utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4); + const uint32_t uaux = utmp[1] & kmask1; + utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4); + utmp[2] = uaux; + utmp[0] &= kmask1; - vuint8mf4_t mins8 = __riscv_vle8_v_u8mf4(mins, vl); - vint16mf2_t v_mins = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vzext_vf2_u16mf2(mins8, vl)); - vint32m1_t prod = __riscv_vwmul_vv_i32m1(q8sums, v_mins, vl); + vuint8mf4_t mins8 = __riscv_vle8_v_u8mf4(mins, vl); + vint16mf2_t v_mins = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vzext_vf2_u16mf2(mins8, vl)); + vint32m1_t prod = __riscv_vwmul_vv_i32m1(q8sums, v_mins, vl); - vint32m1_t sumi = __riscv_vredsum_vs_i32m1_i32m1(prod, __riscv_vmv_v_x_i32m1(0, 1), vl); - sumf -= dmin * __riscv_vmv_x_s_i32m1_i32(sumi); + vint32m1_t sumi = __riscv_vredsum_vs_i32m1_i32m1(prod, __riscv_vmv_v_x_i32m1(0, 1), vl); + sumf -= dmin * __riscv_vmv_x_s_i32m1_i32(sumi); - const uint8_t * GGML_RESTRICT q4 = x[i].qs; - const int8_t * GGML_RESTRICT q8 = y[i].qs; + const uint8_t * GGML_RESTRICT q4 = x[i].qs; + const int8_t * GGML_RESTRICT q8 = y[i].qs; - vl = 32; + vl = 32; - int32_t sum_1 = 0; - int32_t sum_2 = 0; + int32_t sum_1 = 0; + int32_t sum_2 = 0; - vint16m1_t vzero = __riscv_vmv_v_x_i16m1(0, 1); + vint16m1_t vzero = __riscv_vmv_v_x_i16m1(0, 1); - for (int j = 0; j < QK_K/64; ++j) { - // load Q4 - vuint8m1_t q4_x = __riscv_vle8_v_u8m1(q4, vl); + for (int j = 0; j < QK_K/64; ++j) { + // load Q4 + vuint8m1_t q4_x = __riscv_vle8_v_u8m1(q4, vl); - // load Q8 and multiply it with lower Q4 nibble - vint8m1_t q8_0 = __riscv_vle8_v_i8m1(q8, vl); - vint8m1_t q4_0 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(q4_x, 0x0F, vl)); - vint16m2_t qv_0 = __riscv_vwmul_vv_i16m2(q4_0, q8_0, vl); - vint16m1_t vs_0 = __riscv_vredsum_vs_i16m2_i16m1(qv_0, vzero, vl); + // load Q8 and multiply it with lower Q4 nibble + vint8m1_t q8_0 = __riscv_vle8_v_i8m1(q8, vl); + vint8m1_t q4_0 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(q4_x, 0x0F, vl)); + vint16m2_t qv_0 = __riscv_vwmul_vv_i16m2(q4_0, q8_0, vl); + vint16m1_t vs_0 = __riscv_vredsum_vs_i16m2_i16m1(qv_0, vzero, vl); - sum_1 += __riscv_vmv_x_s_i16m1_i16(vs_0) * scales[2*j+0]; + sum_1 += __riscv_vmv_x_s_i16m1_i16(vs_0) * scales[2*j+0]; - // load Q8 and multiply it with upper Q4 nibble - vint8m1_t q8_1 = __riscv_vle8_v_i8m1(q8+32, vl); - vint8m1_t q4_1 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vsrl_vx_u8m1(q4_x, 0x04, vl)); - vint16m2_t qv_1 = __riscv_vwmul_vv_i16m2(q4_1, q8_1, vl); - vint16m1_t vs_1 = __riscv_vredsum_vs_i16m2_i16m1(qv_1, vzero, vl); + // load Q8 and multiply it with upper Q4 nibble + vint8m1_t q8_1 = __riscv_vle8_v_i8m1(q8+32, vl); + vint8m1_t q4_1 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vsrl_vx_u8m1(q4_x, 0x04, vl)); + vint16m2_t qv_1 = __riscv_vwmul_vv_i16m2(q4_1, q8_1, vl); + vint16m1_t vs_1 = __riscv_vredsum_vs_i16m2_i16m1(qv_1, vzero, vl); - sum_2 += __riscv_vmv_x_s_i16m1_i16(vs_1) * scales[2*j+1]; + sum_2 += __riscv_vmv_x_s_i16m1_i16(vs_1) * scales[2*j+1]; - q4 += 32; q8 += 64; + q4 += 32; q8 += 64; + + } + + sumf += d*(sum_1 + sum_2); } + break; + case 128: + for (int i = 0; i < nb; ++i) { + const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); + const float dmin = y[i].d * GGML_FP16_TO_FP32(x[i].dmin); - sumf += d*(sum_1 + sum_2); + int tmp, tmp2, sumi; + __asm__ __volatile__( + "vsetivli zero, 12, e8, m1\n\t" + "vle8.v v1, (%[s6b])\n\t" // {aux[0], aux[1], aux[2]} + "vsetivli zero, 4, e32, m1\n\t" + "vslidedown.vi v2, v1, 2\n\t" + "vmv1r.v v3, v2\n\t" + "vslideup.vi v2, v3, 1\n\t" // {aux[2], aux[2]} + "vsetivli zero, 2, e32, m1\n\t" + "vmv.v.i v4, 4\n\t" + "vand.vx v8, v1, %[kmask1]\n\t" + "vslide1up.vx v5, v4, zero\n\t" // {0, 4} + "vsrl.vi v6, v1, 6\n\t" + "vsrl.vv v7, v2, v5\n\t" + "vand.vx v0, v6, %[kmask3]\n\t" + "vand.vx v2, v7, %[kmask2]\n\t" + "vsll.vi v6, v0, 4\n\t" + "li %[t2], 8\n\t" + "addi %[t1], %[utmp], 4\n\t" + "vor.vv v1, v6, v2\n\t" + "vsse32.v v8, (%[utmp]), %[t2]\n\t" + "vsse32.v v1, (%[t1]), %[t2]\n\t" + "vsetivli zero, 8, e16, m1\n\t" + "vle32.v v2, (%[bsums])\n\t" + "vnsrl.wi v0, v2, 0\n\t" + "vnsrl.wi v1, v2, 16\n\t" + "vadd.vv v2, v0, v1\n\t" + "vle8.v v3, (%[mins])\n\t" + "vzext.vf2 v4, v3\n\t" + "vwmul.vv v6, v4, v2\n\t" + "vmv.v.x v0, zero\n\t" + "vsetivli zero, 8, e32, m2\n\t" + "vredsum.vs v0, v6, v0\n\t" + "vmv.x.s %[sumi], v0" + : [t1] "=&r" (tmp), [t2] "=&r" (tmp2), [sumi] "=&r" (sumi) + : [bsums] "r" (y[i].bsums), [mins] "r" (mins), [utmp] "r" (utmp) + , [s6b] "r" (x[i].scales), [kmask1] "r" (kmask1) + , [kmask2] "r" (kmask2), [kmask3] "r" (kmask3) + : "memory" + , "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7" + , "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15" + , "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23" + , "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31" + ); + sumf -= dmin * sumi; + const uint8_t * restrict q4 = x[i].qs; + const int8_t * restrict q8 = y[i].qs; + + sumi = 0; + const uint8_t * scale = scales; + + for (int j = 0; j < QK_K/128; ++j) { + int vl128 = 128, vl64 = 64, vl32 = 32; + __asm__ __volatile__( + "vsetvli zero, %[vl128], e8, m8\n\t" + "vle8.v v8, (%[q8])\n\t" + "vsetvli zero, %[vl64], e8, m4\n\t" + "vle8.v v0, (%[q4])\n\t" + "vsrl.vi v4, v0, 4\n\t" + "vand.vi v0, v0, 0xF\n\t" + "vsetvli zero, %[vl32], e8, m2\n\t" + "vwmul.vv v28, v6, v14\n\t" + "vwmul.vv v20, v4, v10\n\t" + "vwmul.vv v24, v2, v12\n\t" + "vwmul.vv v16, v0, v8\n\t" + "vsetivli zero, 4, e32, m1\n\t" + "vle8.v v2, (%[scale])\n\t" + "vmv.v.x v0, zero\n\t" + "vzext.vf4 v1, v2\n\t" + "vsetvli zero, %[vl32], e16, m4\n\t" + "vwredsum.vs v6, v24, v0\n\t" + "vwredsum.vs v7, v28, v0\n\t" + "vwredsum.vs v4, v16, v0\n\t" + "vwredsum.vs v5, v20, v0\n\t" + "vsetivli zero, 4, e32, m1\n\t" + "vslideup.vi v6, v7, 1\n\t" + "vslideup.vi v4, v5, 1\n\t" + "vslideup.vi v4, v6, 2\n\t" + "vmul.vv v8, v4, v1\n\t" + "vredsum.vs v0, v8, v0\n\t" + "vmv.x.s %[tmp], v0\n\t" + "add %[sumi], %[sumi], %[tmp]" + : [tmp] "=&r" (tmp), [sumi] "+&r" (sumi) + : [vl128] "r" (vl128), [vl64] "r" (vl64), [vl32] "r" (vl32) + , [q4] "r" (q4), [q8] "r" (q8), [scale] "r" (scale) + : "memory" + , "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7" + , "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15" + , "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23" + , "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31" + ); + + q4 += 64; q8 += 128; scale += 4; + } + + sumf += d * sumi; + } + break; + default: + assert(false && "Unsupported vector length"); + break; } *s = sumf; @@ -7722,9 +7990,9 @@ void ggml_vec_dot_q5_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; const float dmin = GGML_FP16_TO_FP32(x[i].dmin) * y[i].d; - vint16mf2_t q8sums_0 = __riscv_vlse16_v_i16mf2(y[i].bsums, 4, vl); - vint16mf2_t q8sums_1 = __riscv_vlse16_v_i16mf2(y[i].bsums+1, 4, vl); - vint16mf2_t q8sums = __riscv_vadd_vv_i16mf2(q8sums_0, q8sums_1, vl); + vint16m1_t q8sums_0 = __riscv_vlse16_v_i16m1(y[i].bsums, 4, vl); + vint16m1_t q8sums_1 = __riscv_vlse16_v_i16m1(y[i].bsums+1, 4, vl); + vint16m1_t q8sums = __riscv_vadd_vv_i16m1(q8sums_0, q8sums_1, vl); memcpy(utmp, x[i].scales, 12); utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4); @@ -7733,11 +8001,11 @@ void ggml_vec_dot_q5_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi utmp[2] = uaux; utmp[0] &= kmask1; - vuint8mf4_t mins8 = __riscv_vle8_v_u8mf4(mins, vl); - vint16mf2_t v_mins = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vzext_vf2_u16mf2(mins8, vl)); - vint32m1_t prod = __riscv_vwmul_vv_i32m1(q8sums, v_mins, vl); + vuint8mf2_t mins8 = __riscv_vle8_v_u8mf2(mins, vl); + vint16m1_t v_mins = __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vzext_vf2_u16m1(mins8, vl)); + vint32m2_t prod = __riscv_vwmul_vv_i32m2(q8sums, v_mins, vl); - vint32m1_t sumi = __riscv_vredsum_vs_i32m1_i32m1(prod, __riscv_vmv_v_x_i32m1(0, 1), vl); + vint32m1_t sumi = __riscv_vredsum_vs_i32m2_i32m1(prod, __riscv_vmv_v_x_i32m1(0, 1), vl); sumf -= dmin * __riscv_vmv_x_s_i32m1_i32(sumi); vl = 32; @@ -7746,43 +8014,42 @@ void ggml_vec_dot_q5_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi uint8_t m = 1; vint32m1_t vzero = __riscv_vmv_v_x_i32m1(0, 1); - vuint8m1_t vqh = __riscv_vle8_v_u8m1(hm, vl); + vuint8m2_t vqh = __riscv_vle8_v_u8m2(hm, vl); for (int j = 0; j < QK_K/64; ++j) { // load Q5 and Q8 - vuint8m1_t q5_x = __riscv_vle8_v_u8m1(q5, vl); - vint8m1_t q8_y1 = __riscv_vle8_v_i8m1(q8, vl); - vint8m1_t q8_y2 = __riscv_vle8_v_i8m1(q8+32, vl); + vuint8m2_t q5_x = __riscv_vle8_v_u8m2(q5, vl); + vint8m2_t q8_y1 = __riscv_vle8_v_i8m2(q8, vl); + vint8m2_t q8_y2 = __riscv_vle8_v_i8m2(q8+32, vl); // compute mask for addition - vint8m1_t q5_a = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(q5_x, 0x0F, vl)); - vuint8m1_t qh_m1 = __riscv_vand_vx_u8m1(vqh, m, vl); - vbool8_t vmask_1 = __riscv_vmsne_vx_u8m1_b8(qh_m1, 0, vl); - vint8m1_t q5_m1 = __riscv_vadd_vx_i8m1_mu(vmask_1, q5_a, q5_a, 16, vl); + vint8m2_t q5_a = __riscv_vreinterpret_v_u8m2_i8m2(__riscv_vand_vx_u8m2(q5_x, 0x0F, vl)); + vuint8m2_t qh_m1 = __riscv_vand_vx_u8m2(vqh, m, vl); + vbool4_t vmask_1 = __riscv_vmsne_vx_u8m2_b4(qh_m1, 0, vl); + vint8m2_t q5_m1 = __riscv_vadd_vx_i8m2_mu(vmask_1, q5_a, q5_a, 16, vl); m <<= 1; - vint8m1_t q5_l = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vsrl_vx_u8m1(q5_x, 0x04, vl)); - vuint8m1_t qh_m2 = __riscv_vand_vx_u8m1(vqh, m, vl); - vbool8_t vmask_2 = __riscv_vmsne_vx_u8m1_b8(qh_m2, 0, vl); - vint8m1_t q5_m2 = __riscv_vadd_vx_i8m1_mu(vmask_2, q5_l, q5_l, 16, vl); + vint8m2_t q5_l = __riscv_vreinterpret_v_u8m2_i8m2(__riscv_vsrl_vx_u8m2(q5_x, 0x04, vl)); + vuint8m2_t qh_m2 = __riscv_vand_vx_u8m2(vqh, m, vl); + vbool4_t vmask_2 = __riscv_vmsne_vx_u8m2_b4(qh_m2, 0, vl); + vint8m2_t q5_m2 = __riscv_vadd_vx_i8m2_mu(vmask_2, q5_l, q5_l, 16, vl); m <<= 1; - vint16m2_t v0 = __riscv_vwmul_vv_i16m2(q5_m1, q8_y1, vl); - vint16m2_t v1 = __riscv_vwmul_vv_i16m2(q5_m2, q8_y2, vl); + vint16m4_t v0 = __riscv_vwmul_vv_i16m4(q5_m1, q8_y1, vl); + vint16m4_t v1 = __riscv_vwmul_vv_i16m4(q5_m2, q8_y2, vl); - vint32m4_t vs1 = __riscv_vwmul_vx_i32m4(v0, scales[is++], vl); - vint32m4_t vs2 = __riscv_vwmul_vx_i32m4(v1, scales[is++], vl); + vint32m8_t vs1 = __riscv_vwmul_vx_i32m8(v0, scales[is++], vl); + vint32m8_t vs2 = __riscv_vwmul_vx_i32m8(v1, scales[is++], vl); - vint32m1_t vacc1 = __riscv_vredsum_vs_i32m4_i32m1(vs1, vzero, vl); - vint32m1_t vacc2 = __riscv_vredsum_vs_i32m4_i32m1(vs2, vzero, vl); + vint32m1_t vacc1 = __riscv_vredsum_vs_i32m8_i32m1(vs1, vzero, vl); + vint32m1_t vacc2 = __riscv_vredsum_vs_i32m8_i32m1(vs2, vacc1, vl); - aux32 += __riscv_vmv_x_s_i32m1_i32(vacc1) + __riscv_vmv_x_s_i32m1_i32(vacc2); + aux32 += __riscv_vmv_x_s_i32m1_i32(vacc2); q5 += 32; q8 += 64; } - vfloat32m1_t vaux = __riscv_vfmul_vf_f32m1(__riscv_vfmv_v_f_f32m1(aux32, 1), d, 1); - sums += __riscv_vfmv_f_s_f32m1_f32(vaux); + sums += aux32 * d; } @@ -8667,85 +8934,168 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi #elif defined __riscv_v_intrinsic + const int vector_length = __riscv_vlenb() * 8; float sumf = 0; - for (int i = 0; i < nb; ++i) { - const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; + switch (vector_length) { + case 256: + for (int i = 0; i < nb; ++i) { - const uint8_t * GGML_RESTRICT q6 = x[i].ql; - const uint8_t * GGML_RESTRICT qh = x[i].qh; - const int8_t * GGML_RESTRICT q8 = y[i].qs; + const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; - const int8_t * GGML_RESTRICT scale = x[i].scales; + const uint8_t * GGML_RESTRICT q6 = x[i].ql; + const uint8_t * GGML_RESTRICT qh = x[i].qh; + const int8_t * GGML_RESTRICT q8 = y[i].qs; - size_t vl; + const int8_t * GGML_RESTRICT scale = x[i].scales; - vint32m1_t vzero = __riscv_vmv_v_x_i32m1(0, 1); + size_t vl; - int sum_t = 0; - int is = 0; + vint32m1_t vzero = __riscv_vmv_v_x_i32m1(0, 1); - for (int j = 0; j < QK_K/128; ++j) { + int sum_t = 0; + int is = 0; - vl = 32; + for (int j = 0; j < QK_K/128; ++j) { - // load qh - vuint8m1_t qh_x = __riscv_vle8_v_u8m1(qh, vl); + vl = 32; - // load Q6 - vuint8m1_t q6_0 = __riscv_vle8_v_u8m1(q6, vl); - vuint8m1_t q6_1 = __riscv_vle8_v_u8m1(q6+32, vl); + // load qh + vuint8m1_t qh_x = __riscv_vle8_v_u8m1(qh, vl); - vuint8m1_t q6a_0 = __riscv_vand_vx_u8m1(q6_0, 0x0F, vl); - vuint8m1_t q6a_1 = __riscv_vand_vx_u8m1(q6_1, 0x0F, vl); - vuint8m1_t q6s_0 = __riscv_vsrl_vx_u8m1(q6_0, 0x04, vl); - vuint8m1_t q6s_1 = __riscv_vsrl_vx_u8m1(q6_1, 0x04, vl); + // load Q6 + vuint8m1_t q6_0 = __riscv_vle8_v_u8m1(q6, vl); + vuint8m1_t q6_1 = __riscv_vle8_v_u8m1(q6+32, vl); - vuint8m1_t qh_0 = __riscv_vand_vx_u8m1(qh_x, 0x03, vl); - vuint8m1_t qh_1 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(qh_x, 0x2, vl), 0x03 , vl); - vuint8m1_t qh_2 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(qh_x, 0x4, vl), 0x03 , vl); - vuint8m1_t qh_3 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(qh_x, 0x6, vl), 0x03 , vl); + vuint8m1_t q6a_0 = __riscv_vand_vx_u8m1(q6_0, 0x0F, vl); + vuint8m1_t q6a_1 = __riscv_vand_vx_u8m1(q6_1, 0x0F, vl); + vuint8m1_t q6s_0 = __riscv_vsrl_vx_u8m1(q6_0, 0x04, vl); + vuint8m1_t q6s_1 = __riscv_vsrl_vx_u8m1(q6_1, 0x04, vl); - vuint8m1_t qhi_0 = __riscv_vor_vv_u8m1(q6a_0, __riscv_vsll_vx_u8m1(qh_0, 0x04, vl), vl); - vuint8m1_t qhi_1 = __riscv_vor_vv_u8m1(q6a_1, __riscv_vsll_vx_u8m1(qh_1, 0x04, vl), vl); - vuint8m1_t qhi_2 = __riscv_vor_vv_u8m1(q6s_0, __riscv_vsll_vx_u8m1(qh_2, 0x04, vl), vl); - vuint8m1_t qhi_3 = __riscv_vor_vv_u8m1(q6s_1, __riscv_vsll_vx_u8m1(qh_3, 0x04, vl), vl); + vuint8m1_t qh_0 = __riscv_vand_vx_u8m1(qh_x, 0x03, vl); + vuint8m1_t qh_1 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(qh_x, 0x2, vl), 0x03 , vl); + vuint8m1_t qh_2 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(qh_x, 0x4, vl), 0x03 , vl); + vuint8m1_t qh_3 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(qh_x, 0x6, vl), 0x03 , vl); - vint8m1_t a_0 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(qhi_0), 32, vl); - vint8m1_t a_1 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(qhi_1), 32, vl); - vint8m1_t a_2 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(qhi_2), 32, vl); - vint8m1_t a_3 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(qhi_3), 32, vl); + vuint8m1_t qhi_0 = __riscv_vor_vv_u8m1(q6a_0, __riscv_vsll_vx_u8m1(qh_0, 0x04, vl), vl); + vuint8m1_t qhi_1 = __riscv_vor_vv_u8m1(q6a_1, __riscv_vsll_vx_u8m1(qh_1, 0x04, vl), vl); + vuint8m1_t qhi_2 = __riscv_vor_vv_u8m1(q6s_0, __riscv_vsll_vx_u8m1(qh_2, 0x04, vl), vl); + vuint8m1_t qhi_3 = __riscv_vor_vv_u8m1(q6s_1, __riscv_vsll_vx_u8m1(qh_3, 0x04, vl), vl); - // load Q8 and take product - vint16m2_t va_q_0 = __riscv_vwmul_vv_i16m2(a_0, __riscv_vle8_v_i8m1(q8, vl), vl); - vint16m2_t va_q_1 = __riscv_vwmul_vv_i16m2(a_1, __riscv_vle8_v_i8m1(q8+32, vl), vl); - vint16m2_t va_q_2 = __riscv_vwmul_vv_i16m2(a_2, __riscv_vle8_v_i8m1(q8+64, vl), vl); - vint16m2_t va_q_3 = __riscv_vwmul_vv_i16m2(a_3, __riscv_vle8_v_i8m1(q8+96, vl), vl); + vint8m1_t a_0 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(qhi_0), 32, vl); + vint8m1_t a_1 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(qhi_1), 32, vl); + vint8m1_t a_2 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(qhi_2), 32, vl); + vint8m1_t a_3 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(qhi_3), 32, vl); - vl = 16; + // load Q8 and take product + vint16m2_t va_q_0 = __riscv_vwmul_vv_i16m2(a_0, __riscv_vle8_v_i8m1(q8, vl), vl); + vint16m2_t va_q_1 = __riscv_vwmul_vv_i16m2(a_1, __riscv_vle8_v_i8m1(q8+32, vl), vl); + vint16m2_t va_q_2 = __riscv_vwmul_vv_i16m2(a_2, __riscv_vle8_v_i8m1(q8+64, vl), vl); + vint16m2_t va_q_3 = __riscv_vwmul_vv_i16m2(a_3, __riscv_vle8_v_i8m1(q8+96, vl), vl); - vint32m2_t vaux_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_0, 0), scale[is+0], vl); - vint32m2_t vaux_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_0, 1), scale[is+1], vl); - vint32m2_t vaux_2 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_1, 0), scale[is+2], vl); - vint32m2_t vaux_3 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_1, 1), scale[is+3], vl); - vint32m2_t vaux_4 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_2, 0), scale[is+4], vl); - vint32m2_t vaux_5 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_2, 1), scale[is+5], vl); - vint32m2_t vaux_6 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_3, 0), scale[is+6], vl); - vint32m2_t vaux_7 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_3, 1), scale[is+7], vl); + vl = 16; - vint32m1_t isum0 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(vaux_0, vaux_1, vl), vzero, vl); - vint32m1_t isum1 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(vaux_2, vaux_3, vl), isum0, vl); - vint32m1_t isum2 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(vaux_4, vaux_5, vl), isum1, vl); - vint32m1_t isum3 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(vaux_6, vaux_7, vl), isum2, vl); + vint32m2_t vaux_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_0, 0), scale[is+0], vl); + vint32m2_t vaux_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_0, 1), scale[is+1], vl); + vint32m2_t vaux_2 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_1, 0), scale[is+2], vl); + vint32m2_t vaux_3 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_1, 1), scale[is+3], vl); + vint32m2_t vaux_4 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_2, 0), scale[is+4], vl); + vint32m2_t vaux_5 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_2, 1), scale[is+5], vl); + vint32m2_t vaux_6 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_3, 0), scale[is+6], vl); + vint32m2_t vaux_7 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_3, 1), scale[is+7], vl); - sum_t += __riscv_vmv_x_s_i32m1_i32(isum3); + vint32m1_t isum0 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(vaux_0, vaux_1, vl), vzero, vl); + vint32m1_t isum1 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(vaux_2, vaux_3, vl), isum0, vl); + vint32m1_t isum2 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(vaux_4, vaux_5, vl), isum1, vl); + vint32m1_t isum3 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(vaux_6, vaux_7, vl), isum2, vl); - q6 += 64; qh += 32; q8 += 128; is=8; + sum_t += __riscv_vmv_x_s_i32m1_i32(isum3); + + q6 += 64; qh += 32; q8 += 128; is=8; + + } + + sumf += d * sum_t; } + break; + case 128: + for (int i = 0; i < nb; ++i) { - sumf += d * sum_t; + const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; + const uint8_t * restrict q6 = x[i].ql; + const uint8_t * restrict qh = x[i].qh; + const int8_t * restrict q8 = y[i].qs; + + const int8_t * restrict scale = x[i].scales; + + int sum_t = 0; + int t0; + + for (int j = 0; j < QK_K/128; ++j) { + __asm__ __volatile__( + "vsetvli zero, %[vl32], e8, m2\n\t" + "vle8.v v4, (%[qh])\n\t" + "vsll.vi v0, v4, 4\n\t" + "vsll.vi v2, v4, 2\n\t" + "vsrl.vi v6, v4, 2\n\t" + "vsetvli zero, %[vl64], e8, m4\n\t" + "vle8.v v8, (%[q6])\n\t" + "vsrl.vi v12, v8, 4\n\t" + "vand.vi v8, v8, 0xF\n\t" + "vsetvli zero, %[vl128], e8, m8\n\t" + "vand.vx v0, v0, %[mask]\n\t" + "vor.vv v8, v8, v0\n\t" + "vle8.v v0, (%[q8])\n\t" + "vsub.vx v8, v8, %[vl32]\n\t" + "vsetvli zero, %[vl64], e8, m4\n\t" + "vwmul.vv v16, v0, v8\n\t" + "vwmul.vv v24, v4, v12\n\t" + "vsetivli zero, 16, e16, m2\n\t" + "vmv.v.x v0, zero\n\t" + "vwredsum.vs v10, v16, v0\n\t" + "vwredsum.vs v9, v18, v0\n\t" + "vwredsum.vs v8, v20, v0\n\t" + "vwredsum.vs v7, v22, v0\n\t" + "vwredsum.vs v11, v24, v0\n\t" + "vwredsum.vs v12, v26, v0\n\t" + "vwredsum.vs v13, v28, v0\n\t" + "vwredsum.vs v14, v30, v0\n\t" + "vsetivli zero, 4, e32, m1\n\t" + "vslideup.vi v10, v9, 1\n\t" + "vslideup.vi v8, v7, 1\n\t" + "vslideup.vi v11, v12, 1\n\t" + "vslideup.vi v13, v14, 1\n\t" + "vslideup.vi v10, v8, 2\n\t" + "vslideup.vi v11, v13, 2\n\t" + "vsetivli zero, 8, e32, m2\n\t" + "vle8.v v2, (%[scale])\n\t" + "vsext.vf4 v4, v2\n\t" + "vmul.vv v2, v4, v10\n\t" + "vredsum.vs v0, v2, v0\n\t" + "vmv.x.s %[t0], v0\n\t" + "add %[sumi], %[sumi], %[t0]" + : [sumi] "+&r" (sum_t), [t0] "=&r" (t0) + : [qh] "r" (qh), [q6] "r" (q6), [q8] "r" (q8), [scale] "r" (scale) + , [vl32] "r" (32), [vl64] "r" (64), [vl128] "r" (128) + , [mask] "r" (0x30) + : "memory" + , "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7" + , "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15" + , "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23" + , "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31" + ); + q6 += 64; qh += 32; q8 += 128; scale += 8; + } + + sumf += d * sum_t; + + } + break; + default: + assert(false && "Unsupported vector length"); + break; } *s = sumf; diff --git a/ggml/src/ggml-impl.h b/ggml/src/ggml-impl.h index 1fbcbd045..be2e3fc91 100644 --- a/ggml/src/ggml-impl.h +++ b/ggml/src/ggml-impl.h @@ -381,6 +381,35 @@ GGML_API void ggml_aligned_free(void * ptr, size_t size); return r; } +#elif defined(__riscv) && defined(GGML_RV_ZFH) + + static inline float ggml_compute_fp16_to_fp32(ggml_fp16_t h) { + float f; + __asm__( + "fmv.h.x %[f], %[h]\n\t" + "fcvt.s.h %[f], %[f]" + : [f] "=&f" (f) + : [h] "r" (h) + ); + return f; + } + + static inline ggml_fp16_t ggml_compute_fp32_to_fp16(float f) { + ggml_fp16_t res; + __asm__( + "fcvt.h.s %[f], %[f]\n\t" + "fmv.x.h %[h], %[f]" + : [h] "=&r" (res) + : [f] "f" (f) + ); + return res; + } + + #define GGML_COMPUTE_FP16_TO_FP32(x) ggml_compute_fp16_to_fp32(x) + #define GGML_COMPUTE_FP32_TO_FP16(x) ggml_compute_fp32_to_fp16(x) + #define GGML_FP16_TO_FP32(x) GGML_COMPUTE_FP16_TO_FP32(x) + #define GGML_FP32_TO_FP16(x) GGML_COMPUTE_FP32_TO_FP16(x) + #else // FP16 <-> FP32 From c7b43ab60855f752ae79937fb93d561bc30b69a4 Mon Sep 17 00:00:00 2001 From: amritahs-ibm Date: Thu, 27 Mar 2025 12:21:47 +0530 Subject: [PATCH 08/26] llamafile : ppc64le MMA implementation for Q4_0. (#12489) This change upstreams llamafile's cpu matrix multiplication kernels for ppc64le ISA using MMA builtins. This patch handles matrix multiplication between quantised datatypes, block_q4_0 and block_q8_0. This change results in 5% - 50% improvement in total speed(ie all tokens/total time), across various batch sizes. The patch is tested with Meta-Lllama-3-8B, Mistral-7B, Llama-2-7B-chat-hf models on a IBM POWER10 machine. Signed-off-by: Amrita H S --- ggml/src/ggml-cpu/llamafile/sgemm.cpp | 603 ++++++++++++++++++++++---- 1 file changed, 517 insertions(+), 86 deletions(-) diff --git a/ggml/src/ggml-cpu/llamafile/sgemm.cpp b/ggml/src/ggml-cpu/llamafile/sgemm.cpp index e0482c593..92dfbc2d2 100644 --- a/ggml/src/ggml-cpu/llamafile/sgemm.cpp +++ b/ggml/src/ggml-cpu/llamafile/sgemm.cpp @@ -55,6 +55,7 @@ #include #include +#include #ifdef _MSC_VER #define NOINLINE __declspec(noinline) @@ -1092,13 +1093,403 @@ class tinyBLAS_Q0_PPC { } } - template - void packNormal(const TA* a, int64_t lda, int rows, int cols, VA* vec, bool flip) { + template + void packNormalInt4(const TA* a, int64_t lda, int rows, int cols, VA* vec, std::array& comparray) { int64_t i, j; TA *aoffset = NULL; VA *vecOffset = NULL; TA *aoffset1 = NULL, *aoffset2 = NULL, *aoffset3 = NULL, *aoffset4 = NULL; TA *aoffset5 = NULL, *aoffset6 = NULL, *aoffset7 = NULL, *aoffset8 = NULL; + VB c1[2] = {0}, c2[2] = {0}, c3[2] = {0}, c4[2] = {0}; + VB c5[2] = {0}, c6[2] = {0}, c7[2] = {0}, c8[2] = {0}; + VB t1, t2, t3, t4, t5, t6, t7, t8; + const vector signed char lowMask = vec_splats((signed char)0xF); + const vector unsigned char v4 = vec_splats((unsigned char)0x4); + const vector signed char v8 = vec_splats((signed char)0x8); + aoffset = const_cast(a); + vecOffset = vec; + vector unsigned char swiz1 = {0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23}; + vector unsigned char swiz2 = {8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31}; + vector unsigned char swiz3 = {0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27}; + vector unsigned char swiz4 = {4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31}; + vector signed int vsum = {0}; + vector signed int vsum2 = {0}; + + j = (rows >> 3); + if (j > 0) { + do { + aoffset1 = aoffset; + aoffset2 = aoffset1 + lda; + aoffset3 = aoffset2 + lda; + aoffset4 = aoffset3 + lda; + aoffset5 = aoffset4 + lda; + aoffset6 = aoffset5 + lda; + aoffset7 = aoffset6 + lda; + aoffset8 = aoffset7 + lda; + aoffset += 8 * lda; + + i = (cols >> 2); + if (i > 0) { + do { + c1[1] = reinterpret_cast(vec_xl(0, aoffset1->qs)); + c2[1] = reinterpret_cast(vec_xl(0, aoffset2->qs)); + c3[1] = reinterpret_cast(vec_xl(0, aoffset3->qs)); + c4[1] = reinterpret_cast(vec_xl(0, aoffset4->qs)); + c5[1] = reinterpret_cast(vec_xl(0, aoffset5->qs)); + c6[1] = reinterpret_cast(vec_xl(0, aoffset6->qs)); + c7[1] = reinterpret_cast(vec_xl(0, aoffset7->qs)); + c8[1] = reinterpret_cast(vec_xl(0, aoffset8->qs)); + + c1[0] = vec_and(c1[1], lowMask); + c1[1] = vec_sr(c1[1], v4); + c1[0] = vec_sub(c1[0], v8); + c1[1] = vec_sub(c1[1], v8); + vsum = vec_sum4s(c1[0], vsum); + vsum2 = vec_sum4s(c1[1], vsum2); + vsum = vec_add(vsum, vsum2); + comparray[0] = vsum[0] + vsum[1] + vsum[2] + vsum[3]; + vsum = vec_splats(0); + vsum2 = vec_splats(0); + + c2[0] = vec_and(c2[1], lowMask); + c2[1] = vec_sr(c2[1], v4); + c2[0] = vec_sub(c2[0], v8); + c2[1] = vec_sub(c2[1], v8); + vsum = vec_sum4s(c2[0], vsum); + vsum2 = vec_sum4s(c2[1], vsum2); + vsum = vec_add(vsum, vsum2); + comparray[1] = vsum[0] + vsum[1] + vsum[2] + vsum[3]; + vsum = vec_splats(0); + vsum2 = vec_splats(0); + + c3[0] = vec_and(c3[1], lowMask); + c3[1] = vec_sr(c3[1], v4); + c3[0] = vec_sub(c3[0], v8); + c3[1] = vec_sub(c3[1], v8); + vsum = vec_sum4s(c3[0], vsum); + vsum2 = vec_sum4s(c3[1], vsum2); + vsum = vec_add(vsum, vsum2); + comparray[2] = vsum[0] + vsum[1] + vsum[2] + vsum[3]; + vsum = vec_splats(0); + vsum2 = vec_splats(0); + + c4[0] = vec_and(c4[1], lowMask); + c4[1] = vec_sr(c4[1], v4); + c4[0] = vec_sub(c4[0], v8); + c4[1] = vec_sub(c4[1], v8); + vsum = vec_sum4s(c4[0], vsum); + vsum2 = vec_sum4s(c4[1], vsum2); + vsum = vec_add(vsum, vsum2); + comparray[3] = vsum[0] + vsum[1] + vsum[2] + vsum[3]; + vsum = vec_splats(0); + vsum2 = vec_splats(0); + + c5[0] = vec_and(c5[1], lowMask); + c5[1] = vec_sr(c5[1], v4); + c5[0] = vec_sub(c5[0], v8); + c5[1] = vec_sub(c5[1], v8); + vsum = vec_sum4s(c5[0], vsum); + vsum2 = vec_sum4s(c5[1], vsum2); + vsum = vec_add(vsum, vsum2); + comparray[4] = vsum[0] + vsum[1] + vsum[2] + vsum[3]; + vsum = vec_splats(0); + vsum2 = vec_splats(0); + + c6[0] = vec_and(c6[1], lowMask); + c6[1] = vec_sr(c6[1], v4); + c6[0] = vec_sub(c6[0], v8); + c6[1] = vec_sub(c6[1], v8); + vsum = vec_sum4s(c6[0], vsum); + vsum2 = vec_sum4s(c6[1], vsum2); + vsum = vec_add(vsum, vsum2); + comparray[5] = vsum[0] + vsum[1] + vsum[2] + vsum[3]; + vsum = vec_splats(0); + vsum2 = vec_splats(0); + + c7[0] = vec_and(c7[1], lowMask); + c7[1] = vec_sr(c7[1], v4); + c7[0] = vec_sub(c7[0], v8); + c7[1] = vec_sub(c7[1], v8); + vsum = vec_sum4s(c7[0], vsum); + vsum2 = vec_sum4s(c7[1], vsum2); + vsum = vec_add(vsum, vsum2); + comparray[6] = vsum[0] + vsum[1] + vsum[2] + vsum[3]; + vsum = vec_splats(0); + vsum2 = vec_splats(0); + + c8[0] = vec_and(c8[1], lowMask); + c8[1] = vec_sr(c8[1], v4); + c8[0] = vec_sub(c8[0], v8); + c8[1] = vec_sub(c8[1], v8); + vsum = vec_sum4s(c8[0], vsum); + vsum2 = vec_sum4s(c8[1], vsum2); + vsum = vec_add(vsum, vsum2); + comparray[7] = vsum[0] + vsum[1] + vsum[2] + vsum[3]; + vsum = vec_splats(0); + vsum2 = vec_splats(0); + + t1 = vec_perm(c1[0], c2[0], swiz1); + t2 = vec_perm(c1[0], c2[0], swiz2); + t3 = vec_perm(c3[0], c4[0], swiz1); + t4 = vec_perm(c3[0], c4[0], swiz2); + t5 = vec_perm(t1, t3, swiz3); + t6 = vec_perm(t1, t3, swiz4); + t7 = vec_perm(t2, t4, swiz3); + t8 = vec_perm(t2, t4, swiz4); + vec_xst(t5, 0, vecOffset); + vec_xst(t6, 0, vecOffset+16); + vec_xst(t7, 0, vecOffset+32); + vec_xst(t8, 0, vecOffset+48); + + t1 = vec_perm(c1[1], c2[1], swiz1); + t2 = vec_perm(c1[1], c2[1], swiz2); + t3 = vec_perm(c3[1], c4[1], swiz1); + t4 = vec_perm(c3[1], c4[1], swiz2); + t5 = vec_perm(t1, t3, swiz3); + t6 = vec_perm(t1, t3, swiz4); + t7 = vec_perm(t2, t4, swiz3); + t8 = vec_perm(t2, t4, swiz4); + vec_xst(t5, 0, vecOffset+64); + vec_xst(t6, 0, vecOffset+80); + vec_xst(t7, 0, vecOffset+96); + vec_xst(t8, 0, vecOffset+112); + + t1 = vec_perm(c5[0], c6[0], swiz1); + t2 = vec_perm(c5[0], c6[0], swiz2); + t3 = vec_perm(c7[0], c8[0], swiz1); + t4 = vec_perm(c7[0], c8[0], swiz2); + t5 = vec_perm(t1, t3, swiz3); + t6 = vec_perm(t1, t3, swiz4); + t7 = vec_perm(t2, t4, swiz3); + t8 = vec_perm(t2, t4, swiz4); + vec_xst(t5, 0, vecOffset+128); + vec_xst(t6, 0, vecOffset+144); + vec_xst(t7, 0, vecOffset+160); + vec_xst(t8, 0, vecOffset+176); + + t1 = vec_perm(c5[1], c6[1], swiz1); + t2 = vec_perm(c5[1], c6[1], swiz2); + t3 = vec_perm(c7[1], c8[1], swiz1); + t4 = vec_perm(c7[1], c8[1], swiz2); + t5 = vec_perm(t1, t3, swiz3); + t6 = vec_perm(t1, t3, swiz4); + t7 = vec_perm(t2, t4, swiz3); + t8 = vec_perm(t2, t4, swiz4); + vec_xst(t5, 0, vecOffset+192); + vec_xst(t6, 0, vecOffset+208); + vec_xst(t7, 0, vecOffset+224); + vec_xst(t8, 0, vecOffset+240); + + aoffset1 += lda; + aoffset2 += lda; + aoffset3 += lda; + aoffset4 += lda; + aoffset5 += lda; + aoffset6 += lda; + aoffset7 += lda; + aoffset8 += lda; + vecOffset += 256; + i--; + } while (i > 0); + } + j--; + } while (j > 0); + } + + if (rows & 4) { + aoffset1 = aoffset; + aoffset2 = aoffset1 + lda; + aoffset3 = aoffset2 + lda; + aoffset4 = aoffset3 + lda; + aoffset += 4 * lda; + + i = (cols >> 2); + if (i > 0) { + do { + c1[1] = reinterpret_cast(vec_xl(0, aoffset1->qs)); + c2[1] = reinterpret_cast(vec_xl(0, aoffset2->qs)); + c3[1] = reinterpret_cast(vec_xl(0, aoffset3->qs)); + c4[1] = reinterpret_cast(vec_xl(0, aoffset4->qs)); + + c1[0] = vec_and(c1[1], lowMask); + c1[1] = vec_sr(c1[1], v4); + c1[0] = vec_sub(c1[0], v8); + c1[1] = vec_sub(c1[1], v8); + vsum = vec_sum4s(c1[0], vsum); + vsum2 = vec_sum4s(c1[1], vsum2); + vsum = vec_add(vsum, vsum2); + comparray[0] = vsum[0] + vsum[1] + vsum[2] + vsum[3]; + vsum = vec_splats(0); + vsum2 = vec_splats(0); + + c2[0] = vec_and(c2[1], lowMask); + c2[1] = vec_sr(c2[1], v4); + c2[0] = vec_sub(c2[0], v8); + c2[1] = vec_sub(c2[1], v8); + vsum = vec_sum4s(c2[0], vsum); + vsum2 = vec_sum4s(c2[1], vsum2); + vsum = vec_add(vsum, vsum2); + comparray[1] = vsum[0] + vsum[1] + vsum[2] + vsum[3]; + vsum = vec_splats(0); + vsum2 = vec_splats(0); + + c3[0] = vec_and(c3[1], lowMask); + c3[1] = vec_sr(c3[1], v4); + c3[0] = vec_sub(c3[0], v8); + c3[1] = vec_sub(c3[1], v8); + vsum = vec_sum4s(c3[0], vsum); + vsum2 = vec_sum4s(c3[1], vsum2); + vsum = vec_add(vsum, vsum2); + comparray[2] = vsum[0] + vsum[1] + vsum[2] + vsum[3]; + vsum = vec_splats(0); + vsum2 = vec_splats(0); + + c4[0] = vec_and(c4[1], lowMask); + c4[1] = vec_sr(c4[1], v4); + c4[0] = vec_sub(c4[0], v8); + c4[1] = vec_sub(c4[1], v8); + vsum = vec_sum4s(c4[0], vsum); + vsum2 = vec_sum4s(c4[1], vsum2); + vsum = vec_add(vsum, vsum2); + comparray[3] = vsum[0] + vsum[1] + vsum[2] + vsum[3]; + vsum = vec_splats(0); + vsum2 = vec_splats( 0); + + t1 = vec_perm(c1[0], c2[0], swiz1); + t2 = vec_perm(c1[0], c2[0], swiz2); + t3 = vec_perm(c3[0], c4[0], swiz1); + t4 = vec_perm(c3[0], c4[0], swiz2); + t5 = vec_perm(t1, t3, swiz3); + t6 = vec_perm(t1, t3, swiz4); + t7 = vec_perm(t2, t4, swiz3); + t8 = vec_perm(t2, t4, swiz4); + vec_xst(t5, 0, vecOffset); + vec_xst(t6, 0, vecOffset+16); + vec_xst(t7, 0, vecOffset+32); + vec_xst(t8, 0, vecOffset+48); + + t1 = vec_perm(c1[1], c2[1], swiz1); + t2 = vec_perm(c1[1], c2[1], swiz2); + t3 = vec_perm(c3[1], c4[1], swiz1); + t4 = vec_perm(c3[1], c4[1], swiz2); + t5 = vec_perm(t1, t3, swiz3); + t6 = vec_perm(t1, t3, swiz4); + t7 = vec_perm(t2, t4, swiz3); + t8 = vec_perm(t2, t4, swiz4); + vec_xst(t5, 0, vecOffset+64); + vec_xst(t6, 0, vecOffset+80); + vec_xst(t7, 0, vecOffset+96); + vec_xst(t8, 0, vecOffset+112); + + aoffset1 += lda; + aoffset2 += lda; + aoffset3 += lda; + aoffset4 += lda; + vecOffset += 128; + i--; + } while (i > 0); + } + } + + if (rows & 3) { + aoffset1 = aoffset; + aoffset2 = aoffset1 + lda; + aoffset3 = aoffset2 + lda; + i = (cols >> 2); + if (i > 0) { + do { + switch(rows) { + case 3: c3[1] = reinterpret_cast(vec_xl(0, aoffset3->qs)); + case 2: c2[1] = reinterpret_cast(vec_xl(0, aoffset2->qs)); + case 1: c1[1] = reinterpret_cast(vec_xl(0, aoffset1->qs)); + break; + } + c1[0] = vec_and(c1[1], lowMask); + c1[1] = vec_sr(c1[1], v4); + c1[0] = vec_sub(c1[0], v8); + c1[1] = vec_sub(c1[1], v8); + vsum = vec_sum4s(c1[0], vsum); + vsum2 = vec_sum4s(c1[1], vsum2); + vsum = vec_add(vsum, vsum2); + comparray[0] = vsum[0] + vsum[1] + vsum[2] + vsum[3]; + vsum = vec_splats(0); + vsum2 = vec_splats(0); + + c2[0] = vec_and(c2[1], lowMask); + c2[1] = vec_sr(c2[1], v4); + c2[0] = vec_sub(c2[0], v8); + c2[1] = vec_sub(c2[1], v8); + vsum = vec_sum4s(c2[0], vsum); + vsum2 = vec_sum4s(c2[1], vsum2); + vsum = vec_add(vsum, vsum2); + comparray[1] = vsum[0] + vsum[1] + vsum[2] + vsum[3]; + vsum = vec_splats(0); + vsum2 = vec_splats(0); + + c3[0] = vec_and(c3[1], lowMask); + c3[1] = vec_sr(c3[1], v4); + c3[0] = vec_sub(c3[0], v8); + c3[1] = vec_sub(c3[1], v8); + vsum = vec_sum4s(c3[0], vsum); + vsum2 = vec_sum4s(c3[1], vsum2); + vsum = vec_add(vsum, vsum2); + comparray[2] = vsum[0] + vsum[1] + vsum[2] + vsum[3]; + vsum = vec_splats(0); + vsum2 = vec_splats(0); + + c4[0] = vec_and(c4[1], lowMask); + c4[1] = vec_sr(c4[1], v4); + c4[0] = vec_sub(c4[0], v8); + c4[1] = vec_sub(c4[1], v8); + vsum = vec_sum4s(c4[0], vsum); + vsum2 = vec_sum4s(c4[1], vsum2); + vsum = vec_add(vsum, vsum2); + comparray[3] = vsum[0] + vsum[1] + vsum[2] + vsum[3]; + vsum = vec_splats(0); + vsum2 = vec_splats(0); + + t1 = vec_perm(c1[0], c2[0], swiz1); + t2 = vec_perm(c1[0], c2[0], swiz2); + t3 = vec_perm(c3[0], c4[0], swiz1); + t4 = vec_perm(c3[0], c4[0], swiz2); + t5 = vec_perm(t1, t3, swiz3); + t6 = vec_perm(t1, t3, swiz4); + t7 = vec_perm(t2, t4, swiz3); + t8 = vec_perm(t2, t4, swiz4); + vec_xst(t5, 0, vecOffset); + vec_xst(t6, 0, vecOffset+16); + vec_xst(t7, 0, vecOffset+32); + vec_xst(t8, 0, vecOffset+48); + + t1 = vec_perm(c1[1], c2[1], swiz1); + t2 = vec_perm(c1[1], c2[1], swiz2); + t3 = vec_perm(c3[1], c4[1], swiz1); + t4 = vec_perm(c3[1], c4[1], swiz2); + t5 = vec_perm(t1, t3, swiz3); + t6 = vec_perm(t1, t3, swiz4); + t7 = vec_perm(t2, t4, swiz3); + t8 = vec_perm(t2, t4, swiz4); + vec_xst(t5, 0, vecOffset+64); + vec_xst(t6, 0, vecOffset+80); + vec_xst(t7, 0, vecOffset+96); + vec_xst(t8, 0, vecOffset+112); + aoffset1 += lda; + aoffset2 += lda; + aoffset3 += lda; + vecOffset += 128; + i--; + } while(i > 0); + } + } + } + + template + void packNormal(const TB* a, int64_t lda, int rows, int cols, VA* vec, bool flip) { + int64_t i, j; + TB *aoffset = NULL; + VA *vecOffset = NULL; + TB *aoffset1 = NULL, *aoffset2 = NULL, *aoffset3 = NULL, *aoffset4 = NULL; + TB *aoffset5 = NULL, *aoffset6 = NULL, *aoffset7 = NULL, *aoffset8 = NULL; __vector_pair C1, C2, C3, C4, C5, C6, C7, C8; VB c1[2] = {0}, c2[2] = {0}, c3[2] = {0}, c4[2]={0}; VB c5[2] = {0}, c6[2] = {0}, c7[2] = {0}, c8[2]={0}; @@ -1111,24 +1502,24 @@ class tinyBLAS_Q0_PPC { vector unsigned char swiz3 = {0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27}; vector unsigned char swiz4 = {4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31}; - aoffset = const_cast(a); + aoffset = const_cast(a); vecOffset = vec; j = (rows >> 3); if (j > 0) { do { - aoffset1 = aoffset; - aoffset2 = aoffset1 + lda; - aoffset3 = aoffset2 + lda; - aoffset4 = aoffset3 + lda; - aoffset5 = aoffset4 + lda; - aoffset6 = aoffset5 + lda; - aoffset7 = aoffset6 + lda; - aoffset8 = aoffset7 + lda; - aoffset += 8 * lda; + aoffset1 = aoffset; + aoffset2 = aoffset1 + lda; + aoffset3 = aoffset2 + lda; + aoffset4 = aoffset3 + lda; + aoffset5 = aoffset4 + lda; + aoffset6 = aoffset5 + lda; + aoffset7 = aoffset6 + lda; + aoffset8 = aoffset7 + lda; + aoffset += 8 * lda; - i = (cols >> 3); - if (i > 0) { - do { + i = (cols >> 3); + if (i > 0) { + do { C1 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset1->qs); C2 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset2->qs); C3 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset3->qs); @@ -1156,10 +1547,10 @@ class tinyBLAS_Q0_PPC { t7 = vec_perm(t2, t4, swiz3); t8 = vec_perm(t2, t4, swiz4); if (flip == true) { - t5 = vec_xor(t5, xor_vector); - t6 = vec_xor(t6, xor_vector); - t7 = vec_xor(t7, xor_vector); - t8 = vec_xor(t8, xor_vector); + t5 = vec_xor(t5, xor_vector); + t6 = vec_xor(t6, xor_vector); + t7 = vec_xor(t7, xor_vector); + t8 = vec_xor(t8, xor_vector); } vec_xst(t5, 0, vecOffset); vec_xst(t6, 0, vecOffset+16); @@ -1175,10 +1566,10 @@ class tinyBLAS_Q0_PPC { t7 = vec_perm(t2, t4, swiz3); t8 = vec_perm(t2, t4, swiz4); if (flip == true) { - t5 = vec_xor(t5, xor_vector); - t6 = vec_xor(t6, xor_vector); - t7 = vec_xor(t7, xor_vector); - t8 = vec_xor(t8, xor_vector); + t5 = vec_xor(t5, xor_vector); + t6 = vec_xor(t6, xor_vector); + t7 = vec_xor(t7, xor_vector); + t8 = vec_xor(t8, xor_vector); } vec_xst(t5, 0, vecOffset+64); vec_xst(t6, 0, vecOffset+80); @@ -1194,10 +1585,10 @@ class tinyBLAS_Q0_PPC { t7 = vec_perm(t2, t4, swiz3); t8 = vec_perm(t2, t4, swiz4); if (flip == true) { - t5 = vec_xor(t5, xor_vector); - t6 = vec_xor(t6, xor_vector); - t7 = vec_xor(t7, xor_vector); - t8 = vec_xor(t8, xor_vector); + t5 = vec_xor(t5, xor_vector); + t6 = vec_xor(t6, xor_vector); + t7 = vec_xor(t7, xor_vector); + t8 = vec_xor(t8, xor_vector); } vec_xst(t5, 0, vecOffset+128); vec_xst(t6, 0, vecOffset+144); @@ -1213,10 +1604,10 @@ class tinyBLAS_Q0_PPC { t7 = vec_perm(t2, t4, swiz3); t8 = vec_perm(t2, t4, swiz4); if (flip == true) { - t5 = vec_xor(t5, xor_vector); - t6 = vec_xor(t6, xor_vector); - t7 = vec_xor(t7, xor_vector); - t8 = vec_xor(t8, xor_vector); + t5 = vec_xor(t5, xor_vector); + t6 = vec_xor(t6, xor_vector); + t7 = vec_xor(t7, xor_vector); + t8 = vec_xor(t8, xor_vector); } vec_xst(t5, 0, vecOffset+192); vec_xst(t6, 0, vecOffset+208); @@ -1240,11 +1631,11 @@ class tinyBLAS_Q0_PPC { } if (rows & 4) { - aoffset1 = aoffset; - aoffset2 = aoffset1 + lda; - aoffset3 = aoffset2 + lda; - aoffset4 = aoffset3 + lda; - aoffset += 4 * lda; + aoffset1 = aoffset; + aoffset2 = aoffset1 + lda; + aoffset3 = aoffset2 + lda; + aoffset4 = aoffset3 + lda; + aoffset += 4 * lda; i = (cols >> 3); if (i > 0) { @@ -1311,7 +1702,7 @@ class tinyBLAS_Q0_PPC { aoffset2 = aoffset1 + lda; aoffset3 = aoffset2 + lda; i = (cols >> 3); - if (i > 0) { + if (i > 0) { do { switch(rows) { case 3: C3 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset3->qs); @@ -1527,13 +1918,18 @@ class tinyBLAS_Q0_PPC { void KERNEL_4x8(int64_t ii, int64_t jj) { vec_t vec_A[8], vec_B[16] = {0}; acc_t acc_0, acc_1; - std::array comparray; + std::array comparray {}; vector float fin_res[8] = {0}; vector float vs[8] = {0}; + bool isAblock_q4 = std::is_same_v; for (int l = 0; l < k; l++) { __builtin_mma_xxsetaccz(&acc_0); __builtin_mma_xxsetaccz(&acc_1); - packNormal((A+(ii*lda)+l), lda, 4, 8, (int8_t*)vec_A, false); + if (std::is_same_v) { + packNormalInt4((A+(ii*lda)+l), lda, 4, 4, (int8_t*)vec_A, comparray); + } else { + packNormal((const TB*)(A+(ii*lda)+l), lda, 4, 8, (int8_t*)vec_A, false); + } packNormal((B+(jj*ldb)+l), ldb, 8, 8, (uint8_t*)vec_B, true); for(int x = 0; x < 8; x++) { __builtin_mma_xvi8ger4pp(&acc_0, vec_A[x], vec_B[x]); @@ -1545,15 +1941,17 @@ class tinyBLAS_Q0_PPC { *((float*)&vs[I+4]+J) = (unhalf((A+((ii+I)*lda)+l)->d) * unhalf((B+((jj+J+4)*ldb)+l)->d)); } } - auto aoffset = A+(ii*lda)+l; - for (int i = 0; i < 4; i++) { - comparray[i] = 0; - int ca = 0; - const int8_t *at = aoffset->qs; - for (int j = 0; j < 32; j++) - ca += (int)*at++; - comparray[i] = ca; - aoffset += lda; + if (!isAblock_q4) { + auto aoffset = A+(ii*lda)+l; + for (int i = 0; i < 4; i++) { + comparray[i] = 0; + int ca = 0; + auto *at = aoffset->qs; + for (int j = 0; j < 32; j++) + ca += (int)*at++; + comparray[i] = ca; + aoffset += lda; + } } compute<4>(&acc_0, 0, 0, comparray, vs, fin_res); compute<4>(&acc_1, 0, 4, comparray, vs, fin_res); @@ -1565,13 +1963,18 @@ class tinyBLAS_Q0_PPC { void KERNEL_8x4(int64_t ii, int64_t jj) { vec_t vec_A[16], vec_B[8] = {0}; acc_t acc_0, acc_1; - std::array comparray; + std::array comparray {}; vector float fin_res[8] = {0}; vector float vs[8] = {0}; + bool isAblock_q4 = std::is_same_v; for (int l = 0; l < k; l++) { __builtin_mma_xxsetaccz(&acc_0); __builtin_mma_xxsetaccz(&acc_1); - packNormal((A+(ii*lda)+l), lda, 8, 8, (int8_t*)vec_A, false); + if (std::is_same_v) { + packNormalInt4((A+(ii*lda)+l), lda, 8, 4, (int8_t*)vec_A, comparray); + } else { + packNormal((const TB*)(A+(ii*lda)+l), lda, 8, 8, (int8_t*)vec_A, false); + } packNormal((B+(jj*ldb)+l), ldb, 4, 8, (uint8_t*)vec_B, true); for(int x = 0; x < 8; x++) { __builtin_mma_xvi8ger4pp(&acc_0, vec_A[x], vec_B[x]); @@ -1582,15 +1985,17 @@ class tinyBLAS_Q0_PPC { *((float*)&vs[I]+J) = (unhalf((A+((ii+I)*lda)+l)->d) * unhalf((B+((jj+J)*ldb)+l)->d)); } } - auto aoffset = A+(ii*lda)+l; - for (int i = 0; i < 8; i++) { - comparray[i] = 0; - int ca = 0; - const int8_t *at = aoffset->qs; - for (int j = 0; j < 32; j++) - ca += (int)*at++; - comparray[i] = ca; - aoffset += lda; + if (!isAblock_q4) { + auto aoffset = A+(ii*lda)+l; + for (int i = 0; i < 8; i++) { + comparray[i] = 0; + int ca = 0; + auto *at = aoffset->qs; + for (int j = 0; j < 32; j++) + ca += (int)*at++; + comparray[i] = ca; + aoffset += lda; + } } compute<8>(&acc_0, 0, 0, comparray, vs, fin_res); compute<8>(&acc_1, 4, 4, comparray, vs, fin_res); @@ -1602,15 +2007,20 @@ class tinyBLAS_Q0_PPC { void KERNEL_8x8(int64_t ii, int64_t jj) { vec_t vec_A[16], vec_B[16] = {0}; acc_t acc_0, acc_1, acc_2, acc_3; - std::array comparray; + std::array comparray {}; vector float fin_res[16] = {0}; vector float vs[16] = {0}; + bool isAblock_q4 = std::is_same_v; for (int l = 0; l < k; l++) { __builtin_mma_xxsetaccz(&acc_0); __builtin_mma_xxsetaccz(&acc_1); __builtin_mma_xxsetaccz(&acc_2); __builtin_mma_xxsetaccz(&acc_3); - packNormal((A+(ii*lda)+l), lda, 8, 8, (int8_t*)vec_A, false); + if (std::is_same_v) { + packNormalInt4((A+(ii*lda)+l), lda, 8, 4, (int8_t*)vec_A, comparray); + } else { + packNormal((const TB*)(A+(ii*lda)+l), lda, 8, 8, (int8_t*)vec_A, false); + } packNormal((B+(jj*ldb)+l), ldb, 8, 8, (uint8_t*)vec_B, true); for(int x = 0; x < 8; x++) { __builtin_mma_xvi8ger4pp(&acc_0, vec_A[x], vec_B[x]); @@ -1624,15 +2034,17 @@ class tinyBLAS_Q0_PPC { *((float*)&vs[I+8]+J) = (unhalf((A+((ii+I)*lda)+l)->d) * unhalf((B+((jj+J+4)*ldb)+l)->d)); } } - auto aoffset = A+(ii*lda)+l; - for (int i = 0; i < 8; i++) { - comparray[i] = 0; - int ca = 0; - const int8_t *at = aoffset->qs; - for (int j = 0; j < 32; j++) - ca += (int)*at++; - comparray[i] = ca; - aoffset += lda; + if (!isAblock_q4) { + auto aoffset = A+(ii*lda)+l; + for (int i = 0; i < 8; i++) { + comparray[i] = 0; + int ca = 0; + auto *at = aoffset->qs; + for (int j = 0; j < 32; j++) + ca += (int)*at++; + comparray[i] = ca; + aoffset += lda; + } } compute<8>(&acc_0, 0, 0, comparray, vs, fin_res); compute<8>(&acc_1, 4, 4, comparray, vs, fin_res); @@ -1653,16 +2065,17 @@ class tinyBLAS_Q0_PPC { int64_t duty = (tiles + nth - 1) / nth; int64_t start = duty * ith; int64_t end = start + duty; - vec_t vec_A[8], vec_B[8] = {0}; + vec_t vec_A[8] = {0}, vec_B[8] = {0}; vector signed int vec_C[4]; acc_t acc_0; + bool isAblock_q4 = std::is_same_v; if (end > tiles) end = tiles; for (int64_t job = start; job < end; ++job) { int64_t ii = m0 + job / xtiles * RM; int64_t jj = n0 + job % xtiles * RN; - std::array comparray; + std::array comparray{}; vector float res[4] = {0}; vector float fin_res[4] = {0}; vector float vs[4] = {0}; @@ -1673,7 +2086,11 @@ class tinyBLAS_Q0_PPC { __builtin_prefetch((A+(ii*lda)+(l+1))->qs, 0, 1); // prefetch one loop ahead __builtin_prefetch((B+(jj*ldb)+(l+1))->qs, 0, 1); // prefetch one loop ahead __builtin_mma_xxsetaccz(&acc_0); - packNormal((A+(ii*lda)+l), lda, RM, 8, (int8_t*)vec_A, false); + if (isAblock_q4) { + packNormalInt4((A+(ii*lda)+l), lda, RM, 4, (int8_t*)vec_A, comparray); + } else { + packNormal((const TB*)(A+(ii*lda)+l), lda, RM, 8, (int8_t*)vec_A, false); + } packNormal((B+(jj*ldb)+l), ldb, RN, 8, (uint8_t*)vec_B, true); for(int x = 0; x < 8; x+=4) { __builtin_mma_xvi8ger4pp(&acc_0, vec_A[x], vec_B[x]); @@ -1687,17 +2104,18 @@ class tinyBLAS_Q0_PPC { } } __builtin_mma_disassemble_acc(vec_C, &acc_0); - auto aoffset = A+(ii*lda)+l; - for (int i = 0; i < RM; i++) { - comparray[i] = 0; - int ca = 0; - const int8_t *at = aoffset->qs; - for (int j = 0; j < 32; j++) - ca += (int)*at++; - comparray[i] = ca; - aoffset += lda; + if (!isAblock_q4) { + auto aoffset = A+(ii*lda)+l; + for (int i = 0; i < RM; i++) { + comparray[i] = 0; + int ca = 0; + auto *at = aoffset->qs; + for (int j = 0; j < 32; j++) + ca += (int)*at++; + comparray[i] = ca; + aoffset += lda; + } } - for (int i = 0; i < RM; i++) { CA[i] = vec_splats((float)(((double)comparray[i]) * -128.0)); res[i] = vec_add(vec_ctf(vec_C[i], 0), CA[i]); @@ -2013,6 +2431,7 @@ class tinyBLAS_PPC { } } } + void KERNEL_4x4(int64_t ii, int64_t jj) { vec_t vec_A[4], vec_B[4], vec_C[4]; acc_t acc_0; @@ -2259,7 +2678,7 @@ class tinyBLAS_PPC { vec_t vec_C[4]; acc_t acc_0; __builtin_mma_xxsetaccz(&acc_0); - vec_t vec_A[4], vec_B[4]; + vec_t vec_A[4] {0}, vec_B[4] = {0}; for (int l=0; l= 4 && RM == 1) { TA* a = const_cast(A+(ii)*lda+l); @@ -2503,8 +2922,8 @@ bool llamafile_sgemm(const struct ggml_compute_params * params, int64_t m, int64 params->ith, params->nth}; tb.matmul(m, n); return true; - #elif defined(__MMA__) + //TO-DO: Remove this condition once gemv forwarding is enabled. if (n < 8 && n != 4) return false; if (m < 8 && m != 4) @@ -2516,7 +2935,6 @@ bool llamafile_sgemm(const struct ggml_compute_params * params, int64_t m, int64 params->ith, params->nth}; tb.matmul(m, n); return true; - #else return false; #endif @@ -2541,6 +2959,19 @@ bool llamafile_sgemm(const struct ggml_compute_params * params, int64_t m, int64 params->ith, params->nth}; tb.matmul(m, n); return true; +#elif defined(__MMA__) + //TO-DO: Remove this condition once gemv forwarding is enabled. + if (n < 8 && n != 4) + return false; + if (m < 8 && m != 4) + return false; + tinyBLAS_Q0_PPC tb{ + k, (const block_q4_0 *)A, lda, + (const block_q8_0 *)B, ldb, + (float *)C, ldc, + params->ith, params->nth}; + tb.matmul(m, n); + return true; #else return false; #endif From 0306aad1ca89a92fdb33450b35547b90b92f5dbe Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 27 Mar 2025 09:00:57 +0200 Subject: [PATCH 09/26] cmake : sync/merge PowerPC build commands (#0) --- ggml/src/ggml-cpu/CMakeLists.txt | 34 +++++++++++++++++++------------- 1 file changed, 20 insertions(+), 14 deletions(-) diff --git a/ggml/src/ggml-cpu/CMakeLists.txt b/ggml/src/ggml-cpu/CMakeLists.txt index b9076513a..971313d20 100644 --- a/ggml/src/ggml-cpu/CMakeLists.txt +++ b/ggml/src/ggml-cpu/CMakeLists.txt @@ -289,23 +289,29 @@ function(ggml_add_cpu_backend_variant_impl tag_name) endif() elseif ("${CMAKE_SYSTEM_PROCESSOR} " STREQUAL "ppc64le " OR "${CMAKE_SYSTEM_PROCESSOR} " STREQUAL "powerpc ") message(STATUS "PowerPC detected") - if(${CMAKE_SYSTEM_PROCESSOR} MATCHES "ppc64") - file(READ "/proc/cpuinfo" POWER10_M) - elseif(${CMAKE_SYSTEM_PROCESSOR} MATCHES "powerpc") - execute_process(COMMAND bash -c "prtconf |grep 'Implementation' | head -n 1" OUTPUT_VARIABLE POWER10_M) - endif() + if (GGML_NATIVE) + if (${CMAKE_SYSTEM_PROCESSOR} MATCHES "ppc64") + file(READ "/proc/cpuinfo" POWER10_M) + elseif (${CMAKE_SYSTEM_PROCESSOR} MATCHES "powerpc") + execute_process(COMMAND bash -c "prtconf |grep 'Implementation' | head -n 1" OUTPUT_VARIABLE POWER10_M) + endif() - string(REGEX MATCHALL "POWER *([0-9]+)" MATCHED_STRING "${POWER10_M}") - string(REGEX REPLACE "POWER *([0-9]+)" "\\1" EXTRACTED_NUMBER "${MATCHED_STRING}") + string(REGEX MATCHALL "POWER *([0-9]+)" MATCHED_STRING "${POWER10_M}") + string(REGEX REPLACE "POWER *([0-9]+)" "\\1" EXTRACTED_NUMBER "${MATCHED_STRING}") - if (EXTRACTED_NUMBER GREATER_EQUAL 10) - list(APPEND ARCH_FLAGS -mcpu=power10 -mpowerpc64) - elseif (EXTRACTED_NUMBER EQUAL 9) - list(APPEND ARCH_FLAGS -mcpu=power9 -mpowerpc64) - elseif (${CMAKE_SYSTEM_PROCESSOR} MATCHES "ppc64le") - list(APPEND ARCH_FLAGS -mcpu=powerpc64le -mtune=native) + if (EXTRACTED_NUMBER GREATER_EQUAL 10) + list(APPEND ARCH_FLAGS -mcpu=power10 -mpowerpc64) + elseif (EXTRACTED_NUMBER EQUAL 9) + list(APPEND ARCH_FLAGS -mcpu=power9 -mpowerpc64) + elseif (${CMAKE_SYSTEM_PROCESSOR} MATCHES "ppc64le") + list(APPEND ARCH_FLAGS -mcpu=powerpc64le -mtune=native) + else() + list(APPEND ARCH_FLAGS -mcpu=native -mtune=native -mpowerpc64) + endif() else() - list(APPEND ARCH_FLAGS -mcpu=native -mtune=native -mpowerpc64) + if (GGML_CPU_POWERPC_CPUTYPE) + list(APPEND ARCH_FLAGS -mcpu=${GGML_CPU_POWERPC_CPUTYPE}) + endif() endif() elseif (${CMAKE_SYSTEM_PROCESSOR} MATCHES "loongarch64") message(STATUS "loongarch64 detected") From df0665a483b08da1c9b55534b624ae4e1fe89767 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 27 Mar 2025 09:01:21 +0200 Subject: [PATCH 10/26] sync : ggml ggml-ci --- scripts/sync-ggml.last | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/sync-ggml.last b/scripts/sync-ggml.last index c7944d1d4..d0532d85c 100644 --- a/scripts/sync-ggml.last +++ b/scripts/sync-ggml.last @@ -1 +1 @@ -c7dfe3d174f98b14801f9ed12f129179d3e7b638 +fc21aba88324312d66f62af6a1b0683fcb2ce3b5 From 771d84371c0785c2554aef9a47cdd551b8c7ceaf Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 27 Mar 2025 09:22:30 +0200 Subject: [PATCH 11/26] scripts : update sync + fix cmake merge ggml-ci --- ggml/CMakeLists.txt | 3 ++- ggml/cmake/GitVars.cmake | 22 ++++++++++++++++++++++ ggml/cmake/ggml-config.cmake.in | 2 +- scripts/sync-ggml-am.sh | 19 ++++++++++++++++--- scripts/sync-ggml.sh | 4 +++- 5 files changed, 44 insertions(+), 6 deletions(-) create mode 100644 ggml/cmake/GitVars.cmake diff --git a/ggml/CMakeLists.txt b/ggml/CMakeLists.txt index 433628c4c..0332c26f1 100644 --- a/ggml/CMakeLists.txt +++ b/ggml/CMakeLists.txt @@ -127,7 +127,8 @@ option(GGML_RV_ZFH "ggml: enable riscv zfh" OFF) option(GGML_VXE "ggml: enable vxe" ON) option(GGML_CPU_ALL_VARIANTS "ggml: build all variants of the CPU backend (requires GGML_BACKEND_DL)" OFF) -set(GGML_CPU_ARM_ARCH "" CACHE STRING "ggml: CPU architecture for ARM") +set(GGML_CPU_ARM_ARCH "" CACHE STRING "ggml: CPU architecture for ARM") +set(GGML_CPU_POWERPC_CPUTYPE "" CACHE STRING "ggml: CPU type for PowerPC") if (WIN32) diff --git a/ggml/cmake/GitVars.cmake b/ggml/cmake/GitVars.cmake new file mode 100644 index 000000000..1a4c24ebf --- /dev/null +++ b/ggml/cmake/GitVars.cmake @@ -0,0 +1,22 @@ +find_package(Git) + +# the commit's SHA1 +execute_process(COMMAND + "${GIT_EXECUTABLE}" describe --match=NeVeRmAtCh --always --abbrev=8 + WORKING_DIRECTORY "${CMAKE_SOURCE_DIR}" + OUTPUT_VARIABLE GIT_SHA1 + ERROR_QUIET OUTPUT_STRIP_TRAILING_WHITESPACE) + +# the date of the commit +execute_process(COMMAND + "${GIT_EXECUTABLE}" log -1 --format=%ad --date=local + WORKING_DIRECTORY "${CMAKE_SOURCE_DIR}" + OUTPUT_VARIABLE GIT_DATE + ERROR_QUIET OUTPUT_STRIP_TRAILING_WHITESPACE) + +# the subject of the commit +execute_process(COMMAND + "${GIT_EXECUTABLE}" log -1 --format=%s + WORKING_DIRECTORY "${CMAKE_SOURCE_DIR}" + OUTPUT_VARIABLE GIT_COMMIT_SUBJECT + ERROR_QUIET OUTPUT_STRIP_TRAILING_WHITESPACE) diff --git a/ggml/cmake/ggml-config.cmake.in b/ggml/cmake/ggml-config.cmake.in index 823eb797b..8c2dc31c6 100644 --- a/ggml/cmake/ggml-config.cmake.in +++ b/ggml/cmake/ggml-config.cmake.in @@ -5,7 +5,7 @@ set_and_check(GGML_INCLUDE_DIR "@PACKAGE_GGML_INCLUDE_INSTALL_DIR@") set_and_check(GGML_LIB_DIR "@PACKAGE_GGML_LIB_INSTALL_DIR@") -set_and_check(GGML_BIN_DIR "@PACKAGE_GGML_BIN_INSTALL_DIR@") +#set_and_check(GGML_BIN_DIR "@PACKAGE_GGML_BIN_INSTALL_DIR@") find_package(Threads REQUIRED) diff --git a/scripts/sync-ggml-am.sh b/scripts/sync-ggml-am.sh index 33e8c6414..914ff7c55 100755 --- a/scripts/sync-ggml-am.sh +++ b/scripts/sync-ggml-am.sh @@ -69,7 +69,11 @@ while read c; do git format-patch -U${ctx} -k $c~1..$c --stdout -- \ CMakeLists.txt \ src/CMakeLists.txt \ - cmake/FindSIMD.cmake \ + cmake/BuildTypes.cmake \ + cmake/GitVars.cmake \ + cmake/common.cmake \ + cmake/ggml-config.cmake.in \ + src/ggml-cpu/cmake/FindSIMD.cmake \ src/ggml*.h \ src/ggml*.c \ src/ggml*.cpp \ @@ -121,7 +125,12 @@ if [ -f $SRC_LLAMA/ggml-src.patch ]; then # # CMakelists.txt -> ggml/CMakeLists.txt # src/CMakeLists.txt -> ggml/src/CMakeLists.txt - # cmake/FindSIMD.cmake -> ggml/cmake/FindSIMD.cmake + + # cmake/BuildTypes.cmake -> ggml/cmake/BuildTypes.cmake + # cmake/GitVars.cmake -> ggml/cmake/GitVars.cmake + # cmake/common.cmake -> ggml/cmake/common.cmake + # cmake/ggml-config.cmake.in -> ggml/cmake/ggml-config.cmake.in + # src/ggml-cpu/cmake/FindSIMD.cmake -> ggml/src/ggml-cpu/cmake/FindSIMD.cmake # # src/ggml*.c -> ggml/src/ggml*.c # src/ggml*.cpp -> ggml/src/ggml*.cpp @@ -151,7 +160,11 @@ if [ -f $SRC_LLAMA/ggml-src.patch ]; then cat ggml-src.patch | sed -E \ -e 's/(^[[:space:]]| [ab]\/)CMakeLists.txt/\1ggml\/CMakeLists.txt/g' \ -e 's/(^[[:space:]]| [ab]\/)src\/CMakeLists.txt/\1ggml\/src\/CMakeLists.txt/g' \ - -e 's/(^[[:space:]]| [ab]\/)cmake\/FindSIMD.cmake/\1ggml\/cmake\/FindSIMD.cmake/g' \ + -e 's/(^[[:space:]]| [ab]\/)cmake\/BuildTypes.cmake/\1ggml\/cmake\/BuildTypes.cmake/g' \ + -e 's/(^[[:space:]]| [ab]\/)cmake\/GitVars.cmake/\1ggml\/cmake\/GitVars.cmake/g' \ + -e 's/(^[[:space:]]| [ab]\/)cmake\/common.cmake/\1ggml\/cmake\/common.cmake/g' \ + -e 's/(^[[:space:]]| [ab]\/)cmake\/ggml-config.cmake.in/\1ggml\/cmake\/ggml-config.cmake.in/g' \ + -e 's/(^[[:space:]]| [ab]\/)src\/ggml-cpu\/cmake\/FindSIMD.cmake/\1ggml\/src\/ggml-cpu\/cmake\/FindSIMD.cmake/g' \ -e 's/([[:space:]]| [ab]\/)src\/ggml(.*)\.c/\1ggml\/src\/ggml\2.c/g' \ -e 's/([[:space:]]| [ab]\/)src\/ggml(.*)\.cpp/\1ggml\/src\/ggml\2.cpp/g' \ -e 's/([[:space:]]| [ab]\/)src\/ggml(.*)\.h/\1ggml\/src\/ggml\2.h/g' \ diff --git a/scripts/sync-ggml.sh b/scripts/sync-ggml.sh index e83d415c0..aa1a46b4b 100755 --- a/scripts/sync-ggml.sh +++ b/scripts/sync-ggml.sh @@ -2,7 +2,9 @@ cp -rpv ../ggml/CMakeLists.txt ./ggml/CMakeLists.txt cp -rpv ../ggml/src/CMakeLists.txt ./ggml/src/CMakeLists.txt -cp -rpv ../ggml/cmake/FindSIMD.cmake ./ggml/cmake/FindSIMD.cmake + +cp -rpv ../ggml/cmake/* ./ggml/cmake/ +cp -rpv ../ggml/src/ggml-cpu/cmake/* ./ggml/src/ggml-cpu/cmake/ cp -rpv ../ggml/src/ggml*.c ./ggml/src/ cp -rpv ../ggml/src/ggml*.cpp ./ggml/src/ From 029c693fdcdc80e8508553df61a4337bb5fe49a9 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 27 Mar 2025 09:36:13 +0200 Subject: [PATCH 12/26] sync : ggml ggml-ci --- scripts/sync-ggml.last | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/sync-ggml.last b/scripts/sync-ggml.last index d0532d85c..bf01d88ad 100644 --- a/scripts/sync-ggml.last +++ b/scripts/sync-ggml.last @@ -1 +1 @@ -fc21aba88324312d66f62af6a1b0683fcb2ce3b5 +660def06391b3d6c9eed9fed38d7dc025ee1b1ca From d5c6309d91cb22ebc947920f92eb686d92f84eae Mon Sep 17 00:00:00 2001 From: Csaba Kecskemeti Date: Thu, 27 Mar 2025 03:11:23 -0700 Subject: [PATCH 13/26] convert : Support Qwen2_5_VLForConditionalGeneration (#12595) --- convert_hf_to_gguf.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 52637c42f..a06010a79 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -2269,7 +2269,7 @@ class Qwen2Model(Model): self.gguf_writer.add_rope_scaling_orig_ctx_len(self.hparams["rope_scaling"]["original_max_position_embeddings"]) -@Model.register("Qwen2VLForConditionalGeneration") +@Model.register("Qwen2VLForConditionalGeneration", "Qwen2_5_VLForConditionalGeneration") class Qwen2VLModel(Model): model_arch = gguf.MODEL_ARCH.QWEN2VL From 953c2a62cf487e618140f3ea18d94e3b0257af93 Mon Sep 17 00:00:00 2001 From: HighDoping Date: Thu, 27 Mar 2025 18:43:33 +0800 Subject: [PATCH 14/26] model : restore support for T5Encoder (#12590) --- src/llama-model.cpp | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 0ae754154..c8e3386fc 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -11846,10 +11846,11 @@ llm_graph_result_ptr llama_model::build_graph( GGML_ABORT("invalid graph type"); }; } break; - //case LLM_ARCH_T5ENCODER: - // { - // llm.build_t5_enc(gf); - // } break; + case LLM_ARCH_T5ENCODER: + { + llm = std::make_unique(*this, params, gf); + } + break; case LLM_ARCH_JAIS: { llm = std::make_unique(*this, params, gf); From f125b8dccff34439a26bf750c9edef358c48c1f8 Mon Sep 17 00:00:00 2001 From: Si1w <139008732+Si1w@users.noreply.github.com> Date: Thu, 27 Mar 2025 10:49:15 +0000 Subject: [PATCH 15/26] llama : add PLM GGUF Conversion & Inference Support (#12457) * add edgellm model arch[conversation feature doesn't work] * remove output.weight layer for edgellm arch * [Model] update the name of the model * update the name of model arch in convert gguf * [Model] Refarctor the model arch into llama-model * [Bug] Fix the bug in create attn kv * [Code] Fix editorconfig erros * [Code] Remove Trailing whitespace * [Code] Remove Trailing whitespace * [Code] Change the order of model arch in list * [Code] Fix flake8 Lint errors * Remove trailing white space * [Code] Remove call in model arch --- convert_hf_to_gguf.py | 23 ++++ gguf-py/gguf/constants.py | 16 +++ src/llama-arch.cpp | 17 +++ src/llama-arch.h | 1 + src/llama-model.cpp | 216 ++++++++++++++++++++++++++++++++++++++ src/llama-model.h | 1 + 6 files changed, 274 insertions(+) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index a06010a79..c605e4d05 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -4419,6 +4419,29 @@ class DeepseekV2Model(Model): raise ValueError(f"Unprocessed experts: {experts}") +@Model.register("PLMForCausalLM") +class PLMModel(Model): + model_arch = gguf.MODEL_ARCH.PLM + + def set_vocab(self): + self._set_vocab_gpt2() + + def set_gguf_parameters(self): + super().set_gguf_parameters() + hparams = self.hparams + self.gguf_writer.add_vocab_size(hparams["vocab_size"]) + self.gguf_writer.add_kv_lora_rank(hparams["kv_lora_rank"]) + self.gguf_writer.add_key_length(hparams["qk_nope_head_dim"] + hparams["qk_rope_head_dim"]) + self.gguf_writer.add_value_length(hparams["v_head_dim"]) + self.gguf_writer.add_rope_dimension_count(hparams["qk_rope_head_dim"]) + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + return [(self.map_tensor_name(name), data_torch)] + + def prepare_tensors(self): + super().prepare_tensors() + + @Model.register("T5WithLMHeadModel") @Model.register("T5ForConditionalGeneration") @Model.register("MT5ForConditionalGeneration") diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 13cca7ab0..1753dca4b 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -286,6 +286,7 @@ class MODEL_ARCH(IntEnum): GRANITE_MOE = auto() CHAMELEON = auto() WAVTOKENIZER_DEC = auto() + PLM = auto() class MODEL_TENSOR(IntEnum): @@ -488,6 +489,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = { MODEL_ARCH.GRANITE_MOE: "granitemoe", MODEL_ARCH.CHAMELEON: "chameleon", MODEL_ARCH.WAVTOKENIZER_DEC: "wavtokenizer-dec", + MODEL_ARCH.PLM: "plm", } TENSOR_NAMES: dict[MODEL_TENSOR, str] = { @@ -1464,6 +1466,20 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = { MODEL_TENSOR.FFN_UP_SHEXP, MODEL_TENSOR.FFN_EXP_PROBS_B, ], + MODEL_ARCH.PLM: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_Q, + MODEL_TENSOR.ATTN_KV_A_MQA, + MODEL_TENSOR.ATTN_KV_A_NORM, + MODEL_TENSOR.ATTN_KV_B, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.FFN_NORM, + MODEL_TENSOR.FFN_UP, + MODEL_TENSOR.FFN_DOWN, + ], MODEL_ARCH.CHATGLM : [ MODEL_TENSOR.TOKEN_EMBD, MODEL_TENSOR.ROPE_FREQS, diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp index 8664f8963..9e443d830 100644 --- a/src/llama-arch.cpp +++ b/src/llama-arch.cpp @@ -65,6 +65,7 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_GRANITE_MOE, "granitemoe" }, { LLM_ARCH_CHAMELEON, "chameleon" }, { LLM_ARCH_WAVTOKENIZER_DEC, "wavtokenizer-dec" }, + { LLM_ARCH_PLM, "plm" }, { LLM_ARCH_UNKNOWN, "(unknown)" }, }; @@ -1043,6 +1044,22 @@ static const std::map> LLM_TENSOR_N { LLM_TENSOR_FFN_EXP_PROBS_B, "blk.%d.exp_probs_b" }, }, }, + { + LLM_ARCH_PLM, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_KV_A_MQA, "blk.%d.attn_kv_a_mqa" }, + { LLM_TENSOR_ATTN_KV_A_NORM, "blk.%d.attn_kv_a_norm" }, + { LLM_TENSOR_ATTN_KV_B, "blk.%d.attn_kv_b" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + }, + }, { LLM_ARCH_CHATGLM, { diff --git a/src/llama-arch.h b/src/llama-arch.h index a28815d8a..39e3a2ce0 100644 --- a/src/llama-arch.h +++ b/src/llama-arch.h @@ -69,6 +69,7 @@ enum llm_arch { LLM_ARCH_GRANITE_MOE, LLM_ARCH_CHAMELEON, LLM_ARCH_WAVTOKENIZER_DEC, + LLM_ARCH_PLM, LLM_ARCH_UNKNOWN, }; diff --git a/src/llama-model.cpp b/src/llama-model.cpp index c8e3386fc..a442abeb8 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -47,6 +47,7 @@ const char * llm_type_name(llm_type type) { case LLM_TYPE_1_4B: return "1.4B"; case LLM_TYPE_1_5B: return "1.5B"; case LLM_TYPE_1_6B: return "1.6B"; + case LLM_TYPE_1_8B: return "1.8B"; case LLM_TYPE_2B: return "2B"; case LLM_TYPE_2_8B: return "2.8B"; case LLM_TYPE_2_9B: return "2.9B"; @@ -1144,6 +1145,15 @@ void llama_model::load_hparams(llama_model_loader & ml) { default: type = LLM_TYPE_UNKNOWN; } } break; + case LLM_ARCH_PLM: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key(LLM_KV_ATTENTION_KV_LORA_RANK, hparams.n_lora_kv); + switch (hparams.n_layer) { + case 32: type = LLM_TYPE_1_8B; break; + default: type = LLM_TYPE_UNKNOWN; + } + } break; case LLM_ARCH_CHATGLM: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); @@ -3068,6 +3078,35 @@ bool llama_model::load_tensors(llama_model_loader & ml) { } } } break; + case LLM_ARCH_PLM: + { + const int64_t n_embd_head_qk_rope = hparams.n_rot; + const int64_t n_embd_head_qk_nope = hparams.n_embd_head_k - hparams.n_rot; + const int64_t kv_lora_rank = hparams.n_lora_kv; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + // output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0); + layer.wkv_a_mqa = create_tensor(tn(LLM_TENSOR_ATTN_KV_A_MQA, "weight", i), {n_embd, kv_lora_rank + (n_embd_head_qk_rope)}, 0); + layer.attn_kv_a_norm = create_tensor(tn(LLM_TENSOR_ATTN_KV_A_NORM, "weight", i), {kv_lora_rank}, 0); + layer.wkv_b = create_tensor(tn(LLM_TENSOR_ATTN_KV_B, "weight", i), {kv_lora_rank, n_head * (n_embd_head_qk_nope + n_embd_head_v)}, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_head * ( n_embd_head_v), n_embd}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + } + } break; case LLM_ARCH_BITNET: { tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); @@ -11615,6 +11654,178 @@ struct llm_build_wavtokenizer_dec : public llm_graph_context { } }; +struct llm_build_plm : public llm_graph_context { + llm_build_plm(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { + const float kq_scale = 1.0f/sqrtf(float(hparams.n_embd_head_k)); + + const uint32_t n_embd_head_qk_rope = hparams.n_rot; + const uint32_t n_embd_head_qk_nope = hparams.n_embd_head_k - hparams.n_rot; + const uint32_t kv_lora_rank = hparams.n_lora_kv; + + ggml_tensor * cur; + ggml_tensor * inpL; + + // {n_embd, n_tokens} + inpL = build_inp_embd(model.tok_embd); + + // inp_pos - contains the positions + ggml_tensor * inp_pos = build_inp_pos(); + + auto * inp_attn = build_attn_inp_kv_unified(); + + for (int il = 0; il < n_layer; ++il) { + ggml_tensor * inpSA = inpL; + + // norm + cur = build_norm(inpL, + model.layers[il].attn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "attn_norm", il); + + // self_attention + { + ggml_tensor * q = NULL; + q = ggml_mul_mat(ctx0, model.layers[il].wq, cur); + cb(q, "q", il); + + // split into {n_head * n_embd_head_qk_nope, n_tokens} + ggml_tensor * q_nope = ggml_view_3d(ctx0, q, n_embd_head_qk_nope, n_head, n_tokens, + ggml_row_size(q->type, hparams.n_embd_head_k), + ggml_row_size(q->type, hparams.n_embd_head_k * n_head), + 0); + cb(q_nope, "q_nope", il); + + // and {n_head * n_embd_head_qk_rope, n_tokens} + ggml_tensor * q_pe = ggml_view_3d(ctx0, q, n_embd_head_qk_rope, n_head, n_tokens, + ggml_row_size(q->type, hparams.n_embd_head_k), + ggml_row_size(q->type, hparams.n_embd_head_k * n_head), + ggml_row_size(q->type, n_embd_head_qk_nope)); + cb(q_pe, "q_pe", il); + + // {n_embd, kv_lora_rank + n_embd_head_qk_rope} * {n_embd, n_tokens} -> {kv_lora_rank + n_embd_head_qk_rope, n_tokens} + ggml_tensor * kv_pe_compresseed = ggml_mul_mat(ctx0, model.layers[il].wkv_a_mqa, cur); + cb(kv_pe_compresseed, "kv_pe_compresseed", il); + + // split into {kv_lora_rank, n_tokens} + ggml_tensor * kv_compressed = ggml_view_2d(ctx0, kv_pe_compresseed, kv_lora_rank, n_tokens, + kv_pe_compresseed->nb[1], + 0); + cb(kv_compressed, "kv_compressed", il); + + // and {n_embd_head_qk_rope, n_tokens} + ggml_tensor * k_pe = ggml_view_3d(ctx0, kv_pe_compresseed, n_embd_head_qk_rope, 1, n_tokens, + kv_pe_compresseed->nb[1], + kv_pe_compresseed->nb[1], + ggml_row_size(kv_pe_compresseed->type, kv_lora_rank)); + cb(k_pe, "k_pe", il); + + kv_compressed = build_norm(kv_compressed, + model.layers[il].attn_kv_a_norm, NULL, + LLM_NORM_RMS, il); + cb(kv_compressed, "kv_compressed", il); + + // {kv_lora_rank, n_head * (n_embd_head_qk_nope + n_embd_head_v)} * {kv_lora_rank, n_tokens} -> {n_head * (n_embd_head_qk_nope + n_embd_head_v), n_tokens} + ggml_tensor * kv = ggml_mul_mat(ctx0, model.layers[il].wkv_b, kv_compressed); + cb(kv, "kv", il); + + // split into {n_head * n_embd_head_qk_nope, n_tokens} + ggml_tensor * k_nope = ggml_view_3d(ctx0, kv, n_embd_head_qk_nope, n_head, n_tokens, + ggml_row_size(kv->type, n_embd_head_qk_nope + hparams.n_embd_head_v), + ggml_row_size(kv->type, n_head * (n_embd_head_qk_nope + hparams.n_embd_head_v)), + 0); + cb(k_nope, "k_nope", il); + + // and {n_head * n_embd_head_v, n_tokens} + ggml_tensor * v_states = ggml_view_3d(ctx0, kv, hparams.n_embd_head_v, n_head, n_tokens, + ggml_row_size(kv->type, (n_embd_head_qk_nope + hparams.n_embd_head_v)), + ggml_row_size(kv->type, (n_embd_head_qk_nope + hparams.n_embd_head_v)*n_head), + ggml_row_size(kv->type, (n_embd_head_qk_nope))); + cb(v_states, "v_states", il); + + v_states = ggml_cont(ctx0, v_states); + cb(v_states, "v_states", il); + + v_states = ggml_view_2d(ctx0, v_states, hparams.n_embd_head_v * n_head, n_tokens, + ggml_row_size(kv->type, hparams.n_embd_head_v * n_head), + 0); + cb(v_states, "v_states", il); + + q_pe = ggml_rope_ext( + ctx0, q_pe, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + cb(q_pe, "q_pe", il); + + // shared RoPE key + k_pe = ggml_rope_ext( + ctx0, k_pe, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + cb(k_pe, "k_pe", il); + + ggml_tensor * q_states = ggml_concat(ctx0, q_nope, q_pe, 0); + cb(q_states, "q_states", il); + + ggml_tensor * k_states = ggml_concat(ctx0, k_nope, ggml_repeat(ctx0, k_pe, q_pe), 0); + cb(k_states, "k_states", il); + + cur = build_attn(inp_attn, gf, + model.layers[il].wo, NULL, + q_states, k_states, v_states, nullptr, kq_scale, il); + } + + if (il == n_layer - 1) { + // skip computing output for unused tokens + ggml_tensor * inp_out_ids = build_inp_out_ids(); + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + + ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + + cur = build_norm(ffn_inp, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "ffn_norm", il); + + cur = build_ffn(cur, + model.layers[il].ffn_up, NULL, NULL, + NULL, NULL, NULL, + model.layers[il].ffn_down, NULL, NULL, + NULL, + LLM_FFN_RELU_SQR, LLM_FFN_SEQ, il); + cb(cur, "ffn_out", il); + + cur = ggml_add(ctx0, cur, ffn_inp); + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + cur = inpL; + + cur = build_norm(cur, + model.output_norm, NULL, + LLM_NORM_RMS, -1); + + cb(cur, "result_norm", -1); + res->t_embd = cur; + + cur = build_lora_mm(model.output, cur); + + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); + } +}; + llama_memory_i * llama_model::create_memory() const { llama_memory_i * res; @@ -11887,6 +12098,10 @@ llm_graph_result_ptr llama_model::build_graph( { llm = std::make_unique(*this, params, gf); } break; + case LLM_ARCH_PLM: + { + llm = std::make_unique(*this, params, gf); + } break; default: GGML_ABORT("fatal error"); } @@ -12013,6 +12228,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) { case LLM_ARCH_ARCTIC: case LLM_ARCH_DEEPSEEK: case LLM_ARCH_DEEPSEEK2: + case LLM_ARCH_PLM: case LLM_ARCH_CHATGLM: case LLM_ARCH_GRANITE: case LLM_ARCH_GRANITE_MOE: diff --git a/src/llama-model.h b/src/llama-model.h index a9da1215a..0064d597a 100644 --- a/src/llama-model.h +++ b/src/llama-model.h @@ -44,6 +44,7 @@ enum llm_type { LLM_TYPE_1_4B, LLM_TYPE_1_5B, LLM_TYPE_1_6B, + LLM_TYPE_1_8B, LLM_TYPE_2B, LLM_TYPE_2_8B, LLM_TYPE_2_9B, From 5dec47dcd411fdf815a3708fd6194e2b13d19006 Mon Sep 17 00:00:00 2001 From: lhez Date: Thu, 27 Mar 2025 08:08:08 -0700 Subject: [PATCH 16/26] opencl: add multi and vision rope, `gelu_quick` and `im2col` (#12600) * opencl: add `im2col` * opencl: add `gelu_quick` * opencl: add mrope * opencl: add vision rope --- ggml/src/ggml-opencl/CMakeLists.txt | 1 + ggml/src/ggml-opencl/ggml-opencl.cpp | 252 +++++++++++- ggml/src/ggml-opencl/kernels/ggml-opencl.cl | 389 ++++++++++++++++++ .../ggml-opencl/kernels/ggml-opencl_im2col.cl | 146 +++++++ 4 files changed, 774 insertions(+), 14 deletions(-) create mode 100644 ggml/src/ggml-opencl/kernels/ggml-opencl_im2col.cl diff --git a/ggml/src/ggml-opencl/CMakeLists.txt b/ggml/src/ggml-opencl/CMakeLists.txt index 7efb51c8e..624cb1b9d 100644 --- a/ggml/src/ggml-opencl/CMakeLists.txt +++ b/ggml/src/ggml-opencl/CMakeLists.txt @@ -63,6 +63,7 @@ set(GGML_OPENCL_KERNELS ggml-opencl_transpose_16 ggml-opencl_transpose_32 ggml-opencl_transpose_32_16 + ggml-opencl_im2col ) foreach (K ${GGML_OPENCL_KERNELS}) diff --git a/ggml/src/ggml-opencl/ggml-opencl.cpp b/ggml/src/ggml-opencl/ggml-opencl.cpp index efaf7f479..6c123ddef 100644 --- a/ggml/src/ggml-opencl/ggml-opencl.cpp +++ b/ggml/src/ggml-opencl/ggml-opencl.cpp @@ -224,12 +224,14 @@ struct ggml_backend_opencl_context { cl_program program; cl_program program_1; cl_program program_2; + cl_program program_im2col; cl_kernel kernel_add, kernel_add_row; cl_kernel kernel_mul, kernel_mul_row; cl_kernel kernel_scale; cl_kernel kernel_silu, kernel_silu_4; cl_kernel kernel_gelu, kernel_gelu_4; + cl_kernel kernel_gelu_quick, kernel_gelu_quick_4; cl_kernel kernel_relu; cl_kernel kernel_clamp; cl_kernel kernel_norm; @@ -239,6 +241,7 @@ struct ggml_backend_opencl_context { cl_kernel kernel_soft_max_f16, kernel_soft_max_4_f16; cl_kernel kernel_get_rows_f32, kernel_get_rows_f16, kernel_get_rows_q4_0; cl_kernel kernel_rope_norm_f32, kernel_rope_norm_f16, kernel_rope_neox_f32, kernel_rope_neox_f16; + cl_kernel kernel_rope_multi_f32, kernel_rope_multi_f16, kernel_rope_vision_f32, kernel_rope_vision_f16; cl_kernel kernel_cpy_f16_f16, kernel_cpy_f16_f32, kernel_cpy_f32_f16, kernel_cpy_f32_f32; cl_kernel kernel_mul_mat_f32_f32; cl_kernel kernel_mul_mat_f16_f16; @@ -252,6 +255,7 @@ struct ggml_backend_opencl_context { kernel_mul_mat_q4_0_f32_flat_img_v0; cl_kernel kernel_mul_mat_q4_0_f32_1d_8x_flat, kernel_mul_mat_q4_0_f32_1d_16x_flat; cl_kernel kernel_mul_mv_q6_K_f32; + cl_kernel kernel_im2col_f32, kernel_im2col_f16; #ifdef GGML_OPENCL_USE_ADRENO_KERNELS // Transpose kernels @@ -708,6 +712,8 @@ static ggml_backend_opencl_context * ggml_cl2_init(ggml_backend_dev_t dev) { CL_CHECK((backend_ctx->kernel_silu_4 = clCreateKernel(backend_ctx->program, "kernel_silu_4", &err), err)); CL_CHECK((backend_ctx->kernel_gelu = clCreateKernel(backend_ctx->program, "kernel_gelu", &err), err)); CL_CHECK((backend_ctx->kernel_gelu_4 = clCreateKernel(backend_ctx->program, "kernel_gelu_4", &err), err)); + CL_CHECK((backend_ctx->kernel_gelu_quick = clCreateKernel(backend_ctx->program, "kernel_gelu_quick", &err), err)); + CL_CHECK((backend_ctx->kernel_gelu_quick_4 = clCreateKernel(backend_ctx->program, "kernel_gelu_quick_4", &err), err)); CL_CHECK((backend_ctx->kernel_relu = clCreateKernel(backend_ctx->program, "kernel_relu", &err), err)); CL_CHECK((backend_ctx->kernel_clamp = clCreateKernel(backend_ctx->program, "kernel_clamp", &err), err)); CL_CHECK((backend_ctx->kernel_norm = clCreateKernel(backend_ctx->program, "kernel_norm", &err), err)); @@ -722,6 +728,10 @@ static ggml_backend_opencl_context * ggml_cl2_init(ggml_backend_dev_t dev) { CL_CHECK((backend_ctx->kernel_rope_norm_f16 = clCreateKernel(backend_ctx->program, "kernel_rope_norm_f16", &err), err)); CL_CHECK((backend_ctx->kernel_rope_neox_f32 = clCreateKernel(backend_ctx->program, "kernel_rope_neox_f32", &err), err)); CL_CHECK((backend_ctx->kernel_rope_neox_f16 = clCreateKernel(backend_ctx->program, "kernel_rope_neox_f16", &err), err)); + CL_CHECK((backend_ctx->kernel_rope_multi_f32 = clCreateKernel(backend_ctx->program, "kernel_rope_multi_f32", &err), err)); + CL_CHECK((backend_ctx->kernel_rope_multi_f16 = clCreateKernel(backend_ctx->program, "kernel_rope_multi_f16", &err), err)); + CL_CHECK((backend_ctx->kernel_rope_vision_f32 = clCreateKernel(backend_ctx->program, "kernel_rope_vision_f32", &err), err)); + CL_CHECK((backend_ctx->kernel_rope_vision_f16 = clCreateKernel(backend_ctx->program, "kernel_rope_vision_f16", &err), err)); CL_CHECK((backend_ctx->kernel_cpy_f16_f16 = clCreateKernel(backend_ctx->program, "kernel_cpy_f16_f16", &err), err)); CL_CHECK((backend_ctx->kernel_cpy_f16_f32 = clCreateKernel(backend_ctx->program, "kernel_cpy_f16_f32", &err), err)); CL_CHECK((backend_ctx->kernel_cpy_f32_f16 = clCreateKernel(backend_ctx->program, "kernel_cpy_f32_f16", &err), err)); @@ -769,6 +779,19 @@ static ggml_backend_opencl_context * ggml_cl2_init(ggml_backend_dev_t dev) { CL_CHECK((backend_ctx->kernel_convert_block_q4_0_noshuffle = clCreateKernel(backend_ctx->program_2, "kernel_convert_block_q4_0_noshuffle", &err), err)); + // im2col kernels +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src_im2col { + #include "ggml-opencl_im2col.cl.h" + }; +#else + const std::string kernel_src_im2col = read_file("ggml-opencl_im2col.cl"); +#endif + backend_ctx->program_im2col = build_program_from_source(context, device, kernel_src_im2col.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_im2col_f32 = clCreateKernel(backend_ctx->program_im2col, "kernel_im2col_f32", &err), err)); + CL_CHECK((backend_ctx->kernel_im2col_f16 = clCreateKernel(backend_ctx->program_im2col, "kernel_im2col_f16", &err), err)); + // Kernels for Adreno #ifdef GGML_OPENCL_USE_ADRENO_KERNELS #ifdef GGML_OPENCL_EMBED_KERNELS @@ -1187,6 +1210,7 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te case GGML_UNARY_OP_GELU: case GGML_UNARY_OP_SILU: case GGML_UNARY_OP_RELU: + case GGML_UNARY_OP_GELU_QUICK: return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32; default: return false; @@ -1216,14 +1240,26 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te return op->ne[3] == 1; case GGML_OP_ROPE: { const int mode = ((const int32_t *) op->op_params)[2]; - if (mode & GGML_ROPE_TYPE_MROPE) { + const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE; + const bool is_vision = mode == GGML_ROPE_TYPE_VISION; + if (is_mrope && !is_vision) { + if (op->src[0]->type == GGML_TYPE_F32 || + op->src[0]->type == GGML_TYPE_F16) { + return true; + } return false; } - if (mode & GGML_ROPE_TYPE_VISION) { + if (is_vision) { + if (op->src[0]->type == GGML_TYPE_F32 || + op->src[0]->type == GGML_TYPE_F16) { + return true; + } return false; } return true; } + case GGML_OP_IM2COL: + return true; default: return false; } @@ -2582,6 +2618,53 @@ static void ggml_cl_gelu(ggml_backend_t backend, const ggml_tensor * src0, const #endif } +static void ggml_cl_gelu_quick(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + GGML_ASSERT(src0); + GGML_ASSERT(src0->extra); + GGML_ASSERT(dst); + GGML_ASSERT(dst->extra); + + UNUSED(src1); + + ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; + cl_command_queue queue = backend_ctx->queue; + + ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra; + ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra; + + cl_ulong offset0 = extra0->offset + src0->view_offs; + cl_ulong offsetd = extrad->offset + dst->view_offs; + + cl_kernel kernel; + + int n = ggml_nelements(dst); + + if (n % 4 == 0) { + kernel = backend_ctx->kernel_gelu_quick_4; + n /= 4; + } else { + kernel = backend_ctx->kernel_gelu_quick; + } + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd)); + + size_t global_work_size[] = {(size_t)n, 1, 1}; + size_t local_work_size[] = {64, 1, 1}; + +#ifdef GGML_OPENCL_PROFILING + cl_event evt; + clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt); + + g_profiling_info.emplace_back(); + populateProfilingInfo(g_profiling_info.back(), evt, kernel, global_work_size, local_work_size, dst); +#else + clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, NULL); +#endif +} + static void ggml_cl_silu(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { GGML_ASSERT(src0); GGML_ASSERT(src0->extra); @@ -3980,6 +4063,7 @@ static void ggml_cl_rope(ggml_backend_t backend, const ggml_tensor * src0, const float attn_factor; float beta_fast; float beta_slow; + int32_t sections[4]; memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float)); memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float)); @@ -3987,23 +4071,23 @@ static void ggml_cl_rope(ggml_backend_t backend, const ggml_tensor * src0, const memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float)); memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float)); memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float)); + memcpy(§ions, (int32_t *) dst->op_params + 11, sizeof(int32_t)*4); const bool is_neox = mode & 2; + const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE; + const bool is_vision = mode == GGML_ROPE_TYPE_VISION; + + if (is_mrope) { + GGML_ASSERT(sections[0] > 0 || sections[1] > 0 || sections[2] > 0); + } + + if (is_vision) { + GGML_ASSERT(n_dims == ne00/2); + } cl_kernel kernel; - if (!is_neox) { - switch (src0->type) { - case GGML_TYPE_F32: - kernel = backend_ctx->kernel_rope_norm_f32; - break; - case GGML_TYPE_F16: - kernel = backend_ctx->kernel_rope_norm_f16; - break; - default: - GGML_ASSERT(false); - }; - } else { + if (is_neox) { switch (src0->type) { case GGML_TYPE_F32: kernel = backend_ctx->kernel_rope_neox_f32; @@ -4014,6 +4098,39 @@ static void ggml_cl_rope(ggml_backend_t backend, const ggml_tensor * src0, const default: GGML_ASSERT(false); }; + } else if (is_mrope && !is_vision) { + switch (src0->type) { + case GGML_TYPE_F32: + kernel = backend_ctx->kernel_rope_multi_f32; + break; + case GGML_TYPE_F16: + kernel = backend_ctx->kernel_rope_multi_f16; + break; + default: + GGML_ASSERT(false); + }; + } else if (is_vision) { + switch (src0->type) { + case GGML_TYPE_F32: + kernel = backend_ctx->kernel_rope_vision_f32; + break; + case GGML_TYPE_F16: + kernel = backend_ctx->kernel_rope_vision_f16; + break; + default: + GGML_ASSERT(false); + } + } else { + switch (src0->type) { + case GGML_TYPE_F32: + kernel = backend_ctx->kernel_rope_norm_f32; + break; + case GGML_TYPE_F16: + kernel = backend_ctx->kernel_rope_norm_f16; + break; + default: + GGML_ASSERT(false); + }; } CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); @@ -4049,6 +4166,9 @@ static void ggml_cl_rope(ggml_backend_t backend, const ggml_tensor * src0, const CL_CHECK(clSetKernelArg(kernel, 30, sizeof(float), &attn_factor)); CL_CHECK(clSetKernelArg(kernel, 31, sizeof(float), &beta_fast)); CL_CHECK(clSetKernelArg(kernel, 32, sizeof(float), &beta_slow)); + if (is_mrope || is_vision) { + CL_CHECK(clSetKernelArg(kernel, 33, sizeof(int32_t)*4, §ions)); + } size_t global_work_size[] = {(size_t)ne01*nth, (size_t)ne02, (size_t)ne03}; size_t local_work_size[] = {(size_t)nth, 1, 1}; @@ -4064,6 +4184,98 @@ static void ggml_cl_rope(ggml_backend_t backend, const ggml_tensor * src0, const #endif } +static void ggml_cl_im2col(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + GGML_ASSERT(src0); + GGML_ASSERT(src1); + GGML_ASSERT(src1->extra); + GGML_ASSERT(dst); + GGML_ASSERT(dst->extra); + + // src0 - filter, src1 - input + GGML_ASSERT(src1->type == GGML_TYPE_F32); + GGML_ASSERT(dst->type == GGML_TYPE_F16 || dst->type == GGML_TYPE_F32); + + ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; + cl_command_queue queue = backend_ctx->queue; + + ggml_tensor_extra_cl * extra1 = (ggml_tensor_extra_cl *)src1->extra; + ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra; + + cl_ulong offset1 = extra1->offset + src1->view_offs; + cl_ulong offsetd = extrad->offset + dst->view_offs; + + const int32_t s0 = ((const int32_t*)(dst->op_params))[0]; + const int32_t s1 = ((const int32_t*)(dst->op_params))[1]; + const int32_t p0 = ((const int32_t*)(dst->op_params))[2]; + const int32_t p1 = ((const int32_t*)(dst->op_params))[3]; + const int32_t d0 = ((const int32_t*)(dst->op_params))[4]; + const int32_t d1 = ((const int32_t*)(dst->op_params))[5]; + + const bool is_2D = ((const int32_t*)(dst->op_params))[6] == 1; + + const cl_long IC = src1->ne[is_2D ? 2 : 1]; + const cl_long IH = is_2D ? src1->ne[1] : 1; + const cl_long IW = src1->ne[0]; + + const cl_long KH = is_2D ? src0->ne[1] : 1; + const cl_long KW = src0->ne[0]; + + const cl_long OH = is_2D ? dst->ne[2] : 1; + const cl_long OW = dst->ne[1]; + + // nb is byte offset, src is type float32 + const cl_ulong delta_offset = src1->nb[is_2D ? 2 : 1]/4; + const cl_long batch = src1->ne[is_2D ? 3 : 2]; + const cl_ulong batch_offset = src1->nb[is_2D ? 3 : 2]/4; + + const cl_long pelements = OW*KW*KH; + const cl_long CHW = IC*KH*KW; + + cl_kernel kernel; + + if(dst->type == GGML_TYPE_F16) { + kernel = backend_ctx->kernel_im2col_f16; + } else { + kernel = backend_ctx->kernel_im2col_f32; + } + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra1->data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset1)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_ulong), &batch_offset)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &delta_offset)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_long), &IW)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_long), &IH)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_long), &IC)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_long), &OW)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_long), &OH)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_long), &KW)); + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_long), &KH)); + CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_long), &pelements)); + CL_CHECK(clSetKernelArg(kernel, 14, sizeof(cl_long), &CHW)); + CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &s0)); + CL_CHECK(clSetKernelArg(kernel, 16, sizeof(int), &s1)); + CL_CHECK(clSetKernelArg(kernel, 17, sizeof(int), &p0)); + CL_CHECK(clSetKernelArg(kernel, 18, sizeof(int), &p1)); + CL_CHECK(clSetKernelArg(kernel, 19, sizeof(int), &d0)); + CL_CHECK(clSetKernelArg(kernel, 20, sizeof(int), &d1)); + + const int num_blocks = (pelements + 256 - 1) / 256; + size_t global_work_size[] = {(size_t)num_blocks*256, (size_t)OH, (size_t)batch*IC}; + size_t local_work_size[] = {256, 1, 1}; + +#ifdef GGML_OPENCL_PROFILING + cl_event evt; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt)); + + g_profiling_info.emplace_back(); + populateProfilingInfo(g_profiling_info.back(), evt, kernel, global_work_size, local_work_size, dst); +#else + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, NULL)); +#endif +} + //------------------------------------------------------------------------------ // Op offloading //------------------------------------------------------------------------------ @@ -4122,6 +4334,12 @@ bool ggml_cl_compute_forward(ggml_backend_t backend, struct ggml_tensor * tensor } func = ggml_cl_gelu; break; + case GGML_UNARY_OP_GELU_QUICK: + if (!any_on_device) { + return false; + } + func = ggml_cl_gelu_quick; + break; case GGML_UNARY_OP_SILU: if (!any_on_device) { return false; @@ -4194,6 +4412,12 @@ bool ggml_cl_compute_forward(ggml_backend_t backend, struct ggml_tensor * tensor } func = ggml_cl_rope; break; + case GGML_OP_IM2COL: + if (!any_on_device) { + return false; + } + func = ggml_cl_im2col; + break; default: return false; } diff --git a/ggml/src/ggml-opencl/kernels/ggml-opencl.cl b/ggml/src/ggml-opencl/kernels/ggml-opencl.cl index 1d43642a9..b88792887 100644 --- a/ggml/src/ggml-opencl/kernels/ggml-opencl.cl +++ b/ggml/src/ggml-opencl/kernels/ggml-opencl.cl @@ -404,6 +404,7 @@ kernel void kernel_scale( // gelu //------------------------------------------------------------------------------ #define GELU_COEF_A 0.044715f +#define GELU_QUICK_COEF -1.702f #define SQRT_2_OVER_PI 0.79788456080286535587989211986876f kernel void kernel_gelu( @@ -434,6 +435,32 @@ kernel void kernel_gelu_4( dst[get_global_id(0)] = 0.5f*x*(1.0f + tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x))); } +kernel void kernel_gelu_quick( + global float * src0, + ulong offset0, + global float * dst, + ulong offsetd +) { + src0 = (global float*)((global char*)src0 + offset0); + dst = (global float*)((global char*)dst + offsetd); + + float x = src0[get_global_id(0)]; + dst[get_global_id(0)] = x*(1.0f/(1.0f+exp(GELU_QUICK_COEF*x))); +} + +kernel void kernel_gelu_quick_4( + global float4 * src0, + ulong offset0, + global float4 * dst, + ulong offsetd +) { + src0 = (global float4*)((global char*)src0 + offset0); + dst = (global float4*)((global char*)dst + offsetd); + + float4 x = src0[get_global_id(0)]; + dst[get_global_id(0)] = x*(1.0f/(1.0f+exp(GELU_QUICK_COEF*x))); +} + //------------------------------------------------------------------------------ // silu //------------------------------------------------------------------------------ @@ -1325,6 +1352,368 @@ kernel void kernel_rope_neox_f16( } } +kernel void kernel_rope_multi_f32( + global void * src0, + ulong offset0, + global int * src1, + ulong offset1, + global float * src2, + ulong offset2, + global float * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + int ne03, + ulong nb00, + ulong nb01, + ulong nb02, + ulong nb03, + int ne0, + int ne1, + int ne2, + int ne3, + ulong nb0, + ulong nb1, + ulong nb2, + ulong nb3, + int n_past, + int n_dims, + int n_ctx_orig, + float freq_base, + float freq_scale, + float ext_factor, + float attn_factor, + float beta_fast, + float beta_slow, + int4 sections +) { + src0 = (global void*)((global char*)src0 + offset0); + src1 = (global int*)((global char*)src1 + offset1); + src2 = (global float*)((global char*)src2 + offset2); + dst = (global float*)((global char*)dst + offsetd); + + int i3 = get_group_id(2); + int i2 = get_group_id(1); + int i1 = get_group_id(0); + + float2 corr_dims = rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow); + + global int * pos = src1; + + const int sect_dims = sections.s0 + sections.s1 + sections.s2 + sections.s3; + const int sec_w = sections.s1 + sections.s0; + + float inv_ndims = -1.f/n_dims; + + for (int i0 = 2*get_local_id(0); i0 < ne0; i0 += 2*get_local_size(0)) { + if (i0 < n_dims) { + int ic = i0/2; + + const int sector = (i0 / 2) % sect_dims; + float theta_base = 0.0f; + + if (sector < sections.s0) { + theta_base = pos[i2]; + } + else if (sector >= sections.s0 && sector < sec_w) { + theta_base = pos[i2 + ne2 * 1]; + } + else if (sector >= sec_w && sector < sec_w + sections.s2) { + theta_base = pos[i2 + ne2 * 2]; + } + else if (sector >= sec_w + sections.s2) { + theta_base = pos[i2 + ne2 * 3]; + } + + const float theta = theta_base * pow(freq_base, inv_ndims*i0); + + const float freq_factor = src2 != src0 ? src2[ic] : 1.0f; + + float2 cos_sin_theta = rope_yarn(theta/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor); + + global float * src = (global float *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00); + global float * dst_data = (global float *)((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0); + + const float x0 = src[0]; + const float x1 = src[n_dims/2]; + + dst_data[0] = x0*cos_sin_theta.s0 - x1*cos_sin_theta.s1; + dst_data[n_dims/2] = x0*cos_sin_theta.s1 + x1*cos_sin_theta.s0; + } else { + global float * const src = (global float *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); + global float * dst_data = (global float *)((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + + dst_data[0] = src[0]; + dst_data[1] = src[1]; + } + } +} + +kernel void kernel_rope_multi_f16( + global void * src0, + ulong offset0, + global int * src1, + ulong offset1, + global float * src2, + ulong offset2, + global half * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + int ne03, + ulong nb00, + ulong nb01, + ulong nb02, + ulong nb03, + int ne0, + int ne1, + int ne2, + int ne3, + ulong nb0, + ulong nb1, + ulong nb2, + ulong nb3, + int n_past, + int n_dims, + int n_ctx_orig, + float freq_base, + float freq_scale, + float ext_factor, + float attn_factor, + float beta_fast, + float beta_slow, + int4 sections +) { + src0 = (global void*)((global char*)src0 + offset0); + src1 = (global int*)((global char*)src1 + offset1); + src2 = (global float*)((global char*)src2 + offset2); + dst = (global float*)((global char*)dst + offsetd); + + int i3 = get_group_id(2); + int i2 = get_group_id(1); + int i1 = get_group_id(0); + + float2 corr_dims = rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow); + + global int * pos = src1; + + const int sect_dims = sections.s0 + sections.s1 + sections.s2 + sections.s3; + const int sec_w = sections.s1 + sections.s0; + + float inv_ndims = -1.f/n_dims; + + for (int i0 = 2*get_local_id(0); i0 < ne0; i0 += 2*get_local_size(0)) { + if (i0 < n_dims) { + int ic = i0/2; + + const int sector = (i0 / 2) % sect_dims; + float theta_base = 0.0f; + + if (sector < sections.s0) { + theta_base = pos[i2]; + } + else if (sector >= sections.s0 && sector < sec_w) { + theta_base = pos[i2 + ne2 * 1]; + } + else if (sector >= sec_w && sector < sec_w + sections.s2) { + theta_base = pos[i2 + ne2 * 2]; + } + else if (sector >= sec_w + sections.s2) { + theta_base = pos[i2 + ne2 * 3]; + } + + const float theta = theta_base * pow(freq_base, inv_ndims*i0); + + const float freq_factor = src2 != src0 ? src2[ic] : 1.0f; + + float2 cos_sin_theta = rope_yarn(theta/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor); + + global half * src = (global half *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00); + global half * dst_data = (global half *)((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0); + + const float x0 = src[0]; + const float x1 = src[n_dims/2]; + + dst_data[0] = x0*cos_sin_theta.s0 - x1*cos_sin_theta.s1; + dst_data[n_dims/2] = x0*cos_sin_theta.s1 + x1*cos_sin_theta.s0; + } else { + global half * const src = (global half *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); + global half * dst_data = (global half *)((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + + dst_data[0] = src[0]; + dst_data[1] = src[1]; + } + } +} + +kernel void kernel_rope_vision_f32( + global void * src0, + ulong offset0, + global int * src1, + ulong offset1, + global float * src2, + ulong offset2, + global float * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + int ne03, + ulong nb00, + ulong nb01, + ulong nb02, + ulong nb03, + int ne0, + int ne1, + int ne2, + int ne3, + ulong nb0, + ulong nb1, + ulong nb2, + ulong nb3, + int n_past, + int n_dims, + int n_ctx_orig, + float freq_base, + float freq_scale, + float ext_factor, + float attn_factor, + float beta_fast, + float beta_slow, + int4 sections +) { + src0 = (global void*)((global char*)src0 + offset0); + src1 = (global int*)((global char*)src1 + offset1); + src2 = (global float*)((global char*)src2 + offset2); + dst = (global float*)((global char*)dst + offsetd); + + int i3 = get_group_id(2); + int i2 = get_group_id(1); + int i1 = get_group_id(0); + + float2 corr_dims = rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow); + + global int * pos = src1; + + const int sect_dims = sections.s0 + sections.s1; + const int sec_w = sections.s1 + sections.s0; + + float inv_ndims = -1.f/n_dims; + + for (int i0 = 2*get_local_id(0); i0 < ne0; i0 += 2*get_local_size(0)) { + int ic = i0/2; + + const int sector = (i0/2) % sect_dims; + float theta_base = 0.0f; + + if (sector < sections.s0) { + const int p = sector; + theta_base = pos[i2] * pow(freq_base, inv_ndims*2.0f*p); + } else if (sector >= sections.s0 && sector < sec_w) { + const int p = sector - sections.s0; + theta_base = pos[i2 + ne2] * pow(freq_base, inv_ndims*2.0f*p); + } + + const float freq_factor = src2 != src0 ? src2[ic] : 1.0f; + + float2 cos_sin_theta = rope_yarn(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor); + + global float * src = (global float *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00); + global float * dst_data = (global float *)((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0); + + const float x0 = src[0]; + const float x1 = src[n_dims]; + + dst_data[0] = x0*cos_sin_theta.s0 - x1*cos_sin_theta.s1; + dst_data[n_dims] = x0*cos_sin_theta.s1 + x1*cos_sin_theta.s0; + } +} + +kernel void kernel_rope_vision_f16( + global void * src0, + ulong offset0, + global int * src1, + ulong offset1, + global float * src2, + ulong offset2, + global half * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + int ne03, + ulong nb00, + ulong nb01, + ulong nb02, + ulong nb03, + int ne0, + int ne1, + int ne2, + int ne3, + ulong nb0, + ulong nb1, + ulong nb2, + ulong nb3, + int n_past, + int n_dims, + int n_ctx_orig, + float freq_base, + float freq_scale, + float ext_factor, + float attn_factor, + float beta_fast, + float beta_slow, + int4 sections +) { + src0 = (global void*)((global char*)src0 + offset0); + src1 = (global int*)((global char*)src1 + offset1); + src2 = (global float*)((global char*)src2 + offset2); + dst = (global float*)((global char*)dst + offsetd); + + int i3 = get_group_id(2); + int i2 = get_group_id(1); + int i1 = get_group_id(0); + + float2 corr_dims = rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow); + + global int * pos = src1; + + const int sect_dims = sections.s0 + sections.s1; + const int sec_w = sections.s1 + sections.s0; + + float inv_ndims = -1.f/n_dims; + + for (int i0 = 2*get_local_id(0); i0 < ne0; i0 += 2*get_local_size(0)) { + int ic = i0/2; + + const int sector = (i0/2) % sect_dims; + float theta_base = 0.0f; + + if (sector < sections.s0) { + const int p = sector; + theta_base = pos[i2] * pow(freq_base, inv_ndims*2.0f*p); + } else if (sector >= sections.s0 && sector < sec_w) { + const int p = sector - sections.s0; + theta_base = pos[i2 + ne2] * pow(freq_base, inv_ndims*2.0f*p); + } + + const float freq_factor = src2 != src0 ? src2[ic] : 1.0f; + + float2 cos_sin_theta = rope_yarn(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor); + + global half * src = (global half *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00); + global half * dst_data = (global half *)((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0); + + const float x0 = src[0]; + const float x1 = src[n_dims]; + + dst_data[0] = x0*cos_sin_theta.s0 - x1*cos_sin_theta.s1; + dst_data[n_dims] = x0*cos_sin_theta.s1 + x1*cos_sin_theta.s0; + } +} + //------------------------------------------------------------------------------ // cpy //------------------------------------------------------------------------------ diff --git a/ggml/src/ggml-opencl/kernels/ggml-opencl_im2col.cl b/ggml/src/ggml-opencl/kernels/ggml-opencl_im2col.cl new file mode 100644 index 000000000..9b41dfb25 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/ggml-opencl_im2col.cl @@ -0,0 +1,146 @@ +#ifdef cl_khr_fp16 +#pragma OPENCL EXTENSION cl_khr_fp16 : enable +#elif defined(cl_amd_fp16) +#pragma OPENCL EXTENSION cl_amd_fp16 : enable +#else +#error "Half precision floating point not supportedby OpenCL implementation on your device." +#endif + +#ifdef cl_khr_subgroups +#pragma OPENCL EXTENSION cl_khr_subgroups : enable +#elif defined(cl_intel_subgroups) +#pragma OPENCL EXTENSION cl_intel_subgroups : enable +#else +#error "Subgroup not supported on your device." +#endif + +#ifdef cl_intel_required_subgroup_size +// Always use subgroup size of 32 on Intel. +#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable +#define INTEL_GPU 1 +#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16))) +#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32))) +#elif defined(cl_qcom_reqd_sub_group_size) +// Always use subgroups size of 64 on Adreno. +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable +#define ADRENO_GPU 1 +#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half"))) +#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full"))) +#else +// TODO: do not know how to choose subgroup size on other GPUs. +#error "Selecting subgroup size is not supported on your device." +#endif + +kernel void kernel_im2col_f32( + global float * src1, + ulong offset1, + global float * dst, + ulong offsetd, + ulong batch_offset, + ulong delta_offset, + long IW, + long IH, + long IC, + long OW, + long OH, + long KW, + long KH, + long pelements, + long CHW, + int s0, + int s1, + int p0, + int p1, + int d0, + int d1 +) { + // threadIdx.x + blockIdx.x * blockDim.x + long i = get_global_id(0); + if (i >= pelements) { + return; + } + + src1 = (global float*)((global char*)src1 + offset1); + dst = (global float*)((global char*)dst + offsetd); + + long ksize = OW * (KH > 1 ? KW : 1); + long kx = i / ksize; + long kd = kx * ksize; + long ky = (i - kd) / OW; + long ix = i % OW; + + long oh = get_group_id(1); + long batch = get_group_id(2) / IC; + long ic = get_group_id(2) % IC; + + long iiw = ix * s0 + kx * d0 - p0; + long iih = oh * s1 + ky * d1 - p1; + + long offset_dst = + ((batch * OH + oh) * OW + ix) * CHW + + (ic * (KW * KH) + ky * KW + kx); + + if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) { + dst[offset_dst] = 0.0f; + } else { + long offset_src = ic * delta_offset + batch * batch_offset; + dst[offset_dst] = src1[offset_src + iih * IW + iiw]; + } +} + +kernel void kernel_im2col_f16( + global float * src1, + ulong offset1, + global half * dst, + ulong offsetd, + ulong batch_offset, + ulong delta_offset, + long IW, + long IH, + long IC, + long OW, + long OH, + long KW, + long KH, + long pelements, + long CHW, + int s0, + int s1, + int p0, + int p1, + int d0, + int d1 +) { + long i = get_global_id(0); + + if (i >= pelements) { + return; + } + + src1 = (global float*)((global char*)src1 + offset1); + dst = (global half*)((global char*)dst + offsetd); + + long ksize = OW * (KH > 1 ? KW : 1); + long kx = i / ksize; + long kd = kx * ksize; + long ky = (i - kd) / OW; + long ix = i % OW; + + long oh = get_group_id(1); + long batch = get_group_id(2) / IC; + long ic = get_group_id(2) % IC; + + long iiw = ix * s0 + kx * d0 - p0; + long iih = oh * s1 + ky * d1 - p1; + + long offset_dst = + ((batch * OH + oh) * OW + ix) * CHW + + (ic * (KW * KH) + ky * KW + kx); + + if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) { + dst[offset_dst] = 0.0f; + } else { + long offset_src = ic * delta_offset + batch * batch_offset; + dst[offset_dst] = src1[offset_src + iih * IW + iiw]; + } +} From 296901983700f3c37449bcb555d85d27150a679d Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 27 Mar 2025 23:09:05 +0200 Subject: [PATCH 17/26] media : add SVG logo [no ci] (#12616) --- media/llama1-logo.svg | 34 ++++++++++++++++++++++++++++++++++ 1 file changed, 34 insertions(+) create mode 100644 media/llama1-logo.svg diff --git a/media/llama1-logo.svg b/media/llama1-logo.svg new file mode 100644 index 000000000..e080481fa --- /dev/null +++ b/media/llama1-logo.svg @@ -0,0 +1,34 @@ + + + + + + + + + + + + + + + + + + + + + + \ No newline at end of file From 2099a9d5dbdcd2841d02dd396a71d0de70bf4490 Mon Sep 17 00:00:00 2001 From: Piotr Date: Thu, 27 Mar 2025 23:41:04 +0100 Subject: [PATCH 18/26] server : Support listening on a unix socket (#12613) * server : Bump cpp-httplib to include AF_UNIX windows support Signed-off-by: Piotr Stankiewicz * server : Allow running the server example on a unix socket Signed-off-by: Piotr Stankiewicz --------- Signed-off-by: Piotr Stankiewicz --- common/arg.cpp | 2 +- examples/server/httplib.h | 562 +++++++++++++++++++++---------------- examples/server/server.cpp | 23 +- 3 files changed, 331 insertions(+), 256 deletions(-) diff --git a/common/arg.cpp b/common/arg.cpp index b6bfe6f89..8292adaac 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -1979,7 +1979,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex ).set_examples({LLAMA_EXAMPLE_EMBEDDING})); add_opt(common_arg( {"--host"}, "HOST", - string_format("ip address to listen (default: %s)", params.hostname.c_str()), + string_format("ip address to listen, or bind to an UNIX socket if the address ends with .sock (default: %s)", params.hostname.c_str()), [](common_params & params, const std::string & value) { params.hostname = value; } diff --git a/examples/server/httplib.h b/examples/server/httplib.h index 593beb501..0f981dc89 100644 --- a/examples/server/httplib.h +++ b/examples/server/httplib.h @@ -8,7 +8,7 @@ #ifndef CPPHTTPLIB_HTTPLIB_H #define CPPHTTPLIB_HTTPLIB_H -#define CPPHTTPLIB_VERSION "0.19.0" +#define CPPHTTPLIB_VERSION "0.20.0" /* * Configuration @@ -188,15 +188,16 @@ using ssize_t = long; #include #include +// afunix.h uses types declared in winsock2.h, so has to be included after it. +#include + #ifndef WSA_FLAG_NO_HANDLE_INHERIT #define WSA_FLAG_NO_HANDLE_INHERIT 0x80 #endif +using nfds_t = unsigned long; using socket_t = SOCKET; using socklen_t = int; -#ifdef CPPHTTPLIB_USE_POLL -#define poll(fds, nfds, timeout) WSAPoll(fds, nfds, timeout) -#endif #else // not _WIN32 @@ -216,16 +217,11 @@ using socklen_t = int; #ifdef __linux__ #include #endif -#include -#ifdef CPPHTTPLIB_USE_POLL -#include -#endif #include +#include +#include #include #include -#ifndef __VMS -#include -#endif #include #include #include @@ -247,7 +243,6 @@ using socket_t = int; #include #include #include -#include #include #include #include @@ -320,6 +315,10 @@ using socket_t = int; #include #endif +#ifdef CPPHTTPLIB_ZSTD_SUPPORT +#include +#endif + /* * Declaration */ @@ -435,6 +434,15 @@ private: } // namespace detail +enum SSLVerifierResponse { + // no decision has been made, use the built-in certificate verifier + NoDecisionMade, + // connection certificate is verified and accepted + CertificateAccepted, + // connection certificate was processed but is rejected + CertificateRejected +}; + enum StatusCode { // Information responses Continue_100 = 100, @@ -670,7 +678,7 @@ struct Request { bool is_chunked_content_provider_ = false; size_t authorization_count_ = 0; std::chrono::time_point start_time_ = - std::chrono::steady_clock::time_point::min(); + (std::chrono::steady_clock::time_point::min)(); }; struct Response { @@ -736,7 +744,8 @@ public: virtual ~Stream() = default; virtual bool is_readable() const = 0; - virtual bool is_writable() const = 0; + virtual bool wait_readable() const = 0; + virtual bool wait_writable() const = 0; virtual ssize_t read(char *ptr, size_t size) = 0; virtual ssize_t write(const char *ptr, size_t size) = 0; @@ -879,7 +888,7 @@ public: * Captures parameters in request path and stores them in Request::path_params * * Capture name is a substring of a pattern from : to /. - * The rest of the pattern is matched agains the request path directly + * The rest of the pattern is matched against the request path directly * Parameters are captured starting from the next character after * the end of the last matched static pattern fragment until the next /. * @@ -1109,7 +1118,7 @@ private: virtual bool process_and_close_socket(socket_t sock); std::atomic is_running_{false}; - std::atomic is_decommisioned{false}; + std::atomic is_decommissioned{false}; struct MountPointEntry { std::string mount_point; @@ -1483,7 +1492,8 @@ public: #ifdef CPPHTTPLIB_OPENSSL_SUPPORT void enable_server_certificate_verification(bool enabled); void enable_server_hostname_verification(bool enabled); - void set_server_certificate_verifier(std::function verifier); + void set_server_certificate_verifier( + std::function verifier); #endif void set_logger(Logger logger); @@ -1600,7 +1610,7 @@ protected: #ifdef CPPHTTPLIB_OPENSSL_SUPPORT bool server_certificate_verification_ = true; bool server_hostname_verification_ = true; - std::function server_certificate_verifier_; + std::function server_certificate_verifier_; #endif Logger logger_; @@ -1913,7 +1923,8 @@ public: #ifdef CPPHTTPLIB_OPENSSL_SUPPORT void enable_server_certificate_verification(bool enabled); void enable_server_hostname_verification(bool enabled); - void set_server_certificate_verifier(std::function verifier); + void set_server_certificate_verifier( + std::function verifier); #endif void set_logger(Logger logger); @@ -2046,6 +2057,10 @@ inline void duration_to_sec_and_usec(const T &duration, U callback) { callback(static_cast(sec), static_cast(usec)); } +template inline constexpr size_t str_len(const char (&)[N]) { + return N - 1; +} + inline bool is_numeric(const std::string &str) { return !str.empty() && std::all_of(str.begin(), str.end(), ::isdigit); } @@ -2205,9 +2220,9 @@ inline const char *status_message(int status) { inline std::string get_bearer_token_auth(const Request &req) { if (req.has_header("Authorization")) { - static std::string BearerHeaderPrefix = "Bearer "; + constexpr auto bearer_header_prefix_len = detail::str_len("Bearer "); return req.get_header_value("Authorization") - .substr(BearerHeaderPrefix.length()); + .substr(bearer_header_prefix_len); } return ""; } @@ -2382,8 +2397,6 @@ std::string encode_query_param(const std::string &value); std::string decode_url(const std::string &s, bool convert_plus_to_space); -void read_file(const std::string &path, std::string &out); - std::string trim_copy(const std::string &s); void divide( @@ -2439,7 +2452,7 @@ ssize_t send_socket(socket_t sock, const void *ptr, size_t size, int flags); ssize_t read_socket(socket_t sock, void *ptr, size_t size, int flags); -enum class EncodingType { None = 0, Gzip, Brotli }; +enum class EncodingType { None = 0, Gzip, Brotli, Zstd }; EncodingType encoding_type(const Request &req, const Response &res); @@ -2449,7 +2462,8 @@ public: ~BufferStream() override = default; bool is_readable() const override; - bool is_writable() const override; + bool wait_readable() const override; + bool wait_writable() const override; ssize_t read(char *ptr, size_t size) override; ssize_t write(const char *ptr, size_t size) override; void get_remote_ip_and_port(std::string &ip, int &port) const override; @@ -2551,6 +2565,34 @@ private: }; #endif +#ifdef CPPHTTPLIB_ZSTD_SUPPORT +class zstd_compressor : public compressor { +public: + zstd_compressor(); + ~zstd_compressor(); + + bool compress(const char *data, size_t data_length, bool last, + Callback callback) override; + +private: + ZSTD_CCtx *ctx_ = nullptr; +}; + +class zstd_decompressor : public decompressor { +public: + zstd_decompressor(); + ~zstd_decompressor(); + + bool is_valid() const override; + + bool decompress(const char *data, size_t data_length, + Callback callback) override; + +private: + ZSTD_DCtx *ctx_ = nullptr; +}; +#endif + // NOTE: until the read size reaches `fixed_buffer_size`, use `fixed_buffer` // to store data. The call can set memory on stack for performance. class stream_line_reader { @@ -2569,7 +2611,7 @@ private: char *fixed_buffer_; const size_t fixed_buffer_size_; size_t fixed_buffer_used_size_ = 0; - std::string glowable_buffer_; + std::string growable_buffer_; }; class mmap { @@ -2910,18 +2952,9 @@ inline std::string decode_url(const std::string &s, return result; } -inline void read_file(const std::string &path, std::string &out) { - std::ifstream fs(path, std::ios_base::binary); - fs.seekg(0, std::ios_base::end); - auto size = fs.tellg(); - fs.seekg(0); - out.resize(static_cast(size)); - fs.read(&out[0], static_cast(size)); -} - inline std::string file_extension(const std::string &path) { std::smatch m; - static auto re = std::regex("\\.([a-zA-Z0-9]+)$"); + thread_local auto re = std::regex("\\.([a-zA-Z0-9]+)$"); if (std::regex_search(path, m, re)) { return m[1].str(); } return std::string(); } @@ -3005,18 +3038,18 @@ inline stream_line_reader::stream_line_reader(Stream &strm, char *fixed_buffer, fixed_buffer_size_(fixed_buffer_size) {} inline const char *stream_line_reader::ptr() const { - if (glowable_buffer_.empty()) { + if (growable_buffer_.empty()) { return fixed_buffer_; } else { - return glowable_buffer_.data(); + return growable_buffer_.data(); } } inline size_t stream_line_reader::size() const { - if (glowable_buffer_.empty()) { + if (growable_buffer_.empty()) { return fixed_buffer_used_size_; } else { - return glowable_buffer_.size(); + return growable_buffer_.size(); } } @@ -3027,7 +3060,7 @@ inline bool stream_line_reader::end_with_crlf() const { inline bool stream_line_reader::getline() { fixed_buffer_used_size_ = 0; - glowable_buffer_.clear(); + growable_buffer_.clear(); #ifndef CPPHTTPLIB_ALLOW_LF_AS_LINE_TERMINATOR char prev_byte = 0; @@ -3065,11 +3098,11 @@ inline void stream_line_reader::append(char c) { fixed_buffer_[fixed_buffer_used_size_++] = c; fixed_buffer_[fixed_buffer_used_size_] = '\0'; } else { - if (glowable_buffer_.empty()) { + if (growable_buffer_.empty()) { assert(fixed_buffer_[fixed_buffer_used_size_] == '\0'); - glowable_buffer_.assign(fixed_buffer_, fixed_buffer_used_size_); + growable_buffer_.assign(fixed_buffer_, fixed_buffer_used_size_); } - glowable_buffer_ += c; + growable_buffer_ += c; } } @@ -3246,35 +3279,23 @@ inline ssize_t send_socket(socket_t sock, const void *ptr, size_t size, }); } +inline int poll_wrapper(struct pollfd *fds, nfds_t nfds, int timeout) { +#ifdef _WIN32 + return ::WSAPoll(fds, nfds, timeout); +#else + return ::poll(fds, nfds, timeout); +#endif +} + template inline ssize_t select_impl(socket_t sock, time_t sec, time_t usec) { -#ifdef CPPHTTPLIB_USE_POLL struct pollfd pfd; pfd.fd = sock; pfd.events = (Read ? POLLIN : POLLOUT); auto timeout = static_cast(sec * 1000 + usec / 1000); - return handle_EINTR([&]() { return poll(&pfd, 1, timeout); }); -#else -#ifndef _WIN32 - if (sock >= FD_SETSIZE) { return -1; } -#endif - - fd_set fds, *rfds, *wfds; - FD_ZERO(&fds); - FD_SET(sock, &fds); - rfds = (Read ? &fds : nullptr); - wfds = (Read ? nullptr : &fds); - - timeval tv; - tv.tv_sec = static_cast(sec); - tv.tv_usec = static_cast(usec); - - return handle_EINTR([&]() { - return select(static_cast(sock + 1), rfds, wfds, nullptr, &tv); - }); -#endif + return handle_EINTR([&]() { return poll_wrapper(&pfd, 1, timeout); }); } inline ssize_t select_read(socket_t sock, time_t sec, time_t usec) { @@ -3287,14 +3308,14 @@ inline ssize_t select_write(socket_t sock, time_t sec, time_t usec) { inline Error wait_until_socket_is_ready(socket_t sock, time_t sec, time_t usec) { -#ifdef CPPHTTPLIB_USE_POLL struct pollfd pfd_read; pfd_read.fd = sock; pfd_read.events = POLLIN | POLLOUT; auto timeout = static_cast(sec * 1000 + usec / 1000); - auto poll_res = handle_EINTR([&]() { return poll(&pfd_read, 1, timeout); }); + auto poll_res = + handle_EINTR([&]() { return poll_wrapper(&pfd_read, 1, timeout); }); if (poll_res == 0) { return Error::ConnectionTimeout; } @@ -3308,38 +3329,6 @@ inline Error wait_until_socket_is_ready(socket_t sock, time_t sec, } return Error::Connection; -#else -#ifndef _WIN32 - if (sock >= FD_SETSIZE) { return Error::Connection; } -#endif - - fd_set fdsr; - FD_ZERO(&fdsr); - FD_SET(sock, &fdsr); - - auto fdsw = fdsr; - auto fdse = fdsr; - - timeval tv; - tv.tv_sec = static_cast(sec); - tv.tv_usec = static_cast(usec); - - auto ret = handle_EINTR([&]() { - return select(static_cast(sock + 1), &fdsr, &fdsw, &fdse, &tv); - }); - - if (ret == 0) { return Error::ConnectionTimeout; } - - if (ret > 0 && (FD_ISSET(sock, &fdsr) || FD_ISSET(sock, &fdsw))) { - auto error = 0; - socklen_t len = sizeof(error); - auto res = getsockopt(sock, SOL_SOCKET, SO_ERROR, - reinterpret_cast(&error), &len); - auto successful = res >= 0 && !error; - return successful ? Error::Success : Error::Connection; - } - return Error::Connection; -#endif } inline bool is_socket_alive(socket_t sock) { @@ -3359,11 +3348,12 @@ public: time_t write_timeout_sec, time_t write_timeout_usec, time_t max_timeout_msec = 0, std::chrono::time_point start_time = - std::chrono::steady_clock::time_point::min()); + (std::chrono::steady_clock::time_point::min)()); ~SocketStream() override; bool is_readable() const override; - bool is_writable() const override; + bool wait_readable() const override; + bool wait_writable() const override; ssize_t read(char *ptr, size_t size) override; ssize_t write(const char *ptr, size_t size) override; void get_remote_ip_and_port(std::string &ip, int &port) const override; @@ -3378,7 +3368,7 @@ private: time_t write_timeout_sec_; time_t write_timeout_usec_; time_t max_timeout_msec_; - const std::chrono::time_point start_time; + const std::chrono::time_point start_time_; std::vector read_buff_; size_t read_buff_off_ = 0; @@ -3395,11 +3385,12 @@ public: time_t read_timeout_usec, time_t write_timeout_sec, time_t write_timeout_usec, time_t max_timeout_msec = 0, std::chrono::time_point start_time = - std::chrono::steady_clock::time_point::min()); + (std::chrono::steady_clock::time_point::min)()); ~SSLSocketStream() override; bool is_readable() const override; - bool is_writable() const override; + bool wait_readable() const override; + bool wait_writable() const override; ssize_t read(char *ptr, size_t size) override; ssize_t write(const char *ptr, size_t size) override; void get_remote_ip_and_port(std::string &ip, int &port) const override; @@ -3415,7 +3406,7 @@ private: time_t write_timeout_sec_; time_t write_timeout_usec_; time_t max_timeout_msec_; - const std::chrono::time_point start_time; + const std::chrono::time_point start_time_; }; #endif @@ -3550,7 +3541,6 @@ socket_t create_socket(const std::string &host, const std::string &ip, int port, hints.ai_flags = socket_flags; } -#ifndef _WIN32 if (hints.ai_family == AF_UNIX) { const auto addrlen = host.length(); if (addrlen > sizeof(sockaddr_un::sun_path)) { return INVALID_SOCKET; } @@ -3574,11 +3564,19 @@ socket_t create_socket(const std::string &host, const std::string &ip, int port, sizeof(addr) - sizeof(addr.sun_path) + addrlen); #ifndef SOCK_CLOEXEC +#ifndef _WIN32 fcntl(sock, F_SETFD, FD_CLOEXEC); +#endif #endif if (socket_options) { socket_options(sock); } +#ifdef _WIN32 + // Setting SO_REUSEADDR seems not to work well with AF_UNIX on windows, so + // remove the option. + detail::set_socket_opt(sock, SOL_SOCKET, SO_REUSEADDR, 0); +#endif + bool dummy; if (!bind_or_connect(sock, hints, dummy)) { close_socket(sock); @@ -3587,7 +3585,6 @@ socket_t create_socket(const std::string &host, const std::string &ip, int port, } return sock; } -#endif auto service = std::to_string(port); @@ -3993,6 +3990,12 @@ inline EncodingType encoding_type(const Request &req, const Response &res) { if (ret) { return EncodingType::Gzip; } #endif +#ifdef CPPHTTPLIB_ZSTD_SUPPORT + // TODO: 'Accept-Encoding' has zstd, not zstd;q=0 + ret = s.find("zstd") != std::string::npos; + if (ret) { return EncodingType::Zstd; } +#endif + return EncodingType::None; } @@ -4201,6 +4204,61 @@ inline bool brotli_decompressor::decompress(const char *data, } #endif +#ifdef CPPHTTPLIB_ZSTD_SUPPORT +inline zstd_compressor::zstd_compressor() { + ctx_ = ZSTD_createCCtx(); + ZSTD_CCtx_setParameter(ctx_, ZSTD_c_compressionLevel, ZSTD_fast); +} + +inline zstd_compressor::~zstd_compressor() { ZSTD_freeCCtx(ctx_); } + +inline bool zstd_compressor::compress(const char *data, size_t data_length, + bool last, Callback callback) { + std::array buff{}; + + ZSTD_EndDirective mode = last ? ZSTD_e_end : ZSTD_e_continue; + ZSTD_inBuffer input = {data, data_length, 0}; + + bool finished; + do { + ZSTD_outBuffer output = {buff.data(), CPPHTTPLIB_COMPRESSION_BUFSIZ, 0}; + size_t const remaining = ZSTD_compressStream2(ctx_, &output, &input, mode); + + if (ZSTD_isError(remaining)) { return false; } + + if (!callback(buff.data(), output.pos)) { return false; } + + finished = last ? (remaining == 0) : (input.pos == input.size); + + } while (!finished); + + return true; +} + +inline zstd_decompressor::zstd_decompressor() { ctx_ = ZSTD_createDCtx(); } + +inline zstd_decompressor::~zstd_decompressor() { ZSTD_freeDCtx(ctx_); } + +inline bool zstd_decompressor::is_valid() const { return ctx_ != nullptr; } + +inline bool zstd_decompressor::decompress(const char *data, size_t data_length, + Callback callback) { + std::array buff{}; + ZSTD_inBuffer input = {data, data_length, 0}; + + while (input.pos < input.size) { + ZSTD_outBuffer output = {buff.data(), CPPHTTPLIB_COMPRESSION_BUFSIZ, 0}; + size_t const remaining = ZSTD_decompressStream(ctx_, &output, &input); + + if (ZSTD_isError(remaining)) { return false; } + + if (!callback(buff.data(), output.pos)) { return false; } + } + + return true; +} +#endif + inline bool has_header(const Headers &headers, const std::string &key) { return headers.find(key) != headers.end(); } @@ -4227,6 +4285,9 @@ inline bool parse_header(const char *beg, const char *end, T fn) { p++; } + auto name = std::string(beg, p); + if (!detail::fields::is_field_name(name)) { return false; } + if (p == end) { return false; } auto key_end = p; @@ -4242,10 +4303,6 @@ inline bool parse_header(const char *beg, const char *end, T fn) { if (!key_len) { return false; } auto key = std::string(beg, key_end); - // auto val = (case_ignore::equal(key, "Location") || - // case_ignore::equal(key, "Referer")) - // ? std::string(p, end) - // : decode_url(std::string(p, end), false); auto val = std::string(p, end); if (!detail::fields::is_field_value(val)) { return false; } @@ -4341,7 +4398,8 @@ inline bool read_content_without_length(Stream &strm, uint64_t r = 0; for (;;) { auto n = strm.read(buf, CPPHTTPLIB_RECV_BUFSIZ); - if (n <= 0) { return false; } + if (n == 0) { return true; } + if (n < 0) { return false; } if (!out(buf, static_cast(n), r, 0)) { return false; } r += static_cast(n); @@ -4384,7 +4442,7 @@ inline bool read_content_chunked(Stream &strm, T &x, assert(chunk_len == 0); - // NOTE: In RFC 9112, '7.1 Chunked Transfer Coding' mentiones "The chunked + // NOTE: In RFC 9112, '7.1 Chunked Transfer Coding' mentions "The chunked // transfer coding is complete when a chunk with a chunk-size of zero is // received, possibly followed by a trailer section, and finally terminated by // an empty line". https://www.rfc-editor.org/rfc/rfc9112.html#section-7.1 @@ -4394,8 +4452,8 @@ inline bool read_content_chunked(Stream &strm, T &x, // to be ok whether the final CRLF exists or not in the chunked data. // https://www.rfc-editor.org/rfc/rfc9112.html#section-7.1.3 // - // According to the reference code in RFC 9112, cpp-htpplib now allows - // chuncked transfer coding data without the final CRLF. + // According to the reference code in RFC 9112, cpp-httplib now allows + // chunked transfer coding data without the final CRLF. if (!line_reader.getline()) { return true; } while (strcmp(line_reader.ptr(), "\r\n") != 0) { @@ -4442,6 +4500,13 @@ bool prepare_content_receiver(T &x, int &status, #else status = StatusCode::UnsupportedMediaType_415; return false; +#endif + } else if (encoding == "zstd") { +#ifdef CPPHTTPLIB_ZSTD_SUPPORT + decompressor = detail::make_unique(); +#else + status = StatusCode::UnsupportedMediaType_415; + return false; #endif } @@ -4565,7 +4630,7 @@ inline bool write_content(Stream &strm, const ContentProvider &content_provider, data_sink.write = [&](const char *d, size_t l) -> bool { if (ok) { - if (strm.is_writable() && write_data(strm, d, l)) { + if (write_data(strm, d, l)) { offset += l; } else { ok = false; @@ -4574,10 +4639,10 @@ inline bool write_content(Stream &strm, const ContentProvider &content_provider, return ok; }; - data_sink.is_writable = [&]() -> bool { return strm.is_writable(); }; + data_sink.is_writable = [&]() -> bool { return strm.wait_writable(); }; while (offset < end_offset && !is_shutting_down()) { - if (!strm.is_writable()) { + if (!strm.wait_writable()) { error = Error::Write; return false; } else if (!content_provider(offset, end_offset - offset, data_sink)) { @@ -4615,17 +4680,17 @@ write_content_without_length(Stream &strm, data_sink.write = [&](const char *d, size_t l) -> bool { if (ok) { offset += l; - if (!strm.is_writable() || !write_data(strm, d, l)) { ok = false; } + if (!write_data(strm, d, l)) { ok = false; } } return ok; }; - data_sink.is_writable = [&]() -> bool { return strm.is_writable(); }; + data_sink.is_writable = [&]() -> bool { return strm.wait_writable(); }; data_sink.done = [&](void) { data_available = false; }; while (data_available && !is_shutting_down()) { - if (!strm.is_writable()) { + if (!strm.wait_writable()) { return false; } else if (!content_provider(offset, 0, data_sink)) { return false; @@ -4660,10 +4725,7 @@ write_content_chunked(Stream &strm, const ContentProvider &content_provider, // Emit chunked response header and footer for each chunk auto chunk = from_i_to_hex(payload.size()) + "\r\n" + payload + "\r\n"; - if (!strm.is_writable() || - !write_data(strm, chunk.data(), chunk.size())) { - ok = false; - } + if (!write_data(strm, chunk.data(), chunk.size())) { ok = false; } } } else { ok = false; @@ -4672,7 +4734,7 @@ write_content_chunked(Stream &strm, const ContentProvider &content_provider, return ok; }; - data_sink.is_writable = [&]() -> bool { return strm.is_writable(); }; + data_sink.is_writable = [&]() -> bool { return strm.wait_writable(); }; auto done_with_trailer = [&](const Headers *trailer) { if (!ok) { return; } @@ -4692,17 +4754,14 @@ write_content_chunked(Stream &strm, const ContentProvider &content_provider, if (!payload.empty()) { // Emit chunked response header and footer for each chunk auto chunk = from_i_to_hex(payload.size()) + "\r\n" + payload + "\r\n"; - if (!strm.is_writable() || - !write_data(strm, chunk.data(), chunk.size())) { + if (!write_data(strm, chunk.data(), chunk.size())) { ok = false; return; } } - static const std::string done_marker("0\r\n"); - if (!write_data(strm, done_marker.data(), done_marker.size())) { - ok = false; - } + constexpr const char done_marker[] = "0\r\n"; + if (!write_data(strm, done_marker, str_len(done_marker))) { ok = false; } // Trailer if (trailer) { @@ -4714,8 +4773,8 @@ write_content_chunked(Stream &strm, const ContentProvider &content_provider, } } - static const std::string crlf("\r\n"); - if (!write_data(strm, crlf.data(), crlf.size())) { ok = false; } + constexpr const char crlf[] = "\r\n"; + if (!write_data(strm, crlf, str_len(crlf))) { ok = false; } }; data_sink.done = [&](void) { done_with_trailer(nullptr); }; @@ -4725,7 +4784,7 @@ write_content_chunked(Stream &strm, const ContentProvider &content_provider, }; while (data_available && !is_shutting_down()) { - if (!strm.is_writable()) { + if (!strm.wait_writable()) { error = Error::Write; return false; } else if (!content_provider(offset, 0, data_sink)) { @@ -4957,13 +5016,13 @@ public: return false; } - static const std::string header_content_type = "Content-Type:"; + constexpr const char header_content_type[] = "Content-Type:"; if (start_with_case_ignore(header, header_content_type)) { file_.content_type = - trim_copy(header.substr(header_content_type.size())); + trim_copy(header.substr(str_len(header_content_type))); } else { - static const std::regex re_content_disposition( + thread_local const std::regex re_content_disposition( R"~(^Content-Disposition:\s*form-data;\s*(.*)$)~", std::regex_constants::icase); @@ -4985,8 +5044,8 @@ public: it = params.find("filename*"); if (it != params.end()) { - // Only allow UTF-8 enconnding... - static const std::regex re_rfc5987_encoding( + // Only allow UTF-8 encoding... + thread_local const std::regex re_rfc5987_encoding( R"~(^UTF-8''(.+?)$)~", std::regex_constants::icase); std::smatch m2; @@ -5058,10 +5117,10 @@ private: file_.content_type.clear(); } - bool start_with_case_ignore(const std::string &a, - const std::string &b) const { - if (a.size() < b.size()) { return false; } - for (size_t i = 0; i < b.size(); i++) { + bool start_with_case_ignore(const std::string &a, const char *b) const { + const auto b_len = strlen(b); + if (a.size() < b_len) { return false; } + for (size_t i = 0; i < b_len; i++) { if (case_ignore::to_lower(a[i]) != case_ignore::to_lower(b[i])) { return false; } @@ -5148,19 +5207,18 @@ private: }; inline std::string random_string(size_t length) { - static const char data[] = + constexpr const char data[] = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"; - // std::random_device might actually be deterministic on some - // platforms, but due to lack of support in the c++ standard library, - // doing better requires either some ugly hacks or breaking portability. - static std::random_device seed_gen; - - // Request 128 bits of entropy for initialization - static std::seed_seq seed_sequence{seed_gen(), seed_gen(), seed_gen(), - seed_gen()}; - - static std::mt19937 engine(seed_sequence); + thread_local auto engine([]() { + // std::random_device might actually be deterministic on some + // platforms, but due to lack of support in the c++ standard library, + // doing better requires either some ugly hacks or breaking portability. + std::random_device seed_gen; + // Request 128 bits of entropy for initialization + std::seed_seq seed_sequence{seed_gen(), seed_gen(), seed_gen(), seed_gen()}; + return std::mt19937(seed_sequence); + }()); std::string result; for (size_t i = 0; i < length; i++) { @@ -5232,7 +5290,7 @@ serialize_multipart_formdata(const MultipartFormDataItems &items, inline bool range_error(Request &req, Response &res) { if (!req.ranges.empty() && 200 <= res.status && res.status < 300) { - ssize_t contant_len = static_cast( + ssize_t content_len = static_cast( res.content_length_ ? res.content_length_ : res.body.size()); ssize_t prev_first_pos = -1; @@ -5252,12 +5310,12 @@ inline bool range_error(Request &req, Response &res) { if (first_pos == -1 && last_pos == -1) { first_pos = 0; - last_pos = contant_len; + last_pos = content_len; } if (first_pos == -1) { - first_pos = contant_len - last_pos; - last_pos = contant_len - 1; + first_pos = content_len - last_pos; + last_pos = content_len - 1; } // NOTE: RFC-9110 '14.1.2. Byte Ranges': @@ -5269,13 +5327,13 @@ inline bool range_error(Request &req, Response &res) { // with a value that is one less than the current length of the selected // representation). // https://www.rfc-editor.org/rfc/rfc9110.html#section-14.1.2-6 - if (last_pos == -1 || last_pos >= contant_len) { - last_pos = contant_len - 1; + if (last_pos == -1 || last_pos >= content_len) { + last_pos = content_len - 1; } // Range must be within content length if (!(0 <= first_pos && first_pos <= last_pos && - last_pos <= contant_len - 1)) { + last_pos <= content_len - 1)) { return true; } @@ -5674,7 +5732,8 @@ inline bool parse_www_authenticate(const Response &res, bool is_proxy) { auto auth_key = is_proxy ? "Proxy-Authenticate" : "WWW-Authenticate"; if (res.has_header(auth_key)) { - static auto re = std::regex(R"~((?:(?:,\s*)?(.+?)=(?:"(.*?)"|([^,]*))))~"); + thread_local auto re = + std::regex(R"~((?:(?:,\s*)?(.+?)=(?:"(.*?)"|([^,]*))))~"); auto s = res.get_header_value(auth_key); auto pos = s.find(' '); if (pos != std::string::npos) { @@ -5758,7 +5817,7 @@ inline void hosted_at(const std::string &hostname, inline std::string append_query_params(const std::string &path, const Params ¶ms) { std::string path_with_query = path; - const static std::regex re("[^?]+\\?.*"); + thread_local const std::regex re("[^?]+\\?.*"); auto delm = std::regex_match(path, re) ? '&' : '?'; path_with_query += delm + detail::params_to_query_str(params); return path_with_query; @@ -5987,14 +6046,14 @@ inline ssize_t Stream::write(const std::string &s) { namespace detail { -inline void calc_actual_timeout(time_t max_timeout_msec, - time_t duration_msec, time_t timeout_sec, - time_t timeout_usec, time_t &actual_timeout_sec, +inline void calc_actual_timeout(time_t max_timeout_msec, time_t duration_msec, + time_t timeout_sec, time_t timeout_usec, + time_t &actual_timeout_sec, time_t &actual_timeout_usec) { auto timeout_msec = (timeout_sec * 1000) + (timeout_usec / 1000); auto actual_timeout_msec = - std::min(max_timeout_msec - duration_msec, timeout_msec); + (std::min)(max_timeout_msec - duration_msec, timeout_msec); actual_timeout_sec = actual_timeout_msec / 1000; actual_timeout_usec = (actual_timeout_msec % 1000) * 1000; @@ -6010,12 +6069,16 @@ inline SocketStream::SocketStream( read_timeout_usec_(read_timeout_usec), write_timeout_sec_(write_timeout_sec), write_timeout_usec_(write_timeout_usec), - max_timeout_msec_(max_timeout_msec), start_time(start_time), + max_timeout_msec_(max_timeout_msec), start_time_(start_time), read_buff_(read_buff_size_, 0) {} inline SocketStream::~SocketStream() = default; inline bool SocketStream::is_readable() const { + return read_buff_off_ < read_buff_content_size_; +} + +inline bool SocketStream::wait_readable() const { if (max_timeout_msec_ <= 0) { return select_read(sock_, read_timeout_sec_, read_timeout_usec_) > 0; } @@ -6028,7 +6091,7 @@ inline bool SocketStream::is_readable() const { return select_read(sock_, read_timeout_sec, read_timeout_usec) > 0; } -inline bool SocketStream::is_writable() const { +inline bool SocketStream::wait_writable() const { return select_write(sock_, write_timeout_sec_, write_timeout_usec_) > 0 && is_socket_alive(sock_); } @@ -6055,7 +6118,7 @@ inline ssize_t SocketStream::read(char *ptr, size_t size) { } } - if (!is_readable()) { return -1; } + if (!wait_readable()) { return -1; } read_buff_off_ = 0; read_buff_content_size_ = 0; @@ -6080,7 +6143,7 @@ inline ssize_t SocketStream::read(char *ptr, size_t size) { } inline ssize_t SocketStream::write(const char *ptr, size_t size) { - if (!is_writable()) { return -1; } + if (!wait_writable()) { return -1; } #if defined(_WIN32) && !defined(_WIN64) size = @@ -6104,14 +6167,16 @@ inline socket_t SocketStream::socket() const { return sock_; } inline time_t SocketStream::duration() const { return std::chrono::duration_cast( - std::chrono::steady_clock::now() - start_time) + std::chrono::steady_clock::now() - start_time_) .count(); } // Buffer stream implementation inline bool BufferStream::is_readable() const { return true; } -inline bool BufferStream::is_writable() const { return true; } +inline bool BufferStream::wait_readable() const { return true; } + +inline bool BufferStream::wait_writable() const { return true; } inline ssize_t BufferStream::read(char *ptr, size_t size) { #if defined(_MSC_VER) && _MSC_VER < 1910 @@ -6141,7 +6206,7 @@ inline time_t BufferStream::duration() const { return 0; } inline const std::string &BufferStream::get_buffer() const { return buffer; } inline PathParamsMatcher::PathParamsMatcher(const std::string &pattern) { - static constexpr char marker[] = "/:"; + constexpr const char marker[] = "/:"; // One past the last ending position of a path param substring std::size_t last_param_end = 0; @@ -6162,7 +6227,7 @@ inline PathParamsMatcher::PathParamsMatcher(const std::string &pattern) { static_fragments_.push_back( pattern.substr(last_param_end, marker_pos - last_param_end + 1)); - const auto param_name_start = marker_pos + 2; + const auto param_name_start = marker_pos + str_len(marker); auto sep_pos = pattern.find(separator, param_name_start); if (sep_pos == std::string::npos) { sep_pos = pattern.length(); } @@ -6469,12 +6534,12 @@ inline Server &Server::set_payload_max_length(size_t length) { inline bool Server::bind_to_port(const std::string &host, int port, int socket_flags) { auto ret = bind_internal(host, port, socket_flags); - if (ret == -1) { is_decommisioned = true; } + if (ret == -1) { is_decommissioned = true; } return ret >= 0; } inline int Server::bind_to_any_port(const std::string &host, int socket_flags) { auto ret = bind_internal(host, 0, socket_flags); - if (ret == -1) { is_decommisioned = true; } + if (ret == -1) { is_decommissioned = true; } return ret; } @@ -6488,7 +6553,7 @@ inline bool Server::listen(const std::string &host, int port, inline bool Server::is_running() const { return is_running_; } inline void Server::wait_until_ready() const { - while (!is_running_ && !is_decommisioned) { + while (!is_running_ && !is_decommissioned) { std::this_thread::sleep_for(std::chrono::milliseconds{1}); } } @@ -6500,10 +6565,10 @@ inline void Server::stop() { detail::shutdown_socket(sock); detail::close_socket(sock); } - is_decommisioned = false; + is_decommissioned = false; } -inline void Server::decommission() { is_decommisioned = true; } +inline void Server::decommission() { is_decommissioned = true; } inline bool Server::parse_request_line(const char *s, Request &req) const { auto len = strlen(s); @@ -6526,7 +6591,7 @@ inline bool Server::parse_request_line(const char *s, Request &req) const { if (count != 3) { return false; } } - static const std::set methods{ + thread_local const std::set methods{ "GET", "HEAD", "POST", "PUT", "DELETE", "CONNECT", "OPTIONS", "TRACE", "PATCH", "PRI"}; @@ -6680,6 +6745,10 @@ Server::write_content_with_provider(Stream &strm, const Request &req, } else if (type == detail::EncodingType::Brotli) { #ifdef CPPHTTPLIB_BROTLI_SUPPORT compressor = detail::make_unique(); +#endif + } else if (type == detail::EncodingType::Zstd) { +#ifdef CPPHTTPLIB_ZSTD_SUPPORT + compressor = detail::make_unique(); #endif } else { compressor = detail::make_unique(); @@ -6862,7 +6931,7 @@ Server::create_server_socket(const std::string &host, int port, inline int Server::bind_internal(const std::string &host, int port, int socket_flags) { - if (is_decommisioned) { return -1; } + if (is_decommissioned) { return -1; } if (!is_valid()) { return -1; } @@ -6889,7 +6958,7 @@ inline int Server::bind_internal(const std::string &host, int port, } inline bool Server::listen_internal() { - if (is_decommisioned) { return false; } + if (is_decommissioned) { return false; } auto ret = true; is_running_ = true; @@ -6913,7 +6982,7 @@ inline bool Server::listen_internal() { #endif #if defined _WIN32 - // sockets conneced via WASAccept inherit flags NO_HANDLE_INHERIT, + // sockets connected via WASAccept inherit flags NO_HANDLE_INHERIT, // OVERLAPPED socket_t sock = WSAAccept(svr_sock_, nullptr, nullptr, nullptr, 0); #elif defined SOCK_CLOEXEC @@ -6955,7 +7024,7 @@ inline bool Server::listen_internal() { task_queue->shutdown(); } - is_decommisioned = !ret; + is_decommissioned = !ret; return ret; } @@ -7095,6 +7164,8 @@ inline void Server::apply_ranges(const Request &req, Response &res, res.set_header("Content-Encoding", "gzip"); } else if (type == detail::EncodingType::Brotli) { res.set_header("Content-Encoding", "br"); + } else if (type == detail::EncodingType::Zstd) { + res.set_header("Content-Encoding", "zstd"); } } } @@ -7134,6 +7205,11 @@ inline void Server::apply_ranges(const Request &req, Response &res, #ifdef CPPHTTPLIB_BROTLI_SUPPORT compressor = detail::make_unique(); content_encoding = "br"; +#endif + } else if (type == detail::EncodingType::Zstd) { +#ifdef CPPHTTPLIB_ZSTD_SUPPORT + compressor = detail::make_unique(); + content_encoding = "zstd"; #endif } @@ -7189,20 +7265,6 @@ Server::process_request(Stream &strm, const std::string &remote_addr, res.version = "HTTP/1.1"; res.headers = default_headers_; -#ifdef _WIN32 - // TODO: Increase FD_SETSIZE statically (libzmq), dynamically (MySQL). -#else -#ifndef CPPHTTPLIB_USE_POLL - // Socket file descriptor exceeded FD_SETSIZE... - if (strm.socket() >= FD_SETSIZE) { - Headers dummy; - detail::read_headers(strm, dummy); - res.status = StatusCode::InternalServerError_500; - return write_response(strm, close_connection, req, res); - } -#endif -#endif - // Request line and headers if (!parse_request_line(line_reader.ptr(), req) || !detail::read_headers(strm, req.headers)) { @@ -7394,6 +7456,16 @@ inline ClientImpl::ClientImpl(const std::string &host, int port, client_cert_path_(client_cert_path), client_key_path_(client_key_path) {} inline ClientImpl::~ClientImpl() { + // Wait until all the requests in flight are handled. + size_t retry_count = 10; + while (retry_count-- > 0) { + { + std::lock_guard guard(socket_mutex_); + if (socket_requests_in_flight_ == 0) { break; } + } + std::this_thread::sleep_for(std::chrono::milliseconds{1}); + } + std::lock_guard guard(socket_mutex_); shutdown_socket(socket_); close_socket(socket_); @@ -7519,9 +7591,9 @@ inline bool ClientImpl::read_response_line(Stream &strm, const Request &req, if (!line_reader.getline()) { return false; } #ifdef CPPHTTPLIB_ALLOW_LF_AS_LINE_TERMINATOR - const static std::regex re("(HTTP/1\\.[01]) (\\d{3})(?: (.*?))?\r?\n"); + thread_local const std::regex re("(HTTP/1\\.[01]) (\\d{3})(?: (.*?))?\r?\n"); #else - const static std::regex re("(HTTP/1\\.[01]) (\\d{3})(?: (.*?))?\r\n"); + thread_local const std::regex re("(HTTP/1\\.[01]) (\\d{3})(?: (.*?))?\r\n"); #endif std::cmatch m; @@ -7577,7 +7649,7 @@ inline bool ClientImpl::send_(Request &req, Response &res, Error &error) { #endif if (!is_alive) { - // Attempt to avoid sigpipe by shutting down nongracefully if it seems + // Attempt to avoid sigpipe by shutting down non-gracefully if it seems // like the other side has already closed the connection Also, there // cannot be any requests in flight from other threads since we locked // request_mutex_, so safe to close everything immediately @@ -7753,7 +7825,7 @@ inline bool ClientImpl::redirect(Request &req, Response &res, Error &error) { auto location = res.get_header_value("location"); if (location.empty()) { return false; } - const static std::regex re( + thread_local const std::regex re( R"((?:(https?):)?(?://(?:\[([a-fA-F\d:]+)\]|([^:/?#]+))(?::(\d+))?)?([^?#]*)(\?[^#]*)?(?:#.*)?)"); std::smatch m; @@ -7862,6 +7934,10 @@ inline bool ClientImpl::write_request(Stream &strm, Request &req, #ifdef CPPHTTPLIB_ZLIB_SUPPORT if (!accept_encoding.empty()) { accept_encoding += ", "; } accept_encoding += "gzip, deflate"; +#endif +#ifdef CPPHTTPLIB_ZSTD_SUPPORT + if (!accept_encoding.empty()) { accept_encoding += ", "; } + accept_encoding += "zstd"; #endif req.set_header("Accept-Encoding", accept_encoding); } @@ -8213,8 +8289,7 @@ inline bool ClientImpl::process_socket( std::function callback) { return detail::process_client_socket( socket.sock, read_timeout_sec_, read_timeout_usec_, write_timeout_sec_, - write_timeout_usec_, max_timeout_msec_, start_time, - std::move(callback)); + write_timeout_usec_, max_timeout_msec_, start_time, std::move(callback)); } inline bool ClientImpl::is_ssl() const { return false; } @@ -9009,7 +9084,7 @@ inline void ClientImpl::enable_server_hostname_verification(bool enabled) { } inline void ClientImpl::set_server_certificate_verifier( - std::function verifier) { + std::function verifier) { server_certificate_verifier_ = verifier; } #endif @@ -9062,18 +9137,13 @@ inline void ssl_delete(std::mutex &ctx_mutex, SSL *ssl, socket_t sock, // Note that it is not always possible to avoid SIGPIPE, this is merely a // best-efforts. if (shutdown_gracefully) { -#ifdef _WIN32 (void)(sock); - SSL_shutdown(ssl); -#else - detail::set_socket_opt_time(sock, SOL_SOCKET, SO_RCVTIMEO, 1, 0); - - auto ret = SSL_shutdown(ssl); - while (ret == 0) { - std::this_thread::sleep_for(std::chrono::milliseconds{100}); - ret = SSL_shutdown(ssl); + // SSL_shutdown() returns 0 on first call (indicating close_notify alert + // sent) and 1 on subsequent call (indicating close_notify alert received) + if (SSL_shutdown(ssl) == 0) { + // Expected to return 1, but even if it doesn't, we free ssl + SSL_shutdown(ssl); } -#endif } std::lock_guard guard(ctx_mutex); @@ -9124,19 +9194,11 @@ inline bool process_client_socket_ssl( time_t max_timeout_msec, std::chrono::time_point start_time, T callback) { SSLSocketStream strm(sock, ssl, read_timeout_sec, read_timeout_usec, - write_timeout_sec, write_timeout_usec, - max_timeout_msec, start_time); + write_timeout_sec, write_timeout_usec, max_timeout_msec, + start_time); return callback(strm); } -class SSLInit { -public: - SSLInit() { - OPENSSL_init_ssl( - OPENSSL_INIT_LOAD_SSL_STRINGS | OPENSSL_INIT_LOAD_CRYPTO_STRINGS, NULL); - } -}; - // SSL socket stream implementation inline SSLSocketStream::SSLSocketStream( socket_t sock, SSL *ssl, time_t read_timeout_sec, time_t read_timeout_usec, @@ -9147,13 +9209,17 @@ inline SSLSocketStream::SSLSocketStream( read_timeout_usec_(read_timeout_usec), write_timeout_sec_(write_timeout_sec), write_timeout_usec_(write_timeout_usec), - max_timeout_msec_(max_timeout_msec), start_time(start_time) { + max_timeout_msec_(max_timeout_msec), start_time_(start_time) { SSL_clear_mode(ssl, SSL_MODE_AUTO_RETRY); } inline SSLSocketStream::~SSLSocketStream() = default; inline bool SSLSocketStream::is_readable() const { + return SSL_pending(ssl_) > 0; +} + +inline bool SSLSocketStream::wait_readable() const { if (max_timeout_msec_ <= 0) { return select_read(sock_, read_timeout_sec_, read_timeout_usec_) > 0; } @@ -9166,7 +9232,7 @@ inline bool SSLSocketStream::is_readable() const { return select_read(sock_, read_timeout_sec, read_timeout_usec) > 0; } -inline bool SSLSocketStream::is_writable() const { +inline bool SSLSocketStream::wait_writable() const { return select_write(sock_, write_timeout_sec_, write_timeout_usec_) > 0 && is_socket_alive(sock_) && !is_ssl_peer_could_be_closed(ssl_, sock_); } @@ -9174,7 +9240,7 @@ inline bool SSLSocketStream::is_writable() const { inline ssize_t SSLSocketStream::read(char *ptr, size_t size) { if (SSL_pending(ssl_) > 0) { return SSL_read(ssl_, ptr, static_cast(size)); - } else if (is_readable()) { + } else if (wait_readable()) { auto ret = SSL_read(ssl_, ptr, static_cast(size)); if (ret < 0) { auto err = SSL_get_error(ssl_, ret); @@ -9188,7 +9254,7 @@ inline ssize_t SSLSocketStream::read(char *ptr, size_t size) { #endif if (SSL_pending(ssl_) > 0) { return SSL_read(ssl_, ptr, static_cast(size)); - } else if (is_readable()) { + } else if (wait_readable()) { std::this_thread::sleep_for(std::chrono::microseconds{10}); ret = SSL_read(ssl_, ptr, static_cast(size)); if (ret >= 0) { return ret; } @@ -9205,7 +9271,7 @@ inline ssize_t SSLSocketStream::read(char *ptr, size_t size) { } inline ssize_t SSLSocketStream::write(const char *ptr, size_t size) { - if (is_writable()) { + if (wait_writable()) { auto handle_size = static_cast( std::min(size, (std::numeric_limits::max)())); @@ -9220,7 +9286,7 @@ inline ssize_t SSLSocketStream::write(const char *ptr, size_t size) { #else while (--n >= 0 && err == SSL_ERROR_WANT_WRITE) { #endif - if (is_writable()) { + if (wait_writable()) { std::this_thread::sleep_for(std::chrono::microseconds{10}); ret = SSL_write(ssl_, ptr, static_cast(handle_size)); if (ret >= 0) { return ret; } @@ -9249,12 +9315,10 @@ inline socket_t SSLSocketStream::socket() const { return sock_; } inline time_t SSLSocketStream::duration() const { return std::chrono::duration_cast( - std::chrono::steady_clock::now() - start_time) + std::chrono::steady_clock::now() - start_time_) .count(); } -static SSLInit sslinit_; - } // namespace detail // SSL HTTP server implementation @@ -9623,12 +9687,18 @@ inline bool SSLClient::initialize_ssl(Socket &socket, Error &error) { } if (server_certificate_verification_) { + auto verification_status = SSLVerifierResponse::NoDecisionMade; + if (server_certificate_verifier_) { - if (!server_certificate_verifier_(ssl2)) { - error = Error::SSLServerVerification; - return false; - } - } else { + verification_status = server_certificate_verifier_(ssl2); + } + + if (verification_status == SSLVerifierResponse::CertificateRejected) { + error = Error::SSLServerVerification; + return false; + } + + if (verification_status == SSLVerifierResponse::NoDecisionMade) { verify_result_ = SSL_get_verify_result(ssl2); if (verify_result_ != X509_V_OK) { @@ -9740,8 +9810,8 @@ SSLClient::verify_host_with_subject_alt_name(X509 *server_cert) const { auto type = GEN_DNS; - struct in6_addr addr6{}; - struct in_addr addr{}; + struct in6_addr addr6 = {}; + struct in_addr addr = {}; size_t addr_len = 0; #ifndef __MINGW32__ @@ -10389,7 +10459,7 @@ inline void Client::enable_server_hostname_verification(bool enabled) { } inline void Client::set_server_certificate_verifier( - std::function verifier) { + std::function verifier) { cli_->set_server_certificate_verifier(verifier); } #endif @@ -10433,8 +10503,4 @@ inline SSL_CTX *Client::ssl_context() const { } // namespace httplib -#if defined(_WIN32) && defined(CPPHTTPLIB_USE_POLL) -#undef poll -#endif - #endif // CPPHTTPLIB_HTTPLIB_H diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 18caa9127..77dd316d9 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -4459,15 +4459,24 @@ int main(int argc, char ** argv) { llama_backend_free(); }; - // bind HTTP listen port bool was_bound = false; - if (params.port == 0) { - int bound_port = svr->bind_to_any_port(params.hostname); - if ((was_bound = (bound_port >= 0))) { - params.port = bound_port; - } + if (string_ends_with(std::string(params.hostname), ".sock")) { + LOG_INF("%s: setting address family to AF_UNIX\n", __func__); + svr->set_address_family(AF_UNIX); + // bind_to_port requires a second arg, any value other than 0 should + // simply get ignored + was_bound = svr->bind_to_port(params.hostname, 8080); } else { - was_bound = svr->bind_to_port(params.hostname, params.port); + LOG_INF("%s: binding port with default address family\n", __func__); + // bind HTTP listen port + if (params.port == 0) { + int bound_port = svr->bind_to_any_port(params.hostname); + if ((was_bound = (bound_port >= 0))) { + params.port = bound_port; + } + } else { + was_bound = svr->bind_to_port(params.hostname, params.port); + } } if (!was_bound) { From ab6ab8f809bd514d88e02a63869b4d619f13fa86 Mon Sep 17 00:00:00 2001 From: Radoslav Gerganov Date: Fri, 28 Mar 2025 08:18:04 +0200 Subject: [PATCH 19/26] rpc : send hash when tensor data is above some fixed threshold (#12496) * rpc : send hash when tensor data is above some fixed threshold ref #10095 * rpc : put cache under $HOME/.cache/llama.cpp * try to fix win32 build * another try to fix win32 build * remove llama as dependency --- examples/rpc/CMakeLists.txt | 6 +- examples/rpc/rpc-server.cpp | 146 ++++++++++++++++++++++++++++++-- ggml/include/ggml-rpc.h | 4 +- ggml/src/ggml-rpc/ggml-rpc.cpp | 150 +++++++++++++++++++++++++++++++-- 4 files changed, 290 insertions(+), 16 deletions(-) diff --git a/examples/rpc/CMakeLists.txt b/examples/rpc/CMakeLists.txt index ae48fb98d..c2c748148 100644 --- a/examples/rpc/CMakeLists.txt +++ b/examples/rpc/CMakeLists.txt @@ -1,2 +1,4 @@ -add_executable(rpc-server rpc-server.cpp) -target_link_libraries(rpc-server PRIVATE ggml llama) +set(TARGET rpc-server) +add_executable(${TARGET} rpc-server.cpp) +target_link_libraries(${TARGET} PRIVATE ggml) +target_compile_features(${TARGET} PRIVATE cxx_std_17) diff --git a/examples/rpc/rpc-server.cpp b/examples/rpc/rpc-server.cpp index 8b1b23eda..3d590feb0 100644 --- a/examples/rpc/rpc-server.cpp +++ b/examples/rpc/rpc-server.cpp @@ -1,3 +1,7 @@ +#if defined(_MSC_VER) +#define _SILENCE_CXX17_CODECVT_HEADER_DEPRECATION_WARNING +#endif + #include "ggml-cpu.h" #ifdef GGML_USE_CUDA @@ -18,26 +22,142 @@ #include "ggml-rpc.h" #ifdef _WIN32 +# define DIRECTORY_SEPARATOR '\\' +# include # include +# include +# include #else +# define DIRECTORY_SEPARATOR '/' # include +# include #endif +#include #include #include +#include +#include + +namespace fs = std::filesystem; + +// NOTE: this is copied from common.cpp to avoid linking with libcommon +// returns true if successful, false otherwise +static bool fs_create_directory_with_parents(const std::string & path) { +#ifdef _WIN32 + std::wstring_convert> converter; + std::wstring wpath = converter.from_bytes(path); + + // if the path already exists, check whether it's a directory + const DWORD attributes = GetFileAttributesW(wpath.c_str()); + if ((attributes != INVALID_FILE_ATTRIBUTES) && (attributes & FILE_ATTRIBUTE_DIRECTORY)) { + return true; + } + + size_t pos_slash = 0; + + // process path from front to back, procedurally creating directories + while ((pos_slash = path.find('\\', pos_slash)) != std::string::npos) { + const std::wstring subpath = wpath.substr(0, pos_slash); + const wchar_t * test = subpath.c_str(); + + const bool success = CreateDirectoryW(test, NULL); + if (!success) { + const DWORD error = GetLastError(); + + // if the path already exists, ensure that it's a directory + if (error == ERROR_ALREADY_EXISTS) { + const DWORD attributes = GetFileAttributesW(subpath.c_str()); + if (attributes == INVALID_FILE_ATTRIBUTES || !(attributes & FILE_ATTRIBUTE_DIRECTORY)) { + return false; + } + } else { + return false; + } + } + + pos_slash += 1; + } + + return true; +#else + // if the path already exists, check whether it's a directory + struct stat info; + if (stat(path.c_str(), &info) == 0) { + return S_ISDIR(info.st_mode); + } + + size_t pos_slash = 1; // skip leading slashes for directory creation + + // process path from front to back, procedurally creating directories + while ((pos_slash = path.find('/', pos_slash)) != std::string::npos) { + const std::string subpath = path.substr(0, pos_slash); + struct stat info; + + // if the path already exists, ensure that it's a directory + if (stat(subpath.c_str(), &info) == 0) { + if (!S_ISDIR(info.st_mode)) { + return false; + } + } else { + // create parent directories + const int ret = mkdir(subpath.c_str(), 0755); + if (ret != 0) { + return false; + } + } + + pos_slash += 1; + } + + return true; +#endif // _WIN32 +} + +// NOTE: this is copied from common.cpp to avoid linking with libcommon +static std::string fs_get_cache_directory() { + std::string cache_directory = ""; + auto ensure_trailing_slash = [](std::string p) { + // Make sure to add trailing slash + if (p.back() != DIRECTORY_SEPARATOR) { + p += DIRECTORY_SEPARATOR; + } + return p; + }; + if (getenv("LLAMA_CACHE")) { + cache_directory = std::getenv("LLAMA_CACHE"); + } else { +#ifdef __linux__ + if (std::getenv("XDG_CACHE_HOME")) { + cache_directory = std::getenv("XDG_CACHE_HOME"); + } else { + cache_directory = std::getenv("HOME") + std::string("/.cache/"); + } +#elif defined(__APPLE__) + cache_directory = std::getenv("HOME") + std::string("/Library/Caches/"); +#elif defined(_WIN32) + cache_directory = std::getenv("LOCALAPPDATA"); +#endif // __linux__ + cache_directory = ensure_trailing_slash(cache_directory); + cache_directory += "llama.cpp"; + } + return ensure_trailing_slash(cache_directory); +} struct rpc_server_params { std::string host = "127.0.0.1"; int port = 50052; size_t backend_mem = 0; + bool use_cache = false; }; static void print_usage(int /*argc*/, char ** argv, rpc_server_params params) { fprintf(stderr, "Usage: %s [options]\n\n", argv[0]); fprintf(stderr, "options:\n"); - fprintf(stderr, " -h, --help show this help message and exit\n"); - fprintf(stderr, " -H HOST, --host HOST host to bind to (default: %s)\n", params.host.c_str()); - fprintf(stderr, " -p PORT, --port PORT port to bind to (default: %d)\n", params.port); - fprintf(stderr, " -m MEM, --mem MEM backend memory size (in MB)\n"); + fprintf(stderr, " -h, --help show this help message and exit\n"); + fprintf(stderr, " -H HOST, --host HOST host to bind to (default: %s)\n", params.host.c_str()); + fprintf(stderr, " -p PORT, --port PORT port to bind to (default: %d)\n", params.port); + fprintf(stderr, " -m MEM, --mem MEM backend memory size (in MB)\n"); + fprintf(stderr, " -c, --cache enable local file cache\n"); fprintf(stderr, "\n"); } @@ -58,6 +178,8 @@ static bool rpc_server_params_parse(int argc, char ** argv, rpc_server_params & if (params.port <= 0 || params.port > 65535) { return false; } + } else if (arg == "-c" || arg == "--cache") { + params.use_cache = true; } else if (arg == "-m" || arg == "--mem") { if (++i >= argc) { return false; @@ -164,8 +286,20 @@ int main(int argc, char * argv[]) { } else { get_backend_memory(&free_mem, &total_mem); } - printf("Starting RPC server on %s, backend memory: %zu MB\n", endpoint.c_str(), free_mem / (1024 * 1024)); - ggml_backend_rpc_start_server(backend, endpoint.c_str(), free_mem, total_mem); + const char * cache_dir = nullptr; + std::string cache_dir_str = fs_get_cache_directory() + "rpc/"; + if (params.use_cache) { + if (!fs_create_directory_with_parents(cache_dir_str)) { + fprintf(stderr, "Failed to create cache directory: %s\n", cache_dir_str.c_str()); + return 1; + } + cache_dir = cache_dir_str.c_str(); + } + printf("Starting RPC server\n"); + printf(" endpoint : %s\n", endpoint.c_str()); + printf(" local cache : %s\n", cache_dir ? cache_dir : "n/a"); + printf(" backend memory : %zu MB\n", free_mem / (1024 * 1024)); + ggml_backend_rpc_start_server(backend, endpoint.c_str(), cache_dir, free_mem, total_mem); ggml_backend_free(backend); return 0; } diff --git a/ggml/include/ggml-rpc.h b/ggml/include/ggml-rpc.h index ade6c3b0e..4e0d210f8 100644 --- a/ggml/include/ggml-rpc.h +++ b/ggml/include/ggml-rpc.h @@ -17,7 +17,9 @@ GGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_rpc_buffer_type(const c GGML_BACKEND_API void ggml_backend_rpc_get_device_memory(const char * endpoint, size_t * free, size_t * total); -GGML_BACKEND_API void ggml_backend_rpc_start_server(ggml_backend_t backend, const char * endpoint, size_t free_mem, size_t total_mem); +GGML_BACKEND_API void ggml_backend_rpc_start_server(ggml_backend_t backend, const char * endpoint, + const char * cache_dir, + size_t free_mem, size_t total_mem); GGML_BACKEND_API ggml_backend_reg_t ggml_backend_rpc_reg(void); diff --git a/ggml/src/ggml-rpc/ggml-rpc.cpp b/ggml/src/ggml-rpc/ggml-rpc.cpp index 6c3b80b08..862b9b666 100644 --- a/ggml/src/ggml-rpc/ggml-rpc.cpp +++ b/ggml/src/ggml-rpc/ggml-rpc.cpp @@ -26,6 +26,10 @@ # include #endif #include +#include +#include + +namespace fs = std::filesystem; #ifdef _WIN32 typedef SOCKET sockfd_t; @@ -80,6 +84,7 @@ enum rpc_cmd { RPC_CMD_FREE_BUFFER, RPC_CMD_BUFFER_CLEAR, RPC_CMD_SET_TENSOR, + RPC_CMD_SET_TENSOR_HASH, RPC_CMD_GET_TENSOR, RPC_CMD_COPY_TENSOR, RPC_CMD_GRAPH_COMPUTE, @@ -89,6 +94,9 @@ enum rpc_cmd { RPC_CMD_COUNT, }; +// Try RPC_CMD_SET_TENSOR_HASH first when data size is larger than this threshold +const size_t HASH_THRESHOLD = 10 * 1024 * 1024; + struct rpc_msg_get_alloc_size_req { rpc_tensor tensor; }; @@ -135,6 +143,10 @@ struct rpc_msg_buffer_clear_req { uint8_t value; }; +struct rpc_msg_set_tensor_hash_rsp { + uint8_t result; +}; + struct rpc_msg_get_tensor_req { rpc_tensor tensor; uint64_t offset; @@ -187,6 +199,18 @@ struct ggml_backend_rpc_buffer_context { // RPC helper functions +// Computes FNV-1a hash of the data +static uint64_t fnv_hash(const uint8_t * data, size_t len) { + const uint64_t fnv_prime = 0x100000001b3ULL; + uint64_t hash = 0xcbf29ce484222325ULL; + + for (size_t i = 0; i < len; ++i) { + hash ^= data[i]; + hash *= fnv_prime; + } + return hash; +} + static std::shared_ptr make_socket(sockfd_t fd) { #ifdef _WIN32 if (fd == INVALID_SOCKET) { @@ -483,10 +507,26 @@ static enum ggml_status ggml_backend_rpc_buffer_init_tensor(ggml_backend_buffer_ static void ggml_backend_rpc_buffer_set_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, const void * data, size_t offset, size_t size) { ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context; - // input serialization format: | rpc_tensor | offset (8 bytes) | data (size bytes) | + rpc_tensor rpc_tensor = serialize_tensor(tensor); + if (size > HASH_THRESHOLD) { + // input serialization format: | rpc_tensor | offset (8 bytes) | hash (8 bytes) + size_t input_size = sizeof(rpc_tensor) + sizeof(uint64_t) + sizeof(uint64_t); + std::vector input(input_size, 0); + uint64_t hash = fnv_hash((const uint8_t*)data, size); + memcpy(input.data(), &rpc_tensor, sizeof(rpc_tensor)); + memcpy(input.data() + sizeof(rpc_tensor), &offset, sizeof(offset)); + memcpy(input.data() + sizeof(rpc_tensor) + sizeof(offset), &hash, sizeof(hash)); + rpc_msg_set_tensor_hash_rsp response; + bool status = send_rpc_cmd(ctx->sock, RPC_CMD_SET_TENSOR_HASH, input.data(), input.size(), &response, sizeof(response)); + GGML_ASSERT(status); + if (response.result) { + // the server has the same data, no need to send it + return; + } + } + // input serialization format: | rpc_tensor | offset (8 bytes) | data (size bytes) size_t input_size = sizeof(rpc_tensor) + sizeof(uint64_t) + size; std::vector input(input_size, 0); - rpc_tensor rpc_tensor = serialize_tensor(tensor); memcpy(input.data(), &rpc_tensor, sizeof(rpc_tensor)); memcpy(input.data() + sizeof(rpc_tensor), &offset, sizeof(offset)); memcpy(input.data() + sizeof(rpc_tensor) + sizeof(offset), data, size); @@ -772,7 +812,9 @@ void ggml_backend_rpc_get_device_memory(const char * endpoint, size_t * free, si class rpc_server { public: - rpc_server(ggml_backend_t backend) : backend(backend) {} + rpc_server(ggml_backend_t backend, const char * cache_dir) + : backend(backend), cache_dir(cache_dir) { + } ~rpc_server(); void alloc_buffer(const rpc_msg_alloc_buffer_req & request, rpc_msg_alloc_buffer_rsp & response); @@ -782,6 +824,7 @@ public: bool free_buffer(const rpc_msg_free_buffer_req & request); bool buffer_clear(const rpc_msg_buffer_clear_req & request); bool set_tensor(const std::vector & input); + bool set_tensor_hash(const std::vector & input, rpc_msg_set_tensor_hash_rsp & response); bool get_tensor(const rpc_msg_get_tensor_req & request, std::vector & response); bool copy_tensor(const rpc_msg_copy_tensor_req & request, rpc_msg_copy_tensor_rsp & response); bool graph_compute(const std::vector & input, rpc_msg_graph_compute_rsp & response); @@ -789,6 +832,7 @@ public: bool get_alloc_size(const rpc_msg_get_alloc_size_req & request, rpc_msg_get_alloc_size_rsp & response); private: + bool get_cached_file(uint64_t hash, std::vector & data); ggml_tensor * deserialize_tensor(struct ggml_context * ctx, const rpc_tensor * tensor); ggml_tensor * create_node(uint64_t id, struct ggml_context * ctx, @@ -797,6 +841,7 @@ private: ggml_backend_t backend; + const char * cache_dir; std::unordered_set buffers; }; @@ -960,11 +1005,85 @@ bool rpc_server::set_tensor(const std::vector & input) { } const void * data = input.data() + sizeof(rpc_tensor) + sizeof(offset); + if (cache_dir && size > HASH_THRESHOLD) { + uint64_t hash = fnv_hash((const uint8_t*)data, size); + char hash_str[17]; + snprintf(hash_str, sizeof(hash_str), "%016" PRIx64, hash); + // save to cache_dir/hash_str + fs::path cache_file = fs::path(cache_dir) / hash_str; + std::ofstream ofs(cache_file, std::ios::binary); + ofs.write((const char *)data, size); + printf("[%s] saved to '%s'\n", __func__, cache_file.c_str()); + } ggml_backend_tensor_set(tensor, data, offset, size); ggml_free(ctx); return true; } +bool rpc_server::get_cached_file(uint64_t hash, std::vector & data) { + if (!cache_dir) { + return false; + } + char hash_str[17]; + snprintf(hash_str, sizeof(hash_str), "%016" PRIx64, hash); + fs::path cache_file = fs::path(cache_dir) / hash_str; + if (!fs::exists(cache_file)) { + return false; + } + std::ifstream ifs(cache_file, std::ios::binary); + ifs.seekg(0, std::ios::end); + size_t size = ifs.tellg(); + ifs.seekg(0, std::ios::beg); + data.resize(size); + ifs.read((char *)data.data(), size); + return true; +} + +bool rpc_server::set_tensor_hash(const std::vector & input, rpc_msg_set_tensor_hash_rsp & response) +{ + // serialization format: | rpc_tensor | offset (8 bytes) | hash (8 bytes) | + if (input.size() != sizeof(rpc_tensor) + 16) { + return false; + } + const rpc_tensor * in_tensor = (const rpc_tensor *)input.data(); + uint64_t offset; + memcpy(&offset, input.data() + sizeof(rpc_tensor), sizeof(offset)); + const uint64_t * hash = (const uint64_t *)(input.data() + sizeof(rpc_tensor) + sizeof(offset)); + std::vector cached_file; + if (!get_cached_file(*hash, cached_file)) { + response.result = 0; + return true; + } + size_t size = cached_file.size(); + struct ggml_init_params params { + /*.mem_size =*/ ggml_tensor_overhead(), + /*.mem_buffer =*/ NULL, + /*.no_alloc =*/ true, + }; + struct ggml_context * ctx = ggml_init(params); + ggml_tensor * tensor = deserialize_tensor(ctx, in_tensor); + if (tensor == nullptr) { + GGML_LOG_ERROR("[%s] error deserializing tensor\n", __func__); + ggml_free(ctx); + return false; + } + GGML_PRINT_DEBUG("[%s] buffer: %p, data: %p, offset: %" PRIu64 ", size: %zu, hash: %" PRIx64 "\n", __func__, (void*)tensor->buffer, tensor->data, offset, size, *hash); + + // sanitize tensor->data + { + const size_t p0 = (size_t) ggml_backend_buffer_get_base(tensor->buffer); + const size_t p1 = p0 + ggml_backend_buffer_get_size(tensor->buffer); + + if (in_tensor->data + offset < p0 || in_tensor->data + offset >= p1 || size > (p1 - in_tensor->data - offset)) { + GGML_ABORT("[%s] tensor->data out of bounds\n", __func__); + } + } + ggml_backend_tensor_set(tensor, cached_file.data(), offset, size); + response.result = 1; + ggml_free(ctx); + return true; +} + bool rpc_server::init_tensor(const rpc_msg_init_tensor_req & request) { struct ggml_init_params params { /*.mem_size =*/ ggml_tensor_overhead(), @@ -1148,8 +1267,9 @@ rpc_server::~rpc_server() { } } -static void rpc_serve_client(ggml_backend_t backend, sockfd_t sockfd, size_t free_mem, size_t total_mem) { - rpc_server server(backend); +static void rpc_serve_client(ggml_backend_t backend, const char * cache_dir, + sockfd_t sockfd, size_t free_mem, size_t total_mem) { + rpc_server server(backend, cache_dir); while (true) { uint8_t cmd; if (!recv_data(sockfd, &cmd, 1)) { @@ -1260,6 +1380,20 @@ static void rpc_serve_client(ggml_backend_t backend, sockfd_t sockfd, size_t fre } break; } + case RPC_CMD_SET_TENSOR_HASH: { + std::vector input; + if (!recv_msg(sockfd, input)) { + return; + } + rpc_msg_set_tensor_hash_rsp response; + if (!server.set_tensor_hash(input, response)) { + return; + } + if (!send_msg(sockfd, &response, sizeof(response))) { + return; + } + break; + } case RPC_CMD_INIT_TENSOR: { rpc_msg_init_tensor_req request; if (!recv_msg(sockfd, &request,sizeof(request))) { @@ -1335,7 +1469,9 @@ static void rpc_serve_client(ggml_backend_t backend, sockfd_t sockfd, size_t fre } } -void ggml_backend_rpc_start_server(ggml_backend_t backend, const char * endpoint, size_t free_mem, size_t total_mem) { +void ggml_backend_rpc_start_server(ggml_backend_t backend, const char * endpoint, + const char * cache_dir, + size_t free_mem, size_t total_mem) { std::string host; int port; if (!parse_endpoint(endpoint, host, port)) { @@ -1364,7 +1500,7 @@ void ggml_backend_rpc_start_server(ggml_backend_t backend, const char * endpoint } printf("Accepted client connection, free_mem=%zu, total_mem=%zu\n", free_mem, total_mem); fflush(stdout); - rpc_serve_client(backend, client_socket->fd, free_mem, total_mem); + rpc_serve_client(backend, cache_dir, client_socket->fd, free_mem, total_mem); printf("Client connection closed\n"); fflush(stdout); } From 13731766db91ec927c1b61bf502ac7a9be2b11b9 Mon Sep 17 00:00:00 2001 From: amritahs-ibm Date: Fri, 28 Mar 2025 13:13:22 +0530 Subject: [PATCH 20/26] llamafile : ppc64le GEMV forwarding for FP32. (#12594) This patch enables usage of MMA when one of the dimensions of the matrix(ie either M or N) is 1. This is useful in case of token generation where N < 2. The concept of 'GEMV Forwarding' is used where when one of the matrix has a single row/column, the elements are broadcasted, instead of using packing routine to prepack the matrix elements. This change results in 5% - 15% improvement in total speed(ie all tokens/total time), across various batch sizes. This is in comparision with the corresponding dot product implementation. The patch is tested with FP32 models of Meta-Lllama-3-8B, Mistral-7B, Llama-2-7B-chat-hf on a IBM POWER10 machine. Signed-off-by: Amrita H S --- ggml/src/ggml-cpu/llamafile/sgemm.cpp | 18 ++++++++++++++++-- 1 file changed, 16 insertions(+), 2 deletions(-) diff --git a/ggml/src/ggml-cpu/llamafile/sgemm.cpp b/ggml/src/ggml-cpu/llamafile/sgemm.cpp index 92dfbc2d2..f6374f789 100644 --- a/ggml/src/ggml-cpu/llamafile/sgemm.cpp +++ b/ggml/src/ggml-cpu/llamafile/sgemm.cpp @@ -2680,13 +2680,25 @@ class tinyBLAS_PPC { __builtin_mma_xxsetaccz(&acc_0); vec_t vec_A[4] {0}, vec_B[4] = {0}; for (int l=0; l= 4 && RM == 1) { + /* 'GEMV Forwarding' concept is used in first two conditional loops. + * when one of the matrix has a single row/column, the elements are + * broadcasted, instead of using packing routine to prepack the + * matrix elements. + */ + if (RM == 1) { TA* a = const_cast(A+(ii)*lda+l); - packTranspose(B+(jj*ldb)+l, ldb, 4, 4, (TA*)vec_B); + packTranspose(B+(jj*ldb)+l, ldb, RN, 4, (TA*)vec_B); vec_A[0] = (vec_t)vec_xl(0,a); vec_A[1] = (vec_t)vec_splats(*((TA*)&vec_A+1)); vec_A[2] = (vec_t)vec_splats(*((TA*)&vec_A+2)); vec_A[3] = (vec_t)vec_splats(*((TA*)&vec_A+3)); + } else if (RN == 1) { + packTranspose(A+(ii*lda)+l, lda, RM, 4, (TA*)vec_A); + TB* b = const_cast(B+(jj)*ldb+l); + vec_B[0] = (vec_t)vec_xl(0,b); + vec_B[1] = (vec_t)vec_splats(*((TB*)&vec_B+1)); + vec_B[2] = (vec_t)vec_splats(*((TB*)&vec_B+2)); + vec_B[3] = (vec_t)vec_splats(*((TB*)&vec_B+3)); } else { packTranspose(A+(ii*lda)+l, lda, RM, 4, (TA*)vec_A); packTranspose(B+(jj*ldb)+l, ldb, RN, 4, (TA*)vec_B); @@ -2790,8 +2802,10 @@ bool llamafile_sgemm(const struct ggml_compute_params * params, int64_t m, int64 assert(params->ith < params->nth); // only enable sgemm for prompt processing +#if !defined(__MMA__) if (n < 2) return false; +#endif if (Ctype != GGML_TYPE_F32) return false; From ef03229ff423dd1991f4f44ef1352f03334d86eb Mon Sep 17 00:00:00 2001 From: Radoslav Gerganov Date: Fri, 28 Mar 2025 09:44:13 +0200 Subject: [PATCH 21/26] rpc : update README for cache usage (#12620) --- examples/rpc/README.md | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/examples/rpc/README.md b/examples/rpc/README.md index 312bb634d..561f19fda 100644 --- a/examples/rpc/README.md +++ b/examples/rpc/README.md @@ -72,3 +72,14 @@ $ bin/llama-cli -m ../models/tinyllama-1b/ggml-model-f16.gguf -p "Hello, my name This way you can offload model layers to both local and remote devices. +### Local cache + +The RPC server can use a local cache to store large tensors and avoid transferring them over the network. +This can speed up model loading significantly, especially when using large models. +To enable the cache, use the `-c` option: + +```bash +$ bin/rpc-server -c +``` + +By default, the cache is stored in the `$HOME/.cache/llama.cpp/rpc` directory and can be controlled via the `LLAMA_CACHE` environment variable. From 5d01670266859444366e4f333ade5e0e5e2ae63d Mon Sep 17 00:00:00 2001 From: Benson Wong Date: Fri, 28 Mar 2025 01:05:44 -0700 Subject: [PATCH 22/26] server : include speculative decoding stats when timings_per_token is enabled (#12603) * Include speculative decoding stats when timings_per_token is true New fields added to the `timings` object: - draft_n : number of draft tokens generated - draft_accepted_n : number of draft tokens accepted - draft_accept_ratio: ratio of accepted/generated * Remove redundant draft_accept_ratio var * add draft acceptance rate to server console output --- examples/server/server.cpp | 42 +++++++++++++++++++++++++++++++++++++- 1 file changed, 41 insertions(+), 1 deletion(-) diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 77dd316d9..17a292da1 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -489,8 +489,12 @@ struct result_timings { double predicted_per_token_ms; double predicted_per_second; + // Optional speculative metrics - only included when > 0 + int32_t draft_n = 0; + int32_t draft_n_accepted = 0; + json to_json() const { - return { + json base = { {"prompt_n", prompt_n}, {"prompt_ms", prompt_ms}, {"prompt_per_token_ms", prompt_per_token_ms}, @@ -501,6 +505,13 @@ struct result_timings { {"predicted_per_token_ms", predicted_per_token_ms}, {"predicted_per_second", predicted_per_second}, }; + + if (draft_n > 0) { + base["draft_n"] = draft_n; + base["draft_n_accepted"] = draft_n_accepted; + } + + return base; } }; @@ -1299,6 +1310,10 @@ struct server_slot { std::function callback_on_release; + // Speculative decoding stats + int32_t n_draft_total = 0; // Total draft tokens generated + int32_t n_draft_accepted = 0; // Draft tokens actually accepted + void reset() { SLT_DBG(*this, "%s", "\n"); @@ -1315,6 +1330,10 @@ struct server_slot { generated_tokens.clear(); generated_token_probs.clear(); + + // clear speculative decoding stats + n_draft_total = 0; + n_draft_accepted = 0; } bool is_non_causal() const { @@ -1381,6 +1400,12 @@ struct server_slot { timings.predicted_per_token_ms = t_token_generation / n_decoded; timings.predicted_per_second = 1e3 / t_token_generation * n_decoded; + // Add speculative metrics + if (n_draft_total > 0) { + timings.draft_n = n_draft_total; + timings.draft_n_accepted = n_draft_accepted; + } + return timings; } @@ -1428,6 +1453,15 @@ struct server_slot { t_prompt_processing, n_prompt_tokens_processed, t_prompt, n_prompt_second, t_token_generation, n_decoded, t_gen, n_gen_second, t_prompt_processing + t_token_generation, n_prompt_tokens_processed + n_decoded); + + if (n_draft_total > 0) { + const float draft_ratio = (float) n_draft_accepted / n_draft_total; + SLT_INF(*this, + "\n" + "draft acceptance rate = %0.5f (%5d accepted / %5d generated)\n", + draft_ratio, n_draft_accepted, n_draft_total + ); + } } json to_json() const { @@ -3290,6 +3324,9 @@ struct server_context { llama_tokens draft = common_speculative_gen_draft(slot.spec, params_spec, slot.cache_tokens, id); + // keep track of total number of tokens generated in the draft + slot.n_draft_total += draft.size(); + // ignore small drafts if (slot.params.speculative.n_min > (int) draft.size()) { SLT_DBG(slot, "ignoring small draft: %d < %d\n", (int) draft.size(), slot.params.speculative.n_min); @@ -3315,6 +3352,9 @@ struct server_context { slot.n_past += ids.size(); slot.n_decoded += ids.size(); + // update how many tokens out of draft was accepted + slot.n_draft_accepted += ids.size() - 1; + slot.cache_tokens.push_back(id); slot.cache_tokens.insert(slot.cache_tokens.end(), ids.begin(), ids.end() - 1); From dd373dd3bf81eced3e711fb7cb49123a6105933e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Fri, 28 Mar 2025 18:08:52 +0100 Subject: [PATCH 23/26] llama: fix error on bad grammar (#12628) --- common/sampling.cpp | 3 +++ include/llama.h | 4 ++++ src/llama-sampling.cpp | 5 +++++ 3 files changed, 12 insertions(+) diff --git a/common/sampling.cpp b/common/sampling.cpp index baf22066d..1735b6501 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -208,6 +208,9 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co trigger_patterns_c.data(), trigger_patterns_c.size(), trigger_tokens.data(), trigger_tokens.size()) : llama_sampler_init_grammar(vocab, params.grammar.c_str(), "root"); + if (!grmr) { + return nullptr; + } } auto * result = new common_sampler { diff --git a/include/llama.h b/include/llama.h index 25a9f8278..c66a23709 100644 --- a/include/llama.h +++ b/include/llama.h @@ -1265,6 +1265,10 @@ extern "C" { float tau, float eta); + /// @details Intializes a GBNF grammar, see grammars/README.md for details. + /// @param vocab The vocabulary that this grammar will be used with. + /// @param grammar_str The production rules for the grammar, encoded as a string. Returns an empty grammar if empty. Returns NULL if parsing of grammar_str fails. + /// @param grammar_root The name of the start symbol for the grammar. LLAMA_API struct llama_sampler * llama_sampler_init_grammar( const struct llama_vocab * vocab, const char * grammar_str, diff --git a/src/llama-sampling.cpp b/src/llama-sampling.cpp index c25977ca3..d14979850 100644 --- a/src/llama-sampling.cpp +++ b/src/llama-sampling.cpp @@ -1477,6 +1477,7 @@ static struct llama_sampler * llama_sampler_grammar_clone(const struct llama_sam const auto * ctx = (const llama_sampler_grammar *) smpl->ctx; auto * result = llama_sampler_init_grammar_impl(ctx->vocab, nullptr, nullptr, false, nullptr, 0, nullptr, 0, nullptr, 0); + GGML_ASSERT(result); // copy the state { @@ -1548,6 +1549,10 @@ static struct llama_sampler * llama_sampler_init_grammar_impl( /* .grammar_root = */ grammar_root, /* .grammar = */ llama_grammar_init_impl(vocab, grammar_str, grammar_root, lazy, trigger_patterns, num_trigger_patterns, trigger_tokens, num_trigger_tokens), }; + if (!ctx->grammar) { + delete ctx; + return nullptr; + } } else { *ctx = { /* .vocab = */ vocab, From b86f6007234da4bff51a3ebef2bdb952b52059c6 Mon Sep 17 00:00:00 2001 From: Icenowy Zheng Date: Sat, 29 Mar 2025 01:51:06 +0800 Subject: [PATCH 24/26] vulkan: fix coopmat shader generation when cross-compiling (#12272) * vulkan: fix coopmat shader generation when cross-compiling Previously the status of coopmat{,2} support isn't passed to the vulkan-shaders-gen project building on the host, which leads to build failure because of the cross-compiling code expecting coopmat{,2} shaders that didn't get generated. Fix this by passing the coopmat{,2} support status to vulkan-shaders subproject. Signed-off-by: Icenowy Zheng * Only call coop-mat shaders once * Fix whitespace --------- Signed-off-by: Icenowy Zheng Co-authored-by: bandoti <141645996+bandoti@users.noreply.github.com> --- ggml/src/ggml-vulkan/CMakeLists.txt | 54 +++++++++++-------- .../ggml-vulkan/vulkan-shaders/CMakeLists.txt | 6 +++ 2 files changed, 38 insertions(+), 22 deletions(-) diff --git a/ggml/src/ggml-vulkan/CMakeLists.txt b/ggml/src/ggml-vulkan/CMakeLists.txt index d970f7e20..8ef28e2d5 100644 --- a/ggml/src/ggml-vulkan/CMakeLists.txt +++ b/ggml/src/ggml-vulkan/CMakeLists.txt @@ -23,32 +23,40 @@ if (Vulkan_FOUND) ../../include/ggml-vulkan.h ) - # Compile a test shader to determine whether GL_KHR_cooperative_matrix is supported. - # If it's not, there will be an error to stderr. - # If it's supported, set a define to indicate that we should compile those shaders - execute_process(COMMAND ${Vulkan_GLSLC_EXECUTABLE} -o - -fshader-stage=compute --target-env=vulkan1.3 "${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders/test_coopmat_support.comp" - OUTPUT_VARIABLE glslc_output - ERROR_VARIABLE glslc_error) + if(NOT DEFINED GGML_VULKAN_COOPMAT_GLSLC_SUPPORT) + # Compile a test shader to determine whether GL_KHR_cooperative_matrix is supported. + # If it's not, there will be an error to stderr. + # If it's supported, set a define to indicate that we should compile those shaders + execute_process(COMMAND ${Vulkan_GLSLC_EXECUTABLE} -o - -fshader-stage=compute --target-env=vulkan1.3 "${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders/test_coopmat_support.comp" + OUTPUT_VARIABLE glslc_output + ERROR_VARIABLE glslc_error) - if (${glslc_error} MATCHES ".*extension not supported: GL_KHR_cooperative_matrix.*") - message(STATUS "GL_KHR_cooperative_matrix not supported by glslc") - else() - message(STATUS "GL_KHR_cooperative_matrix supported by glslc") - add_compile_definitions(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT) + if (${glslc_error} MATCHES ".*extension not supported: GL_KHR_cooperative_matrix.*") + message(STATUS "GL_KHR_cooperative_matrix not supported by glslc") + set(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT OFF CACHE INTERNAL "Whether coopmat is supported by glslc") + else() + message(STATUS "GL_KHR_cooperative_matrix supported by glslc") + add_compile_definitions(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT) + set(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT ON CACHE INTERNAL "Whether coopmat is supported by glslc") + endif() endif() - # Compile a test shader to determine whether GL_NV_cooperative_matrix2 is supported. - # If it's not, there will be an error to stderr. - # If it's supported, set a define to indicate that we should compile those shaders - execute_process(COMMAND ${Vulkan_GLSLC_EXECUTABLE} -o - -fshader-stage=compute --target-env=vulkan1.3 "${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders/test_coopmat2_support.comp" - OUTPUT_VARIABLE glslc_output - ERROR_VARIABLE glslc_error) + if(NOT DEFINED GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT) + # Compile a test shader to determine whether GL_NV_cooperative_matrix2 is supported. + # If it's not, there will be an error to stderr. + # If it's supported, set a define to indicate that we should compile those shaders + execute_process(COMMAND ${Vulkan_GLSLC_EXECUTABLE} -o - -fshader-stage=compute --target-env=vulkan1.3 "${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders/test_coopmat2_support.comp" + OUTPUT_VARIABLE glslc_output + ERROR_VARIABLE glslc_error) - if (${glslc_error} MATCHES ".*extension not supported: GL_NV_cooperative_matrix2.*") - message(STATUS "GL_NV_cooperative_matrix2 not supported by glslc") - else() - message(STATUS "GL_NV_cooperative_matrix2 supported by glslc") - add_compile_definitions(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT) + if (${glslc_error} MATCHES ".*extension not supported: GL_NV_cooperative_matrix2.*") + message(STATUS "GL_NV_cooperative_matrix2 not supported by glslc") + set(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT OFF CACHE INTERNAL "Whether coopmat2 is supported by glslc") + else() + message(STATUS "GL_NV_cooperative_matrix2 supported by glslc") + add_compile_definitions(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT) + set(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT ON CACHE INTERNAL "Whether coopmat2 is supported by glslc") + endif() endif() target_link_libraries(ggml-vulkan PRIVATE Vulkan::Vulkan) @@ -119,6 +127,8 @@ if (Vulkan_FOUND) SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders CMAKE_ARGS -DCMAKE_TOOLCHAIN_FILE=${HOST_CMAKE_TOOLCHAIN_FILE} -DCMAKE_INSTALL_PREFIX=${CMAKE_BINARY_DIR} + -DGGML_VULKAN_COOPMAT_GLSLC_SUPPORT=${GGML_VULKAN_COOPMAT_GLSLC_SUPPORT} + -DGGML_VULKAN_COOPMAT2_GLSLC_SUPPORT=${GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT} BUILD_COMMAND ${CMAKE_COMMAND} --build . INSTALL_COMMAND ${CMAKE_COMMAND} --install . INSTALL_DIR ${CMAKE_BINARY_DIR} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt b/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt index 51c78b7d2..b1e175021 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +++ b/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt @@ -1,5 +1,11 @@ find_package (Threads REQUIRED) +if (GGML_VULKAN_COOPMAT_GLSLC_SUPPORT) + add_compile_definitions(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT) +endif() +if (GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT) + add_compile_definitions(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT) +endif() set(TARGET vulkan-shaders-gen) add_executable(${TARGET} vulkan-shaders-gen.cpp) install(TARGETS ${TARGET} RUNTIME) From b4ae50810e4304d052e630784c14bde7e79e4132 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Fri, 28 Mar 2025 20:21:59 +0200 Subject: [PATCH 25/26] metal : improve FA + improve MoE (#12612) * ggml : FA with different K, V head sizes (CPU) ggml-ci * metal : add FA with HS=192 * metal : extend FA to support different K and V head sizes ggml-ci * metal : add FA vector kernels for heads K 192 and V 128 ggml-ci * ggml : restrict op on other backends to equal head sizes ggml-ci * metal : optimize FA-vec kernel ggml-ci * metal : FA remove mq registers * metal : improve MoE mul_mat_id condition ggml-ci * metal : fix comments + remove unnecessary addition ggml-ci * metal : avoid too much shared memory usage with mul_mat_id ggml-ci --- ggml/include/ggml.h | 10 +- ggml/src/ggml-cpu/ggml-cpu.c | 50 +- ggml/src/ggml-cuda/ggml-cuda.cu | 7 + ggml/src/ggml-metal/ggml-metal-impl.h | 9 +- ggml/src/ggml-metal/ggml-metal.m | 955 +++++++++++++++----------- ggml/src/ggml-metal/ggml-metal.metal | 508 ++++++++------ ggml/src/ggml-vulkan/ggml-vulkan.cpp | 4 + ggml/src/ggml.c | 2 +- src/llama-context.cpp | 5 - tests/test-backend-ops.cpp | 69 +- 10 files changed, 913 insertions(+), 706 deletions(-) diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index cb3edb10d..452c967b0 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -1791,11 +1791,11 @@ extern "C" { #define GGML_KQ_MASK_PAD 64 - // q: [n_embd, n_batch, n_head, 1] - // k: [n_embd, n_kv, n_head_kv, 1] - // v: [n_embd, n_kv, n_head_kv, 1] !! not transposed !! - // mask: [n_kv, n_batch_pad, 1, 1] !! n_batch_pad = GGML_PAD(n_batch, GGML_KQ_MASK_PAD) !! - // res: [n_embd, n_head, n_batch, 1] !! permuted !! + // q: [n_embd_k, n_batch, n_head, 1] + // k: [n_embd_k, n_kv, n_head_kv, 1] + // v: [n_embd_v, n_kv, n_head_kv, 1] !! not transposed !! + // mask: [n_kv, n_batch_pad, 1, 1] !! n_batch_pad = GGML_PAD(n_batch, GGML_KQ_MASK_PAD) !! + // res: [n_embd_v, n_head, n_batch, 1] !! permuted !! GGML_API struct ggml_tensor * ggml_flash_attn_ext( struct ggml_context * ctx, struct ggml_tensor * q, diff --git a/ggml/src/ggml-cpu/ggml-cpu.c b/ggml/src/ggml-cpu/ggml-cpu.c index 2dbe83558..fde837aed 100644 --- a/ggml/src/ggml-cpu/ggml-cpu.c +++ b/ggml/src/ggml-cpu/ggml-cpu.c @@ -12238,10 +12238,11 @@ static void ggml_compute_forward_flash_attn_ext_f16( const int ith = params->ith; const int nth = params->nth; - const int64_t D = neq0; - const int64_t N = neq1; + const int64_t DK = nek0; + const int64_t DV = nev0; + const int64_t N = neq1; - GGML_ASSERT(ne0 == D); + GGML_ASSERT(ne0 == DV); GGML_ASSERT(ne2 == N); // input tensor rows must be contiguous @@ -12249,12 +12250,11 @@ static void ggml_compute_forward_flash_attn_ext_f16( GGML_ASSERT(nbk0 == ggml_type_size(k->type)); GGML_ASSERT(nbv0 == ggml_type_size(v->type)); - GGML_ASSERT(neq0 == D); - GGML_ASSERT(nek0 == D); - GGML_ASSERT(nev0 == D); + GGML_ASSERT(neq0 == DK); + GGML_ASSERT(nek0 == DK); + GGML_ASSERT(nev0 == DV); GGML_ASSERT(neq1 == N); - GGML_ASSERT(nev0 == D); // dst cannot be transposed or permuted GGML_ASSERT(nb0 == sizeof(float)); @@ -12320,15 +12320,15 @@ static void ggml_compute_forward_flash_attn_ext_f16( float S = 0.0f; // sum float M = -INFINITY; // maximum KQ value - float * VKQ32 = (float *) params->wdata + ith*(3*D + CACHE_LINE_SIZE_F32); // FP32 VKQ accumulator - float * V32 = (VKQ32 + 1*D); // (temporary) FP32 V buffer - ggml_fp16_t * VKQ16 = (ggml_fp16_t *) (VKQ32 + 1*D); // (temporary) FP16 VKQ accumulator - ggml_fp16_t * Q_q = (ggml_fp16_t *) (VKQ32 + 2*D); // (temporary) buffer for Q converted to quantized/FP16 + float * VKQ32 = (float *) params->wdata + ith*(1*DK + 2*DV + CACHE_LINE_SIZE_F32); // FP32 VKQ accumulator + float * V32 = (VKQ32 + 1*DV); // (temporary) FP32 V buffer + ggml_fp16_t * VKQ16 = (ggml_fp16_t *) (VKQ32 + 1*DV); // (temporary) FP16 VKQ accumulator + ggml_fp16_t * Q_q = (ggml_fp16_t *) (VKQ32 + 2*DV); // (temporary) buffer for Q converted to quantized/FP16 if (v->type == GGML_TYPE_F16) { - memset(VKQ16, 0, D*sizeof(ggml_fp16_t)); + memset(VKQ16, 0, DV*sizeof(ggml_fp16_t)); } else { - memset(VKQ32, 0, D*sizeof(float)); + memset(VKQ32, 0, DV*sizeof(float)); } const ggml_fp16_t * mp = mask ? (ggml_fp16_t *)((char *) mask->data + iq1*mask->nb[1]) : NULL; @@ -12342,7 +12342,7 @@ static void ggml_compute_forward_flash_attn_ext_f16( const int iv2 = iq2 / rv2; const float * pq = (const float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3)); - q_to_vec_dot(pq, Q_q, D); + q_to_vec_dot(pq, Q_q, DK); // online softmax / attention // loop over n_kv and n_head_kv @@ -12356,7 +12356,7 @@ static void ggml_compute_forward_flash_attn_ext_f16( float s; // KQ value const char * k_data = (const char *) k->data + ( ic*nbk1 + ik2*nbk2 + ik3*nbk3); - kq_vec_dot(D, &s, 0, k_data, 0, Q_q, 0, 1); + kq_vec_dot(DK, &s, 0, k_data, 0, Q_q, 0, 1); s = s*scale; // scale KQ value @@ -12380,14 +12380,14 @@ static void ggml_compute_forward_flash_attn_ext_f16( ms = expf(Mold - M); // V = V*expf(Mold - M) - ggml_vec_scale_f16(D, VKQ16, ms); + ggml_vec_scale_f16(DV, VKQ16, ms); } else { // no new maximum, ms == 1.0f, vs != 1.0f vs = expf(s - M); } // V += v*expf(s - M) - ggml_vec_mad_f16(D, VKQ16, (const ggml_fp16_t *) v_data, vs); + ggml_vec_mad_f16(DV, VKQ16, (const ggml_fp16_t *) v_data, vs); } else { if (s > M) { // s is new maximum, ms < 1.0f, vs == expf(s - s) == 1.0f @@ -12395,30 +12395,30 @@ static void ggml_compute_forward_flash_attn_ext_f16( ms = expf(Mold - M); // V = V*expf(Mold - M) - ggml_vec_scale_f32(D, VKQ32, ms); + ggml_vec_scale_f32(DV, VKQ32, ms); } else { // no new maximum, ms == 1.0f, vs != 1.0f vs = expf(s - M); } - v_to_float(v_data, V32, D); + v_to_float(v_data, V32, DV); // V += v*expf(s - M) - ggml_vec_mad_f32(D, VKQ32, V32, vs); + ggml_vec_mad_f32(DV, VKQ32, V32, vs); } S = S*ms + vs; // scale and increment sum with partial sum } if (v->type == GGML_TYPE_F16) { - for (int64_t d = 0; d < D; ++d) { + for (int64_t d = 0; d < DV; ++d) { VKQ32[d] = GGML_FP16_TO_FP32(VKQ16[d]); } } // V /= S const float S_inv = 1.0f/S; - ggml_vec_scale_f32(D, VKQ32, S_inv); + ggml_vec_scale_f32(DV, VKQ32, S_inv); // dst indices const int i1 = iq1; @@ -15277,7 +15277,6 @@ struct ggml_cplan ggml_graph_plan( size_t cur = 0; if (!ggml_cpu_extra_work_size(n_threads, node, &cur)) { - switch (node->op) { case GGML_OP_CPY: case GGML_OP_DUP: @@ -15386,9 +15385,10 @@ struct ggml_cplan ggml_graph_plan( } break; case GGML_OP_FLASH_ATTN_EXT: { - const int64_t ne00 = node->src[0]->ne[0]; // D + const int64_t ne10 = node->src[1]->ne[0]; // DK + const int64_t ne20 = node->src[2]->ne[0]; // DV - cur = 3*sizeof(float)*ne00*n_tasks; // 3x head size/thread + cur = sizeof(float)*(1*ne10 + 2*ne20)*n_tasks; // 1x head size K + 2x head size V (per thread) } break; case GGML_OP_FLASH_ATTN_BACK: { diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 3bb472ffb..f2ad692f6 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -3232,6 +3232,13 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g #ifndef FLASH_ATTN_AVAILABLE return false; #endif // FLASH_ATTN_AVAILABLE + if (op->src[1]->ne[0] != op->src[2]->ne[0]) { + // different head sizes of K and V are not supported yet + return false; + } + if (op->src[0]->ne[0] == 192) { + return false; + } if (op->src[0]->ne[3] != 1) { return false; } diff --git a/ggml/src/ggml-metal/ggml-metal-impl.h b/ggml/src/ggml-metal/ggml-metal-impl.h index ca5a00b03..8721b272d 100644 --- a/ggml/src/ggml-metal/ggml-metal-impl.h +++ b/ggml/src/ggml-metal/ggml-metal-impl.h @@ -219,9 +219,12 @@ typedef struct { int32_t ne11; int32_t ne_12_2; // assume K and V are same shape int32_t ne_12_3; - uint64_t nb_12_1; - uint64_t nb_12_2; - uint64_t nb_12_3; + uint64_t nb11; + uint64_t nb12; + uint64_t nb13; + uint64_t nb21; + uint64_t nb22; + uint64_t nb23; uint64_t nb31; int32_t ne1; int32_t ne2; diff --git a/ggml/src/ggml-metal/ggml-metal.m b/ggml/src/ggml-metal/ggml-metal.m index 195d96782..3942013f4 100644 --- a/ggml/src/ggml-metal/ggml-metal.m +++ b/ggml/src/ggml-metal/ggml-metal.m @@ -351,42 +351,56 @@ enum ggml_metal_kernel_type { GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H192, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_HK192_HV128, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H64, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H80, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H96, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H112, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H128, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H192, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_HK192_HV128, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H256, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H64, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H80, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H96, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H112, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H128, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H192, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_HK192_HV128, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H256, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H64, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H80, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H96, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H112, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H128, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H192, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_HK192_HV128, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H256, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H64, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H80, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H96, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H112, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H128, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H192, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_HK192_HV128, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H256, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H64, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H80, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H96, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H112, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H128, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H192, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_HK192_HV128, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H256, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H64, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H80, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H96, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H112, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H128, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H192, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK192_HV128, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H256, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H128, @@ -395,6 +409,20 @@ enum ggml_metal_kernel_type { GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H128, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H128, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H128, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H192, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H192, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H192, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H192, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H192, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H192, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H192, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_HK192_HV128, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_HK192_HV128, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_HK192_HV128, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_HK192_HV128, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_HK192_HV128, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_HK192_HV128, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_HK192_HV128, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H256, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H256, @@ -758,313 +786,341 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de // simd_sum and simd_max requires MTLGPUFamilyApple7 - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD, add, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW, add_row, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUB, sub, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUB_ROW, sub_row, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL, mul, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_ROW, mul_row, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV, div, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV_ROW, div_row, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_F32, repeat_f32, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_F16, repeat_f16, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_I32, repeat_i32, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_I16, repeat_i16, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SCALE, scale, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SCALE_4, scale_4, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CLAMP, clamp, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_TANH, tanh, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RELU, relu, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SIGMOID, sigmoid, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU, gelu, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_4, gelu_4, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_QUICK, gelu_quick, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_QUICK_4, gelu_quick_4, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU, silu, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU_4, silu_4, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ELU, elu, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16, soft_max_f16, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16_4, soft_max_f16_4, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32, soft_max_f32, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32_4, soft_max_f32_4, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF, diag_mask_inf, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF_8, diag_mask_inf_8, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_F32, get_rows_f32, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_F16, get_rows_f16, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_BF16, get_rows_bf16, use_bfloat); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_0, get_rows_q4_0, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_1, get_rows_q4_1, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_0, get_rows_q5_0, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_1, get_rows_q5_1, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q8_0, get_rows_q8_0, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q2_K, get_rows_q2_K, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q3_K, get_rows_q3_K, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_K, get_rows_q4_K, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_K, get_rows_q5_K, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q6_K, get_rows_q6_K, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XXS, get_rows_iq2_xxs, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS, get_rows_iq2_xs, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_XXS, get_rows_iq3_xxs, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_S, get_rows_iq3_s, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_S, get_rows_iq2_s, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_S, get_rows_iq1_s, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_M, get_rows_iq1_m, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL, get_rows_iq4_nl, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS, get_rows_iq4_xs, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_I32, get_rows_i32, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RMS_NORM, rms_norm, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_L2_NORM, l2_norm, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GROUP_NORM, group_norm, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NORM, norm, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_CONV_F32, ssm_conv_f32, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32, ssm_scan_f32, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RWKV_WKV6_F32, rwkv_wkv6_f32, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RWKV_WKV7_F32, rwkv_wkv7_f32, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32, mul_mv_f32_f32, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32, mul_mv_bf16_f32, has_simdgroup_reduction && use_bfloat); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_1ROW, mul_mv_bf16_f32_1row, has_simdgroup_reduction && use_bfloat); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_L4, mul_mv_bf16_f32_l4, has_simdgroup_reduction && use_bfloat); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_BF16, mul_mv_bf16_bf16, has_simdgroup_reduction && use_bfloat); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32, mul_mv_f16_f32, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW, mul_mv_f16_f32_1row, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4, mul_mv_f16_f32_l4, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16, mul_mv_f16_f16, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_0_F32, mul_mv_q4_0_f32, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_1_F32, mul_mv_q4_1_f32, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32, mul_mv_q5_0_f32, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32, mul_mv_q5_1_f32, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32, mul_mv_q8_0_f32, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_2, mul_mv_ext_f16_f32_r1_2, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_3, mul_mv_ext_f16_f32_r1_3, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_4, mul_mv_ext_f16_f32_r1_4, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_5, mul_mv_ext_f16_f32_r1_5, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_2, mul_mv_ext_q4_0_f32_r1_2, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_3, mul_mv_ext_q4_0_f32_r1_3, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_4, mul_mv_ext_q4_0_f32_r1_4, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_5, mul_mv_ext_q4_0_f32_r1_5, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_2, mul_mv_ext_q4_1_f32_r1_2, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_3, mul_mv_ext_q4_1_f32_r1_3, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_4, mul_mv_ext_q4_1_f32_r1_4, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_5, mul_mv_ext_q4_1_f32_r1_5, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_2, mul_mv_ext_q5_0_f32_r1_2, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_3, mul_mv_ext_q5_0_f32_r1_3, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_4, mul_mv_ext_q5_0_f32_r1_4, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_5, mul_mv_ext_q5_0_f32_r1_5, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_2, mul_mv_ext_q5_1_f32_r1_2, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_3, mul_mv_ext_q5_1_f32_r1_3, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_4, mul_mv_ext_q5_1_f32_r1_4, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_5, mul_mv_ext_q5_1_f32_r1_5, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_2, mul_mv_ext_q8_0_f32_r1_2, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_3, mul_mv_ext_q8_0_f32_r1_3, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_4, mul_mv_ext_q8_0_f32_r1_4, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_5, mul_mv_ext_q8_0_f32_r1_5, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_2, mul_mv_ext_q4_K_f32_r1_2, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_3, mul_mv_ext_q4_K_f32_r1_3, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_4, mul_mv_ext_q4_K_f32_r1_4, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_5, mul_mv_ext_q4_K_f32_r1_5, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_K_F32_R1_2, mul_mv_ext_q5_K_f32_r1_2, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_K_F32_R1_3, mul_mv_ext_q5_K_f32_r1_3, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_K_F32_R1_4, mul_mv_ext_q5_K_f32_r1_4, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_K_F32_R1_5, mul_mv_ext_q5_K_f32_r1_5, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q6_K_F32_R1_2, mul_mv_ext_q6_K_f32_r1_2, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q6_K_F32_R1_3, mul_mv_ext_q6_K_f32_r1_3, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q6_K_F32_R1_4, mul_mv_ext_q6_K_f32_r1_4, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q6_K_F32_R1_5, mul_mv_ext_q6_K_f32_r1_5, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_IQ4_NL_F32_R1_2, mul_mv_ext_iq4_nl_f32_r1_2, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_IQ4_NL_F32_R1_3, mul_mv_ext_iq4_nl_f32_r1_3, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_IQ4_NL_F32_R1_4, mul_mv_ext_iq4_nl_f32_r1_4, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_IQ4_NL_F32_R1_5, mul_mv_ext_iq4_nl_f32_r1_5, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q2_K_F32, mul_mv_q2_K_f32, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q3_K_F32, mul_mv_q3_K_f32, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_K_F32, mul_mv_q4_K_f32, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_K_F32, mul_mv_q5_K_f32, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q6_K_F32, mul_mv_q6_K_f32, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XXS_F32, mul_mv_iq2_xxs_f32, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32, mul_mv_iq2_xs_f32, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_XXS_F32, mul_mv_iq3_xxs_f32, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_S_F32, mul_mv_iq3_s_f32, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_S_F32, mul_mv_iq2_s_f32, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_S_F32, mul_mv_iq1_s_f32, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_M_F32, mul_mv_iq1_m_f32, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32, mul_mv_iq4_nl_f32, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_XS_F32, mul_mv_iq4_xs_f32, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32, mul_mv_id_f32_f32, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32, mul_mv_id_f16_f32, has_simdgroup_reduction); - //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_1ROW, mul_mv_id_f16_f32_1row, has_simdgroup_reduction); - //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_L4, mul_mv_id_f16_f32_l4, has_simdgroup_reduction); - //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F16, mul_mv_id_f16_f16, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_BF16_F32, mul_mv_id_bf16_f32, has_simdgroup_reduction && use_bfloat); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_0_F32, mul_mv_id_q4_0_f32, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_1_F32, mul_mv_id_q4_1_f32, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_0_F32, mul_mv_id_q5_0_f32, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_1_F32, mul_mv_id_q5_1_f32, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q8_0_F32, mul_mv_id_q8_0_f32, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q2_K_F32, mul_mv_id_q2_K_f32, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q3_K_F32, mul_mv_id_q3_K_f32, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_K_F32, mul_mv_id_q4_K_f32, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_K_F32, mul_mv_id_q5_K_f32, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q6_K_F32, mul_mv_id_q6_K_f32, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XXS_F32, mul_mv_id_iq2_xxs_f32, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32, mul_mv_id_iq2_xs_f32, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_XXS_F32, mul_mv_id_iq3_xxs_f32, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_S_F32, mul_mv_id_iq3_s_f32, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_S_F32, mul_mv_id_iq2_s_f32, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_S_F32, mul_mv_id_iq1_s_f32, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_M_F32, mul_mv_id_iq1_m_f32, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32, mul_mv_id_iq4_nl_f32, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32, mul_mv_id_iq4_xs_f32, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32, mul_mm_f32_f32, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32, mul_mm_f16_f32, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_BF16_F32, mul_mm_bf16_f32, has_simdgroup_mm && use_bfloat); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32, mul_mm_q4_0_f32, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_1_F32, mul_mm_q4_1_f32, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32, mul_mm_q5_0_f32, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_1_F32, mul_mm_q5_1_f32, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q8_0_F32, mul_mm_q8_0_f32, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q2_K_F32, mul_mm_q2_K_f32, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q3_K_F32, mul_mm_q3_K_f32, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_K_F32, mul_mm_q4_K_f32, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_K_F32, mul_mm_q5_K_f32, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q6_K_F32, mul_mm_q6_K_f32, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32, mul_mm_iq2_xxs_f32, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32, mul_mm_iq2_xs_f32, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_XXS_F32, mul_mm_iq3_xxs_f32, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_S_F32, mul_mm_iq3_s_f32, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_S_F32, mul_mm_iq2_s_f32, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32, mul_mm_iq1_s_f32, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F32, mul_mm_iq1_m_f32, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32, mul_mm_iq4_nl_f32, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32, mul_mm_iq4_xs_f32, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32, mul_mm_id_f32_f32, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32, mul_mm_id_f16_f32, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_BF16_F32, mul_mm_id_bf16_f32, has_simdgroup_mm && use_bfloat); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32, mul_mm_id_q4_0_f32, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F32, mul_mm_id_q4_1_f32, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F32, mul_mm_id_q5_0_f32, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F32, mul_mm_id_q5_1_f32, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F32, mul_mm_id_q8_0_f32, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F32, mul_mm_id_q2_K_f32, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F32, mul_mm_id_q3_K_f32, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F32, mul_mm_id_q4_K_f32, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_K_F32, mul_mm_id_q5_K_f32, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F32, mul_mm_id_q6_K_f32, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32, mul_mm_id_iq2_xxs_f32, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32, mul_mm_id_iq2_xs_f32, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F32, mul_mm_id_iq3_xxs_f32, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_S_F32, mul_mm_id_iq3_s_f32, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F32, mul_mm_id_iq2_s_f32, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32, mul_mm_id_iq1_s_f32, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F32, mul_mm_id_iq1_m_f32, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32, mul_mm_id_iq4_nl_f32, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32, mul_mm_id_iq4_xs_f32, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32, rope_norm_f32, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16, rope_norm_f16, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32, rope_neox_f32, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16, rope_neox_f16, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F16, im2col_f16, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F32, im2col_f32, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F16, im2col_ext_f16, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F32, im2col_ext_f32, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CONV_TRANSPOSE_1D_F32_F32, conv_transpose_1d_f32_f32, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CONV_TRANSPOSE_1D_F16_F32, conv_transpose_1d_f16_f32, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_UPSCALE_F32, upscale_f32, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_PAD_F32, pad_f32, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_PAD_REFLECT_1D_F32, pad_reflect_1d_f32, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_TIMESTEP_EMBEDDING_F32, timestep_embedding_f32, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARANGE_F32, arange_f32, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC, argsort_f32_i32_asc, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC, argsort_f32_i32_desc, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32, leaky_relu_f32, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64, flash_attn_ext_f16_h64, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80, flash_attn_ext_f16_h80, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96, flash_attn_ext_f16_h96, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112, flash_attn_ext_f16_h112, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128, flash_attn_ext_f16_h128, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, flash_attn_ext_f16_h256, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H64, flash_attn_ext_bf16_h64, has_simdgroup_mm && use_bfloat); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H80, flash_attn_ext_bf16_h80, has_simdgroup_mm && use_bfloat); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H96, flash_attn_ext_bf16_h96, has_simdgroup_mm && use_bfloat); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H112, flash_attn_ext_bf16_h112, has_simdgroup_mm && use_bfloat); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H128, flash_attn_ext_bf16_h128, has_simdgroup_mm && use_bfloat); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H256, flash_attn_ext_bf16_h256, has_simdgroup_mm && use_bfloat); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H64, flash_attn_ext_q4_0_h64, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H80, flash_attn_ext_q4_0_h80, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H96, flash_attn_ext_q4_0_h96, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H112, flash_attn_ext_q4_0_h112, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H128, flash_attn_ext_q4_0_h128, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H256, flash_attn_ext_q4_0_h256, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H64, flash_attn_ext_q4_1_h64, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H80, flash_attn_ext_q4_1_h80, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H96, flash_attn_ext_q4_1_h96, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H112, flash_attn_ext_q4_1_h112, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H128, flash_attn_ext_q4_1_h128, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H256, flash_attn_ext_q4_1_h256, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H64, flash_attn_ext_q5_0_h64, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H80, flash_attn_ext_q5_0_h80, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H96, flash_attn_ext_q5_0_h96, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H112, flash_attn_ext_q5_0_h112, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H128, flash_attn_ext_q5_0_h128, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H256, flash_attn_ext_q5_0_h256, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H64, flash_attn_ext_q5_1_h64, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H80, flash_attn_ext_q5_1_h80, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H96, flash_attn_ext_q5_1_h96, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H112, flash_attn_ext_q5_1_h112, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H128, flash_attn_ext_q5_1_h128, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H256, flash_attn_ext_q5_1_h256, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H64, flash_attn_ext_q8_0_h64, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H80, flash_attn_ext_q8_0_h80, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H96, flash_attn_ext_q8_0_h96, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H112, flash_attn_ext_q8_0_h112, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H128, flash_attn_ext_q8_0_h128, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H256, flash_attn_ext_q8_0_h256, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128, flash_attn_ext_vec_f16_h128, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H128, flash_attn_ext_vec_bf16_h128, has_simdgroup_reduction && use_bfloat); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H128, flash_attn_ext_vec_q4_0_h128, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H128, flash_attn_ext_vec_q4_1_h128, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H128, flash_attn_ext_vec_q5_0_h128, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H128, flash_attn_ext_vec_q5_1_h128, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H128, flash_attn_ext_vec_q8_0_h128, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, flash_attn_ext_vec_f16_h256, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H256, flash_attn_ext_vec_bf16_h256, has_simdgroup_reduction && use_bfloat); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H256, flash_attn_ext_vec_q4_0_h256, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H256, flash_attn_ext_vec_q4_1_h256, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H256, flash_attn_ext_vec_q5_0_h256, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H256, flash_attn_ext_vec_q5_1_h256, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H256, flash_attn_ext_vec_q8_0_h256, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_F32, set_f32, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_I32, set_i32, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F32, cpy_f32_f32, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F16, cpy_f32_f16, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_BF16, cpy_f32_bf16, use_bfloat); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F32, cpy_f16_f32, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F16, cpy_f16_f16, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_BF16_F32, cpy_bf16_f32, use_bfloat); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_BF16_BF16, cpy_bf16_bf16, use_bfloat); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0, cpy_f32_q8_0, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0, cpy_f32_q4_0, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1, cpy_f32_q4_1, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0, cpy_f32_q5_0, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1, cpy_f32_q5_1, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL, cpy_f32_iq4_nl, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q4_0_F32, cpy_q4_0_f32, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q4_0_F16, cpy_q4_0_f16, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q4_1_F32, cpy_q4_1_f32, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q4_1_F16, cpy_q4_1_f16, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q5_0_F32, cpy_q5_0_f32, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q5_0_F16, cpy_q5_0_f16, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q5_1_F32, cpy_q5_1_f32, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q5_1_F16, cpy_q5_1_f16, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q8_0_F32, cpy_q8_0_f32, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q8_0_F16, cpy_q8_0_f16, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CONCAT, concat, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SQR, sqr, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SQRT, sqrt, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SIN, sin, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_COS, cos, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGMAX, argmax, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32, pool_2d_avg_f32, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32, pool_2d_max_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD, add, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW, add_row, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUB, sub, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUB_ROW, sub_row, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL, mul, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_ROW, mul_row, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV, div, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV_ROW, div_row, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_F32, repeat_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_F16, repeat_f16, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_I32, repeat_i32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_I16, repeat_i16, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SCALE, scale, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SCALE_4, scale_4, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CLAMP, clamp, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_TANH, tanh, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RELU, relu, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SIGMOID, sigmoid, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU, gelu, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_4, gelu_4, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_QUICK, gelu_quick, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_QUICK_4, gelu_quick_4, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU, silu, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU_4, silu_4, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ELU, elu, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16, soft_max_f16, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16_4, soft_max_f16_4, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32, soft_max_f32, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32_4, soft_max_f32_4, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF, diag_mask_inf, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF_8, diag_mask_inf_8, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_F32, get_rows_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_F16, get_rows_f16, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_BF16, get_rows_bf16, use_bfloat); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_0, get_rows_q4_0, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_1, get_rows_q4_1, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_0, get_rows_q5_0, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_1, get_rows_q5_1, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q8_0, get_rows_q8_0, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q2_K, get_rows_q2_K, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q3_K, get_rows_q3_K, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_K, get_rows_q4_K, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_K, get_rows_q5_K, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q6_K, get_rows_q6_K, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XXS, get_rows_iq2_xxs, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS, get_rows_iq2_xs, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_XXS, get_rows_iq3_xxs, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_S, get_rows_iq3_s, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_S, get_rows_iq2_s, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_S, get_rows_iq1_s, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_M, get_rows_iq1_m, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL, get_rows_iq4_nl, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS, get_rows_iq4_xs, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_I32, get_rows_i32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RMS_NORM, rms_norm, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_L2_NORM, l2_norm, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GROUP_NORM, group_norm, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NORM, norm, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_CONV_F32, ssm_conv_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32, ssm_scan_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RWKV_WKV6_F32, rwkv_wkv6_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RWKV_WKV7_F32, rwkv_wkv7_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32, mul_mv_f32_f32, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32, mul_mv_bf16_f32, has_simdgroup_reduction && use_bfloat); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_1ROW, mul_mv_bf16_f32_1row, has_simdgroup_reduction && use_bfloat); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_L4, mul_mv_bf16_f32_l4, has_simdgroup_reduction && use_bfloat); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_BF16, mul_mv_bf16_bf16, has_simdgroup_reduction && use_bfloat); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32, mul_mv_f16_f32, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW, mul_mv_f16_f32_1row, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4, mul_mv_f16_f32_l4, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16, mul_mv_f16_f16, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_0_F32, mul_mv_q4_0_f32, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_1_F32, mul_mv_q4_1_f32, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32, mul_mv_q5_0_f32, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32, mul_mv_q5_1_f32, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32, mul_mv_q8_0_f32, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_2, mul_mv_ext_f16_f32_r1_2, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_3, mul_mv_ext_f16_f32_r1_3, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_4, mul_mv_ext_f16_f32_r1_4, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_5, mul_mv_ext_f16_f32_r1_5, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_2, mul_mv_ext_q4_0_f32_r1_2, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_3, mul_mv_ext_q4_0_f32_r1_3, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_4, mul_mv_ext_q4_0_f32_r1_4, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_5, mul_mv_ext_q4_0_f32_r1_5, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_2, mul_mv_ext_q4_1_f32_r1_2, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_3, mul_mv_ext_q4_1_f32_r1_3, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_4, mul_mv_ext_q4_1_f32_r1_4, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_5, mul_mv_ext_q4_1_f32_r1_5, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_2, mul_mv_ext_q5_0_f32_r1_2, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_3, mul_mv_ext_q5_0_f32_r1_3, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_4, mul_mv_ext_q5_0_f32_r1_4, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_5, mul_mv_ext_q5_0_f32_r1_5, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_2, mul_mv_ext_q5_1_f32_r1_2, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_3, mul_mv_ext_q5_1_f32_r1_3, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_4, mul_mv_ext_q5_1_f32_r1_4, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_5, mul_mv_ext_q5_1_f32_r1_5, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_2, mul_mv_ext_q8_0_f32_r1_2, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_3, mul_mv_ext_q8_0_f32_r1_3, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_4, mul_mv_ext_q8_0_f32_r1_4, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_5, mul_mv_ext_q8_0_f32_r1_5, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_2, mul_mv_ext_q4_K_f32_r1_2, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_3, mul_mv_ext_q4_K_f32_r1_3, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_4, mul_mv_ext_q4_K_f32_r1_4, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_5, mul_mv_ext_q4_K_f32_r1_5, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_K_F32_R1_2, mul_mv_ext_q5_K_f32_r1_2, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_K_F32_R1_3, mul_mv_ext_q5_K_f32_r1_3, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_K_F32_R1_4, mul_mv_ext_q5_K_f32_r1_4, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_K_F32_R1_5, mul_mv_ext_q5_K_f32_r1_5, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q6_K_F32_R1_2, mul_mv_ext_q6_K_f32_r1_2, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q6_K_F32_R1_3, mul_mv_ext_q6_K_f32_r1_3, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q6_K_F32_R1_4, mul_mv_ext_q6_K_f32_r1_4, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q6_K_F32_R1_5, mul_mv_ext_q6_K_f32_r1_5, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_IQ4_NL_F32_R1_2, mul_mv_ext_iq4_nl_f32_r1_2, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_IQ4_NL_F32_R1_3, mul_mv_ext_iq4_nl_f32_r1_3, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_IQ4_NL_F32_R1_4, mul_mv_ext_iq4_nl_f32_r1_4, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_IQ4_NL_F32_R1_5, mul_mv_ext_iq4_nl_f32_r1_5, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q2_K_F32, mul_mv_q2_K_f32, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q3_K_F32, mul_mv_q3_K_f32, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_K_F32, mul_mv_q4_K_f32, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_K_F32, mul_mv_q5_K_f32, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q6_K_F32, mul_mv_q6_K_f32, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XXS_F32, mul_mv_iq2_xxs_f32, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32, mul_mv_iq2_xs_f32, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_XXS_F32, mul_mv_iq3_xxs_f32, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_S_F32, mul_mv_iq3_s_f32, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_S_F32, mul_mv_iq2_s_f32, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_S_F32, mul_mv_iq1_s_f32, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_M_F32, mul_mv_iq1_m_f32, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32, mul_mv_iq4_nl_f32, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_XS_F32, mul_mv_iq4_xs_f32, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32, mul_mv_id_f32_f32, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32, mul_mv_id_f16_f32, has_simdgroup_reduction); + //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_1ROW, mul_mv_id_f16_f32_1row, has_simdgroup_reduction); + //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_L4, mul_mv_id_f16_f32_l4, has_simdgroup_reduction); + //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F16, mul_mv_id_f16_f16, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_BF16_F32, mul_mv_id_bf16_f32, has_simdgroup_reduction && use_bfloat); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_0_F32, mul_mv_id_q4_0_f32, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_1_F32, mul_mv_id_q4_1_f32, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_0_F32, mul_mv_id_q5_0_f32, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_1_F32, mul_mv_id_q5_1_f32, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q8_0_F32, mul_mv_id_q8_0_f32, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q2_K_F32, mul_mv_id_q2_K_f32, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q3_K_F32, mul_mv_id_q3_K_f32, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_K_F32, mul_mv_id_q4_K_f32, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_K_F32, mul_mv_id_q5_K_f32, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q6_K_F32, mul_mv_id_q6_K_f32, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XXS_F32, mul_mv_id_iq2_xxs_f32, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32, mul_mv_id_iq2_xs_f32, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_XXS_F32, mul_mv_id_iq3_xxs_f32, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_S_F32, mul_mv_id_iq3_s_f32, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_S_F32, mul_mv_id_iq2_s_f32, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_S_F32, mul_mv_id_iq1_s_f32, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_M_F32, mul_mv_id_iq1_m_f32, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32, mul_mv_id_iq4_nl_f32, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32, mul_mv_id_iq4_xs_f32, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32, mul_mm_f32_f32, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32, mul_mm_f16_f32, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_BF16_F32, mul_mm_bf16_f32, has_simdgroup_mm && use_bfloat); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32, mul_mm_q4_0_f32, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_1_F32, mul_mm_q4_1_f32, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32, mul_mm_q5_0_f32, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_1_F32, mul_mm_q5_1_f32, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q8_0_F32, mul_mm_q8_0_f32, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q2_K_F32, mul_mm_q2_K_f32, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q3_K_F32, mul_mm_q3_K_f32, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_K_F32, mul_mm_q4_K_f32, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_K_F32, mul_mm_q5_K_f32, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q6_K_F32, mul_mm_q6_K_f32, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32, mul_mm_iq2_xxs_f32, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32, mul_mm_iq2_xs_f32, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_XXS_F32, mul_mm_iq3_xxs_f32, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_S_F32, mul_mm_iq3_s_f32, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_S_F32, mul_mm_iq2_s_f32, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32, mul_mm_iq1_s_f32, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F32, mul_mm_iq1_m_f32, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32, mul_mm_iq4_nl_f32, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32, mul_mm_iq4_xs_f32, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32, mul_mm_id_f32_f32, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32, mul_mm_id_f16_f32, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_BF16_F32, mul_mm_id_bf16_f32, has_simdgroup_mm && use_bfloat); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32, mul_mm_id_q4_0_f32, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F32, mul_mm_id_q4_1_f32, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F32, mul_mm_id_q5_0_f32, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F32, mul_mm_id_q5_1_f32, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F32, mul_mm_id_q8_0_f32, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F32, mul_mm_id_q2_K_f32, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F32, mul_mm_id_q3_K_f32, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F32, mul_mm_id_q4_K_f32, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_K_F32, mul_mm_id_q5_K_f32, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F32, mul_mm_id_q6_K_f32, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32, mul_mm_id_iq2_xxs_f32, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32, mul_mm_id_iq2_xs_f32, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F32, mul_mm_id_iq3_xxs_f32, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_S_F32, mul_mm_id_iq3_s_f32, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F32, mul_mm_id_iq2_s_f32, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32, mul_mm_id_iq1_s_f32, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F32, mul_mm_id_iq1_m_f32, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32, mul_mm_id_iq4_nl_f32, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32, mul_mm_id_iq4_xs_f32, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32, rope_norm_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16, rope_norm_f16, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32, rope_neox_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16, rope_neox_f16, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F16, im2col_f16, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F32, im2col_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F16, im2col_ext_f16, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F32, im2col_ext_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CONV_TRANSPOSE_1D_F32_F32, conv_transpose_1d_f32_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CONV_TRANSPOSE_1D_F16_F32, conv_transpose_1d_f16_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_UPSCALE_F32, upscale_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_PAD_F32, pad_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_PAD_REFLECT_1D_F32, pad_reflect_1d_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_TIMESTEP_EMBEDDING_F32, timestep_embedding_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARANGE_F32, arange_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC, argsort_f32_i32_asc, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC, argsort_f32_i32_desc, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32, leaky_relu_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64, flash_attn_ext_f16_h64, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80, flash_attn_ext_f16_h80, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96, flash_attn_ext_f16_h96, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112, flash_attn_ext_f16_h112, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128, flash_attn_ext_f16_h128, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H192, flash_attn_ext_f16_h192, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_HK192_HV128, flash_attn_ext_f16_hk192_hv128, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, flash_attn_ext_f16_h256, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H64, flash_attn_ext_bf16_h64, has_simdgroup_mm && use_bfloat); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H80, flash_attn_ext_bf16_h80, has_simdgroup_mm && use_bfloat); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H96, flash_attn_ext_bf16_h96, has_simdgroup_mm && use_bfloat); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H112, flash_attn_ext_bf16_h112, has_simdgroup_mm && use_bfloat); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H128, flash_attn_ext_bf16_h128, has_simdgroup_mm && use_bfloat); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H192, flash_attn_ext_bf16_h192, has_simdgroup_mm && use_bfloat); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_HK192_HV128, flash_attn_ext_bf16_hk192_hv128, has_simdgroup_mm && use_bfloat); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H256, flash_attn_ext_bf16_h256, has_simdgroup_mm && use_bfloat); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H64, flash_attn_ext_q4_0_h64, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H80, flash_attn_ext_q4_0_h80, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H96, flash_attn_ext_q4_0_h96, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H112, flash_attn_ext_q4_0_h112, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H128, flash_attn_ext_q4_0_h128, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H192, flash_attn_ext_q4_0_h192, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_HK192_HV128, flash_attn_ext_q4_0_hk192_hv128, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H256, flash_attn_ext_q4_0_h256, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H64, flash_attn_ext_q4_1_h64, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H80, flash_attn_ext_q4_1_h80, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H96, flash_attn_ext_q4_1_h96, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H112, flash_attn_ext_q4_1_h112, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H128, flash_attn_ext_q4_1_h128, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H192, flash_attn_ext_q4_1_h192, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_HK192_HV128, flash_attn_ext_q4_1_hk192_hv128, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H256, flash_attn_ext_q4_1_h256, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H64, flash_attn_ext_q5_0_h64, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H80, flash_attn_ext_q5_0_h80, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H96, flash_attn_ext_q5_0_h96, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H112, flash_attn_ext_q5_0_h112, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H128, flash_attn_ext_q5_0_h128, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H192, flash_attn_ext_q5_0_h192, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_HK192_HV128, flash_attn_ext_q5_0_hk192_hv128, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H256, flash_attn_ext_q5_0_h256, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H64, flash_attn_ext_q5_1_h64, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H80, flash_attn_ext_q5_1_h80, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H96, flash_attn_ext_q5_1_h96, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H112, flash_attn_ext_q5_1_h112, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H128, flash_attn_ext_q5_1_h128, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H192, flash_attn_ext_q5_1_h192, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_HK192_HV128, flash_attn_ext_q5_1_hk192_hv128, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H256, flash_attn_ext_q5_1_h256, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H64, flash_attn_ext_q8_0_h64, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H80, flash_attn_ext_q8_0_h80, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H96, flash_attn_ext_q8_0_h96, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H112, flash_attn_ext_q8_0_h112, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H128, flash_attn_ext_q8_0_h128, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H192, flash_attn_ext_q8_0_h192, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK192_HV128, flash_attn_ext_q8_0_hk192_hv128, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H256, flash_attn_ext_q8_0_h256, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128, flash_attn_ext_vec_f16_h128, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H128, flash_attn_ext_vec_bf16_h128, has_simdgroup_reduction && use_bfloat); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H128, flash_attn_ext_vec_q4_0_h128, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H128, flash_attn_ext_vec_q4_1_h128, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H128, flash_attn_ext_vec_q5_0_h128, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H128, flash_attn_ext_vec_q5_1_h128, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H128, flash_attn_ext_vec_q8_0_h128, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H192, flash_attn_ext_vec_f16_h192, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H192, flash_attn_ext_vec_bf16_h192, has_simdgroup_reduction && use_bfloat); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H192, flash_attn_ext_vec_q4_0_h192, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H192, flash_attn_ext_vec_q4_1_h192, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H192, flash_attn_ext_vec_q5_0_h192, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H192, flash_attn_ext_vec_q5_1_h192, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H192, flash_attn_ext_vec_q8_0_h192, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_HK192_HV128, flash_attn_ext_vec_f16_hk192_hv128, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_HK192_HV128, flash_attn_ext_vec_bf16_hk192_hv128, has_simdgroup_reduction && use_bfloat); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_HK192_HV128, flash_attn_ext_vec_q4_0_hk192_hv128, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_HK192_HV128, flash_attn_ext_vec_q4_1_hk192_hv128, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_HK192_HV128, flash_attn_ext_vec_q5_0_hk192_hv128, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_HK192_HV128, flash_attn_ext_vec_q5_1_hk192_hv128, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_HK192_HV128, flash_attn_ext_vec_q8_0_hk192_hv128, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, flash_attn_ext_vec_f16_h256, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H256, flash_attn_ext_vec_bf16_h256, has_simdgroup_reduction && use_bfloat); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H256, flash_attn_ext_vec_q4_0_h256, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H256, flash_attn_ext_vec_q4_1_h256, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H256, flash_attn_ext_vec_q5_0_h256, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H256, flash_attn_ext_vec_q5_1_h256, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H256, flash_attn_ext_vec_q8_0_h256, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_F32, set_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_I32, set_i32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F32, cpy_f32_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F16, cpy_f32_f16, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_BF16, cpy_f32_bf16, use_bfloat); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F32, cpy_f16_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F16, cpy_f16_f16, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_BF16_F32, cpy_bf16_f32, use_bfloat); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_BF16_BF16, cpy_bf16_bf16, use_bfloat); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0, cpy_f32_q8_0, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0, cpy_f32_q4_0, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1, cpy_f32_q4_1, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0, cpy_f32_q5_0, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1, cpy_f32_q5_1, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL, cpy_f32_iq4_nl, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q4_0_F32, cpy_q4_0_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q4_0_F16, cpy_q4_0_f16, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q4_1_F32, cpy_q4_1_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q4_1_F16, cpy_q4_1_f16, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q5_0_F32, cpy_q5_0_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q5_0_F16, cpy_q5_0_f16, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q5_1_F32, cpy_q5_1_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q5_1_F16, cpy_q5_1_f16, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q8_0_F32, cpy_q8_0_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q8_0_F16, cpy_q8_0_f16, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CONCAT, concat, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SQR, sqr, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SQRT, sqrt, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SIN, sin, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_COS, cos, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGMAX, argmax, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32, pool_2d_avg_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32, pool_2d_max_f32, true); } return ctx; @@ -2800,20 +2856,19 @@ static void ggml_metal_encode_node( // ne21 = n_rows const int dst_rows = ne20*ne21; const int dst_rows_min = n_as; - const int dst_rows_max = (device.maxThreadgroupMemoryLength - 32 - 8192)/4; + const int dst_rows_max = (device.maxThreadgroupMemoryLength/2 - 8192)/4; // max size of the rowids array in the kernel shared buffer - GGML_ASSERT(dst_rows <= dst_rows_max); + //GGML_ASSERT(dst_rows <= dst_rows_max); // for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs // AMD GPU and older A-chips will reuse matrix-vector multiplication kernel - // !!! - // TODO: for now, always use mat-vec kernels until we figure out how to improve the - // indirect matrix multiplication - // !!! if ([device supportsFamily:MTLGPUFamilyApple7] && ne00 % 32 == 0 && ne00 >= 64 && - dst_rows > dst_rows_min) { + //ne01 / ne02 >= 512 && // NOTE: this is based on Mixtral shapes, might need adjustments + dst_rows > dst_rows_min && + dst_rows <= dst_rows_max) { + // some Metal matrix data types require aligned pointers // ref: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf (Table 2.5) switch (src0->type) { @@ -3732,7 +3787,9 @@ static void ggml_metal_encode_node( GGML_ASSERT(src0->type == GGML_TYPE_F32); GGML_ASSERT(src1->type == src2->type); - GGML_ASSERT(ggml_are_same_shape (src1, src2)); + //GGML_ASSERT(ggml_are_same_shape (src1, src2)); + GGML_ASSERT(ne11 == ne21); + GGML_ASSERT(ne12 == ne22); struct ggml_tensor * src3 = node->src[3]; @@ -3779,125 +3836,161 @@ static void ggml_metal_encode_node( // TODO: add vec kernels for (ne00%64 == 0) and maybe also for (ne00%32 == 0) // for now avoiding mainly to keep the number of templates/kernels a bit lower - if (ne01 >= 4 || (ne00%128 != 0)) { + // these are now trivial to add after: https://github.com/ggml-org/llama.cpp/pull/12612 + if (ne01 >= 4 || (ne00%128 != 0 && ne00 != 192)) { switch (src1->type) { case GGML_TYPE_F16: { - switch (ne00) { - case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64 ].pipeline; break; - case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80 ].pipeline; break; - case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96 ].pipeline; break; - case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112].pipeline; break; - case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128].pipeline; break; - case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256].pipeline; break; - default: - { - GGML_LOG_ERROR("unsupported size: %lld\n", ne00); - GGML_LOG_ERROR("add template specialization for this size\n"); - GGML_ABORT("add template specialization for this size"); - } + if (ne00 == 192 && ne20 == 128) { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_HK192_HV128].pipeline; + } else { + switch (ne00) { + case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64 ].pipeline; break; + case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80 ].pipeline; break; + case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96 ].pipeline; break; + case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112].pipeline; break; + case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128].pipeline; break; + case 192: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H192].pipeline; break; + case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256].pipeline; break; + default: + { + GGML_LOG_ERROR("unsupported size: %lld\n", ne00); + GGML_LOG_ERROR("add template specialization for this size\n"); + GGML_ABORT("add template specialization for this size"); + } + } } } break; case GGML_TYPE_BF16: { - switch (ne00) { - case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H64 ].pipeline; break; - case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H80 ].pipeline; break; - case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H96 ].pipeline; break; - case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H112].pipeline; break; - case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H128].pipeline; break; - case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H256].pipeline; break; - default: - { - GGML_LOG_ERROR("unsupported size: %lld\n", ne00); - GGML_LOG_ERROR("add template specialization for this size\n"); - GGML_ABORT("add template specialization for this size"); - } + if (ne00 == 192 && ne20 == 128) { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_HK192_HV128].pipeline; + } else { + switch (ne00) { + case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H64 ].pipeline; break; + case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H80 ].pipeline; break; + case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H96 ].pipeline; break; + case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H112].pipeline; break; + case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H128].pipeline; break; + case 192: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H192].pipeline; break; + case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H256].pipeline; break; + default: + { + GGML_LOG_ERROR("unsupported size: %lld\n", ne00); + GGML_LOG_ERROR("add template specialization for this size\n"); + GGML_ABORT("add template specialization for this size"); + } + } } } break; case GGML_TYPE_Q4_0: { - switch (ne00) { - case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H64 ].pipeline; break; - case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H80 ].pipeline; break; - case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H96 ].pipeline; break; - case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H112].pipeline; break; - case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H128].pipeline; break; - case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H256].pipeline; break; - default: - { - GGML_LOG_ERROR("unsupported size: %lld\n", ne00); - GGML_LOG_ERROR("add template specialization for this size\n"); - GGML_ABORT("add template specialization for this size"); - } + if (ne00 == 192 && ne20 == 128) { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_HK192_HV128].pipeline; + } else { + switch (ne00) { + case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H64 ].pipeline; break; + case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H80 ].pipeline; break; + case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H96 ].pipeline; break; + case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H112].pipeline; break; + case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H128].pipeline; break; + case 192: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H192].pipeline; break; + case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H256].pipeline; break; + default: + { + GGML_LOG_ERROR("unsupported size: %lld\n", ne00); + GGML_LOG_ERROR("add template specialization for this size\n"); + GGML_ABORT("add template specialization for this size"); + } + } } } break; case GGML_TYPE_Q4_1: { - switch (ne00) { - case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H64 ].pipeline; break; - case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H80 ].pipeline; break; - case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H96 ].pipeline; break; - case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H112].pipeline; break; - case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H128].pipeline; break; - case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H256].pipeline; break; - default: - { - GGML_LOG_ERROR("unsupported size: %lld\n", ne00); - GGML_LOG_ERROR("add template specialization for this size\n"); - GGML_ABORT("add template specialization for this size"); - } + if (ne00 == 192 && ne20 == 128) { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_HK192_HV128].pipeline; + } else { + switch (ne00) { + case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H64 ].pipeline; break; + case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H80 ].pipeline; break; + case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H96 ].pipeline; break; + case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H112].pipeline; break; + case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H128].pipeline; break; + case 192: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H192].pipeline; break; + case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H256].pipeline; break; + default: + { + GGML_LOG_ERROR("unsupported size: %lld\n", ne00); + GGML_LOG_ERROR("add template specialization for this size\n"); + GGML_ABORT("add template specialization for this size"); + } + } } } break; case GGML_TYPE_Q5_0: { - switch (ne00) { - case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H64 ].pipeline; break; - case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H80 ].pipeline; break; - case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H96 ].pipeline; break; - case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H112].pipeline; break; - case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H128].pipeline; break; - case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H256].pipeline; break; - default: - { - GGML_LOG_ERROR("unsupported size: %lld\n", ne00); - GGML_LOG_ERROR("add template specialization for this size\n"); - GGML_ABORT("add template specialization for this size"); - } + if (ne00 == 192 && ne20 == 128) { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_HK192_HV128].pipeline; + } else { + switch (ne00) { + case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H64 ].pipeline; break; + case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H80 ].pipeline; break; + case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H96 ].pipeline; break; + case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H112].pipeline; break; + case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H128].pipeline; break; + case 192: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H192].pipeline; break; + case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H256].pipeline; break; + default: + { + GGML_LOG_ERROR("unsupported size: %lld\n", ne00); + GGML_LOG_ERROR("add template specialization for this size\n"); + GGML_ABORT("add template specialization for this size"); + } + } } } break; case GGML_TYPE_Q5_1: { - switch (ne00) { - case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H64 ].pipeline; break; - case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H80 ].pipeline; break; - case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H96 ].pipeline; break; - case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H112].pipeline; break; - case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H128].pipeline; break; - case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H256].pipeline; break; - default: - { - GGML_LOG_ERROR("unsupported size: %lld\n", ne00); - GGML_LOG_ERROR("add template specialization for this size\n"); - GGML_ABORT("add template specialization for this size"); - } + if (ne00 == 192 && ne20 == 128) { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_HK192_HV128].pipeline; + } else { + switch (ne00) { + case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H64 ].pipeline; break; + case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H80 ].pipeline; break; + case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H96 ].pipeline; break; + case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H112].pipeline; break; + case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H128].pipeline; break; + case 192: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H192].pipeline; break; + case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H256].pipeline; break; + default: + { + GGML_LOG_ERROR("unsupported size: %lld\n", ne00); + GGML_LOG_ERROR("add template specialization for this size\n"); + GGML_ABORT("add template specialization for this size"); + } + } } } break; case GGML_TYPE_Q8_0: { - switch (ne00) { - case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H64 ].pipeline; break; - case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H80 ].pipeline; break; - case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H96 ].pipeline; break; - case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H112].pipeline; break; - case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H128].pipeline; break; - case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H256].pipeline; break; - default: - { - GGML_LOG_ERROR("unsupported size: %lld\n", ne00); - GGML_LOG_ERROR("add template specialization for this size\n"); - GGML_ABORT("add template specialization for this size"); - } + if (ne00 == 192 && ne20 == 128) { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK192_HV128].pipeline; + } else { + switch (ne00) { + case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H64 ].pipeline; break; + case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H80 ].pipeline; break; + case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H96 ].pipeline; break; + case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H112].pipeline; break; + case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H128].pipeline; break; + case 192: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H192].pipeline; break; + case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H256].pipeline; break; + default: + { + GGML_LOG_ERROR("unsupported size: %lld\n", ne00); + GGML_LOG_ERROR("add template specialization for this size\n"); + GGML_ABORT("add template specialization for this size"); + } + } } } break; default: @@ -3929,6 +4022,42 @@ static void ggml_metal_encode_node( } } } break; + case 192: + { + if (ne20 == 128) { + switch (src1->type) { + case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_HK192_HV128].pipeline; break; + case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_HK192_HV128].pipeline; break; + case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_HK192_HV128].pipeline; break; + case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_HK192_HV128].pipeline; break; + case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_HK192_HV128].pipeline; break; + case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_HK192_HV128].pipeline; break; + case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_HK192_HV128].pipeline; break; + default: + { + GGML_LOG_ERROR("unsupported type: %d\n", src1->type); + GGML_LOG_ERROR("add template specialization for this type\n"); + GGML_ABORT("add template specialization for this type"); + } + } + } else { + switch (src1->type) { + case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H192].pipeline; break; + case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H192].pipeline; break; + case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H192].pipeline; break; + case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H192].pipeline; break; + case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H192].pipeline; break; + case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H192].pipeline; break; + case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H192].pipeline; break; + default: + { + GGML_LOG_ERROR("unsupported type: %d\n", src1->type); + GGML_LOG_ERROR("add template specialization for this type\n"); + GGML_ABORT("add template specialization for this type"); + } + } + } + } break; case 256: { switch (src1->type) { @@ -3966,9 +4095,12 @@ static void ggml_metal_encode_node( /*.ne11 =*/ ne11, /*.ne_12_2 =*/ ne12, /*.ne_12_3 =*/ ne13, - /*.nb_12_1 =*/ nb11, - /*.nb_12_2 =*/ nb12, - /*.nb_12_3 =*/ nb13, + /*.nb11 =*/ nb11, + /*.nb12 =*/ nb12, + /*.nb13 =*/ nb13, + /*.nb21 =*/ nb21, + /*.nb22 =*/ nb22, + /*.nb23 =*/ nb23, /*.nb31 =*/ nb31, /*.ne1 =*/ ne1, /*.ne2 =*/ ne2, @@ -4047,10 +4179,9 @@ static void ggml_metal_encode_node( // ne00*(nsg) // each simdgroup has a full f16 head vector in shared mem to accumulate results // -#define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(ne00 + 2*ncpsg*(nsg)) + ne00*(nsg))*(sizeof(float)/2), 16)) +#define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(GGML_PAD(ne00, 128) + 2*ncpsg*(nsg)) + ne20*(nsg))*(sizeof(float)/2), 16)) int64_t nsgmax = 2; - while (true) { const size_t smem = FATTN_SMEM(nsgmax); if (smem > device.maxThreadgroupMemoryLength) { diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index 38f03efba..1c0ca5adf 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -48,7 +48,7 @@ void dequantize_f16(device const half4x4 * src, short il, thread type4x4 & reg) template void dequantize_f16_t4(device const half4 * src, short il, thread type4 & reg) { - reg = (type4)(*(src + il)); + reg = (type4)(*(src)); } #if defined(GGML_METAL_USE_BF16) @@ -56,6 +56,11 @@ template void dequantize_bf16(device const bfloat4x4 * src, short il, thread type4x4 & reg) { reg = (type4x4)(*src); } + +template +void dequantize_bf16_t4(device const bfloat4 * src, short il, thread type4 & reg) { + reg = (type4)(*(src)); +} #endif template @@ -3100,7 +3105,8 @@ template< typename vd4x4_t, // key type in device memory short nl_v, void (*deq_v)(device const vd4x4_t *, short, thread v4x4_t &), - short D, // head size + short DK, // K head size + short DV, // V head size short Q = 8, // queries per threadgroup short KV = 8, // key/value processed per each simdgroup short C = 32> // cache items per threadgroup @@ -3122,20 +3128,23 @@ kernel void kernel_flash_attn_ext( const int iq2 = tgpig[1]; const int iq1 = tgpig[0]*Q; - const short D4 = D/4; - const short D8 = D/8; - const short D16 = D/16; + const short DK4 = DK/4; + const short DK8 = DK/8; + const short DK16 = DK/16; + const short DV4 = DV/4; + const short DV8 = DV/8; + const short DV16 = DV/16; const short NW = N_SIMDWIDTH; const short SH = (2*C + Q); // shared memory per simdgroup (s_t == float) const short TS = nsg*SH; // shared memory size per query in (s_t == float) - const short T = D + 2*TS; // shared memory size per query in (half) + const short T = DK + 2*TS; // shared memory size per query in (half) - threadgroup q_t * sq = (threadgroup q_t *) (shmem_f16 + 0*D); // holds the query data - threadgroup q4_t * sq4 = (threadgroup q4_t *) (shmem_f16 + 0*D); // same as above but in q4_t - threadgroup o_t * so = (threadgroup o_t *) (shmem_f16 + 0*D); // reuse query data for accumulation - threadgroup o4_t * so4 = (threadgroup o4_t *) (shmem_f16 + 0*D); // same as above but in o4_t - threadgroup s_t * ss = (threadgroup s_t *) (shmem_f16 + 2*sgitg*SH + Q*D); // scratch buffer for attention, mask and diagonal matrix + threadgroup q_t * sq = (threadgroup q_t *) (shmem_f16 + 0*DK); // holds the query data + threadgroup q4_t * sq4 = (threadgroup q4_t *) (shmem_f16 + 0*DK); // same as above but in q4_t + threadgroup o_t * so = (threadgroup o_t *) (shmem_f16 + 0*DK); // reuse query data for accumulation + threadgroup o4_t * so4 = (threadgroup o4_t *) (shmem_f16 + 0*DK); // same as above but in o4_t + threadgroup s_t * ss = (threadgroup s_t *) (shmem_f16 + 2*sgitg*SH + Q*DK); // scratch buffer for attention, mask and diagonal matrix threadgroup k_t * sk = (threadgroup k_t *) (shmem_f16 + sgitg*(4*16*KV) + Q*T); // scratch buffer to load K in shared memory threadgroup k4x4_t * sk4x4 = (threadgroup k4x4_t *) (shmem_f16 + sgitg*(4*16*KV) + Q*T); // same as above but in k4x4_t @@ -3144,23 +3153,23 @@ kernel void kernel_flash_attn_ext( threadgroup v4x4_t * sv4x4 = (threadgroup v4x4_t *) (shmem_f16 + sgitg*(4*16*KV) + Q*T); // same as above but in v4x4_t // store the result for all queries in local memory in 8x8 matrices (the O matrix from the paper) - o8x8_t lo[D8]; + o8x8_t lo[DV8]; // load heads from Q to shared memory for (short j = sgitg; j < Q; j += nsg) { device const float4 * q4 = (device const float4 *) ((device const char *) q + ((iq1 + j)*args.nb01 + iq2*args.nb02 + iq3*args.nb03)); - for (short i = tiisg; i < D4; i += NW) { + for (short i = tiisg; i < DK4; i += NW) { if (iq1 + j < args.ne01) { - sq4[j*D4 + i] = (q4_t) q4[i]; + sq4[j*DK4 + i] = (q4_t) q4[i]; } else { - sq4[j*D4 + i] = (q4_t) 0.0f; + sq4[j*DK4 + i] = (q4_t) 0.0f; } } } // zero out lo - for (short i = 0; i < D8; ++i) { + for (short i = 0; i < DV8; ++i) { lo[i] = make_filled_simdgroup_matrix((o_t) 0.0f); } @@ -3190,13 +3199,6 @@ kernel void kernel_flash_attn_ext( const short ikv2 = iq2/(args.ne02/args.ne_12_2); const short ikv3 = iq3/(args.ne03/args.ne_12_3); - // load the queries from shared memory into local memory - q8x8_t mq[D8]; - - for (short i = 0; i < D8; ++i) { - simdgroup_load(mq[i], sq + i*8, D); - } - const bool has_mask = mask != q; half slope = 1.0f; @@ -3249,20 +3251,22 @@ kernel void kernel_flash_attn_ext( // this is compile-time check, so it does not have runtime overhead if (is_same::value) { // we can read directly from global memory - device const k_t * pk = (device const k_t *) ((device const char *) k + ((ic + 8*cc)*args.nb_12_1 + ikv2*args.nb_12_2 + ikv3*args.nb_12_3)); + device const k_t * pk = (device const k_t *) ((device const char *) k + ((ic + 8*cc)*args.nb11 + ikv2*args.nb12 + ikv3*args.nb13)); - #pragma unroll(D8) - for (short i = 0; i < D8; ++i) { + #pragma unroll(DK8) + for (short i = 0; i < DK8; ++i) { k8x8_t mk; - simdgroup_load(mk, pk + i*8, args.nb_12_1/sizeof(k_t), 0, true); // transpose // TODO: use ne10 + simdgroup_load(mk, pk + i*8, args.nb11/sizeof(k_t), 0, true); // transpose // TODO: use ne10 - simdgroup_multiply_accumulate(mqk, mq[i], mk, mqk); + q8x8_t mq; + simdgroup_load(mq, sq + i*8, DK); + simdgroup_multiply_accumulate(mqk, mq, mk, mqk); } } else { - for (short ii = 0; ii < D16; ii += 4) { - device const kd4x4_t * pk4x4 = (device const kd4x4_t *) ((device const char *) k + ((ic + 8*cc + ty)*args.nb_12_1 + ikv2*args.nb_12_2 + ikv3*args.nb_12_3)); + for (short ii = 0; ii < DK16; ii += 4) { + device const kd4x4_t * pk4x4 = (device const kd4x4_t *) ((device const char *) k + ((ic + 8*cc + ty)*args.nb11 + ikv2*args.nb12 + ikv3*args.nb13)); - if (D16%4 == 0) { + if (DK16%4 == 0) { // the head is evenly divisible by 4*16 = 64, so no need for bound checks { k4x4_t tmp; @@ -3275,15 +3279,18 @@ kernel void kernel_flash_attn_ext( #pragma unroll(4) for (short k = 0; k < 4; ++k) { k8x8_t mk; + q8x8_t mq; simdgroup_load(mk, sk + 16*k + 0*8, 4*16, 0, true); // transpose - simdgroup_multiply_accumulate(mqk, mq[2*(ii + k) + 0], mk, mqk); + simdgroup_load(mq, sq + (2*(ii + k) + 0)*8, DK); + simdgroup_multiply_accumulate(mqk, mq, mk, mqk); simdgroup_load(mk, sk + 16*k + 1*8, 4*16, 0, true); // transpose - simdgroup_multiply_accumulate(mqk, mq[2*(ii + k) + 1], mk, mqk); + simdgroup_load(mq, sq + (2*(ii + k) + 1)*8, DK); + simdgroup_multiply_accumulate(mqk, mq, mk, mqk); } } else { - if (ii + tx < D16) { + if (ii + tx < DK16) { k4x4_t tmp; deq_k(pk4x4 + (ii + tx)/nl_k, (ii + tx)%nl_k, tmp); sk4x4[4*ty + tx] = tmp; @@ -3291,14 +3298,17 @@ kernel void kernel_flash_attn_ext( simdgroup_barrier(mem_flags::mem_threadgroup); - for (short k = 0; k < 4 && ii + k < D16; ++k) { + for (short k = 0; k < 4 && ii + k < DK16; ++k) { k8x8_t mk; + q8x8_t mq; simdgroup_load(mk, sk + 16*k + 0*8, 4*16, 0, true); // transpose - simdgroup_multiply_accumulate(mqk, mq[2*(ii + k) + 0], mk, mqk); + simdgroup_load(mq, sq + (2*(ii + k) + 0)*8, DK); + simdgroup_multiply_accumulate(mqk, mq, mk, mqk); simdgroup_load(mk, sk + 16*k + 1*8, 4*16, 0, true); // transpose - simdgroup_multiply_accumulate(mqk, mq[2*(ii + k) + 1], mk, mqk); + simdgroup_load(mq, sq + (2*(ii + k) + 1)*8, DK); + simdgroup_multiply_accumulate(mqk, mq, mk, mqk); } } } @@ -3350,8 +3360,8 @@ kernel void kernel_flash_attn_ext( s8x8_t mm; simdgroup_load(mm, ss + 2*C, TS, 0, false); - #pragma unroll(D8) - for (short i = 0; i < D8; ++i) { + #pragma unroll(DV8) + for (short i = 0; i < DV8; ++i) { simdgroup_multiply(lo[i], mm, lo[i]); } } @@ -3364,20 +3374,20 @@ kernel void kernel_flash_attn_ext( if (is_same::value) { // we can read directly from global memory - device const v_t * pv = (device const v_t *) ((device const char *) v + ((ic + 8*cc)*args.nb_12_1 + ikv2*args.nb_12_2 + ikv3*args.nb_12_3)); + device const v_t * pv = (device const v_t *) ((device const char *) v + ((ic + 8*cc)*args.nb21 + ikv2*args.nb22 + ikv3*args.nb23)); - #pragma unroll(D8) - for (short i = 0; i < D8; ++i) { + #pragma unroll(DV8) + for (short i = 0; i < DV8; ++i) { v8x8_t mv; - simdgroup_load(mv, pv + i*8, args.nb_12_1/sizeof(v_t), 0, false); // TODO: use ne20 + simdgroup_load(mv, pv + i*8, args.nb21/sizeof(v_t), 0, false); // TODO: use ne20 simdgroup_multiply_accumulate(lo[i], ms, mv, lo[i]); } } else { - for (short ii = 0; ii < D16; ii += 4) { - device const vd4x4_t * pv4x4 = (device const vd4x4_t *) ((device const char *) v + ((ic + 8*cc + ty)*args.nb_12_1 + ikv2*args.nb_12_2 + ikv3*args.nb_12_3)); + for (short ii = 0; ii < DV16; ii += 4) { + device const vd4x4_t * pv4x4 = (device const vd4x4_t *) ((device const char *) v + ((ic + 8*cc + ty)*args.nb21 + ikv2*args.nb22 + ikv3*args.nb23)); - if (D16%4 == 0) { + if (DV16%4 == 0) { // no need for bound checks { v4x4_t tmp; @@ -3398,7 +3408,7 @@ kernel void kernel_flash_attn_ext( simdgroup_multiply_accumulate(lo[2*(ii + k) + 1], ms, mv, lo[2*(ii + k) + 1]); } } else { - if (ii + tx < D16) { + if (ii + tx < DV16) { v4x4_t tmp; deq_v(pv4x4 + (ii + tx)/nl_v, (ii + tx)%nl_v, tmp); sv4x4[4*ty + tx] = tmp; @@ -3406,7 +3416,7 @@ kernel void kernel_flash_attn_ext( simdgroup_barrier(mem_flags::mem_threadgroup); - for (short k = 0; k < 4 && ii + k < D16; ++k) { + for (short k = 0; k < 4 && ii + k < DV16; ++k) { v8x8_t mv; simdgroup_load(mv, sv + 16*k + 0*8, 4*16, 0, false); @@ -3440,8 +3450,8 @@ kernel void kernel_flash_attn_ext( // each simdgroup stores its output to shared memory, reusing sq if (sgitg == sg) { - for (short i = 0; i < D8; ++i) { - simdgroup_store(lo[i], so + i*8, D, 0, false); + for (short i = 0; i < DV8; ++i) { + simdgroup_store(lo[i], so + i*8, DV, 0, false); } } @@ -3480,11 +3490,11 @@ kernel void kernel_flash_attn_ext( simdgroup_load(ms0, ss + 2*C, TS, 0, false); simdgroup_load(ms1, ss + 2*C + sg*SH, TS, 0, false); - #pragma unroll(D8) - for (short i = 0; i < D8; ++i) { + #pragma unroll(DV8) + for (short i = 0; i < DV8; ++i) { o8x8_t t; - simdgroup_load (t, so + i*8, D, 0, false); + simdgroup_load (t, so + i*8, DV, 0, false); simdgroup_multiply(t, ms1, t); simdgroup_multiply_accumulate(lo[i], ms0, lo[i], t); @@ -3495,8 +3505,8 @@ kernel void kernel_flash_attn_ext( // store result to shared memory (reuse sq) if (sgitg == 0) { - for (short i = 0; i < D8; ++i) { - simdgroup_store(lo[i], so + i*8, D, 0, false); + for (short i = 0; i < DV8; ++i) { + simdgroup_store(lo[i], so + i*8, DV, 0, false); } } @@ -3507,8 +3517,8 @@ kernel void kernel_flash_attn_ext( for (short j = 0; j < Q && iq1 + j < args.ne01; ++j) { const float S = ss[j*TS + 0]; - for (short i = tiisg; i < D4; i += NW) { - dst4[((uint64_t)iq3*args.ne2*args.ne1 + iq2 + (uint64_t)(iq1 + j)*args.ne1)*D4 + i] = (float4) so4[j*D4 + i]/S; + for (short i = tiisg; i < DV4; i += NW) { + dst4[((uint64_t)iq3*args.ne2*args.ne1 + iq2 + (uint64_t)(iq1 + j)*args.ne1)*DV4 + i] = (float4) so4[j*DV4 + i]/S; } } } @@ -3525,80 +3535,94 @@ kernel void kernel_flash_attn_ext( float, simdgroup_float8x8, \ half, half4, simdgroup_half8x8 -typedef decltype(kernel_flash_attn_ext) flash_attn_ext_t; +typedef decltype(kernel_flash_attn_ext) flash_attn_ext_t; -template [[host_name("kernel_flash_attn_ext_f16_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_f16_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_f16_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_f16_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_f16_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_f16_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_f16_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_f16_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_f16_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_f16_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_f16_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_f16_h192")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_f16_hk192_hv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_f16_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; #if defined(GGML_METAL_USE_BF16) -template [[host_name("kernel_flash_attn_ext_bf16_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_bf16_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_bf16_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_bf16_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_bf16_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_bf16_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_bf16_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_bf16_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_bf16_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_bf16_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_bf16_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_bf16_h192")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_bf16_hk192_hv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_bf16_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; #endif -template [[host_name("kernel_flash_attn_ext_q4_0_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q4_0_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q4_0_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q4_0_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q4_0_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q4_0_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q4_0_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q4_0_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q4_0_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q4_0_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q4_0_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q4_0_h192")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q4_0_hk192_hv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q4_0_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q4_1_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q4_1_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q4_1_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q4_1_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q4_1_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q4_1_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q4_1_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q4_1_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q4_1_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q4_1_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q4_1_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q4_1_h192")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q4_1_hk192_hv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q4_1_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q5_0_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q5_0_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q5_0_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q5_0_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q5_0_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q5_0_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q5_0_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q5_0_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q5_0_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q5_0_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q5_0_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q5_0_h192")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q5_0_hk192_hv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q5_0_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q5_1_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q5_1_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q5_1_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q5_1_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q5_1_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q5_1_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q5_1_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q5_1_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q5_1_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q5_1_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q5_1_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q5_1_h192")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q5_1_hk192_hv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q5_1_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q8_0_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q8_0_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q8_0_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q8_0_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q8_0_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q8_0_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q8_0_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q8_0_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q8_0_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q8_0_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q8_0_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q8_0_h192")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q8_0_hk192_hv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q8_0_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; #undef FA_TYPES template< - typename q4_t, // query types in shared memory - typename q4x4_t, - typename k4x4_t, // key types in shared memory - typename v4x4_t, // value types in shared memory - typename qk_t, // Q*K types - typename s_t, // soft-max types + typename q4_t, // query types in shared memory + typename k4_t, // key types in shared memory + typename v4_t, // value types in shared memory + typename qk_t, // Q*K types + typename s_t, // soft-max types typename s4_t, - typename s4x4_t, - typename o4x4_t, // attention accumulation types - typename kd4x4_t, // key type in device memory + typename o4_t, // attention accumulation types + typename kd4_t, // key type in device memory short nl_k, - void (*deq_k)(device const kd4x4_t *, short, thread k4x4_t &), - typename vd4x4_t, // key type in device memory + void (*deq_k_t4)(device const kd4_t *, short, thread k4_t &), + typename vd4_t, // key type in device memory short nl_v, - void (*deq_v)(device const vd4x4_t *, short, thread v4x4_t &), - short D, // head size - short Q = 1, // queries per threadgroup - short C = 32> // cache items per threadgroup + void (*deq_v_t4)(device const vd4_t *, short, thread v4_t &), + short DK, // K head size + short DV, // V head size + short NE = 4, // head elements per thread + short Q = 1, // queries per threadgroup + short C = 32> // cache items per threadgroup kernel void kernel_flash_attn_ext_vec( constant ggml_metal_kargs_flash_attn_ext & args, device const char * q, @@ -3617,29 +3641,28 @@ kernel void kernel_flash_attn_ext_vec( const int iq2 = tgpig[1]; const int iq1 = tgpig[0]; - const short D4 = D/4; - const short D16 = D/16; + const short DK4 = DK/4; + const short DV4 = DV/4; const short NW = N_SIMDWIDTH; - const short NL = NW/4; // note: this can be adjusted to support D%64 == 0 and D%32 == 0 - const short SH = 2*C; // shared memory per simdgroup + const short NL = NW/NE; // note: this can be adjusted to support different head sizes and simdgroup work loads + const short SH = 2*C; // shared memory per simdgroup - const short T = D + nsg*SH; // shared memory size per query in (half) + const short T = DK + nsg*SH; // shared memory size per query in (half) - //threadgroup q_t * sq = (threadgroup q_t *) (shmem_f16 + 0*D); // holds the query data - threadgroup q4_t * sq4 = (threadgroup q4_t *) (shmem_f16 + 0*D); // same as above but in q4_t - threadgroup q4x4_t * sq4x4 = (threadgroup q4x4_t *) (shmem_f16 + 0*D); // same as above but in q4x4_t - threadgroup s_t * ss = (threadgroup s_t *) (shmem_f16 + sgitg*SH + Q*D); // scratch buffer for attention - threadgroup s4_t * ss4 = (threadgroup s4_t *) (shmem_f16 + sgitg*SH + Q*D); // same as above but in s4_t - threadgroup half * sm = (threadgroup half *) (shmem_f16 + sgitg*SH + C + Q*D); // scratch buffer for mask - threadgroup o4x4_t * sr4x4 = (threadgroup o4x4_t *) (shmem_f16 + sgitg*D + Q*T); // scratch buffer for the results + //threadgroup q_t * sq = (threadgroup q_t *) (shmem_f16 + 0*DK); // holds the query data + threadgroup q4_t * sq4 = (threadgroup q4_t *) (shmem_f16 + 0*DK); // same as above but in q4_t + threadgroup s_t * ss = (threadgroup s_t *) (shmem_f16 + sgitg*SH + Q*DK); // scratch buffer for attention + threadgroup s4_t * ss4 = (threadgroup s4_t *) (shmem_f16 + sgitg*SH + Q*DK); // same as above but in s4_t + threadgroup half * sm = (threadgroup half *) (shmem_f16 + sgitg*SH + C + Q*DK); // scratch buffer for mask + threadgroup o4_t * sr4 = (threadgroup o4_t *) (shmem_f16 + sgitg*DV + Q*T); // scratch buffer for the results - // store the result for all queries in local memory in 8x8 matrices (the O matrix from the paper) - o4x4_t lo[D16/NL]; + // store the result for all queries in local memory (the O matrix from the paper) + o4_t lo[DV4/NL]; // load heads from Q to shared memory device const float4 * q4 = (device const float4 *) ((device const char *) q + (iq1*args.nb01 + iq2*args.nb02 + iq3*args.nb03)); - for (short i = tiisg; i < D4; i += NW) { + for (short i = tiisg; i < DK4; i += NW) { if (iq1 < args.ne01) { sq4[i] = (q4_t) q4[i]; } else { @@ -3648,8 +3671,8 @@ kernel void kernel_flash_attn_ext_vec( } // zero out lo - for (short i = 0; i < D16/NL; ++i) { - lo[i] = (o4x4_t) 0.0f; + for (short i = 0; i < DV4/NL; ++i) { + lo[i] = (o4_t) 0.0f; } // zero out shared memory SH @@ -3674,14 +3697,6 @@ kernel void kernel_flash_attn_ext_vec( const short ikv2 = iq2/(args.ne02/args.ne_12_2); const short ikv3 = iq3/(args.ne03/args.ne_12_3); - // load the queries from shared memory into local memory - q4x4_t mq[D16/NL]; - - #pragma unroll(D16/NL) - for (short ii = 0; ii < D16; ii += NL) { - mq[ii/NL] = sq4x4[ii + tx]; - } - const bool has_mask = mask != q; // pointer to the mask @@ -3713,43 +3728,56 @@ kernel void kernel_flash_attn_ext_vec( // Q*K^T { - // each simdgroup processes 1 query and 4 (NW/NL) keys - for (short cc = 0; cc < C/4; ++cc) { - qk_t mqka[4] = { 0.0, 0.0, 0.0, 0.0 }; + // each simdgroup processes 1 query and NE (NW/NL) head elements + for (short cc = 0; cc < C/NE; ++cc) { + qk_t mqk = 0.0f; - device const kd4x4_t * pk = (device const kd4x4_t *) ((device const char *) k + ((ic + 4*cc + ty)*args.nb_12_1 + ikv2*args.nb_12_2 + ikv3*args.nb_12_3)); + device const kd4_t * pk = (device const kd4_t *) ((device const char *) k + ((ic + NE*cc + ty)*args.nb11 + ikv2*args.nb12 + ikv3*args.nb13)); - #pragma unroll(D16/NL) - for (short ii = 0; ii < D16; ii += NL) { + #pragma unroll(DK4/NL) + for (short ii = 0; ii < DK4; ii += NL) { const short i = ii + tx; - k4x4_t mk; - deq_k(pk + i/nl_k, i%nl_k, mk); + k4_t mk; + deq_k_t4(pk + i/nl_k, i%nl_k, mk); // note: this is less precise than the version below - //mqka[0] += dot(mq[ii/NL][0], mk[0]); - //mqka[1] += dot(mq[ii/NL][1], mk[1]); - //mqka[2] += dot(mq[ii/NL][2], mk[2]); - //mqka[3] += dot(mq[ii/NL][3], mk[3]); + //mqka[0] += dot(mq[0], mk[0]); + //mqka[1] += dot(mq[1], mk[1]); + //mqka[2] += dot(mq[2], mk[2]); + //mqka[3] += dot(mq[3], mk[3]); - mqka[0] += dot((float4) mq[ii/NL][0], (float4) mk[0]); - mqka[1] += dot((float4) mq[ii/NL][1], (float4) mk[1]); - mqka[2] += dot((float4) mq[ii/NL][2], (float4) mk[2]); - mqka[3] += dot((float4) mq[ii/NL][3], (float4) mk[3]); + //q4x4_t mq = sq4x4[i]; + //mqka[0] += dot((float4) mq[0], (float4) mk[0]); + //mqka[1] += dot((float4) mq[1], (float4) mk[1]); + //mqka[2] += dot((float4) mq[2], (float4) mk[2]); + //mqka[3] += dot((float4) mq[3], (float4) mk[3]); + + mqk += dot((float4) mk, (float4) sq4[i]); } - qk_t mqk = mqka[0] + mqka[1] + mqka[2] + mqka[3]; + static_assert(NE > 1, "NE must be > 1"); // note: not sure why NE == 1 fails - // simdgroup reduce + // simdgroup reduce (NE = 4) // [ 0 .. 7] -> [ 0] // [ 8 .. 15] -> [ 8] // [16 .. 23] -> [16] // [24 .. 31] -> [24] - //mqk += simd_shuffle_down(mqk, 16); - //mqk += simd_shuffle_down(mqk, 8); - mqk += simd_shuffle_down(mqk, 4); - mqk += simd_shuffle_down(mqk, 2); - mqk += simd_shuffle_down(mqk, 1); + if (NE <= 1) { + mqk += simd_shuffle_down(mqk, 16); + } + if (NE <= 2) { + mqk += simd_shuffle_down(mqk, 8); + } + if (NE <= 4) { + mqk += simd_shuffle_down(mqk, 4); + } + if (NE <= 8) { + mqk += simd_shuffle_down(mqk, 2); + } + if (NE <= 16) { + mqk += simd_shuffle_down(mqk, 1); + } // mqk = mqk*scale + mask*slope if (tx == 0) { @@ -3759,9 +3787,9 @@ kernel void kernel_flash_attn_ext_vec( mqk = args.logit_softcap*precise::tanh(mqk); } - mqk += sm[4*cc + ty]*slope; + mqk += sm[NE*cc + ty]*slope; - ss[4*cc + ty] = mqk; + ss[NE*cc + ty] = mqk; } } } @@ -3784,8 +3812,8 @@ kernel void kernel_flash_attn_ext_vec( ss[tiisg] = vs; // O = diag(ms)*O - #pragma unroll(D16/NL) - for (short ii = 0; ii < D16; ii += NL) { + #pragma unroll(DV4/NL) + for (short ii = 0; ii < DV4; ii += NL) { lo[ii/NL] *= ms; } } @@ -3794,17 +3822,18 @@ kernel void kernel_flash_attn_ext_vec( // O = O + (Q*K^T)*V { - for (short cc = 0; cc < C/4; ++cc) { - device const vd4x4_t * pv4 = (device const vd4x4_t *) ((device const char *) v + ((ic + 4*cc + ty)*args.nb_12_1 + ikv2*args.nb_12_2 + ikv3*args.nb_12_3)); + //#pragma unroll(C/NE) + for (short cc = 0; cc < C/NE; ++cc) { + device const vd4_t * pv4 = (device const vd4_t *) ((device const char *) v + ((ic + NE*cc + ty)*args.nb21 + ikv2*args.nb22 + ikv3*args.nb23)); - const s4x4_t ms(ss[4*cc + ty]); + const s4_t ms(ss[NE*cc + ty]); - #pragma unroll(D16/NL) - for (short ii = 0; ii < D16; ii += NL) { + #pragma unroll(DV4/NL) + for (short ii = 0; ii < DV4; ii += NL) { const short i = ii + tx; - v4x4_t mv; - deq_v(pv4 + i/nl_v, i%nl_v, mv); + v4_t mv; + deq_v_t4(pv4 + i/nl_v, i%nl_v, mv); lo[ii/NL] += mv*ms; } @@ -3819,7 +3848,7 @@ kernel void kernel_flash_attn_ext_vec( } } - // simdgroup reduce + // simdgroup reduce (NE = 4) // [ 0, 8, 16, 24] -> [ 0] // [ 1, 9, 17, 25] -> [ 1] // [ 2, 10, 18, 26] -> [ 2] @@ -3828,37 +3857,48 @@ kernel void kernel_flash_attn_ext_vec( // [ 5, 13, 21, 29] -> [ 5] // [ 6, 14, 22, 30] -> [ 6] // [ 7, 15, 23, 31] -> [ 7] - for (short ii = 0; ii < D16; ii += NL) { - lo[ii/NL][0] += simd_shuffle_down(lo[ii/NL][0], 16); - lo[ii/NL][0] += simd_shuffle_down(lo[ii/NL][0], 8); - //lo[ii/NL][0] += simd_shuffle_down(lo[ii/NL][0], 4); - //lo[ii/NL][0] += simd_shuffle_down(lo[ii/NL][0], 2); - //lo[ii/NL][0] += simd_shuffle_down(lo[ii/NL][0], 1); + for (short ii = 0; ii < DV4; ii += NL) { + if (NE > 1) { + lo[ii/NL][0] += simd_shuffle_down(lo[ii/NL][0], 16); + lo[ii/NL][1] += simd_shuffle_down(lo[ii/NL][1], 16); + lo[ii/NL][2] += simd_shuffle_down(lo[ii/NL][2], 16); + lo[ii/NL][3] += simd_shuffle_down(lo[ii/NL][3], 16); + } - lo[ii/NL][1] += simd_shuffle_down(lo[ii/NL][1], 16); - lo[ii/NL][1] += simd_shuffle_down(lo[ii/NL][1], 8); - //lo[ii/NL][1] += simd_shuffle_down(lo[ii/NL][1], 4); - //lo[ii/NL][1] += simd_shuffle_down(lo[ii/NL][1], 2); - //lo[ii/NL][1] += simd_shuffle_down(lo[ii/NL][1], 1); + if (NE > 2) { + lo[ii/NL][0] += simd_shuffle_down(lo[ii/NL][0], 8); + lo[ii/NL][1] += simd_shuffle_down(lo[ii/NL][1], 8); + lo[ii/NL][2] += simd_shuffle_down(lo[ii/NL][2], 8); + lo[ii/NL][3] += simd_shuffle_down(lo[ii/NL][3], 8); + } - lo[ii/NL][2] += simd_shuffle_down(lo[ii/NL][2], 16); - lo[ii/NL][2] += simd_shuffle_down(lo[ii/NL][2], 8); - //lo[ii/NL][2] += simd_shuffle_down(lo[ii/NL][2], 4); - //lo[ii/NL][2] += simd_shuffle_down(lo[ii/NL][2], 2); - //lo[ii/NL][2] += simd_shuffle_down(lo[ii/NL][2], 1); + if (NE > 4) { + lo[ii/NL][0] += simd_shuffle_down(lo[ii/NL][0], 4); + lo[ii/NL][1] += simd_shuffle_down(lo[ii/NL][1], 4); + lo[ii/NL][2] += simd_shuffle_down(lo[ii/NL][2], 4); + lo[ii/NL][3] += simd_shuffle_down(lo[ii/NL][3], 4); + } - lo[ii/NL][3] += simd_shuffle_down(lo[ii/NL][3], 16); - lo[ii/NL][3] += simd_shuffle_down(lo[ii/NL][3], 8); - //lo[ii/NL][3] += simd_shuffle_down(lo[ii/NL][3], 4); - //lo[ii/NL][3] += simd_shuffle_down(lo[ii/NL][3], 2); - //lo[ii/NL][3] += simd_shuffle_down(lo[ii/NL][3], 1); + if (NE > 8) { + lo[ii/NL][0] += simd_shuffle_down(lo[ii/NL][0], 2); + lo[ii/NL][1] += simd_shuffle_down(lo[ii/NL][1], 2); + lo[ii/NL][2] += simd_shuffle_down(lo[ii/NL][2], 2); + lo[ii/NL][3] += simd_shuffle_down(lo[ii/NL][3], 2); + } + + if (NE > 16) { + lo[ii/NL][0] += simd_shuffle_down(lo[ii/NL][0], 1); + lo[ii/NL][1] += simd_shuffle_down(lo[ii/NL][1], 1); + lo[ii/NL][2] += simd_shuffle_down(lo[ii/NL][2], 1); + lo[ii/NL][3] += simd_shuffle_down(lo[ii/NL][3], 1); + } } threadgroup_barrier(mem_flags::mem_threadgroup); // store results to shared memory - for (short i = tiisg; i < D16; i += NL) { - sr4x4[i] = lo[i/NL]; + for (short i = tiisg; i < DV4; i += NL) { + sr4[i] = lo[i/NL]; } threadgroup_barrier(mem_flags::mem_threadgroup); @@ -3885,22 +3925,22 @@ kernel void kernel_flash_attn_ext_vec( } // O_0 = diag(ms0)*O_0 + diag(ms1)*O_1 - for (short i = tiisg; i < D16; i += NW) { - sr4x4[i] = sr4x4[i]*ms0 + sr4x4[i + r*D16]*ms1; + for (short i = tiisg; i < DV4; i += NW) { + sr4[i] = sr4[i]*ms0 + sr4[i + r*DV4]*ms1; } } threadgroup_barrier(mem_flags::mem_threadgroup); } - device float4x4 * dst44 = (device float4x4 *) dst; + device float4 * dst4 = (device float4 *) dst; // final rescale with 1/S and store to global memory if (sgitg == 0) { const float S = ss[0]; - for (short i = tiisg; i < D16; i += NW) { - dst44[((uint64_t)iq3*args.ne2*args.ne1 + iq2 + (uint64_t)iq1*args.ne1)*D16 + i] = (float4x4) sr4x4[i]/S; + for (short i = tiisg; i < DV4; i += NW) { + dst4[((uint64_t)iq3*args.ne2*args.ne1 + iq2 + (uint64_t)iq1*args.ne1)*DV4 + i] = (float4) sr4[i]/S; } } } @@ -3909,34 +3949,54 @@ kernel void kernel_flash_attn_ext_vec( // in the other (non-vec) kernel, we need s_t to also be float because we scale during the soft_max // #define FA_TYPES \ - half4, half4x4, \ - half4x4, \ - half4x4, \ - float, \ - half, half4, half4x4, \ - half4x4 + half4, \ + half4, \ + half4, \ + float, \ + half, half4, \ + half4 -typedef decltype(kernel_flash_attn_ext_vec) flash_attn_ext_vec_t; +typedef decltype(kernel_flash_attn_ext_vec) flash_attn_ext_vec_t; -template [[host_name("kernel_flash_attn_ext_vec_f16_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_f16_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; #if defined(GGML_METAL_USE_BF16) -template [[host_name("kernel_flash_attn_ext_vec_bf16_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_bf16_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; #endif -template [[host_name("kernel_flash_attn_ext_vec_q4_0_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_q4_1_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_q5_0_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_q5_1_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_q8_0_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q4_0_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q4_1_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q5_0_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q5_1_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q8_0_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_f16_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_f16_h192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; #if defined(GGML_METAL_USE_BF16) -template [[host_name("kernel_flash_attn_ext_vec_bf16_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_bf16_h192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; #endif -template [[host_name("kernel_flash_attn_ext_vec_q4_0_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_q4_1_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_q5_0_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_q5_1_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_q8_0_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q4_0_h192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q4_1_h192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q5_0_h192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q5_1_h192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q8_0_h192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; + +template [[host_name("kernel_flash_attn_ext_vec_f16_hk192_hv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +#if defined(GGML_METAL_USE_BF16) +template [[host_name("kernel_flash_attn_ext_vec_bf16_hk192_hv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +#endif +template [[host_name("kernel_flash_attn_ext_vec_q4_0_hk192_hv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q4_1_hk192_hv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q5_0_hk192_hv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q5_1_hk192_hv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q8_0_hk192_hv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; + +template [[host_name("kernel_flash_attn_ext_vec_f16_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +#if defined(GGML_METAL_USE_BF16) +template [[host_name("kernel_flash_attn_ext_vec_bf16_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +#endif +template [[host_name("kernel_flash_attn_ext_vec_q4_0_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q4_1_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q5_0_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q5_1_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q8_0_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; #undef FA_TYPES diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 37fa8eec5..bc16567dc 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -8764,6 +8764,10 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm default: return false; } + if (op->src[1]->ne[0] != op->src[2]->ne[0]) { + // different head sizes of K and V are not supported yet + return false; + } if (op->src[0]->type != GGML_TYPE_F32) { return false; } diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 2e081d591..161dd3fa9 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -4369,7 +4369,7 @@ struct ggml_tensor * ggml_flash_attn_ext( } // permute(0, 2, 1, 3) - int64_t ne[4] = { q->ne[0], q->ne[2], q->ne[1], q->ne[3] }; + int64_t ne[4] = { v->ne[0], q->ne[2], q->ne[1], q->ne[3] }; struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne); float params[] = { scale, max_bias, logit_softcap }; diff --git a/src/llama-context.cpp b/src/llama-context.cpp index aa363df63..9467c3a01 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -2316,11 +2316,6 @@ llama_context * llama_init_from_model( params.flash_attn = false; } - if (params.flash_attn && model->hparams.n_embd_head_k != model->hparams.n_embd_head_v) { - LLAMA_LOG_WARN("%s: flash_attn requires n_embd_head_k == n_embd_head_v - forcing off\n", __func__); - params.flash_attn = false; - } - if (ggml_is_quantized(params.type_v) && !params.flash_attn) { LLAMA_LOG_ERROR("%s: V cache quantization requires flash_attn\n", __func__); return nullptr; diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 28f860a7f..426a9557c 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -3217,7 +3217,8 @@ struct test_leaky_relu : public test_case { // GGML_OP_FLASH_ATTN_EXT struct test_flash_attn_ext : public test_case { - const int64_t hs; // head size + const int64_t hsk; // K head size + const int64_t hsv; // V head size const int64_t nh; // num heads const int64_t nr; // repeat in Q, tests for grouped-query attention const int64_t kv; // kv size @@ -3233,7 +3234,7 @@ struct test_flash_attn_ext : public test_case { std::array permute; std::string vars() override { - return VARS_TO_STR11(hs, nh, nr, kv, nb, mask, max_bias, logit_softcap, prec, type_KV, permute); + return VARS_TO_STR12(hsk, hsv, nh, nr, kv, nb, mask, max_bias, logit_softcap, prec, type_KV, permute); } double max_nmse_err() override { @@ -3243,17 +3244,18 @@ struct test_flash_attn_ext : public test_case { uint64_t op_flops(ggml_tensor * t) override { GGML_UNUSED(t); // Just counting matmul costs: - // Q*K^T is nb x hs x kv, P*V is nb x kv x hs, per head - return 2 * 2 * nh*nr * nb * hs * kv; + // Q*K^T is nb x hsk x kv, P*V is nb x kv x hsv, per head + return 2 * nh*nr * nb * (hsk + hsv) * kv; } - test_flash_attn_ext(int64_t hs = 128, int64_t nh = 32, int64_t nr = 1, int64_t kv = 96, int64_t nb = 8, + test_flash_attn_ext(int64_t hsk = 128, int64_t hsv = 128, int64_t nh = 32, int64_t nr = 1, int64_t kv = 96, int64_t nb = 8, bool mask = true, float max_bias = 0.0f, float logit_softcap = 0.0f, ggml_prec prec = GGML_PREC_F32, ggml_type type_KV = GGML_TYPE_F16, std::array permute = {0, 1, 2, 3}) - : hs(hs), nh(nh), nr(nr), kv(kv), nb(nb), mask(mask), max_bias(max_bias), logit_softcap(logit_softcap), prec(prec), type_KV(type_KV), permute(permute) {} + : hsk(hsk), hsv(hsv), nh(nh), nr(nr), kv(kv), nb(nb), mask(mask), max_bias(max_bias), logit_softcap(logit_softcap), prec(prec), type_KV(type_KV), permute(permute) {} ggml_tensor * build_graph(ggml_context * ctx) override { - const int64_t hs_padded = GGML_PAD(hs, ggml_blck_size(type_KV)); + const int64_t hsk_padded = GGML_PAD(hsk, ggml_blck_size(type_KV)); + const int64_t hsv_padded = GGML_PAD(hsv, ggml_blck_size(type_KV)); auto const &create_permuted = [&](ggml_type type, int64_t ne0, int64_t ne1, int64_t ne2, int64_t ne3) -> ggml_tensor * { int64_t ne[4] = {ne0, ne1, ne2, ne3}; @@ -3268,13 +3270,13 @@ struct test_flash_attn_ext : public test_case { return t; }; - ggml_tensor * q = create_permuted(GGML_TYPE_F32, hs_padded, nb, nh*nr, 1); + ggml_tensor * q = create_permuted(GGML_TYPE_F32, hsk_padded, nb, nh*nr, 1); ggml_set_name(q, "q"); - ggml_tensor * k = create_permuted(type_KV, hs_padded, kv, nh, 1); + ggml_tensor * k = create_permuted(type_KV, hsk_padded, kv, nh, 1); ggml_set_name(k, "k"); - ggml_tensor * v = create_permuted(type_KV, hs_padded, kv, nh, 1); + ggml_tensor * v = create_permuted(type_KV, hsv_padded, kv, nh, 1); ggml_set_name(v, "v"); ggml_tensor * m = nullptr; @@ -3283,7 +3285,7 @@ struct test_flash_attn_ext : public test_case { ggml_set_name(m, "m"); } - ggml_tensor * out = ggml_flash_attn_ext(ctx, q, k, v, m, 1.0f/sqrtf(hs), max_bias, logit_softcap); + ggml_tensor * out = ggml_flash_attn_ext(ctx, q, k, v, m, 1.0f/sqrtf(hsk), max_bias, logit_softcap); ggml_flash_attn_ext_set_prec(out, prec); ggml_set_name(out, "out"); @@ -4412,27 +4414,32 @@ static std::vector> make_test_cases_eval() { test_cases.emplace_back(new test_timestep_embedding()); test_cases.emplace_back(new test_leaky_relu()); - for (int hs : { 64, 80, 128, 256, }) { - for (bool mask : { true, false } ) { - for (float max_bias : { 0.0f, 8.0f }) { - if (!mask && max_bias > 0.0f) continue; - for (float logit_softcap : {0.0f, 10.0f}) { - if (hs != 128 && logit_softcap != 0.0f) continue; - for (int nh : { 4, }) { - for (int nr : { 1, 4, 16 }) { - if (nr == 16 && hs != 128) continue; - for (int kv : { 512, 1024, }) { - if (nr != 1 && kv != 512) continue; - for (int nb : { 1, 3, 32, 35, }) { - for (ggml_prec prec : {GGML_PREC_F32, GGML_PREC_DEFAULT}) { - if (hs != 128 && prec == GGML_PREC_DEFAULT) continue; - for (ggml_type type_KV : {GGML_TYPE_F16, GGML_TYPE_BF16, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0}) { - test_cases.emplace_back(new test_flash_attn_ext( - hs, nh, nr, kv, nb, mask, max_bias, logit_softcap, prec, type_KV)); - // run fewer test cases permuted - if (mask == true && max_bias == 0.0f && logit_softcap == 0 && kv == 512) { + for (int hsk : { 64, 80, 128, 192, 256, }) { + for (int hsv : { 64, 80, 128, 192, 256, }) { + if (hsk != 192 && hsk != hsv) continue; + if (hsk == 192 && (hsv != 128 && hsv != 192)) continue; + + for (bool mask : { true, false } ) { + for (float max_bias : { 0.0f, 8.0f }) { + if (!mask && max_bias > 0.0f) continue; + for (float logit_softcap : {0.0f, 10.0f}) { + if (hsk != 128 && logit_softcap != 0.0f) continue; + for (int nh : { 4, }) { + for (int nr : { 1, 4, 16 }) { + if (nr == 16 && hsk != 128) continue; + for (int kv : { 512, 1024, }) { + if (nr != 1 && kv != 512) continue; + for (int nb : { 1, 3, 32, 35, }) { + for (ggml_prec prec : {GGML_PREC_F32, GGML_PREC_DEFAULT}) { + if (hsk != 128 && prec == GGML_PREC_DEFAULT) continue; + for (ggml_type type_KV : {GGML_TYPE_F16, GGML_TYPE_BF16, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0}) { test_cases.emplace_back(new test_flash_attn_ext( - hs, nh, nr, kv, nb, mask, max_bias, logit_softcap, prec, type_KV, {0, 2, 1, 3})); + hsk, hsv, nh, nr, kv, nb, mask, max_bias, logit_softcap, prec, type_KV)); + // run fewer test cases permuted + if (mask == true && max_bias == 0.0f && logit_softcap == 0 && kv == 512) { + test_cases.emplace_back(new test_flash_attn_ext( + hsk, hsv, nh, nr, kv, nb, mask, max_bias, logit_softcap, prec, type_KV, {0, 2, 1, 3})); + } } } } From 3714c3ee1a62ed64ac328ec7d699410ad1219150 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sigbj=C3=B8rn=20Skj=C3=A6ret?= Date: Fri, 28 Mar 2025 22:13:02 +0100 Subject: [PATCH 26/26] llama : fix incorrect Qwen2Moe ffn_moe_out graph callback (#12631) --- src/llama-model.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/llama-model.cpp b/src/llama-model.cpp index a442abeb8..a4f06112d 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -6323,7 +6323,7 @@ struct llm_build_qwen2moe : public llm_graph_context { false, 0.0, LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX, il); - cb(cur, "ffn_moe_out", il); + cb(moe_out, "ffn_moe_out", il); // FFN shared expert {