diff --git a/.devops/cann.Dockerfile b/.devops/cann.Dockerfile new file mode 100644 index 000000000..02f3e03b5 --- /dev/null +++ b/.devops/cann.Dockerfile @@ -0,0 +1,130 @@ +# ============================================================================== +# ARGUMENTS +# ============================================================================== + +# Define the CANN base image for easier version updates later +ARG CANN_BASE_IMAGE=quay.io/ascend/cann:8.1.rc1-910b-openeuler22.03-py3.10 + +# ============================================================================== +# BUILD STAGE +# Compile all binary files and libraries +# ============================================================================== +FROM ${CANN_BASE_IMAGE} AS build + +# Define the Ascend chip model for compilation. Default is Ascend910B3 +ARG ASCEND_SOC_TYPE=Ascend910B3 + +# -- Install build dependencies -- +RUN yum install -y gcc g++ cmake make git libcurl-devel python3 python3-pip && \ + yum clean all && \ + rm -rf /var/cache/yum + +# -- Set the working directory -- +WORKDIR /app + +# -- Copy project files -- +COPY . . + +# -- Set CANN environment variables (required for compilation) -- +# Using ENV instead of `source` allows environment variables to persist across the entire image layer +ENV ASCEND_TOOLKIT_HOME=/usr/local/Ascend/ascend-toolkit/latest +ENV LD_LIBRARY_PATH=${ASCEND_TOOLKIT_HOME}/lib64:${LD_LIBRARY_PATH} +ENV PATH=${ASCEND_TOOLKIT_HOME}/bin:${PATH} +ENV ASCEND_OPP_PATH=${ASCEND_TOOLKIT_HOME}/opp +ENV LD_LIBRARY_PATH=${ASCEND_TOOLKIT_HOME}/runtime/lib64/stub:$LD_LIBRARY_PATH +# ... You can add other environment variables from the original file as needed ... +# For brevity, only core variables are listed here. You can paste the original ENV list here. + +# -- Build llama.cpp -- +# Use the passed ASCEND_SOC_TYPE argument and add general build options +RUN source /usr/local/Ascend/ascend-toolkit/set_env.sh --force \ + && \ + cmake -B build \ + -DGGML_CANN=ON \ + -DCMAKE_BUILD_TYPE=Release \ + -DSOC_TYPE=${ASCEND_SOC_TYPE} \ + . && \ + cmake --build build --config Release -j$(nproc) + +# -- Organize build artifacts for copying in later stages -- +# Create a lib directory to store all .so files +RUN mkdir -p /app/lib && \ + find build -name "*.so" -exec cp {} /app/lib \; + +# Create a full directory to store all executables and Python scripts +RUN mkdir -p /app/full && \ + cp build/bin/* /app/full/ && \ + cp *.py /app/full/ && \ + cp -r gguf-py /app/full/ && \ + cp -r requirements /app/full/ && \ + cp requirements.txt /app/full/ + # If you have a tools.sh script, make sure it is copied here + # cp .devops/tools.sh /app/full/tools.sh + +# ============================================================================== +# BASE STAGE +# Create a minimal base image with CANN runtime and common libraries +# ============================================================================== +FROM ${CANN_BASE_IMAGE} AS base + +# -- Install runtime dependencies -- +RUN yum install -y libgomp curl && \ + yum clean all && \ + rm -rf /var/cache/yum + +# -- Set CANN environment variables (required for runtime) -- +ENV ASCEND_TOOLKIT_HOME=/usr/local/Ascend/ascend-toolkit/latest +ENV LD_LIBRARY_PATH=/app:${ASCEND_TOOLKIT_HOME}/lib64:${LD_LIBRARY_PATH} +ENV PATH=${ASCEND_TOOLKIT_HOME}/bin:${PATH} +ENV ASCEND_OPP_PATH=${ASCEND_TOOLKIT_HOME}/opp +# ... You can add other environment variables from the original file as needed ... + +WORKDIR /app + +# Copy compiled .so files from the build stage +COPY --from=build /app/lib/ /app + +# ============================================================================== +# FINAL STAGES (TARGETS) +# ============================================================================== + +### Target: full +# Complete image with all tools, Python bindings, and dependencies +# ============================================================================== +FROM base AS full + +COPY --from=build /app/full /app + +# Install Python dependencies +RUN yum install -y git python3 python3-pip && \ + pip3 install --no-cache-dir --upgrade pip setuptools wheel && \ + pip3 install --no-cache-dir -r requirements.txt && \ + yum clean all && \ + rm -rf /var/cache/yum + +# You need to provide a tools.sh script as the entrypoint +ENTRYPOINT ["/app/tools.sh"] +# If there is no tools.sh, you can set the default to start the server +# ENTRYPOINT ["/app/llama-server"] + +### Target: light +# Lightweight image containing only llama-cli +# ============================================================================== +FROM base AS light + +COPY --from=build /app/full/llama-cli /app + +ENTRYPOINT [ "/app/llama-cli" ] + +### Target: server +# Dedicated server image containing only llama-server +# ============================================================================== +FROM base AS server + +ENV LLAMA_ARG_HOST=0.0.0.0 + +COPY --from=build /app/full/llama-server /app + +HEALTHCHECK --interval=5m CMD [ "curl", "-f", "http://localhost:8080/health" ] + +ENTRYPOINT [ "/app/llama-server" ] diff --git a/common/arg.cpp b/common/arg.cpp index 104a6f8ca..3c275325b 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -979,6 +979,10 @@ static bool common_params_parse_ex(int argc, char ** argv, common_params_context for (auto & seq_breaker : params.sampling.dry_sequence_breakers) { string_process_escapes(seq_breaker); } + for (auto & pair : params.speculative.replacements) { + string_process_escapes(pair.first); + string_process_escapes(pair.second); + } } if (!params.kv_overrides.empty()) { @@ -2093,6 +2097,13 @@ common_params_context common_params_parser_init(common_params & params, llama_ex params.no_kv_offload = true; } ).set_env("LLAMA_ARG_NO_KV_OFFLOAD")); + add_opt(common_arg( + {"-nr", "--no-repack"}, + "disable weight repacking", + [](common_params & params) { + params.no_extra_bufts = true; + } + ).set_env("LLAMA_ARG_NO_REPACK")); add_opt(common_arg( {"-ctk", "--cache-type-k"}, "TYPE", string_format( @@ -2371,6 +2382,15 @@ common_params_context common_params_parser_init(common_params & params, llama_ex } } )); + add_opt(common_arg( + {"--cpu-moe"}, + "use CPU for Mixture of Experts (MoE) weights", + [](common_params & params) { + params.tensor_buft_overrides.push_back({"\\.ffn_up_exps\\.weight$", ggml_backend_cpu_buffer_type()}); + params.tensor_buft_overrides.push_back({"\\.ffn_down_exps\\.weight$", ggml_backend_cpu_buffer_type()}); + params.tensor_buft_overrides.push_back({"\\.ffn_gate_exps\\.weight$", ggml_backend_cpu_buffer_type()}); + } + ).set_env("LLAMA_ARG_CPU_MOE")); add_opt(common_arg( {"-ngl", "--gpu-layers", "--n-gpu-layers"}, "N", "number of layers to store in VRAM", @@ -3251,6 +3271,13 @@ common_params_context common_params_parser_init(common_params & params, llama_ex params.speculative.model.path = value; } ).set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_MODEL_DRAFT")); + add_opt(common_arg( + {"--spec-replace"}, "TARGET", "DRAFT", + "translate the string in TARGET into DRAFT if the draft model and main model are not compatible", + [](common_params & params, const std::string & tgt, const std::string & dft) { + params.speculative.replacements.push_back({ tgt, dft }); + } + ).set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SERVER})); add_opt(common_arg( {"-ctkd", "--cache-type-k-draft"}, "TYPE", string_format( @@ -3440,28 +3467,11 @@ common_params_context common_params_parser_init(common_params & params, llama_ex } ).set_examples({LLAMA_EXAMPLE_SERVER})); - // diffusion parameters add_opt(common_arg( { "--diffusion-steps" }, "N", string_format("number of diffusion steps (default: %d)", params.diffusion.steps), [](common_params & params, int value) { params.diffusion.steps = value; } ).set_examples({ LLAMA_EXAMPLE_DIFFUSION })); - add_opt(common_arg( - { "--diffusion-eps" }, "F", - string_format("epsilon for timesteps (default: %.6f)", (double) params.diffusion.eps), - [](common_params & params, const std::string & value) { params.diffusion.eps = std::stof(value); } - ).set_examples({ LLAMA_EXAMPLE_DIFFUSION })); - add_opt(common_arg( - { "--diffusion-algorithm" }, "N", - string_format("diffusion algorithm: 0=ORIGIN, 1=MASKGIT_PLUS, 2=TOPK_MARGIN, 3=ENTROPY (default: %d)", - params.diffusion.algorithm), - [](common_params & params, int value) { params.diffusion.algorithm = value; } - ).set_examples({ LLAMA_EXAMPLE_DIFFUSION })); - add_opt(common_arg( - { "--diffusion-alg-temp" }, "F", - string_format("algorithm temperature (default: %.3f)", (double) params.diffusion.alg_temp), - [](common_params & params, const std::string & value) { params.diffusion.alg_temp = std::stof(value); } - ).set_examples({ LLAMA_EXAMPLE_DIFFUSION })); add_opt(common_arg( { "--diffusion-visual" }, string_format("enable visual diffusion mode (show progressive generation) (default: %s)", @@ -3469,5 +3479,39 @@ common_params_context common_params_parser_init(common_params & params, llama_ex [](common_params & params) { params.diffusion.visual_mode = true; } ).set_examples({ LLAMA_EXAMPLE_DIFFUSION })); + add_opt(common_arg( + { "--diffusion-eps" }, "F", + string_format("epsilon for timesteps (default: %.6f)", (double) params.diffusion.eps), + [](common_params & params, const std::string & value) { params.diffusion.eps = std::stof(value); } + ).set_examples({ LLAMA_EXAMPLE_DIFFUSION })); + add_opt(common_arg( + { "--diffusion-algorithm" }, "N", + string_format("diffusion algorithm: 0=ORIGIN, 1=ENTROPY_BASED, 2=MARGIN_BASED, 3=RANDOM, 4=LOW_CONFIDENCE (default: %d)", + params.diffusion.algorithm), + [](common_params & params, int value) { params.diffusion.algorithm = value; } + ).set_examples({ LLAMA_EXAMPLE_DIFFUSION })); + add_opt(common_arg( + { "--diffusion-alg-temp" }, "F", + string_format("dream algorithm temperature (default: %.3f)", (double) params.diffusion.alg_temp), + [](common_params & params, const std::string & value) { params.diffusion.alg_temp = std::stof(value); } + ).set_examples({ LLAMA_EXAMPLE_DIFFUSION })); + + add_opt(common_arg( + { "--diffusion-block-length" }, "N", + string_format("llada block length for generation (default: %d)", params.diffusion.block_length), + [](common_params & params, int value) { params.diffusion.block_length = value; } + ).set_examples({ LLAMA_EXAMPLE_DIFFUSION })); + add_opt(common_arg( + { "--diffusion-cfg-scale" }, "F", + string_format("llada classifier-free guidance scale (default: %.3f)", (double) params.diffusion.cfg_scale), + [](common_params & params, const std::string & value) { params.diffusion.cfg_scale = std::stof(value); } + ).set_examples({ LLAMA_EXAMPLE_DIFFUSION })); + add_opt(common_arg( + { "--diffusion-add-gumbel-noise" }, "F", + string_format("add gumbel noise to the logits if temp > 0.0 (default: %s)", params.diffusion.add_gumbel_noise ? "true" : "false"), + [](common_params & params, const std::string & value) { params.diffusion.add_gumbel_noise = std::stof(value); } + ).set_examples({ LLAMA_EXAMPLE_DIFFUSION })); + + return ctx_arg; } diff --git a/common/common.cpp b/common/common.cpp index 871158ae3..27e948463 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1130,6 +1130,7 @@ struct llama_model_params common_model_params_to_llama(common_params & params) { mparams.use_mmap = params.use_mmap; mparams.use_mlock = params.use_mlock; mparams.check_tensors = params.check_tensors; + mparams.use_extra_bufts = !params.no_extra_bufts; if (params.kv_overrides.empty()) { mparams.kv_overrides = NULL; diff --git a/common/common.h b/common/common.h index 6ba1df613..0eb663e2a 100644 --- a/common/common.h +++ b/common/common.h @@ -197,6 +197,7 @@ struct common_params_speculative { int32_t n_gpu_layers = -1; // number of layers to store in VRAM for the draft model (-1 - use default) float p_split = 0.1f; // speculative decoding split probability float p_min = 0.75f; // minimum speculative decoding probability (greedy) + std::vector> replacements; // main to speculative model replacements ggml_type cache_type_k = GGML_TYPE_F16; // KV cache data type for the K ggml_type cache_type_v = GGML_TYPE_F16; // KV cache data type for the V @@ -216,11 +217,17 @@ struct common_params_vocoder { }; struct common_params_diffusion { - int32_t steps = 64; // number of diffusion steps - float eps = 1e-3f; // epsilon for timesteps - int32_t algorithm = 0; // diffusion algorithm (0=ORIGIN, 1=MASKGIT_PLUS, 2=TOPK_MARGIN, 3=ENTROPY) - float alg_temp = 0.0f; // algorithm temperature - bool visual_mode = false; // show progressive diffusion on screen + int32_t steps = 128; + bool visual_mode = false; + + float eps = 0; // epsilon for timesteps + int32_t block_length = 0; // block length for generation + + int32_t algorithm = 4; // default algorithm: low-confidence + float alg_temp = 0.0f; // algorithm temperature + + float cfg_scale = 0; // classifier-free guidance scale + bool add_gumbel_noise = false; // add gumbel noise to the logits if temp > 0.0 }; enum common_reasoning_format { @@ -348,6 +355,7 @@ struct common_params { bool warmup = true; // warmup run bool check_tensors = false; // validate tensor data bool no_op_offload = false; // globally disable offload host tensor operations to device + bool no_extra_bufts = false; // disable extra buffer types (used for weight repacking) bool single_turn = false; // single turn chat conversation diff --git a/common/speculative.cpp b/common/speculative.cpp index 843bd1ddb..262b2c23e 100644 --- a/common/speculative.cpp +++ b/common/speculative.cpp @@ -1,30 +1,39 @@ #include "speculative.h" +#include "ggml.h" +#include "llama.h" #include "log.h" #include "common.h" #include "sampling.h" #include #include +#include #define SPEC_VOCAB_MAX_SIZE_DIFFERENCE 128 #define SPEC_VOCAB_CHECK_START_TOKEN_ID 5 struct common_speculative { - struct llama_context * ctx; + struct llama_context * ctx_tgt; // only used for retokenizing from ctx_dft + struct llama_context * ctx_dft; struct common_sampler * smpl; llama_batch batch; - llama_tokens prompt; + llama_tokens prompt_dft; + bool vocab_dft_compatible = true; // whether retokenization is needed + std::map tgt_dft_replacements = {}; }; struct common_speculative * common_speculative_init( + struct llama_context * ctx_tgt, struct llama_context * ctx_dft) { auto * result = new common_speculative { - /* .ctx = */ ctx_dft, - /* .smpl = */ nullptr, - /* .batch = */ llama_batch_init(llama_n_batch(ctx_dft), 0, 1), - /* .prompt = */ {}, + /* .ctx_tgt = */ ctx_tgt, + /* .ctx_dft = */ ctx_dft, + /* .smpl = */ nullptr, + /* .batch = */ llama_batch_init(llama_n_batch(ctx_dft), 0, 1), + /* .prompt_dft = */ {}, + /* .vocab_dft_compatible = */ false, }; // TODO: optimize or pass from outside? @@ -59,6 +68,9 @@ struct common_speculative * common_speculative_init( } #endif + result->vocab_dft_compatible = common_speculative_are_compatible(ctx_tgt, ctx_dft); + LOG_DBG("vocab_dft_compatible = %d\n", result->vocab_dft_compatible); + return result; } @@ -75,8 +87,8 @@ void common_speculative_free(struct common_speculative * spec) { } bool common_speculative_are_compatible( - const struct llama_context * ctx_tgt, - const struct llama_context * ctx_dft) { + const struct llama_context * ctx_tgt, + const struct llama_context * ctx_dft) { const struct llama_model * model_tgt = llama_get_model(ctx_tgt); const struct llama_model * model_dft = llama_get_model(ctx_dft); @@ -90,31 +102,32 @@ bool common_speculative_are_compatible( LOG_DBG("%s: vocab_type dft: %d\n", __func__, vocab_type_dft); if (vocab_type_tgt != vocab_type_dft) { - LOG_ERR("%s: draft model vocab type must match target model to use speculation but " - "vocab_type_dft = %d while vocab_type_tgt = %d\n", __func__, vocab_type_dft, vocab_type_tgt); + LOG_DBG("%s: draft model vocab type must match target model to use speculation but ", __func__); + LOG_DBG("vocab_type_dft = %d while vocab_type_tgt = %d\n", vocab_type_dft, vocab_type_tgt); return false; } - if (llama_vocab_get_add_bos(vocab_tgt) != llama_vocab_get_add_bos(vocab_dft) || + if ( + llama_vocab_get_add_bos(vocab_tgt) != llama_vocab_get_add_bos(vocab_dft) || llama_vocab_get_add_eos(vocab_tgt) != llama_vocab_get_add_eos(vocab_dft) || llama_vocab_bos(vocab_tgt) != llama_vocab_bos(vocab_dft) || - llama_vocab_eos(vocab_tgt) != llama_vocab_eos(vocab_dft)) { - LOG_ERR("%s: draft vocab special tokens must match target vocab to use speculation\n", __func__); - LOG_ERR("%s: tgt: bos = %d (%d), eos = %d (%d)\n", __func__, llama_vocab_bos(vocab_tgt), llama_vocab_get_add_bos(vocab_tgt), llama_vocab_eos(vocab_tgt), llama_vocab_get_add_eos(vocab_tgt)); - LOG_ERR("%s: dft: bos = %d (%d), eos = %d (%d)\n", __func__, llama_vocab_bos(vocab_dft), llama_vocab_get_add_bos(vocab_dft), llama_vocab_eos(vocab_dft), llama_vocab_get_add_eos(vocab_dft)); + llama_vocab_eos(vocab_tgt) != llama_vocab_eos(vocab_dft) + ) { + LOG_DBG("%s: draft model special tokens must match target model to use speculation\n", __func__); return false; } { const int n_vocab_tgt = llama_vocab_n_tokens(vocab_tgt); const int n_vocab_dft = llama_vocab_n_tokens(vocab_dft); - - const int vocab_diff = std::abs(n_vocab_tgt - n_vocab_dft); + const int vocab_diff = n_vocab_tgt > n_vocab_dft + ? n_vocab_tgt - n_vocab_dft + : n_vocab_dft - n_vocab_tgt; if (vocab_diff > SPEC_VOCAB_MAX_SIZE_DIFFERENCE) { - LOG_ERR("%s: draft model vocab must closely match target model to use speculation but " - "target vocab size %d does not match draft vocab size %d - difference %d, max allowed %d\n", - __func__, n_vocab_tgt, llama_vocab_n_tokens(vocab_dft), vocab_diff, SPEC_VOCAB_MAX_SIZE_DIFFERENCE); + LOG_DBG("%s: draft model vocab must closely match target model to use speculation but ", __func__); + LOG_DBG("target vocab size %d does not match draft vocab size %d - difference %d, max allowed %d\n", + n_vocab_tgt, llama_vocab_n_tokens(vocab_dft), vocab_diff, SPEC_VOCAB_MAX_SIZE_DIFFERENCE); return false; } @@ -122,8 +135,8 @@ bool common_speculative_are_compatible( const char * token_text_tgt = llama_vocab_get_text(vocab_tgt, i); const char * token_text_dft = llama_vocab_get_text(vocab_dft, i); if (std::strcmp(token_text_tgt, token_text_dft) != 0) { - LOG_ERR("%s: draft vocab vocab must match target vocab to use speculation but " - "token %d content differs - target '%s', draft '%s'\n", __func__, i, + LOG_DBG("%s: draft model vocab must match target model to use speculation but ", __func__); + LOG_DBG("token %d content differs - target '%s', draft '%s'\n", i, common_token_to_piece(ctx_tgt, i).c_str(), common_token_to_piece(ctx_dft, i).c_str()); return false; @@ -134,32 +147,93 @@ bool common_speculative_are_compatible( return true; } +void common_speculative_add_replacement_tgt_dft( + struct common_speculative * spec, + const char *source, const char *dest) { + spec->tgt_dft_replacements[source] = dest; +} + +static std::string replace_to_dft( + struct common_speculative * spec, + const std::string& input) { + std::string result = input; + for (const auto & pair : spec->tgt_dft_replacements) { + size_t pos = result.find(pair.first); + while (pos != std::string::npos) { + result.replace(pos, pair.first.length(), pair.second); + pos = result.find(pair.first, pos + pair.second.length()); + } + } + return result; +} + +static std::string replace_to_tgt( + struct common_speculative * spec, + const std::string& input) { + std::string result = input; + for (const auto& pair : spec->tgt_dft_replacements) { + size_t pos = result.find(pair.second); + while (pos != std::string::npos) { + result.replace(pos, pair.second.length(), pair.first); + pos = result.find(pair.second, pos + pair.first.length()); + } + } + return result; +} + + llama_tokens common_speculative_gen_draft( struct common_speculative * spec, struct common_speculative_params params, - const llama_tokens & prompt_tgt, + const llama_tokens & prompt_tgt_main_model, // specified in target model vocab llama_token id_last) { auto & batch = spec->batch; - auto & ctx = spec->ctx; + auto & ctx_tgt = spec->ctx_tgt; + auto & ctx_dft = spec->ctx_dft; auto & smpl = spec->smpl; - auto & prompt = spec->prompt; + auto & prompt_dft = spec->prompt_dft; - auto * mem = llama_get_memory(ctx); + auto * mem_dft = llama_get_memory(ctx_dft); int reuse_i = 0; int reuse_n = 0; - const int n_ctx = llama_n_ctx(ctx) - params.n_draft; + const int n_ctx = llama_n_ctx(ctx_dft) - params.n_draft; + + llama_tokens prompt_tgt_draft_model; + if (!spec->vocab_dft_compatible) { + std::string text; + text = common_detokenize(ctx_tgt, prompt_tgt_main_model, true); + text = replace_to_dft(spec, text); + LOG_DBG("%s: main->draft detokenized string: '%s'\n", __func__, text.c_str()); + prompt_tgt_draft_model = common_tokenize(ctx_dft, text, false, true); + + // convert id_last to draft vocab. llama_detokenize is called directly to avoid an allocation + const auto * model_tgt = llama_get_model(ctx_tgt); + const auto * vocab_tgt = llama_model_get_vocab(model_tgt); + + int32_t n_chars = llama_detokenize(vocab_tgt, &id_last, 1, nullptr, 0, false, false); + GGML_ASSERT(n_chars < 0 && "failed to detokenize id_last"); + text.resize(-n_chars); + llama_detokenize(vocab_tgt, &id_last, 1, text.data(), text.size(), false, false); + text = replace_to_dft(spec, text); + + LOG_DBG("main->draft detokenized id_last(%d): '%s'\n", id_last, text.c_str()); + id_last = common_tokenize(ctx_dft, text, false, true)[0]; + } + // prompt_tgt's tokens will always be compatible with ctx_dft + const llama_tokens &prompt_tgt = + spec->vocab_dft_compatible ? prompt_tgt_main_model : prompt_tgt_draft_model; const int i_start = std::max(0, (int) prompt_tgt.size() - n_ctx); // reuse as much as possible from the old draft context // ideally, the draft context should be as big as the target context and we will always reuse the entire prompt - for (int i = 0; i < (int) prompt.size(); ++i) { + for (int i = 0; i < (int) prompt_dft.size(); ++i) { int cur = 0; while (i_start + cur < (int) prompt_tgt.size() && - i + cur < (int) prompt.size() && - prompt_tgt[i_start + cur] == prompt[i + cur]) { + i + cur < (int) prompt_dft.size() && + prompt_tgt[i_start + cur] == prompt_dft[i + cur]) { cur++; } @@ -169,21 +243,20 @@ llama_tokens common_speculative_gen_draft( } } - LOG_DBG("%s: reuse_i = %d, reuse_n = %d, prompt = %d\n", __func__, reuse_i, reuse_n, (int) prompt.size()); + LOG_DBG("%s: reuse_i = %d, reuse_n = %d, prompt = %d\n", __func__, reuse_i, reuse_n, (int) prompt_dft.size()); llama_tokens result; result.reserve(params.n_draft); if (reuse_n == 0) { - llama_memory_clear(mem, false); - - prompt.clear(); + llama_memory_clear(mem_dft, false); + prompt_dft.clear(); } else { // this happens when a previous draft has been discarded (for example, due to being too small), but the // target model agreed with it. in this case, we simply pass back the previous results to save compute - if (reuse_i + reuse_n < (int) prompt.size() && prompt[reuse_i + reuse_n] == id_last) { - for (int i = reuse_i + reuse_n + 1; i < (int) prompt.size(); ++i) { - result.push_back(prompt[i]); + if (reuse_i + reuse_n < (int) prompt_dft.size() && prompt_dft[reuse_i + reuse_n] == id_last) { + for (int i = reuse_i + reuse_n + 1; i < (int) prompt_dft.size(); ++i) { + result.push_back(prompt_dft[i]); if (params.n_draft <= (int) result.size()) { break; @@ -194,16 +267,15 @@ llama_tokens common_speculative_gen_draft( } if (reuse_i > 0) { - llama_memory_seq_rm (mem, 0, 0, reuse_i); - llama_memory_seq_add(mem, 0, reuse_i, -1, -reuse_i); + llama_memory_seq_rm (mem_dft, 0, 0, reuse_i); + llama_memory_seq_add(mem_dft, 0, reuse_i, -1, -reuse_i); - prompt.erase(prompt.begin(), prompt.begin() + reuse_i); + prompt_dft.erase(prompt_dft.begin(), prompt_dft.begin() + reuse_i); } - if (reuse_n < (int) prompt.size()) { - llama_memory_seq_rm (mem, 0, reuse_n, -1); - - prompt.erase(prompt.begin() + reuse_n, prompt.end()); + if (reuse_n < (int) prompt_dft.size()) { + llama_memory_seq_rm (mem_dft, 0, reuse_n, -1); + prompt_dft.erase(prompt_dft.begin() + reuse_n, prompt_dft.end()); } } @@ -214,28 +286,28 @@ llama_tokens common_speculative_gen_draft( //LOG_DBG("i = %d, i_start = %d, reuse_n = %d, i - i_start = %d, id = %6d\n", i, i_start, reuse_n, i - i_start, prompt_tgt[i]); common_batch_add(batch, prompt_tgt[i], i - i_start, { 0 }, false); - prompt.push_back(prompt_tgt[i]); + prompt_dft.push_back(prompt_tgt[i]); } // we should rarely end-up here during normal decoding if (batch.n_tokens > 0) { //LOG_DBG("%s: draft prompt batch: %s\n", __func__, string_from(ctx, batch).c_str()); - llama_decode(ctx, batch); + llama_decode(ctx_dft, batch); } - const llama_pos n_past = prompt.size(); + const llama_pos n_past = prompt_dft.size(); LOG_DBG("%s: n_past = %d\n", __func__, n_past); common_batch_clear(batch); common_batch_add (batch, id_last, n_past, { 0 }, true); - prompt.push_back(id_last); + prompt_dft.push_back(id_last); - //LOG_DBG("%s: draft prompt: %s\n", __func__, string_from(ctx, prompt).c_str()); + LOG_DBG("%s: draft prompt: %s\n", __func__, string_from(ctx_dft, prompt_dft).c_str()); - llama_decode(ctx, batch); + llama_decode(ctx_dft, batch); common_sampler_reset(smpl); @@ -243,13 +315,13 @@ llama_tokens common_speculative_gen_draft( for (int i = 0; i < params.n_draft; ++i) { common_batch_clear(batch); - common_sampler_sample(smpl, ctx, 0, true); + common_sampler_sample(smpl, ctx_dft, 0, true); const auto * cur_p = common_sampler_get_candidates(smpl); for (int k = 0; k < std::min(3, (int) cur_p->size); ++k) { LOG_DBG(" - draft candidate %3d, pos %3d: %6d (%8.3f) '%s'\n", - k, i, cur_p->data[k].id, cur_p->data[k].p, common_token_to_piece(ctx, cur_p->data[k].id).c_str()); + k, i, cur_p->data[k].id, cur_p->data[k].p, common_token_to_piece(ctx_dft, cur_p->data[k].id).c_str()); } // add drafted token for each sequence @@ -271,10 +343,19 @@ llama_tokens common_speculative_gen_draft( common_batch_add(batch, id, n_past + i + 1, { 0 }, true); // evaluate the drafted tokens on the draft model - llama_decode(ctx, batch); + llama_decode(ctx_dft, batch); - prompt.push_back(id); + prompt_dft.push_back(id); } + if (!spec->vocab_dft_compatible) { + std::string detokenized = common_detokenize(ctx_dft, result, true); + detokenized = replace_to_tgt(spec, detokenized); + LOG_DBG("draft->main detokenized string: '%s'\n", detokenized.c_str()); + result = common_tokenize(ctx_tgt, detokenized, false, true); + if (result.size() > (size_t)params.n_draft) { + result.resize(params.n_draft); + } + } return result; } diff --git a/common/speculative.h b/common/speculative.h index 2b51a70ca..e69d7aaa1 100644 --- a/common/speculative.h +++ b/common/speculative.h @@ -12,7 +12,10 @@ struct common_speculative_params { float p_min = 0.75f; // min probability required to accept a token in the draft }; -struct common_speculative * common_speculative_init(struct llama_context * ctx_dft); +struct common_speculative * common_speculative_init( + struct llama_context * ctx_tgt, + struct llama_context * ctx_dft +); void common_speculative_free(struct common_speculative * spec); @@ -20,6 +23,10 @@ bool common_speculative_are_compatible( const struct llama_context * ctx_tgt, const struct llama_context * ctx_dft); +void common_speculative_add_replacement_tgt_dft( + struct common_speculative * spec, + const char *source, const char *dest); + // sample up to n_draft tokens and add them to the batch using the draft model llama_tokens common_speculative_gen_draft( struct common_speculative * spec, diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 3f5cefe00..feef03d1c 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -684,6 +684,9 @@ class TextModel(ModelBase): if chkhsh == "7e57df22b1fe23a7b1e1c7f3dc4e3f96d43a4eb0836d0c6bdc3436d7b2f1c664": # ref: https://huggingface.co/tencent/Hunyuan-A13B-Instruct res = "hunyuan" + if chkhsh == "bba3b3366b646dbdded5dbc42d59598b849371afc42f7beafa914afaa5b70aa6": + # ref: https://huggingface.co/tencent/Hunyuan-4B-Instruct + res = "hunyuan-dense" if chkhsh == "a6b57017d60e6edb4d88ecc2845188e0eb333a70357e45dcc9b53964a73bbae6": # ref: https://huggingface.co/tiiuae/Falcon-H1-0.5B-Base res = "falcon-h1" @@ -2904,6 +2907,107 @@ class DreamModel(TextModel): yield from super().modify_tensors(data_torch, name, bid) +@ModelBase.register("LLaDAModelLM") +class LLaDAModel(TextModel): + model_arch = gguf.MODEL_ARCH.LLADA + undo_permute = True + + def get_vocab_base(self) -> tuple[list[str], list[int], str]: + tokens: list[str] = [] + toktypes: list[int] = [] + + from transformers import AutoTokenizer + tokenizer = AutoTokenizer.from_pretrained(self.dir_model, trust_remote_code=True) + + vocab_dict = tokenizer.get_vocab() + vocab_size = self.hparams.get("vocab_size", len(vocab_dict)) + assert max(vocab_dict.values()) < vocab_size + + tokpre = self.get_vocab_base_pre(tokenizer) + + reverse_vocab = {id_: encoded_tok for encoded_tok, id_ in vocab_dict.items()} + added_vocab = tokenizer.get_added_vocab() + + for i in range(vocab_size): + if i not in reverse_vocab: + tokens.append(f"[PAD{i}]") + toktypes.append(gguf.TokenType.UNUSED) + elif reverse_vocab[i] in added_vocab: + tokens.append(reverse_vocab[i]) + # Check if it's a special token - treat special tokens as CONTROL tokens + if hasattr(tokenizer, 'added_tokens_decoder') and i in tokenizer.added_tokens_decoder: + if tokenizer.added_tokens_decoder[i].special: + toktypes.append(gguf.TokenType.CONTROL) + else: + toktypes.append(gguf.TokenType.USER_DEFINED) + else: + # Fallback: treat all added vocab as control tokens for special tokens like <|im_start|> + toktypes.append(gguf.TokenType.CONTROL) + else: + tokens.append(reverse_vocab[i]) + toktypes.append(gguf.TokenType.NORMAL) + + return tokens, toktypes, tokpre + + def set_vocab(self): + self._set_vocab_gpt2() + + # LLaDA specific parameters + self.gguf_writer.add_add_bos_token(True) + + def set_gguf_parameters(self): + super().set_gguf_parameters() + self._try_set_pooling_type() + + # Add parameters similar to LlamaModel + hparams = self.hparams + self.gguf_writer.add_vocab_size(hparams["vocab_size"]) + + if (rope_dim := hparams.get("head_dim")) is None: + n_heads = hparams.get("num_attention_heads", hparams.get("n_heads")) + rope_dim = hparams.get("hidden_size", hparams.get("d_model")) // n_heads + self.gguf_writer.add_rope_dimension_count(rope_dim) + + # Set context length for LLaDA + context_length = self.hparams.get("max_sequence_length", 4096) + self.gguf_writer.add_context_length(context_length) + + # Set embedding length (dimension size) + embedding_length = self.hparams.get("d_model", 4096) + self.gguf_writer.add_embedding_length(embedding_length) + + # Set feed forward length (MLP hidden size) + feed_forward_length = self.hparams.get("mlp_hidden_size", 12288) + self.gguf_writer.add_feed_forward_length(feed_forward_length) + + # LLaDA models use non-causal attention for diffusion, similar to Dream + self.gguf_writer.add_causal_attention(False) + + # LLaDA models don't shift their logits + self.gguf_writer.add_diffusion_shift_logits(False) + + @staticmethod + def permute(weights: Tensor, n_head: int, n_head_kv: int | None): + if n_head_kv is not None and n_head != n_head_kv: + n_head = n_head_kv + return (weights.reshape(n_head, 2, weights.shape[0] // n_head // 2, *weights.shape[1:]) + .swapaxes(1, 2) + .reshape(weights.shape)) + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + n_head = self.hparams.get("num_attention_heads", self.hparams.get("n_heads")) + n_kv_head = self.hparams.get("num_key_value_heads", self.hparams.get("n_kv_heads")) + + if self.undo_permute: + if name.endswith(("q_proj.weight", "q_proj.bias")): + data_torch = LLaDAModel.permute(data_torch, n_head, n_head) + if name.endswith(("k_proj.weight", "k_proj.bias")): + data_torch = LLaDAModel.permute(data_torch, n_head, n_kv_head) + + # LLaDA model tensors should be mapped directly since it's the base model + yield from super().modify_tensors(data_torch, name, bid) + + @ModelBase.register("Ernie4_5_ForCausalLM") class Ernie4_5Model(TextModel): model_arch = gguf.MODEL_ARCH.ERNIE4_5 @@ -7452,11 +7556,6 @@ class FalconH1Model(Mamba2Model): class HunYuanMoEModel(TextModel): model_arch = gguf.MODEL_ARCH.HUNYUAN_MOE - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - # For handling tied embeddings - self._tok_embd = None - def set_vocab(self): from transformers import AutoTokenizer tokenizer = AutoTokenizer.from_pretrained(self.dir_model, trust_remote_code=True) @@ -7550,9 +7649,6 @@ class HunYuanMoEModel(TextModel): _experts: list[dict[str, Tensor]] | None = None def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: - if name == "model.embed_tokens.weight": - self._tok_embd = data_torch.clone() - if name == "lm_head.weight": if self.hparams.get("tie_word_embeddings", False): logger.info("Skipping tied output layer 'lm_head.weight'") @@ -7597,6 +7693,98 @@ class HunYuanMoEModel(TextModel): raise ValueError(f"Unprocessed experts: {experts}") +@ModelBase.register("HunYuanDenseV1ForCausalLM") +class HunYuanModel(TextModel): + model_arch = gguf.MODEL_ARCH.HUNYUAN_DENSE + + def set_vocab(self): + if (self.dir_model / "tokenizer.json").is_file(): + self._set_vocab_gpt2() + else: + from transformers import AutoTokenizer + tokenizer = AutoTokenizer.from_pretrained(self.dir_model, trust_remote_code=True) + + # 1. Get the pre-tokenizer identifier hash + tokpre = self.get_vocab_base_pre(tokenizer) + + # 2. Reverse-engineer the merges list from mergeable_ranks + merges = [] + vocab = {} + mergeable_ranks = tokenizer.mergeable_ranks + for token, rank in mergeable_ranks.items(): + vocab[QwenModel.token_bytes_to_string(token)] = rank + if len(token) == 1: + continue + merged = QwenModel.bpe(mergeable_ranks, token, max_rank=rank) + if len(merged) == 2: + merges.append(' '.join(map(QwenModel.token_bytes_to_string, merged))) + + # 3. Generate the tokens and toktypes lists + vocab_size = self.hparams["vocab_size"] + assert tokenizer.vocab_size == vocab_size + special_tokens = tokenizer.special_tokens + reverse_vocab = {id_ : encoded_tok for encoded_tok, id_ in {**vocab, **special_tokens}.items()} + tokens: list[str] = [] + toktypes: list[int] = [] + for i in range(vocab_size): + if i not in reverse_vocab: + tokens.append(f"[PAD{i}]") + toktypes.append(gguf.TokenType.UNUSED) + else: + token = reverse_vocab[i] + tokens.append(token) + if i in special_tokens.values(): + toktypes.append(gguf.TokenType.CONTROL) + else: + toktypes.append(gguf.TokenType.NORMAL) + + # 4. Write all vocab-related fields to the GGUF writer + self.gguf_writer.add_tokenizer_model("gpt2") + self.gguf_writer.add_tokenizer_pre(tokpre) + self.gguf_writer.add_token_list(tokens) + self.gguf_writer.add_token_types(toktypes) + self.gguf_writer.add_token_merges(merges) + + # 5. Add special tokens and chat templates + special_vocab = gguf.SpecialVocab(self.dir_model, load_merges=False) + special_vocab.add_to_gguf(self.gguf_writer) + # FIX for BOS token: Overwrite incorrect id read from config.json + if self.hparams['hidden_size'] == 4096: + self.gguf_writer.add_bos_token_id(127958) # only for 7b dense, fix <|bos|> token + + def set_gguf_parameters(self): + super().set_gguf_parameters() + hparams = self.hparams + + # Rope + rope_scaling = hparams.get("rope_scaling", {}) + if rope_scaling.get("type") == "dynamic": + # HunYuan uses NTK Aware Alpha based scaling. Original implementation: https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/ + # 1000 corresponds to a usable context length of 256k (https://github.com/Tencent-Hunyuan/Hunyuan-A13B/blob/main/report/Hunyuan_A13B_Technical_Report.pdf) + alpha = rope_scaling.get("alpha", 50) + base = hparams.get("rope_theta", 10000.0) + dim = hparams["head_dim"] + scaled_base = base * (alpha ** (dim / (dim - 2))) + self.gguf_writer.add_rope_freq_base(scaled_base) + self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.NONE) + self.gguf_writer.add_rope_scaling_factor(1) + # There is no consistent way to calculate ctx from alpha, and the config is incorrectly set to 32k + self.gguf_writer.add_rope_scaling_orig_ctx_len(256 * 1024) # 256k context length + self.gguf_writer.add_context_length(256 * 1024) # 256k context length + + # if any of our assumptions about the values are wrong, something has changed and this may need to be updated + assert base == 10000.0 and self.hparams["max_position_embeddings"] in [32 * 1024, 256 * 1024] , \ + "HunYuan dynamic RoPE scaling assumptions changed, please update the logic or context length manually" + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + if name == "lm_head.weight": + if self.hparams.get("tie_word_embeddings", False): + logger.info("Skipping tied output layer 'lm_head.weight'") + return [] + + return [(self.map_tensor_name(name), data_torch)] + + @ModelBase.register("SmolLM3ForCausalLM") class SmolLM3Model(LlamaModel): model_arch = gguf.MODEL_ARCH.SMOLLM3 diff --git a/convert_hf_to_gguf_update.py b/convert_hf_to_gguf_update.py index abaf2ea9a..c4904b539 100755 --- a/convert_hf_to_gguf_update.py +++ b/convert_hf_to_gguf_update.py @@ -140,6 +140,7 @@ pre_computed_hashes = [ {"name": "glm4", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/THUDM/glm-4-9b-hf", "chkhsh": "a1336059768a55c99a734006ffb02203cd450fed003e9a71886c88acf24fdbc2"}, {"name": "minerva-7b", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/sapienzanlp/Minerva-7B-base-v1.0", "chkhsh": "1431a23e583c97432bc230bff598d103ddb5a1f89960c8f1d1051aaa944d0b35"}, {"name": "hunyuan", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/tencent/Hunyuan-A13B-Instruct", "chkhsh": "7e57df22b1fe23a7b1e1c7f3dc4e3f96d43a4eb0836d0c6bdc3436d7b2f1c664"}, + {"name": "hunyuan-dense", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/tencent/Hunyuan-4B-Instruct", "chkhsh": "bba3b3366b646dbdded5dbc42d59598b849371afc42f7beafa914afaa5b70aa6"}, # falcon-h1 series uses 4 different tokenizers across model sizes (0.5b - 34b), hence we need to define 4 different hashes {"name": "falcon-h1", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/tiiuae/Falcon-H1-0.5B-Base", "chkhsh": "a6b57017d60e6edb4d88ecc2845188e0eb333a70357e45dcc9b53964a73bbae6"}, {"name": "falcon-h1", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/tiiuae/Falcon-H1-1B-Base", "chkhsh": "60476e1243776c4fb1b993dbd7a5f15ac22f83c80afdf425fa5ae01c8d44ef86"}, diff --git a/docs/multimodal/minicpmo4.0.md b/docs/multimodal/minicpmo4.0.md new file mode 100644 index 000000000..49125ea05 --- /dev/null +++ b/docs/multimodal/minicpmo4.0.md @@ -0,0 +1,47 @@ +## MiniCPM-o 4 + +### Prepare models and code + +Download [MiniCPM-o-4](https://huggingface.co/openbmb/MiniCPM-o-4) PyTorch model from huggingface to "MiniCPM-o-4" folder. + + +### Build llama.cpp +Readme modification time: 20250206 + +If there are differences in usage, please refer to the official build [documentation](https://github.com/ggerganov/llama.cpp/blob/master/docs/build.md) + +Clone llama.cpp: +```bash +git clone https://github.com/ggerganov/llama.cpp +cd llama.cpp +``` + +Build llama.cpp using `CMake`: +```bash +cmake -B build +cmake --build build --config Release +``` + + +### Usage of MiniCPM-o 4 + +Convert PyTorch model to gguf files (You can also download the converted [gguf](https://huggingface.co/openbmb/MiniCPM-o-4-gguf) by us) + +```bash +python ./tools/mtmd/legacy-models/minicpmv-surgery.py -m ../MiniCPM-o-4 +python ./tools/mtmd/legacy-models/minicpmv-convert-image-encoder-to-gguf.py -m ../MiniCPM-o-4 --minicpmv-projector ../MiniCPM-o-4/minicpmv.projector --output-dir ../MiniCPM-o-4/ --minicpmv_version 6 +python ./convert_hf_to_gguf.py ../MiniCPM-o-4/model + +# quantize int4 version +./build/bin/llama-quantize ../MiniCPM-o-4/model/ggml-model-f16.gguf ../MiniCPM-o-4/model/ggml-model-Q4_K_M.gguf Q4_K_M +``` + + +Inference on Linux or Mac +```bash +# run in single-turn mode +./build/bin/llama-mtmd-cli -m ../MiniCPM-o-4/model/ggml-model-f16.gguf --mmproj ../MiniCPM-o-4/mmproj-model-f16.gguf -c 4096 --temp 0.7 --top-p 0.8 --top-k 100 --repeat-penalty 1.05 --image xx.jpg -p "What is in the image?" + +# run in conversation mode +./build/bin/llama-mtmd-cli -m ../MiniCPM-o-4/model/ggml-model-Q4_K_M.gguf --mmproj ../MiniCPM-o-4/mmproj-model-f16.gguf +``` diff --git a/docs/multimodal/minicpmv4.0.md b/docs/multimodal/minicpmv4.0.md new file mode 100644 index 000000000..65887d960 --- /dev/null +++ b/docs/multimodal/minicpmv4.0.md @@ -0,0 +1,47 @@ +## MiniCPM-V 4 + +### Prepare models and code + +Download [MiniCPM-V-4](https://huggingface.co/openbmb/MiniCPM-V-4) PyTorch model from huggingface to "MiniCPM-V-4" folder. + + +### Build llama.cpp +Readme modification time: 20250206 + +If there are differences in usage, please refer to the official build [documentation](https://github.com/ggerganov/llama.cpp/blob/master/docs/build.md) + +Clone llama.cpp: +```bash +git clone https://github.com/ggerganov/llama.cpp +cd llama.cpp +``` + +Build llama.cpp using `CMake`: +```bash +cmake -B build +cmake --build build --config Release +``` + + +### Usage of MiniCPM-V 4 + +Convert PyTorch model to gguf files (You can also download the converted [gguf](https://huggingface.co/openbmb/MiniCPM-V-4-gguf) by us) + +```bash +python ./tools/mtmd/legacy-models/minicpmv-surgery.py -m ../MiniCPM-V-4 +python ./tools/mtmd/legacy-models/minicpmv-convert-image-encoder-to-gguf.py -m ../MiniCPM-V-4 --minicpmv-projector ../MiniCPM-V-4/minicpmv.projector --output-dir ../MiniCPM-V-4/ --minicpmv_version 5 +python ./convert_hf_to_gguf.py ../MiniCPM-V-4/model + +# quantize int4 version +./build/bin/llama-quantize ../MiniCPM-V-4/model/ggml-model-f16.gguf ../MiniCPM-V-4/model/ggml-model-Q4_K_M.gguf Q4_K_M +``` + + +Inference on Linux or Mac +```bash +# run in single-turn mode +./build/bin/llama-mtmd-cli -m ../MiniCPM-V-4/model/ggml-model-f16.gguf --mmproj ../MiniCPM-V-4/mmproj-model-f16.gguf -c 4096 --temp 0.7 --top-p 0.8 --top-k 100 --repeat-penalty 1.05 --image xx.jpg -p "What is in the image?" + +# run in conversation mode +./build/bin/llama-mtmd-cli -m ../MiniCPM-V-4/model/ggml-model-Q4_K_M.gguf --mmproj ../MiniCPM-V-4/mmproj-model-f16.gguf +``` diff --git a/examples/diffusion/README.md b/examples/diffusion/README.md new file mode 100644 index 000000000..26de5668a --- /dev/null +++ b/examples/diffusion/README.md @@ -0,0 +1,13 @@ +# Diffusion Text Generation + +This directory contains implementations for Diffusion LLMs (DLLMs) + +More Info: +- https://github.com/ggml-org/llama.cpp/pull/14644 +- https://github.com/ggml-org/llama.cpp/pull/14771 + + +Example of using Dream architechture: `llama-diffusion-cli -m dream7b.gguf -p "write code to train MNIST in pytorch" -ub 512 --diffusion-eps 0.001 --diffusion-algorithm 3 --diffusion-steps 256 --diffusion-visual` + +Example of using LLaDA architechture: `llama-diffusion-cli -m llada-8b.gguf -p "write code to train MNIST in pytorch" -ub 512 --diffusion-block-length 32 --diffusion-steps 256 --diffusion-visual` + diff --git a/examples/diffusion/diffusion-cli.cpp b/examples/diffusion/diffusion-cli.cpp index 3e11ce116..8431dcea8 100644 --- a/examples/diffusion/diffusion-cli.cpp +++ b/examples/diffusion/diffusion-cli.cpp @@ -5,344 +5,128 @@ #include "log.h" #include -#include -#include + #include #include +#include #include #include +#include +#include -typedef bool (*diffusion_step_callback_t)(int32_t step, - int32_t total_steps, - const llama_token * tokens, - int32_t n_tokens, - void * user_data); +enum diffusion_algorithm { ORIGIN = 0, ENTROPY_BASED = 1, MARGIN_BASED = 2, RANDOM = 3, CONFIDENCE_BASED = 4 }; -enum diffusion_alg { - DIFFUSION_ALG_ORIGIN = 0, - DIFFUSION_ALG_MASKGIT_PLUS = 1, - DIFFUSION_ALG_TOPK_MARGIN = 2, - DIFFUSION_ALG_ENTROPY = 3, +// Unified transfer scheduling methods +enum transfer_schedule { + TIMESTEP_BASED = 0, // Dream-style: (1.0 - s/t) * remaining + BLOCK_BASED = 1, // LLaDA-style: process in blocks with get_num_transfer_tokens }; +typedef bool (*diffusion_step_callback_t)(int32_t step, + int32_t total_steps, + const llama_token * tokens, + int32_t n_tokens, + void * user_data); + struct diffusion_params { - int32_t steps; - float eps; - float temperature; - float top_p; - int32_t top_k; - llama_token mask_token_id; - enum diffusion_alg algorithm; - float alg_temp; - diffusion_step_callback_t step_callback; - void * step_callback_user_data; - int32_t seed; + int32_t steps = 0; + float temperature = 0; + llama_token mask_token_id = LLAMA_TOKEN_NULL; + diffusion_step_callback_t step_callback = nullptr; + void * step_callback_user_data = nullptr; + int32_t seed = 0; + bool visual_mode = false; + bool shift_logits = false; // Shift logits by -1 after decode + + float top_p = 0.; + int32_t top_k = 0.; + + diffusion_algorithm algorithm = CONFIDENCE_BASED; + transfer_schedule schedule = TIMESTEP_BASED; + + float cfg_scale = 0.; // Config scale for classifier-free guidance + float eps = 0.; // Timestep scheduling + int32_t block_length = 0; // Block size (for block scheduling) + float alg_temp = 0; // algorithm temperature (0.0 = deterministic) + bool add_gumbel_noise = false; // Add gumbel noise to the logits if temp > 0.0 + + int32_t max_length = 0; // Maximum sequence length }; - -static diffusion_params diffusion_default_params() { - diffusion_params params = {}; - params.steps = 64; - params.eps = 1e-3f; - params.temperature = 0.2f; - params.top_p = 0.95f; - params.top_k = 0; - params.mask_token_id = LLAMA_TOKEN_NULL; - params.algorithm = DIFFUSION_ALG_ORIGIN; - params.alg_temp = 0.0f; - params.step_callback = nullptr; - params.step_callback_user_data = nullptr; - params.seed = 0; - return params; -} - -static void diffusion_generate(llama_context * ctx, - const llama_token * input_tokens, - llama_token * output_tokens, - int32_t n_input, - int32_t max_length, - struct diffusion_params params, - int32_t & n_generated) { - - n_generated = 0; - if (!ctx || !input_tokens || !output_tokens || n_input <= 0 || max_length <= n_input) { - return; - } - - const llama_model * model = llama_get_model(ctx); - - // Initialize with input and pad with mask tokens - std::copy(input_tokens, input_tokens + n_input, output_tokens); - std::fill(output_tokens + n_input, output_tokens + max_length, params.mask_token_id); - - std::mt19937 rng(params.seed); - - std::vector timesteps(params.steps + 1); - for (int32_t i = 0; i <= params.steps; i++) { - timesteps[i] = 1.0f - (float) i / params.steps * (1.0f - params.eps); - } - - llama_set_causal_attn(ctx, false); - - int32_t n_vocab = llama_vocab_n_tokens(llama_model_get_vocab(model)); - - std::vector candidates(n_vocab); - - std::vector conf_candidates; - conf_candidates.reserve(max_length); - - std::vector mask_positions; - mask_positions.reserve(max_length); - - struct llama_sampler * sampler = llama_sampler_chain_init(llama_sampler_chain_default_params()); - if (params.top_k > 0) { - llama_sampler_chain_add(sampler, llama_sampler_init_top_k(params.top_k)); - } - if (params.top_p < 1.0f) { - llama_sampler_chain_add(sampler, llama_sampler_init_top_p(params.top_p, 1)); - } - if (params.temperature > 0.0f) { - llama_sampler_chain_add(sampler, llama_sampler_init_temp(params.temperature)); - } - llama_sampler_chain_add(sampler, llama_sampler_init_dist(params.seed)); - - struct llama_sampler * dist_sampler = llama_sampler_init_dist(params.seed); - - llama_batch batch = llama_batch_init(max_length, 0, 1); - batch.n_tokens = max_length; - - int64_t total_sampling_time = 0; - int64_t total_time = 0; - - int64_t time_start = ggml_time_us(); - for (int32_t step = 0; step < params.steps; step++) { - if (params.step_callback) { - if (!params.step_callback(step, params.steps, output_tokens, max_length, params.step_callback_user_data)) { - break; - } - } - - for (int32_t i = 0; i < max_length; i++) { - batch.token[i] = output_tokens[i]; - batch.pos[i] = i; - batch.n_seq_id[i] = 1; - batch.seq_id[i][0] = 0; - batch.logits[i] = 1; - } - - int ret = llama_decode(ctx, batch); - if (ret != 0) { - LOG_ERR("%s: failed to decode at step %d, ret = %d\n", __func__, step, ret); - break; - } - - float * raw_logits = llama_get_logits(ctx); - if (!raw_logits) { - LOG_ERR("%s: failed to get logits at step %d\n", __func__, step); - break; - } - - auto get_logits_for_pos = [&](int32_t pos) -> const float * { - return pos == 0 ? raw_logits : raw_logits + (pos - 1) * n_vocab; - }; - - int64_t time_start_sampling = ggml_time_us(); - - mask_positions.clear(); - for (int32_t i = 0; i < max_length; i++) { - if (output_tokens[i] == params.mask_token_id) { - mask_positions.push_back(i); - } - } - - if (mask_positions.empty()) { - break; - } - - float t = timesteps[step]; - float s = timesteps[step + 1]; - - if (params.algorithm == DIFFUSION_ALG_ORIGIN) { - float p_transfer = (step < params.steps - 1) ? (1.0f - s / t) : 1.0f; - - for (int32_t pos : mask_positions) { - if (std::uniform_real_distribution(0.0f, 1.0f)(rng) < p_transfer) { - const float * pos_logits = get_logits_for_pos(pos); - for (int32_t token_id = 0; token_id < n_vocab; token_id++) { - candidates[token_id].id = token_id; - candidates[token_id].logit = pos_logits[token_id]; - candidates[token_id].p = 0.0f; - } - - llama_token_data_array cur_p = { - /* .data = */ candidates.data(), - /* .size = */ (size_t) n_vocab, // Reset size to full vocab - /* .selected = */ -1, - /* .sorted = */ false, - }; - - llama_sampler_apply(sampler, &cur_p); - output_tokens[pos] = cur_p.data[cur_p.selected].id; - } - } - } else { - std::vector> confidences; - std::vector sampled_tokens(mask_positions.size()); - - for (size_t i = 0; i < mask_positions.size(); i++) { - int32_t pos = mask_positions[i]; - const float * pos_logits = get_logits_for_pos(pos); - - for (int32_t token_id = 0; token_id < n_vocab; token_id++) { - candidates[token_id].logit = pos_logits[token_id]; - candidates[token_id].p = 0.0f; - candidates[token_id].id = token_id; - } - - llama_token_data_array cur_p = { - /* .data = */ candidates.data(), - /* .size = */ candidates.size(), - /* .selected = */ -1, - /* .sorted = */ false, - }; - - llama_sampler_apply(sampler, &cur_p); - - llama_token sampled_token = cur_p.data[cur_p.selected].id; - - float confidence = 0.0f; - if (params.algorithm == DIFFUSION_ALG_ENTROPY) { - const float epsilon = 1e-10f; - for (size_t j = 0; j < cur_p.size; j++) { - float prob = cur_p.data[j].p; - confidence += prob * logf(prob + epsilon); - } - } else if (params.algorithm == DIFFUSION_ALG_TOPK_MARGIN) { - confidence = cur_p.data[0].p - cur_p.data[1].p; - } else { - confidence = cur_p.data[cur_p.selected].p; - } - - sampled_tokens[i] = sampled_token; - confidences.emplace_back(confidence, i); - } - - int32_t num_transfer = - (step < params.steps - 1) ? (int32_t) (mask_positions.size() * (1.0f - s / t)) : mask_positions.size(); - - if (num_transfer > 0) { - if (params.alg_temp == 0.0f) { - std::partial_sort(confidences.begin(), confidences.begin() + num_transfer, confidences.end(), - [](const std::pair & a, const std::pair & b) { - if (a.first != b.first) { - return a.first > b.first; - } - return a.second < b.second; - }); - } else { - conf_candidates.clear(); - - for (int32_t pos = 0; pos < max_length; pos++) { - float conf_logit = -std::numeric_limits::infinity(); - - auto it = std::find(mask_positions.begin(), mask_positions.end(), pos); - if (it != mask_positions.end()) { - size_t mask_idx = std::distance(mask_positions.begin(), it); - conf_logit = confidences[mask_idx].first / params.alg_temp; // Apply temperature scaling - } - - conf_candidates.emplace_back(llama_token_data{ pos, conf_logit, 0.0f }); - } - - llama_token_data_array conf_array = { - /* .data = */ conf_candidates.data(), - /* .size = */ conf_candidates.size(), - /* .selected = */ -1, - /* .sorted = */ false, - }; - - for (int32_t i = 0; i < num_transfer; i++) { - // Apply distribution sampler to get selected index - llama_sampler_apply(dist_sampler, &conf_array); - int selected_idx = conf_array.selected; - confidences[i].second = conf_candidates[selected_idx].id; - - conf_candidates[selected_idx].p = 0.0f; - conf_array.selected = -1; - } - } - - if (params.alg_temp == 0.0f) { - // Deterministic - use confidence order - for (int32_t i = 0; i < num_transfer; i++) { - int32_t mask_idx = confidences[i].second; - int32_t pos = mask_positions[mask_idx]; - llama_token token = sampled_tokens[mask_idx]; - output_tokens[pos] = token; - } - } else { - for (int32_t i = 0; i < num_transfer; i++) { - int32_t pos = confidences[i].second; - auto it = std::find(mask_positions.begin(), mask_positions.end(), pos); - if (it != mask_positions.end()) { - int32_t mask_idx = std::distance(mask_positions.begin(), it); - output_tokens[pos] = sampled_tokens[mask_idx]; - } - } - } - } - } - int64_t time_end_sampling = ggml_time_us(); - total_sampling_time += time_end_sampling - time_start_sampling; - } - int64_t time_end = ggml_time_us(); - total_time += time_end - time_start; - - LOG_INF("\ntotal time: %0.2fms, time per step: %0.2fms, sampling time per step: %0.2fms\n", - total_time / 1000.0, total_time / 1000.0 / params.steps, total_sampling_time / 1000.0 / params.steps); - - - llama_batch_free(batch); - llama_sampler_free(sampler); - llama_sampler_free(dist_sampler); - - n_generated = max_length; -} - - - - -static std::string format_input_text(const std::string & prompt, bool use_chat_template, llama_model * model) { - if (!use_chat_template) { - return prompt; - } - - auto chat_templates = common_chat_templates_init(model, ""); - - common_chat_templates_inputs inputs; - common_chat_msg user_msg; - user_msg.role = "user"; - user_msg.content = prompt; - inputs.add_generation_prompt = true; - inputs.messages.push_back(user_msg); - - auto result = common_chat_templates_apply(chat_templates.get(), inputs); - - return result.prompt; -} - struct callback_data { - const common_params_diffusion * diff_params; - const llama_vocab * vocab; - int32_t n_input; + diffusion_params * diff_params; + const llama_vocab * vocab; + int32_t n_input; }; -static bool diffusion_step_callback(int32_t step, - int32_t total_steps, +static float calculate_confidence(const llama_token_data_array & cur_p, + diffusion_algorithm algorithm, + std::mt19937 & rng) { + switch (algorithm) { + case CONFIDENCE_BASED: + return cur_p.data[cur_p.selected].p; // Selected token probability + + case ENTROPY_BASED: + { + float entropy = 0.0f; + const float epsilon = 1e-10f; + for (size_t i = 0; i < cur_p.size; i++) { + float prob = cur_p.data[i].p; + entropy += prob * logf(prob + epsilon); + } + return -entropy; // Higher entropy = lower confidence + } + + case MARGIN_BASED: + return (cur_p.size > 1) ? cur_p.data[0].p - cur_p.data[1].p : cur_p.data[0].p; + + case RANDOM: + { + std::uniform_real_distribution uniform(0.0f, 1.0f); + return uniform(rng); // Random confidence + } + + case ORIGIN: + return cur_p.data[cur_p.selected].p; + + default: + return 0.0f; + } +} + +// Unified transfer count calculation function +static int32_t calculate_transfer_count(int32_t step, + int32_t total_steps, + int32_t remaining_masked, + transfer_schedule schedule, + float eps, + const std::vector & num_transfer_tokens = {}) { + switch (schedule) { + case TIMESTEP_BASED: + { + float t = 1.0f - (float) step / total_steps * (1.0f - eps); + float s = 1.0f - (float) (step + 1) / total_steps * (1.0f - eps); + float p_transfer = (step < total_steps - 1) ? (1.0f - s / t) : 1.0f; + return (int32_t) (remaining_masked * p_transfer); + } + + case BLOCK_BASED: + if (!num_transfer_tokens.empty() && step < (int32_t) num_transfer_tokens.size()) { + return num_transfer_tokens[step]; + } + return remaining_masked / (total_steps - step); // Fallback + + default: + return remaining_masked / (total_steps - step); + } +} + +static bool diffusion_step_callback(int32_t step, + int32_t total_steps, const llama_token * tokens, - int32_t n_tokens, - void * user_data) { - (void)user_data; + int32_t n_tokens, + void * user_data) { + (void) user_data; callback_data * data = static_cast(user_data); @@ -350,11 +134,11 @@ static bool diffusion_step_callback(int32_t step, int progress_percent = (step * 100) / total_steps; int progress_bars = (step * 50) / total_steps; LOG_INF("\rdiffusion step: %d/%d [%s%s] %d%%", - step, - total_steps, - std::string(progress_bars, '=').c_str(), - std::string(50 - progress_bars, ' ').c_str(), - progress_percent); + step, + total_steps, + std::string(progress_bars, '=').c_str(), + std::string(50 - progress_bars, ' ').c_str(), + progress_percent); }; if (data->diff_params->visual_mode) { @@ -391,6 +175,360 @@ static bool diffusion_step_callback(int32_t step, return true; } +static void add_gumbel_noise(float * logits, int32_t n_vocab, float temperature, std::mt19937 & rng) { + if (temperature == 0.0f) { + return; + } + + std::uniform_real_distribution uniform(0.0, 1.0); + for (int32_t i = 0; i < n_vocab; i++) { + double noise = uniform(rng); + // Prevent log(0) + noise = std::max(noise, 1e-20); + double gumbel_noise = std::pow(-std::log(noise), temperature); + logits[i] = std::exp(logits[i]) / gumbel_noise; + } +} + +static std::vector get_num_transfer_tokens(int32_t mask_count, int32_t steps) { + std::vector num_transfer_tokens(steps); + + int32_t base = mask_count / steps; + int32_t remainder = mask_count % steps; + + for (int32_t i = 0; i < steps; i++) { + num_transfer_tokens[i] = base + (i < remainder ? 1 : 0); + } + + return num_transfer_tokens; +} + +static void diffusion_generate(llama_context * ctx, + const llama_token * input_tokens, + llama_token * output_tokens, + int32_t n_input, + const diffusion_params & params, + int32_t & n_generated) { + n_generated = 0; + if (!ctx || !input_tokens || !output_tokens || n_input <= 0 || params.max_length <= n_input) { + return; + } + + const llama_model * model = llama_get_model(ctx); + + // Initialize with input and pad with mask tokens + std::copy(input_tokens, input_tokens + n_input, output_tokens); + std::fill(output_tokens + n_input, output_tokens + params.max_length, params.mask_token_id); + + std::mt19937 rng(params.seed); + + llama_set_causal_attn(ctx, false); + + int32_t n_vocab = llama_vocab_n_tokens(llama_model_get_vocab(model)); + + std::vector candidates(n_vocab); + std::vector conf_candidates; + conf_candidates.reserve(params.max_length); + std::vector mask_positions; + mask_positions.reserve(params.max_length); + + // Setup sampler chain + struct llama_sampler * sampler = llama_sampler_chain_init(llama_sampler_chain_default_params()); + if (params.top_k > 0) { + llama_sampler_chain_add(sampler, llama_sampler_init_top_k(params.top_k)); + } + if (params.top_p < 1.0f) { + llama_sampler_chain_add(sampler, llama_sampler_init_top_p(params.top_p, 1)); + } + if (params.temperature > 0.0f) { + llama_sampler_chain_add(sampler, llama_sampler_init_temp(params.temperature)); + } + llama_sampler_chain_add(sampler, llama_sampler_init_dist(params.seed)); + + struct llama_sampler * dist_sampler = llama_sampler_init_dist(params.seed); + + llama_batch batch = llama_batch_init(params.max_length, 0, 1); + batch.n_tokens = params.max_length; + + // Pre-allocate buffers for CFG if needed + int32_t logits_size = n_vocab * params.max_length; + std::vector cond_logits_buffer; + std::vector un_x_buffer; + if (params.cfg_scale > 0.0f) { + cond_logits_buffer.resize(logits_size); + un_x_buffer.resize(params.max_length); + } + + // For block-based processing + std::vector num_transfer_tokens; + int32_t num_blocks = 1; + int32_t steps_per_block = params.steps; + + if (params.schedule == BLOCK_BASED) { + GGML_ASSERT(params.max_length % params.block_length == 0); + num_blocks = params.max_length / params.block_length; + GGML_ASSERT(params.steps % num_blocks == 0); + steps_per_block = params.steps / num_blocks; + } + + std::vector confidence(params.max_length); + + int64_t total_sampling_time = 0; + int64_t total_time = 0; + int64_t time_start = ggml_time_us(); + + for (int block_num = 0; block_num < num_blocks; block_num++) { + int32_t block_start = (params.schedule == BLOCK_BASED) ? n_input + block_num * params.block_length : 0; + int32_t block_end = (params.schedule == BLOCK_BASED) ? + std::min(n_input + (block_num + 1) * params.block_length, params.max_length) : + params.max_length; + + // Count masked tokens in current block for block-based processing + if (params.schedule == BLOCK_BASED) { + int32_t block_mask_count = 0; + for (int i = block_start; i < block_end; i++) { + if (output_tokens[i] == params.mask_token_id) { + block_mask_count++; + } + } + num_transfer_tokens = get_num_transfer_tokens(block_mask_count, steps_per_block); + } + + for (int32_t step = 0; step < steps_per_block; step++) { + int32_t global_step = block_num * steps_per_block + step; + + if (params.step_callback) { + if (!params.step_callback( + global_step, params.steps, output_tokens, params.max_length, params.step_callback_user_data)) { + break; + } + } + + // Setup batch + for (int32_t i = 0; i < params.max_length; i++) { + batch.token[i] = output_tokens[i]; + batch.pos[i] = i; + batch.n_seq_id[i] = 1; + batch.seq_id[i][0] = 0; + batch.logits[i] = 1; + } + + float * logits = nullptr; + + if (params.cfg_scale > 0.0f) { + int ret = llama_decode(ctx, batch); + if (ret != 0) { + LOG_ERR("Failed to generate conditional"); + break; + } + float * cond_logits_ptr = llama_get_logits(ctx); + std::memcpy(cond_logits_buffer.data(), cond_logits_ptr, logits_size * sizeof(float)); + + // Unconditional generation (mask input) + std::copy(output_tokens, output_tokens + params.max_length, un_x_buffer.begin()); + for (int32_t i = 0; i < n_input; i++) { + un_x_buffer[i] = params.mask_token_id; + } + + for (int32_t i = 0; i < params.max_length; i++) { + batch.token[i] = un_x_buffer[i]; + } + ret = llama_decode(ctx, batch); + if (ret != 0) { + LOG_ERR("Failed to generate unconditional"); + break; + } + float * uncond_logits = llama_get_logits(ctx); + + // Apply CFG + for (int32_t i = 0; i < logits_size; i++) { + cond_logits_buffer[i] = + uncond_logits[i] + (params.cfg_scale + 1.0f) * (cond_logits_buffer[i] - uncond_logits[i]); + } + logits = cond_logits_buffer.data(); + } else { + int ret = llama_decode(ctx, batch); + if (ret != 0) { + LOG_ERR("%s: failed to decode at step %d, ret = %d\n", __func__, global_step, ret); + break; + } + logits = llama_get_logits(ctx); + } + + if (!logits) { + LOG_ERR("%s: failed to get logits at step %d\n", __func__, global_step); + break; + } + + auto get_logits_for_pos = [&](int32_t pos) -> const float * { + if (params.shift_logits) { + return pos == 0 ? logits : logits + (pos - 1) * n_vocab; + } + return logits + (pos) *n_vocab; + }; + + int64_t time_start_sampling = ggml_time_us(); + + mask_positions.clear(); + for (int32_t i = 0; i < params.max_length; i++) { + if (output_tokens[i] == params.mask_token_id) { + // For block-based, only consider current block + if (params.schedule != BLOCK_BASED || (i >= block_start && i < block_end)) { + mask_positions.push_back(i); + } + } + } + + if (mask_positions.empty()) { + break; + } + + if (params.add_gumbel_noise && params.temperature > 0.0f) { + add_gumbel_noise(logits, n_vocab, params.temperature, rng); + } + + if (params.algorithm == ORIGIN) { + int32_t transfer_count = calculate_transfer_count( + step, steps_per_block, mask_positions.size(), params.schedule, params.eps, num_transfer_tokens); + float p_transfer = (float) transfer_count / mask_positions.size(); + + for (int32_t pos : mask_positions) { + if (std::uniform_real_distribution(0.0f, 1.0f)(rng) < p_transfer) { + const float * pos_logits = get_logits_for_pos(pos); + for (int32_t token_id = 0; token_id < n_vocab; token_id++) { + candidates[token_id].id = token_id; + candidates[token_id].logit = pos_logits[token_id]; + candidates[token_id].p = 0.0f; + } + + llama_token_data_array cur_p = { + candidates.data(), + (size_t) n_vocab, + -1, + false, + }; + + llama_sampler_apply(sampler, &cur_p); + output_tokens[pos] = cur_p.data[cur_p.selected].id; + } + } + } else { + std::vector> confidences; + std::vector sampled_tokens(mask_positions.size()); + + for (size_t i = 0; i < mask_positions.size(); i++) { + int32_t pos = mask_positions[i]; + const float * pos_logits = get_logits_for_pos(pos); + + for (int32_t token_id = 0; token_id < n_vocab; token_id++) { + candidates[token_id].logit = pos_logits[token_id]; + candidates[token_id].p = 0.0f; + candidates[token_id].id = token_id; + } + + llama_token_data_array cur_p = { + candidates.data(), + candidates.size(), + -1, + false, + }; + + llama_sampler_apply(sampler, &cur_p); + llama_token sampled_token = cur_p.data[cur_p.selected].id; + + float conf = calculate_confidence(cur_p, params.algorithm, rng); + + sampled_tokens[i] = sampled_token; + confidences.emplace_back(conf, i); + } + + int32_t transfer_count = calculate_transfer_count( + step, steps_per_block, mask_positions.size(), params.schedule, params.eps, num_transfer_tokens); + + if (transfer_count > 0) { + if (params.alg_temp == 0.0f) { + std::partial_sort(confidences.begin(), + confidences.begin() + std::min(transfer_count, (int32_t) confidences.size()), + confidences.end(), + [](const std::pair & a, const std::pair & b) { + if (a.first != b.first) { + return a.first > b.first; + } + return a.second < b.second; + }); + + for (int32_t i = 0; i < std::min(transfer_count, (int32_t) confidences.size()); i++) { + int32_t mask_idx = confidences[i].second; + int32_t pos = mask_positions[mask_idx]; + output_tokens[pos] = sampled_tokens[mask_idx]; + } + } else { + conf_candidates.clear(); + for (size_t i = 0; i < confidences.size(); i++) { + float conf_logit = confidences[i].first / params.alg_temp; + conf_candidates.emplace_back(llama_token_data{ (int32_t) i, conf_logit, 0.0f }); + } + + llama_token_data_array conf_array = { + conf_candidates.data(), + conf_candidates.size(), + -1, + false, + }; + + for (int32_t i = 0; i < std::min(transfer_count, (int32_t) confidences.size()); i++) { + llama_sampler_apply(dist_sampler, &conf_array); + int32_t selected_idx = conf_array.selected; + int32_t mask_idx = selected_idx; + int32_t pos = mask_positions[mask_idx]; + output_tokens[pos] = sampled_tokens[mask_idx]; + + conf_candidates[selected_idx].p = 0.0f; + conf_array.selected = -1; + } + } + } + } + + int64_t time_end_sampling = ggml_time_us(); + total_sampling_time += time_end_sampling - time_start_sampling; + } + } + + int64_t time_end = ggml_time_us(); + total_time += time_end - time_start; + + LOG_INF("\ntotal time: %0.2fms, time per step: %0.2fms, sampling time per step: %0.2fms\n", + total_time / 1000.0, + total_time / 1000.0 / params.steps, + total_sampling_time / 1000.0 / params.steps); + + llama_batch_free(batch); + llama_sampler_free(sampler); + llama_sampler_free(dist_sampler); + + n_generated = params.max_length; +} + +static std::string format_input_text(const std::string & prompt, bool use_chat_template, llama_model * model) { + if (!use_chat_template) { + return prompt; + } + + auto chat_templates = common_chat_templates_init(model, ""); + + common_chat_templates_inputs inputs; + common_chat_msg user_msg; + user_msg.role = "user"; + user_msg.content = prompt; + inputs.add_generation_prompt = true; + inputs.messages.push_back(user_msg); + + auto result = common_chat_templates_apply(chat_templates.get(), inputs); + + return result.prompt; +} + int main(int argc, char ** argv) { ggml_time_init(); @@ -400,11 +538,6 @@ int main(int argc, char ** argv) { return 1; } - const char * alg_names[] = { "ORIGIN", "MASKGIT_PLUS", "TOPK_MARGIN", "ENTROPY" }; - const char * alg_name = (params.diffusion.algorithm >= 0 && params.diffusion.algorithm <= 3) ? - alg_names[params.diffusion.algorithm] : - "UNKNOWN"; - common_init(); llama_backend_init(); @@ -421,6 +554,12 @@ int main(int argc, char ** argv) { return 1; } + if (!llama_model_is_diffusion(model)) { + LOG_ERR("error: unsupported model for diffusion"); + llama_model_free(model); + return 1; + } + llama_context_params ctx_params = llama_context_default_params(); ctx_params.n_ctx = params.n_ctx; ctx_params.n_batch = params.n_batch; @@ -442,10 +581,12 @@ int main(int argc, char ** argv) { const llama_vocab * vocab = llama_model_get_vocab(model); std::string formatted_prompt = format_input_text(params.prompt, params.enable_chat_template, model); - std::vector input_tokens = common_tokenize(vocab, formatted_prompt, + std::vector input_tokens = common_tokenize(vocab, + formatted_prompt, /*add special tokens*/ true, /*parse special*/ true); - int n_input = input_tokens.size(); + + int n_input = input_tokens.size(); if (n_input >= params.n_ctx) { LOG_ERR("error: input too long (%d tokens), max context is %d\n", n_input, params.n_ctx); @@ -454,44 +595,79 @@ int main(int argc, char ** argv) { return 1; } - struct diffusion_params ldiff_params = diffusion_default_params(); - ldiff_params.steps = params.diffusion.steps; - ldiff_params.eps = params.diffusion.eps; - ldiff_params.temperature = params.sampling.temp; - ldiff_params.top_p = params.sampling.top_p; - ldiff_params.top_k = params.sampling.top_k; - ldiff_params.algorithm = static_cast(params.diffusion.algorithm); - ldiff_params.alg_temp = params.diffusion.alg_temp; - ldiff_params.seed = params.sampling.seed; - llama_token mask_token_id = llama_vocab_mask(vocab); GGML_ASSERT(mask_token_id != LLAMA_TOKEN_NULL); - LOG_INF("diffusion_params: - %-25s llama_token = %d\n", "mask_token_id", mask_token_id); - LOG_INF("diffusion_params: - %-25s u32 = %d\n", "steps", params.diffusion.steps); - LOG_INF("diffusion_params: - %-25s f32 = %.6f\n", "eps", params.diffusion.eps); - LOG_INF("diffusion_params: - %-25s u32 = %d (%s)\n", "algorithm", params.diffusion.algorithm, - alg_name); - LOG_INF("diffusion_params: - %-25s f32 = %.3f\n", "alg_temp", params.diffusion.alg_temp); - - ldiff_params.mask_token_id = mask_token_id; - - callback_data cb_data = { ¶ms.diffusion, vocab, n_input }; - - ldiff_params.step_callback = diffusion_step_callback; - ldiff_params.step_callback_user_data = &cb_data; - - int32_t n_generated = 0; + bool visual_mode = params.diffusion.visual_mode; + int32_t n_generated = 0; std::vector output_tokens(params.n_ubatch); - diffusion_generate(ctx, input_tokens.data(), output_tokens.data(), n_input, params.n_ubatch, - ldiff_params, n_generated); + + struct diffusion_params diff_params; + + char shift_logits_str[8]; + if (llama_model_meta_val_str(model, "diffusion.shift_logits", shift_logits_str, sizeof(shift_logits_str)) >= 0) { + diff_params.shift_logits = (strcmp(shift_logits_str, "true") == 0); + } else { + diff_params.shift_logits = true; + } + + //Use either eps or block length, but not both + GGML_ASSERT((params.diffusion.eps == 0) ^ (params.diffusion.block_length == 0)); + + if (params.diffusion.eps) { + diff_params.schedule = TIMESTEP_BASED; + diff_params.eps = params.diffusion.eps; + } else if (params.diffusion.block_length) { + diff_params.schedule = BLOCK_BASED; + diff_params.block_length = params.diffusion.block_length; + } + + diff_params.mask_token_id = mask_token_id; + diff_params.seed = params.sampling.seed; + diff_params.temperature = params.sampling.temp; + diff_params.steps = params.diffusion.steps; + diff_params.algorithm = static_cast(params.diffusion.algorithm); + diff_params.max_length = params.n_ubatch; + diff_params.top_p = params.sampling.top_p; + diff_params.top_k = params.sampling.top_k; + diff_params.visual_mode = params.diffusion.visual_mode; + diff_params.add_gumbel_noise = params.diffusion.add_gumbel_noise; + + diff_params.step_callback = diffusion_step_callback; + callback_data cb_data = { &diff_params, vocab, n_input }; + diff_params.step_callback_user_data = &cb_data; + + const char * alg_names[] = { "ORIGIN", "ENTROPY_BASED", "MARGIN_BASED", "RANDOM", "CONFIDENCE_BASED" }; + const char * sched_names[] = { "TIMESTEP_BASED", "BLOCK_BASED" }; + const char * alg_name = + (diff_params.algorithm >= 0 && diff_params.algorithm <= 4) ? alg_names[diff_params.algorithm] : "UNKNOWN"; + const char * sched_name = + (diff_params.schedule >= 0 && diff_params.schedule <= 1) ? sched_names[diff_params.schedule] : "UNKNOWN"; + + LOG_INF("diffusion_params: - %-25s llama_token = %d\n", "mask_token_id", mask_token_id); + LOG_INF("diffusion_params: - %-25s u32 = %d\n", "steps", diff_params.steps); + LOG_INF("diffusion_params: - %-25s u32 = %d\n", "max_length", diff_params.max_length); + LOG_INF("diffusion_params: - %-25s enum = %d (%s)\n", "algorithm", diff_params.algorithm, alg_name); + LOG_INF("diffusion_params: - %-25s enum = %d (%s)\n", "schedule", diff_params.schedule, sched_name); + LOG_INF("diffusion_params: - %-25s f32 = %.3f\n", "temperature", diff_params.temperature); + if (diff_params.schedule == TIMESTEP_BASED) { + LOG_INF("diffusion_params: - %-25s f32 = %.6f\n", "eps", diff_params.eps); + LOG_INF("diffusion_params: - %-25s f32 = %.3f\n", "alg_temp", diff_params.alg_temp); + } + if (diff_params.schedule == BLOCK_BASED) { + LOG_INF("diffusion_params: - %-25s u32 = %d\n", "block_length", diff_params.block_length); + LOG_INF("diffusion_params: - %-25s f32 = %.3f\n", "cfg_scale", diff_params.cfg_scale); + } + + diffusion_generate(ctx, input_tokens.data(), output_tokens.data(), n_input, diff_params, n_generated); if (n_generated > 0) { - if (params.diffusion.visual_mode) { + if (visual_mode) { //clear screen and move cursor to top-left LOG_INF("\033[2J\033[H"); } + output_tokens.erase(output_tokens.begin(), output_tokens.begin() + n_input); std::string output_data = common_detokenize(vocab, output_tokens, false); LOG_INF("\n%s\n", output_data.c_str()); diff --git a/ggml/src/ggml-cpu/arch-fallback.h b/ggml/src/ggml-cpu/arch-fallback.h index 10e534251..f02cfe8fa 100644 --- a/ggml/src/ggml-cpu/arch-fallback.h +++ b/ggml/src/ggml-cpu/arch-fallback.h @@ -37,17 +37,21 @@ #define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0 #define ggml_gemv_q4_0_8x8_q8_0_generic ggml_gemv_q4_0_8x8_q8_0 #define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K +#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K #define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0 #define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0 #define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0 #define ggml_gemm_q4_0_8x8_q8_0_generic ggml_gemm_q4_0_8x8_q8_0 #define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K +#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K #define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0 #elif defined(__aarch64__) || defined(__arm__) || defined(_M_ARM) || defined(_M_ARM64) // repack.cpp #define ggml_quantize_mat_q8_K_4x8_generic ggml_quantize_mat_q8_K_4x8 #define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K +#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K #define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K +#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K #elif defined(__x86_64__) || defined(__i386__) || defined(_M_IX86) || defined(_M_X64) // repack.cpp #define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4 @@ -72,11 +76,13 @@ #define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0 #define ggml_gemv_q4_0_8x8_q8_0_generic ggml_gemv_q4_0_8x8_q8_0 #define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K +#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K #define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0 #define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0 #define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0 #define ggml_gemm_q4_0_8x8_q8_0_generic ggml_gemm_q4_0_8x8_q8_0 #define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K +#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K #define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0 #elif defined(__loongarch64) // quants.c @@ -92,11 +98,13 @@ #define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0 #define ggml_gemv_q4_0_8x8_q8_0_generic ggml_gemv_q4_0_8x8_q8_0 #define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K +#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K #define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0 #define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0 #define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0 #define ggml_gemm_q4_0_8x8_q8_0_generic ggml_gemm_q4_0_8x8_q8_0 #define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K +#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K #define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0 #elif defined(__riscv) // quants.c @@ -119,10 +127,12 @@ #define ggml_gemv_q4_0_4x4_q8_0_generic ggml_gemv_q4_0_4x4_q8_0 #define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0 #define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K +#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K #define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0 #define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0 #define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0 #define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K +#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K #define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0 #elif defined(__s390x__) // quants.c @@ -147,11 +157,13 @@ #define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0 #define ggml_gemv_q4_0_8x8_q8_0_generic ggml_gemv_q4_0_8x8_q8_0 #define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K +#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K #define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0 #define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0 #define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0 #define ggml_gemm_q4_0_8x8_q8_0_generic ggml_gemm_q4_0_8x8_q8_0 #define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K +#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K #define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0 #elif defined(__wasm__) // quants.c @@ -175,10 +187,12 @@ #define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0 #define ggml_gemv_q4_0_8x8_q8_0_generic ggml_gemv_q4_0_8x8_q8_0 #define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K +#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K #define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0 #define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0 #define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0 #define ggml_gemm_q4_0_8x8_q8_0_generic ggml_gemm_q4_0_8x8_q8_0 #define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K +#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K #define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0 #endif diff --git a/ggml/src/ggml-cpu/arch/x86/repack.cpp b/ggml/src/ggml-cpu/arch/x86/repack.cpp index beeb260bc..57c6778f5 100644 --- a/ggml/src/ggml-cpu/arch/x86/repack.cpp +++ b/ggml/src/ggml-cpu/arch/x86/repack.cpp @@ -849,6 +849,319 @@ void ggml_gemv_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo #endif } +void ggml_gemv_q2_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + const int qk = QK_K; + const int nb = n / qk; + const int ncols_interleaved = 8; + const int blocklen = 8; + + assert (n % qk == 0); + assert (nc % ncols_interleaved == 0); + + UNUSED(s); + UNUSED(bs); + UNUSED(vx); + UNUSED(vy); + UNUSED(nr); + UNUSED(nc); + UNUSED(nb); + UNUSED(ncols_interleaved); + UNUSED(blocklen); + +#if defined(__AVX2__) + // Lookup table to convert signed nibbles to signed bytes + __m256i signextendlut = _mm256_castsi128_si256(_mm_set_epi8(-1, -2, -3, -4, -5, -6, -7, -8, 7, 6, 5, 4, 3, 2, 1, 0)); + signextendlut = _mm256_permute2f128_si256(signextendlut, signextendlut, 0); + // Shuffle masks to rearrange delta values to multiply with appropriate scales + __m128i deltamask = _mm_set_epi8(15, 14, 7, 6, 13, 12, 5, 4, 11, 10, 3, 2, 9, 8, 1, 0); + // Permute mask used for easier vector processing at later stages + __m256i finalpermutemask = _mm256_set_epi32(7, 5, 3, 1, 6, 4, 2, 0); + + const __m256i m3b = _mm256_set1_epi8(3); + const __m128i m4b_sse = _mm_set1_epi8(0xF); + + //Mask to get appropriate scales + __m128i scalemask1 = _mm_set_epi8(14,14,6,6,12,12,4,4,10,10,2,2,8,8,0,0); + __m128i scalemask2 = _mm_set_epi8(15,15,7,7,13,13,5,5,11,11,3,3,9,9,1,1); + + int64_t b_nb = n / QK_K; + + const block_q2_Kx8 * b_ptr_start = (const block_q2_Kx8 *)vx; + const block_q8_K * a_ptr_start = (const block_q8_K *)vy; + + // Process Q8_K blocks one by one + for (int64_t y = 0; y < nr; y++) { + + // Pointers to LHS blocks of block_q8_K format + const block_q8_K * a_ptr = a_ptr_start + (y * nb); + + // Take group of eight interleaved block_q2_K structures at each pass of the loop and perform dot product operation + for(int64_t x = 0; x < nc / 8; x++) { + + // Pointers to RHS blocks + const block_q2_Kx8 * b_ptr = b_ptr_start + (x * b_nb); + + // Master FP accumulators + __m256 acc_row = _mm256_setzero_ps(); + __m256 acc_min_rows = _mm256_setzero_ps(); + + for (int64_t b = 0; b < nb; b++) { + + // Load and convert to FP32 delta from block_q8_K + const __m256 row_scale_f32 = _mm256_set1_ps((a_ptr[b].d)); + + // Load the delta values for the 8 blocks interleaved in block_q2_Kx8 + // col_scale_f32 rearranged so as to multiply with appropriate quants + const __m256 col_scale_f32 = GGML_F32Cx8_REARRANGE_LOAD(b_ptr[b].d, deltamask); + const __m256 col_dmin_f32 = GGML_F32Cx8_LOAD(b_ptr[b].dmin); + + __m256i iacc_b = _mm256_setzero_si256(); + __m256i iacc_min_b = _mm256_setzero_si256(); + + // Processes eight sub blocks from each Q2_K in each iteration + for(int sb = 0; sb < QK_K / 128; sb++) { + + // Load the eight block_q2_K for eight sub blocks quantized values interleaved with each other in chunks of eight - B0,B1 ....B6,B7 + const __m256i rhs_raw_vec_0123_0 = _mm256_loadu_si256((const __m256i * )(b_ptr[b].qs + sb * 256)); + const __m256i rhs_raw_vec_4567_0 = _mm256_loadu_si256((const __m256i * )(b_ptr[b].qs + 32 + sb * 256)); + const __m256i rhs_raw_vec_0123_1 = _mm256_loadu_si256((const __m256i * )(b_ptr[b].qs + 64 + sb * 256)); + const __m256i rhs_raw_vec_4567_1 = _mm256_loadu_si256((const __m256i * )(b_ptr[b].qs + 96 + sb * 256)); + const __m256i rhs_raw_vec_0123_2 = _mm256_loadu_si256((const __m256i * )(b_ptr[b].qs + 128 + sb * 256)); + const __m256i rhs_raw_vec_4567_2 = _mm256_loadu_si256((const __m256i * )(b_ptr[b].qs + 160 + sb * 256)); + const __m256i rhs_raw_vec_0123_3 = _mm256_loadu_si256((const __m256i * )(b_ptr[b].qs + 192 + sb * 256)); + const __m256i rhs_raw_vec_4567_3 = _mm256_loadu_si256((const __m256i * )(b_ptr[b].qs + 224 + sb * 256)); + + // 2-bit -> 8-bit + // Values of the 0th,2nd,4th,6th sub blocks of eight block_q2_K structures for the sb loop + const __m256i rhs_vec_0123_00 = _mm256_and_si256(rhs_raw_vec_0123_0, m3b); //B00(0-7) B01(0-7) B02(0-7) B03(0-7) + const __m256i rhs_vec_0123_20 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_0123_0, 2), m3b); //B20(0-7) B21(0-7) B22(0-7) B23(0-7) + const __m256i rhs_vec_0123_40 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_0123_0, 4), m3b); //B40(0-7) B41(0-7) B42(0-7) B43(0-7) + const __m256i rhs_vec_0123_60 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_0123_0, 6), m3b); //B60(0-7) B61(0-7) B62(0-7) B63(0-7) + + const __m256i rhs_vec_4567_00 = _mm256_and_si256(rhs_raw_vec_4567_0, m3b); //B04(0-7) B05(0-7) B06(0-7) B07(0-7) + const __m256i rhs_vec_4567_20 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_4567_0, 2), m3b); //B24(0-7) B25(0-7) B26(0-7) B27(0-7) + const __m256i rhs_vec_4567_40 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_4567_0, 4), m3b); //B44(0-7) B45(0-7) B46(0-7) B47(0-7) + const __m256i rhs_vec_4567_60 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_4567_0, 6), m3b); //B64(0-7) B65(0-7) B66(0-7) B67(0-7) + + const __m256i rhs_vec_0123_01 = _mm256_and_si256(rhs_raw_vec_0123_1, m3b); //B00(8-15) B01(8-15) B02(8-15) B03(8-15) + const __m256i rhs_vec_0123_21 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_0123_1, 2), m3b); //B20(8-15) B21(8-15) B22(8-15) B23(8-15) + const __m256i rhs_vec_0123_41 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_0123_1, 4), m3b); //B40(8-15) B41(8-15) B42(8-15) B43(8-15) + const __m256i rhs_vec_0123_61 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_0123_1, 6), m3b); //B60(8-15) B61(8-15) B62(8-15) B63(8-15) + + const __m256i rhs_vec_4567_01 = _mm256_and_si256(rhs_raw_vec_4567_1, m3b); //B04(8-15) B05(8-15) B06(8-15) B07(8-15) + const __m256i rhs_vec_4567_21 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_4567_1, 2), m3b); //B24(8-15) B25(8-15) B26(8-15) B27(8-15) + const __m256i rhs_vec_4567_41 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_4567_1, 4), m3b); //B44(8-15) B45(8-15) B46(8-15) B47(8-15) + const __m256i rhs_vec_4567_61 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_4567_1, 6), m3b); //B64(8-15) B65(8-15) B66(8-15) B67(8-15) + + // Values of the 1st,3rd,5th,7th sub blocks of eight block_q2_K structures for the sb loop + const __m256i rhs_vec_0123_10 = _mm256_and_si256(rhs_raw_vec_0123_2, m3b); //B10(0-7) B11(0-7) B12(0-7) B13(0-7) + const __m256i rhs_vec_0123_30 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_0123_2, 2), m3b); //B30(0-7) B31(0-7) B32(0-7) B33(0-7) + const __m256i rhs_vec_0123_50 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_0123_2, 4), m3b); //B50(0-7) B51(0-7) B52(0-7) B53(0-7) + const __m256i rhs_vec_0123_70 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_0123_2, 6), m3b); //B70(0-7) B71(0-7) B72(0-7) B73(0-7) + + const __m256i rhs_vec_4567_10 = _mm256_and_si256(rhs_raw_vec_4567_2, m3b); //B14(0-7) B15(0-7) B16(0-7) B17(0-7) + const __m256i rhs_vec_4567_30 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_4567_2, 2), m3b); //B34(0-7) B35(0-7) B36(0-7) B37(0-7) + const __m256i rhs_vec_4567_50 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_4567_2, 4), m3b); //B54(0-7) B55(0-7) B56(0-7) B57(0-7) + const __m256i rhs_vec_4567_70 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_4567_2, 6), m3b); //B74(0-7) B75(0-7) B76(0-7) B77(0-7) + + const __m256i rhs_vec_0123_11 = _mm256_and_si256(rhs_raw_vec_0123_3, m3b); //B10(8-15) B11(8-15) B12(8-15) B13(8-15) + const __m256i rhs_vec_0123_31 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_0123_3, 2), m3b); //B30(8-15) B31(8-15) B32(8-15) B33(8-15) + const __m256i rhs_vec_0123_51 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_0123_3, 4), m3b); //B50(8-15) B51(8-15) B52(8-15) B53(8-15) + const __m256i rhs_vec_0123_71 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_0123_3, 6), m3b); //B70(8-15) B71(8-15) B72(8-15) B73(8-15) + + const __m256i rhs_vec_4567_11 = _mm256_and_si256(rhs_raw_vec_4567_3, m3b); //B14(8-15) B15(8-15) B16(8-15) B17(8-15) + const __m256i rhs_vec_4567_31 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_4567_3, 2), m3b); //B34(8-15) B35(8-15) B36(8-15) B37(8-15) + const __m256i rhs_vec_4567_51 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_4567_3, 4), m3b); //B54(8-15) B55(8-15) B56(8-15) B57(8-15) + const __m256i rhs_vec_4567_71 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_4567_3, 6), m3b); //B74(8-15) B75(8-15) B76(8-15) B77(8-15) + + //Scales and Mins of corresponding sub blocks from different Q2_K structures are stored together + //s00 m00 s01 m01 s10 m10 s11 m11 s20 m20 s21 m21 s30 m30 s31 m31 s40 m40 s41 m41 s50 m50 s51 m51 s60 m60 s61 m61 s70 m70 s71 m71 + + const __m128i mins_and_scales_01 = _mm_loadu_si128((const __m128i *)(b_ptr[b].scales + sb * 64)); + const __m128i mins_and_scales_23 = _mm_loadu_si128((const __m128i *)(b_ptr[b].scales + 16 + sb * 64)); + const __m128i mins_and_scales_45 = _mm_loadu_si128((const __m128i *)(b_ptr[b].scales + 32 + sb * 64)); + const __m128i mins_and_scales_67 = _mm_loadu_si128((const __m128i *)(b_ptr[b].scales + 48 + sb * 64)); + + // Extract scales which is lower half from mins_and_scales + const __m128i scales_01 = _mm_and_si128(mins_and_scales_01, m4b_sse); + const __m128i scales_23 = _mm_and_si128(mins_and_scales_23, m4b_sse); + const __m128i scales_45 = _mm_and_si128(mins_and_scales_45, m4b_sse); + const __m128i scales_67 = _mm_and_si128(mins_and_scales_67, m4b_sse); + + // Extract mins which is upper half from mins_and_scales + const __m256i mins_01 = _mm256_cvtepu8_epi16(_mm_and_si128(_mm_srli_epi16(mins_and_scales_01, 4), m4b_sse)); + const __m256i mins_23 = _mm256_cvtepu8_epi16(_mm_and_si128(_mm_srli_epi16(mins_and_scales_23, 4), m4b_sse)); + const __m256i mins_45 = _mm256_cvtepu8_epi16(_mm_and_si128(_mm_srli_epi16(mins_and_scales_45, 4), m4b_sse)); + const __m256i mins_67 = _mm256_cvtepu8_epi16(_mm_and_si128(_mm_srli_epi16(mins_and_scales_67, 4), m4b_sse)); + + // Scales of sub blocks in the sb loop + // Scales of the 0th sub block from each super block + __m128i scales_rearrange_0 = _mm_shuffle_epi8(scales_01, scalemask1); + __m256i scales_0 = _mm256_cvtepu8_epi16(scales_rearrange_0); + + // Scales of the 1st sub block from each super block + __m128i scales_rearrange_1 = _mm_shuffle_epi8(scales_01, scalemask2); + __m256i scales_1 = _mm256_cvtepu8_epi16(scales_rearrange_1); + + // Scales of the 2nd sub block from each super block + __m128i scales_rearrange_2 = _mm_shuffle_epi8(scales_23, scalemask1); + __m256i scales_2 = _mm256_cvtepu8_epi16(scales_rearrange_2); + + // Scales of the 3rd sub block from each super block + __m128i scales_rearrange_3 = _mm_shuffle_epi8(scales_23, scalemask2); + __m256i scales_3 = _mm256_cvtepu8_epi16(scales_rearrange_3); + + // Scales of the 4th sub block from each super block + __m128i scales_rearrange_4 = _mm_shuffle_epi8(scales_45, scalemask1); + __m256i scales_4 = _mm256_cvtepu8_epi16(scales_rearrange_4); + + // Scales of the 5th sub block from each super block + __m128i scales_rearrange_5 = _mm_shuffle_epi8(scales_45, scalemask2); + __m256i scales_5 = _mm256_cvtepu8_epi16(scales_rearrange_5); + + // Scales of the 6th sub block from each super block + __m128i scales_rearrange_6 = _mm_shuffle_epi8(scales_67, scalemask1); + __m256i scales_6 = _mm256_cvtepu8_epi16(scales_rearrange_6); + + // Scales of the 7th sub block from each super block + __m128i scales_rearrange_7 = _mm_shuffle_epi8(scales_67, scalemask2); + __m256i scales_7 = _mm256_cvtepu8_epi16(scales_rearrange_7); + + // Load the sub block values corresponding to sb in block_q8_K in batches of 16 bytes and replicate the same across 256 bit vector + __m256i lhs_vec_0 = _mm256_castsi128_si256(_mm_loadu_si128((const __m128i *)(a_ptr[b].qs + sb * 128))); + __m256i lhs_vec_1 = _mm256_castsi128_si256(_mm_loadu_si128((const __m128i *)(a_ptr[b].qs + 16 + sb * 128))); + __m256i lhs_vec_2 = _mm256_castsi128_si256(_mm_loadu_si128((const __m128i *)(a_ptr[b].qs + 32 + sb * 128))); + __m256i lhs_vec_3 = _mm256_castsi128_si256(_mm_loadu_si128((const __m128i *)(a_ptr[b].qs + 48 + sb * 128))); + __m256i lhs_vec_4 = _mm256_castsi128_si256(_mm_loadu_si128((const __m128i *)(a_ptr[b].qs + 64 + sb * 128))); + __m256i lhs_vec_5 = _mm256_castsi128_si256(_mm_loadu_si128((const __m128i *)(a_ptr[b].qs + 80 + sb * 128))); + __m256i lhs_vec_6 = _mm256_castsi128_si256(_mm_loadu_si128((const __m128i *)(a_ptr[b].qs + 96 + sb * 128))); + __m256i lhs_vec_7 = _mm256_castsi128_si256(_mm_loadu_si128((const __m128i *)(a_ptr[b].qs + 112 + sb * 128))); + + lhs_vec_0 = _mm256_permute2f128_si256(lhs_vec_0, lhs_vec_0, 0); + lhs_vec_1 = _mm256_permute2f128_si256(lhs_vec_1, lhs_vec_1, 0); + lhs_vec_2 = _mm256_permute2f128_si256(lhs_vec_2, lhs_vec_2, 0); + lhs_vec_3 = _mm256_permute2f128_si256(lhs_vec_3, lhs_vec_3, 0); + lhs_vec_4 = _mm256_permute2f128_si256(lhs_vec_4, lhs_vec_4, 0); + lhs_vec_5 = _mm256_permute2f128_si256(lhs_vec_5, lhs_vec_5, 0); + lhs_vec_6 = _mm256_permute2f128_si256(lhs_vec_6, lhs_vec_6, 0); + lhs_vec_7 = _mm256_permute2f128_si256(lhs_vec_7, lhs_vec_7, 0); + + __m256i iacc_0 = _mm256_setzero_si256(); + __m256i iacc_1 = _mm256_setzero_si256(); + __m256i iacc_2 = _mm256_setzero_si256(); + __m256i iacc_3 = _mm256_setzero_si256(); + __m256i iacc_4 = _mm256_setzero_si256(); + __m256i iacc_5 = _mm256_setzero_si256(); + __m256i iacc_6 = _mm256_setzero_si256(); + __m256i iacc_7 = _mm256_setzero_si256(); + + // Dot product done within 32 bit lanes and accumulated in the same vector + // First done for 0th sub block and then for seven (1st - 7th) other sub blocks processed for each sb (sb < QK_K/128 loop) // B0(0-3) B4(0-3) B1(0-3) B5(0-3) B2(0-3) B6(0-3) B3(0-3) B7(0-3) with A0(0-3) + // B0(4-7) B4(4-7) B1(4-7) B5(4-7) B2(4-7) B6(4-7) B3(4-7) B7(4-7) with A0(4-7) + // B0(8-11) B4(8-11) B1(8-11) B5(8-11) B2(8-11) B6(8-11) B3(8-11) B7(8-11) with A0(8-11) + // B0(12-15) B4(12-15) B1(12-15) B5(12-15) B2(12-15) B6(12-15) B3(12-15) B7(12-15) with A0(12-15) + + iacc_0 = _mm256_add_epi16(iacc_0, _mm256_maddubs_epi16(_mm256_blend_epi32(rhs_vec_0123_00 ,_mm256_shuffle_epi32(rhs_vec_4567_00, 177), 170), _mm256_shuffle_epi32(lhs_vec_0, 0))); + iacc_0 = _mm256_add_epi16(iacc_0, _mm256_maddubs_epi16(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_00, 177) ,rhs_vec_4567_00, 170), _mm256_shuffle_epi32(lhs_vec_0, 85))); + + iacc_0 = _mm256_add_epi16(iacc_0, _mm256_maddubs_epi16(_mm256_blend_epi32(rhs_vec_0123_01 ,_mm256_shuffle_epi32(rhs_vec_4567_01, 177), 170), _mm256_shuffle_epi32(lhs_vec_0, 170))); + iacc_0 = _mm256_add_epi16(iacc_0, _mm256_maddubs_epi16(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_01, 177) ,rhs_vec_4567_01, 170), _mm256_shuffle_epi32(lhs_vec_0, 255))); + + iacc_0 = _mm256_madd_epi16(iacc_0, scales_0); + + iacc_1 = _mm256_add_epi16(iacc_1, _mm256_maddubs_epi16(_mm256_blend_epi32(rhs_vec_0123_10 ,_mm256_shuffle_epi32(rhs_vec_4567_10, 177), 170), _mm256_shuffle_epi32(lhs_vec_1, 0))); + iacc_1 = _mm256_add_epi16(iacc_1, _mm256_maddubs_epi16(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_10, 177) ,rhs_vec_4567_10, 170), _mm256_shuffle_epi32(lhs_vec_1, 85))); + + iacc_1 = _mm256_add_epi16(iacc_1, _mm256_maddubs_epi16(_mm256_blend_epi32(rhs_vec_0123_11 ,_mm256_shuffle_epi32(rhs_vec_4567_11, 177), 170), _mm256_shuffle_epi32(lhs_vec_1, 170))); + iacc_1 = _mm256_add_epi16(iacc_1, _mm256_maddubs_epi16(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_11, 177) ,rhs_vec_4567_11, 170), _mm256_shuffle_epi32(lhs_vec_1, 255))); + + iacc_1 = _mm256_madd_epi16(iacc_1, scales_1); + + iacc_2 = _mm256_add_epi16(iacc_2, _mm256_maddubs_epi16(_mm256_blend_epi32(rhs_vec_0123_20 ,_mm256_shuffle_epi32(rhs_vec_4567_20, 177), 170), _mm256_shuffle_epi32(lhs_vec_2, 0))); + iacc_2 = _mm256_add_epi16(iacc_2, _mm256_maddubs_epi16(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_20, 177) ,rhs_vec_4567_20, 170), _mm256_shuffle_epi32(lhs_vec_2, 85))); + + iacc_2 = _mm256_add_epi16(iacc_2, _mm256_maddubs_epi16(_mm256_blend_epi32(rhs_vec_0123_21 ,_mm256_shuffle_epi32(rhs_vec_4567_21, 177), 170), _mm256_shuffle_epi32(lhs_vec_2, 170))); + iacc_2 = _mm256_add_epi16(iacc_2, _mm256_maddubs_epi16(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_21, 177) ,rhs_vec_4567_21, 170), _mm256_shuffle_epi32(lhs_vec_2, 255))); + + iacc_2 = _mm256_madd_epi16(iacc_2, scales_2); + + iacc_3 = _mm256_add_epi16(iacc_3, _mm256_maddubs_epi16(_mm256_blend_epi32(rhs_vec_0123_30 ,_mm256_shuffle_epi32(rhs_vec_4567_30, 177), 170), _mm256_shuffle_epi32(lhs_vec_3, 0))); + iacc_3 = _mm256_add_epi16(iacc_3, _mm256_maddubs_epi16(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_30, 177) ,rhs_vec_4567_30, 170), _mm256_shuffle_epi32(lhs_vec_3, 85))); + + iacc_3 = _mm256_add_epi16(iacc_3, _mm256_maddubs_epi16(_mm256_blend_epi32(rhs_vec_0123_31 ,_mm256_shuffle_epi32(rhs_vec_4567_31, 177), 170), _mm256_shuffle_epi32(lhs_vec_3, 170))); + iacc_3 = _mm256_add_epi16(iacc_3, _mm256_maddubs_epi16(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_31, 177) ,rhs_vec_4567_31, 170), _mm256_shuffle_epi32(lhs_vec_3, 255))); + + iacc_3 = _mm256_madd_epi16(iacc_3, scales_3); + + iacc_4 = _mm256_add_epi16(iacc_4, _mm256_maddubs_epi16(_mm256_blend_epi32(rhs_vec_0123_40 ,_mm256_shuffle_epi32(rhs_vec_4567_40, 177), 170), _mm256_shuffle_epi32(lhs_vec_4, 0))); + iacc_4 = _mm256_add_epi16(iacc_4, _mm256_maddubs_epi16(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_40, 177) ,rhs_vec_4567_40, 170), _mm256_shuffle_epi32(lhs_vec_4, 85))); + + iacc_4 = _mm256_add_epi16(iacc_4, _mm256_maddubs_epi16(_mm256_blend_epi32(rhs_vec_0123_41 ,_mm256_shuffle_epi32(rhs_vec_4567_41, 177), 170), _mm256_shuffle_epi32(lhs_vec_4, 170))); + iacc_4 = _mm256_add_epi16(iacc_4, _mm256_maddubs_epi16(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_41, 177) ,rhs_vec_4567_41, 170), _mm256_shuffle_epi32(lhs_vec_4, 255))); + + iacc_4 = _mm256_madd_epi16(iacc_4, scales_4); + + iacc_5 = _mm256_add_epi16(iacc_5, _mm256_maddubs_epi16(_mm256_blend_epi32(rhs_vec_0123_50 ,_mm256_shuffle_epi32(rhs_vec_4567_50, 177), 170), _mm256_shuffle_epi32(lhs_vec_5, 0))); + iacc_5 = _mm256_add_epi16(iacc_5, _mm256_maddubs_epi16(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_50, 177) ,rhs_vec_4567_50, 170), _mm256_shuffle_epi32(lhs_vec_5, 85))); + + iacc_5 = _mm256_add_epi16(iacc_5, _mm256_maddubs_epi16(_mm256_blend_epi32(rhs_vec_0123_51 ,_mm256_shuffle_epi32(rhs_vec_4567_51, 177), 170), _mm256_shuffle_epi32(lhs_vec_5, 170))); + iacc_5 = _mm256_add_epi16(iacc_5, _mm256_maddubs_epi16(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_51, 177) ,rhs_vec_4567_51, 170), _mm256_shuffle_epi32(lhs_vec_5, 255))); + + iacc_5 = _mm256_madd_epi16(iacc_5, scales_5); + + iacc_6 = _mm256_add_epi16(iacc_6, _mm256_maddubs_epi16(_mm256_blend_epi32(rhs_vec_0123_60 ,_mm256_shuffle_epi32(rhs_vec_4567_60, 177), 170), _mm256_shuffle_epi32(lhs_vec_6, 0))); + iacc_6 = _mm256_add_epi16(iacc_6, _mm256_maddubs_epi16(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_60, 177) ,rhs_vec_4567_60, 170), _mm256_shuffle_epi32(lhs_vec_6, 85))); + + iacc_6 = _mm256_add_epi16(iacc_6, _mm256_maddubs_epi16(_mm256_blend_epi32(rhs_vec_0123_61 ,_mm256_shuffle_epi32(rhs_vec_4567_61, 177), 170), _mm256_shuffle_epi32(lhs_vec_6, 170))); + iacc_6 = _mm256_add_epi16(iacc_6, _mm256_maddubs_epi16(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_61, 177) ,rhs_vec_4567_61, 170), _mm256_shuffle_epi32(lhs_vec_6, 255))); + + iacc_6 = _mm256_madd_epi16(iacc_6, scales_6); + + iacc_7 = _mm256_add_epi16(iacc_7, _mm256_maddubs_epi16(_mm256_blend_epi32(rhs_vec_0123_70 ,_mm256_shuffle_epi32(rhs_vec_4567_70, 177), 170), _mm256_shuffle_epi32(lhs_vec_7, 0))); + iacc_7 = _mm256_add_epi16(iacc_7, _mm256_maddubs_epi16(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_70, 177) ,rhs_vec_4567_70, 170), _mm256_shuffle_epi32(lhs_vec_7, 85))); + + iacc_7 = _mm256_add_epi16(iacc_7, _mm256_maddubs_epi16(_mm256_blend_epi32(rhs_vec_0123_71 ,_mm256_shuffle_epi32(rhs_vec_4567_71, 177), 170), _mm256_shuffle_epi32(lhs_vec_7, 170))); + iacc_7 = _mm256_add_epi16(iacc_7, _mm256_maddubs_epi16(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_71, 177) ,rhs_vec_4567_71, 170), _mm256_shuffle_epi32(lhs_vec_7, 255))); + + iacc_7 = _mm256_madd_epi16(iacc_7, scales_7); + + // Accumulate the iacc value for one sb + __m256i iacc_sb = _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(iacc_0, iacc_1), _mm256_add_epi32(iacc_2, iacc_3)), _mm256_add_epi32(_mm256_add_epi32(iacc_4, iacc_5), _mm256_add_epi32(iacc_6, iacc_7))); + + __m128i q8sums = _mm_loadu_si128((const __m128i *)(a_ptr[b].bsums + sb * 8)); + __m256i q8s = _mm256_castsi128_si256(q8sums); + q8s= _mm256_permute2f128_si256(q8s, q8s, 0); + + // Broadcast the bsums of the two corresponding subblocks of q8_k + // Multiply-Add with corresponding mins of Q2_Kx8 with bsums + __m256i iacc_min_sb_01 = _mm256_madd_epi16(_mm256_shuffle_epi32(q8s, 0), mins_01); + __m256i iacc_min_sb_23 = _mm256_madd_epi16(_mm256_shuffle_epi32(q8s, 85), mins_23); + __m256i iacc_min_sb_45 = _mm256_madd_epi16(_mm256_shuffle_epi32(q8s, 170), mins_45); + __m256i iacc_min_sb_67 = _mm256_madd_epi16(_mm256_shuffle_epi32(q8s, 255), mins_67); + + __m256i iacc_min_sb = _mm256_add_epi32(_mm256_add_epi32(iacc_min_sb_01, iacc_min_sb_23), _mm256_add_epi32(iacc_min_sb_45,iacc_min_sb_67)); + + // Accumulate for the complete block + iacc_b = _mm256_add_epi32(iacc_b, iacc_sb); + iacc_min_b = _mm256_add_epi32(iacc_min_b, iacc_min_sb); + } + + //Multiply-Add with scale values for complete super block + acc_row = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_b), _mm256_mul_ps(col_scale_f32, row_scale_f32), acc_row); + acc_min_rows = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_min_b), _mm256_mul_ps(col_dmin_f32, row_scale_f32), acc_min_rows); + } + // Accumulated output values permuted so as to be stored in appropriate order post accumulation + acc_row = _mm256_permutevar8x32_ps(acc_row, finalpermutemask); + _mm256_storeu_ps(s + (y * nr + x * 8), _mm256_sub_ps(acc_row, acc_min_rows)); + } + } +#else + + ggml_gemv_q2_K_8x8_q8_K_generic(n, s, bs, vx, vy, nr, nc); + +#endif +} + void ggml_gemm_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { const int qk = QK8_0; const int nb = n / qk; @@ -3050,3 +3363,2886 @@ void ggml_gemm_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo ggml_gemm_q4_K_8x8_q8_K_generic(n, s, bs, vx, vy, nr, nc); #endif } + +void ggml_gemm_q2_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + const int qk = QK_K; + const int nb = n / qk; + const int ncols_interleaved = 8; + const int blocklen = 8; + + assert (n % qk == 0); + assert (nr % 4 == 0); + assert (nc % ncols_interleaved == 0); + + UNUSED(s); + UNUSED(bs); + UNUSED(vx); + UNUSED(vy); + UNUSED(nr); + UNUSED(nc); + UNUSED(nb); + UNUSED(ncols_interleaved); + UNUSED(blocklen); + +#if defined(__AVX2__) || defined(__AVX512F__) + const block_q2_Kx8 * b_ptr_start = (const block_q2_Kx8 * ) vx; + const block_q8_Kx4 * a_ptr_start = (const block_q8_Kx4 * ) vy; + int64_t b_nb = n / QK_K; + int64_t y = 0; + + // Permute mask used for easier vector processing at later stages + __m256i requiredOrder = _mm256_set_epi32(3, 2, 1, 0, 7, 6, 5, 4); + int64_t xstart = 0; + int anr = nr - nr % 16; // Used to align nr with boundary of 16 + + // Mask to convert 2 bit and 4 bit values into a bytes + const __m256i m3b = _mm256_set1_epi8(3); + const __m128i m4b_sse = _mm_set1_epi8(0xF); + + //Mask to get appropriate scales + __m128i scalesmask1_sse = _mm_set_epi8(14,14,12,12,10,10,8,8,6,6,4,4,2,2,0,0); + __m128i scalesmask2_sse = _mm_set_epi8(15,15,13,13,11,11,9,9,7,7,5,5,3,3,1,1); + + __m256i scalesmask1 = _mm256_castsi128_si256(scalesmask1_sse); + scalesmask1 = _mm256_permute2f128_si256(scalesmask1, scalesmask1, 0); + __m256i scalesmask2 = _mm256_castsi128_si256(scalesmask2_sse); + scalesmask2 = _mm256_permute2f128_si256(scalesmask2, scalesmask2, 0); + +#ifdef __AVX512F__ + + int anc = nc - nc % 16; // Used to align nc with boundary of 16 + + // Mask to mask out nibbles from packed bytes + const __m256i m4b = _mm256_set1_epi8(0x0F); + // Mask to mask out nibbles from packed bytes expanded to 512 bit length + const __m512i m3bexpanded = _mm512_set1_epi8(3); + //Take group of four block_q8_Kx4 structures at each pass of the loop and perform dot product operation + for (; y < anr / 4; y += 4) { + + const block_q8_Kx4 * a_ptrs[4]; + + a_ptrs[0] = a_ptr_start + (y * nb); + for (int i = 0; i < 3; ++i) { + a_ptrs[i + 1] = a_ptrs[i] + nb; + } + + // Take group of eight block_q2_kx8 structures at each pass of the loop and perform dot product operation + for (int64_t x = 0; x < anc / 8; x += 2) { + + const block_q2_Kx8 * b_ptr_0 = b_ptr_start + ((x) * b_nb); + const block_q2_Kx8 * b_ptr_1 = b_ptr_start + ((x + 1) * b_nb); + + // Master FP accumulators + __m512 acc_rows[16]; + for (int i = 0; i < 16; i++) { + acc_rows[i] = _mm512_setzero_ps(); + } + + __m512 acc_min_rows[16]; + for (int i = 0; i < 16; i++) { + acc_min_rows[i] = _mm512_setzero_ps(); + } + // For super block + for (int64_t b = 0; b < nb; b++) { + // Delta values - Load the sixteen scale values from two block_q2_kx8 structures + const __m512 col_scale_f32 = GGML_F32Cx8x2_LOAD(b_ptr_0[b].d, b_ptr_1[b].d); + + // dmin values - Load the sixteen dmin values from two block_q2_kx8 structures + const __m512 col_dmin_f32 = GGML_F32Cx8x2_LOAD(b_ptr_0[b].dmin, b_ptr_1[b].dmin); + + // Loop to iterate over the sixteen sub blocks of a super block - eight sub blocks are processed per iteration + for (int sb = 0; sb < QK_K / 128; sb++) { + + // Load the eight block_q2_k for eight sub blocks quantized values interleaved with each other in chunks of eight bytes - B0,B1 ....B6,B7 + const __m256i rhs_raw_mat_0123_0 = _mm256_loadu_si256((const __m256i * )(b_ptr_0[b].qs + sb * 256)); + const __m256i rhs_raw_mat_4567_0 = _mm256_loadu_si256((const __m256i * )(b_ptr_0[b].qs + 32 + sb * 256)); + const __m256i rhs_raw_mat_0123_1 = _mm256_loadu_si256((const __m256i * )(b_ptr_0[b].qs + 64 + sb * 256)); + const __m256i rhs_raw_mat_4567_1 = _mm256_loadu_si256((const __m256i * )(b_ptr_0[b].qs + 96 + sb * 256)); + const __m256i rhs_raw_mat_0123_2 = _mm256_loadu_si256((const __m256i * )(b_ptr_0[b].qs + 128 + sb * 256)); + const __m256i rhs_raw_mat_4567_2 = _mm256_loadu_si256((const __m256i * )(b_ptr_0[b].qs + 160 + sb * 256)); + const __m256i rhs_raw_mat_0123_3 = _mm256_loadu_si256((const __m256i * )(b_ptr_0[b].qs + 192 + sb * 256)); + const __m256i rhs_raw_mat_4567_3 = _mm256_loadu_si256((const __m256i * )(b_ptr_0[b].qs + 224 + sb * 256)); + + const __m256i rhs_raw_mat_89AB_0 = _mm256_loadu_si256((const __m256i * )(b_ptr_1[b].qs + sb * 256)); + const __m256i rhs_raw_mat_CDEF_0 = _mm256_loadu_si256((const __m256i * )(b_ptr_1[b].qs + 32 + sb * 256)); + const __m256i rhs_raw_mat_89AB_1 = _mm256_loadu_si256((const __m256i * )(b_ptr_1[b].qs + 64 + sb * 256)); + const __m256i rhs_raw_mat_CDEF_1 = _mm256_loadu_si256((const __m256i * )(b_ptr_1[b].qs + 96 + sb * 256)); + const __m256i rhs_raw_mat_89AB_2 = _mm256_loadu_si256((const __m256i * )(b_ptr_1[b].qs + 128 + sb * 256)); + const __m256i rhs_raw_mat_CDEF_2 = _mm256_loadu_si256((const __m256i * )(b_ptr_1[b].qs + 160 + sb * 256)); + const __m256i rhs_raw_mat_89AB_3 = _mm256_loadu_si256((const __m256i * )(b_ptr_1[b].qs + 192 + sb * 256)); + const __m256i rhs_raw_mat_CDEF_3 = _mm256_loadu_si256((const __m256i * )(b_ptr_1[b].qs + 224 + sb * 256)); + + const __m256i rhs_raw_mat_0145_0 = _mm256_blend_epi32(rhs_raw_mat_0123_0, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_0, requiredOrder), 240); + const __m256i rhs_raw_mat_2367_0 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_0, requiredOrder), rhs_raw_mat_4567_0, 240); + const __m256i rhs_raw_mat_0145_1 = _mm256_blend_epi32(rhs_raw_mat_0123_1, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_1, requiredOrder), 240); + const __m256i rhs_raw_mat_2367_1 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_1, requiredOrder), rhs_raw_mat_4567_1, 240); + const __m256i rhs_raw_mat_0145_2 = _mm256_blend_epi32(rhs_raw_mat_0123_2, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_2, requiredOrder), 240); + const __m256i rhs_raw_mat_2367_2 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_2, requiredOrder), rhs_raw_mat_4567_2, 240); + const __m256i rhs_raw_mat_0145_3 = _mm256_blend_epi32(rhs_raw_mat_0123_3, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_3, requiredOrder), 240); + const __m256i rhs_raw_mat_2367_3 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_3, requiredOrder), rhs_raw_mat_4567_3, 240); + + const __m256i rhs_raw_mat_89CD_0 = _mm256_blend_epi32(rhs_raw_mat_89AB_0, _mm256_permutevar8x32_epi32(rhs_raw_mat_CDEF_0, requiredOrder), 240); + const __m256i rhs_raw_mat_ABEF_0 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_89AB_0, requiredOrder), rhs_raw_mat_CDEF_0, 240); + const __m256i rhs_raw_mat_89CD_1 = _mm256_blend_epi32(rhs_raw_mat_89AB_1, _mm256_permutevar8x32_epi32(rhs_raw_mat_CDEF_1, requiredOrder), 240); + const __m256i rhs_raw_mat_ABEF_1 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_89AB_1, requiredOrder), rhs_raw_mat_CDEF_1, 240); + const __m256i rhs_raw_mat_89CD_2 = _mm256_blend_epi32(rhs_raw_mat_89AB_2, _mm256_permutevar8x32_epi32(rhs_raw_mat_CDEF_2, requiredOrder), 240); + const __m256i rhs_raw_mat_ABEF_2 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_89AB_2, requiredOrder), rhs_raw_mat_CDEF_2, 240); + const __m256i rhs_raw_mat_89CD_3 = _mm256_blend_epi32(rhs_raw_mat_89AB_3, _mm256_permutevar8x32_epi32(rhs_raw_mat_CDEF_3, requiredOrder), 240); + const __m256i rhs_raw_mat_ABEF_3 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_89AB_3, requiredOrder), rhs_raw_mat_CDEF_3, 240); + + const __m512i rhs_raw_mat_014589CD_0 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_0145_0), rhs_raw_mat_89CD_0, 1); + const __m512i rhs_raw_mat_2367ABEF_0 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_2367_0), rhs_raw_mat_ABEF_0, 1); + const __m512i rhs_raw_mat_014589CD_1 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_0145_1), rhs_raw_mat_89CD_1, 1); + const __m512i rhs_raw_mat_2367ABEF_1 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_2367_1), rhs_raw_mat_ABEF_1, 1); + + const __m512i rhs_raw_mat_014589CD_2 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_0145_2), rhs_raw_mat_89CD_2, 1); + const __m512i rhs_raw_mat_2367ABEF_2 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_2367_2), rhs_raw_mat_ABEF_2, 1); + const __m512i rhs_raw_mat_014589CD_3 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_0145_3), rhs_raw_mat_89CD_3, 1); + const __m512i rhs_raw_mat_2367ABEF_3 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_2367_3), rhs_raw_mat_ABEF_3, 1); + + //2-bit -> 8-bit + const __m512i rhs_mat_014589CD_00 = _mm512_and_si512(rhs_raw_mat_014589CD_0,m3bexpanded); //B00(0-7) B01(0-7) B04(0-7) B05(0-7) B08(0-7) B09(0-7) B0C(0-7) B0D(0-7) + const __m512i rhs_mat_2367ABEF_00 = _mm512_and_si512(rhs_raw_mat_2367ABEF_0,m3bexpanded); //B02(0-7) B03(0-7) B06(0-7) B07(0-7) B0A(0-7) B0B(0-7) B0E(0-7) B0F(0-7) + const __m512i rhs_mat_014589CD_01 = _mm512_and_si512(rhs_raw_mat_014589CD_1,m3bexpanded); //B00(8-15) B01(8-15) B04(8-15) B05(8-15) B08(8-15) B09(8-15) B0C(8-15) B0D(8-15) + const __m512i rhs_mat_2367ABEF_01 = _mm512_and_si512(rhs_raw_mat_2367ABEF_1,m3bexpanded); //B02(8-15) B03(8-15) B06(8-15) B07(8-15) B0A(8-15) B0B(8-15) B0E(8-15) B0F(8-15) + const __m512i rhs_mat_014589CD_10 = _mm512_and_si512(rhs_raw_mat_014589CD_2,m3bexpanded); //B10(0-7) B11(0-7) B14(0-7) B15(0-7) B18(0-7) B19(0-7) B1C(0-7) B1D(0-7) + const __m512i rhs_mat_2367ABEF_10 = _mm512_and_si512(rhs_raw_mat_2367ABEF_2,m3bexpanded); //B12(0-7) B13(0-7) B16(0-7) B17(0-7) B1A(0-7) B1B(0-7) B1E(0-7) B1F(0-7) + const __m512i rhs_mat_014589CD_11 = _mm512_and_si512(rhs_raw_mat_014589CD_3,m3bexpanded); //B10(8-15) B11(8-15) B14(8-15) B15(8-15) B18(8-15) B19(8-15) B1C(8-15) B1D(8-15) + const __m512i rhs_mat_2367ABEF_11 = _mm512_and_si512(rhs_raw_mat_2367ABEF_3,m3bexpanded); //B12(8-15) B13(8-15) B16(8-15) B17(8-15) B1A(8-15) B1B(8-15) B1E(8-15) B1F(8-15) + + const __m512i rhs_mat_014589CD_20 = _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_014589CD_0, 2), m3bexpanded); //B20(0-7) B21(0-7) B24(0-7) B25(0-7) B28(0-7) B29(0-7) B2C(0-7) B2D(0-7) + const __m512i rhs_mat_2367ABEF_20 = _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_2367ABEF_0, 2), m3bexpanded); //B22(0-7) B23(0-7) B26(0-7) B27(0-7) B2A(0-7) B2B(0-7) B2E(0-7) B2F(0-7) + + const __m512i rhs_mat_014589CD_21 = _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_014589CD_1, 2), m3bexpanded); //B20(8-15) B21(8-15) B24(8-15) B25(8-15) B28(8-15) B29(8-15) B2C(8-15) B2D(8-15) + const __m512i rhs_mat_2367ABEF_21 = _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_2367ABEF_1, 2), m3bexpanded); //B22(8-15) B23(8-15) B26(8-15) B27(8-15) B2A(8-15) B2B(8-15) B2E(8-15) B2F(8-15) + + const __m512i rhs_mat_014589CD_30 = _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_014589CD_2, 2), m3bexpanded); //B30(0-7) B31(0-7) B34(0-7) B35(0-7) B38(0-7) B39(0-7) B3C(0-7) B3D(0-7) + const __m512i rhs_mat_2367ABEF_30 = _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_2367ABEF_2, 2), m3bexpanded); //B32(0-7) B33(0-7) B36(0-7) B37(0-7) B3A(0-7) B3B(0-7) B3E(0-7) B3F(0-7) + + const __m512i rhs_mat_014589CD_31 = _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_014589CD_3, 2), m3bexpanded); //B30(8-15) B31(8-15) B34(8-15) B35(8-15) B38(8-15) B39(8-15) B3C(8-15) B3D(8-15) + const __m512i rhs_mat_2367ABEF_31 = _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_2367ABEF_3, 2), m3bexpanded); //B32(8-15) B33(8-15) B36(8-15) B37(8-15) B3A(8-15) B3B(8-15) B3E(8-15) B3F(8-15) + + const __m512i rhs_mat_014589CD_40 = _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_014589CD_0, 4), m3bexpanded); //B40(0-7) B41(0-7) B44(0-7) B45(0-7) B48(0-7) B49(0-7) B4C(0-7) B4D(0-7) + const __m512i rhs_mat_2367ABEF_40 = _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_2367ABEF_0, 4), m3bexpanded); //B42(0-7) B43(0-7) B46(0-7) B47(0-7) B4A(0-7) B4B(0-7) B4E(0-7) B4F(0-7) + + const __m512i rhs_mat_014589CD_41 = _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_014589CD_1, 4), m3bexpanded); //B40(8-15) B41(8-15) B44(8-15) B45(8-15) B48(8-15) B49(8-15) B4C(8-15) B4D(8-15) + const __m512i rhs_mat_2367ABEF_41 = _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_2367ABEF_1, 4), m3bexpanded); //B42(8-15) B43(8-15) B46(8-15) B47(8-15) B4A(8-15) B4B(8-15) B4E(8-15) B4F(8-15) + + const __m512i rhs_mat_014589CD_50 = _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_014589CD_2, 4), m3bexpanded); //B50(0-7) B51(0-7) B54(0-7) B55(0-7) B58(0-7) B59(0-7) B5C(0-7) B5D(0-7) + const __m512i rhs_mat_2367ABEF_50 = _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_2367ABEF_2, 4), m3bexpanded); //B52(0-7) B53(0-7) B56(0-7) B57(0-7) B5A(0-7) B5B(0-7) B5E(0-7) B5F(0-7) + + const __m512i rhs_mat_014589CD_51 = _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_014589CD_3, 4), m3bexpanded); //B50(8-15) B51(8-15) B54(8-15) B55(8-15) B58(8-15) B59(8-15) B5C(8-15) B5D(8-15) + const __m512i rhs_mat_2367ABEF_51 = _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_2367ABEF_3, 4), m3bexpanded); //B52(8-15) B53(8-15) B56(8-15) B57(8-15) B5A(8-15) B5B(8-15) B5E(8-15) B5F(8-15) + + const __m512i rhs_mat_014589CD_60 = _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_014589CD_0, 6), m3bexpanded); //B60(0-7) B61(0-7) B64(0-7) B65(0-7) B68(0-7) B69(0-7) B6C(0-7) B6D(0-7) + const __m512i rhs_mat_2367ABEF_60 = _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_2367ABEF_0, 6), m3bexpanded); //B62(0-7) B63(0-7) B66(0-7) B67(0-7) B6A(0-7) B6B(0-7) B6E(0-7) B6F(0-7) + + const __m512i rhs_mat_014589CD_61 = _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_014589CD_1, 6), m3bexpanded); //B60(8-15) B61(8-15) B64(8-15) B65(8-15) B68(8-15) B69(8-15) B6C(8-15) B6D(8-15) + const __m512i rhs_mat_2367ABEF_61 = _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_2367ABEF_1, 6), m3bexpanded); //B62(8-15) B63(8-15) B66(8-15) B67(8-15) B6A(8-15) B6B(8-15) B6E(8-15) B6F(8-15) + + const __m512i rhs_mat_014589CD_70 = _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_014589CD_2, 6), m3bexpanded); //B70(0-7) B71(0-7) B74(0-7) B75(0-7) B78(0-7) B79(0-7) B7C(0-7) B7D(0-7) + const __m512i rhs_mat_2367ABEF_70 = _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_2367ABEF_2, 6), m3bexpanded); //B72(0-7) B73(0-7) B76(0-7) B77(0-7) B7A(0-7) B7B(0-7) B7E(0-7) B7F(0-7) + + const __m512i rhs_mat_014589CD_71 = _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_014589CD_3, 6), m3bexpanded); //B70(8-15) B71(8-15) B74(8-15) B75(8-15) B78(8-15) B79(8-15) B7C(8-15) B7D(8-15) + const __m512i rhs_mat_2367ABEF_71 = _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_2367ABEF_3, 6), m3bexpanded); //B72(8-15) B73(8-15) B76(8-15) B77(8-15) B7A(8-15) B7B(8-15) B7E(8-15) B7F(8-15) + + const __m512i rhs_mat_014589CD_00_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_00, (_MM_PERM_ENUM)136); //B00(0-3) B01(0-3) B00(0-3) B01(0-3) B04(0-3) B05(0-3) B04(0-3) B05(0-3) B08(0-3) B09(0-3) B08(0-3) B09(0-3) B0C(0-3) B0D(0-3) B0C(0-3) B0D(0-3) + const __m512i rhs_mat_2367ABEF_00_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_00, (_MM_PERM_ENUM)136); //B02(0-3) B03(0-3) B02(0-3) B03(0-3) B06(0-3) B07(0-3) B06(0-3) B07(0-3) B0A(0-3) B0B(0-3) B0A(0-3) B0B(0-3) B0E(0-3) B0F(0-3) B0E(0-3) B0F(0-3) + + const __m512i rhs_mat_014589CD_01_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_01, (_MM_PERM_ENUM)136); //B00(8-11) B01(8-11) B00(8-11) B01(8-11) B04(8-11) B05(8-11) B04(8-11) B05(8-11) B08(8-11) B09(8-11) B08(8-11) B09(8-11) B0C(8-11) B0D(8-11) B0C(8-11) B0D(8-11) + const __m512i rhs_mat_2367ABEF_01_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_01, (_MM_PERM_ENUM)136); //B02(8-11) B03(8-11) B02(8-11) B03(8-11) B06(8-11) B07(8-11) B06(8-11) B07(8-11) B0A(8-11) B0B(8-11) B0A(8-11) B0B(8-11) B0E(8-11) B0F(8-11) B0E(8-11) B0F(8-11) + + const __m512i rhs_mat_014589CD_10_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_10, (_MM_PERM_ENUM)136); //B10(0-3) B11(0-3) B10(0-3) B11(0-3) B14(0-3) B15(0-3) B14(0-3) B15(0-3) B18(0-3) B19(0-3) B18(0-3) B19(0-3) B1C(0-3) B1D(0-3) B1C(0-3) B1D(0-3) + const __m512i rhs_mat_2367ABEF_10_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_10, (_MM_PERM_ENUM)136); //B12(0-3) B13(0-3) B12(0-3) B13(0-3) B16(0-3) B17(0-3) B16(0-3) B17(0-3) B1A(0-3) B1B(0-3) B1A(0-3) B1B(0-3) B1E(0-3) B1F(0-3) B1E(0-3) B1F(0-3) + + const __m512i rhs_mat_014589CD_11_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_11, (_MM_PERM_ENUM)136); //B10(8-11) B11(8-11) B10(8-11) B11(8-11) B14(8-11) B15(8-11) B14(8-11) B15(8-11) B18(8-11) B19(8-11) B18(8-11) B19(8-11) B1C(8-11) B1D(8-11) B1C(8-11) B1D(8-11) + const __m512i rhs_mat_2367ABEF_11_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_11, (_MM_PERM_ENUM)136); //B12(8-11) B13(8-11) B12(8-11) B13(8-11) B16(8-11) B17(8-11) B16(8-11) B17(8-11) B1A(8-11) B1B(8-11) B1A(8-11) B1B(8-11) B1E(8-11) B1F(8-11) B1E(8-11) B1F(8-11) + + const __m512i rhs_mat_014589CD_20_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_20, (_MM_PERM_ENUM)136); //B20(0-3) B21(0-3) B20(0-3) B21(0-3) B24(0-3) B25(0-3) B24(0-3) B25(0-3) B28(0-3) B29(0-3) B28(0-3) B29(0-3) B2C(0-3) B2D(0-3) B2C(0-3) B2D(0-3) + const __m512i rhs_mat_2367ABEF_20_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_20, (_MM_PERM_ENUM)136); //B22(0-3) B23(0-3) B22(0-3) B23(0-3) B26(0-3) B27(0-3) B26(0-3) B27(0-3) B2A(0-3) B2B(0-3) B2A(0-3) B2B(0-3) B2E(0-3) B2F(0-3) B2E(0-3) B2F(0-3) + + const __m512i rhs_mat_014589CD_21_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_21, (_MM_PERM_ENUM)136); //B20(8-11) B21(8-11) B20(8-11) B21(8-11) B24(8-11) B25(8-11) B24(8-11) B25(8-11) B28(8-11) B29(8-11) B28(8-11) B29(8-11) B2C(8-11) B2D(8-11) B2C(8-11) B2D(8-11) + const __m512i rhs_mat_2367ABEF_21_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_21, (_MM_PERM_ENUM)136); //B22(8-11) B23(8-11) B22(8-11) B23(8-11) B26(8-11) B27(8-11) B26(8-11) B27(8-11) B2A(8-11) B2B(8-11) B2A(8-11) B2B(8-11) B2E(8-11) B2F(8-11) B2E(8-11) B2F(8-11) + + const __m512i rhs_mat_014589CD_30_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_30, (_MM_PERM_ENUM)136); ///B30(0-3) B31(0-3) B30(0-3) B31(0-3) B34(0-3) B35(0-3) B34(0-3) B35(0-3) B38(0-3) B39(0-3) B38(0-3) B39(0-3) B3C(0-3) B3D(0-3) B3C(0-3) B3D(0-3) + const __m512i rhs_mat_2367ABEF_30_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_30, (_MM_PERM_ENUM)136); //B32(0-3) B33(0-3) B32(0-3) B33(0-3) B36(0-3) B37(0-3) B36(0-3) B37(0-3) B3A(0-3) B3B(0-3) B3A(0-3) B3B(0-3) B3E(0-3) B3F(0-3) B3E(0-3) B3F(0-3) + + const __m512i rhs_mat_014589CD_31_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_31, (_MM_PERM_ENUM)136); //B30(8-11) B31(8-11) B30(8-11) B31(8-11) B34(8-11) B35(8-11) B34(8-11) B35(8-11) B38(8-11) B39(8-11) B38(8-11) B39(8-11) B3C(8-11) B3D(8-11) B3C(8-11) B3D(8-11) + const __m512i rhs_mat_2367ABEF_31_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_31, (_MM_PERM_ENUM)136); //B32(8-11) B33(8-11) B32(8-11) B33(8-11) B36(8-11) B37(8-11) B36(8-11) B37(8-11) B3A(8-11) B3B(8-11) B3A(8-11) B3B(8-11) B3E(8-11) B3F(8-11) B3E(8-11) B3F(8-11) + + const __m512i rhs_mat_014589CD_40_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_40, (_MM_PERM_ENUM)136); //B40(0-3) B41(0-3) B40(0-3) B41(0-3) B44(0-3) B45(0-3) B44(0-3) B45(0-3) B48(0-3) B49(0-3) B48(0-3) B49(0-3) B4C(0-3) B4D(0-3) B4C(0-3) B4D(0-3) + const __m512i rhs_mat_2367ABEF_40_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_40, (_MM_PERM_ENUM)136); //B42(0-3) B43(0-3) B42(0-3) B43(0-3) B46(0-3) B47(0-3) B46(0-3) B47(0-3) B4A(0-3) B4B(0-3) B4A(0-3) B4B(0-3) B4E(0-3) B4F(0-3) B4E(0-3) B4F(0-3) + + const __m512i rhs_mat_014589CD_41_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_41, (_MM_PERM_ENUM)136); //B40(8-11) B41(8-11) B40(8-11) B41(8-11) B44(8-11) B45(8-11) B44(8-11) B45(8-11) B48(8-11) B49(8-11) B48(8-11) B49(8-11) B4C(8-11) B4D(8-11) B4C(8-11) B4D(8-11) + const __m512i rhs_mat_2367ABEF_41_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_41, (_MM_PERM_ENUM)136); //B42(8-11) B43(8-11) B42(8-11) B43(8-11) B46(8-11) B47(8-11) B46(8-11) B47(8-11) B4A(8-11) B4B(8-11) B4A(8-11) B4B(8-11) B4E(8-11) B4F(8-11) B4E(8-11) B4F(8-11) + + const __m512i rhs_mat_014589CD_50_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_50, (_MM_PERM_ENUM)136); //B50(0-3) B51(0-3) B50(0-3) B51(0-3) B54(0-3) B55(0-3) B54(0-3) B55(0-3) B58(0-3) B59(0-3) B58(0-3) B59(0-3) B5C(0-3) B5D(0-3) B5C(0-3) B5D(0-3) + const __m512i rhs_mat_2367ABEF_50_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_50, (_MM_PERM_ENUM)136); //B52(0-3) B53(0-3) B52(0-3) B53(0-3) B56(0-3) B57(0-3) B56(0-3) B57(0-3) B5A(0-3) B5B(0-3) B5A(0-3) B5B(0-3) B5E(0-3) B5F(0-3) B5E(0-3) B5F(0-3) + + const __m512i rhs_mat_014589CD_51_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_51, (_MM_PERM_ENUM)136); //B50(8-11) B51(8-11) B50(8-11) B51(8-11) B54(8-11) B55(8-11) B54(8-11) B55(8-11) B58(8-11) B59(8-11) B58(8-11) B59(8-11) B5C(8-11) B5D(8-11) B5C(8-11) B5D(8-11) + const __m512i rhs_mat_2367ABEF_51_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_51, (_MM_PERM_ENUM)136); //B52(8-11) B53(8-11) B52(8-11) B53(8-11) B56(8-11) B57(8-11) B56(8-11) B57(8-11) B5A(8-11) B5B(8-11) B5A(8-11) B5B(8-11) B5E(8-11) B5F(8-11) B5E(8-11) B5F(8-11) + + const __m512i rhs_mat_014589CD_60_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_60, (_MM_PERM_ENUM)136); //B60(0-3) B61(0-3) B60(0-3) B61(0-3) B64(0-3) B65(0-3) B64(0-3) B65(0-3) B68(0-3) B69(0-3) B68(0-3) B69(0-3) B6C(0-3) B6D(0-3) B6C(0-3) B6D(0-3) + const __m512i rhs_mat_2367ABEF_60_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_60, (_MM_PERM_ENUM)136); //B62(0-3) B63(0-3) B62(0-3) B63(0-3) B66(0-3) B67(0-3) B66(0-3) B67(0-3) B6A(0-3) B6B(0-3) B6A(0-3) B6B(0-3) B6E(0-3) B6F(0-3) B6E(0-3) B6F(0-3) + + const __m512i rhs_mat_014589CD_61_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_61, (_MM_PERM_ENUM)136); //B60(8-11) B61(8-11) B60(8-11) B61(8-11) B64(8-11) B65(8-11) B64(8-11) B65(8-11) B68(8-11) B69(8-11) B68(8-11) B69(8-11) B6C(8-11) B6D(8-11) B6C(8-11) B6D(8-11) + const __m512i rhs_mat_2367ABEF_61_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_61, (_MM_PERM_ENUM)136); //B62(8-11) B63(8-11) B62(8-11) B63(8-11) B66(8-11) B67(8-11) B66(8-11) B67(8-11) B6A(8-11) B6B(8-11) B6A(8-11) B6B(8-11) B6E(8-11) B6F(8-11) B6E(8-11) B6F(8-11) + + const __m512i rhs_mat_014589CD_70_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_70, (_MM_PERM_ENUM)136); //B70(0-3) B71(0-3) B70(0-3) B71(0-3) B74(0-3) B75(0-3) B74(0-3) B75(0-3) B78(0-3) B79(0-3) B78(0-3) B79(0-3) B7C(0-3) B7D(0-3) B7C(0-3) B7D(0-3) + const __m512i rhs_mat_2367ABEF_70_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_70, (_MM_PERM_ENUM)136); //B72(0-3) B73(0-3) B72(0-3) B73(0-3) B76(0-3) B77(0-3) B76(0-3) B77(0-3) B7A(0-3) B7B(0-3) B7A(0-3) B7B(0-3) B7E(0-3) B7F(0-3) B7E(0-3) B7F(0-3) + + const __m512i rhs_mat_014589CD_71_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_71, (_MM_PERM_ENUM)136); //B00(8-11) B01(8-11) B00(8-11) B01(8-11) B04(8-11) B05(8-11) B04(8-11) B05(8-11) B08(8-11) B09(8-11) B08(8-11) B09(8-11) B0C(8-11) B0D(8-11) B0C(8-11) B0D(8-11) + const __m512i rhs_mat_2367ABEF_71_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_71, (_MM_PERM_ENUM)136); //B72(8-11) B73(8-11) B72(8-11) B73(8-11) B76(8-11) B77(8-11) B76(8-11) B77(8-11) B7A(8-11) B7B(8-11) B7A(8-11) B7B(8-11) B7E(8-11) B7F(8-11) B7E(8-11) B7F(8-11) + + const __m512i rhs_mat_014589CD_00_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_00, (_MM_PERM_ENUM)221); //B00(4-7) B01(4-7) B00(4-7) B01(4-7) B04(4-7) B05(4-7) B04(4-7) B05(4-7) B08(4-7) B09(4-7) B08(4-7) B09(4-7) B0C(4-7) B0D(4-7) B0C(4-7) B0D(4-7) + const __m512i rhs_mat_2367ABEF_00_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_00, (_MM_PERM_ENUM)221); //B02(4-7) B03(4-7) B02(4-7) B03(4-7) B06(4-7) B07(4-7) B06(4-7) B07(4-7) B0A(4-7) B0B(4-7) B0A(4-7) B0B(4-7) B0E(4-7) B0F(4-7) B0E(4-7) B0F(4-7) + + const __m512i rhs_mat_014589CD_01_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_01, (_MM_PERM_ENUM)221); //B00(12-15) B01(12-15) B00(12-15) B01(12-15) B04(12-15) B05(12-15) B04(12-15) B05(12-15) B08(12-15) B09(12-15) B08(12-15) B09(12-15) B0C(12-15) B0D(12-15) B0C(12-15) B0D(12-15) + const __m512i rhs_mat_2367ABEF_01_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_01, (_MM_PERM_ENUM)221); //B02(12-15) B03(12-15) B02(12-15) B03(12-15) B06(12-15) B07(12-15) B06(12-15) B07(12-15) B0A(12-15) B0B(12-15) B0A(12-15) B0B(12-15) B0E(12-15) B0F(12-15) B0E(12-15) B0F(12-15) + + const __m512i rhs_mat_014589CD_10_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_10, (_MM_PERM_ENUM)221); //B10(4-7) B11(4-7) B10(4-7) B11(4-7) B14(4-7) B15(4-7) B14(4-7) B15(4-7) B18(4-7) B19(4-7) B18(4-7) B19(4-7) B1C(4-7) B1D(4-7) B1C(4-7) B1D(4-7) + const __m512i rhs_mat_2367ABEF_10_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_10, (_MM_PERM_ENUM)221); //B12(4-7) B13(4-7) B12(4-7) B13(4-7) B16(4-7) B17(4-7) B16(4-7) B17(4-7) B1A(4-7) B1B(4-7) B1A(4-7) B1B(4-7) B1E(4-7) B1F(4-7) B1E(4-7) B1F(4-7) + + const __m512i rhs_mat_014589CD_11_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_11, (_MM_PERM_ENUM)221); //B10(12-15) B11(12-15) B10(12-15) B11(12-15) B14(12-15) B15(12-15) B14(12-15) B15(12-15) B18(12-15) B19(12-15) B18(12-15) B19(12-15) B1C(12-15) B1D(12-15) B1C(12-15) B1D(12-15) + const __m512i rhs_mat_2367ABEF_11_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_11, (_MM_PERM_ENUM)221); //B12(12-15) B13(12-15) B12(12-15) B13(12-15) B16(12-15) B17(12-15) B16(12-15) B17(12-15) B1A(12-15) B1B(12-15) B1A(12-15) B1B(12-15) B1E(12-15) B1F(12-15) B1E(12-15) B1F(12-15) + + const __m512i rhs_mat_014589CD_20_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_20, (_MM_PERM_ENUM)221); //B20(4-7) B21(4-7) B20(4-7) B21(4-7) B24(4-7) B25(4-7) B24(4-7) B25(4-7) B28(4-7) B29(4-7) B28(4-7) B29(4-7) B2C(4-7) B2D(4-7) B2C(4-7) B2D(4-7) + const __m512i rhs_mat_2367ABEF_20_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_20, (_MM_PERM_ENUM)221); //B22(4-7) B23(4-7) B22(4-7) B23(4-7) B26(4-7) B27(4-7) B26(4-7) B27(4-7) B2A(4-7) B2B(4-7) B2A(4-7) B2B(4-7) B2E(4-7) B2F(4-7) B2E(4-7) B2F(4-7) + + const __m512i rhs_mat_014589CD_21_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_21, (_MM_PERM_ENUM)221); //B20(12-15) B21(12-15) B20(12-15) B21(12-15) B24(12-15) B25(12-15) B24(12-15) B25(12-15) B28(12-15) B29(12-15) B28(12-15) B29(12-15) B2C(12-15) B2D(12-15) B2C(12-15) B2D(12-15) + const __m512i rhs_mat_2367ABEF_21_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_21, (_MM_PERM_ENUM)221); //B22(12-15) B23(12-15) B22(12-15) B23(12-15) B26(12-15) B27(12-15) B26(12-15) B27(12-15) B2A(12-15) B2B(12-15) B2A(12-15) B2B(12-15) B2E(12-15) B2F(12-15) B2E(12-15) B2F(12-15) + + const __m512i rhs_mat_014589CD_30_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_30, (_MM_PERM_ENUM)221); //B30(4-7) B31(4-7) B30(4-7) B31(4-7) B34(4-7) B35(4-7) B34(4-7) B35(4-7) B38(4-7) B39(4-7) B38(4-7) B39(4-7) B3C(4-7) B3D(4-7) B3C(4-7) B3D(4-7) + const __m512i rhs_mat_2367ABEF_30_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_30, (_MM_PERM_ENUM)221); //B32(4-7) B33(4-7) B32(4-7) B33(4-7) B36(4-7) B37(4-7) B36(4-7) B37(4-7) B3A(4-7) B3B(4-7) B3A(4-7) B3B(4-7) B3E(4-7) B3F(4-7) B3E(4-7) B3F(4-7) + + const __m512i rhs_mat_014589CD_31_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_31, (_MM_PERM_ENUM)221); //B30(12-15) B31(12-15) B30(12-15) B31(12-15) B34(12-15) B35(12-15) B34(12-15) B35(12-15) B38(12-15) B39(12-15) B38(12-15) B39(12-15) B3C(12-15) B3D(12-15) B3C(12-15) B3D(12-15) + const __m512i rhs_mat_2367ABEF_31_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_31, (_MM_PERM_ENUM)221); //B32(12-15) B33(12-15) B32(12-15) B33(12-15) B36(12-15) B37(12-15) B36(12-15) B37(12-15) B3A(12-15) B3B(12-15) B3A(12-15) B3B(12-15) B3E(12-15) B3F(12-15) B3E(12-15) B3F(12-15) + + const __m512i rhs_mat_014589CD_40_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_40, (_MM_PERM_ENUM)221); //B40(4-7) B41(4-7) B40(4-7) B41(4-7) B44(4-7) B45(4-7) B44(4-7) B45(4-7) B48(4-7) B49(4-7) B48(4-7) B49(4-7) B4C(4-7) B4D(4-7) B4C(4-7) B4D(4-7) + const __m512i rhs_mat_2367ABEF_40_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_40, (_MM_PERM_ENUM)221); //B42(4-7) B43(4-7) B42(4-7) B43(4-7) B46(4-7) B47(4-7) B46(4-7) B47(4-7) B4A(4-7) B4B(4-7) B4A(4-7) B4B(4-7) B4E(4-7) B4F(4-7) B4E(4-7) B4F(4-7) + + const __m512i rhs_mat_014589CD_41_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_41, (_MM_PERM_ENUM)221); //B40(12-15) B41(12-15) B40(12-15) B41(12-15) B44(12-15) B45(12-15) B44(12-15) B45(12-15) B48(12-15) B49(12-15) B48(12-15) B49(12-15) B4C(12-15) B4D(12-15) B4C(12-15) B4D(12-15) + const __m512i rhs_mat_2367ABEF_41_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_41, (_MM_PERM_ENUM)221); //B42(12-15) B43(12-15) B42(12-15) B43(12-15) B46(12-15) B47(12-15) B46(12-15) B47(12-15) B4A(12-15) B4B(12-15) B4A(12-15) B4B(12-15) B4E(12-15) B4F(12-15) B4E(12-15) B4F(12-15) + + const __m512i rhs_mat_014589CD_50_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_50, (_MM_PERM_ENUM)221); //B50(4-7) B51(4-7) B50(4-7) B51(4-7) B54(4-7) B55(4-7) B54(4-7) B55(4-7) B58(4-7) B59(4-7) B58(4-7) B59(4-7) B5C(4-7) B5D(4-7) B5C(4-7) B5D(4-7) + const __m512i rhs_mat_2367ABEF_50_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_50, (_MM_PERM_ENUM)221); //B52(4-7) B53(4-7) B52(4-7) B53(4-7) B56(4-7) B57(4-7) B56(4-7) B57(4-7) B5A(4-7) B5B(4-7) B5A(4-7) B5B(4-7) B5E(4-7) B5F(4-7) B5E(4-7) B5F(4-7) + + const __m512i rhs_mat_014589CD_51_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_51, (_MM_PERM_ENUM)221); //B50(12-15) B51(12-15) B50(12-15) B51(12-15) B54(12-15) B55(12-15) B54(12-15) B55(12-15) B58(12-15) B59(12-15) B58(12-15) B59(12-15) B5C(12-15) B5D(12-15) B5C(12-15) B5D(12-15) + const __m512i rhs_mat_2367ABEF_51_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_51, (_MM_PERM_ENUM)221); //B52(12-15) B53(12-15) B52(12-15) B53(12-15) B56(12-15) B57(12-15) B56(12-15) B57(12-15) B5A(12-15) B5B(12-15) B5A(12-15) B5B(12-15) B5E(12-15) B5F(12-15) B5E(12-15) B5F(12-15) + + const __m512i rhs_mat_014589CD_60_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_60, (_MM_PERM_ENUM)221); //B60(4-7) B61(4-7) B60(4-7) B61(4-7) B64(4-7) B65(4-7) B64(4-7) B65(4-7) B68(4-7) B69(4-7) B68(4-7) B69(4-7) B6C(4-7) B6D(4-7) B6C(4-7) B6D(4-7) + const __m512i rhs_mat_2367ABEF_60_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_60, (_MM_PERM_ENUM)221); //B62(4-7) B63(4-7) B62(4-7) B63(4-7) B66(4-7) B67(4-7) B66(4-7) B67(4-7) B6A(4-7) B6B(4-7) B6A(4-7) B6B(4-7) B6E(4-7) B6F(4-7) B6E(4-7) B6F(4-7) + + const __m512i rhs_mat_014589CD_61_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_61, (_MM_PERM_ENUM)221); //B60(12-15) B61(12-15) B60(12-15) B61(12-15) B64(12-15) B65(12-15) B64(12-15) B65(12-15) B68(12-15) B69(12-15) B68(12-15) B69(12-15) B6C(12-15) B6D(12-15) B6C(12-15) B6D(12-15) + const __m512i rhs_mat_2367ABEF_61_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_61, (_MM_PERM_ENUM)221); //B62(12-15) B63(12-15) B62(12-15) B63(12-15) B66(12-15) B67(12-15) B66(12-15) B67(12-15) B6A(12-15) B6B(12-15) B6A(12-15) B6B(12-15) B6E(12-15) B6F(12-15) B6E(12-15) B6F(12-15) + + const __m512i rhs_mat_014589CD_70_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_70, (_MM_PERM_ENUM)221); //B70(4-7) B71(4-7) B70(4-7) B71(4-7) B74(4-7) B75(4-7) B74(4-7) B75(4-7) B78(4-7) B79(4-7) B78(4-7) B79(4-7) B7C(4-7) B7D(4-7) B7C(4-7) B7D(4-7) + const __m512i rhs_mat_2367ABEF_70_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_70, (_MM_PERM_ENUM)221); //B72(4-7) B73(4-7) B72(4-7) B73(4-7) B76(4-7) B77(4-7) B76(4-7) B77(4-7) B7A(4-7) B7B(4-7) B7A(4-7) B7B(4-7) B7E(4-7) B7F(4-7) B7E(4-7) B7F(4-7) + + const __m512i rhs_mat_014589CD_71_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_71, (_MM_PERM_ENUM)221); //B70(12-15) B71(12-15) B70(12-15) B71(12-15) B74(12-15) B75(12-15) B74(12-15) B75(12-15) B78(12-15) B79(12-15) B78(12-15) B79(12-15) B7C(12-15) B7D(12-15) B7C(12-15) B7D(12-15) + const __m512i rhs_mat_2367ABEF_71_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_71, (_MM_PERM_ENUM)221); //B72(12-15) B73(12-15) B72(12-15) B73(12-15) B76(12-15) B77(12-15) B76(12-15) B77(12-15) B7A(12-15) B7B(12-15) B7A(12-15) B7B(12-15) B7E(12-15) B7F(12-15) B7E(12-15) B7F(12-15) + + //notation:superblock subblock + //s00 m00 s01 m01 s10 m10 s11 m11 s20 m20 s21 m21 s30 m30 s31 m31 s40 m40 s41 m41 s50 m50 s51 m51 s60 m60 s61 m61 s70 m70 s71 m71 + + const __m128i mins_and_scales_01_0 = _mm_loadu_si128((const __m128i *)(b_ptr_0[b].scales + sb * 64)); + const __m128i mins_and_scales_23_0 = _mm_loadu_si128((const __m128i *)(b_ptr_0[b].scales + 16 + sb * 64)); + const __m128i mins_and_scales_45_0 = _mm_loadu_si128((const __m128i *)(b_ptr_0[b].scales + 32 + sb * 64)); + const __m128i mins_and_scales_67_0 = _mm_loadu_si128((const __m128i *)(b_ptr_0[b].scales + 48 + sb * 64)); + + const __m128i mins_and_scales_01_1 = _mm_loadu_si128((const __m128i *)(b_ptr_1[b].scales + sb * 64)); + const __m128i mins_and_scales_23_1 = _mm_loadu_si128((const __m128i *)(b_ptr_1[b].scales + 16 + sb * 64)); + const __m128i mins_and_scales_45_1 = _mm_loadu_si128((const __m128i *)(b_ptr_1[b].scales + 32 + sb * 64)); + const __m128i mins_and_scales_67_1 = _mm_loadu_si128((const __m128i *)(b_ptr_1[b].scales + 48 + sb * 64)); + + // Combine mins and scales for sub-blocks: 0-1, 2-3, 4-5, 6-7 in the sb loop + const __m256i mins_and_scales_01 = _mm256_insertf128_si256(_mm256_castsi128_si256(mins_and_scales_01_0), mins_and_scales_01_1, 1); + const __m256i mins_and_scales_23 = _mm256_insertf128_si256(_mm256_castsi128_si256(mins_and_scales_23_0), mins_and_scales_23_1, 1); + const __m256i mins_and_scales_45 = _mm256_insertf128_si256(_mm256_castsi128_si256(mins_and_scales_45_0), mins_and_scales_45_1, 1); + const __m256i mins_and_scales_67 = _mm256_insertf128_si256(_mm256_castsi128_si256(mins_and_scales_67_0), mins_and_scales_67_1, 1); + + // Extract scales which is lower half from mins_and_scales + const __m256i scales_01 = _mm256_and_si256(mins_and_scales_01, m4b); + const __m256i scales_23 = _mm256_and_si256(mins_and_scales_23, m4b); + const __m256i scales_45 = _mm256_and_si256(mins_and_scales_45, m4b); + const __m256i scales_67 = _mm256_and_si256(mins_and_scales_67, m4b); + + // Extract mins which is upper half from mins_and_scales + const __m512i mins_01 = _mm512_cvtepu8_epi16(_mm256_and_si256(_mm256_srli_epi16(mins_and_scales_01, 4), m4b)); + const __m512i mins_23 = _mm512_cvtepu8_epi16(_mm256_and_si256(_mm256_srli_epi16(mins_and_scales_23, 4), m4b)); + const __m512i mins_45 = _mm512_cvtepu8_epi16(_mm256_and_si256(_mm256_srli_epi16(mins_and_scales_45, 4), m4b)); + const __m512i mins_67 = _mm512_cvtepu8_epi16(_mm256_and_si256(_mm256_srli_epi16(mins_and_scales_67, 4), m4b)); + + const __m512i scales_0 = _mm512_cvtepu8_epi16(_mm256_shuffle_epi8(scales_01,scalesmask1)); + const __m512i scales_1 = _mm512_cvtepu8_epi16(_mm256_shuffle_epi8(scales_01,scalesmask2)); + const __m512i scales_2 = _mm512_cvtepu8_epi16(_mm256_shuffle_epi8(scales_23,scalesmask1)); + const __m512i scales_3 = _mm512_cvtepu8_epi16(_mm256_shuffle_epi8(scales_23,scalesmask2)); + const __m512i scales_4 = _mm512_cvtepu8_epi16(_mm256_shuffle_epi8(scales_45,scalesmask1)); + const __m512i scales_5 = _mm512_cvtepu8_epi16(_mm256_shuffle_epi8(scales_45,scalesmask2)); + const __m512i scales_6 = _mm512_cvtepu8_epi16(_mm256_shuffle_epi8(scales_67,scalesmask1)); + const __m512i scales_7 = _mm512_cvtepu8_epi16(_mm256_shuffle_epi8(scales_67,scalesmask2)); + + const __m512i scale_014589CD_0 = _mm512_shuffle_epi32(scales_0, (_MM_PERM_ENUM)68); + const __m512i scale_2367ABEF_0 = _mm512_shuffle_epi32(scales_0, (_MM_PERM_ENUM)238); + + const __m512i scale_014589CD_1 = _mm512_shuffle_epi32(scales_1, (_MM_PERM_ENUM)68); + const __m512i scale_2367ABEF_1 = _mm512_shuffle_epi32(scales_1, (_MM_PERM_ENUM)238); + + const __m512i scale_014589CD_2 = _mm512_shuffle_epi32(scales_2, (_MM_PERM_ENUM)68); + const __m512i scale_2367ABEF_2 = _mm512_shuffle_epi32(scales_2, (_MM_PERM_ENUM)238); + + const __m512i scale_014589CD_3 = _mm512_shuffle_epi32(scales_3, (_MM_PERM_ENUM)68); + const __m512i scale_2367ABEF_3 = _mm512_shuffle_epi32(scales_3, (_MM_PERM_ENUM)238); + + const __m512i scale_014589CD_4 = _mm512_shuffle_epi32(scales_4, (_MM_PERM_ENUM)68); + const __m512i scale_2367ABEF_4 = _mm512_shuffle_epi32(scales_4, (_MM_PERM_ENUM)238); + + const __m512i scale_014589CD_5 = _mm512_shuffle_epi32(scales_5, (_MM_PERM_ENUM)68); + const __m512i scale_2367ABEF_5 = _mm512_shuffle_epi32(scales_5, (_MM_PERM_ENUM)238); + + const __m512i scale_014589CD_6 = _mm512_shuffle_epi32(scales_6, (_MM_PERM_ENUM)68); + const __m512i scale_2367ABEF_6 = _mm512_shuffle_epi32(scales_6, (_MM_PERM_ENUM)238); + + const __m512i scale_014589CD_7 = _mm512_shuffle_epi32(scales_7, (_MM_PERM_ENUM)68); + const __m512i scale_2367ABEF_7 = _mm512_shuffle_epi32(scales_7, (_MM_PERM_ENUM)238); + + + for (int rp = 0; rp < 4; rp++) { + + // Load the four block_q8_k quantized values interleaved with each other in chunks of eight bytes - A0,A1,A2,A3 + // Loaded as set of 128 bit vectors and repeated and stored into a 256 bit vector before again repeating into 512 bit vector + __m256i lhs_mat_ymm_0123_00 = _mm256_loadu_si256((const __m256i * )((a_ptrs[rp][b].qs + 512 * sb))); + __m256i lhs_mat_ymm_01_00 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_00, lhs_mat_ymm_0123_00, 0); + __m256i lhs_mat_ymm_23_00 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_00, lhs_mat_ymm_0123_00, 17); + __m256i lhs_mat_ymm_0123_01 = _mm256_loadu_si256((const __m256i * )((a_ptrs[rp][b].qs + 32 + 512 * sb))); + __m256i lhs_mat_ymm_01_01 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_01, lhs_mat_ymm_0123_01, 0); + __m256i lhs_mat_ymm_23_01 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_01, lhs_mat_ymm_0123_01, 17); + __m256i lhs_mat_ymm_0123_10 = _mm256_loadu_si256((const __m256i * )((a_ptrs[rp][b].qs + 64 + 512 * sb))); + __m256i lhs_mat_ymm_01_10 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_10, lhs_mat_ymm_0123_10, 0); + __m256i lhs_mat_ymm_23_10 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_10, lhs_mat_ymm_0123_10, 17); + __m256i lhs_mat_ymm_0123_11 = _mm256_loadu_si256((const __m256i * )((a_ptrs[rp][b].qs + 96 + 512 * sb))); + __m256i lhs_mat_ymm_01_11 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_11, lhs_mat_ymm_0123_11, 0); + __m256i lhs_mat_ymm_23_11 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_11, lhs_mat_ymm_0123_11, 17); + __m256i lhs_mat_ymm_0123_20 = _mm256_loadu_si256((const __m256i * )((a_ptrs[rp][b].qs + 128 + 512 * sb))); + __m256i lhs_mat_ymm_01_20 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_20, lhs_mat_ymm_0123_20, 0); + __m256i lhs_mat_ymm_23_20 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_20, lhs_mat_ymm_0123_20, 17); + __m256i lhs_mat_ymm_0123_21 = _mm256_loadu_si256((const __m256i * )((a_ptrs[rp][b].qs + 160 + 512 * sb))); + __m256i lhs_mat_ymm_01_21 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_21, lhs_mat_ymm_0123_21, 0); + __m256i lhs_mat_ymm_23_21 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_21, lhs_mat_ymm_0123_21, 17); + __m256i lhs_mat_ymm_0123_30 = _mm256_loadu_si256((const __m256i * )((a_ptrs[rp][b].qs + 192 + 512 * sb))); + __m256i lhs_mat_ymm_01_30 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_30, lhs_mat_ymm_0123_30, 0); + __m256i lhs_mat_ymm_23_30 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_30, lhs_mat_ymm_0123_30, 17); + __m256i lhs_mat_ymm_0123_31 = _mm256_loadu_si256((const __m256i * )((a_ptrs[rp][b].qs + 224 + 512 * sb))); + __m256i lhs_mat_ymm_01_31 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_31, lhs_mat_ymm_0123_31, 0); + __m256i lhs_mat_ymm_23_31 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_31, lhs_mat_ymm_0123_31, 17); + + __m256i lhs_mat_ymm_0123_40 = _mm256_loadu_si256((const __m256i * )((a_ptrs[rp][b].qs + 256 + 512 * sb))); + __m256i lhs_mat_ymm_01_40 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_40, lhs_mat_ymm_0123_40, 0); + __m256i lhs_mat_ymm_23_40 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_40, lhs_mat_ymm_0123_40, 17); + __m256i lhs_mat_ymm_0123_41 = _mm256_loadu_si256((const __m256i * )((a_ptrs[rp][b].qs + 288 + 512 * sb))); + __m256i lhs_mat_ymm_01_41 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_41, lhs_mat_ymm_0123_41, 0); + __m256i lhs_mat_ymm_23_41 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_41, lhs_mat_ymm_0123_41, 17); + __m256i lhs_mat_ymm_0123_50 = _mm256_loadu_si256((const __m256i * )((a_ptrs[rp][b].qs + 320 + 512 * sb))); + __m256i lhs_mat_ymm_01_50 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_50, lhs_mat_ymm_0123_50, 0); + __m256i lhs_mat_ymm_23_50 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_50, lhs_mat_ymm_0123_50, 17); + __m256i lhs_mat_ymm_0123_51 = _mm256_loadu_si256((const __m256i * )((a_ptrs[rp][b].qs + 352 + 512 * sb))); + __m256i lhs_mat_ymm_01_51 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_51, lhs_mat_ymm_0123_51, 0); + __m256i lhs_mat_ymm_23_51 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_51, lhs_mat_ymm_0123_51, 17); + __m256i lhs_mat_ymm_0123_60 = _mm256_loadu_si256((const __m256i * )((a_ptrs[rp][b].qs + 384 + 512 * sb))); + __m256i lhs_mat_ymm_01_60 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_60, lhs_mat_ymm_0123_60, 0); + __m256i lhs_mat_ymm_23_60 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_60, lhs_mat_ymm_0123_60, 17); + __m256i lhs_mat_ymm_0123_61 = _mm256_loadu_si256((const __m256i * )((a_ptrs[rp][b].qs + 416 + 512 * sb))); + __m256i lhs_mat_ymm_01_61 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_61, lhs_mat_ymm_0123_61, 0); + __m256i lhs_mat_ymm_23_61 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_61, lhs_mat_ymm_0123_61, 17); + __m256i lhs_mat_ymm_0123_70 = _mm256_loadu_si256((const __m256i * )((a_ptrs[rp][b].qs + 448 + 512 * sb))); + __m256i lhs_mat_ymm_01_70 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_70, lhs_mat_ymm_0123_70, 0); + __m256i lhs_mat_ymm_23_70 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_70, lhs_mat_ymm_0123_70, 17); + __m256i lhs_mat_ymm_0123_71 = _mm256_loadu_si256((const __m256i * )((a_ptrs[rp][b].qs + 480 + 512 * sb))); + __m256i lhs_mat_ymm_01_71 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_71, lhs_mat_ymm_0123_71, 0); + __m256i lhs_mat_ymm_23_71 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_71, lhs_mat_ymm_0123_71, 17); + + + __m512i lhs_mat_01_00 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_00), lhs_mat_ymm_01_00, 1); + __m512i lhs_mat_23_00 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_00), lhs_mat_ymm_23_00, 1); + __m512i lhs_mat_01_01 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_01), lhs_mat_ymm_01_01, 1); + __m512i lhs_mat_23_01 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_01), lhs_mat_ymm_23_01, 1); + + __m512i lhs_mat_01_10 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_10), lhs_mat_ymm_01_10, 1); + __m512i lhs_mat_23_10 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_10), lhs_mat_ymm_23_10, 1); + __m512i lhs_mat_01_11 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_11), lhs_mat_ymm_01_11, 1); + __m512i lhs_mat_23_11 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_11), lhs_mat_ymm_23_11, 1); + + __m512i lhs_mat_01_20 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_20), lhs_mat_ymm_01_20, 1); + __m512i lhs_mat_23_20 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_20), lhs_mat_ymm_23_20, 1); + __m512i lhs_mat_01_21 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_21), lhs_mat_ymm_01_21, 1); + __m512i lhs_mat_23_21 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_21), lhs_mat_ymm_23_21, 1); + + __m512i lhs_mat_01_30 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_30), lhs_mat_ymm_01_30, 1); + __m512i lhs_mat_23_30 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_30), lhs_mat_ymm_23_30, 1); + __m512i lhs_mat_01_31 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_31), lhs_mat_ymm_01_31, 1); + __m512i lhs_mat_23_31 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_31), lhs_mat_ymm_23_31, 1); + + __m512i lhs_mat_01_40 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_40), lhs_mat_ymm_01_40, 1); + __m512i lhs_mat_23_40 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_40), lhs_mat_ymm_23_40, 1); + __m512i lhs_mat_01_41 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_41), lhs_mat_ymm_01_41, 1); + __m512i lhs_mat_23_41 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_41), lhs_mat_ymm_23_41, 1); + + __m512i lhs_mat_01_50 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_50), lhs_mat_ymm_01_50, 1); + __m512i lhs_mat_23_50 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_50), lhs_mat_ymm_23_50, 1); + __m512i lhs_mat_01_51 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_51), lhs_mat_ymm_01_51, 1); + __m512i lhs_mat_23_51 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_51), lhs_mat_ymm_23_51, 1); + + __m512i lhs_mat_01_60 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_60), lhs_mat_ymm_01_60, 1); + __m512i lhs_mat_23_60 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_60), lhs_mat_ymm_23_60, 1); + __m512i lhs_mat_01_61 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_61), lhs_mat_ymm_01_61, 1); + __m512i lhs_mat_23_61 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_61), lhs_mat_ymm_23_61, 1); + + __m512i lhs_mat_01_70 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_70), lhs_mat_ymm_01_70, 1); + __m512i lhs_mat_23_70 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_70), lhs_mat_ymm_23_70, 1); + __m512i lhs_mat_01_71 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_71), lhs_mat_ymm_01_71, 1); + __m512i lhs_mat_23_71 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_71), lhs_mat_ymm_23_71, 1); + + // Bsums are loaded for the different Q8_K blocks + __m128i lhs_raw_bsums_01_0123 = _mm_loadu_si128((const __m128i *)((a_ptrs[rp][b].bsums + 32 * sb))); + __m128i lhs_raw_bsums_23_0123 = _mm_loadu_si128((const __m128i *)(a_ptrs[rp][b].bsums + 8 + 32 * sb)); + __m128i lhs_raw_bsums_01_4567 = _mm_loadu_si128((const __m128i *)((a_ptrs[rp][b].bsums + 16 + 32 * sb))); + __m128i lhs_raw_bsums_23_4567 = _mm_loadu_si128((const __m128i *)(a_ptrs[rp][b].bsums + 24 + 32 * sb)); + + __m256i lhs_bsums_ymm_01_0123 = _mm256_inserti128_si256(_mm256_castsi128_si256(lhs_raw_bsums_01_0123), lhs_raw_bsums_01_0123, 1); + __m512i lhs_bsums_01_0123 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_bsums_ymm_01_0123), lhs_bsums_ymm_01_0123, 1); + __m256i lhs_bsums_ymm_23_0123 = _mm256_inserti128_si256(_mm256_castsi128_si256(lhs_raw_bsums_23_0123), lhs_raw_bsums_23_0123, 1); + __m512i lhs_bsums_23_0123 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_bsums_ymm_23_0123), lhs_bsums_ymm_23_0123, 1); __m256i lhs_bsums_ymm_01_4567 = _mm256_inserti128_si256(_mm256_castsi128_si256(lhs_raw_bsums_01_4567), lhs_raw_bsums_01_4567, 1); + __m512i lhs_bsums_01_4567 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_bsums_ymm_01_4567), lhs_bsums_ymm_01_4567, 1); + __m256i lhs_bsums_ymm_23_4567 = _mm256_inserti128_si256(_mm256_castsi128_si256(lhs_raw_bsums_23_4567), lhs_raw_bsums_23_4567, 1); + __m512i lhs_bsums_23_4567 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_bsums_ymm_23_4567), lhs_bsums_ymm_23_4567, 1); + + // Shuffle pattern one - left side input + const __m512i lhs_mat_01_00_sp1 = _mm512_shuffle_epi32(lhs_mat_01_00, (_MM_PERM_ENUM)160); //A00(0-3) A00(0-3) A01(0-3) A01(0-3) A00(0-3) A00(0-3) A01(0-3) A01(0-3) A00(0-3) A00(0-3) A01(0-3) A01(0-3) A00(0-3) A00(0-3) A01(0-3) A01(0-3) + const __m512i lhs_mat_23_00_sp1 = _mm512_shuffle_epi32(lhs_mat_23_00, (_MM_PERM_ENUM)160); //A02(0-3) A02(0-3) A03(0-3) A03(0-3) A02(0-3) A02(0-3) A03(0-3) A03(0-3) A02(0-3) A02(0-3) A03(0-3) A03(0-3) A02(0-3) A02(0-3) A03(0-3) A03(0-3) + + const __m512i lhs_mat_01_01_sp1 = _mm512_shuffle_epi32(lhs_mat_01_01, (_MM_PERM_ENUM)160); //A00(8-11) A00(8-11) A01(8-11) A01(8-11) A00(8-11) A00(8-11) A01(8-11) A01(8-11) A00(8-11) A00(8-11) A01(8-11) A01(8-11) A00(8-11) A00(8-11) A01(8-11) A01(8-11) + const __m512i lhs_mat_23_01_sp1 = _mm512_shuffle_epi32(lhs_mat_23_01, (_MM_PERM_ENUM)160); //A02(8-11) A02(8-11) A03(8-11) A03(8-11) A02(8-11) A02(8-11) A03(8-11) A03(8-11) A02(8-11) A02(8-11) A03(8-11) A03(8-11) A02(8-11) A02(8-11) A03(8-11) A03(8-11) + + const __m512i lhs_mat_01_10_sp1 = _mm512_shuffle_epi32(lhs_mat_01_10, (_MM_PERM_ENUM)160); //A10(0-3) A10(0-3) A11(0-3) A11(0-3) A10(0-3) A10(0-3) A11(0-3) A11(0-3) A10(0-3) A10(0-3) A11(0-3) A11(0-3) A10(0-3) A10(0-3) A11(0-3) A11(0-3) + const __m512i lhs_mat_23_10_sp1 = _mm512_shuffle_epi32(lhs_mat_23_10, (_MM_PERM_ENUM)160); //A12(0-3) A12(0-3) A13(0-3) A13(0-3) A12(0-3) A12(0-3) A13(0-3) A13(0-3) A12(0-3) A12(0-3) A13(0-3) A13(0-3) A12(0-3) A12(0-3) A13(0-3) A13(0-3) + + const __m512i lhs_mat_01_11_sp1 = _mm512_shuffle_epi32(lhs_mat_01_11, (_MM_PERM_ENUM)160); //A10(8-11) A10(8-11) A11(8-11) A11(8-11) A10(8-11) A10(8-11) A11(8-11) A11(8-11) A10(8-11) A10(8-11) A11(8-11) A11(8-11) A10(8-11) A10(8-11) A11(8-11) A11(8-11) + const __m512i lhs_mat_23_11_sp1 = _mm512_shuffle_epi32(lhs_mat_23_11, (_MM_PERM_ENUM)160); //A12(8-11) A12(8-11) A13(8-11) A13(8-11) A12(8-11) A12(8-11) A13(8-11) A13(8-11) A12(8-11) A12(8-11) A13(8-11) A13(8-11) A12(8-11) A12(8-11) A13(8-11) A13(8-11) + + const __m512i lhs_mat_01_20_sp1 = _mm512_shuffle_epi32(lhs_mat_01_20, (_MM_PERM_ENUM)160); //A20(0-3) A20(0-3) A21(0-3) A21(0-3) A20(0-3) A20(0-3) A21(0-3) A21(0-3) A20(0-3) A20(0-3) A21(0-3) A21(0-3) A20(0-3) A20(0-3) A21(0-3) A21(0-3) + const __m512i lhs_mat_23_20_sp1 = _mm512_shuffle_epi32(lhs_mat_23_20, (_MM_PERM_ENUM)160); //A22(0-3) A22(0-3) A23(0-3) A23(0-3) A22(0-3) A22(0-3) A23(0-3) A23(0-3) A22(0-3) A22(0-3) A23(0-3) A23(0-3) A22(0-3) A22(0-3) A23(0-3) A23(0-3) + + const __m512i lhs_mat_01_21_sp1 = _mm512_shuffle_epi32(lhs_mat_01_21, (_MM_PERM_ENUM)160); //A20(8-11) A20(8-11) A21(8-11) A21(8-11) A20(8-11) A20(8-11) A21(8-11) A21(8-11) A20(8-11) A20(8-11) A21(8-11) A21(8-11) A20(8-11) A20(8-11) A21(8-11) A21(8-11) + const __m512i lhs_mat_23_21_sp1 = _mm512_shuffle_epi32(lhs_mat_23_21, (_MM_PERM_ENUM)160); //A22(8-11) A22(8-11) A23(8-11) A23(8-11) A22(8-11) A22(8-11) A23(8-11) A23(8-11) A22(8-11) A22(8-11) A23(8-11) A23(8-11) A22(8-11) A22(8-11) A23(8-11) A23(8-11) + + const __m512i lhs_mat_01_30_sp1 = _mm512_shuffle_epi32(lhs_mat_01_30, (_MM_PERM_ENUM)160); //A30(0-3) A30(0-3) A31(0-3) A31(0-3) A30(0-3) A30(0-3) A31(0-3) A31(0-3) A30(0-3) A30(0-3) A31(0-3) A31(0-3) A30(0-3) A30(0-3) A31(0-3) A31(0-3) + const __m512i lhs_mat_23_30_sp1 = _mm512_shuffle_epi32(lhs_mat_23_30, (_MM_PERM_ENUM)160); //A32(0-3) A32(0-3) A33(0-3) A33(0-3) A32(0-3) A32(0-3) A33(0-3) A33(0-3) A32(0-3) A32(0-3) A33(0-3) A33(0-3) A32(0-3) A32(0-3) A33(0-3) A33(0-3) + + const __m512i lhs_mat_01_31_sp1 = _mm512_shuffle_epi32(lhs_mat_01_31, (_MM_PERM_ENUM)160); //A30(8-11) A30(8-11) A31(8-11) A31(8-11) A30(8-11) A30(8-11) A31(8-11) A31(8-11) A30(8-11) A30(8-11) A31(8-11) A31(8-11) A30(8-11) A30(8-11) A31(8-11) A31(8-11) + const __m512i lhs_mat_23_31_sp1 = _mm512_shuffle_epi32(lhs_mat_23_31, (_MM_PERM_ENUM)160); //A32(8-11) A32(8-11) A33(8-11) A33(8-11) A32(8-11) A32(8-11) A33(8-11) A33(8-11) A32(8-11) A32(8-11) A33(8-11) A33(8-11) A32(8-11) A32(8-11) A33(8-11) A33(8-11) + + const __m512i lhs_mat_01_40_sp1 = _mm512_shuffle_epi32(lhs_mat_01_40, (_MM_PERM_ENUM)160); //A40(0-3) A40(0-3) A41(0-3) A41(0-3) A40(0-3) A40(0-3) A41(0-3) A41(0-3) A40(0-3) A40(0-3) A41(0-3) A41(0-3) A40(0-3) A40(0-3) A41(0-3) A41(0-3) + const __m512i lhs_mat_23_40_sp1 = _mm512_shuffle_epi32(lhs_mat_23_40, (_MM_PERM_ENUM)160); //A42(0-3) A42(0-3) A43(0-3) A43(0-3) A42(0-3) A42(0-3) A43(0-3) A43(0-3) A42(0-3) A42(0-3) A43(0-3) A43(0-3) A42(0-3) A42(0-3) A43(0-3) A43(0-3) + + const __m512i lhs_mat_01_41_sp1 = _mm512_shuffle_epi32(lhs_mat_01_41, (_MM_PERM_ENUM)160); //A40(8-11) A40(8-11) A41(8-11) A41(8-11) A40(8-11) A40(8-11) A41(8-11) A41(8-11) A40(8-11) A40(8-11) A41(8-11) A41(8-11) A40(8-11) A40(8-11) A41(8-11) A41(8-11) + const __m512i lhs_mat_23_41_sp1 = _mm512_shuffle_epi32(lhs_mat_23_41, (_MM_PERM_ENUM)160); //A42(8-11) A42(8-11) A43(8-11) A43(8-11) A42(8-11) A42(8-11) A43(8-11) A43(8-11) A42(8-11) A42(8-11) A43(8-11) A43(8-11) A42(8-11) A42(8-11) A43(8-11) A43(8-11) + + const __m512i lhs_mat_01_50_sp1 = _mm512_shuffle_epi32(lhs_mat_01_50, (_MM_PERM_ENUM)160); //A50(0-3) A50(0-3) A51(0-3) A51(0-3) A50(0-3) A50(0-3) A51(0-3) A51(0-3) A50(0-3) A50(0-3) A51(0-3) A51(0-3) A50(0-3) A50(0-3) A51(0-3) A51(0-3) + const __m512i lhs_mat_23_50_sp1 = _mm512_shuffle_epi32(lhs_mat_23_50, (_MM_PERM_ENUM)160); //A52(0-3) A52(0-3) A53(0-3) A53(0-3) A52(0-3) A52(0-3) A53(0-3) A53(0-3) A52(0-3) A52(0-3) A53(0-3) A53(0-3) A52(0-3) A52(0-3) A53(0-3) A53(0-3) + + const __m512i lhs_mat_01_51_sp1 = _mm512_shuffle_epi32(lhs_mat_01_51, (_MM_PERM_ENUM)160); //A50(8-11) A50(8-11) A51(8-11) A51(8-11) A50(8-11) A50(8-11) A51(8-11) A51(8-11) A50(8-11) A50(8-11) A51(8-11) A51(8-11) A50(8-11) A50(8-11) A51(8-11) A51(8-11) + const __m512i lhs_mat_23_51_sp1 = _mm512_shuffle_epi32(lhs_mat_23_51, (_MM_PERM_ENUM)160); //A52(8-11) A52(8-11) A53(8-11) A53(8-11) A52(8-11) A52(8-11) A53(8-11) A53(8-11) A52(8-11) A52(8-11) A53(8-11) A53(8-11) A52(8-11) A52(8-11) A53(8-11) A53(8-11) + + const __m512i lhs_mat_01_60_sp1 = _mm512_shuffle_epi32(lhs_mat_01_60, (_MM_PERM_ENUM)160); //A60(0-3) A60(0-3) A61(0-3) A61(0-3) A60(0-3) A60(0-3) A61(0-3) A61(0-3) A60(0-3) A60(0-3) A61(0-3) A61(0-3) A60(0-3) A60(0-3) A61(0-3) A61(0-3) + const __m512i lhs_mat_23_60_sp1 = _mm512_shuffle_epi32(lhs_mat_23_60, (_MM_PERM_ENUM)160); //A62(0-3) A62(0-3) A63(0-3) A63(0-3) A62(0-3) A62(0-3) A63(0-3) A63(0-3) A62(0-3) A62(0-3) A63(0-3) A63(0-3) A62(0-3) A62(0-3) A63(0-3) A63(0-3) + + const __m512i lhs_mat_01_61_sp1 = _mm512_shuffle_epi32(lhs_mat_01_61, (_MM_PERM_ENUM)160); //A60(8-11) A60(8-11) A61(8-11) A61(8-11) A60(8-11) A60(8-11) A61(8-11) A61(8-11) A60(8-11) A60(8-11) A61(8-11) A61(8-11) A60(8-11) A60(8-11) A61(8-11) A61(8-11) + const __m512i lhs_mat_23_61_sp1 = _mm512_shuffle_epi32(lhs_mat_23_61, (_MM_PERM_ENUM)160); //A62(8-11) A62(8-11) A63(8-11) A63(8-11) A62(8-11) A62(8-11) A63(8-11) A63(8-11) A62(8-11) A62(8-11) A63(8-11) A63(8-11) A62(8-11) A62(8-11) A63(8-11) A63(8-11) + + const __m512i lhs_mat_01_70_sp1 = _mm512_shuffle_epi32(lhs_mat_01_70, (_MM_PERM_ENUM)160); //A70(0-3) A70(0-3) A71(0-3) A71(0-3) A70(0-3) A70(0-3) A71(0-3) A71(0-3) A70(0-3) A70(0-3) A71(0-3) A71(0-3) A70(0-3) A70(0-3) A71(0-3) A71(0-3) + const __m512i lhs_mat_23_70_sp1 = _mm512_shuffle_epi32(lhs_mat_23_70, (_MM_PERM_ENUM)160); //A72(0-3) A72(0-3) A73(0-3) A73(0-3) A72(0-3) A72(0-3) A73(0-3) A73(0-3) A72(0-3) A72(0-3) A73(0-3) A73(0-3) A72(0-3) A72(0-3) A73(0-3) A73(0-3) + + const __m512i lhs_mat_01_71_sp1 = _mm512_shuffle_epi32(lhs_mat_01_71, (_MM_PERM_ENUM)160); //A70(8-11) A70(8-11) A71(8-11) A71(8-11) A70(8-11) A70(8-11) A71(8-11) A71(8-11) A70(8-11) A70(8-11) A71(8-11) A71(8-11) A70(8-11) A70(8-11) A71(8-11) A71(8-11) + const __m512i lhs_mat_23_71_sp1 = _mm512_shuffle_epi32(lhs_mat_23_71, (_MM_PERM_ENUM)160); //A72(8-11) A72(8-11) A73(8-11) A73(8-11) A72(8-11) A72(8-11) A73(8-11) A73(8-11) A72(8-11) A72(8-11) A73(8-11) A73(8-11) A72(8-11) A72(8-11) A73(8-11) A73(8-11) + + const __m512i lhs_mat_01_00_sp2 = _mm512_shuffle_epi32(lhs_mat_01_00, (_MM_PERM_ENUM)245); //A00(4-7) A00(4-7) A01(4-7) A01(4-7) A00(4-7) A00(4-7) A01(4-7) A01(4-7) A00(4-7) A00(4-7) A01(4-7) A01(4-7) A00(4-7) A00(4-7) A01(4-7) A01(4-7) + const __m512i lhs_mat_23_00_sp2 = _mm512_shuffle_epi32(lhs_mat_23_00, (_MM_PERM_ENUM)245); //A02(4-7) A02(4-7) A03(4-7) A03(4-7) A02(4-7) A02(4-7) A03(4-7) A03(4-7) A02(4-7) A02(4-7) A03(4-7) A03(4-7) A02(4-7) A02(4-7) A03(4-7) A03(4-7) + + const __m512i lhs_mat_01_01_sp2 = _mm512_shuffle_epi32(lhs_mat_01_01, (_MM_PERM_ENUM)245); //A00(12-15) A00(12-15) A01(12-15) A01(12-15) A00(12-15) A00(12-15) A01(12-15) A01(12-15) A00(12-15) A00(12-15) A01(12-15) A01(12-15) A00(12-15) A00(12-15) A01(12-15) A01(12-15) + const __m512i lhs_mat_23_01_sp2 = _mm512_shuffle_epi32(lhs_mat_23_01, (_MM_PERM_ENUM)245); //A02(12-15) A02(12-15) A03(12-15) A03(12-15) A02(12-15) A02(12-15) A03(12-15) A03(12-15) A02(12-15) A02(12-15) A03(12-15) A03(12-15) A02(12-15) A02(12-15) A03(12-15) A03(12-15) + + const __m512i lhs_mat_01_10_sp2 = _mm512_shuffle_epi32(lhs_mat_01_10, (_MM_PERM_ENUM)245); //A10(4-7) A10(4-7) A11(4-7) A11(4-7) A10(4-7) A10(4-7) A11(4-7) A11(4-7) A10(4-7) A10(4-7) A11(4-7) A11(4-7) A10(4-7) A10(4-7) A11(4-7) A11(4-7) + const __m512i lhs_mat_23_10_sp2 = _mm512_shuffle_epi32(lhs_mat_23_10, (_MM_PERM_ENUM)245); //A12(4-7) A12(4-7) A13(4-7) A13(4-7) A12(4-7) A12(4-7) A13(4-7) A13(4-7) A12(4-7) A12(4-7) A13(4-7) A13(4-7) A12(4-7) A12(4-7) A13(4-7) A13(4-7) + + const __m512i lhs_mat_01_11_sp2 = _mm512_shuffle_epi32(lhs_mat_01_11, (_MM_PERM_ENUM)245); //A10(12-15) A10(12-15) A11(12-15) A11(12-15) A10(12-15) A10(12-15) A11(12-15) A11(12-15) A10(12-15) A10(12-15) A11(12-15) A11(12-15) A10(12-15) A10(12-15) A11(12-15) A11(12-15) + const __m512i lhs_mat_23_11_sp2 = _mm512_shuffle_epi32(lhs_mat_23_11, (_MM_PERM_ENUM)245); //A12(12-15) A12(12-15) A13(12-15) A13(12-15) A12(12-15) A12(12-15) A13(12-15) A13(12-15) A12(12-15) A12(12-15) A13(12-15) A13(12-15) A12(12-15) A12(12-15) A13(12-15) A13(12-15) + + const __m512i lhs_mat_01_20_sp2 = _mm512_shuffle_epi32(lhs_mat_01_20, (_MM_PERM_ENUM)245); //A20(4-7) A20(4-7) A21(4-7) A21(4-7) A20(4-7) A20(4-7) A21(4-7) A21(4-7) A20(4-7) A20(4-7) A21(4-7) A21(4-7) A20(4-7) A20(4-7) A21(4-7) A21(4-7) + const __m512i lhs_mat_23_20_sp2 = _mm512_shuffle_epi32(lhs_mat_23_20, (_MM_PERM_ENUM)245); //A22(4-7) A22(4-7) A23(4-7) A23(4-7) A22(4-7) A22(4-7) A23(4-7) A23(4-7) A22(4-7) A22(4-7) A23(4-7) A23(4-7) A22(4-7) A22(4-7) A23(4-7) A23(4-7) + + const __m512i lhs_mat_01_21_sp2 = _mm512_shuffle_epi32(lhs_mat_01_21, (_MM_PERM_ENUM)245); //A20(12-15) A20(12-15) A21(12-15) A21(12-15) A20(12-15) A20(12-15) A21(12-15) A21(12-15) A20(12-15) A20(12-15) A21(12-15) A21(12-15) A20(12-15) A20(12-15) A21(12-15) A21(12-15) + const __m512i lhs_mat_23_21_sp2 = _mm512_shuffle_epi32(lhs_mat_23_21, (_MM_PERM_ENUM)245); //A22(12-15) A22(12-15) A23(12-15) A23(12-15) A22(12-15) A22(12-15) A23(12-15) A23(12-15) A22(12-15) A22(12-15) A23(12-15) A23(12-15) A22(12-15) A22(12-15) A23(12-15) A23(12-15) + + const __m512i lhs_mat_01_30_sp2 = _mm512_shuffle_epi32(lhs_mat_01_30, (_MM_PERM_ENUM)245); //A30(4-7) A30(4-7) A31(4-7) A31(4-7) A30(4-7) A30(4-7) A31(4-7) A31(4-7) A30(4-7) A30(4-7) A31(4-7) A31(4-7) A30(4-7) A30(4-7) A31(4-7) A31(4-7) + const __m512i lhs_mat_23_30_sp2 = _mm512_shuffle_epi32(lhs_mat_23_30, (_MM_PERM_ENUM)245); //A32(4-7) A32(4-7) A33(4-7) A33(4-7) A32(4-7) A32(4-7) A33(4-7) A33(4-7) A32(4-7) A32(4-7) A33(4-7) A33(4-7) A32(4-7) A32(4-7) A33(4-7) A33(4-7) + + const __m512i lhs_mat_01_31_sp2 = _mm512_shuffle_epi32(lhs_mat_01_31, (_MM_PERM_ENUM)245); //A30(12-15) A30(12-15) A31(12-15) A31(12-15) A30(12-15) A30(12-15) A31(12-15) A31(12-15) A30(12-15) A30(12-15) A31(12-15) A31(12-15) A30(12-15) A30(12-15) A31(12-15) A31(12-15) + const __m512i lhs_mat_23_31_sp2 = _mm512_shuffle_epi32(lhs_mat_23_31, (_MM_PERM_ENUM)245); //A32(12-15) A32(12-15) A33(12-15) A33(12-15) A32(12-15) A32(12-15) A33(12-15) A33(12-15) A32(12-15) A32(12-15) A33(12-15) A33(12-15) A32(12-15) A32(12-15) A33(12-15) A33(12-15) + + const __m512i lhs_mat_01_40_sp2 = _mm512_shuffle_epi32(lhs_mat_01_40, (_MM_PERM_ENUM)245); //A40(4-7) A40(4-7) A41(4-7) A41(4-7) A40(4-7) A40(4-7) A41(4-7) A41(4-7) A40(4-7) A40(4-7) A41(4-7) A41(4-7) A40(4-7) A40(4-7) A41(4-7) A41(4-7) + const __m512i lhs_mat_23_40_sp2 = _mm512_shuffle_epi32(lhs_mat_23_40, (_MM_PERM_ENUM)245); //A42(4-7) A42(4-7) A43(4-7) A43(4-7) A42(4-7) A42(4-7) A43(4-7) A43(4-7) A42(4-7) A42(4-7) A43(4-7) A43(4-7) A42(4-7) A42(4-7) A43(4-7) A43(4-7) + + const __m512i lhs_mat_01_41_sp2 = _mm512_shuffle_epi32(lhs_mat_01_41, (_MM_PERM_ENUM)245); //A40(12-15) A40(12-15) A41(12-15) A41(12-15) A40(12-15) A40(12-15) A41(12-15) A41(12-15) A40(12-15) A40(12-15) A41(12-15) A41(12-15) A40(12-15) A40(12-15) A41(12-15) A41(12-15) + const __m512i lhs_mat_23_41_sp2 = _mm512_shuffle_epi32(lhs_mat_23_41, (_MM_PERM_ENUM)245); //A42(12-15) A42(12-15) A43(12-15) A43(12-15) A42(12-15) A42(12-15) A43(12-15) A43(12-15) A42(12-15) A42(12-15) A43(12-15) A43(12-15) A42(12-15) A42(12-15) A43(12-15) A43(12-15) + + const __m512i lhs_mat_01_50_sp2 = _mm512_shuffle_epi32(lhs_mat_01_50, (_MM_PERM_ENUM)245); //A50(4-7) A50(4-7) A51(4-7) A51(4-7) A50(4-7) A50(4-7) A51(4-7) A51(4-7) A50(4-7) A50(4-7) A51(4-7) A51(4-7) A50(4-7) A50(4-7) A51(4-7) A51(4-7) + const __m512i lhs_mat_23_50_sp2 = _mm512_shuffle_epi32(lhs_mat_23_50, (_MM_PERM_ENUM)245); //A52(4-7) A52(4-7) A53(4-7) A53(4-7) A52(4-7) A52(4-7) A53(4-7) A53(4-7) A52(4-7) A52(4-7) A53(4-7) A53(4-7) A52(4-7) A52(4-7) A53(4-7) A53(4-7) + + const __m512i lhs_mat_01_51_sp2 = _mm512_shuffle_epi32(lhs_mat_01_51, (_MM_PERM_ENUM)245); //A50(12-15) A50(12-15) A51(12-15) A51(12-15) A50(12-15) A50(12-15) A51(12-15) A51(12-15) A50(12-15) A50(12-15) A51(12-15) A51(12-15) A50(12-15) A50(12-15) A51(12-15) A51(12-15) + const __m512i lhs_mat_23_51_sp2 = _mm512_shuffle_epi32(lhs_mat_23_51, (_MM_PERM_ENUM)245); //A52(12-15) A52(12-15) A53(12-15) A53(12-15) A52(12-15) A52(12-15) A53(12-15) A53(12-15) A52(12-15) A52(12-15) A53(12-15) A53(12-15) A52(12-15) A52(12-15) A53(12-15) A53(12-15) + + const __m512i lhs_mat_01_60_sp2 = _mm512_shuffle_epi32(lhs_mat_01_60, (_MM_PERM_ENUM)245); //A60(4-7) A60(4-7) A61(4-7) A61(4-7) A60(4-7) A60(4-7) A61(4-7) A61(4-7) A60(4-7) A60(4-7) A61(4-7) A61(4-7) A60(4-7) A60(4-7) A61(4-7) A61(4-7) + const __m512i lhs_mat_23_60_sp2 = _mm512_shuffle_epi32(lhs_mat_23_60, (_MM_PERM_ENUM)245); //A62(4-7) A62(4-7) A63(4-7) A63(4-7) A62(4-7) A62(4-7) A63(4-7) A63(4-7) A62(4-7) A62(4-7) A63(4-7) A63(4-7) A62(4-7) A62(4-7) A63(4-7) A63(4-7) + + const __m512i lhs_mat_01_61_sp2 = _mm512_shuffle_epi32(lhs_mat_01_61, (_MM_PERM_ENUM)245); //A60(12-15) A60(12-15) A61(12-15) A61(12-15) A60(12-15) A60(12-15) A61(12-15) A61(12-15) A60(12-15) A60(12-15) A61(12-15) A61(12-15) A60(12-15) A60(12-15) A61(12-15) A61(12-15) + const __m512i lhs_mat_23_61_sp2 = _mm512_shuffle_epi32(lhs_mat_23_61, (_MM_PERM_ENUM)245); //A62(12-15) A62(12-15) A63(12-15) A63(12-15) A62(12-15) A62(12-15) A63(12-15) A63(12-15) A62(12-15) A62(12-15) A63(12-15) A63(12-15) A62(12-15) A62(12-15) A63(12-15) A63(12-15) + + const __m512i lhs_mat_01_70_sp2 = _mm512_shuffle_epi32(lhs_mat_01_70, (_MM_PERM_ENUM)245); //A70(4-7) A70(4-7) A71(4-7) A71(4-7) A70(4-7) A70(4-7) A71(4-7) A71(4-7) A70(4-7) A70(4-7) A71(4-7) A71(4-7) A70(4-7) A70(4-7) A71(4-7) A71(4-7) + const __m512i lhs_mat_23_70_sp2 = _mm512_shuffle_epi32(lhs_mat_23_70, (_MM_PERM_ENUM)245); //A72(4-7) A72(4-7) A73(4-7) A73(4-7) A72(4-7) A72(4-7) A73(4-7) A73(4-7) A72(4-7) A72(4-7) A73(4-7) A73(4-7) A72(4-7) A72(4-7) A73(4-7) A73(4-7) + + const __m512i lhs_mat_01_71_sp2 = _mm512_shuffle_epi32(lhs_mat_01_71, (_MM_PERM_ENUM)245); //A70(12-15) A70(12-15) A71(12-15) A71(12-15) A70(12-15) A70(12-15) A71(12-15) A71(12-15) A70(12-15) A70(12-15) A71(12-15) A71(12-15) A70(12-15) A70(12-15) A71(12-15) A71(12-15) + const __m512i lhs_mat_23_71_sp2 = _mm512_shuffle_epi32(lhs_mat_23_71, (_MM_PERM_ENUM)245); //A72(12-15) A72(12-15) A73(12-15) A73(12-15) A72(12-15) A72(12-15) A73(12-15) A73(12-15) A72(12-15) A72(12-15) A73(12-15) A73(12-15) A72(12-15) A72(12-15) A73(12-15) A73(12-15) + + // The values arranged in shuffle patterns are operated with dot product operation within 32 bit lane i.e corresponding bytes and multiplied and added into 32 bit integers within 32 bit lane + __m512i iacc_mat_00_0_sp1 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_00_sp1, lhs_mat_01_00_sp1),_mm512_maddubs_epi16(rhs_mat_014589CD_01_sp1, lhs_mat_01_01_sp1)); + __m512i iacc_mat_01_0_sp1 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_00_sp1, lhs_mat_01_00_sp1),_mm512_maddubs_epi16(rhs_mat_2367ABEF_01_sp1, lhs_mat_01_01_sp1)); + + __m512i iacc_mat_10_0_sp1 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_00_sp1, lhs_mat_23_00_sp1),_mm512_maddubs_epi16(rhs_mat_014589CD_01_sp1, lhs_mat_23_01_sp1)); + __m512i iacc_mat_11_0_sp1 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_00_sp1, lhs_mat_23_00_sp1),_mm512_maddubs_epi16(rhs_mat_2367ABEF_01_sp1, lhs_mat_23_01_sp1)); + + __m512i iacc_mat_00_1_sp1 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_10_sp1, lhs_mat_01_10_sp1),_mm512_maddubs_epi16(rhs_mat_014589CD_11_sp1, lhs_mat_01_11_sp1)); + __m512i iacc_mat_01_1_sp1 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_10_sp1, lhs_mat_01_10_sp1),_mm512_maddubs_epi16(rhs_mat_2367ABEF_11_sp1, lhs_mat_01_11_sp1)); + + __m512i iacc_mat_10_1_sp1 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_10_sp1, lhs_mat_23_10_sp1),_mm512_maddubs_epi16(rhs_mat_014589CD_11_sp1, lhs_mat_23_11_sp1)); + __m512i iacc_mat_11_1_sp1 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_10_sp1, lhs_mat_23_10_sp1),_mm512_maddubs_epi16(rhs_mat_2367ABEF_11_sp1, lhs_mat_23_11_sp1)); + + __m512i iacc_mat_00_2_sp1 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_20_sp1, lhs_mat_01_20_sp1),_mm512_maddubs_epi16(rhs_mat_014589CD_21_sp1, lhs_mat_01_21_sp1)); + __m512i iacc_mat_01_2_sp1 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_20_sp1, lhs_mat_01_20_sp1),_mm512_maddubs_epi16(rhs_mat_2367ABEF_21_sp1, lhs_mat_01_21_sp1)); + + __m512i iacc_mat_10_2_sp1 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_20_sp1, lhs_mat_23_20_sp1),_mm512_maddubs_epi16(rhs_mat_014589CD_21_sp1, lhs_mat_23_21_sp1)); + __m512i iacc_mat_11_2_sp1 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_20_sp1, lhs_mat_23_20_sp1),_mm512_maddubs_epi16(rhs_mat_2367ABEF_21_sp1, lhs_mat_23_21_sp1)); + + __m512i iacc_mat_00_3_sp1 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_30_sp1, lhs_mat_01_30_sp1),_mm512_maddubs_epi16(rhs_mat_014589CD_31_sp1, lhs_mat_01_31_sp1)); + __m512i iacc_mat_01_3_sp1 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_30_sp1, lhs_mat_01_30_sp1),_mm512_maddubs_epi16(rhs_mat_2367ABEF_31_sp1, lhs_mat_01_31_sp1)); + + __m512i iacc_mat_10_3_sp1 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_30_sp1, lhs_mat_23_30_sp1),_mm512_maddubs_epi16(rhs_mat_014589CD_31_sp1, lhs_mat_23_31_sp1)); + __m512i iacc_mat_11_3_sp1 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_30_sp1, lhs_mat_23_30_sp1),_mm512_maddubs_epi16(rhs_mat_2367ABEF_31_sp1, lhs_mat_23_31_sp1)); + + __m512i iacc_mat_00_4_sp1 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_40_sp1, lhs_mat_01_40_sp1),_mm512_maddubs_epi16(rhs_mat_014589CD_41_sp1, lhs_mat_01_41_sp1)); + __m512i iacc_mat_01_4_sp1 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_40_sp1, lhs_mat_01_40_sp1),_mm512_maddubs_epi16(rhs_mat_2367ABEF_41_sp1, lhs_mat_01_41_sp1)); + + __m512i iacc_mat_10_4_sp1 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_40_sp1, lhs_mat_23_40_sp1),_mm512_maddubs_epi16(rhs_mat_014589CD_41_sp1, lhs_mat_23_41_sp1)); + __m512i iacc_mat_11_4_sp1 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_40_sp1, lhs_mat_23_40_sp1),_mm512_maddubs_epi16(rhs_mat_2367ABEF_41_sp1, lhs_mat_23_41_sp1)); + + __m512i iacc_mat_00_5_sp1 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_50_sp1, lhs_mat_01_50_sp1),_mm512_maddubs_epi16(rhs_mat_014589CD_51_sp1, lhs_mat_01_51_sp1)); + __m512i iacc_mat_01_5_sp1 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_50_sp1, lhs_mat_01_50_sp1),_mm512_maddubs_epi16(rhs_mat_2367ABEF_51_sp1, lhs_mat_01_51_sp1)); + + __m512i iacc_mat_10_5_sp1 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_50_sp1, lhs_mat_23_50_sp1),_mm512_maddubs_epi16(rhs_mat_014589CD_51_sp1, lhs_mat_23_51_sp1)); + __m512i iacc_mat_11_5_sp1 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_50_sp1, lhs_mat_23_50_sp1),_mm512_maddubs_epi16(rhs_mat_2367ABEF_51_sp1, lhs_mat_23_51_sp1)); + + __m512i iacc_mat_00_6_sp1 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_60_sp1, lhs_mat_01_60_sp1),_mm512_maddubs_epi16(rhs_mat_014589CD_61_sp1, lhs_mat_01_61_sp1)); + __m512i iacc_mat_01_6_sp1 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_60_sp1, lhs_mat_01_60_sp1),_mm512_maddubs_epi16(rhs_mat_2367ABEF_61_sp1, lhs_mat_01_61_sp1)); + + __m512i iacc_mat_10_6_sp1 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_60_sp1, lhs_mat_23_60_sp1),_mm512_maddubs_epi16(rhs_mat_014589CD_61_sp1, lhs_mat_23_61_sp1)); + __m512i iacc_mat_11_6_sp1 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_60_sp1, lhs_mat_23_60_sp1),_mm512_maddubs_epi16(rhs_mat_2367ABEF_61_sp1, lhs_mat_23_61_sp1)); + + __m512i iacc_mat_00_7_sp1 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_70_sp1, lhs_mat_01_70_sp1),_mm512_maddubs_epi16(rhs_mat_014589CD_71_sp1, lhs_mat_01_71_sp1)); + __m512i iacc_mat_01_7_sp1 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_70_sp1, lhs_mat_01_70_sp1),_mm512_maddubs_epi16(rhs_mat_2367ABEF_71_sp1, lhs_mat_01_71_sp1)); + + __m512i iacc_mat_10_7_sp1 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_70_sp1, lhs_mat_23_70_sp1),_mm512_maddubs_epi16(rhs_mat_014589CD_71_sp1, lhs_mat_23_71_sp1)); + __m512i iacc_mat_11_7_sp1 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_70_sp1, lhs_mat_23_70_sp1),_mm512_maddubs_epi16(rhs_mat_2367ABEF_71_sp1, lhs_mat_23_71_sp1)); + + + __m512i iacc_mat_00_0_sp2 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_00_sp2, lhs_mat_01_00_sp2),_mm512_maddubs_epi16(rhs_mat_014589CD_01_sp2, lhs_mat_01_01_sp2)); + __m512i iacc_mat_01_0_sp2 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_00_sp2, lhs_mat_01_00_sp2),_mm512_maddubs_epi16(rhs_mat_2367ABEF_01_sp2, lhs_mat_01_01_sp2)); + + __m512i iacc_mat_10_0_sp2 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_00_sp2, lhs_mat_23_00_sp2),_mm512_maddubs_epi16(rhs_mat_014589CD_01_sp2, lhs_mat_23_01_sp2)); + __m512i iacc_mat_11_0_sp2 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_00_sp2, lhs_mat_23_00_sp2),_mm512_maddubs_epi16(rhs_mat_2367ABEF_01_sp2, lhs_mat_23_01_sp2)); + + __m512i iacc_mat_00_1_sp2 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_10_sp2, lhs_mat_01_10_sp2),_mm512_maddubs_epi16(rhs_mat_014589CD_11_sp2, lhs_mat_01_11_sp2)); + __m512i iacc_mat_01_1_sp2 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_10_sp2, lhs_mat_01_10_sp2),_mm512_maddubs_epi16(rhs_mat_2367ABEF_11_sp2, lhs_mat_01_11_sp2)); + + __m512i iacc_mat_10_1_sp2 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_10_sp2, lhs_mat_23_10_sp2),_mm512_maddubs_epi16(rhs_mat_014589CD_11_sp2, lhs_mat_23_11_sp2)); + __m512i iacc_mat_11_1_sp2 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_10_sp2, lhs_mat_23_10_sp2),_mm512_maddubs_epi16(rhs_mat_2367ABEF_11_sp2, lhs_mat_23_11_sp2)); + + __m512i iacc_mat_00_2_sp2 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_20_sp2, lhs_mat_01_20_sp2),_mm512_maddubs_epi16(rhs_mat_014589CD_21_sp2, lhs_mat_01_21_sp2)); + __m512i iacc_mat_01_2_sp2 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_20_sp2, lhs_mat_01_20_sp2),_mm512_maddubs_epi16(rhs_mat_2367ABEF_21_sp2, lhs_mat_01_21_sp2)); + + __m512i iacc_mat_10_2_sp2 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_20_sp2, lhs_mat_23_20_sp2),_mm512_maddubs_epi16(rhs_mat_014589CD_21_sp2, lhs_mat_23_21_sp2)); + __m512i iacc_mat_11_2_sp2 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_20_sp2, lhs_mat_23_20_sp2),_mm512_maddubs_epi16(rhs_mat_2367ABEF_21_sp2, lhs_mat_23_21_sp2)); + + __m512i iacc_mat_00_3_sp2 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_30_sp2, lhs_mat_01_30_sp2),_mm512_maddubs_epi16(rhs_mat_014589CD_31_sp2, lhs_mat_01_31_sp2)); + __m512i iacc_mat_01_3_sp2 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_30_sp2, lhs_mat_01_30_sp2),_mm512_maddubs_epi16(rhs_mat_2367ABEF_31_sp2, lhs_mat_01_31_sp2)); + + __m512i iacc_mat_10_3_sp2 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_30_sp2, lhs_mat_23_30_sp2),_mm512_maddubs_epi16(rhs_mat_014589CD_31_sp2, lhs_mat_23_31_sp2)); + __m512i iacc_mat_11_3_sp2 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_30_sp2, lhs_mat_23_30_sp2),_mm512_maddubs_epi16(rhs_mat_2367ABEF_31_sp2, lhs_mat_23_31_sp2)); + + __m512i iacc_mat_00_4_sp2 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_40_sp2, lhs_mat_01_40_sp2),_mm512_maddubs_epi16(rhs_mat_014589CD_41_sp2, lhs_mat_01_41_sp2)); + __m512i iacc_mat_01_4_sp2 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_40_sp2, lhs_mat_01_40_sp2),_mm512_maddubs_epi16(rhs_mat_2367ABEF_41_sp2, lhs_mat_01_41_sp2)); + + __m512i iacc_mat_10_4_sp2 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_40_sp2, lhs_mat_23_40_sp2),_mm512_maddubs_epi16(rhs_mat_014589CD_41_sp2, lhs_mat_23_41_sp2)); + __m512i iacc_mat_11_4_sp2 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_40_sp2, lhs_mat_23_40_sp2),_mm512_maddubs_epi16(rhs_mat_2367ABEF_41_sp2, lhs_mat_23_41_sp2)); + + __m512i iacc_mat_00_5_sp2 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_50_sp2, lhs_mat_01_50_sp2),_mm512_maddubs_epi16(rhs_mat_014589CD_51_sp2, lhs_mat_01_51_sp2)); + __m512i iacc_mat_01_5_sp2 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_50_sp2, lhs_mat_01_50_sp2),_mm512_maddubs_epi16(rhs_mat_2367ABEF_51_sp2, lhs_mat_01_51_sp2)); + + __m512i iacc_mat_10_5_sp2 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_50_sp2, lhs_mat_23_50_sp2),_mm512_maddubs_epi16(rhs_mat_014589CD_51_sp2, lhs_mat_23_51_sp2)); + __m512i iacc_mat_11_5_sp2 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_50_sp2, lhs_mat_23_50_sp2),_mm512_maddubs_epi16(rhs_mat_2367ABEF_51_sp2, lhs_mat_23_51_sp2)); + + __m512i iacc_mat_00_6_sp2 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_60_sp2, lhs_mat_01_60_sp2),_mm512_maddubs_epi16(rhs_mat_014589CD_61_sp2, lhs_mat_01_61_sp2)); + __m512i iacc_mat_01_6_sp2 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_60_sp2, lhs_mat_01_60_sp2),_mm512_maddubs_epi16(rhs_mat_2367ABEF_61_sp2, lhs_mat_01_61_sp2)); + + __m512i iacc_mat_10_6_sp2 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_60_sp2, lhs_mat_23_60_sp2),_mm512_maddubs_epi16(rhs_mat_014589CD_61_sp2, lhs_mat_23_61_sp2)); + __m512i iacc_mat_11_6_sp2 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_60_sp2, lhs_mat_23_60_sp2),_mm512_maddubs_epi16(rhs_mat_2367ABEF_61_sp2, lhs_mat_23_61_sp2)); + + __m512i iacc_mat_00_7_sp2 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_70_sp2, lhs_mat_01_70_sp2),_mm512_maddubs_epi16(rhs_mat_014589CD_71_sp2, lhs_mat_01_71_sp2)); + __m512i iacc_mat_01_7_sp2 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_70_sp2, lhs_mat_01_70_sp2),_mm512_maddubs_epi16(rhs_mat_2367ABEF_71_sp2, lhs_mat_01_71_sp2)); + + __m512i iacc_mat_10_7_sp2 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_70_sp2, lhs_mat_23_70_sp2),_mm512_maddubs_epi16(rhs_mat_014589CD_71_sp2, lhs_mat_23_71_sp2)); + __m512i iacc_mat_11_7_sp2 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_70_sp2, lhs_mat_23_70_sp2),_mm512_maddubs_epi16(rhs_mat_2367ABEF_71_sp2, lhs_mat_23_71_sp2)); + + // Combine results from both shuffle patterns for each output block + __m512i iacc_mat_00_0 = _mm512_add_epi16(iacc_mat_00_0_sp1, iacc_mat_00_0_sp2); + __m512i iacc_mat_01_0 = _mm512_add_epi16(iacc_mat_01_0_sp1, iacc_mat_01_0_sp2); + __m512i iacc_mat_10_0 = _mm512_add_epi16(iacc_mat_10_0_sp1, iacc_mat_10_0_sp2); + __m512i iacc_mat_11_0 = _mm512_add_epi16(iacc_mat_11_0_sp1, iacc_mat_11_0_sp2); + + __m512i iacc_mat_00_1 = _mm512_add_epi16(iacc_mat_00_1_sp1, iacc_mat_00_1_sp2); + __m512i iacc_mat_01_1 = _mm512_add_epi16(iacc_mat_01_1_sp1, iacc_mat_01_1_sp2); + __m512i iacc_mat_10_1 = _mm512_add_epi16(iacc_mat_10_1_sp1, iacc_mat_10_1_sp2); + __m512i iacc_mat_11_1 = _mm512_add_epi16(iacc_mat_11_1_sp1, iacc_mat_11_1_sp2); + + __m512i iacc_mat_00_2 = _mm512_add_epi16(iacc_mat_00_2_sp1, iacc_mat_00_2_sp2); + __m512i iacc_mat_01_2 = _mm512_add_epi16(iacc_mat_01_2_sp1, iacc_mat_01_2_sp2); + __m512i iacc_mat_10_2 = _mm512_add_epi16(iacc_mat_10_2_sp1, iacc_mat_10_2_sp2); + __m512i iacc_mat_11_2 = _mm512_add_epi16(iacc_mat_11_2_sp1, iacc_mat_11_2_sp2); + + __m512i iacc_mat_00_3 = _mm512_add_epi16(iacc_mat_00_3_sp1, iacc_mat_00_3_sp2); + __m512i iacc_mat_01_3 = _mm512_add_epi16(iacc_mat_01_3_sp1, iacc_mat_01_3_sp2); + __m512i iacc_mat_10_3 = _mm512_add_epi16(iacc_mat_10_3_sp1, iacc_mat_10_3_sp2); + __m512i iacc_mat_11_3 = _mm512_add_epi16(iacc_mat_11_3_sp1, iacc_mat_11_3_sp2); + + __m512i iacc_mat_00_4 = _mm512_add_epi16(iacc_mat_00_4_sp1, iacc_mat_00_4_sp2); + __m512i iacc_mat_01_4 = _mm512_add_epi16(iacc_mat_01_4_sp1, iacc_mat_01_4_sp2); + __m512i iacc_mat_10_4 = _mm512_add_epi16(iacc_mat_10_4_sp1, iacc_mat_10_4_sp2); + __m512i iacc_mat_11_4 = _mm512_add_epi16(iacc_mat_11_4_sp1, iacc_mat_11_4_sp2); + + __m512i iacc_mat_00_5 = _mm512_add_epi16(iacc_mat_00_5_sp1, iacc_mat_00_5_sp2); + __m512i iacc_mat_01_5 = _mm512_add_epi16(iacc_mat_01_5_sp1, iacc_mat_01_5_sp2); + __m512i iacc_mat_10_5 = _mm512_add_epi16(iacc_mat_10_5_sp1, iacc_mat_10_5_sp2); + __m512i iacc_mat_11_5 = _mm512_add_epi16(iacc_mat_11_5_sp1, iacc_mat_11_5_sp2); + + __m512i iacc_mat_00_6 = _mm512_add_epi16(iacc_mat_00_6_sp1, iacc_mat_00_6_sp2); + __m512i iacc_mat_01_6 = _mm512_add_epi16(iacc_mat_01_6_sp1, iacc_mat_01_6_sp2); + __m512i iacc_mat_10_6 = _mm512_add_epi16(iacc_mat_10_6_sp1, iacc_mat_10_6_sp2); + __m512i iacc_mat_11_6 = _mm512_add_epi16(iacc_mat_11_6_sp1, iacc_mat_11_6_sp2); + + __m512i iacc_mat_00_7 = _mm512_add_epi16(iacc_mat_00_7_sp1, iacc_mat_00_7_sp2); + __m512i iacc_mat_01_7 = _mm512_add_epi16(iacc_mat_01_7_sp1, iacc_mat_01_7_sp2); + __m512i iacc_mat_10_7 = _mm512_add_epi16(iacc_mat_10_7_sp1, iacc_mat_10_7_sp2); + __m512i iacc_mat_11_7 = _mm512_add_epi16(iacc_mat_11_7_sp1, iacc_mat_11_7_sp2); + + // Output of both shuffle patterns are added in order to sum dot product outputs of all 32 values in block + iacc_mat_00_0 = _mm512_madd_epi16(iacc_mat_00_0, scale_014589CD_0); + iacc_mat_01_0 = _mm512_madd_epi16(iacc_mat_01_0, scale_2367ABEF_0); + iacc_mat_10_0 = _mm512_madd_epi16(iacc_mat_10_0, scale_014589CD_0); + iacc_mat_11_0 = _mm512_madd_epi16(iacc_mat_11_0, scale_2367ABEF_0); + + iacc_mat_00_1 = _mm512_madd_epi16(iacc_mat_00_1, scale_014589CD_1); + iacc_mat_01_1 = _mm512_madd_epi16(iacc_mat_01_1, scale_2367ABEF_1); + iacc_mat_10_1 = _mm512_madd_epi16(iacc_mat_10_1, scale_014589CD_1); + iacc_mat_11_1 = _mm512_madd_epi16(iacc_mat_11_1, scale_2367ABEF_1); + + iacc_mat_00_2 = _mm512_madd_epi16(iacc_mat_00_2, scale_014589CD_2); + iacc_mat_01_2 = _mm512_madd_epi16(iacc_mat_01_2, scale_2367ABEF_2); + iacc_mat_10_2 = _mm512_madd_epi16(iacc_mat_10_2, scale_014589CD_2); + iacc_mat_11_2 = _mm512_madd_epi16(iacc_mat_11_2, scale_2367ABEF_2); + + iacc_mat_00_3 = _mm512_madd_epi16(iacc_mat_00_3, scale_014589CD_3); + iacc_mat_01_3 = _mm512_madd_epi16(iacc_mat_01_3, scale_2367ABEF_3); + iacc_mat_10_3 = _mm512_madd_epi16(iacc_mat_10_3, scale_014589CD_3); + iacc_mat_11_3 = _mm512_madd_epi16(iacc_mat_11_3, scale_2367ABEF_3); + + iacc_mat_00_4 = _mm512_madd_epi16(iacc_mat_00_4, scale_014589CD_4); + iacc_mat_01_4 = _mm512_madd_epi16(iacc_mat_01_4, scale_2367ABEF_4); + iacc_mat_10_4 = _mm512_madd_epi16(iacc_mat_10_4, scale_014589CD_4); + iacc_mat_11_4 = _mm512_madd_epi16(iacc_mat_11_4, scale_2367ABEF_4); + + iacc_mat_00_5 = _mm512_madd_epi16(iacc_mat_00_5, scale_014589CD_5); + iacc_mat_01_5 = _mm512_madd_epi16(iacc_mat_01_5, scale_2367ABEF_5); + iacc_mat_10_5 = _mm512_madd_epi16(iacc_mat_10_5, scale_014589CD_5); + iacc_mat_11_5 = _mm512_madd_epi16(iacc_mat_11_5, scale_2367ABEF_5); + + iacc_mat_00_6 = _mm512_madd_epi16(iacc_mat_00_6, scale_014589CD_6); + iacc_mat_01_6 = _mm512_madd_epi16(iacc_mat_01_6, scale_2367ABEF_6); + iacc_mat_10_6 = _mm512_madd_epi16(iacc_mat_10_6, scale_014589CD_6); + iacc_mat_11_6 = _mm512_madd_epi16(iacc_mat_11_6, scale_2367ABEF_6); + + iacc_mat_00_7 = _mm512_madd_epi16(iacc_mat_00_7, scale_014589CD_7); + iacc_mat_01_7 = _mm512_madd_epi16(iacc_mat_01_7, scale_2367ABEF_7); + iacc_mat_10_7 = _mm512_madd_epi16(iacc_mat_10_7, scale_014589CD_7); + iacc_mat_11_7 = _mm512_madd_epi16(iacc_mat_11_7, scale_2367ABEF_7); + + __m512i iacc_mat_00 = _mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(iacc_mat_00_0, iacc_mat_00_1), _mm512_add_epi32(iacc_mat_00_2, iacc_mat_00_3)), _mm512_add_epi32(_mm512_add_epi32(iacc_mat_00_4, iacc_mat_00_5), _mm512_add_epi32(iacc_mat_00_6, iacc_mat_00_7))); + __m512i iacc_mat_01 = _mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(iacc_mat_01_0, iacc_mat_01_1), _mm512_add_epi32(iacc_mat_01_2, iacc_mat_01_3)), _mm512_add_epi32(_mm512_add_epi32(iacc_mat_01_4, iacc_mat_01_5), _mm512_add_epi32(iacc_mat_01_6, iacc_mat_01_7))); + __m512i iacc_mat_10 = _mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(iacc_mat_10_0, iacc_mat_10_1), _mm512_add_epi32(iacc_mat_10_2, iacc_mat_10_3)), _mm512_add_epi32(_mm512_add_epi32(iacc_mat_10_4, iacc_mat_10_5), _mm512_add_epi32(iacc_mat_10_6, iacc_mat_10_7))); + __m512i iacc_mat_11 = _mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(iacc_mat_11_0, iacc_mat_11_1), _mm512_add_epi32(iacc_mat_11_2, iacc_mat_11_3)), _mm512_add_epi32(_mm512_add_epi32(iacc_mat_11_4, iacc_mat_11_5), _mm512_add_epi32(iacc_mat_11_6, iacc_mat_11_7))); + + // Straighten out to make 4 row vectors + __m512i iacc_row_0 = _mm512_mask_blend_epi32(0xCCCC, iacc_mat_00, _mm512_shuffle_epi32(iacc_mat_01, (_MM_PERM_ENUM)78)); + __m512i iacc_row_1 = _mm512_mask_blend_epi32(0xCCCC, _mm512_shuffle_epi32(iacc_mat_00, (_MM_PERM_ENUM)78), iacc_mat_01); + __m512i iacc_row_2 = _mm512_mask_blend_epi32(0xCCCC, iacc_mat_10, _mm512_shuffle_epi32(iacc_mat_11, (_MM_PERM_ENUM)78)); + __m512i iacc_row_3 = _mm512_mask_blend_epi32(0xCCCC, _mm512_shuffle_epi32(iacc_mat_10, (_MM_PERM_ENUM)78), iacc_mat_11); + + // Load the scale(d) values for all the 4 Q8_k blocks and repeat it across lanes + const __m128 row_scale_f32_sse = _mm_load_ps(a_ptrs[rp][b].d); + const __m256 row_scale_f32_ymm = _mm256_set_m128(row_scale_f32_sse, row_scale_f32_sse); + const __m512 row_scale_f32 = _mm512_insertf32x8(_mm512_castps256_ps512(row_scale_f32_ymm), row_scale_f32_ymm, 1); + + // Multiply with appropiate scales and accumulate (for both d and dmin) below + acc_rows[rp * 4] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_0), _mm512_mul_ps(col_scale_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 0)), acc_rows[rp * 4]); + acc_rows[rp * 4 + 1] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_1), _mm512_mul_ps(col_scale_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 85)), acc_rows[rp * 4 + 1]); + acc_rows[rp * 4 + 2] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_2), _mm512_mul_ps(col_scale_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 170)), acc_rows[rp * 4 + 2]); + acc_rows[rp * 4 + 3] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_3), _mm512_mul_ps(col_scale_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 255)), acc_rows[rp * 4 + 3]); + + // Take two bsums from two Q8_Ks at a time and multiply with corresponding mins values from each Q2_K + __m512i iacc_row_min_0_01 = _mm512_madd_epi16(_mm512_shuffle_epi32(lhs_bsums_01_0123, (_MM_PERM_ENUM)0), mins_01); + __m512i iacc_row_min_1_01 = _mm512_madd_epi16(_mm512_shuffle_epi32(lhs_bsums_01_0123, (_MM_PERM_ENUM)170), mins_01); + __m512i iacc_row_min_2_01 = _mm512_madd_epi16(_mm512_shuffle_epi32(lhs_bsums_23_0123, (_MM_PERM_ENUM)0), mins_01); + __m512i iacc_row_min_3_01 = _mm512_madd_epi16(_mm512_shuffle_epi32(lhs_bsums_23_0123, (_MM_PERM_ENUM)170), mins_01); + + __m512i iacc_row_min_0_23 = _mm512_madd_epi16(_mm512_shuffle_epi32(lhs_bsums_01_0123, (_MM_PERM_ENUM)85), mins_23); + __m512i iacc_row_min_1_23 = _mm512_madd_epi16(_mm512_shuffle_epi32(lhs_bsums_01_0123, (_MM_PERM_ENUM)255), mins_23); + __m512i iacc_row_min_2_23 = _mm512_madd_epi16(_mm512_shuffle_epi32(lhs_bsums_23_0123, (_MM_PERM_ENUM)85), mins_23); + __m512i iacc_row_min_3_23 = _mm512_madd_epi16(_mm512_shuffle_epi32(lhs_bsums_23_0123, (_MM_PERM_ENUM)255), mins_23); + + __m512i iacc_row_min_0_45 = _mm512_madd_epi16(_mm512_shuffle_epi32(lhs_bsums_01_4567, (_MM_PERM_ENUM)0), mins_45); + __m512i iacc_row_min_1_45 = _mm512_madd_epi16(_mm512_shuffle_epi32(lhs_bsums_01_4567, (_MM_PERM_ENUM)170), mins_45); + __m512i iacc_row_min_2_45 = _mm512_madd_epi16(_mm512_shuffle_epi32(lhs_bsums_23_4567, (_MM_PERM_ENUM)0), mins_45); + __m512i iacc_row_min_3_45 = _mm512_madd_epi16(_mm512_shuffle_epi32(lhs_bsums_23_4567, (_MM_PERM_ENUM)170), mins_45); + + __m512i iacc_row_min_0_67 = _mm512_madd_epi16(_mm512_shuffle_epi32(lhs_bsums_01_4567, (_MM_PERM_ENUM)85), mins_67); + __m512i iacc_row_min_1_67 = _mm512_madd_epi16(_mm512_shuffle_epi32(lhs_bsums_01_4567, (_MM_PERM_ENUM)255), mins_67); + __m512i iacc_row_min_2_67 = _mm512_madd_epi16(_mm512_shuffle_epi32(lhs_bsums_23_4567, (_MM_PERM_ENUM)85), mins_67); + __m512i iacc_row_min_3_67 = _mm512_madd_epi16(_mm512_shuffle_epi32(lhs_bsums_23_4567, (_MM_PERM_ENUM)255), mins_67); + + __m512i iacc_row_min_0 = _mm512_add_epi32(_mm512_add_epi32(iacc_row_min_0_01, iacc_row_min_0_23), _mm512_add_epi32(iacc_row_min_0_45,iacc_row_min_0_67)); + __m512i iacc_row_min_1 = _mm512_add_epi32(_mm512_add_epi32(iacc_row_min_1_01, iacc_row_min_1_23), _mm512_add_epi32(iacc_row_min_1_45,iacc_row_min_1_67)); + __m512i iacc_row_min_2 = _mm512_add_epi32(_mm512_add_epi32(iacc_row_min_2_01, iacc_row_min_2_23), _mm512_add_epi32(iacc_row_min_2_45,iacc_row_min_2_67)); + __m512i iacc_row_min_3 = _mm512_add_epi32(_mm512_add_epi32(iacc_row_min_3_01, iacc_row_min_3_23), _mm512_add_epi32(iacc_row_min_3_45,iacc_row_min_3_67)); + + acc_min_rows[rp * 4] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_min_0), _mm512_mul_ps(col_dmin_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 0)), acc_min_rows[rp * 4]); + acc_min_rows[rp * 4 + 1] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_min_1), _mm512_mul_ps(col_dmin_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 85)), acc_min_rows[rp * 4 + 1]); + acc_min_rows[rp * 4 + 2] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_min_2), _mm512_mul_ps(col_dmin_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 170)), acc_min_rows[rp * 4 + 2]); + acc_min_rows[rp * 4 + 3] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_min_3), _mm512_mul_ps(col_dmin_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 255)), acc_min_rows[rp * 4 + 3]); + } + } + } + // Store the accumulated values + for (int i = 0; i < 16; i++) { + _mm512_storeu_ps((float * )(s + ((y * 4 + i) * bs + x * 8)), _mm512_sub_ps(acc_rows[i], acc_min_rows[i])); + } + } + } + + for (; y < nr / 4; y ++) { + + const block_q8_Kx4 * a_ptr = a_ptr_start + (y * nb); + + // Take group of eight block_q2_kx8 structures at each pass of the loop and perform dot product operation + for (int64_t x = 0; x < anc / 8; x += 2) { + + const block_q2_Kx8 * b_ptr_0 = b_ptr_start + ((x) * b_nb); + const block_q2_Kx8 * b_ptr_1 = b_ptr_start + ((x + 1) * b_nb); + + // Master FP accumulators + __m512 acc_rows[4]; + for (int i = 0; i < 4; i++) { + acc_rows[i] = _mm512_setzero_ps(); + } + + __m512 acc_min_rows[4]; + for (int i = 0; i < 4; i++) { + acc_min_rows[i] = _mm512_setzero_ps(); + } + // For super block + for (int64_t b = 0; b < nb; b++) { + // Delta values - Load the sixteen scale values from two block_q2_kx8 structures + const __m512 col_scale_f32 = GGML_F32Cx8x2_LOAD(b_ptr_0[b].d, b_ptr_1[b].d); + + // dmin values - Load the sixteen dmin values from two block_q2_kx8 structures + const __m512 col_dmin_f32 = GGML_F32Cx8x2_LOAD(b_ptr_0[b].dmin, b_ptr_1[b].dmin); + + // Loop to iterate over the sixteen sub blocks of a super block - eight sub blocks are processed per iteration + for (int sb = 0; sb < QK_K / 128; sb++) { + + // Load the eight block_q2_k for eight sub blocks quantized values interleaved with each other in chunks of eight bytes - B0,B1 ....B6,B7 + const __m256i rhs_raw_mat_0123_0 = _mm256_loadu_si256((const __m256i * )(b_ptr_0[b].qs + sb * 256)); + const __m256i rhs_raw_mat_4567_0 = _mm256_loadu_si256((const __m256i * )(b_ptr_0[b].qs + 32 + sb * 256)); + const __m256i rhs_raw_mat_0123_1 = _mm256_loadu_si256((const __m256i * )(b_ptr_0[b].qs + 64 + sb * 256)); + const __m256i rhs_raw_mat_4567_1 = _mm256_loadu_si256((const __m256i * )(b_ptr_0[b].qs + 96 + sb * 256)); + const __m256i rhs_raw_mat_0123_2 = _mm256_loadu_si256((const __m256i * )(b_ptr_0[b].qs + 128 + sb * 256)); + const __m256i rhs_raw_mat_4567_2 = _mm256_loadu_si256((const __m256i * )(b_ptr_0[b].qs + 160 + sb * 256)); + const __m256i rhs_raw_mat_0123_3 = _mm256_loadu_si256((const __m256i * )(b_ptr_0[b].qs + 192 + sb * 256)); + const __m256i rhs_raw_mat_4567_3 = _mm256_loadu_si256((const __m256i * )(b_ptr_0[b].qs + 224 + sb * 256)); + + const __m256i rhs_raw_mat_89AB_0 = _mm256_loadu_si256((const __m256i * )(b_ptr_1[b].qs + sb * 256)); + const __m256i rhs_raw_mat_CDEF_0 = _mm256_loadu_si256((const __m256i * )(b_ptr_1[b].qs + 32 + sb * 256)); + const __m256i rhs_raw_mat_89AB_1 = _mm256_loadu_si256((const __m256i * )(b_ptr_1[b].qs + 64 + sb * 256)); + const __m256i rhs_raw_mat_CDEF_1 = _mm256_loadu_si256((const __m256i * )(b_ptr_1[b].qs + 96 + sb * 256)); + const __m256i rhs_raw_mat_89AB_2 = _mm256_loadu_si256((const __m256i * )(b_ptr_1[b].qs + 128 + sb * 256)); + const __m256i rhs_raw_mat_CDEF_2 = _mm256_loadu_si256((const __m256i * )(b_ptr_1[b].qs + 160 + sb * 256)); + const __m256i rhs_raw_mat_89AB_3 = _mm256_loadu_si256((const __m256i * )(b_ptr_1[b].qs + 192 + sb * 256)); + const __m256i rhs_raw_mat_CDEF_3 = _mm256_loadu_si256((const __m256i * )(b_ptr_1[b].qs + 224 + sb * 256)); + + const __m256i rhs_raw_mat_0145_0 = _mm256_blend_epi32(rhs_raw_mat_0123_0, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_0, requiredOrder), 240); + const __m256i rhs_raw_mat_2367_0 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_0, requiredOrder), rhs_raw_mat_4567_0, 240); + const __m256i rhs_raw_mat_0145_1 = _mm256_blend_epi32(rhs_raw_mat_0123_1, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_1, requiredOrder), 240); + const __m256i rhs_raw_mat_2367_1 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_1, requiredOrder), rhs_raw_mat_4567_1, 240); + const __m256i rhs_raw_mat_0145_2 = _mm256_blend_epi32(rhs_raw_mat_0123_2, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_2, requiredOrder), 240); + const __m256i rhs_raw_mat_2367_2 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_2, requiredOrder), rhs_raw_mat_4567_2, 240); + const __m256i rhs_raw_mat_0145_3 = _mm256_blend_epi32(rhs_raw_mat_0123_3, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_3, requiredOrder), 240); + const __m256i rhs_raw_mat_2367_3 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_3, requiredOrder), rhs_raw_mat_4567_3, 240); + + const __m256i rhs_raw_mat_89CD_0 = _mm256_blend_epi32(rhs_raw_mat_89AB_0, _mm256_permutevar8x32_epi32(rhs_raw_mat_CDEF_0, requiredOrder), 240); + const __m256i rhs_raw_mat_ABEF_0 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_89AB_0, requiredOrder), rhs_raw_mat_CDEF_0, 240); + const __m256i rhs_raw_mat_89CD_1 = _mm256_blend_epi32(rhs_raw_mat_89AB_1, _mm256_permutevar8x32_epi32(rhs_raw_mat_CDEF_1, requiredOrder), 240); + const __m256i rhs_raw_mat_ABEF_1 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_89AB_1, requiredOrder), rhs_raw_mat_CDEF_1, 240); + const __m256i rhs_raw_mat_89CD_2 = _mm256_blend_epi32(rhs_raw_mat_89AB_2, _mm256_permutevar8x32_epi32(rhs_raw_mat_CDEF_2, requiredOrder), 240); + const __m256i rhs_raw_mat_ABEF_2 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_89AB_2, requiredOrder), rhs_raw_mat_CDEF_2, 240); + const __m256i rhs_raw_mat_89CD_3 = _mm256_blend_epi32(rhs_raw_mat_89AB_3, _mm256_permutevar8x32_epi32(rhs_raw_mat_CDEF_3, requiredOrder), 240); + const __m256i rhs_raw_mat_ABEF_3 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_89AB_3, requiredOrder), rhs_raw_mat_CDEF_3, 240); + + const __m512i rhs_raw_mat_014589CD_0 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_0145_0), rhs_raw_mat_89CD_0, 1); + const __m512i rhs_raw_mat_2367ABEF_0 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_2367_0), rhs_raw_mat_ABEF_0, 1); + const __m512i rhs_raw_mat_014589CD_1 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_0145_1), rhs_raw_mat_89CD_1, 1); + const __m512i rhs_raw_mat_2367ABEF_1 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_2367_1), rhs_raw_mat_ABEF_1, 1); + + const __m512i rhs_raw_mat_014589CD_2 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_0145_2), rhs_raw_mat_89CD_2, 1); + const __m512i rhs_raw_mat_2367ABEF_2 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_2367_2), rhs_raw_mat_ABEF_2, 1); + const __m512i rhs_raw_mat_014589CD_3 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_0145_3), rhs_raw_mat_89CD_3, 1); + const __m512i rhs_raw_mat_2367ABEF_3 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_2367_3), rhs_raw_mat_ABEF_3, 1); + + //2-bit -> 8-bit + const __m512i rhs_mat_014589CD_00 = _mm512_and_si512(rhs_raw_mat_014589CD_0,m3bexpanded); //B00(0-7) B01(0-7) B04(0-7) B05(0-7) B08(0-7) B09(0-7) B0C(0-7) B0D(0-7) + const __m512i rhs_mat_2367ABEF_00 = _mm512_and_si512(rhs_raw_mat_2367ABEF_0,m3bexpanded); //B02(0-7) B03(0-7) B06(0-7) B07(0-7) B0A(0-7) B0B(0-7) B0E(0-7) B0F(0-7) + const __m512i rhs_mat_014589CD_01 = _mm512_and_si512(rhs_raw_mat_014589CD_1,m3bexpanded); //B00(8-15) B01(8-15) B04(8-15) B05(8-15) B08(8-15) B09(8-15) B0C(8-15) B0D(8-15) + const __m512i rhs_mat_2367ABEF_01 = _mm512_and_si512(rhs_raw_mat_2367ABEF_1,m3bexpanded); //B02(8-15) B03(8-15) B06(8-15) B07(8-15) B0A(8-15) B0B(8-15) B0E(8-15) B0F(8-15) + const __m512i rhs_mat_014589CD_10 = _mm512_and_si512(rhs_raw_mat_014589CD_2,m3bexpanded); //B10(0-7) B11(0-7) B14(0-7) B15(0-7) B18(0-7) B19(0-7) B1C(0-7) B1D(0-7) + const __m512i rhs_mat_2367ABEF_10 = _mm512_and_si512(rhs_raw_mat_2367ABEF_2,m3bexpanded); //B12(0-7) B13(0-7) B16(0-7) B17(0-7) B1A(0-7) B1B(0-7) B1E(0-7) B1F(0-7) + const __m512i rhs_mat_014589CD_11 = _mm512_and_si512(rhs_raw_mat_014589CD_3,m3bexpanded); //B10(8-15) B11(8-15) B14(8-15) B15(8-15) B18(8-15) B19(8-15) B1C(8-15) B1D(8-15) + const __m512i rhs_mat_2367ABEF_11 = _mm512_and_si512(rhs_raw_mat_2367ABEF_3,m3bexpanded); //B12(8-15) B13(8-15) B16(8-15) B17(8-15) B1A(8-15) B1B(8-15) B1E(8-15) B1F(8-15) + + const __m512i rhs_mat_014589CD_20 = _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_014589CD_0, 2), m3bexpanded); //B20(0-7) B21(0-7) B24(0-7) B25(0-7) B28(0-7) B29(0-7) B2C(0-7) B2D(0-7) + const __m512i rhs_mat_2367ABEF_20 = _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_2367ABEF_0, 2), m3bexpanded); //B22(0-7) B23(0-7) B26(0-7) B27(0-7) B2A(0-7) B2B(0-7) B2E(0-7) B2F(0-7) + + const __m512i rhs_mat_014589CD_21 = _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_014589CD_1, 2), m3bexpanded); //B20(8-15) B21(8-15) B24(8-15) B25(8-15) B28(8-15) B29(8-15) B2C(8-15) B2D(8-15) + const __m512i rhs_mat_2367ABEF_21 = _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_2367ABEF_1, 2), m3bexpanded); //B22(8-15) B23(8-15) B26(8-15) B27(8-15) B2A(8-15) B2B(8-15) B2E(8-15) B2F(8-15) + + const __m512i rhs_mat_014589CD_30 = _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_014589CD_2, 2), m3bexpanded); //B30(0-7) B31(0-7) B34(0-7) B35(0-7) B38(0-7) B39(0-7) B3C(0-7) B3D(0-7) + const __m512i rhs_mat_2367ABEF_30 = _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_2367ABEF_2, 2), m3bexpanded); //B32(0-7) B33(0-7) B36(0-7) B37(0-7) B3A(0-7) B3B(0-7) B3E(0-7) B3F(0-7) + + const __m512i rhs_mat_014589CD_31 = _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_014589CD_3, 2), m3bexpanded); //B30(8-15) B31(8-15) B34(8-15) B35(8-15) B38(8-15) B39(8-15) B3C(8-15) B3D(8-15) + const __m512i rhs_mat_2367ABEF_31 = _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_2367ABEF_3, 2), m3bexpanded); //B32(8-15) B33(8-15) B36(8-15) B37(8-15) B3A(8-15) B3B(8-15) B3E(8-15) B3F(8-15) + + const __m512i rhs_mat_014589CD_40 = _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_014589CD_0, 4), m3bexpanded); //B40(0-7) B41(0-7) B44(0-7) B45(0-7) B48(0-7) B49(0-7) B4C(0-7) B4D(0-7) + const __m512i rhs_mat_2367ABEF_40 = _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_2367ABEF_0, 4), m3bexpanded); //B42(0-7) B43(0-7) B46(0-7) B47(0-7) B4A(0-7) B4B(0-7) B4E(0-7) B4F(0-7) + + const __m512i rhs_mat_014589CD_41 = _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_014589CD_1, 4), m3bexpanded); //B40(8-15) B41(8-15) B44(8-15) B45(8-15) B48(8-15) B49(8-15) B4C(8-15) B4D(8-15) + const __m512i rhs_mat_2367ABEF_41 = _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_2367ABEF_1, 4), m3bexpanded); //B42(8-15) B43(8-15) B46(8-15) B47(8-15) B4A(8-15) B4B(8-15) B4E(8-15) B4F(8-15) + + const __m512i rhs_mat_014589CD_50 = _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_014589CD_2, 4), m3bexpanded); //B50(0-7) B51(0-7) B54(0-7) B55(0-7) B58(0-7) B59(0-7) B5C(0-7) B5D(0-7) + const __m512i rhs_mat_2367ABEF_50 = _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_2367ABEF_2, 4), m3bexpanded); //B52(0-7) B53(0-7) B56(0-7) B57(0-7) B5A(0-7) B5B(0-7) B5E(0-7) B5F(0-7) + + const __m512i rhs_mat_014589CD_51 = _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_014589CD_3, 4), m3bexpanded); //B50(8-15) B51(8-15) B54(8-15) B55(8-15) B58(8-15) B59(8-15) B5C(8-15) B5D(8-15) + const __m512i rhs_mat_2367ABEF_51 = _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_2367ABEF_3, 4), m3bexpanded); //B52(8-15) B53(8-15) B56(8-15) B57(8-15) B5A(8-15) B5B(8-15) B5E(8-15) B5F(8-15) + + const __m512i rhs_mat_014589CD_60 = _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_014589CD_0, 6), m3bexpanded); //B60(0-7) B61(0-7) B64(0-7) B65(0-7) B68(0-7) B69(0-7) B6C(0-7) B6D(0-7) + const __m512i rhs_mat_2367ABEF_60 = _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_2367ABEF_0, 6), m3bexpanded); //B62(0-7) B63(0-7) B66(0-7) B67(0-7) B6A(0-7) B6B(0-7) B6E(0-7) B6F(0-7) + + const __m512i rhs_mat_014589CD_61 = _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_014589CD_1, 6), m3bexpanded); //B60(8-15) B61(8-15) B64(8-15) B65(8-15) B68(8-15) B69(8-15) B6C(8-15) B6D(8-15) + const __m512i rhs_mat_2367ABEF_61 = _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_2367ABEF_1, 6), m3bexpanded); //B62(8-15) B63(8-15) B66(8-15) B67(8-15) B6A(8-15) B6B(8-15) B6E(8-15) B6F(8-15) + + const __m512i rhs_mat_014589CD_70 = _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_014589CD_2, 6), m3bexpanded); //B70(0-7) B71(0-7) B74(0-7) B75(0-7) B78(0-7) B79(0-7) B7C(0-7) B7D(0-7) + const __m512i rhs_mat_2367ABEF_70 = _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_2367ABEF_2, 6), m3bexpanded); //B72(0-7) B73(0-7) B76(0-7) B77(0-7) B7A(0-7) B7B(0-7) B7E(0-7) B7F(0-7) + + const __m512i rhs_mat_014589CD_71 = _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_014589CD_3, 6), m3bexpanded); //B70(8-15) B71(8-15) B74(8-15) B75(8-15) B78(8-15) B79(8-15) B7C(8-15) B7D(8-15) + const __m512i rhs_mat_2367ABEF_71 = _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_2367ABEF_3, 6), m3bexpanded); //B72(8-15) B73(8-15) B76(8-15) B77(8-15) B7A(8-15) B7B(8-15) B7E(8-15) B7F(8-15) + + const __m512i rhs_mat_014589CD_00_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_00, (_MM_PERM_ENUM)136); //B00(0-3) B01(0-3) B00(0-3) B01(0-3) B04(0-3) B05(0-3) B04(0-3) B05(0-3) B08(0-3) B09(0-3) B08(0-3) B09(0-3) B0C(0-3) B0D(0-3) B0C(0-3) B0D(0-3) + const __m512i rhs_mat_2367ABEF_00_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_00, (_MM_PERM_ENUM)136); //B02(0-3) B03(0-3) B02(0-3) B03(0-3) B06(0-3) B07(0-3) B06(0-3) B07(0-3) B0A(0-3) B0B(0-3) B0A(0-3) B0B(0-3) B0E(0-3) B0F(0-3) B0E(0-3) B0F(0-3) + + const __m512i rhs_mat_014589CD_01_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_01, (_MM_PERM_ENUM)136); //B00(8-11) B01(8-11) B00(8-11) B01(8-11) B04(8-11) B05(8-11) B04(8-11) B05(8-11) B08(8-11) B09(8-11) B08(8-11) B09(8-11) B0C(8-11) B0D(8-11) B0C(8-11) B0D(8-11) + const __m512i rhs_mat_2367ABEF_01_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_01, (_MM_PERM_ENUM)136); //B02(8-11) B03(8-11) B02(8-11) B03(8-11) B06(8-11) B07(8-11) B06(8-11) B07(8-11) B0A(8-11) B0B(8-11) B0A(8-11) B0B(8-11) B0E(8-11) B0F(8-11) B0E(8-11) B0F(8-11) + + const __m512i rhs_mat_014589CD_10_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_10, (_MM_PERM_ENUM)136); //B10(0-3) B11(0-3) B10(0-3) B11(0-3) B14(0-3) B15(0-3) B14(0-3) B15(0-3) B18(0-3) B19(0-3) B18(0-3) B19(0-3) B1C(0-3) B1D(0-3) B1C(0-3) B1D(0-3) + const __m512i rhs_mat_2367ABEF_10_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_10, (_MM_PERM_ENUM)136); //B12(0-3) B13(0-3) B12(0-3) B13(0-3) B16(0-3) B17(0-3) B16(0-3) B17(0-3) B1A(0-3) B1B(0-3) B1A(0-3) B1B(0-3) B1E(0-3) B1F(0-3) B1E(0-3) B1F(0-3) + + const __m512i rhs_mat_014589CD_11_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_11, (_MM_PERM_ENUM)136); //B10(8-11) B11(8-11) B10(8-11) B11(8-11) B14(8-11) B15(8-11) B14(8-11) B15(8-11) B18(8-11) B19(8-11) B18(8-11) B19(8-11) B1C(8-11) B1D(8-11) B1C(8-11) B1D(8-11) + const __m512i rhs_mat_2367ABEF_11_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_11, (_MM_PERM_ENUM)136); //B12(8-11) B13(8-11) B12(8-11) B13(8-11) B16(8-11) B17(8-11) B16(8-11) B17(8-11) B1A(8-11) B1B(8-11) B1A(8-11) B1B(8-11) B1E(8-11) B1F(8-11) B1E(8-11) B1F(8-11) + + const __m512i rhs_mat_014589CD_20_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_20, (_MM_PERM_ENUM)136); //B20(0-3) B21(0-3) B20(0-3) B21(0-3) B24(0-3) B25(0-3) B24(0-3) B25(0-3) B28(0-3) B29(0-3) B28(0-3) B29(0-3) B2C(0-3) B2D(0-3) B2C(0-3) B2D(0-3) + const __m512i rhs_mat_2367ABEF_20_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_20, (_MM_PERM_ENUM)136); //B22(0-3) B23(0-3) B22(0-3) B23(0-3) B26(0-3) B27(0-3) B26(0-3) B27(0-3) B2A(0-3) B2B(0-3) B2A(0-3) B2B(0-3) B2E(0-3) B2F(0-3) B2E(0-3) B2F(0-3) + + const __m512i rhs_mat_014589CD_21_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_21, (_MM_PERM_ENUM)136); //B20(8-11) B21(8-11) B20(8-11) B21(8-11) B24(8-11) B25(8-11) B24(8-11) B25(8-11) B28(8-11) B29(8-11) B28(8-11) B29(8-11) B2C(8-11) B2D(8-11) B2C(8-11) B2D(8-11) + const __m512i rhs_mat_2367ABEF_21_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_21, (_MM_PERM_ENUM)136); //B22(8-11) B23(8-11) B22(8-11) B23(8-11) B26(8-11) B27(8-11) B26(8-11) B27(8-11) B2A(8-11) B2B(8-11) B2A(8-11) B2B(8-11) B2E(8-11) B2F(8-11) B2E(8-11) B2F(8-11) + const __m512i rhs_mat_014589CD_30_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_30, (_MM_PERM_ENUM)136); ///B30(0-3) B31(0-3) B30(0-3) B31(0-3) B34(0-3) B35(0-3) B34(0-3) B35(0-3) B38(0-3) B39(0-3) B38(0-3) B39(0-3) B3C(0-3) B3D(0-3) B3C(0-3) B3D(0-3) + const __m512i rhs_mat_2367ABEF_30_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_30, (_MM_PERM_ENUM)136); //B32(0-3) B33(0-3) B32(0-3) B33(0-3) B36(0-3) B37(0-3) B36(0-3) B37(0-3) B3A(0-3) B3B(0-3) B3A(0-3) B3B(0-3) B3E(0-3) B3F(0-3) B3E(0-3) B3F(0-3) + + const __m512i rhs_mat_014589CD_31_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_31, (_MM_PERM_ENUM)136); //B30(8-11) B31(8-11) B30(8-11) B31(8-11) B34(8-11) B35(8-11) B34(8-11) B35(8-11) B38(8-11) B39(8-11) B38(8-11) B39(8-11) B3C(8-11) B3D(8-11) B3C(8-11) B3D(8-11) + const __m512i rhs_mat_2367ABEF_31_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_31, (_MM_PERM_ENUM)136); //B32(8-11) B33(8-11) B32(8-11) B33(8-11) B36(8-11) B37(8-11) B36(8-11) B37(8-11) B3A(8-11) B3B(8-11) B3A(8-11) B3B(8-11) B3E(8-11) B3F(8-11) B3E(8-11) B3F(8-11) + + const __m512i rhs_mat_014589CD_40_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_40, (_MM_PERM_ENUM)136); //B40(0-3) B41(0-3) B40(0-3) B41(0-3) B44(0-3) B45(0-3) B44(0-3) B45(0-3) B48(0-3) B49(0-3) B48(0-3) B49(0-3) B4C(0-3) B4D(0-3) B4C(0-3) B4D(0-3) + const __m512i rhs_mat_2367ABEF_40_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_40, (_MM_PERM_ENUM)136); //B42(0-3) B43(0-3) B42(0-3) B43(0-3) B46(0-3) B47(0-3) B46(0-3) B47(0-3) B4A(0-3) B4B(0-3) B4A(0-3) B4B(0-3) B4E(0-3) B4F(0-3) B4E(0-3) B4F(0-3) + + const __m512i rhs_mat_014589CD_41_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_41, (_MM_PERM_ENUM)136); //B40(8-11) B41(8-11) B40(8-11) B41(8-11) B44(8-11) B45(8-11) B44(8-11) B45(8-11) B48(8-11) B49(8-11) B48(8-11) B49(8-11) B4C(8-11) B4D(8-11) B4C(8-11) B4D(8-11) + const __m512i rhs_mat_2367ABEF_41_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_41, (_MM_PERM_ENUM)136); //B42(8-11) B43(8-11) B42(8-11) B43(8-11) B46(8-11) B47(8-11) B46(8-11) B47(8-11) B4A(8-11) B4B(8-11) B4A(8-11) B4B(8-11) B4E(8-11) B4F(8-11) B4E(8-11) B4F(8-11) + + const __m512i rhs_mat_014589CD_50_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_50, (_MM_PERM_ENUM)136); //B50(0-3) B51(0-3) B50(0-3) B51(0-3) B54(0-3) B55(0-3) B54(0-3) B55(0-3) B58(0-3) B59(0-3) B58(0-3) B59(0-3) B5C(0-3) B5D(0-3) B5C(0-3) B5D(0-3) + const __m512i rhs_mat_2367ABEF_50_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_50, (_MM_PERM_ENUM)136); //B52(0-3) B53(0-3) B52(0-3) B53(0-3) B56(0-3) B57(0-3) B56(0-3) B57(0-3) B5A(0-3) B5B(0-3) B5A(0-3) B5B(0-3) B5E(0-3) B5F(0-3) B5E(0-3) B5F(0-3) + + const __m512i rhs_mat_014589CD_51_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_51, (_MM_PERM_ENUM)136); //B50(8-11) B51(8-11) B50(8-11) B51(8-11) B54(8-11) B55(8-11) B54(8-11) B55(8-11) B58(8-11) B59(8-11) B58(8-11) B59(8-11) B5C(8-11) B5D(8-11) B5C(8-11) B5D(8-11) + const __m512i rhs_mat_2367ABEF_51_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_51, (_MM_PERM_ENUM)136); //B52(8-11) B53(8-11) B52(8-11) B53(8-11) B56(8-11) B57(8-11) B56(8-11) B57(8-11) B5A(8-11) B5B(8-11) B5A(8-11) B5B(8-11) B5E(8-11) B5F(8-11) B5E(8-11) B5F(8-11) + + const __m512i rhs_mat_014589CD_60_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_60, (_MM_PERM_ENUM)136); //B60(0-3) B61(0-3) B60(0-3) B61(0-3) B64(0-3) B65(0-3) B64(0-3) B65(0-3) B68(0-3) B69(0-3) B68(0-3) B69(0-3) B6C(0-3) B6D(0-3) B6C(0-3) B6D(0-3) + const __m512i rhs_mat_2367ABEF_60_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_60, (_MM_PERM_ENUM)136); //B62(0-3) B63(0-3) B62(0-3) B63(0-3) B66(0-3) B67(0-3) B66(0-3) B67(0-3) B6A(0-3) B6B(0-3) B6A(0-3) B6B(0-3) B6E(0-3) B6F(0-3) B6E(0-3) B6F(0-3) + + const __m512i rhs_mat_014589CD_61_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_61, (_MM_PERM_ENUM)136); //B60(8-11) B61(8-11) B60(8-11) B61(8-11) B64(8-11) B65(8-11) B64(8-11) B65(8-11) B68(8-11) B69(8-11) B68(8-11) B69(8-11) B6C(8-11) B6D(8-11) B6C(8-11) B6D(8-11) + const __m512i rhs_mat_2367ABEF_61_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_61, (_MM_PERM_ENUM)136); //B62(8-11) B63(8-11) B62(8-11) B63(8-11) B66(8-11) B67(8-11) B66(8-11) B67(8-11) B6A(8-11) B6B(8-11) B6A(8-11) B6B(8-11) B6E(8-11) B6F(8-11) B6E(8-11) B6F(8-11) + + const __m512i rhs_mat_014589CD_70_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_70, (_MM_PERM_ENUM)136); //B70(0-3) B71(0-3) B70(0-3) B71(0-3) B74(0-3) B75(0-3) B74(0-3) B75(0-3) B78(0-3) B79(0-3) B78(0-3) B79(0-3) B7C(0-3) B7D(0-3) B7C(0-3) B7D(0-3) + const __m512i rhs_mat_2367ABEF_70_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_70, (_MM_PERM_ENUM)136); //B72(0-3) B73(0-3) B72(0-3) B73(0-3) B76(0-3) B77(0-3) B76(0-3) B77(0-3) B7A(0-3) B7B(0-3) B7A(0-3) B7B(0-3) B7E(0-3) B7F(0-3) B7E(0-3) B7F(0-3) + + const __m512i rhs_mat_014589CD_71_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_71, (_MM_PERM_ENUM)136); //B00(8-11) B01(8-11) B00(8-11) B01(8-11) B04(8-11) B05(8-11) B04(8-11) B05(8-11) B08(8-11) B09(8-11) B08(8-11) B09(8-11) B0C(8-11) B0D(8-11) B0C(8-11) B0D(8-11) + const __m512i rhs_mat_2367ABEF_71_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_71, (_MM_PERM_ENUM)136); //B72(8-11) B73(8-11) B72(8-11) B73(8-11) B76(8-11) B77(8-11) B76(8-11) B77(8-11) B7A(8-11) B7B(8-11) B7A(8-11) B7B(8-11) B7E(8-11) B7F(8-11) B7E(8-11) B7F(8-11) + + const __m512i rhs_mat_014589CD_00_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_00, (_MM_PERM_ENUM)221); //B00(4-7) B01(4-7) B00(4-7) B01(4-7) B04(4-7) B05(4-7) B04(4-7) B05(4-7) B08(4-7) B09(4-7) B08(4-7) B09(4-7) B0C(4-7) B0D(4-7) B0C(4-7) B0D(4-7) + const __m512i rhs_mat_2367ABEF_00_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_00, (_MM_PERM_ENUM)221); //B02(4-7) B03(4-7) B02(4-7) B03(4-7) B06(4-7) B07(4-7) B06(4-7) B07(4-7) B0A(4-7) B0B(4-7) B0A(4-7) B0B(4-7) B0E(4-7) B0F(4-7) B0E(4-7) B0F(4-7) + + const __m512i rhs_mat_014589CD_01_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_01, (_MM_PERM_ENUM)221); //B00(12-15) B01(12-15) B00(12-15) B01(12-15) B04(12-15) B05(12-15) B04(12-15) B05(12-15) B08(12-15) B09(12-15) B08(12-15) B09(12-15) B0C(12-15) B0D(12-15) B0C(12-15) B0D(12-15) + const __m512i rhs_mat_2367ABEF_01_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_01, (_MM_PERM_ENUM)221); //B02(12-15) B03(12-15) B02(12-15) B03(12-15) B06(12-15) B07(12-15) B06(12-15) B07(12-15) B0A(12-15) B0B(12-15) B0A(12-15) B0B(12-15) B0E(12-15) B0F(12-15) B0E(12-15) B0F(12-15) + + const __m512i rhs_mat_014589CD_10_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_10, (_MM_PERM_ENUM)221); //B10(4-7) B11(4-7) B10(4-7) B11(4-7) B14(4-7) B15(4-7) B14(4-7) B15(4-7) B18(4-7) B19(4-7) B18(4-7) B19(4-7) B1C(4-7) B1D(4-7) B1C(4-7) B1D(4-7) + const __m512i rhs_mat_2367ABEF_10_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_10, (_MM_PERM_ENUM)221); //B12(4-7) B13(4-7) B12(4-7) B13(4-7) B16(4-7) B17(4-7) B16(4-7) B17(4-7) B1A(4-7) B1B(4-7) B1A(4-7) B1B(4-7) B1E(4-7) B1F(4-7) B1E(4-7) B1F(4-7) + + const __m512i rhs_mat_014589CD_11_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_11, (_MM_PERM_ENUM)221); //B10(12-15) B11(12-15) B10(12-15) B11(12-15) B14(12-15) B15(12-15) B14(12-15) B15(12-15) B18(12-15) B19(12-15) B18(12-15) B19(12-15) B1C(12-15) B1D(12-15) B1C(12-15) B1D(12-15) + const __m512i rhs_mat_2367ABEF_11_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_11, (_MM_PERM_ENUM)221); //B12(12-15) B13(12-15) B12(12-15) B13(12-15) B16(12-15) B17(12-15) B16(12-15) B17(12-15) B1A(12-15) B1B(12-15) B1A(12-15) B1B(12-15) B1E(12-15) B1F(12-15) B1E(12-15) B1F(12-15) + + const __m512i rhs_mat_014589CD_20_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_20, (_MM_PERM_ENUM)221); //B20(4-7) B21(4-7) B20(4-7) B21(4-7) B24(4-7) B25(4-7) B24(4-7) B25(4-7) B28(4-7) B29(4-7) B28(4-7) B29(4-7) B2C(4-7) B2D(4-7) B2C(4-7) B2D(4-7) + const __m512i rhs_mat_2367ABEF_20_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_20, (_MM_PERM_ENUM)221); //B22(4-7) B23(4-7) B22(4-7) B23(4-7) B26(4-7) B27(4-7) B26(4-7) B27(4-7) B2A(4-7) B2B(4-7) B2A(4-7) B2B(4-7) B2E(4-7) B2F(4-7) B2E(4-7) B2F(4-7) + + const __m512i rhs_mat_014589CD_21_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_21, (_MM_PERM_ENUM)221); //B20(12-15) B21(12-15) B20(12-15) B21(12-15) B24(12-15) B25(12-15) B24(12-15) B25(12-15) B28(12-15) B29(12-15) B28(12-15) B29(12-15) B2C(12-15) B2D(12-15) B2C(12-15) B2D(12-15) + const __m512i rhs_mat_2367ABEF_21_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_21, (_MM_PERM_ENUM)221); //B22(12-15) B23(12-15) B22(12-15) B23(12-15) B26(12-15) B27(12-15) B26(12-15) B27(12-15) B2A(12-15) B2B(12-15) B2A(12-15) B2B(12-15) B2E(12-15) B2F(12-15) B2E(12-15) B2F(12-15) + + const __m512i rhs_mat_014589CD_30_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_30, (_MM_PERM_ENUM)221); //B30(4-7) B31(4-7) B30(4-7) B31(4-7) B34(4-7) B35(4-7) B34(4-7) B35(4-7) B38(4-7) B39(4-7) B38(4-7) B39(4-7) B3C(4-7) B3D(4-7) B3C(4-7) B3D(4-7) + const __m512i rhs_mat_2367ABEF_30_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_30, (_MM_PERM_ENUM)221); //B32(4-7) B33(4-7) B32(4-7) B33(4-7) B36(4-7) B37(4-7) B36(4-7) B37(4-7) B3A(4-7) B3B(4-7) B3A(4-7) B3B(4-7) B3E(4-7) B3F(4-7) B3E(4-7) B3F(4-7) + + const __m512i rhs_mat_014589CD_31_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_31, (_MM_PERM_ENUM)221); //B30(12-15) B31(12-15) B30(12-15) B31(12-15) B34(12-15) B35(12-15) B34(12-15) B35(12-15) B38(12-15) B39(12-15) B38(12-15) B39(12-15) B3C(12-15) B3D(12-15) B3C(12-15) B3D(12-15) + const __m512i rhs_mat_2367ABEF_31_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_31, (_MM_PERM_ENUM)221); //B32(12-15) B33(12-15) B32(12-15) B33(12-15) B36(12-15) B37(12-15) B36(12-15) B37(12-15) B3A(12-15) B3B(12-15) B3A(12-15) B3B(12-15) B3E(12-15) B3F(12-15) B3E(12-15) B3F(12-15) + + const __m512i rhs_mat_014589CD_40_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_40, (_MM_PERM_ENUM)221); //B40(4-7) B41(4-7) B40(4-7) B41(4-7) B44(4-7) B45(4-7) B44(4-7) B45(4-7) B48(4-7) B49(4-7) B48(4-7) B49(4-7) B4C(4-7) B4D(4-7) B4C(4-7) B4D(4-7) + const __m512i rhs_mat_2367ABEF_40_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_40, (_MM_PERM_ENUM)221); //B42(4-7) B43(4-7) B42(4-7) B43(4-7) B46(4-7) B47(4-7) B46(4-7) B47(4-7) B4A(4-7) B4B(4-7) B4A(4-7) B4B(4-7) B4E(4-7) B4F(4-7) B4E(4-7) B4F(4-7) + + const __m512i rhs_mat_014589CD_41_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_41, (_MM_PERM_ENUM)221); //B40(12-15) B41(12-15) B40(12-15) B41(12-15) B44(12-15) B45(12-15) B44(12-15) B45(12-15) B48(12-15) B49(12-15) B48(12-15) B49(12-15) B4C(12-15) B4D(12-15) B4C(12-15) B4D(12-15) + const __m512i rhs_mat_2367ABEF_41_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_41, (_MM_PERM_ENUM)221); //B42(12-15) B43(12-15) B42(12-15) B43(12-15) B46(12-15) B47(12-15) B46(12-15) B47(12-15) B4A(12-15) B4B(12-15) B4A(12-15) B4B(12-15) B4E(12-15) B4F(12-15) B4E(12-15) B4F(12-15) + + const __m512i rhs_mat_014589CD_50_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_50, (_MM_PERM_ENUM)221); //B50(4-7) B51(4-7) B50(4-7) B51(4-7) B54(4-7) B55(4-7) B54(4-7) B55(4-7) B58(4-7) B59(4-7) B58(4-7) B59(4-7) B5C(4-7) B5D(4-7) B5C(4-7) B5D(4-7) + const __m512i rhs_mat_2367ABEF_50_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_50, (_MM_PERM_ENUM)221); //B52(4-7) B53(4-7) B52(4-7) B53(4-7) B56(4-7) B57(4-7) B56(4-7) B57(4-7) B5A(4-7) B5B(4-7) B5A(4-7) B5B(4-7) B5E(4-7) B5F(4-7) B5E(4-7) B5F(4-7) + + const __m512i rhs_mat_014589CD_51_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_51, (_MM_PERM_ENUM)221); //B50(12-15) B51(12-15) B50(12-15) B51(12-15) B54(12-15) B55(12-15) B54(12-15) B55(12-15) B58(12-15) B59(12-15) B58(12-15) B59(12-15) B5C(12-15) B5D(12-15) B5C(12-15) B5D(12-15) + const __m512i rhs_mat_2367ABEF_51_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_51, (_MM_PERM_ENUM)221); //B52(12-15) B53(12-15) B52(12-15) B53(12-15) B56(12-15) B57(12-15) B56(12-15) B57(12-15) B5A(12-15) B5B(12-15) B5A(12-15) B5B(12-15) B5E(12-15) B5F(12-15) B5E(12-15) B5F(12-15) + + const __m512i rhs_mat_014589CD_60_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_60, (_MM_PERM_ENUM)221); //B60(4-7) B61(4-7) B60(4-7) B61(4-7) B64(4-7) B65(4-7) B64(4-7) B65(4-7) B68(4-7) B69(4-7) B68(4-7) B69(4-7) B6C(4-7) B6D(4-7) B6C(4-7) B6D(4-7) + const __m512i rhs_mat_2367ABEF_60_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_60, (_MM_PERM_ENUM)221); //B62(4-7) B63(4-7) B62(4-7) B63(4-7) B66(4-7) B67(4-7) B66(4-7) B67(4-7) B6A(4-7) B6B(4-7) B6A(4-7) B6B(4-7) B6E(4-7) B6F(4-7) B6E(4-7) B6F(4-7) + + const __m512i rhs_mat_014589CD_61_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_61, (_MM_PERM_ENUM)221); //B60(12-15) B61(12-15) B60(12-15) B61(12-15) B64(12-15) B65(12-15) B64(12-15) B65(12-15) B68(12-15) B69(12-15) B68(12-15) B69(12-15) B6C(12-15) B6D(12-15) B6C(12-15) B6D(12-15) + const __m512i rhs_mat_2367ABEF_61_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_61, (_MM_PERM_ENUM)221); //B62(12-15) B63(12-15) B62(12-15) B63(12-15) B66(12-15) B67(12-15) B66(12-15) B67(12-15) B6A(12-15) B6B(12-15) B6A(12-15) B6B(12-15) B6E(12-15) B6F(12-15) B6E(12-15) B6F(12-15) + + const __m512i rhs_mat_014589CD_70_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_70, (_MM_PERM_ENUM)221); //B70(4-7) B71(4-7) B70(4-7) B71(4-7) B74(4-7) B75(4-7) B74(4-7) B75(4-7) B78(4-7) B79(4-7) B78(4-7) B79(4-7) B7C(4-7) B7D(4-7) B7C(4-7) B7D(4-7) + const __m512i rhs_mat_2367ABEF_70_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_70, (_MM_PERM_ENUM)221); //B72(4-7) B73(4-7) B72(4-7) B73(4-7) B76(4-7) B77(4-7) B76(4-7) B77(4-7) B7A(4-7) B7B(4-7) B7A(4-7) B7B(4-7) B7E(4-7) B7F(4-7) B7E(4-7) B7F(4-7) + + const __m512i rhs_mat_014589CD_71_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_71, (_MM_PERM_ENUM)221); //B70(12-15) B71(12-15) B70(12-15) B71(12-15) B74(12-15) B75(12-15) B74(12-15) B75(12-15) B78(12-15) B79(12-15) B78(12-15) B79(12-15) B7C(12-15) B7D(12-15) B7C(12-15) B7D(12-15) + const __m512i rhs_mat_2367ABEF_71_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_71, (_MM_PERM_ENUM)221); //B72(12-15) B73(12-15) B72(12-15) B73(12-15) B76(12-15) B77(12-15) B76(12-15) B77(12-15) B7A(12-15) B7B(12-15) B7A(12-15) B7B(12-15) B7E(12-15) B7F(12-15) B7E(12-15) B7F(12-15) + + //notation:superblock subblock + //s00 m00 s01 m01 s10 m10 s11 m11 s20 m20 s21 m21 s30 m30 s31 m31 s40 m40 s41 m41 s50 m50 s51 m51 s60 m60 s61 m61 s70 m70 s71 m71 + + const __m128i mins_and_scales_01_0 = _mm_loadu_si128((const __m128i *)(b_ptr_0[b].scales + sb * 64)); + const __m128i mins_and_scales_23_0 = _mm_loadu_si128((const __m128i *)(b_ptr_0[b].scales + 16 + sb * 64)); + const __m128i mins_and_scales_45_0 = _mm_loadu_si128((const __m128i *)(b_ptr_0[b].scales + 32 + sb * 64)); + const __m128i mins_and_scales_67_0 = _mm_loadu_si128((const __m128i *)(b_ptr_0[b].scales + 48 + sb * 64)); + + const __m128i mins_and_scales_01_1 = _mm_loadu_si128((const __m128i *)(b_ptr_1[b].scales + sb * 64)); + const __m128i mins_and_scales_23_1 = _mm_loadu_si128((const __m128i *)(b_ptr_1[b].scales + 16 + sb * 64)); + const __m128i mins_and_scales_45_1 = _mm_loadu_si128((const __m128i *)(b_ptr_1[b].scales + 32 + sb * 64)); + const __m128i mins_and_scales_67_1 = _mm_loadu_si128((const __m128i *)(b_ptr_1[b].scales + 48 + sb * 64)); + + // Combine mins and scales for sub-blocks: 0-1, 2-3, 4-5, 6-7 in the sb loop + const __m256i mins_and_scales_01 = _mm256_insertf128_si256(_mm256_castsi128_si256(mins_and_scales_01_0), mins_and_scales_01_1, 1); + const __m256i mins_and_scales_23 = _mm256_insertf128_si256(_mm256_castsi128_si256(mins_and_scales_23_0), mins_and_scales_23_1, 1); + const __m256i mins_and_scales_45 = _mm256_insertf128_si256(_mm256_castsi128_si256(mins_and_scales_45_0), mins_and_scales_45_1, 1); + const __m256i mins_and_scales_67 = _mm256_insertf128_si256(_mm256_castsi128_si256(mins_and_scales_67_0), mins_and_scales_67_1, 1); + + // Extract scales which is lower half from mins_and_scales + const __m256i scales_01 = _mm256_and_si256(mins_and_scales_01, m4b); + const __m256i scales_23 = _mm256_and_si256(mins_and_scales_23, m4b); + const __m256i scales_45 = _mm256_and_si256(mins_and_scales_45, m4b); + const __m256i scales_67 = _mm256_and_si256(mins_and_scales_67, m4b); + + // Extract mins which is upper half from mins_and_scales + const __m512i mins_01 = _mm512_cvtepu8_epi16(_mm256_and_si256(_mm256_srli_epi16(mins_and_scales_01, 4), m4b)); + const __m512i mins_23 = _mm512_cvtepu8_epi16(_mm256_and_si256(_mm256_srli_epi16(mins_and_scales_23, 4), m4b)); + const __m512i mins_45 = _mm512_cvtepu8_epi16(_mm256_and_si256(_mm256_srli_epi16(mins_and_scales_45, 4), m4b)); + const __m512i mins_67 = _mm512_cvtepu8_epi16(_mm256_and_si256(_mm256_srli_epi16(mins_and_scales_67, 4), m4b)); + + const __m512i scales_0 = _mm512_cvtepu8_epi16(_mm256_shuffle_epi8(scales_01, scalesmask1)); + const __m512i scales_1 = _mm512_cvtepu8_epi16(_mm256_shuffle_epi8(scales_01, scalesmask2)); + const __m512i scales_2 = _mm512_cvtepu8_epi16(_mm256_shuffle_epi8(scales_23, scalesmask1)); + const __m512i scales_3 = _mm512_cvtepu8_epi16(_mm256_shuffle_epi8(scales_23, scalesmask2)); + const __m512i scales_4 = _mm512_cvtepu8_epi16(_mm256_shuffle_epi8(scales_45, scalesmask1)); + const __m512i scales_5 = _mm512_cvtepu8_epi16(_mm256_shuffle_epi8(scales_45, scalesmask2)); + const __m512i scales_6 = _mm512_cvtepu8_epi16(_mm256_shuffle_epi8(scales_67, scalesmask1)); + const __m512i scales_7 = _mm512_cvtepu8_epi16(_mm256_shuffle_epi8(scales_67, scalesmask2)); + + const __m512i scale_014589CD_0 = _mm512_shuffle_epi32(scales_0, (_MM_PERM_ENUM)68); + const __m512i scale_2367ABEF_0 = _mm512_shuffle_epi32(scales_0, (_MM_PERM_ENUM)238); + + const __m512i scale_014589CD_1 = _mm512_shuffle_epi32(scales_1, (_MM_PERM_ENUM)68); + const __m512i scale_2367ABEF_1 = _mm512_shuffle_epi32(scales_1, (_MM_PERM_ENUM)238); + + const __m512i scale_014589CD_2 = _mm512_shuffle_epi32(scales_2, (_MM_PERM_ENUM)68); + const __m512i scale_2367ABEF_2 = _mm512_shuffle_epi32(scales_2, (_MM_PERM_ENUM)238); + + const __m512i scale_014589CD_3 = _mm512_shuffle_epi32(scales_3, (_MM_PERM_ENUM)68); + const __m512i scale_2367ABEF_3 = _mm512_shuffle_epi32(scales_3, (_MM_PERM_ENUM)238); + + const __m512i scale_014589CD_4 = _mm512_shuffle_epi32(scales_4, (_MM_PERM_ENUM)68); + const __m512i scale_2367ABEF_4 = _mm512_shuffle_epi32(scales_4, (_MM_PERM_ENUM)238); + + const __m512i scale_014589CD_5 = _mm512_shuffle_epi32(scales_5, (_MM_PERM_ENUM)68); + const __m512i scale_2367ABEF_5 = _mm512_shuffle_epi32(scales_5, (_MM_PERM_ENUM)238); + + const __m512i scale_014589CD_6 = _mm512_shuffle_epi32(scales_6, (_MM_PERM_ENUM)68); + const __m512i scale_2367ABEF_6 = _mm512_shuffle_epi32(scales_6, (_MM_PERM_ENUM)238); + + const __m512i scale_014589CD_7 = _mm512_shuffle_epi32(scales_7, (_MM_PERM_ENUM)68); + const __m512i scale_2367ABEF_7 = _mm512_shuffle_epi32(scales_7, (_MM_PERM_ENUM)238); + + // Load the four block_q8_k quantized values interleaved with each other in chunks of eight bytes - A0,A1,A2,A3 + // Loaded as set of 128 bit vectors and repeated into a 256 bit vector + __m256i lhs_mat_ymm_0123_00 = _mm256_loadu_si256((const __m256i * )((a_ptr[b].qs + 512 * sb))); + __m256i lhs_mat_ymm_01_00 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_00, lhs_mat_ymm_0123_00, 0); + __m256i lhs_mat_ymm_23_00 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_00, lhs_mat_ymm_0123_00, 17); + __m256i lhs_mat_ymm_0123_01 = _mm256_loadu_si256((const __m256i * )((a_ptr[b].qs + 32 + 512 * sb))); + __m256i lhs_mat_ymm_01_01 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_01, lhs_mat_ymm_0123_01, 0); + __m256i lhs_mat_ymm_23_01 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_01, lhs_mat_ymm_0123_01, 17); + __m256i lhs_mat_ymm_0123_10 = _mm256_loadu_si256((const __m256i * )((a_ptr[b].qs + 64 + 512 * sb))); + __m256i lhs_mat_ymm_01_10 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_10, lhs_mat_ymm_0123_10, 0); + __m256i lhs_mat_ymm_23_10 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_10, lhs_mat_ymm_0123_10, 17); + __m256i lhs_mat_ymm_0123_11 = _mm256_loadu_si256((const __m256i * )((a_ptr[b].qs + 96 + 512 * sb))); + __m256i lhs_mat_ymm_01_11 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_11, lhs_mat_ymm_0123_11, 0); + __m256i lhs_mat_ymm_23_11 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_11, lhs_mat_ymm_0123_11, 17); + __m256i lhs_mat_ymm_0123_20 = _mm256_loadu_si256((const __m256i * )((a_ptr[b].qs + 128 + 512 * sb))); + __m256i lhs_mat_ymm_01_20 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_20, lhs_mat_ymm_0123_20, 0); + __m256i lhs_mat_ymm_23_20 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_20, lhs_mat_ymm_0123_20, 17); + __m256i lhs_mat_ymm_0123_21 = _mm256_loadu_si256((const __m256i * )((a_ptr[b].qs + 160 + 512 * sb))); + __m256i lhs_mat_ymm_01_21 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_21, lhs_mat_ymm_0123_21, 0); + __m256i lhs_mat_ymm_23_21 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_21, lhs_mat_ymm_0123_21, 17); + __m256i lhs_mat_ymm_0123_30 = _mm256_loadu_si256((const __m256i * )((a_ptr[b].qs + 192 + 512 * sb))); + __m256i lhs_mat_ymm_01_30 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_30, lhs_mat_ymm_0123_30, 0); + __m256i lhs_mat_ymm_23_30 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_30, lhs_mat_ymm_0123_30, 17); + __m256i lhs_mat_ymm_0123_31 = _mm256_loadu_si256((const __m256i * )((a_ptr[b].qs + 224 + 512 * sb))); + __m256i lhs_mat_ymm_01_31 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_31, lhs_mat_ymm_0123_31, 0); + __m256i lhs_mat_ymm_23_31 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_31, lhs_mat_ymm_0123_31, 17); + + __m256i lhs_mat_ymm_0123_40 = _mm256_loadu_si256((const __m256i * )((a_ptr[b].qs + 256 + 512 * sb))); + __m256i lhs_mat_ymm_01_40 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_40, lhs_mat_ymm_0123_40, 0); + __m256i lhs_mat_ymm_23_40 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_40, lhs_mat_ymm_0123_40, 17); + __m256i lhs_mat_ymm_0123_41 = _mm256_loadu_si256((const __m256i * )((a_ptr[b].qs + 288 + 512 * sb))); + __m256i lhs_mat_ymm_01_41 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_41, lhs_mat_ymm_0123_41, 0); + __m256i lhs_mat_ymm_23_41 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_41, lhs_mat_ymm_0123_41, 17); + __m256i lhs_mat_ymm_0123_50 = _mm256_loadu_si256((const __m256i * )((a_ptr[b].qs + 320 + 512 * sb))); + __m256i lhs_mat_ymm_01_50 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_50, lhs_mat_ymm_0123_50, 0); + __m256i lhs_mat_ymm_23_50 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_50, lhs_mat_ymm_0123_50, 17); + __m256i lhs_mat_ymm_0123_51 = _mm256_loadu_si256((const __m256i * )((a_ptr[b].qs + 352 + 512 * sb))); + __m256i lhs_mat_ymm_01_51 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_51, lhs_mat_ymm_0123_51, 0); + __m256i lhs_mat_ymm_23_51 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_51, lhs_mat_ymm_0123_51, 17); + __m256i lhs_mat_ymm_0123_60 = _mm256_loadu_si256((const __m256i * )((a_ptr[b].qs + 384 + 512 * sb))); + __m256i lhs_mat_ymm_01_60 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_60, lhs_mat_ymm_0123_60, 0); + __m256i lhs_mat_ymm_23_60 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_60, lhs_mat_ymm_0123_60, 17); + __m256i lhs_mat_ymm_0123_61 = _mm256_loadu_si256((const __m256i * )((a_ptr[b].qs + 416 + 512 * sb))); + __m256i lhs_mat_ymm_01_61 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_61, lhs_mat_ymm_0123_61, 0); + __m256i lhs_mat_ymm_23_61 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_61, lhs_mat_ymm_0123_61, 17); + __m256i lhs_mat_ymm_0123_70 = _mm256_loadu_si256((const __m256i * )((a_ptr[b].qs + 448 + 512 * sb))); + __m256i lhs_mat_ymm_01_70 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_70, lhs_mat_ymm_0123_70, 0); + __m256i lhs_mat_ymm_23_70 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_70, lhs_mat_ymm_0123_70, 17); + __m256i lhs_mat_ymm_0123_71 = _mm256_loadu_si256((const __m256i * )((a_ptr[b].qs + 480 + 512 * sb))); + __m256i lhs_mat_ymm_01_71 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_71, lhs_mat_ymm_0123_71, 0); + __m256i lhs_mat_ymm_23_71 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_71, lhs_mat_ymm_0123_71, 17); + + __m512i lhs_mat_01_00 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_00), lhs_mat_ymm_01_00, 1); + __m512i lhs_mat_23_00 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_00), lhs_mat_ymm_23_00, 1); + __m512i lhs_mat_01_01 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_01), lhs_mat_ymm_01_01, 1); + __m512i lhs_mat_23_01 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_01), lhs_mat_ymm_23_01, 1); + + __m512i lhs_mat_01_10 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_10), lhs_mat_ymm_01_10, 1); + __m512i lhs_mat_23_10 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_10), lhs_mat_ymm_23_10, 1); + __m512i lhs_mat_01_11 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_11), lhs_mat_ymm_01_11, 1); + __m512i lhs_mat_23_11 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_11), lhs_mat_ymm_23_11, 1); + + __m512i lhs_mat_01_20 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_20), lhs_mat_ymm_01_20, 1); + __m512i lhs_mat_23_20 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_20), lhs_mat_ymm_23_20, 1); + __m512i lhs_mat_01_21 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_21), lhs_mat_ymm_01_21, 1); + __m512i lhs_mat_23_21 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_21), lhs_mat_ymm_23_21, 1); + + __m512i lhs_mat_01_30 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_30), lhs_mat_ymm_01_30, 1); + __m512i lhs_mat_23_30 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_30), lhs_mat_ymm_23_30, 1); + __m512i lhs_mat_01_31 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_31), lhs_mat_ymm_01_31, 1); + __m512i lhs_mat_23_31 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_31), lhs_mat_ymm_23_31, 1); + + __m512i lhs_mat_01_40 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_40), lhs_mat_ymm_01_40, 1); + __m512i lhs_mat_23_40 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_40), lhs_mat_ymm_23_40, 1); + __m512i lhs_mat_01_41 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_41), lhs_mat_ymm_01_41, 1); + __m512i lhs_mat_23_41 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_41), lhs_mat_ymm_23_41, 1); + + __m512i lhs_mat_01_50 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_50), lhs_mat_ymm_01_50, 1); + __m512i lhs_mat_23_50 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_50), lhs_mat_ymm_23_50, 1); + __m512i lhs_mat_01_51 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_51), lhs_mat_ymm_01_51, 1); + __m512i lhs_mat_23_51 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_51), lhs_mat_ymm_23_51, 1); + + __m512i lhs_mat_01_60 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_60), lhs_mat_ymm_01_60, 1); + __m512i lhs_mat_23_60 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_60), lhs_mat_ymm_23_60, 1); + __m512i lhs_mat_01_61 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_61), lhs_mat_ymm_01_61, 1); + __m512i lhs_mat_23_61 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_61), lhs_mat_ymm_23_61, 1); + + __m512i lhs_mat_01_70 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_70), lhs_mat_ymm_01_70, 1); + __m512i lhs_mat_23_70 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_70), lhs_mat_ymm_23_70, 1); + __m512i lhs_mat_01_71 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_71), lhs_mat_ymm_01_71, 1); + __m512i lhs_mat_23_71 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_71), lhs_mat_ymm_23_71, 1); + + // Bsums are loaded for the different Q8_K blocks + __m128i lhs_raw_bsums_01_0123 = _mm_loadu_si128((const __m128i *)((a_ptr[b].bsums + 32 * sb))); + __m128i lhs_raw_bsums_23_0123 = _mm_loadu_si128((const __m128i *)(a_ptr[b].bsums + 8 + 32 * sb)); + __m128i lhs_raw_bsums_01_4567 = _mm_loadu_si128((const __m128i *)((a_ptr[b].bsums + 16 + 32 * sb))); + __m128i lhs_raw_bsums_23_4567 = _mm_loadu_si128((const __m128i *)(a_ptr[b].bsums + 24 + 32 * sb)); + + __m256i lhs_bsums_ymm_01_0123 = _mm256_inserti128_si256(_mm256_castsi128_si256(lhs_raw_bsums_01_0123), lhs_raw_bsums_01_0123, 1); + __m512i lhs_bsums_01_0123 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_bsums_ymm_01_0123), lhs_bsums_ymm_01_0123, 1); + __m256i lhs_bsums_ymm_23_0123 = _mm256_inserti128_si256(_mm256_castsi128_si256(lhs_raw_bsums_23_0123), lhs_raw_bsums_23_0123, 1); + __m512i lhs_bsums_23_0123 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_bsums_ymm_23_0123), lhs_bsums_ymm_23_0123, 1); + __m256i lhs_bsums_ymm_01_4567 = _mm256_inserti128_si256(_mm256_castsi128_si256(lhs_raw_bsums_01_4567), lhs_raw_bsums_01_4567, 1); + __m512i lhs_bsums_01_4567 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_bsums_ymm_01_4567), lhs_bsums_ymm_01_4567, 1); + __m256i lhs_bsums_ymm_23_4567 = _mm256_inserti128_si256(_mm256_castsi128_si256(lhs_raw_bsums_23_4567), lhs_raw_bsums_23_4567, 1); + __m512i lhs_bsums_23_4567 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_bsums_ymm_23_4567), lhs_bsums_ymm_23_4567, 1); + + // Shuffle pattern one - left side input + const __m512i lhs_mat_01_00_sp1 = _mm512_shuffle_epi32(lhs_mat_01_00, (_MM_PERM_ENUM)160); //A00(0-3) A00(0-3) A01(0-3) A01(0-3) A00(0-3) A00(0-3) A01(0-3) A01(0-3) A00(0-3) A00(0-3) A01(0-3) A01(0-3) A00(0-3) A00(0-3) A01(0-3) A01(0-3) + const __m512i lhs_mat_23_00_sp1 = _mm512_shuffle_epi32(lhs_mat_23_00, (_MM_PERM_ENUM)160); //A02(0-3) A02(0-3) A03(0-3) A03(0-3) A02(0-3) A02(0-3) A03(0-3) A03(0-3) A02(0-3) A02(0-3) A03(0-3) A03(0-3) A02(0-3) A02(0-3) A03(0-3) A03(0-3) + + const __m512i lhs_mat_01_01_sp1 = _mm512_shuffle_epi32(lhs_mat_01_01, (_MM_PERM_ENUM)160); //A00(8-11) A00(8-11) A01(8-11) A01(8-11) A00(8-11) A00(8-11) A01(8-11) A01(8-11) A00(8-11) A00(8-11) A01(8-11) A01(8-11) A00(8-11) A00(8-11) A01(8-11) A01(8-11) + const __m512i lhs_mat_23_01_sp1 = _mm512_shuffle_epi32(lhs_mat_23_01, (_MM_PERM_ENUM)160); //A02(8-11) A02(8-11) A03(8-11) A03(8-11) A02(8-11) A02(8-11) A03(8-11) A03(8-11) A02(8-11) A02(8-11) A03(8-11) A03(8-11) A02(8-11) A02(8-11) A03(8-11) A03(8-11) + + const __m512i lhs_mat_01_10_sp1 = _mm512_shuffle_epi32(lhs_mat_01_10, (_MM_PERM_ENUM)160); //A10(0-3) A10(0-3) A11(0-3) A11(0-3) A10(0-3) A10(0-3) A11(0-3) A11(0-3) A10(0-3) A10(0-3) A11(0-3) A11(0-3) A10(0-3) A10(0-3) A11(0-3) A11(0-3) + const __m512i lhs_mat_23_10_sp1 = _mm512_shuffle_epi32(lhs_mat_23_10, (_MM_PERM_ENUM)160); //A12(0-3) A12(0-3) A13(0-3) A13(0-3) A12(0-3) A12(0-3) A13(0-3) A13(0-3) A12(0-3) A12(0-3) A13(0-3) A13(0-3) A12(0-3) A12(0-3) A13(0-3) A13(0-3) + + const __m512i lhs_mat_01_11_sp1 = _mm512_shuffle_epi32(lhs_mat_01_11, (_MM_PERM_ENUM)160); //A10(8-11) A10(8-11) A11(8-11) A11(8-11) A10(8-11) A10(8-11) A11(8-11) A11(8-11) A10(8-11) A10(8-11) A11(8-11) A11(8-11) A10(8-11) A10(8-11) A11(8-11) A11(8-11) + const __m512i lhs_mat_23_11_sp1 = _mm512_shuffle_epi32(lhs_mat_23_11, (_MM_PERM_ENUM)160); //A12(8-11) A12(8-11) A13(8-11) A13(8-11) A12(8-11) A12(8-11) A13(8-11) A13(8-11) A12(8-11) A12(8-11) A13(8-11) A13(8-11) A12(8-11) A12(8-11) A13(8-11) A13(8-11) + + const __m512i lhs_mat_01_20_sp1 = _mm512_shuffle_epi32(lhs_mat_01_20, (_MM_PERM_ENUM)160); //A20(0-3) A20(0-3) A21(0-3) A21(0-3) A20(0-3) A20(0-3) A21(0-3) A21(0-3) A20(0-3) A20(0-3) A21(0-3) A21(0-3) A20(0-3) A20(0-3) A21(0-3) A21(0-3) + const __m512i lhs_mat_23_20_sp1 = _mm512_shuffle_epi32(lhs_mat_23_20, (_MM_PERM_ENUM)160); //A22(0-3) A22(0-3) A23(0-3) A23(0-3) A22(0-3) A22(0-3) A23(0-3) A23(0-3) A22(0-3) A22(0-3) A23(0-3) A23(0-3) A22(0-3) A22(0-3) A23(0-3) A23(0-3) + + const __m512i lhs_mat_01_21_sp1 = _mm512_shuffle_epi32(lhs_mat_01_21, (_MM_PERM_ENUM)160); //A20(8-11) A20(8-11) A21(8-11) A21(8-11) A20(8-11) A20(8-11) A21(8-11) A21(8-11) A20(8-11) A20(8-11) A21(8-11) A21(8-11) A20(8-11) A20(8-11) A21(8-11) A21(8-11) + const __m512i lhs_mat_23_21_sp1 = _mm512_shuffle_epi32(lhs_mat_23_21, (_MM_PERM_ENUM)160); //A22(8-11) A22(8-11) A23(8-11) A23(8-11) A22(8-11) A22(8-11) A23(8-11) A23(8-11) A22(8-11) A22(8-11) A23(8-11) A23(8-11) A22(8-11) A22(8-11) A23(8-11) A23(8-11) + + const __m512i lhs_mat_01_30_sp1 = _mm512_shuffle_epi32(lhs_mat_01_30, (_MM_PERM_ENUM)160); //A30(0-3) A30(0-3) A31(0-3) A31(0-3) A30(0-3) A30(0-3) A31(0-3) A31(0-3) A30(0-3) A30(0-3) A31(0-3) A31(0-3) A30(0-3) A30(0-3) A31(0-3) A31(0-3) + const __m512i lhs_mat_23_30_sp1 = _mm512_shuffle_epi32(lhs_mat_23_30, (_MM_PERM_ENUM)160); //A32(0-3) A32(0-3) A33(0-3) A33(0-3) A32(0-3) A32(0-3) A33(0-3) A33(0-3) A32(0-3) A32(0-3) A33(0-3) A33(0-3) A32(0-3) A32(0-3) A33(0-3) A33(0-3) + + const __m512i lhs_mat_01_31_sp1 = _mm512_shuffle_epi32(lhs_mat_01_31, (_MM_PERM_ENUM)160); //A30(8-11) A30(8-11) A31(8-11) A31(8-11) A30(8-11) A30(8-11) A31(8-11) A31(8-11) A30(8-11) A30(8-11) A31(8-11) A31(8-11) A30(8-11) A30(8-11) A31(8-11) A31(8-11) + const __m512i lhs_mat_23_31_sp1 = _mm512_shuffle_epi32(lhs_mat_23_31, (_MM_PERM_ENUM)160); //A32(8-11) A32(8-11) A33(8-11) A33(8-11) A32(8-11) A32(8-11) A33(8-11) A33(8-11) A32(8-11) A32(8-11) A33(8-11) A33(8-11) A32(8-11) A32(8-11) A33(8-11) A33(8-11) + + const __m512i lhs_mat_01_40_sp1 = _mm512_shuffle_epi32(lhs_mat_01_40, (_MM_PERM_ENUM)160); //A40(0-3) A40(0-3) A41(0-3) A41(0-3) A40(0-3) A40(0-3) A41(0-3) A41(0-3) A40(0-3) A40(0-3) A41(0-3) A41(0-3) A40(0-3) A40(0-3) A41(0-3) A41(0-3) + const __m512i lhs_mat_23_40_sp1 = _mm512_shuffle_epi32(lhs_mat_23_40, (_MM_PERM_ENUM)160); //A42(0-3) A42(0-3) A43(0-3) A43(0-3) A42(0-3) A42(0-3) A43(0-3) A43(0-3) A42(0-3) A42(0-3) A43(0-3) A43(0-3) A42(0-3) A42(0-3) A43(0-3) A43(0-3) + + const __m512i lhs_mat_01_41_sp1 = _mm512_shuffle_epi32(lhs_mat_01_41, (_MM_PERM_ENUM)160); //A40(8-11) A40(8-11) A41(8-11) A41(8-11) A40(8-11) A40(8-11) A41(8-11) A41(8-11) A40(8-11) A40(8-11) A41(8-11) A41(8-11) A40(8-11) A40(8-11) A41(8-11) A41(8-11) + const __m512i lhs_mat_23_41_sp1 = _mm512_shuffle_epi32(lhs_mat_23_41, (_MM_PERM_ENUM)160); //A42(8-11) A42(8-11) A43(8-11) A43(8-11) A42(8-11) A42(8-11) A43(8-11) A43(8-11) A42(8-11) A42(8-11) A43(8-11) A43(8-11) A42(8-11) A42(8-11) A43(8-11) A43(8-11) + + const __m512i lhs_mat_01_50_sp1 = _mm512_shuffle_epi32(lhs_mat_01_50, (_MM_PERM_ENUM)160); //A50(0-3) A50(0-3) A51(0-3) A51(0-3) A50(0-3) A50(0-3) A51(0-3) A51(0-3) A50(0-3) A50(0-3) A51(0-3) A51(0-3) A50(0-3) A50(0-3) A51(0-3) A51(0-3) + const __m512i lhs_mat_23_50_sp1 = _mm512_shuffle_epi32(lhs_mat_23_50, (_MM_PERM_ENUM)160); //A52(0-3) A52(0-3) A53(0-3) A53(0-3) A52(0-3) A52(0-3) A53(0-3) A53(0-3) A52(0-3) A52(0-3) A53(0-3) A53(0-3) A52(0-3) A52(0-3) A53(0-3) A53(0-3) + + const __m512i lhs_mat_01_51_sp1 = _mm512_shuffle_epi32(lhs_mat_01_51, (_MM_PERM_ENUM)160); //A50(8-11) A50(8-11) A51(8-11) A51(8-11) A50(8-11) A50(8-11) A51(8-11) A51(8-11) A50(8-11) A50(8-11) A51(8-11) A51(8-11) A50(8-11) A50(8-11) A51(8-11) A51(8-11) + const __m512i lhs_mat_23_51_sp1 = _mm512_shuffle_epi32(lhs_mat_23_51, (_MM_PERM_ENUM)160); //A52(8-11) A52(8-11) A53(8-11) A53(8-11) A52(8-11) A52(8-11) A53(8-11) A53(8-11) A52(8-11) A52(8-11) A53(8-11) A53(8-11) A52(8-11) A52(8-11) A53(8-11) A53(8-11) + + const __m512i lhs_mat_01_60_sp1 = _mm512_shuffle_epi32(lhs_mat_01_60, (_MM_PERM_ENUM)160); //A60(0-3) A60(0-3) A61(0-3) A61(0-3) A60(0-3) A60(0-3) A61(0-3) A61(0-3) A60(0-3) A60(0-3) A61(0-3) A61(0-3) A60(0-3) A60(0-3) A61(0-3) A61(0-3) + const __m512i lhs_mat_23_60_sp1 = _mm512_shuffle_epi32(lhs_mat_23_60, (_MM_PERM_ENUM)160); //A62(0-3) A62(0-3) A63(0-3) A63(0-3) A62(0-3) A62(0-3) A63(0-3) A63(0-3) A62(0-3) A62(0-3) A63(0-3) A63(0-3) A62(0-3) A62(0-3) A63(0-3) A63(0-3) + + const __m512i lhs_mat_01_61_sp1 = _mm512_shuffle_epi32(lhs_mat_01_61, (_MM_PERM_ENUM)160); //A60(8-11) A60(8-11) A61(8-11) A61(8-11) A60(8-11) A60(8-11) A61(8-11) A61(8-11) A60(8-11) A60(8-11) A61(8-11) A61(8-11) A60(8-11) A60(8-11) A61(8-11) A61(8-11) + const __m512i lhs_mat_23_61_sp1 = _mm512_shuffle_epi32(lhs_mat_23_61, (_MM_PERM_ENUM)160); //A62(8-11) A62(8-11) A63(8-11) A63(8-11) A62(8-11) A62(8-11) A63(8-11) A63(8-11) A62(8-11) A62(8-11) A63(8-11) A63(8-11) A62(8-11) A62(8-11) A63(8-11) A63(8-11) + + const __m512i lhs_mat_01_70_sp1 = _mm512_shuffle_epi32(lhs_mat_01_70, (_MM_PERM_ENUM)160); //A70(0-3) A70(0-3) A71(0-3) A71(0-3) A70(0-3) A70(0-3) A71(0-3) A71(0-3) A70(0-3) A70(0-3) A71(0-3) A71(0-3) A70(0-3) A70(0-3) A71(0-3) A71(0-3) + const __m512i lhs_mat_23_70_sp1 = _mm512_shuffle_epi32(lhs_mat_23_70, (_MM_PERM_ENUM)160); //A72(0-3) A72(0-3) A73(0-3) A73(0-3) A72(0-3) A72(0-3) A73(0-3) A73(0-3) A72(0-3) A72(0-3) A73(0-3) A73(0-3) A72(0-3) A72(0-3) A73(0-3) A73(0-3) + + const __m512i lhs_mat_01_71_sp1 = _mm512_shuffle_epi32(lhs_mat_01_71, (_MM_PERM_ENUM)160); //A70(8-11) A70(8-11) A71(8-11) A71(8-11) A70(8-11) A70(8-11) A71(8-11) A71(8-11) A70(8-11) A70(8-11) A71(8-11) A71(8-11) A70(8-11) A70(8-11) A71(8-11) A71(8-11) + const __m512i lhs_mat_23_71_sp1 = _mm512_shuffle_epi32(lhs_mat_23_71, (_MM_PERM_ENUM)160); //A72(8-11) A72(8-11) A73(8-11) A73(8-11) A72(8-11) A72(8-11) A73(8-11) A73(8-11) A72(8-11) A72(8-11) A73(8-11) A73(8-11) A72(8-11) A72(8-11) A73(8-11) A73(8-11) + + const __m512i lhs_mat_01_00_sp2 = _mm512_shuffle_epi32(lhs_mat_01_00, (_MM_PERM_ENUM)245); //A00(4-7) A00(4-7) A01(4-7) A01(4-7) A00(4-7) A00(4-7) A01(4-7) A01(4-7) A00(4-7) A00(4-7) A01(4-7) A01(4-7) A00(4-7) A00(4-7) A01(4-7) A01(4-7) + const __m512i lhs_mat_23_00_sp2 = _mm512_shuffle_epi32(lhs_mat_23_00, (_MM_PERM_ENUM)245); //A02(4-7) A02(4-7) A03(4-7) A03(4-7) A02(4-7) A02(4-7) A03(4-7) A03(4-7) A02(4-7) A02(4-7) A03(4-7) A03(4-7) A02(4-7) A02(4-7) A03(4-7) A03(4-7) + + const __m512i lhs_mat_01_01_sp2 = _mm512_shuffle_epi32(lhs_mat_01_01, (_MM_PERM_ENUM)245); //A00(12-15) A00(12-15) A01(12-15) A01(12-15) A00(12-15) A00(12-15) A01(12-15) A01(12-15) A00(12-15) A00(12-15) A01(12-15) A01(12-15) A00(12-15) A00(12-15) A01(12-15) A01(12-15) + const __m512i lhs_mat_23_01_sp2 = _mm512_shuffle_epi32(lhs_mat_23_01, (_MM_PERM_ENUM)245); //A02(12-15) A02(12-15) A03(12-15) A03(12-15) A02(12-15) A02(12-15) A03(12-15) A03(12-15) A02(12-15) A02(12-15) A03(12-15) A03(12-15) A02(12-15) A02(12-15) A03(12-15) A03(12-15) + + const __m512i lhs_mat_01_10_sp2 = _mm512_shuffle_epi32(lhs_mat_01_10, (_MM_PERM_ENUM)245); //A10(4-7) A10(4-7) A11(4-7) A11(4-7) A10(4-7) A10(4-7) A11(4-7) A11(4-7) A10(4-7) A10(4-7) A11(4-7) A11(4-7) A10(4-7) A10(4-7) A11(4-7) A11(4-7) + const __m512i lhs_mat_23_10_sp2 = _mm512_shuffle_epi32(lhs_mat_23_10, (_MM_PERM_ENUM)245); //A12(4-7) A12(4-7) A13(4-7) A13(4-7) A12(4-7) A12(4-7) A13(4-7) A13(4-7) A12(4-7) A12(4-7) A13(4-7) A13(4-7) A12(4-7) A12(4-7) A13(4-7) A13(4-7) + + const __m512i lhs_mat_01_11_sp2 = _mm512_shuffle_epi32(lhs_mat_01_11, (_MM_PERM_ENUM)245); //A10(12-15) A10(12-15) A11(12-15) A11(12-15) A10(12-15) A10(12-15) A11(12-15) A11(12-15) A10(12-15) A10(12-15) A11(12-15) A11(12-15) A10(12-15) A10(12-15) A11(12-15) A11(12-15) + const __m512i lhs_mat_23_11_sp2 = _mm512_shuffle_epi32(lhs_mat_23_11, (_MM_PERM_ENUM)245); //A12(12-15) A12(12-15) A13(12-15) A13(12-15) A12(12-15) A12(12-15) A13(12-15) A13(12-15) A12(12-15) A12(12-15) A13(12-15) A13(12-15) A12(12-15) A12(12-15) A13(12-15) A13(12-15) + + const __m512i lhs_mat_01_20_sp2 = _mm512_shuffle_epi32(lhs_mat_01_20, (_MM_PERM_ENUM)245); //A20(4-7) A20(4-7) A21(4-7) A21(4-7) A20(4-7) A20(4-7) A21(4-7) A21(4-7) A20(4-7) A20(4-7) A21(4-7) A21(4-7) A20(4-7) A20(4-7) A21(4-7) A21(4-7) + const __m512i lhs_mat_23_20_sp2 = _mm512_shuffle_epi32(lhs_mat_23_20, (_MM_PERM_ENUM)245); //A22(4-7) A22(4-7) A23(4-7) A23(4-7) A22(4-7) A22(4-7) A23(4-7) A23(4-7) A22(4-7) A22(4-7) A23(4-7) A23(4-7) A22(4-7) A22(4-7) A23(4-7) A23(4-7) + + const __m512i lhs_mat_01_21_sp2 = _mm512_shuffle_epi32(lhs_mat_01_21, (_MM_PERM_ENUM)245); //A20(12-15) A20(12-15) A21(12-15) A21(12-15) A20(12-15) A20(12-15) A21(12-15) A21(12-15) A20(12-15) A20(12-15) A21(12-15) A21(12-15) A20(12-15) A20(12-15) A21(12-15) A21(12-15) + const __m512i lhs_mat_23_21_sp2 = _mm512_shuffle_epi32(lhs_mat_23_21, (_MM_PERM_ENUM)245); //A22(12-15) A22(12-15) A23(12-15) A23(12-15) A22(12-15) A22(12-15) A23(12-15) A23(12-15) A22(12-15) A22(12-15) A23(12-15) A23(12-15) A22(12-15) A22(12-15) A23(12-15) A23(12-15) + + const __m512i lhs_mat_01_30_sp2 = _mm512_shuffle_epi32(lhs_mat_01_30, (_MM_PERM_ENUM)245); //A30(4-7) A30(4-7) A31(4-7) A31(4-7) A30(4-7) A30(4-7) A31(4-7) A31(4-7) A30(4-7) A30(4-7) A31(4-7) A31(4-7) A30(4-7) A30(4-7) A31(4-7) A31(4-7) + const __m512i lhs_mat_23_30_sp2 = _mm512_shuffle_epi32(lhs_mat_23_30, (_MM_PERM_ENUM)245); //A32(4-7) A32(4-7) A33(4-7) A33(4-7) A32(4-7) A32(4-7) A33(4-7) A33(4-7) A32(4-7) A32(4-7) A33(4-7) A33(4-7) A32(4-7) A32(4-7) A33(4-7) A33(4-7) + + const __m512i lhs_mat_01_31_sp2 = _mm512_shuffle_epi32(lhs_mat_01_31, (_MM_PERM_ENUM)245); //A30(12-15) A30(12-15) A31(12-15) A31(12-15) A30(12-15) A30(12-15) A31(12-15) A31(12-15) A30(12-15) A30(12-15) A31(12-15) A31(12-15) A30(12-15) A30(12-15) A31(12-15) A31(12-15) + const __m512i lhs_mat_23_31_sp2 = _mm512_shuffle_epi32(lhs_mat_23_31, (_MM_PERM_ENUM)245); //A32(12-15) A32(12-15) A33(12-15) A33(12-15) A32(12-15) A32(12-15) A33(12-15) A33(12-15) A32(12-15) A32(12-15) A33(12-15) A33(12-15) A32(12-15) A32(12-15) A33(12-15) A33(12-15) + + const __m512i lhs_mat_01_40_sp2 = _mm512_shuffle_epi32(lhs_mat_01_40, (_MM_PERM_ENUM)245); //A40(4-7) A40(4-7) A41(4-7) A41(4-7) A40(4-7) A40(4-7) A41(4-7) A41(4-7) A40(4-7) A40(4-7) A41(4-7) A41(4-7) A40(4-7) A40(4-7) A41(4-7) A41(4-7) + const __m512i lhs_mat_23_40_sp2 = _mm512_shuffle_epi32(lhs_mat_23_40, (_MM_PERM_ENUM)245); //A42(4-7) A42(4-7) A43(4-7) A43(4-7) A42(4-7) A42(4-7) A43(4-7) A43(4-7) A42(4-7) A42(4-7) A43(4-7) A43(4-7) A42(4-7) A42(4-7) A43(4-7) A43(4-7) + + const __m512i lhs_mat_01_41_sp2 = _mm512_shuffle_epi32(lhs_mat_01_41, (_MM_PERM_ENUM)245); //A40(12-15) A40(12-15) A41(12-15) A41(12-15) A40(12-15) A40(12-15) A41(12-15) A41(12-15) A40(12-15) A40(12-15) A41(12-15) A41(12-15) A40(12-15) A40(12-15) A41(12-15) A41(12-15) + const __m512i lhs_mat_23_41_sp2 = _mm512_shuffle_epi32(lhs_mat_23_41, (_MM_PERM_ENUM)245); //A42(12-15) A42(12-15) A43(12-15) A43(12-15) A42(12-15) A42(12-15) A43(12-15) A43(12-15) A42(12-15) A42(12-15) A43(12-15) A43(12-15) A42(12-15) A42(12-15) A43(12-15) A43(12-15) + + const __m512i lhs_mat_01_50_sp2 = _mm512_shuffle_epi32(lhs_mat_01_50, (_MM_PERM_ENUM)245); //A50(4-7) A50(4-7) A51(4-7) A51(4-7) A50(4-7) A50(4-7) A51(4-7) A51(4-7) A50(4-7) A50(4-7) A51(4-7) A51(4-7) A50(4-7) A50(4-7) A51(4-7) A51(4-7) + const __m512i lhs_mat_23_50_sp2 = _mm512_shuffle_epi32(lhs_mat_23_50, (_MM_PERM_ENUM)245); //A52(4-7) A52(4-7) A53(4-7) A53(4-7) A52(4-7) A52(4-7) A53(4-7) A53(4-7) A52(4-7) A52(4-7) A53(4-7) A53(4-7) A52(4-7) A52(4-7) A53(4-7) A53(4-7) + + const __m512i lhs_mat_01_51_sp2 = _mm512_shuffle_epi32(lhs_mat_01_51, (_MM_PERM_ENUM)245); //A50(12-15) A50(12-15) A51(12-15) A51(12-15) A50(12-15) A50(12-15) A51(12-15) A51(12-15) A50(12-15) A50(12-15) A51(12-15) A51(12-15) A50(12-15) A50(12-15) A51(12-15) A51(12-15) + const __m512i lhs_mat_23_51_sp2 = _mm512_shuffle_epi32(lhs_mat_23_51, (_MM_PERM_ENUM)245); //A52(12-15) A52(12-15) A53(12-15) A53(12-15) A52(12-15) A52(12-15) A53(12-15) A53(12-15) A52(12-15) A52(12-15) A53(12-15) A53(12-15) A52(12-15) A52(12-15) A53(12-15) A53(12-15) + + const __m512i lhs_mat_01_60_sp2 = _mm512_shuffle_epi32(lhs_mat_01_60, (_MM_PERM_ENUM)245); //A60(4-7) A60(4-7) A61(4-7) A61(4-7) A60(4-7) A60(4-7) A61(4-7) A61(4-7) A60(4-7) A60(4-7) A61(4-7) A61(4-7) A60(4-7) A60(4-7) A61(4-7) A61(4-7) + const __m512i lhs_mat_23_60_sp2 = _mm512_shuffle_epi32(lhs_mat_23_60, (_MM_PERM_ENUM)245); //A62(4-7) A62(4-7) A63(4-7) A63(4-7) A62(4-7) A62(4-7) A63(4-7) A63(4-7) A62(4-7) A62(4-7) A63(4-7) A63(4-7) A62(4-7) A62(4-7) A63(4-7) A63(4-7) + + const __m512i lhs_mat_01_61_sp2 = _mm512_shuffle_epi32(lhs_mat_01_61, (_MM_PERM_ENUM)245); //A60(12-15) A60(12-15) A61(12-15) A61(12-15) A60(12-15) A60(12-15) A61(12-15) A61(12-15) A60(12-15) A60(12-15) A61(12-15) A61(12-15) A60(12-15) A60(12-15) A61(12-15) A61(12-15) + const __m512i lhs_mat_23_61_sp2 = _mm512_shuffle_epi32(lhs_mat_23_61, (_MM_PERM_ENUM)245); //A62(12-15) A62(12-15) A63(12-15) A63(12-15) A62(12-15) A62(12-15) A63(12-15) A63(12-15) A62(12-15) A62(12-15) A63(12-15) A63(12-15) A62(12-15) A62(12-15) A63(12-15) A63(12-15) + + const __m512i lhs_mat_01_70_sp2 = _mm512_shuffle_epi32(lhs_mat_01_70, (_MM_PERM_ENUM)245); //A70(4-7) A70(4-7) A71(4-7) A71(4-7) A70(4-7) A70(4-7) A71(4-7) A71(4-7) A70(4-7) A70(4-7) A71(4-7) A71(4-7) A70(4-7) A70(4-7) A71(4-7) A71(4-7) + const __m512i lhs_mat_23_70_sp2 = _mm512_shuffle_epi32(lhs_mat_23_70, (_MM_PERM_ENUM)245); //A72(4-7) A72(4-7) A73(4-7) A73(4-7) A72(4-7) A72(4-7) A73(4-7) A73(4-7) A72(4-7) A72(4-7) A73(4-7) A73(4-7) A72(4-7) A72(4-7) A73(4-7) A73(4-7) + + const __m512i lhs_mat_01_71_sp2 = _mm512_shuffle_epi32(lhs_mat_01_71, (_MM_PERM_ENUM)245); //A70(12-15) A70(12-15) A71(12-15) A71(12-15) A70(12-15) A70(12-15) A71(12-15) A71(12-15) A70(12-15) A70(12-15) A71(12-15) A71(12-15) A70(12-15) A70(12-15) A71(12-15) A71(12-15) + const __m512i lhs_mat_23_71_sp2 = _mm512_shuffle_epi32(lhs_mat_23_71, (_MM_PERM_ENUM)245); //A72(12-15) A72(12-15) A73(12-15) A73(12-15) A72(12-15) A72(12-15) A73(12-15) A73(12-15) A72(12-15) A72(12-15) A73(12-15) A73(12-15) A72(12-15) A72(12-15) A73(12-15) A73(12-15) + + // The values arranged in shuffle patterns are operated with dot product operation within 32 bit lane i.e corresponding bytes and multiplied and added into 32 bit integers within 32 bit lane + __m512i iacc_mat_00_0_sp1 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_00_sp1, lhs_mat_01_00_sp1),_mm512_maddubs_epi16(rhs_mat_014589CD_01_sp1, lhs_mat_01_01_sp1)); + __m512i iacc_mat_01_0_sp1 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_00_sp1, lhs_mat_01_00_sp1),_mm512_maddubs_epi16(rhs_mat_2367ABEF_01_sp1, lhs_mat_01_01_sp1)); + + __m512i iacc_mat_10_0_sp1 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_00_sp1, lhs_mat_23_00_sp1),_mm512_maddubs_epi16(rhs_mat_014589CD_01_sp1, lhs_mat_23_01_sp1)); + __m512i iacc_mat_11_0_sp1 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_00_sp1, lhs_mat_23_00_sp1),_mm512_maddubs_epi16(rhs_mat_2367ABEF_01_sp1, lhs_mat_23_01_sp1)); + + __m512i iacc_mat_00_1_sp1 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_10_sp1, lhs_mat_01_10_sp1),_mm512_maddubs_epi16(rhs_mat_014589CD_11_sp1, lhs_mat_01_11_sp1)); + __m512i iacc_mat_01_1_sp1 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_10_sp1, lhs_mat_01_10_sp1),_mm512_maddubs_epi16(rhs_mat_2367ABEF_11_sp1, lhs_mat_01_11_sp1)); + + __m512i iacc_mat_10_1_sp1 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_10_sp1, lhs_mat_23_10_sp1),_mm512_maddubs_epi16(rhs_mat_014589CD_11_sp1, lhs_mat_23_11_sp1)); + __m512i iacc_mat_11_1_sp1 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_10_sp1, lhs_mat_23_10_sp1),_mm512_maddubs_epi16(rhs_mat_2367ABEF_11_sp1, lhs_mat_23_11_sp1)); + + __m512i iacc_mat_00_2_sp1 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_20_sp1, lhs_mat_01_20_sp1),_mm512_maddubs_epi16(rhs_mat_014589CD_21_sp1, lhs_mat_01_21_sp1)); + __m512i iacc_mat_01_2_sp1 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_20_sp1, lhs_mat_01_20_sp1),_mm512_maddubs_epi16(rhs_mat_2367ABEF_21_sp1, lhs_mat_01_21_sp1)); + + __m512i iacc_mat_10_2_sp1 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_20_sp1, lhs_mat_23_20_sp1),_mm512_maddubs_epi16(rhs_mat_014589CD_21_sp1, lhs_mat_23_21_sp1)); + __m512i iacc_mat_11_2_sp1 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_20_sp1, lhs_mat_23_20_sp1),_mm512_maddubs_epi16(rhs_mat_2367ABEF_21_sp1, lhs_mat_23_21_sp1)); + + __m512i iacc_mat_00_3_sp1 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_30_sp1, lhs_mat_01_30_sp1),_mm512_maddubs_epi16(rhs_mat_014589CD_31_sp1, lhs_mat_01_31_sp1)); + __m512i iacc_mat_01_3_sp1 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_30_sp1, lhs_mat_01_30_sp1),_mm512_maddubs_epi16(rhs_mat_2367ABEF_31_sp1, lhs_mat_01_31_sp1)); + + __m512i iacc_mat_10_3_sp1 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_30_sp1, lhs_mat_23_30_sp1),_mm512_maddubs_epi16(rhs_mat_014589CD_31_sp1, lhs_mat_23_31_sp1)); + __m512i iacc_mat_11_3_sp1 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_30_sp1, lhs_mat_23_30_sp1),_mm512_maddubs_epi16(rhs_mat_2367ABEF_31_sp1, lhs_mat_23_31_sp1)); + + __m512i iacc_mat_00_4_sp1 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_40_sp1, lhs_mat_01_40_sp1),_mm512_maddubs_epi16(rhs_mat_014589CD_41_sp1, lhs_mat_01_41_sp1)); + __m512i iacc_mat_01_4_sp1 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_40_sp1, lhs_mat_01_40_sp1),_mm512_maddubs_epi16(rhs_mat_2367ABEF_41_sp1, lhs_mat_01_41_sp1)); + + __m512i iacc_mat_10_4_sp1 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_40_sp1, lhs_mat_23_40_sp1),_mm512_maddubs_epi16(rhs_mat_014589CD_41_sp1, lhs_mat_23_41_sp1)); + __m512i iacc_mat_11_4_sp1 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_40_sp1, lhs_mat_23_40_sp1),_mm512_maddubs_epi16(rhs_mat_2367ABEF_41_sp1, lhs_mat_23_41_sp1)); + + __m512i iacc_mat_00_5_sp1 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_50_sp1, lhs_mat_01_50_sp1),_mm512_maddubs_epi16(rhs_mat_014589CD_51_sp1, lhs_mat_01_51_sp1)); + __m512i iacc_mat_01_5_sp1 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_50_sp1, lhs_mat_01_50_sp1),_mm512_maddubs_epi16(rhs_mat_2367ABEF_51_sp1, lhs_mat_01_51_sp1)); + + __m512i iacc_mat_10_5_sp1 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_50_sp1, lhs_mat_23_50_sp1),_mm512_maddubs_epi16(rhs_mat_014589CD_51_sp1, lhs_mat_23_51_sp1)); + __m512i iacc_mat_11_5_sp1 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_50_sp1, lhs_mat_23_50_sp1),_mm512_maddubs_epi16(rhs_mat_2367ABEF_51_sp1, lhs_mat_23_51_sp1)); + + __m512i iacc_mat_00_6_sp1 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_60_sp1, lhs_mat_01_60_sp1),_mm512_maddubs_epi16(rhs_mat_014589CD_61_sp1, lhs_mat_01_61_sp1)); + __m512i iacc_mat_01_6_sp1 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_60_sp1, lhs_mat_01_60_sp1),_mm512_maddubs_epi16(rhs_mat_2367ABEF_61_sp1, lhs_mat_01_61_sp1)); + + __m512i iacc_mat_10_6_sp1 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_60_sp1, lhs_mat_23_60_sp1),_mm512_maddubs_epi16(rhs_mat_014589CD_61_sp1, lhs_mat_23_61_sp1)); + __m512i iacc_mat_11_6_sp1 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_60_sp1, lhs_mat_23_60_sp1),_mm512_maddubs_epi16(rhs_mat_2367ABEF_61_sp1, lhs_mat_23_61_sp1)); + + __m512i iacc_mat_00_7_sp1 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_70_sp1, lhs_mat_01_70_sp1),_mm512_maddubs_epi16(rhs_mat_014589CD_71_sp1, lhs_mat_01_71_sp1)); + __m512i iacc_mat_01_7_sp1 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_70_sp1, lhs_mat_01_70_sp1),_mm512_maddubs_epi16(rhs_mat_2367ABEF_71_sp1, lhs_mat_01_71_sp1)); + + __m512i iacc_mat_10_7_sp1 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_70_sp1, lhs_mat_23_70_sp1),_mm512_maddubs_epi16(rhs_mat_014589CD_71_sp1, lhs_mat_23_71_sp1)); + __m512i iacc_mat_11_7_sp1 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_70_sp1, lhs_mat_23_70_sp1),_mm512_maddubs_epi16(rhs_mat_2367ABEF_71_sp1, lhs_mat_23_71_sp1)); + + + __m512i iacc_mat_00_0_sp2 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_00_sp2, lhs_mat_01_00_sp2),_mm512_maddubs_epi16(rhs_mat_014589CD_01_sp2, lhs_mat_01_01_sp2)); + __m512i iacc_mat_01_0_sp2 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_00_sp2, lhs_mat_01_00_sp2),_mm512_maddubs_epi16(rhs_mat_2367ABEF_01_sp2, lhs_mat_01_01_sp2)); + + __m512i iacc_mat_10_0_sp2 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_00_sp2, lhs_mat_23_00_sp2),_mm512_maddubs_epi16(rhs_mat_014589CD_01_sp2, lhs_mat_23_01_sp2)); + __m512i iacc_mat_11_0_sp2 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_00_sp2, lhs_mat_23_00_sp2),_mm512_maddubs_epi16(rhs_mat_2367ABEF_01_sp2, lhs_mat_23_01_sp2)); + + __m512i iacc_mat_00_1_sp2 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_10_sp2, lhs_mat_01_10_sp2),_mm512_maddubs_epi16(rhs_mat_014589CD_11_sp2, lhs_mat_01_11_sp2)); + __m512i iacc_mat_01_1_sp2 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_10_sp2, lhs_mat_01_10_sp2),_mm512_maddubs_epi16(rhs_mat_2367ABEF_11_sp2, lhs_mat_01_11_sp2)); + + __m512i iacc_mat_10_1_sp2 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_10_sp2, lhs_mat_23_10_sp2),_mm512_maddubs_epi16(rhs_mat_014589CD_11_sp2, lhs_mat_23_11_sp2)); + __m512i iacc_mat_11_1_sp2 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_10_sp2, lhs_mat_23_10_sp2),_mm512_maddubs_epi16(rhs_mat_2367ABEF_11_sp2, lhs_mat_23_11_sp2)); + + __m512i iacc_mat_00_2_sp2 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_20_sp2, lhs_mat_01_20_sp2),_mm512_maddubs_epi16(rhs_mat_014589CD_21_sp2, lhs_mat_01_21_sp2)); + __m512i iacc_mat_01_2_sp2 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_20_sp2, lhs_mat_01_20_sp2),_mm512_maddubs_epi16(rhs_mat_2367ABEF_21_sp2, lhs_mat_01_21_sp2)); + + __m512i iacc_mat_10_2_sp2 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_20_sp2, lhs_mat_23_20_sp2),_mm512_maddubs_epi16(rhs_mat_014589CD_21_sp2, lhs_mat_23_21_sp2)); + __m512i iacc_mat_11_2_sp2 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_20_sp2, lhs_mat_23_20_sp2),_mm512_maddubs_epi16(rhs_mat_2367ABEF_21_sp2, lhs_mat_23_21_sp2)); + + __m512i iacc_mat_00_3_sp2 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_30_sp2, lhs_mat_01_30_sp2),_mm512_maddubs_epi16(rhs_mat_014589CD_31_sp2, lhs_mat_01_31_sp2)); + __m512i iacc_mat_01_3_sp2 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_30_sp2, lhs_mat_01_30_sp2),_mm512_maddubs_epi16(rhs_mat_2367ABEF_31_sp2, lhs_mat_01_31_sp2)); + + __m512i iacc_mat_10_3_sp2 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_30_sp2, lhs_mat_23_30_sp2),_mm512_maddubs_epi16(rhs_mat_014589CD_31_sp2, lhs_mat_23_31_sp2)); + __m512i iacc_mat_11_3_sp2 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_30_sp2, lhs_mat_23_30_sp2),_mm512_maddubs_epi16(rhs_mat_2367ABEF_31_sp2, lhs_mat_23_31_sp2)); + + __m512i iacc_mat_00_4_sp2 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_40_sp2, lhs_mat_01_40_sp2),_mm512_maddubs_epi16(rhs_mat_014589CD_41_sp2, lhs_mat_01_41_sp2)); + __m512i iacc_mat_01_4_sp2 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_40_sp2, lhs_mat_01_40_sp2),_mm512_maddubs_epi16(rhs_mat_2367ABEF_41_sp2, lhs_mat_01_41_sp2)); + + __m512i iacc_mat_10_4_sp2 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_40_sp2, lhs_mat_23_40_sp2),_mm512_maddubs_epi16(rhs_mat_014589CD_41_sp2, lhs_mat_23_41_sp2)); + __m512i iacc_mat_11_4_sp2 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_40_sp2, lhs_mat_23_40_sp2),_mm512_maddubs_epi16(rhs_mat_2367ABEF_41_sp2, lhs_mat_23_41_sp2)); + + __m512i iacc_mat_00_5_sp2 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_50_sp2, lhs_mat_01_50_sp2),_mm512_maddubs_epi16(rhs_mat_014589CD_51_sp2, lhs_mat_01_51_sp2)); + __m512i iacc_mat_01_5_sp2 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_50_sp2, lhs_mat_01_50_sp2),_mm512_maddubs_epi16(rhs_mat_2367ABEF_51_sp2, lhs_mat_01_51_sp2)); + + __m512i iacc_mat_10_5_sp2 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_50_sp2, lhs_mat_23_50_sp2),_mm512_maddubs_epi16(rhs_mat_014589CD_51_sp2, lhs_mat_23_51_sp2)); + __m512i iacc_mat_11_5_sp2 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_50_sp2, lhs_mat_23_50_sp2),_mm512_maddubs_epi16(rhs_mat_2367ABEF_51_sp2, lhs_mat_23_51_sp2)); + + __m512i iacc_mat_00_6_sp2 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_60_sp2, lhs_mat_01_60_sp2),_mm512_maddubs_epi16(rhs_mat_014589CD_61_sp2, lhs_mat_01_61_sp2)); + __m512i iacc_mat_01_6_sp2 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_60_sp2, lhs_mat_01_60_sp2),_mm512_maddubs_epi16(rhs_mat_2367ABEF_61_sp2, lhs_mat_01_61_sp2)); + + __m512i iacc_mat_10_6_sp2 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_60_sp2, lhs_mat_23_60_sp2),_mm512_maddubs_epi16(rhs_mat_014589CD_61_sp2, lhs_mat_23_61_sp2)); + __m512i iacc_mat_11_6_sp2 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_60_sp2, lhs_mat_23_60_sp2),_mm512_maddubs_epi16(rhs_mat_2367ABEF_61_sp2, lhs_mat_23_61_sp2)); + + __m512i iacc_mat_00_7_sp2 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_70_sp2, lhs_mat_01_70_sp2),_mm512_maddubs_epi16(rhs_mat_014589CD_71_sp2, lhs_mat_01_71_sp2)); + __m512i iacc_mat_01_7_sp2 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_70_sp2, lhs_mat_01_70_sp2),_mm512_maddubs_epi16(rhs_mat_2367ABEF_71_sp2, lhs_mat_01_71_sp2)); + + __m512i iacc_mat_10_7_sp2 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_70_sp2, lhs_mat_23_70_sp2),_mm512_maddubs_epi16(rhs_mat_014589CD_71_sp2, lhs_mat_23_71_sp2)); + __m512i iacc_mat_11_7_sp2 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_70_sp2, lhs_mat_23_70_sp2),_mm512_maddubs_epi16(rhs_mat_2367ABEF_71_sp2, lhs_mat_23_71_sp2)); + + // Combine results from both shuffle patterns for each output block + __m512i iacc_mat_00_0 = _mm512_add_epi16(iacc_mat_00_0_sp1, iacc_mat_00_0_sp2); + __m512i iacc_mat_01_0 = _mm512_add_epi16(iacc_mat_01_0_sp1, iacc_mat_01_0_sp2); + __m512i iacc_mat_10_0 = _mm512_add_epi16(iacc_mat_10_0_sp1, iacc_mat_10_0_sp2); + __m512i iacc_mat_11_0 = _mm512_add_epi16(iacc_mat_11_0_sp1, iacc_mat_11_0_sp2); + + __m512i iacc_mat_00_1 = _mm512_add_epi16(iacc_mat_00_1_sp1, iacc_mat_00_1_sp2); + __m512i iacc_mat_01_1 = _mm512_add_epi16(iacc_mat_01_1_sp1, iacc_mat_01_1_sp2); + __m512i iacc_mat_10_1 = _mm512_add_epi16(iacc_mat_10_1_sp1, iacc_mat_10_1_sp2); + __m512i iacc_mat_11_1 = _mm512_add_epi16(iacc_mat_11_1_sp1, iacc_mat_11_1_sp2); + + __m512i iacc_mat_00_2 = _mm512_add_epi16(iacc_mat_00_2_sp1, iacc_mat_00_2_sp2); + __m512i iacc_mat_01_2 = _mm512_add_epi16(iacc_mat_01_2_sp1, iacc_mat_01_2_sp2); + __m512i iacc_mat_10_2 = _mm512_add_epi16(iacc_mat_10_2_sp1, iacc_mat_10_2_sp2); + __m512i iacc_mat_11_2 = _mm512_add_epi16(iacc_mat_11_2_sp1, iacc_mat_11_2_sp2); + + __m512i iacc_mat_00_3 = _mm512_add_epi16(iacc_mat_00_3_sp1, iacc_mat_00_3_sp2); + __m512i iacc_mat_01_3 = _mm512_add_epi16(iacc_mat_01_3_sp1, iacc_mat_01_3_sp2); + __m512i iacc_mat_10_3 = _mm512_add_epi16(iacc_mat_10_3_sp1, iacc_mat_10_3_sp2); + __m512i iacc_mat_11_3 = _mm512_add_epi16(iacc_mat_11_3_sp1, iacc_mat_11_3_sp2); + + __m512i iacc_mat_00_4 = _mm512_add_epi16(iacc_mat_00_4_sp1, iacc_mat_00_4_sp2); + __m512i iacc_mat_01_4 = _mm512_add_epi16(iacc_mat_01_4_sp1, iacc_mat_01_4_sp2); + __m512i iacc_mat_10_4 = _mm512_add_epi16(iacc_mat_10_4_sp1, iacc_mat_10_4_sp2); + __m512i iacc_mat_11_4 = _mm512_add_epi16(iacc_mat_11_4_sp1, iacc_mat_11_4_sp2); + + __m512i iacc_mat_00_5 = _mm512_add_epi16(iacc_mat_00_5_sp1, iacc_mat_00_5_sp2); + __m512i iacc_mat_01_5 = _mm512_add_epi16(iacc_mat_01_5_sp1, iacc_mat_01_5_sp2); + __m512i iacc_mat_10_5 = _mm512_add_epi16(iacc_mat_10_5_sp1, iacc_mat_10_5_sp2); + __m512i iacc_mat_11_5 = _mm512_add_epi16(iacc_mat_11_5_sp1, iacc_mat_11_5_sp2); + + __m512i iacc_mat_00_6 = _mm512_add_epi16(iacc_mat_00_6_sp1, iacc_mat_00_6_sp2); + __m512i iacc_mat_01_6 = _mm512_add_epi16(iacc_mat_01_6_sp1, iacc_mat_01_6_sp2); + __m512i iacc_mat_10_6 = _mm512_add_epi16(iacc_mat_10_6_sp1, iacc_mat_10_6_sp2); + __m512i iacc_mat_11_6 = _mm512_add_epi16(iacc_mat_11_6_sp1, iacc_mat_11_6_sp2); + + __m512i iacc_mat_00_7 = _mm512_add_epi16(iacc_mat_00_7_sp1, iacc_mat_00_7_sp2); + __m512i iacc_mat_01_7 = _mm512_add_epi16(iacc_mat_01_7_sp1, iacc_mat_01_7_sp2); + __m512i iacc_mat_10_7 = _mm512_add_epi16(iacc_mat_10_7_sp1, iacc_mat_10_7_sp2); + __m512i iacc_mat_11_7 = _mm512_add_epi16(iacc_mat_11_7_sp1, iacc_mat_11_7_sp2); + + // Output of both shuffle patterns are added in order to sum dot product outputs of all 32 values in block + iacc_mat_00_0 = _mm512_madd_epi16(iacc_mat_00_0, scale_014589CD_0); + iacc_mat_01_0 = _mm512_madd_epi16(iacc_mat_01_0, scale_2367ABEF_0); + iacc_mat_10_0 = _mm512_madd_epi16(iacc_mat_10_0, scale_014589CD_0); + iacc_mat_11_0 = _mm512_madd_epi16(iacc_mat_11_0, scale_2367ABEF_0); + + iacc_mat_00_1 = _mm512_madd_epi16(iacc_mat_00_1, scale_014589CD_1); + iacc_mat_01_1 = _mm512_madd_epi16(iacc_mat_01_1, scale_2367ABEF_1); + iacc_mat_10_1 = _mm512_madd_epi16(iacc_mat_10_1, scale_014589CD_1); + iacc_mat_11_1 = _mm512_madd_epi16(iacc_mat_11_1, scale_2367ABEF_1); + + iacc_mat_00_2 = _mm512_madd_epi16(iacc_mat_00_2, scale_014589CD_2); + iacc_mat_01_2 = _mm512_madd_epi16(iacc_mat_01_2, scale_2367ABEF_2); + iacc_mat_10_2 = _mm512_madd_epi16(iacc_mat_10_2, scale_014589CD_2); + iacc_mat_11_2 = _mm512_madd_epi16(iacc_mat_11_2, scale_2367ABEF_2); + + iacc_mat_00_3 = _mm512_madd_epi16(iacc_mat_00_3, scale_014589CD_3); + iacc_mat_01_3 = _mm512_madd_epi16(iacc_mat_01_3, scale_2367ABEF_3); + iacc_mat_10_3 = _mm512_madd_epi16(iacc_mat_10_3, scale_014589CD_3); + iacc_mat_11_3 = _mm512_madd_epi16(iacc_mat_11_3, scale_2367ABEF_3); + + iacc_mat_00_4 = _mm512_madd_epi16(iacc_mat_00_4, scale_014589CD_4); + iacc_mat_01_4 = _mm512_madd_epi16(iacc_mat_01_4, scale_2367ABEF_4); + iacc_mat_10_4 = _mm512_madd_epi16(iacc_mat_10_4, scale_014589CD_4); + iacc_mat_11_4 = _mm512_madd_epi16(iacc_mat_11_4, scale_2367ABEF_4); + + iacc_mat_00_5 = _mm512_madd_epi16(iacc_mat_00_5, scale_014589CD_5); + iacc_mat_01_5 = _mm512_madd_epi16(iacc_mat_01_5, scale_2367ABEF_5); + iacc_mat_10_5 = _mm512_madd_epi16(iacc_mat_10_5, scale_014589CD_5); + iacc_mat_11_5 = _mm512_madd_epi16(iacc_mat_11_5, scale_2367ABEF_5); + + iacc_mat_00_6 = _mm512_madd_epi16(iacc_mat_00_6, scale_014589CD_6); + iacc_mat_01_6 = _mm512_madd_epi16(iacc_mat_01_6, scale_2367ABEF_6); + iacc_mat_10_6 = _mm512_madd_epi16(iacc_mat_10_6, scale_014589CD_6); + iacc_mat_11_6 = _mm512_madd_epi16(iacc_mat_11_6, scale_2367ABEF_6); + + iacc_mat_00_7 = _mm512_madd_epi16(iacc_mat_00_7, scale_014589CD_7); + iacc_mat_01_7 = _mm512_madd_epi16(iacc_mat_01_7, scale_2367ABEF_7); + iacc_mat_10_7 = _mm512_madd_epi16(iacc_mat_10_7, scale_014589CD_7); + iacc_mat_11_7 = _mm512_madd_epi16(iacc_mat_11_7, scale_2367ABEF_7); + + __m512i iacc_mat_00 = _mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(iacc_mat_00_0, iacc_mat_00_1), _mm512_add_epi32(iacc_mat_00_2, iacc_mat_00_3)), _mm512_add_epi32(_mm512_add_epi32(iacc_mat_00_4, iacc_mat_00_5), _mm512_add_epi32(iacc_mat_00_6, iacc_mat_00_7))); + __m512i iacc_mat_01 = _mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(iacc_mat_01_0, iacc_mat_01_1), _mm512_add_epi32(iacc_mat_01_2, iacc_mat_01_3)), _mm512_add_epi32(_mm512_add_epi32(iacc_mat_01_4, iacc_mat_01_5), _mm512_add_epi32(iacc_mat_01_6, iacc_mat_01_7))); + __m512i iacc_mat_10 = _mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(iacc_mat_10_0, iacc_mat_10_1), _mm512_add_epi32(iacc_mat_10_2, iacc_mat_10_3)), _mm512_add_epi32(_mm512_add_epi32(iacc_mat_10_4, iacc_mat_10_5), _mm512_add_epi32(iacc_mat_10_6, iacc_mat_10_7))); + __m512i iacc_mat_11 = _mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(iacc_mat_11_0, iacc_mat_11_1), _mm512_add_epi32(iacc_mat_11_2, iacc_mat_11_3)), _mm512_add_epi32(_mm512_add_epi32(iacc_mat_11_4, iacc_mat_11_5), _mm512_add_epi32(iacc_mat_11_6, iacc_mat_11_7))); + + // Straighten out to make 4 row vectors + __m512i iacc_row_0 = _mm512_mask_blend_epi32(0xCCCC, iacc_mat_00, _mm512_shuffle_epi32(iacc_mat_01, (_MM_PERM_ENUM)78)); + __m512i iacc_row_1 = _mm512_mask_blend_epi32(0xCCCC, _mm512_shuffle_epi32(iacc_mat_00, (_MM_PERM_ENUM)78), iacc_mat_01); + __m512i iacc_row_2 = _mm512_mask_blend_epi32(0xCCCC, iacc_mat_10, _mm512_shuffle_epi32(iacc_mat_11, (_MM_PERM_ENUM)78)); + __m512i iacc_row_3 = _mm512_mask_blend_epi32(0xCCCC, _mm512_shuffle_epi32(iacc_mat_10, (_MM_PERM_ENUM)78), iacc_mat_11); + + // Load the scale(d) values for all the 4 Q8_k blocks and repeat it across lanes + const __m128 row_scale_f32_sse = _mm_load_ps(a_ptr[b].d); + const __m256 row_scale_f32_ymm = _mm256_set_m128(row_scale_f32_sse, row_scale_f32_sse); + const __m512 row_scale_f32 = _mm512_insertf32x8(_mm512_castps256_ps512(row_scale_f32_ymm), row_scale_f32_ymm, 1); + + // Multiply with appropiate scales and accumulate (for both d and dmin) below + acc_rows[0] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_0), _mm512_mul_ps(col_scale_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 0)), acc_rows[0]); + acc_rows[1] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_1), _mm512_mul_ps(col_scale_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 85)), acc_rows[1]); + acc_rows[2] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_2), _mm512_mul_ps(col_scale_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 170)), acc_rows[2]); + acc_rows[3] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_3), _mm512_mul_ps(col_scale_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 255)), acc_rows[3]); + + // Take two bsums from two Q8_Ks at a time and multiply with corresponding mins values from each Q2_K + __m512i iacc_row_min_0_01 = _mm512_madd_epi16(_mm512_shuffle_epi32(lhs_bsums_01_0123, (_MM_PERM_ENUM)0), mins_01); + __m512i iacc_row_min_1_01 = _mm512_madd_epi16(_mm512_shuffle_epi32(lhs_bsums_01_0123, (_MM_PERM_ENUM)170), mins_01); + __m512i iacc_row_min_2_01 = _mm512_madd_epi16(_mm512_shuffle_epi32(lhs_bsums_23_0123, (_MM_PERM_ENUM)0), mins_01); + __m512i iacc_row_min_3_01 = _mm512_madd_epi16(_mm512_shuffle_epi32(lhs_bsums_23_0123, (_MM_PERM_ENUM)170), mins_01); + + __m512i iacc_row_min_0_23 = _mm512_madd_epi16(_mm512_shuffle_epi32(lhs_bsums_01_0123, (_MM_PERM_ENUM)85), mins_23); + __m512i iacc_row_min_1_23 = _mm512_madd_epi16(_mm512_shuffle_epi32(lhs_bsums_01_0123, (_MM_PERM_ENUM)255), mins_23); + __m512i iacc_row_min_2_23 = _mm512_madd_epi16(_mm512_shuffle_epi32(lhs_bsums_23_0123, (_MM_PERM_ENUM)85), mins_23); + __m512i iacc_row_min_3_23 = _mm512_madd_epi16(_mm512_shuffle_epi32(lhs_bsums_23_0123, (_MM_PERM_ENUM)255), mins_23); + + __m512i iacc_row_min_0_45 = _mm512_madd_epi16(_mm512_shuffle_epi32(lhs_bsums_01_4567, (_MM_PERM_ENUM)0), mins_45); + __m512i iacc_row_min_1_45 = _mm512_madd_epi16(_mm512_shuffle_epi32(lhs_bsums_01_4567, (_MM_PERM_ENUM)170), mins_45); + __m512i iacc_row_min_2_45 = _mm512_madd_epi16(_mm512_shuffle_epi32(lhs_bsums_23_4567, (_MM_PERM_ENUM)0), mins_45); + __m512i iacc_row_min_3_45 = _mm512_madd_epi16(_mm512_shuffle_epi32(lhs_bsums_23_4567, (_MM_PERM_ENUM)170), mins_45); + + __m512i iacc_row_min_0_67 = _mm512_madd_epi16(_mm512_shuffle_epi32(lhs_bsums_01_4567, (_MM_PERM_ENUM)85), mins_67); + __m512i iacc_row_min_1_67 = _mm512_madd_epi16(_mm512_shuffle_epi32(lhs_bsums_01_4567, (_MM_PERM_ENUM)255), mins_67); + __m512i iacc_row_min_2_67 = _mm512_madd_epi16(_mm512_shuffle_epi32(lhs_bsums_23_4567, (_MM_PERM_ENUM)85), mins_67); + __m512i iacc_row_min_3_67 = _mm512_madd_epi16(_mm512_shuffle_epi32(lhs_bsums_23_4567, (_MM_PERM_ENUM)255), mins_67); + + __m512i iacc_row_min_0 = _mm512_add_epi32(_mm512_add_epi32(iacc_row_min_0_01, iacc_row_min_0_23), _mm512_add_epi32(iacc_row_min_0_45,iacc_row_min_0_67)); + __m512i iacc_row_min_1 = _mm512_add_epi32(_mm512_add_epi32(iacc_row_min_1_01, iacc_row_min_1_23), _mm512_add_epi32(iacc_row_min_1_45,iacc_row_min_1_67)); + __m512i iacc_row_min_2 = _mm512_add_epi32(_mm512_add_epi32(iacc_row_min_2_01, iacc_row_min_2_23), _mm512_add_epi32(iacc_row_min_2_45,iacc_row_min_2_67)); + __m512i iacc_row_min_3 = _mm512_add_epi32(_mm512_add_epi32(iacc_row_min_3_01, iacc_row_min_3_23), _mm512_add_epi32(iacc_row_min_3_45,iacc_row_min_3_67)); + + acc_min_rows[0] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_min_0), _mm512_mul_ps(col_dmin_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 0)), acc_min_rows[0]); + acc_min_rows[1] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_min_1), _mm512_mul_ps(col_dmin_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 85)), acc_min_rows[1]); + acc_min_rows[2] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_min_2), _mm512_mul_ps(col_dmin_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 170)), acc_min_rows[2]); + acc_min_rows[3] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_min_3), _mm512_mul_ps(col_dmin_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 255)), acc_min_rows[3]); + } + } + // Store accumlated values + for (int i = 0; i < 4; i++) { + _mm512_storeu_ps((float * )(s + ((y * 4 + i) * bs + x * 8)), _mm512_sub_ps(acc_rows[i], acc_min_rows[i])); + } + } + } + + if (anc != nc) { + xstart = anc/8; + y = 0; + } + +#endif //AVX512F + + // Take group of four block_q8_Kx4 structures at each pass of the loop and perform dot product operation + for (; y < anr / 4; y += 4) { + + const block_q8_Kx4 * a_ptrs[4]; + + a_ptrs[0] = a_ptr_start + (y * nb); + for (int i = 0; i < 3; ++i) { + a_ptrs[i + 1] = a_ptrs[i] + nb; + } + + // Take group of eight block_q2_kx8 structures at each pass of the loop and perform dot product operation + for (int64_t x = xstart; x < nc / 8; x++) { + + const block_q2_Kx8 * b_ptr = b_ptr_start + (x * b_nb); + + // Master FP accumulators + __m256 acc_rows[16]; + for (int i = 0; i < 16; i++) { + acc_rows[i] = _mm256_setzero_ps(); + } + + __m256 acc_min_rows[16]; + for (int i = 0; i < 16; i++) { + acc_min_rows[i] = _mm256_setzero_ps(); + } + + // For super block + for (int64_t b = 0; b < nb; b++) { + // Delta values - Load the eight scale values of block_q2_kx8 + const __m256 col_scale_f32 = GGML_F32Cx8_LOAD(b_ptr[b].d); + + // dmin values - Load the eight dmin values of block_q2_kx8 + const __m256 col_dmin_f32 = GGML_F32Cx8_LOAD(b_ptr[b].dmin); + + // Loop to iterate over the sixteen sub blocks of a super block - eight sub blocks are processed per iteration + for (int sb = 0; sb < QK_K / 128; sb++) { + + // Load the eight block_q2_K for eight sub blocks quantized values interleaved with each other in chunks of eight bytes - B0,B1 ....B6,B7 + const __m256i rhs_raw_mat_0123_0 = _mm256_loadu_si256((const __m256i *)(b_ptr[b].qs + sb * 256)); + const __m256i rhs_raw_mat_4567_0 = _mm256_loadu_si256((const __m256i *)(b_ptr[b].qs + 32 + sb * 256)); + const __m256i rhs_raw_mat_0123_1 = _mm256_loadu_si256((const __m256i *)(b_ptr[b].qs + 64 + sb * 256)); + const __m256i rhs_raw_mat_4567_1 = _mm256_loadu_si256((const __m256i *)(b_ptr[b].qs + 96 + sb * 256)); + const __m256i rhs_raw_mat_0123_2 = _mm256_loadu_si256((const __m256i *)(b_ptr[b].qs + 128 + sb * 256)); + const __m256i rhs_raw_mat_4567_2 = _mm256_loadu_si256((const __m256i *)(b_ptr[b].qs + 160 + sb * 256)); + const __m256i rhs_raw_mat_0123_3 = _mm256_loadu_si256((const __m256i *)(b_ptr[b].qs + 192 + sb * 256)); + const __m256i rhs_raw_mat_4567_3 = _mm256_loadu_si256((const __m256i *)(b_ptr[b].qs + 224 + sb * 256)); + + // Save the values in the following vectors in the formats B0B1B4B5, B2B3B6B7 for further processing and storing of values + //superblock sub block which part of sub block + const __m256i rhs_raw_mat_0145_0 = _mm256_blend_epi32(rhs_raw_mat_0123_0, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_0, requiredOrder), 240); + const __m256i rhs_raw_mat_2367_0 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_0, requiredOrder), rhs_raw_mat_4567_0, 240); + + const __m256i rhs_raw_mat_0145_1 = _mm256_blend_epi32(rhs_raw_mat_0123_1, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_1, requiredOrder), 240); + const __m256i rhs_raw_mat_2367_1 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_1, requiredOrder), rhs_raw_mat_4567_1, 240); + + const __m256i rhs_raw_mat_0145_2 = _mm256_blend_epi32(rhs_raw_mat_0123_2, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_2, requiredOrder), 240); + const __m256i rhs_raw_mat_2367_2 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_2, requiredOrder), rhs_raw_mat_4567_2, 240); + + const __m256i rhs_raw_mat_0145_3 = _mm256_blend_epi32(rhs_raw_mat_0123_3, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_3, requiredOrder), 240); + const __m256i rhs_raw_mat_2367_3 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_3, requiredOrder), rhs_raw_mat_4567_3, 240); + + // 2-bit -> 8-bit + // First sub block of the eight sub blocks processed in the iteration + const __m256i rhs_mat_0145_00 = _mm256_and_si256(rhs_raw_mat_0145_0, m3b); //B00(0-7) B01(0-7) B04(0-7) B05(0-7) + const __m256i rhs_mat_2367_00 = _mm256_and_si256(rhs_raw_mat_2367_0, m3b); //B02(0-7) B03(0-7) B06(0-7) B07(0-7) + + const __m256i rhs_mat_0145_01 = _mm256_and_si256(rhs_raw_mat_0145_1, m3b); //B00(8-15) B01(8-15) B04(8-15) B05(8-15) + const __m256i rhs_mat_2367_01 = _mm256_and_si256(rhs_raw_mat_2367_1, m3b); //B02(8-15) B03(8-15) B06(8-15) B07(8-15) + + // Second sub block of the eight sub blocks processed in the iteration + const __m256i rhs_mat_0145_10 = _mm256_and_si256(rhs_raw_mat_0145_2, m3b); //B10(0-7) B11(0-7) B14(0-7) B15(0-7) + const __m256i rhs_mat_2367_10 = _mm256_and_si256(rhs_raw_mat_2367_2, m3b); //B12(0-7) B13(0-7) B16(0-7) B17(0-7) + + const __m256i rhs_mat_0145_11 = _mm256_and_si256(rhs_raw_mat_0145_3, m3b); //B10(8-15) B11(8-15) B14(8-15) B15(8-15) + const __m256i rhs_mat_2367_11 = _mm256_and_si256(rhs_raw_mat_2367_3, m3b); //B12(8-15) B13(8-15) B16(8-15) B17(8-15) + + // Third sub block of the eight sub blocks processed in the iteration + const __m256i rhs_mat_0145_20 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_0145_0, 2), m3b); //B20(0-7) B21(0-7) B24(0-7) B25(0-7) + const __m256i rhs_mat_2367_20 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_2367_0, 2), m3b); //B22(0-7) B23(0-7) B26(0-7) B27(0-7) + + const __m256i rhs_mat_0145_21 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_0145_1, 2), m3b); //B20(8-15) B21(8-15) B24(8-15) B25(8-15) + const __m256i rhs_mat_2367_21 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_2367_1, 2), m3b); //B22(8-15) B23(8-15) B26(8-15) B27(8-15) + + // Fourth sub block of the eight sub blocks processed in the iteration + const __m256i rhs_mat_0145_30 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_0145_2, 2), m3b); //B30(0-7) B31(0-7) B34(0-7) B35(0-7) + const __m256i rhs_mat_2367_30 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_2367_2, 2), m3b); //B32(0-7) B33(0-7) B36(0-7) B37(0-7) + + const __m256i rhs_mat_0145_31 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_0145_3, 2), m3b); //B30(8-15) B31(8-15) B34(8-15) B35(8-15) + const __m256i rhs_mat_2367_31 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_2367_3, 2), m3b); //B32(8-15) B33(8-15) B36(8-15) B37(8-15) + + // Fifth sub block of the eight sub blocks processed in the iteration + const __m256i rhs_mat_0145_40 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_0145_0, 4), m3b); //B40(0-7) B41(0-7) B44(0-7) B45(0-7) + const __m256i rhs_mat_2367_40 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_2367_0, 4), m3b); //B42(0-7) B43(0-7) B46(0-7) B47(0-7) + + const __m256i rhs_mat_0145_41 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_0145_1, 4), m3b); //B40(8-15) B41(8-15) B44(8-15) B45(8-15) + const __m256i rhs_mat_2367_41 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_2367_1, 4), m3b); //B42(8-15) B43(8-15) B46(8-15) B47(8-15) + + // Sixth sub block of the eight sub blocks processed in the iteration + const __m256i rhs_mat_0145_50 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_0145_2, 4), m3b); //B50(0-7) B51(0-7) B54(0-7) B55(0-7) + const __m256i rhs_mat_2367_50 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_2367_2, 4), m3b); //B52(0-7) B53(0-7) B56(0-7) B57(0-7) + + const __m256i rhs_mat_0145_51 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_0145_3, 4), m3b); //B50(8-15) B51(8-15) B54(8-15) B55(8-15) + const __m256i rhs_mat_2367_51 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_2367_3, 4), m3b); //B52(8-15) B53(8-15) B56(8-15) B57(8-15) + + // Seventh sub block of the eight sub blocks processed in the iteration + const __m256i rhs_mat_0145_60 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_0145_0, 6), m3b); //B60(0-7) B61(0-7) B64(0-7) B65(0-7) + const __m256i rhs_mat_2367_60 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_2367_0, 6), m3b); //B62(0-7) B63(0-7) B66(0-7) B67(0-7) + + const __m256i rhs_mat_0145_61 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_0145_1, 6), m3b); //B60(8-15) B61(8-15) B64(8-15) B65(8-15) + const __m256i rhs_mat_2367_61 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_2367_1, 6), m3b); //B62(8-15) B63(8-15) B66(8-15) B67(8-15) + + // Eighth sub block of the eight sub blocks processed in the iteration + const __m256i rhs_mat_0145_70 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_0145_2, 6), m3b); //B70(0-7) B71(0-7) B74(0-7) B75(0-7) + const __m256i rhs_mat_2367_70 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_2367_2, 6), m3b); //B72(0-7) B73(0-7) B76(0-7) B77(0-7) + + const __m256i rhs_mat_0145_71 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_0145_3, 6), m3b); //B70(8-15) B71(8-15) B74(8-15) B75(8-15) + const __m256i rhs_mat_2367_71 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_2367_3, 6), m3b); //B72(8-15) B73(8-15) B76(8-15) B77(8-15) + + // Shuffle pattern one - right side input + const __m256i rhs_mat_0145_00_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_00, 136); //B00(0-3) B01(0-3) B00(0-3) B01(0-3) B04(0-3) B05(0-3) B04(0-3) B05(0-3) + const __m256i rhs_mat_2367_00_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_00, 136); //B02(0-3) B03(0-3) B02(0-3) B03(0-3) B06(0-3) B07(0-3) B06(0-3) B07(0-3) + + const __m256i rhs_mat_0145_01_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_01, 136); //B00(8-11) B01(8-11) B00(8-11) B01(8-11) B04(8-11) B05(8-11) B04(8-11) B05(8-11) + const __m256i rhs_mat_2367_01_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_01, 136); //B02(8-11) B03(8-11) B02(8-11) B03(8-11) B06(8-11) B07(8-11) B06(8-11) B07(8-11) + + const __m256i rhs_mat_0145_10_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_10, 136); //B10(0-3) B11(0-3) B10(0-3) B11(0-3) B14(0-3) B15(0-3) B14(0-3) B15(0-3) + const __m256i rhs_mat_2367_10_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_10, 136); //B12(0-3) B13(0-3) B12(0-3) B13(0-3) B16(0-3) B17(0-3) B16(0-3) B17(0-3) + + const __m256i rhs_mat_0145_11_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_11, 136); //B10(8-11) B11(8-11) B10(8-11) B11(8-11) B14(8-11) B15(8-11) B14(8-11) B15(8-11) + const __m256i rhs_mat_2367_11_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_11, 136); //B12(8-11) B13(8-11) B12(8-11) B13(8-11) B16(8-11) B17(8-11) B16(8-11) B17(8-11) + + const __m256i rhs_mat_0145_20_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_20, 136); //B20(0-3) B21(0-3) B20(0-3) B21(0-3) B24(0-3) B25(0-3) B24(0-3) B25(0-3) + const __m256i rhs_mat_2367_20_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_20, 136); //B22(0-3) B23(0-3) B22(0-3) B23(0-3) B26(0-3) B27(0-3) B26(0-3) B27(0-3) + + const __m256i rhs_mat_0145_21_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_21, 136); //B20(8-11) B21(8-11) B20(8-11) B21(8-11) B24(8-11) B25(8-11) B24(8-11) B25(8-11) + const __m256i rhs_mat_2367_21_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_21, 136); //B22(8-11) B23(8-11) B22(8-11) B23(8-11) B26(8-11) B27(8-11) B26(8-11) B27(8-11) + + const __m256i rhs_mat_0145_30_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_30, 136); //B30(0-3) B31(0-3) B30(0-3) B31(0-3) B34(0-3) B35(0-3) B34(0-3) B35(0-3) + const __m256i rhs_mat_2367_30_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_30, 136); //B32(0-3) B33(0-3) B32(0-3) B33(0-3) B36(0-3) B37(0-3) B36(0-3) B37(0-3) + + const __m256i rhs_mat_0145_31_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_31, 136); //B30(8-11) B31(8-11) B30(8-11) B31(8-11) B34(8-11) B35(8-11) B34(8-11) B35(8-11 + const __m256i rhs_mat_2367_31_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_31, 136); //B32(8-11) B33(8-11) B32(8-11) B33(8-11) B36(8-11) B37(8-11) B36(8-11) B37(8-11) + + const __m256i rhs_mat_0145_40_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_40, 136); //B40(0-3) B41(0-3) B40(0-3) B41(0-3) B44(0-3) B45(0-3) B44(0-3) B45(0-3) + const __m256i rhs_mat_2367_40_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_40, 136); //B42(0-3) B43(0-3) B42(0-3) B43(0-3) B46(0-3) B47(0-3) B46(0-3) B47(0-3) + + const __m256i rhs_mat_0145_41_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_41, 136); //B40(8-11) B41(8-11) B40(8-11) B41(8-11) B44(8-11) B45(8-11) B44(8-11) B45(8-11) + const __m256i rhs_mat_2367_41_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_41, 136); //B42(8-11) B43(8-11) B42(8-11) B43(8-11) B46(8-11) B47(8-11) B46(8-11) B47(8-11) + + const __m256i rhs_mat_0145_50_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_50, 136); //B50(0-3) B51(0-3) B50(0-3) B51(0-3) B54(0-3) B55(0-3) B54(0-3) B55(0-3) + const __m256i rhs_mat_2367_50_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_50, 136); //B52(0-3) B53(0-3) B52(0-3) B53(0-3) B56(0-3) B57(0-3) B56(0-3) B57(0-3) + + const __m256i rhs_mat_0145_51_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_51, 136); //B50(8-11) B51(8-11) B50(8-11) B51(8-11) B54(8-11) B55(8-11) B54(8-11) B55(8-11) + const __m256i rhs_mat_2367_51_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_51, 136); //B52(8-11) B53(8-11) B52(8-11) B53(8-11) B56(8-11) B57(8-11) B56(8-11) B57(8-11) + + const __m256i rhs_mat_0145_60_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_60, 136); //B60(0-3) B61(0-3) B60(0-3) B61(0-3) B64(0-3) B65(0-3) B64(0-3) B65(0-3) + const __m256i rhs_mat_2367_60_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_60, 136); //B62(0-3) B63(0-3) B62(0-3) B63(0-3) B66(0-3) B67(0-3) B66(0-3) B67(0-3) + + const __m256i rhs_mat_0145_61_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_61, 136); //B60(8-11) B61(8-11) B60(8-11) B61(8-11) B64(8-11) B65(8-11) B64(8-11) B65(8-11) + const __m256i rhs_mat_2367_61_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_61, 136); //B62(8-11) B63(8-11) B62(8-11) B63(8-11) B66(8-11) B67(8-11) B66(8-11) B67(8-11) + + const __m256i rhs_mat_0145_70_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_70, 136); //B70(0-3) B71(0-3) B70(0-3) B71(0-3) B74(0-3) B75(0-3) B74(0-3) B75(0-3) + const __m256i rhs_mat_2367_70_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_70, 136); //B72(0-3) B73(0-3) B72(0-3) B73(0-3) B76(0-3) B77(0-3) B76(0-3) B77(0-3) + + const __m256i rhs_mat_0145_71_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_71, 136); //B70(8-11) B71(8-11) B70(8-11) B71(8-11) B74(8-11) B75(8-11) B74(8-11) B75(8-11) + const __m256i rhs_mat_2367_71_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_71, 136); //B72(8-11) B73(8-11) B72(8-11) B73(8-11) B76(8-11) B77(8-11) B76(8-11) B77(8-11) + + + // Shuffle pattern two - right side input + const __m256i rhs_mat_0145_00_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_00, 221); //B00(4-7) B01(4-7) B00(4-7) B01(4-7) B04(4-7) B05(4-7) B04(4-7) B05(4-7) + const __m256i rhs_mat_2367_00_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_00, 221); //B02(4-7) B03(4-7) B02(4-7) B03(4-7) B06(4-7) B07(4-7) B06(4-7) B07(4-7) + + const __m256i rhs_mat_0145_01_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_01, 221); //B00(12-15) B01(12-15) B00(12-15) B01(12-15) B04(12-15) B05(12-15) B04(12-15) B05(12-15) + const __m256i rhs_mat_2367_01_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_01, 221); //B02(12-15) B03(12-15) B02(12-15) B03(12-15) B06(12-15) B07(12-15) B06(12-15) B07(12-15) + + const __m256i rhs_mat_0145_10_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_10, 221); //B10(4-7) B11(4-7) B10(4-7) B11(4-7) B14(4-7) B15(4-7) B14(4-7) B15(4-7) + const __m256i rhs_mat_2367_10_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_10, 221); //B12(4-7) B13(4-7) B12(4-7) B13(4-7) B16(4-7) B17(4-7) B16(4-7) B17(4-7) + + const __m256i rhs_mat_0145_11_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_11, 221); //B10(12-15) B11(12-15) B10(12-15) B11(12-15) B14(12-15) B15(12-15) B14(12-15) B15(12-15) + const __m256i rhs_mat_2367_11_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_11, 221); //B12(12-15) B13(12-15) B12(12-15) B13(12-15) B16(12-15) B17(12-15) B16(12-15) B17(12-15) + + const __m256i rhs_mat_0145_20_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_20, 221); //B20(4-7) B21(4-7) B20(4-7) B21(4-7) B24(4-7) B25(4-7) B24(4-7) B25(4-7) + const __m256i rhs_mat_2367_20_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_20, 221); //B22(4-7) B23(4-7) B22(4-7) B23(4-7) B26(4-7) B27(4-7) B26(4-7) B27(4-7) + + const __m256i rhs_mat_0145_21_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_21, 221); //B20(12-15) B21(12-15) B20(12-15) B21(12-15) B24(12-15) B25(12-15) B24(12-15) B25(12-15) + const __m256i rhs_mat_2367_21_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_21, 221); //B22(12-15) B23(12-15) B22(12-15) B23(12-15) B26(12-15) B27(12-15) B26(12-15) B27(12-15) + + const __m256i rhs_mat_0145_30_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_30, 221); //B30(4-7) B31(4-7) B30(4-7) B31(4-7) B34(4-7) B35(4-7) B34(4-7) B35(4-7) + const __m256i rhs_mat_2367_30_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_30, 221); //B32(4-7) B33(4-7) B32(4-7) B33(4-7) B36(4-7) B37(4-7) B36(4-7) B37(4-7) + + const __m256i rhs_mat_0145_31_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_31, 221); //B30(12-15) B31(12-15) B30(12-15) B31(12-15) B34(12-15) B35(12-15) B34(12-15) B35(12-15) + const __m256i rhs_mat_2367_31_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_31, 221); //B32(12-15) B33(12-15) B32(12-15) B33(12-15) B36(12-15) B37(12-15) B36(12-15) B37(12-15) + + const __m256i rhs_mat_0145_40_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_40, 221); //B40(4-7) B41(4-7) B40(4-7) B41(4-7) B44(4-7) B45(4-7) B44(4-7) B45(4-7) + const __m256i rhs_mat_2367_40_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_40, 221); //B42(4-7) B43(4-7) B42(4-7) B43(4-7) B46(4-7) B47(4-7) B46(4-7) B47(4-7) + + const __m256i rhs_mat_0145_41_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_41, 221); //B40(12-15) B41(12-15) B40(12-15) B41(12-15) B44(12-15) B45(12-15) B44(12-15) B45(12-15) + const __m256i rhs_mat_2367_41_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_41, 221); //B42(12-15) B43(12-15) B42(12-15) B43(12-15) B46(12-15) B47(12-15) B46(12-15) B47(12-15) + + const __m256i rhs_mat_0145_50_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_50, 221); //B50(4-7) B51(4-7) B50(4-7) B51(4-7) B54(4-7) B55(4-7) B54(4-7) B55(4-7) + const __m256i rhs_mat_2367_50_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_50, 221); //B52(4-7) B53(4-7) B52(4-7) B53(4-7) B56(4-7) B57(4-7) B56(4-7) B57(4-7) + + const __m256i rhs_mat_0145_51_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_51, 221); //B50(12-15) B51(12-15) B50(12-15) B51(12-15) B54(12-15) B55(12-15) B54(12-15) B55(12-15) + const __m256i rhs_mat_2367_51_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_51, 221); //B52(12-15) B53(12-15) B52(12-15) B53(12-15) B56(12-15) B57(12-15) B56(12-15) B57(12-15) + + const __m256i rhs_mat_0145_60_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_60, 221); //B60(4-7) B61(4-7) B60(4-7) B61(4-7) B64(4-7) B65(4-7) B64(4-7) B65(4-7) + const __m256i rhs_mat_2367_60_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_60, 221); //B62(4-7) B63(4-7) B62(4-7) B63(4-7) B66(4-7) B67(4-7) B66(4-7) B67(4-7) + + const __m256i rhs_mat_0145_61_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_61, 221); //B60(12-15) B61(12-15) B60(12-15) B61(12-15) B64(12-15) B65(12-15) B64(12-15) B65(12-15) + const __m256i rhs_mat_2367_61_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_61, 221); //B62(12-15) B63(12-15) B62(12-15) B63(12-15) B66(12-15) B67(12-15) B66(12-15) B67(12-15) + + const __m256i rhs_mat_0145_70_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_70, 221); //B70(4-7) B71(4-7) B70(4-7) B71(4-7) B74(4-7) B75(4-7) B74(4-7) B75(4-7) + const __m256i rhs_mat_2367_70_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_70, 221); //B72(4-7) B73(4-7) B72(4-7) B73(4-7) B76(4-7) B77(4-7) B76(4-7) B77(4-7) + + const __m256i rhs_mat_0145_71_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_71, 221); //B70(12-15) B71(12-15) B70(12-15) B71(12-15) B74(12-15) B75(12-15) B74(12-15) B75(12-15) + const __m256i rhs_mat_2367_71_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_71, 221); //B72(12-15) B73(12-15) B72(12-15) B73(12-15) B76(12-15) B77(12-15) B76(12-15) B77(12-15) + + //Scales and Mins of corresponding sub blocks from different Q2_K structures are stored together + //s00 m00 s01 m01 s10 m10 s11 m11 s20 m20 s21 m21 s30 m30 s31 m31 s40 m40 s41 m41 s50 m50 s51 m51 s60 m60 s61 m61 s70 m70 s71 m71 + + // Combine mins and scales for sub-blocks: 0-1, 2-3, 4-5, 6-7 in the sb loop + const __m128i mins_and_scales_01 = _mm_loadu_si128((const __m128i *)(b_ptr[b].scales + sb * 64)); + const __m128i mins_and_scales_23 = _mm_loadu_si128((const __m128i *)(b_ptr[b].scales + 16 + sb * 64)); + const __m128i mins_and_scales_45 = _mm_loadu_si128((const __m128i *)(b_ptr[b].scales + 32 + sb * 64)); + const __m128i mins_and_scales_67 = _mm_loadu_si128((const __m128i *)(b_ptr[b].scales + 48 + sb * 64)); + + // Extract scales which is lower half from mins_and_scales + const __m128i scales_01 = _mm_and_si128(mins_and_scales_01, m4b_sse); + const __m128i scales_23 = _mm_and_si128(mins_and_scales_23, m4b_sse); + const __m128i scales_45 = _mm_and_si128(mins_and_scales_45, m4b_sse); + const __m128i scales_67 = _mm_and_si128(mins_and_scales_67, m4b_sse); + + // Extract mins which is upper half from mins_and_scales + const __m256i mins_01 = _mm256_cvtepu8_epi16(_mm_and_si128(_mm_srli_epi16(mins_and_scales_01, 4), m4b_sse)); + const __m256i mins_23 = _mm256_cvtepu8_epi16(_mm_and_si128(_mm_srli_epi16(mins_and_scales_23, 4), m4b_sse)); + const __m256i mins_45 = _mm256_cvtepu8_epi16(_mm_and_si128(_mm_srli_epi16(mins_and_scales_45, 4), m4b_sse)); + const __m256i mins_67 = _mm256_cvtepu8_epi16(_mm_and_si128(_mm_srli_epi16(mins_and_scales_67, 4), m4b_sse)); + + const __m256i scales_0 = _mm256_cvtepu8_epi16(_mm_shuffle_epi8(scales_01, scalesmask1_sse)); + const __m256i scales_1 = _mm256_cvtepu8_epi16(_mm_shuffle_epi8(scales_01, scalesmask2_sse)); + + const __m256i scales_2 = _mm256_cvtepu8_epi16(_mm_shuffle_epi8(scales_23, scalesmask1_sse)); + const __m256i scales_3 = _mm256_cvtepu8_epi16(_mm_shuffle_epi8(scales_23, scalesmask2_sse)); + + const __m256i scales_4 = _mm256_cvtepu8_epi16(_mm_shuffle_epi8(scales_45, scalesmask1_sse)); + const __m256i scales_5 = _mm256_cvtepu8_epi16(_mm_shuffle_epi8(scales_45, scalesmask2_sse)); + + const __m256i scales_6 = _mm256_cvtepu8_epi16(_mm_shuffle_epi8(scales_67, scalesmask1_sse)); + const __m256i scales_7 = _mm256_cvtepu8_epi16(_mm_shuffle_epi8(scales_67, scalesmask2_sse)); + + const __m256i scale_0145_0 = _mm256_shuffle_epi32(scales_0, 68); + const __m256i scale_2367_0 = _mm256_shuffle_epi32(scales_0, 238); + + const __m256i scale_0145_1 = _mm256_shuffle_epi32(scales_1, 68); + const __m256i scale_2367_1 = _mm256_shuffle_epi32(scales_1, 238); + + const __m256i scale_0145_2 = _mm256_shuffle_epi32(scales_2, 68); + const __m256i scale_2367_2 = _mm256_shuffle_epi32(scales_2, 238); + + const __m256i scale_0145_3 = _mm256_shuffle_epi32(scales_3, 68); + const __m256i scale_2367_3 = _mm256_shuffle_epi32(scales_3, 238); + + const __m256i scale_0145_4 = _mm256_shuffle_epi32(scales_4, 68); + const __m256i scale_2367_4 = _mm256_shuffle_epi32(scales_4, 238); + + const __m256i scale_0145_5 = _mm256_shuffle_epi32(scales_5, 68); + const __m256i scale_2367_5 = _mm256_shuffle_epi32(scales_5, 238); + + const __m256i scale_0145_6 = _mm256_shuffle_epi32(scales_6, 68); + const __m256i scale_2367_6 = _mm256_shuffle_epi32(scales_6, 238); + + const __m256i scale_0145_7 = _mm256_shuffle_epi32(scales_7, 68); + const __m256i scale_2367_7 = _mm256_shuffle_epi32(scales_7, 238); + + + for (int rp = 0; rp < 4; rp++) { + + // Load the four block_q8_k quantized values interleaved with each other in chunks of eight bytes - A0,A1,A2,A3 + // Loaded as set of 128 bit vectors and repeated into a 256 bit vector + __m256i lhs_mat_0123_00 = _mm256_loadu_si256((const __m256i * )((a_ptrs[rp][b].qs + 512 * sb))); + __m256i lhs_mat_01_00 = _mm256_permute2f128_si256(lhs_mat_0123_00, lhs_mat_0123_00, 0); + __m256i lhs_mat_23_00 = _mm256_permute2f128_si256(lhs_mat_0123_00, lhs_mat_0123_00, 17); + __m256i lhs_mat_0123_01 = _mm256_loadu_si256((const __m256i * )((a_ptrs[rp][b].qs + 32 + 512 * sb))); + __m256i lhs_mat_01_01 = _mm256_permute2f128_si256(lhs_mat_0123_01, lhs_mat_0123_01, 0); + __m256i lhs_mat_23_01 = _mm256_permute2f128_si256(lhs_mat_0123_01, lhs_mat_0123_01, 17); + __m256i lhs_mat_0123_10 = _mm256_loadu_si256((const __m256i * )((a_ptrs[rp][b].qs + 64 + 512 * sb))); + __m256i lhs_mat_01_10 = _mm256_permute2f128_si256(lhs_mat_0123_10, lhs_mat_0123_10, 0); + __m256i lhs_mat_23_10 = _mm256_permute2f128_si256(lhs_mat_0123_10, lhs_mat_0123_10, 17); + __m256i lhs_mat_0123_11 = _mm256_loadu_si256((const __m256i * )((a_ptrs[rp][b].qs + 96 + 512 * sb))); + __m256i lhs_mat_01_11 = _mm256_permute2f128_si256(lhs_mat_0123_11, lhs_mat_0123_11, 0); + __m256i lhs_mat_23_11 = _mm256_permute2f128_si256(lhs_mat_0123_11, lhs_mat_0123_11, 17); + __m256i lhs_mat_0123_20 = _mm256_loadu_si256((const __m256i * )((a_ptrs[rp][b].qs + 128 + 512 * sb))); + __m256i lhs_mat_01_20 = _mm256_permute2f128_si256(lhs_mat_0123_20, lhs_mat_0123_20, 0); + __m256i lhs_mat_23_20 = _mm256_permute2f128_si256(lhs_mat_0123_20, lhs_mat_0123_20, 17); + __m256i lhs_mat_0123_21 = _mm256_loadu_si256((const __m256i * )((a_ptrs[rp][b].qs + 160 + 512 * sb))); + __m256i lhs_mat_01_21 = _mm256_permute2f128_si256(lhs_mat_0123_21, lhs_mat_0123_21, 0); + __m256i lhs_mat_23_21 = _mm256_permute2f128_si256(lhs_mat_0123_21, lhs_mat_0123_21, 17); + __m256i lhs_mat_0123_30 = _mm256_loadu_si256((const __m256i * )((a_ptrs[rp][b].qs + 192 + 512 * sb))); + __m256i lhs_mat_01_30 = _mm256_permute2f128_si256(lhs_mat_0123_30, lhs_mat_0123_30, 0); + __m256i lhs_mat_23_30 = _mm256_permute2f128_si256(lhs_mat_0123_30, lhs_mat_0123_30, 17); + __m256i lhs_mat_0123_31 = _mm256_loadu_si256((const __m256i * )((a_ptrs[rp][b].qs + 224 + 512 * sb))); + __m256i lhs_mat_01_31 = _mm256_permute2f128_si256(lhs_mat_0123_31, lhs_mat_0123_31, 0); + __m256i lhs_mat_23_31 = _mm256_permute2f128_si256(lhs_mat_0123_31, lhs_mat_0123_31, 17); + + __m256i lhs_mat_0123_40 = _mm256_loadu_si256((const __m256i * )((a_ptrs[rp][b].qs + 256 + 512 * sb))); + __m256i lhs_mat_01_40 = _mm256_permute2f128_si256(lhs_mat_0123_40, lhs_mat_0123_40, 0); + __m256i lhs_mat_23_40 = _mm256_permute2f128_si256(lhs_mat_0123_40, lhs_mat_0123_40, 17); + __m256i lhs_mat_0123_41 = _mm256_loadu_si256((const __m256i * )((a_ptrs[rp][b].qs + 288 + 512 * sb))); + __m256i lhs_mat_01_41 = _mm256_permute2f128_si256(lhs_mat_0123_41, lhs_mat_0123_41, 0); + __m256i lhs_mat_23_41 = _mm256_permute2f128_si256(lhs_mat_0123_41, lhs_mat_0123_41, 17); + __m256i lhs_mat_0123_50 = _mm256_loadu_si256((const __m256i * )((a_ptrs[rp][b].qs + 320 + 512 * sb))); + __m256i lhs_mat_01_50 = _mm256_permute2f128_si256(lhs_mat_0123_50, lhs_mat_0123_50, 0); + __m256i lhs_mat_23_50 = _mm256_permute2f128_si256(lhs_mat_0123_50, lhs_mat_0123_50, 17); + __m256i lhs_mat_0123_51 = _mm256_loadu_si256((const __m256i * )((a_ptrs[rp][b].qs + 352 + 512 * sb))); + __m256i lhs_mat_01_51 = _mm256_permute2f128_si256(lhs_mat_0123_51, lhs_mat_0123_51, 0); + __m256i lhs_mat_23_51 = _mm256_permute2f128_si256(lhs_mat_0123_51, lhs_mat_0123_51, 17); + __m256i lhs_mat_0123_60 = _mm256_loadu_si256((const __m256i * )((a_ptrs[rp][b].qs + 384 + 512 * sb))); + __m256i lhs_mat_01_60 = _mm256_permute2f128_si256(lhs_mat_0123_60, lhs_mat_0123_60, 0); + __m256i lhs_mat_23_60 = _mm256_permute2f128_si256(lhs_mat_0123_60, lhs_mat_0123_60, 17); + __m256i lhs_mat_0123_61 = _mm256_loadu_si256((const __m256i * )((a_ptrs[rp][b].qs + 416 + 512 * sb))); + __m256i lhs_mat_01_61 = _mm256_permute2f128_si256(lhs_mat_0123_61, lhs_mat_0123_61, 0); + __m256i lhs_mat_23_61 = _mm256_permute2f128_si256(lhs_mat_0123_61, lhs_mat_0123_61, 17); + __m256i lhs_mat_0123_70 = _mm256_loadu_si256((const __m256i * )((a_ptrs[rp][b].qs + 448 + 512 * sb))); + __m256i lhs_mat_01_70 = _mm256_permute2f128_si256(lhs_mat_0123_70, lhs_mat_0123_70, 0); + __m256i lhs_mat_23_70 = _mm256_permute2f128_si256(lhs_mat_0123_70, lhs_mat_0123_70, 17); + __m256i lhs_mat_0123_71 = _mm256_loadu_si256((const __m256i * )((a_ptrs[rp][b].qs + 480 + 512 * sb))); + __m256i lhs_mat_01_71 = _mm256_permute2f128_si256(lhs_mat_0123_71, lhs_mat_0123_71, 0); + __m256i lhs_mat_23_71 = _mm256_permute2f128_si256(lhs_mat_0123_71, lhs_mat_0123_71, 17); + + // Bsums are loaded for the different Q8_K blocks + __m128i lhs_raw_bsums_01_0123 = _mm_loadu_si128((const __m128i *)((a_ptrs[rp][b].bsums + 32 * sb))); + __m128i lhs_raw_bsums_23_0123 = _mm_loadu_si128((const __m128i *)(a_ptrs[rp][b].bsums + 8 + 32 * sb)); + __m128i lhs_raw_bsums_01_4567 = _mm_loadu_si128((const __m128i *)((a_ptrs[rp][b].bsums + 16 + 32 * sb))); + __m128i lhs_raw_bsums_23_4567 = _mm_loadu_si128((const __m128i *)(a_ptrs[rp][b].bsums + 24 + 32 * sb)); + + // Shuffle pattern one - left side input + const __m256i lhs_mat_01_00_sp1 = _mm256_shuffle_epi32(lhs_mat_01_00, 160); //A00(0-3) A00(0-3) A01(0-3) A01(0-3) A00(0-3) A00(0-3) A01(0-3) A01(0-3) + const __m256i lhs_mat_23_00_sp1 = _mm256_shuffle_epi32(lhs_mat_23_00, 160); //A02(0-3) A03(0-3) A02(0-3) A03(0-3) A02(0-3) A03(0-3) A02(0-3) A03(0-3) + + const __m256i lhs_mat_01_01_sp1 = _mm256_shuffle_epi32(lhs_mat_01_01, 160); //A00(8-11) A00(8-11) A01(8-11) A01(8-11) A00(8-11) A00(8-11) A01(8-11) A01(8-11) + const __m256i lhs_mat_23_01_sp1 = _mm256_shuffle_epi32(lhs_mat_23_01, 160); //A02(8-11) A03(8-11) A02(8-11) A03(8-11) A02(8-11) A03(8-11) A02(8-11) A03(8-11) + + const __m256i lhs_mat_01_10_sp1 = _mm256_shuffle_epi32(lhs_mat_01_10, 160); //A10(0-3) A10(0-3) A11(0-3) A11(0-3) A10(0-3) A10(0-3) A11(0-3) A11(0-3) + const __m256i lhs_mat_23_10_sp1 = _mm256_shuffle_epi32(lhs_mat_23_10, 160); //A12(0-3) A13(0-3) A12(0-3) A13(0-3) A12(0-3) A13(0-3) A12(0-3) A13(0-3) + + const __m256i lhs_mat_01_11_sp1 = _mm256_shuffle_epi32(lhs_mat_01_11, 160); //A10(8-11) A10(8-11) A11(8-11) A11(8-11) A10(8-11) A10(8-11) A11(8-11) A11(8-11) + const __m256i lhs_mat_23_11_sp1 = _mm256_shuffle_epi32(lhs_mat_23_11, 160); //A12(8-11) A13(8-11) A12(8-11) A13(8-11) A12(8-11) A13(8-11) A12(8-11) A13(8-11) + + const __m256i lhs_mat_01_20_sp1 = _mm256_shuffle_epi32(lhs_mat_01_20, 160); //A20(0-3) A20(0-3) A21(0-3) A21(0-3) A20(0-3) A20(0-3) A21(0-3) A21(0-3) + const __m256i lhs_mat_23_20_sp1 = _mm256_shuffle_epi32(lhs_mat_23_20, 160); //A22(0-3) A23(0-3) A22(0-3) A23(0-3) A22(0-3) A23(0-3) A22(0-3) A23(0-3) + + const __m256i lhs_mat_01_21_sp1 = _mm256_shuffle_epi32(lhs_mat_01_21, 160); //A20(8-11) A20(8-11) A21(8-11) A21(8-11) A20(8-11) A20(8-11) A21(8-11) A21(8-11) + const __m256i lhs_mat_23_21_sp1 = _mm256_shuffle_epi32(lhs_mat_23_21, 160); //A22(8-11) A23(8-11) A22(8-11) A23(8-11) A22(8-11) A23(8-11) A22(8-11) A23(8-11) + + const __m256i lhs_mat_01_30_sp1 = _mm256_shuffle_epi32(lhs_mat_01_30, 160); //A30(0-3) A30(0-3) A31(0-3) A31(0-3) A30(0-3) A30(0-3) A31(0-3) A31(0-3) + const __m256i lhs_mat_23_30_sp1 = _mm256_shuffle_epi32(lhs_mat_23_30, 160); //A32(0-3) A33(0-3) A32(0-3) A33(0-3) A32(0-3) A33(0-3) A32(0-3) A33(0-3) + + const __m256i lhs_mat_01_31_sp1 = _mm256_shuffle_epi32(lhs_mat_01_31, 160); //A30(8-11) A30(8-11) A31(8-11) A31(8-11) A30(8-11) A30(8-11) A31(8-11) A31(8-11) + const __m256i lhs_mat_23_31_sp1 = _mm256_shuffle_epi32(lhs_mat_23_31, 160); //A32(8-11) A33(8-11) A32(8-11) A33(8-11) A32(8-11) A33(8-11) A32(8-11) A33(8-11) + + const __m256i lhs_mat_01_40_sp1 = _mm256_shuffle_epi32(lhs_mat_01_40, 160); //A40(0-3) A40(0-3) A41(0-3) A41(0-3) A40(0-3) A40(0-3) A41(0-3) A41(0-3) + const __m256i lhs_mat_23_40_sp1 = _mm256_shuffle_epi32(lhs_mat_23_40, 160); //A42(0-3) A43(0-3) A42(0-3) A43(0-3) A42(0-3) A43(0-3) A42(0-3) A43(0-3) + + const __m256i lhs_mat_01_41_sp1 = _mm256_shuffle_epi32(lhs_mat_01_41, 160); //A40(8-11) A40(8-11) A41(8-11) A41(8-11) A40(8-11) A40(8-11) A41(8-11) A41(8-11) + const __m256i lhs_mat_23_41_sp1 = _mm256_shuffle_epi32(lhs_mat_23_41, 160); //A42(8-11) A43(8-11) A42(8-11) A43(8-11) A42(8-11) A43(8-11) A42(8-11) A43(8-11) + + const __m256i lhs_mat_01_50_sp1 = _mm256_shuffle_epi32(lhs_mat_01_50, 160); //A50(0-3) A50(0-3) A51(0-3) A51(0-3) A50(0-3) A50(0-3) A51(0-3) A51(0-3) + const __m256i lhs_mat_23_50_sp1 = _mm256_shuffle_epi32(lhs_mat_23_50, 160); //A52(0-3) A53(0-3) A52(0-3) A53(0-3) A52(0-3) A53(0-3) A52(0-3) A53(0-3) + + const __m256i lhs_mat_01_51_sp1 = _mm256_shuffle_epi32(lhs_mat_01_51, 160); //A50(8-11) A50(8-11) A51(8-11) A51(8-11) A50(8-11) A50(8-11) A51(8-11) A51(8-11) + const __m256i lhs_mat_23_51_sp1 = _mm256_shuffle_epi32(lhs_mat_23_51, 160); //A52(8-11) A53(8-11) A52(8-11) A53(8-11) A52(8-11) A53(8-11) A52(8-11) A53(8-11) + + const __m256i lhs_mat_01_60_sp1 = _mm256_shuffle_epi32(lhs_mat_01_60, 160); //A60(0-3) A60(0-3) A61(0-3) A61(0-3) A60(0-3) A60(0-3) A61(0-3) A61(0-3) + const __m256i lhs_mat_23_60_sp1 = _mm256_shuffle_epi32(lhs_mat_23_60, 160); //A62(0-3) A63(0-3) A62(0-3) A63(0-3) A62(0-3) A63(0-3) A62(0-3) A63(0-3) + + const __m256i lhs_mat_01_61_sp1 = _mm256_shuffle_epi32(lhs_mat_01_61, 160); //A60(8-11) A60(8-11) A61(8-11) A61(8-11) A60(8-11) A60(8-11) A61(8-11) A61(8-11) + const __m256i lhs_mat_23_61_sp1 = _mm256_shuffle_epi32(lhs_mat_23_61, 160); //A62(8-11) A63(8-11) A62(8-11) A63(8-11) A62(8-11) A63(8-11) A62(8-11) A63(8-11) + + const __m256i lhs_mat_01_70_sp1 = _mm256_shuffle_epi32(lhs_mat_01_70, 160); //A70(0-3) A70(0-3) A71(0-3) A71(0-3) A70(0-3) A70(0-3) A71(0-3) A71(0-3) + const __m256i lhs_mat_23_70_sp1 = _mm256_shuffle_epi32(lhs_mat_23_70, 160); //A72(0-3) A73(0-3) A72(0-3) A73(0-3) A72(0-3) A73(0-3) A72(0-3) A73(0-3) + + const __m256i lhs_mat_01_71_sp1 = _mm256_shuffle_epi32(lhs_mat_01_71, 160); //A70(8-11) A70(8-11) A71(8-11) A71(8-11) A70(8-11) A70(8-11) A71(8-11) A71(8-11) + const __m256i lhs_mat_23_71_sp1 = _mm256_shuffle_epi32(lhs_mat_23_71, 160); //A72(8-11) A73(8-11) A72(8-11) A73(8-11) A72(8-11) A73(8-11) A72(8-11) A73(8-11) + + // Shuffle pattern two- left side input + const __m256i lhs_mat_01_00_sp2 = _mm256_shuffle_epi32(lhs_mat_01_00, 245); //A00(4-7) A00(4-7) A01(4-7) A01(4-7) A00(4-7) A00(4-7) A01(4-7) A01(4-7) + const __m256i lhs_mat_23_00_sp2 = _mm256_shuffle_epi32(lhs_mat_23_00, 245); //A02(4-7) A03(4-7) A02(4-7) A03(4-7) A02(4-7) A03(4-7) A02(4-7) A03(4-7) + + const __m256i lhs_mat_01_01_sp2 = _mm256_shuffle_epi32(lhs_mat_01_01, 245); //A00(12-15) A00(12-15) A01(12-15) A01(12-15) A00(12-15) A00(12-15) A01(12-15) A01(12-15) + const __m256i lhs_mat_23_01_sp2 = _mm256_shuffle_epi32(lhs_mat_23_01, 245); //A02(12-15) A03(12-15) A02(12-15) A03(12-15) A02(12-15) A03(12-15) A02(12-15) A03(12-15) + + const __m256i lhs_mat_01_10_sp2 = _mm256_shuffle_epi32(lhs_mat_01_10, 245); //A10(4-7) A10(4-7) A11(4-7) A11(4-7) A10(4-7) A10(4-7) A11(4-7) A11(4-7) + const __m256i lhs_mat_23_10_sp2 = _mm256_shuffle_epi32(lhs_mat_23_10, 245); //A12(4-7) A13(4-7) A12(4-7) A13(4-7) A12(4-7) A13(4-7) A12(4-7) A13(4-7) + + const __m256i lhs_mat_01_11_sp2 = _mm256_shuffle_epi32(lhs_mat_01_11, 245); //A10(12-15) A10(12-15) A11(12-15) A11(12-15) A10(12-15) A10(12-15) A11(12-15) A11(12-15) + const __m256i lhs_mat_23_11_sp2 = _mm256_shuffle_epi32(lhs_mat_23_11, 245); //A12(12-15) A13(12-15) A12(12-15) A13(12-15) A12(12-15) A13(12-15) A12(12-15) A13(12-15) + + const __m256i lhs_mat_01_20_sp2 = _mm256_shuffle_epi32(lhs_mat_01_20, 245); //A20(4-7) A20(4-7) A21(4-7) A21(4-7) A20(4-7) A20(4-7) A21(4-7) A21(4-7) + const __m256i lhs_mat_23_20_sp2 = _mm256_shuffle_epi32(lhs_mat_23_20, 245); //A22(4-7) A23(4-7) A22(4-7) A23(4-7) A22(4-7) A23(4-7) A22(4-7) A23(4-7) + + const __m256i lhs_mat_01_21_sp2 = _mm256_shuffle_epi32(lhs_mat_01_21, 245); //A20(12-15) A20(12-15) A21(12-15) A21(12-15) A20(12-15) A20(12-15) A21(12-15) A21(12-15) + const __m256i lhs_mat_23_21_sp2 = _mm256_shuffle_epi32(lhs_mat_23_21, 245); //A22(12-15) A23(12-15) A22(12-15) A23(12-15) A22(12-15) A23(12-15) A22(12-15) A23(12-15) + + const __m256i lhs_mat_01_30_sp2 = _mm256_shuffle_epi32(lhs_mat_01_30, 245); //A30(4-7) A30(4-7) A31(4-7) A31(4-7) A30(4-7) A30(4-7) A31(4-7) A31(4-7) + const __m256i lhs_mat_23_30_sp2 = _mm256_shuffle_epi32(lhs_mat_23_30, 245); //A32(4-7) A33(4-7) A32(4-7) A33(4-7) A32(4-7) A33(4-7) A32(4-7) A33(4-7) + + const __m256i lhs_mat_01_31_sp2 = _mm256_shuffle_epi32(lhs_mat_01_31, 245); //A30(12-15) A30(12-15) A31(12-15) A31(12-15) A30(12-15) A30(12-15) A31(12-15) A31(12-15) + const __m256i lhs_mat_23_31_sp2 = _mm256_shuffle_epi32(lhs_mat_23_31, 245); //A32(12-15) A33(12-15) A32(12-15) A33(12-15) A32(12-15) A33(12-15) A32(12-15) A33(12-15) + + const __m256i lhs_mat_01_40_sp2 = _mm256_shuffle_epi32(lhs_mat_01_40, 245); //A40(4-7) A40(4-7) A41(4-7) A41(4-7) A40(4-7) A40(4-7) A41(4-7) A41(4-7) + const __m256i lhs_mat_23_40_sp2 = _mm256_shuffle_epi32(lhs_mat_23_40, 245); //A42(4-7) A43(4-7) A42(4-7) A43(4-7) A42(4-7) A43(4-7) A42(4-7) A43(4-7) + + const __m256i lhs_mat_01_41_sp2 = _mm256_shuffle_epi32(lhs_mat_01_41, 245); //A40(12-15) A40(12-15) A41(12-15) A41(12-15) A40(12-15) A40(12-15) A41(12-15) A41(12-15) + const __m256i lhs_mat_23_41_sp2 = _mm256_shuffle_epi32(lhs_mat_23_41, 245); //A42(12-15) A43(12-15) A42(12-15) A43(12-15) A42(12-15) A43(12-15) A42(12-15) A43(12-15) + + const __m256i lhs_mat_01_50_sp2 = _mm256_shuffle_epi32(lhs_mat_01_50, 245); //A50(4-7) A50(4-7) A51(4-7) A51(4-7) A50(4-7) A50(4-7) A51(4-7) A51(4-7) + const __m256i lhs_mat_23_50_sp2 = _mm256_shuffle_epi32(lhs_mat_23_50, 245); //A52(4-7) A53(4-7) A52(4-7) A53(4-7) A52(4-7) A53(4-7) A52(4-7) A53(4-7) + + const __m256i lhs_mat_01_51_sp2 = _mm256_shuffle_epi32(lhs_mat_01_51, 245); //A50(12-15) A50(12-15) A51(12-15) A51(12-15) A50(12-15) A50(12-15) A51(12-15) A51(12-15) + const __m256i lhs_mat_23_51_sp2 = _mm256_shuffle_epi32(lhs_mat_23_51, 245); //A52(12-15) A53(12-15) A52(12-15) A53(12-15) A52(12-15) A53(12-15) A52(12-15) A53(12-15) + + const __m256i lhs_mat_01_60_sp2 = _mm256_shuffle_epi32(lhs_mat_01_60, 245); //A60(4-7) A60(4-7) A61(4-7) A61(4-7) A60(4-7) A60(4-7) A61(4-7) A61(4-7) + const __m256i lhs_mat_23_60_sp2 = _mm256_shuffle_epi32(lhs_mat_23_60, 245); //A62(4-7) A63(4-7) A62(4-7) A63(4-7) A62(4-7) A63(4-7) A62(4-7) A63(4-7) + + const __m256i lhs_mat_01_61_sp2 = _mm256_shuffle_epi32(lhs_mat_01_61, 245); //A60(12-15) A60(12-15) A61(12-15) A61(12-15) A60(12-15) A60(12-15) A61(12-15) A61(12-15) + const __m256i lhs_mat_23_61_sp2 = _mm256_shuffle_epi32(lhs_mat_23_61, 245); //A62(12-15) A63(12-15) A62(12-15) A63(12-15) A62(12-15) A63(12-15) A62(12-15) A63(12-15) + + const __m256i lhs_mat_01_70_sp2 = _mm256_shuffle_epi32(lhs_mat_01_70, 245); //A70(4-7) A70(4-7) A71(4-7) A71(4-7) A70(4-7) A70(4-7) A71(4-7) A71(4-7) + const __m256i lhs_mat_23_70_sp2 = _mm256_shuffle_epi32(lhs_mat_23_70, 245); //A72(4-7) A73(4-7) A72(4-7) A73(4-7) A72(4-7) A73(4-7) A72(4-7) A73(4-7) + + const __m256i lhs_mat_01_71_sp2 = _mm256_shuffle_epi32(lhs_mat_01_71, 245); //A70(12-15) A70(12-15) A71(12-15) A71(12-15) A70(12-15) A70(12-15) A71(12-15) A71(12-15) + const __m256i lhs_mat_23_71_sp2 = _mm256_shuffle_epi32(lhs_mat_23_71, 245); //A72(12-15) A73(12-15) A72(12-15) A73(12-15) A72(12-15) A73(12-15) A72(12-15) A73(12-15) + + // The values arranged in shuffle patterns are operated with dot product operation within 32 bit lane i.e corresponding bytes and multiplied and added into 32 bit integers within 32 bit lane + __m256i iacc_mat_00_0_sp1 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_00_sp1, lhs_mat_01_00_sp1),_mm256_maddubs_epi16(rhs_mat_0145_01_sp1, lhs_mat_01_01_sp1)); + __m256i iacc_mat_01_0_sp1 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_00_sp1, lhs_mat_01_00_sp1),_mm256_maddubs_epi16(rhs_mat_2367_01_sp1, lhs_mat_01_01_sp1)); + + __m256i iacc_mat_10_0_sp1 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_00_sp1, lhs_mat_23_00_sp1),_mm256_maddubs_epi16(rhs_mat_0145_01_sp1, lhs_mat_23_01_sp1)); + __m256i iacc_mat_11_0_sp1 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_00_sp1, lhs_mat_23_00_sp1),_mm256_maddubs_epi16(rhs_mat_2367_01_sp1, lhs_mat_23_01_sp1)); + + __m256i iacc_mat_00_1_sp1 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_10_sp1, lhs_mat_01_10_sp1),_mm256_maddubs_epi16(rhs_mat_0145_11_sp1, lhs_mat_01_11_sp1)); + __m256i iacc_mat_01_1_sp1 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_10_sp1, lhs_mat_01_10_sp1),_mm256_maddubs_epi16(rhs_mat_2367_11_sp1, lhs_mat_01_11_sp1)); + + __m256i iacc_mat_10_1_sp1 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_10_sp1, lhs_mat_23_10_sp1),_mm256_maddubs_epi16(rhs_mat_0145_11_sp1, lhs_mat_23_11_sp1)); + __m256i iacc_mat_11_1_sp1 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_10_sp1, lhs_mat_23_10_sp1),_mm256_maddubs_epi16(rhs_mat_2367_11_sp1, lhs_mat_23_11_sp1)); + + __m256i iacc_mat_00_2_sp1 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_20_sp1, lhs_mat_01_20_sp1),_mm256_maddubs_epi16(rhs_mat_0145_21_sp1, lhs_mat_01_21_sp1)); + __m256i iacc_mat_01_2_sp1 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_20_sp1, lhs_mat_01_20_sp1),_mm256_maddubs_epi16(rhs_mat_2367_21_sp1, lhs_mat_01_21_sp1)); + + __m256i iacc_mat_10_2_sp1 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_20_sp1, lhs_mat_23_20_sp1),_mm256_maddubs_epi16(rhs_mat_0145_21_sp1, lhs_mat_23_21_sp1)); + __m256i iacc_mat_11_2_sp1 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_20_sp1, lhs_mat_23_20_sp1),_mm256_maddubs_epi16(rhs_mat_2367_21_sp1, lhs_mat_23_21_sp1)); + + __m256i iacc_mat_00_3_sp1 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_30_sp1, lhs_mat_01_30_sp1),_mm256_maddubs_epi16(rhs_mat_0145_31_sp1, lhs_mat_01_31_sp1)); + __m256i iacc_mat_01_3_sp1 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_30_sp1, lhs_mat_01_30_sp1),_mm256_maddubs_epi16(rhs_mat_2367_31_sp1, lhs_mat_01_31_sp1)); + + __m256i iacc_mat_10_3_sp1 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_30_sp1, lhs_mat_23_30_sp1),_mm256_maddubs_epi16(rhs_mat_0145_31_sp1, lhs_mat_23_31_sp1)); + __m256i iacc_mat_11_3_sp1 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_30_sp1, lhs_mat_23_30_sp1),_mm256_maddubs_epi16(rhs_mat_2367_31_sp1, lhs_mat_23_31_sp1)); + + __m256i iacc_mat_00_4_sp1 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_40_sp1, lhs_mat_01_40_sp1),_mm256_maddubs_epi16(rhs_mat_0145_41_sp1, lhs_mat_01_41_sp1)); + __m256i iacc_mat_01_4_sp1 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_40_sp1, lhs_mat_01_40_sp1),_mm256_maddubs_epi16(rhs_mat_2367_41_sp1, lhs_mat_01_41_sp1)); + + __m256i iacc_mat_10_4_sp1 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_40_sp1, lhs_mat_23_40_sp1),_mm256_maddubs_epi16(rhs_mat_0145_41_sp1, lhs_mat_23_41_sp1)); + __m256i iacc_mat_11_4_sp1 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_40_sp1, lhs_mat_23_40_sp1),_mm256_maddubs_epi16(rhs_mat_2367_41_sp1, lhs_mat_23_41_sp1)); + + __m256i iacc_mat_00_5_sp1 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_50_sp1, lhs_mat_01_50_sp1),_mm256_maddubs_epi16(rhs_mat_0145_51_sp1, lhs_mat_01_51_sp1)); + __m256i iacc_mat_01_5_sp1 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_50_sp1, lhs_mat_01_50_sp1),_mm256_maddubs_epi16(rhs_mat_2367_51_sp1, lhs_mat_01_51_sp1)); + + __m256i iacc_mat_10_5_sp1 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_50_sp1, lhs_mat_23_50_sp1),_mm256_maddubs_epi16(rhs_mat_0145_51_sp1, lhs_mat_23_51_sp1)); + __m256i iacc_mat_11_5_sp1 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_50_sp1, lhs_mat_23_50_sp1),_mm256_maddubs_epi16(rhs_mat_2367_51_sp1, lhs_mat_23_51_sp1)); + + __m256i iacc_mat_00_6_sp1 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_60_sp1, lhs_mat_01_60_sp1),_mm256_maddubs_epi16(rhs_mat_0145_61_sp1, lhs_mat_01_61_sp1)); + __m256i iacc_mat_01_6_sp1 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_60_sp1, lhs_mat_01_60_sp1),_mm256_maddubs_epi16(rhs_mat_2367_61_sp1, lhs_mat_01_61_sp1)); + + __m256i iacc_mat_10_6_sp1 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_60_sp1, lhs_mat_23_60_sp1),_mm256_maddubs_epi16(rhs_mat_0145_61_sp1, lhs_mat_23_61_sp1)); + __m256i iacc_mat_11_6_sp1 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_60_sp1, lhs_mat_23_60_sp1),_mm256_maddubs_epi16(rhs_mat_2367_61_sp1, lhs_mat_23_61_sp1)); + + __m256i iacc_mat_00_7_sp1 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_70_sp1, lhs_mat_01_70_sp1),_mm256_maddubs_epi16(rhs_mat_0145_71_sp1, lhs_mat_01_71_sp1)); + __m256i iacc_mat_01_7_sp1 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_70_sp1, lhs_mat_01_70_sp1),_mm256_maddubs_epi16(rhs_mat_2367_71_sp1, lhs_mat_01_71_sp1)); + + __m256i iacc_mat_10_7_sp1 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_70_sp1, lhs_mat_23_70_sp1),_mm256_maddubs_epi16(rhs_mat_0145_71_sp1, lhs_mat_23_71_sp1)); + __m256i iacc_mat_11_7_sp1 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_70_sp1, lhs_mat_23_70_sp1),_mm256_maddubs_epi16(rhs_mat_2367_71_sp1, lhs_mat_23_71_sp1)); + + + __m256i iacc_mat_00_0_sp2 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_00_sp2, lhs_mat_01_00_sp2),_mm256_maddubs_epi16(rhs_mat_0145_01_sp2, lhs_mat_01_01_sp2)); + __m256i iacc_mat_01_0_sp2 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_00_sp2, lhs_mat_01_00_sp2),_mm256_maddubs_epi16(rhs_mat_2367_01_sp2, lhs_mat_01_01_sp2)); + + __m256i iacc_mat_10_0_sp2 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_00_sp2, lhs_mat_23_00_sp2),_mm256_maddubs_epi16(rhs_mat_0145_01_sp2, lhs_mat_23_01_sp2)); + __m256i iacc_mat_11_0_sp2 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_00_sp2, lhs_mat_23_00_sp2),_mm256_maddubs_epi16(rhs_mat_2367_01_sp2, lhs_mat_23_01_sp2)); + + __m256i iacc_mat_00_1_sp2 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_10_sp2, lhs_mat_01_10_sp2),_mm256_maddubs_epi16(rhs_mat_0145_11_sp2, lhs_mat_01_11_sp2)); + __m256i iacc_mat_01_1_sp2 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_10_sp2, lhs_mat_01_10_sp2),_mm256_maddubs_epi16(rhs_mat_2367_11_sp2, lhs_mat_01_11_sp2)); + + __m256i iacc_mat_10_1_sp2 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_10_sp2, lhs_mat_23_10_sp2),_mm256_maddubs_epi16(rhs_mat_0145_11_sp2, lhs_mat_23_11_sp2)); + __m256i iacc_mat_11_1_sp2 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_10_sp2, lhs_mat_23_10_sp2),_mm256_maddubs_epi16(rhs_mat_2367_11_sp2, lhs_mat_23_11_sp2)); + + __m256i iacc_mat_00_2_sp2 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_20_sp2, lhs_mat_01_20_sp2),_mm256_maddubs_epi16(rhs_mat_0145_21_sp2, lhs_mat_01_21_sp2)); + __m256i iacc_mat_01_2_sp2 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_20_sp2, lhs_mat_01_20_sp2),_mm256_maddubs_epi16(rhs_mat_2367_21_sp2, lhs_mat_01_21_sp2)); + + __m256i iacc_mat_10_2_sp2 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_20_sp2, lhs_mat_23_20_sp2),_mm256_maddubs_epi16(rhs_mat_0145_21_sp2, lhs_mat_23_21_sp2)); + __m256i iacc_mat_11_2_sp2 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_20_sp2, lhs_mat_23_20_sp2),_mm256_maddubs_epi16(rhs_mat_2367_21_sp2, lhs_mat_23_21_sp2)); + + __m256i iacc_mat_00_3_sp2 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_30_sp2, lhs_mat_01_30_sp2),_mm256_maddubs_epi16(rhs_mat_0145_31_sp2, lhs_mat_01_31_sp2)); + __m256i iacc_mat_01_3_sp2 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_30_sp2, lhs_mat_01_30_sp2),_mm256_maddubs_epi16(rhs_mat_2367_31_sp2, lhs_mat_01_31_sp2)); + + __m256i iacc_mat_10_3_sp2 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_30_sp2, lhs_mat_23_30_sp2),_mm256_maddubs_epi16(rhs_mat_0145_31_sp2, lhs_mat_23_31_sp2)); + __m256i iacc_mat_11_3_sp2 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_30_sp2, lhs_mat_23_30_sp2),_mm256_maddubs_epi16(rhs_mat_2367_31_sp2, lhs_mat_23_31_sp2)); + + __m256i iacc_mat_00_4_sp2 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_40_sp2, lhs_mat_01_40_sp2),_mm256_maddubs_epi16(rhs_mat_0145_41_sp2, lhs_mat_01_41_sp2)); + __m256i iacc_mat_01_4_sp2 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_40_sp2, lhs_mat_01_40_sp2),_mm256_maddubs_epi16(rhs_mat_2367_41_sp2, lhs_mat_01_41_sp2)); + + __m256i iacc_mat_10_4_sp2 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_40_sp2, lhs_mat_23_40_sp2),_mm256_maddubs_epi16(rhs_mat_0145_41_sp2, lhs_mat_23_41_sp2)); + __m256i iacc_mat_11_4_sp2 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_40_sp2, lhs_mat_23_40_sp2),_mm256_maddubs_epi16(rhs_mat_2367_41_sp2, lhs_mat_23_41_sp2)); + + __m256i iacc_mat_00_5_sp2 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_50_sp2, lhs_mat_01_50_sp2),_mm256_maddubs_epi16(rhs_mat_0145_51_sp2, lhs_mat_01_51_sp2)); + __m256i iacc_mat_01_5_sp2 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_50_sp2, lhs_mat_01_50_sp2),_mm256_maddubs_epi16(rhs_mat_2367_51_sp2, lhs_mat_01_51_sp2)); + + __m256i iacc_mat_10_5_sp2 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_50_sp2, lhs_mat_23_50_sp2),_mm256_maddubs_epi16(rhs_mat_0145_51_sp2, lhs_mat_23_51_sp2)); + __m256i iacc_mat_11_5_sp2 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_50_sp2, lhs_mat_23_50_sp2),_mm256_maddubs_epi16(rhs_mat_2367_51_sp2, lhs_mat_23_51_sp2)); + + __m256i iacc_mat_00_6_sp2 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_60_sp2, lhs_mat_01_60_sp2),_mm256_maddubs_epi16(rhs_mat_0145_61_sp2, lhs_mat_01_61_sp2)); + __m256i iacc_mat_01_6_sp2 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_60_sp2, lhs_mat_01_60_sp2),_mm256_maddubs_epi16(rhs_mat_2367_61_sp2, lhs_mat_01_61_sp2)); + + __m256i iacc_mat_10_6_sp2 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_60_sp2, lhs_mat_23_60_sp2),_mm256_maddubs_epi16(rhs_mat_0145_61_sp2, lhs_mat_23_61_sp2)); + __m256i iacc_mat_11_6_sp2 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_60_sp2, lhs_mat_23_60_sp2),_mm256_maddubs_epi16(rhs_mat_2367_61_sp2, lhs_mat_23_61_sp2)); + + __m256i iacc_mat_00_7_sp2 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_70_sp2, lhs_mat_01_70_sp2),_mm256_maddubs_epi16(rhs_mat_0145_71_sp2, lhs_mat_01_71_sp2)); + __m256i iacc_mat_01_7_sp2 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_70_sp2, lhs_mat_01_70_sp2),_mm256_maddubs_epi16(rhs_mat_2367_71_sp2, lhs_mat_01_71_sp2)); + + __m256i iacc_mat_10_7_sp2 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_70_sp2, lhs_mat_23_70_sp2),_mm256_maddubs_epi16(rhs_mat_0145_71_sp2, lhs_mat_23_71_sp2)); + __m256i iacc_mat_11_7_sp2 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_70_sp2, lhs_mat_23_70_sp2),_mm256_maddubs_epi16(rhs_mat_2367_71_sp2, lhs_mat_23_71_sp2)); + + // Combine results from both shuffle patterns for each output block + __m256i iacc_mat_00_0 = _mm256_add_epi16(iacc_mat_00_0_sp1, iacc_mat_00_0_sp2); + __m256i iacc_mat_01_0 = _mm256_add_epi16(iacc_mat_01_0_sp1, iacc_mat_01_0_sp2); + __m256i iacc_mat_10_0 = _mm256_add_epi16(iacc_mat_10_0_sp1, iacc_mat_10_0_sp2); + __m256i iacc_mat_11_0 = _mm256_add_epi16(iacc_mat_11_0_sp1, iacc_mat_11_0_sp2); + + __m256i iacc_mat_00_1 = _mm256_add_epi16(iacc_mat_00_1_sp1, iacc_mat_00_1_sp2); + __m256i iacc_mat_01_1 = _mm256_add_epi16(iacc_mat_01_1_sp1, iacc_mat_01_1_sp2); + __m256i iacc_mat_10_1 = _mm256_add_epi16(iacc_mat_10_1_sp1, iacc_mat_10_1_sp2); + __m256i iacc_mat_11_1 = _mm256_add_epi16(iacc_mat_11_1_sp1, iacc_mat_11_1_sp2); + + __m256i iacc_mat_00_2 = _mm256_add_epi16(iacc_mat_00_2_sp1, iacc_mat_00_2_sp2); + __m256i iacc_mat_01_2 = _mm256_add_epi16(iacc_mat_01_2_sp1, iacc_mat_01_2_sp2); + __m256i iacc_mat_10_2 = _mm256_add_epi16(iacc_mat_10_2_sp1, iacc_mat_10_2_sp2); + __m256i iacc_mat_11_2 = _mm256_add_epi16(iacc_mat_11_2_sp1, iacc_mat_11_2_sp2); + + __m256i iacc_mat_00_3 = _mm256_add_epi16(iacc_mat_00_3_sp1, iacc_mat_00_3_sp2); + __m256i iacc_mat_01_3 = _mm256_add_epi16(iacc_mat_01_3_sp1, iacc_mat_01_3_sp2); + __m256i iacc_mat_10_3 = _mm256_add_epi16(iacc_mat_10_3_sp1, iacc_mat_10_3_sp2); + __m256i iacc_mat_11_3 = _mm256_add_epi16(iacc_mat_11_3_sp1, iacc_mat_11_3_sp2); + + __m256i iacc_mat_00_4 = _mm256_add_epi16(iacc_mat_00_4_sp1, iacc_mat_00_4_sp2); + __m256i iacc_mat_01_4 = _mm256_add_epi16(iacc_mat_01_4_sp1, iacc_mat_01_4_sp2); + __m256i iacc_mat_10_4 = _mm256_add_epi16(iacc_mat_10_4_sp1, iacc_mat_10_4_sp2); + __m256i iacc_mat_11_4 = _mm256_add_epi16(iacc_mat_11_4_sp1, iacc_mat_11_4_sp2); + + __m256i iacc_mat_00_5 = _mm256_add_epi16(iacc_mat_00_5_sp1, iacc_mat_00_5_sp2); + __m256i iacc_mat_01_5 = _mm256_add_epi16(iacc_mat_01_5_sp1, iacc_mat_01_5_sp2); + __m256i iacc_mat_10_5 = _mm256_add_epi16(iacc_mat_10_5_sp1, iacc_mat_10_5_sp2); + __m256i iacc_mat_11_5 = _mm256_add_epi16(iacc_mat_11_5_sp1, iacc_mat_11_5_sp2); + + __m256i iacc_mat_00_6 = _mm256_add_epi16(iacc_mat_00_6_sp1, iacc_mat_00_6_sp2); + __m256i iacc_mat_01_6 = _mm256_add_epi16(iacc_mat_01_6_sp1, iacc_mat_01_6_sp2); + __m256i iacc_mat_10_6 = _mm256_add_epi16(iacc_mat_10_6_sp1, iacc_mat_10_6_sp2); + __m256i iacc_mat_11_6 = _mm256_add_epi16(iacc_mat_11_6_sp1, iacc_mat_11_6_sp2); + + __m256i iacc_mat_00_7 = _mm256_add_epi16(iacc_mat_00_7_sp1, iacc_mat_00_7_sp2); + __m256i iacc_mat_01_7 = _mm256_add_epi16(iacc_mat_01_7_sp1, iacc_mat_01_7_sp2); + __m256i iacc_mat_10_7 = _mm256_add_epi16(iacc_mat_10_7_sp1, iacc_mat_10_7_sp2); + __m256i iacc_mat_11_7 = _mm256_add_epi16(iacc_mat_11_7_sp1, iacc_mat_11_7_sp2); + + // Output of both shuffle patterns are added in order to sum dot product outputs of all 32 values in block + iacc_mat_00_0 = _mm256_madd_epi16(iacc_mat_00_0, scale_0145_0); + iacc_mat_01_0 = _mm256_madd_epi16(iacc_mat_01_0, scale_2367_0); + iacc_mat_10_0 = _mm256_madd_epi16(iacc_mat_10_0, scale_0145_0); + iacc_mat_11_0 = _mm256_madd_epi16(iacc_mat_11_0, scale_2367_0); + + iacc_mat_00_1 = _mm256_madd_epi16(iacc_mat_00_1, scale_0145_1); + iacc_mat_01_1 = _mm256_madd_epi16(iacc_mat_01_1, scale_2367_1); + iacc_mat_10_1 = _mm256_madd_epi16(iacc_mat_10_1, scale_0145_1); + iacc_mat_11_1 = _mm256_madd_epi16(iacc_mat_11_1, scale_2367_1); + + iacc_mat_00_2 = _mm256_madd_epi16(iacc_mat_00_2, scale_0145_2); + iacc_mat_01_2 = _mm256_madd_epi16(iacc_mat_01_2, scale_2367_2); + iacc_mat_10_2 = _mm256_madd_epi16(iacc_mat_10_2, scale_0145_2); + iacc_mat_11_2 = _mm256_madd_epi16(iacc_mat_11_2, scale_2367_2); + + iacc_mat_00_3 = _mm256_madd_epi16(iacc_mat_00_3, scale_0145_3); + iacc_mat_01_3 = _mm256_madd_epi16(iacc_mat_01_3, scale_2367_3); + iacc_mat_10_3 = _mm256_madd_epi16(iacc_mat_10_3, scale_0145_3); + iacc_mat_11_3 = _mm256_madd_epi16(iacc_mat_11_3, scale_2367_3); + + iacc_mat_00_4 = _mm256_madd_epi16(iacc_mat_00_4, scale_0145_4); + iacc_mat_01_4 = _mm256_madd_epi16(iacc_mat_01_4, scale_2367_4); + iacc_mat_10_4 = _mm256_madd_epi16(iacc_mat_10_4, scale_0145_4); + iacc_mat_11_4 = _mm256_madd_epi16(iacc_mat_11_4, scale_2367_4); + + iacc_mat_00_5 = _mm256_madd_epi16(iacc_mat_00_5, scale_0145_5); + iacc_mat_01_5 = _mm256_madd_epi16(iacc_mat_01_5, scale_2367_5); + iacc_mat_10_5 = _mm256_madd_epi16(iacc_mat_10_5, scale_0145_5); + iacc_mat_11_5 = _mm256_madd_epi16(iacc_mat_11_5, scale_2367_5); + + iacc_mat_00_6 = _mm256_madd_epi16(iacc_mat_00_6, scale_0145_6); + iacc_mat_01_6 = _mm256_madd_epi16(iacc_mat_01_6, scale_2367_6); + iacc_mat_10_6 = _mm256_madd_epi16(iacc_mat_10_6, scale_0145_6); + iacc_mat_11_6 = _mm256_madd_epi16(iacc_mat_11_6, scale_2367_6); + + iacc_mat_00_7 = _mm256_madd_epi16(iacc_mat_00_7, scale_0145_7); + iacc_mat_01_7 = _mm256_madd_epi16(iacc_mat_01_7, scale_2367_7); + iacc_mat_10_7 = _mm256_madd_epi16(iacc_mat_10_7, scale_0145_7); + iacc_mat_11_7 = _mm256_madd_epi16(iacc_mat_11_7, scale_2367_7); + + __m256i iacc_mat_00 = _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(iacc_mat_00_0, iacc_mat_00_1), _mm256_add_epi32(iacc_mat_00_2, iacc_mat_00_3)), _mm256_add_epi32(_mm256_add_epi32(iacc_mat_00_4, iacc_mat_00_5), _mm256_add_epi32(iacc_mat_00_6, iacc_mat_00_7))); + __m256i iacc_mat_01 = _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(iacc_mat_01_0, iacc_mat_01_1), _mm256_add_epi32(iacc_mat_01_2, iacc_mat_01_3)), _mm256_add_epi32(_mm256_add_epi32(iacc_mat_01_4, iacc_mat_01_5), _mm256_add_epi32(iacc_mat_01_6, iacc_mat_01_7))); + __m256i iacc_mat_10 = _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(iacc_mat_10_0, iacc_mat_10_1), _mm256_add_epi32(iacc_mat_10_2, iacc_mat_10_3)), _mm256_add_epi32(_mm256_add_epi32(iacc_mat_10_4, iacc_mat_10_5), _mm256_add_epi32(iacc_mat_10_6, iacc_mat_10_7))); + __m256i iacc_mat_11 = _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(iacc_mat_11_0, iacc_mat_11_1), _mm256_add_epi32(iacc_mat_11_2, iacc_mat_11_3)), _mm256_add_epi32(_mm256_add_epi32(iacc_mat_11_4, iacc_mat_11_5), _mm256_add_epi32(iacc_mat_11_6, iacc_mat_11_7))); + + // Straighten out to make 4 row vectors + __m256i iacc_row_0 = _mm256_blend_epi32(iacc_mat_00, _mm256_shuffle_epi32(iacc_mat_01, 78), 204); + __m256i iacc_row_1 = _mm256_blend_epi32(_mm256_shuffle_epi32(iacc_mat_00, 78), iacc_mat_01, 204); + __m256i iacc_row_2 = _mm256_blend_epi32(iacc_mat_10, _mm256_shuffle_epi32(iacc_mat_11, 78), 204); + __m256i iacc_row_3 = _mm256_blend_epi32(_mm256_shuffle_epi32(iacc_mat_10, 78), iacc_mat_11, 204); + + // Load the scale(d) values for all the 4 Q8_k blocks and repeat it across lanes + const __m128 row_scale_f32_sse = _mm_load_ps(a_ptrs[rp][b].d); + const __m256 row_scale_f32 = _mm256_set_m128(row_scale_f32_sse, row_scale_f32_sse); + + // Multiply with appropiate scales and accumulate (for both d and dmin) below + acc_rows[rp * 4] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_0), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 0)), acc_rows[rp * 4]); + acc_rows[rp * 4 + 1] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_1), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 85)), acc_rows[rp * 4 + 1]); + acc_rows[rp * 4 + 2] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_2), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 170)), acc_rows[rp * 4 + 2]); + acc_rows[rp * 4 + 3] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_3), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 255)), acc_rows[rp * 4 + 3]); + + __m256i lhs_bsums_01_0123 = _mm256_inserti128_si256(_mm256_castsi128_si256(lhs_raw_bsums_01_0123), lhs_raw_bsums_01_0123, 1); + __m256i lhs_bsums_23_0123 = _mm256_inserti128_si256(_mm256_castsi128_si256(lhs_raw_bsums_23_0123), lhs_raw_bsums_23_0123, 1); + __m256i lhs_bsums_01_4567 = _mm256_inserti128_si256(_mm256_castsi128_si256(lhs_raw_bsums_01_4567), lhs_raw_bsums_01_4567, 1); + __m256i lhs_bsums_23_4567 = _mm256_inserti128_si256(_mm256_castsi128_si256(lhs_raw_bsums_23_4567), lhs_raw_bsums_23_4567, 1); + + // Take two bsums from two Q8_Ks at a time and multiply with corresponding mins values from each Q2_K + __m256i iacc_row_min_0_01 = _mm256_madd_epi16(_mm256_shuffle_epi32(lhs_bsums_01_0123, 0), mins_01); + __m256i iacc_row_min_1_01 = _mm256_madd_epi16(_mm256_shuffle_epi32(lhs_bsums_01_0123, 170), mins_01); + __m256i iacc_row_min_2_01 = _mm256_madd_epi16(_mm256_shuffle_epi32(lhs_bsums_23_0123, 0), mins_01); + __m256i iacc_row_min_3_01 = _mm256_madd_epi16(_mm256_shuffle_epi32(lhs_bsums_23_0123, 170), mins_01); + + __m256i iacc_row_min_0_23 = _mm256_madd_epi16(_mm256_shuffle_epi32(lhs_bsums_01_0123, 85), mins_23); + __m256i iacc_row_min_1_23 = _mm256_madd_epi16(_mm256_shuffle_epi32(lhs_bsums_01_0123, 255), mins_23); + __m256i iacc_row_min_2_23 = _mm256_madd_epi16(_mm256_shuffle_epi32(lhs_bsums_23_0123, 85), mins_23); + __m256i iacc_row_min_3_23 = _mm256_madd_epi16(_mm256_shuffle_epi32(lhs_bsums_23_0123, 255), mins_23); + + __m256i iacc_row_min_0_45 = _mm256_madd_epi16(_mm256_shuffle_epi32(lhs_bsums_01_4567, 0), mins_45); + __m256i iacc_row_min_1_45 = _mm256_madd_epi16(_mm256_shuffle_epi32(lhs_bsums_01_4567, 170), mins_45); + __m256i iacc_row_min_2_45 = _mm256_madd_epi16(_mm256_shuffle_epi32(lhs_bsums_23_4567, 0), mins_45); + __m256i iacc_row_min_3_45 = _mm256_madd_epi16(_mm256_shuffle_epi32(lhs_bsums_23_4567, 170), mins_45); + + __m256i iacc_row_min_0_67 = _mm256_madd_epi16(_mm256_shuffle_epi32(lhs_bsums_01_4567, 85), mins_67); + __m256i iacc_row_min_1_67 = _mm256_madd_epi16(_mm256_shuffle_epi32(lhs_bsums_01_4567, 255), mins_67); + __m256i iacc_row_min_2_67 = _mm256_madd_epi16(_mm256_shuffle_epi32(lhs_bsums_23_4567, 85), mins_67); + __m256i iacc_row_min_3_67 = _mm256_madd_epi16(_mm256_shuffle_epi32(lhs_bsums_23_4567, 255), mins_67); + + __m256i iacc_row_min_0 = _mm256_add_epi32(_mm256_add_epi32(iacc_row_min_0_01, iacc_row_min_0_23), _mm256_add_epi32(iacc_row_min_0_45,iacc_row_min_0_67)); + __m256i iacc_row_min_1 = _mm256_add_epi32(_mm256_add_epi32(iacc_row_min_1_01, iacc_row_min_1_23), _mm256_add_epi32(iacc_row_min_1_45,iacc_row_min_1_67)); + __m256i iacc_row_min_2 = _mm256_add_epi32(_mm256_add_epi32(iacc_row_min_2_01, iacc_row_min_2_23), _mm256_add_epi32(iacc_row_min_2_45,iacc_row_min_2_67)); + __m256i iacc_row_min_3 = _mm256_add_epi32(_mm256_add_epi32(iacc_row_min_3_01, iacc_row_min_3_23), _mm256_add_epi32(iacc_row_min_3_45,iacc_row_min_3_67)); + + acc_min_rows[rp * 4] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_min_0), _mm256_mul_ps(col_dmin_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 0)), acc_min_rows[rp * 4]); + acc_min_rows[rp * 4 + 1] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_min_1), _mm256_mul_ps(col_dmin_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 85)), acc_min_rows[rp * 4 + 1]); + acc_min_rows[rp * 4 + 2] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_min_2), _mm256_mul_ps(col_dmin_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 170)), acc_min_rows[rp * 4 + 2]); + acc_min_rows[rp * 4 + 3] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_min_3), _mm256_mul_ps(col_dmin_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 255)), acc_min_rows[rp * 4 + 3]); + + } + } + } + // Store the accumulated values + for (int i = 0; i < 16; i++) { + _mm256_storeu_ps((float * )(s + ((y * 4 + i) * bs + x * 8)), _mm256_sub_ps(acc_rows[i], acc_min_rows[i])); + + } + } + } + + for (; y < nr / 4; y ++) { + + const block_q8_Kx4 * a_ptr = a_ptr_start + (y * nb); + + // Take group of eight block_q2_kx8 structures at each pass of the loop and perform dot product operation + for (int64_t x = xstart; x < nc / 8; x++) { + + const block_q2_Kx8 * b_ptr = b_ptr_start + (x * b_nb); + + // Master FP accumulators + __m256 acc_rows[4]; + for (int i = 0; i < 4; i++) { + acc_rows[i] = _mm256_setzero_ps(); + } + + __m256 acc_min_rows[4]; + for (int i = 0; i < 4; i++) { + acc_min_rows[i] = _mm256_setzero_ps(); + } + + for (int64_t b = 0; b < nb; b++) { + // Delta values - Load the eight scale values of block_q2_kx8 + const __m256 col_scale_f32 = GGML_F32Cx8_LOAD(b_ptr[b].d); + + // dmin values - Load the eight dmin values of block_q2_kx8 + const __m256 col_dmin_f32 = GGML_F32Cx8_LOAD(b_ptr[b].dmin); + + // Loop to iterate over the sixteen sub blocks of a super block - eight sub blocks are processed per iteration + for (int sb = 0; sb < QK_K / 128; sb++) { + + // Load the eight block_q2_k for eight sub blocks quantized values interleaved with each other in chunks of eight bytes - B0,B1 ....B6,B7 + const __m256i rhs_raw_mat_0123_0 = _mm256_loadu_si256((const __m256i *)(b_ptr[b].qs + sb * 256)); + const __m256i rhs_raw_mat_4567_0 = _mm256_loadu_si256((const __m256i *)(b_ptr[b].qs + 32 + sb * 256)); + const __m256i rhs_raw_mat_0123_1 = _mm256_loadu_si256((const __m256i *)(b_ptr[b].qs + 64 + sb * 256)); + const __m256i rhs_raw_mat_4567_1 = _mm256_loadu_si256((const __m256i *)(b_ptr[b].qs + 96 + sb * 256)); + const __m256i rhs_raw_mat_0123_2 = _mm256_loadu_si256((const __m256i *)(b_ptr[b].qs + 128 + sb * 256)); + const __m256i rhs_raw_mat_4567_2 = _mm256_loadu_si256((const __m256i *)(b_ptr[b].qs + 160 + sb * 256)); + const __m256i rhs_raw_mat_0123_3 = _mm256_loadu_si256((const __m256i *)(b_ptr[b].qs + 192 + sb * 256)); + const __m256i rhs_raw_mat_4567_3 = _mm256_loadu_si256((const __m256i *)(b_ptr[b].qs + 224 + sb * 256)); + + // Save the values in the following vectors in the formats B0B1B4B5, B2B3B6B7 for further processing and storing of values + //superblock sub block which part of sub block + const __m256i rhs_raw_mat_0145_0 = _mm256_blend_epi32(rhs_raw_mat_0123_0, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_0, requiredOrder), 240); + const __m256i rhs_raw_mat_2367_0 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_0, requiredOrder), rhs_raw_mat_4567_0, 240); + + const __m256i rhs_raw_mat_0145_1 = _mm256_blend_epi32(rhs_raw_mat_0123_1, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_1, requiredOrder), 240); + const __m256i rhs_raw_mat_2367_1 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_1, requiredOrder), rhs_raw_mat_4567_1, 240); + + const __m256i rhs_raw_mat_0145_2 = _mm256_blend_epi32(rhs_raw_mat_0123_2, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_2, requiredOrder), 240); + const __m256i rhs_raw_mat_2367_2 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_2, requiredOrder), rhs_raw_mat_4567_2, 240); + + const __m256i rhs_raw_mat_0145_3 = _mm256_blend_epi32(rhs_raw_mat_0123_3, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_3, requiredOrder), 240); + const __m256i rhs_raw_mat_2367_3 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_3, requiredOrder), rhs_raw_mat_4567_3, 240); + + // 2-bit -> 8-bit + // First sub block of the eight sub blocks processed in the iteration + const __m256i rhs_mat_0145_00 = _mm256_and_si256(rhs_raw_mat_0145_0, m3b); //B00(0-7) B01(0-7) B04(0-7) B05(0-7) + const __m256i rhs_mat_2367_00 = _mm256_and_si256(rhs_raw_mat_2367_0, m3b); //B02(0-7) B03(0-7) B06(0-7) B07(0-7) + + const __m256i rhs_mat_0145_01 = _mm256_and_si256(rhs_raw_mat_0145_1, m3b); //B00(8-15) B01(8-15) B04(8-15) B05(8-15) + const __m256i rhs_mat_2367_01 = _mm256_and_si256(rhs_raw_mat_2367_1, m3b); //B02(8-15) B03(8-15) B06(8-15) B07(8-15) + + // Second sub block of the eight sub blocks processed in the iteration + const __m256i rhs_mat_0145_10 = _mm256_and_si256(rhs_raw_mat_0145_2, m3b); //B10(0-7) B11(0-7) B14(0-7) B15(0-7) + const __m256i rhs_mat_2367_10 = _mm256_and_si256(rhs_raw_mat_2367_2, m3b); //B12(0-7) B13(0-7) B16(0-7) B17(0-7) + + const __m256i rhs_mat_0145_11 = _mm256_and_si256(rhs_raw_mat_0145_3, m3b); //B10(8-15) B11(8-15) B14(8-15) B15(8-15) + const __m256i rhs_mat_2367_11 = _mm256_and_si256(rhs_raw_mat_2367_3, m3b); //B12(8-15) B13(8-15) B16(8-15) B17(8-15) + + // Third sub block of the eight sub blocks processed in the iteration + const __m256i rhs_mat_0145_20 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_0145_0, 2), m3b); //B20(0-7) B21(0-7) B24(0-7) B25(0-7) + const __m256i rhs_mat_2367_20 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_2367_0, 2), m3b); //B22(0-7) B23(0-7) B26(0-7) B27(0-7) + + const __m256i rhs_mat_0145_21 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_0145_1, 2), m3b); //B20(8-15) B21(8-15) B24(8-15) B25(8-15) + const __m256i rhs_mat_2367_21 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_2367_1, 2), m3b); //B22(8-15) B23(8-15) B26(8-15) B27(8-15) + + // Fourth sub block of the eight sub blocks processed in the iteration + const __m256i rhs_mat_0145_30 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_0145_2, 2), m3b); //B30(0-7) B31(0-7) B34(0-7) B35(0-7) + const __m256i rhs_mat_2367_30 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_2367_2, 2), m3b); //B32(0-7) B33(0-7) B36(0-7) B37(0-7) + + const __m256i rhs_mat_0145_31 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_0145_3, 2), m3b); //B30(8-15) B31(8-15) B34(8-15) B35(8-15) + const __m256i rhs_mat_2367_31 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_2367_3, 2), m3b); //B32(8-15) B33(8-15) B36(8-15) B37(8-15) + + // Fifth sub block of the eight sub blocks processed in the iteration + const __m256i rhs_mat_0145_40 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_0145_0, 4), m3b); //B40(0-7) B41(0-7) B44(0-7) B45(0-7) + const __m256i rhs_mat_2367_40 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_2367_0, 4), m3b); //B42(0-7) B43(0-7) B46(0-7) B47(0-7) + + const __m256i rhs_mat_0145_41 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_0145_1, 4), m3b); //B40(8-15) B41(8-15) B44(8-15) B45(8-15) + const __m256i rhs_mat_2367_41 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_2367_1, 4), m3b); //B42(8-15) B43(8-15) B46(8-15) B47(8-15) + + // Sixth sub block of the eight sub blocks processed in the iteration + const __m256i rhs_mat_0145_50 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_0145_2, 4), m3b); //B50(0-7) B51(0-7) B54(0-7) B55(0-7) + const __m256i rhs_mat_2367_50 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_2367_2, 4), m3b); //B52(0-7) B53(0-7) B56(0-7) B57(0-7) + + const __m256i rhs_mat_0145_51 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_0145_3, 4), m3b); //B50(8-15) B51(8-15) B54(8-15) B55(8-15) + const __m256i rhs_mat_2367_51 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_2367_3, 4), m3b); //B52(8-15) B53(8-15) B56(8-15) B57(8-15) + + // Seventh sub block of the eight sub blocks processed in the iteration + const __m256i rhs_mat_0145_60 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_0145_0, 6), m3b); //B60(0-7) B61(0-7) B64(0-7) B65(0-7) + const __m256i rhs_mat_2367_60 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_2367_0, 6), m3b); //B62(0-7) B63(0-7) B66(0-7) B67(0-7) + + const __m256i rhs_mat_0145_61 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_0145_1, 6), m3b); //B60(8-15) B61(8-15) B64(8-15) B65(8-15) + const __m256i rhs_mat_2367_61 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_2367_1, 6), m3b); //B62(8-15) B63(8-15) B66(8-15) B67(8-15) + + // Eighth sub block of the eight sub blocks processed in the iteration + const __m256i rhs_mat_0145_70 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_0145_2, 6), m3b); //B70(0-7) B71(0-7) B74(0-7) B75(0-7) + const __m256i rhs_mat_2367_70 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_2367_2, 6), m3b); //B72(0-7) B73(0-7) B76(0-7) B77(0-7) + + const __m256i rhs_mat_0145_71 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_0145_3, 6), m3b); //B70(8-15) B71(8-15) B74(8-15) B75(8-15) + const __m256i rhs_mat_2367_71 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_2367_3, 6), m3b); //B72(8-15) B73(8-15) B76(8-15) B77(8-15) + + // Shuffle pattern one - right side input + const __m256i rhs_mat_0145_00_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_00, 136); //B00(0-3) B01(0-3) B00(0-3) B01(0-3) B04(0-3) B05(0-3) B04(0-3) B05(0-3) + const __m256i rhs_mat_2367_00_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_00, 136); //B02(0-3) B03(0-3) B02(0-3) B03(0-3) B06(0-3) B07(0-3) B06(0-3) B07(0-3) + + const __m256i rhs_mat_0145_01_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_01, 136); //B00(8-11) B01(8-11) B00(8-11) B01(8-11) B04(8-11) B05(8-11) B04(8-11) B05(8-11) + const __m256i rhs_mat_2367_01_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_01, 136); //B02(8-11) B03(8-11) B02(8-11) B03(8-11) B06(8-11) B07(8-11) B06(8-11) B07(8-11) + + const __m256i rhs_mat_0145_10_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_10, 136); //B10(0-3) B11(0-3) B10(0-3) B11(0-3) B14(0-3) B15(0-3) B14(0-3) B15(0-3) + const __m256i rhs_mat_2367_10_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_10, 136); //B12(0-3) B13(0-3) B12(0-3) B13(0-3) B16(0-3) B17(0-3) B16(0-3) B17(0-3) + + const __m256i rhs_mat_0145_11_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_11, 136); //B10(8-11) B11(8-11) B10(8-11) B11(8-11) B14(8-11) B15(8-11) B14(8-11) B15(8-11) + const __m256i rhs_mat_2367_11_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_11, 136); //B12(8-11) B13(8-11) B12(8-11) B13(8-11) B16(8-11) B17(8-11) B16(8-11) B17(8-11) + + const __m256i rhs_mat_0145_20_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_20, 136); //B20(0-3) B21(0-3) B20(0-3) B21(0-3) B24(0-3) B25(0-3) B24(0-3) B25(0-3) + const __m256i rhs_mat_2367_20_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_20, 136); //B22(0-3) B23(0-3) B22(0-3) B23(0-3) B26(0-3) B27(0-3) B26(0-3) B27(0-3) + + const __m256i rhs_mat_0145_21_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_21, 136); //B20(8-11) B21(8-11) B20(8-11) B21(8-11) B24(8-11) B25(8-11) B24(8-11) B25(8-11) + const __m256i rhs_mat_2367_21_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_21, 136); //B22(8-11) B23(8-11) B22(8-11) B23(8-11) B26(8-11) B27(8-11) B26(8-11) B27(8-11) + + const __m256i rhs_mat_0145_30_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_30, 136); //B30(0-3) B31(0-3) B30(0-3) B31(0-3) B34(0-3) B35(0-3) B34(0-3) B35(0-3) + const __m256i rhs_mat_2367_30_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_30, 136); //B32(0-3) B33(0-3) B32(0-3) B33(0-3) B36(0-3) B37(0-3) B36(0-3) B37(0-3) + + const __m256i rhs_mat_0145_31_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_31, 136); //B30(8-11) B31(8-11) B30(8-11) B31(8-11) B34(8-11) B35(8-11) B34(8-11) B35(8-11 + const __m256i rhs_mat_2367_31_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_31, 136); //B32(8-11) B33(8-11) B32(8-11) B33(8-11) B36(8-11) B37(8-11) B36(8-11) B37(8-11) + + const __m256i rhs_mat_0145_40_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_40, 136); //B40(0-3) B41(0-3) B40(0-3) B41(0-3) B44(0-3) B45(0-3) B44(0-3) B45(0-3) + const __m256i rhs_mat_2367_40_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_40, 136); //B42(0-3) B43(0-3) B42(0-3) B43(0-3) B46(0-3) B47(0-3) B46(0-3) B47(0-3) + + const __m256i rhs_mat_0145_41_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_41, 136); //B40(8-11) B41(8-11) B40(8-11) B41(8-11) B44(8-11) B45(8-11) B44(8-11) B45(8-11) + const __m256i rhs_mat_2367_41_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_41, 136); //B42(8-11) B43(8-11) B42(8-11) B43(8-11) B46(8-11) B47(8-11) B46(8-11) B47(8-11) + + const __m256i rhs_mat_0145_50_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_50, 136); //B50(0-3) B51(0-3) B50(0-3) B51(0-3) B54(0-3) B55(0-3) B54(0-3) B55(0-3) + const __m256i rhs_mat_2367_50_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_50, 136); //B52(0-3) B53(0-3) B52(0-3) B53(0-3) B56(0-3) B57(0-3) B56(0-3) B57(0-3) + + const __m256i rhs_mat_0145_51_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_51, 136); //B50(8-11) B51(8-11) B50(8-11) B51(8-11) B54(8-11) B55(8-11) B54(8-11) B55(8-11) + const __m256i rhs_mat_2367_51_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_51, 136); //B52(8-11) B53(8-11) B52(8-11) B53(8-11) B56(8-11) B57(8-11) B56(8-11) B57(8-11) + + const __m256i rhs_mat_0145_60_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_60, 136); //B60(0-3) B61(0-3) B60(0-3) B61(0-3) B64(0-3) B65(0-3) B64(0-3) B65(0-3) + const __m256i rhs_mat_2367_60_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_60, 136); //B62(0-3) B63(0-3) B62(0-3) B63(0-3) B66(0-3) B67(0-3) B66(0-3) B67(0-3) + + const __m256i rhs_mat_0145_61_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_61, 136); //B60(8-11) B61(8-11) B60(8-11) B61(8-11) B64(8-11) B65(8-11) B64(8-11) B65(8-11) + const __m256i rhs_mat_2367_61_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_61, 136); //B62(8-11) B63(8-11) B62(8-11) B63(8-11) B66(8-11) B67(8-11) B66(8-11) B67(8-11) + + const __m256i rhs_mat_0145_70_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_70, 136); //B70(0-3) B71(0-3) B70(0-3) B71(0-3) B74(0-3) B75(0-3) B74(0-3) B75(0-3) + const __m256i rhs_mat_2367_70_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_70, 136); //B72(0-3) B73(0-3) B72(0-3) B73(0-3) B76(0-3) B77(0-3) B76(0-3) B77(0-3) + + const __m256i rhs_mat_0145_71_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_71, 136); //B70(8-11) B71(8-11) B70(8-11) B71(8-11) B74(8-11) B75(8-11) B74(8-11) B75(8-11) + const __m256i rhs_mat_2367_71_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_71, 136); //B72(8-11) B73(8-11) B72(8-11) B73(8-11) B76(8-11) B77(8-11) B76(8-11) B77(8-11) + + + // Shuffle pattern two - right side input + const __m256i rhs_mat_0145_00_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_00, 221); //B00(4-7) B01(4-7) B00(4-7) B01(4-7) B04(4-7) B05(4-7) B04(4-7) B05(4-7) + const __m256i rhs_mat_2367_00_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_00, 221); //B02(4-7) B03(4-7) B02(4-7) B03(4-7) B06(4-7) B07(4-7) B06(4-7) B07(4-7) + + const __m256i rhs_mat_0145_01_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_01, 221); //B00(12-15) B01(12-15) B00(12-15) B01(12-15) B04(12-15) B05(12-15) B04(12-15) B05(12-15) + const __m256i rhs_mat_2367_01_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_01, 221); //B02(12-15) B03(12-15) B02(12-15) B03(12-15) B06(12-15) B07(12-15) B06(12-15) B07(12-15) + + const __m256i rhs_mat_0145_10_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_10, 221); //B10(4-7) B11(4-7) B10(4-7) B11(4-7) B14(4-7) B15(4-7) B14(4-7) B15(4-7) + const __m256i rhs_mat_2367_10_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_10, 221); //B12(4-7) B13(4-7) B12(4-7) B13(4-7) B16(4-7) B17(4-7) B16(4-7) B17(4-7) + + const __m256i rhs_mat_0145_11_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_11, 221); //B10(12-15) B11(12-15) B10(12-15) B11(12-15) B14(12-15) B15(12-15) B14(12-15) B15(12-15) + const __m256i rhs_mat_2367_11_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_11, 221); //B12(12-15) B13(12-15) B12(12-15) B13(12-15) B16(12-15) B17(12-15) B16(12-15) B17(12-15) + + const __m256i rhs_mat_0145_20_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_20, 221); //B20(4-7) B21(4-7) B20(4-7) B21(4-7) B24(4-7) B25(4-7) B24(4-7) B25(4-7) + const __m256i rhs_mat_2367_20_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_20, 221); //B22(4-7) B23(4-7) B22(4-7) B23(4-7) B26(4-7) B27(4-7) B26(4-7) B27(4-7) + + const __m256i rhs_mat_0145_21_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_21, 221); //B20(12-15) B21(12-15) B20(12-15) B21(12-15) B24(12-15) B25(12-15) B24(12-15) B25(12-15) + const __m256i rhs_mat_2367_21_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_21, 221); //B22(12-15) B23(12-15) B22(12-15) B23(12-15) B26(12-15) B27(12-15) B26(12-15) B27(12-15) + + const __m256i rhs_mat_0145_30_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_30, 221); //B30(4-7) B31(4-7) B30(4-7) B31(4-7) B34(4-7) B35(4-7) B34(4-7) B35(4-7) + const __m256i rhs_mat_2367_30_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_30, 221); //B32(4-7) B33(4-7) B32(4-7) B33(4-7) B36(4-7) B37(4-7) B36(4-7) B37(4-7) + + const __m256i rhs_mat_0145_31_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_31, 221); //B30(12-15) B31(12-15) B30(12-15) B31(12-15) B34(12-15) B35(12-15) B34(12-15) B35(12-15) + const __m256i rhs_mat_2367_31_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_31, 221); //B32(12-15) B33(12-15) B32(12-15) B33(12-15) B36(12-15) B37(12-15) B36(12-15) B37(12-15) + + const __m256i rhs_mat_0145_40_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_40, 221); //B40(4-7) B41(4-7) B40(4-7) B41(4-7) B44(4-7) B45(4-7) B44(4-7) B45(4-7) + const __m256i rhs_mat_2367_40_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_40, 221); //B42(4-7) B43(4-7) B42(4-7) B43(4-7) B46(4-7) B47(4-7) B46(4-7) B47(4-7) + + const __m256i rhs_mat_0145_41_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_41, 221); //B40(12-15) B41(12-15) B40(12-15) B41(12-15) B44(12-15) B45(12-15) B44(12-15) B45(12-15) + const __m256i rhs_mat_2367_41_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_41, 221); //B42(12-15) B43(12-15) B42(12-15) B43(12-15) B46(12-15) B47(12-15) B46(12-15) B47(12-15) + + const __m256i rhs_mat_0145_50_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_50, 221); //B50(4-7) B51(4-7) B50(4-7) B51(4-7) B54(4-7) B55(4-7) B54(4-7) B55(4-7) + const __m256i rhs_mat_2367_50_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_50, 221); //B52(4-7) B53(4-7) B52(4-7) B53(4-7) B56(4-7) B57(4-7) B56(4-7) B57(4-7) + + const __m256i rhs_mat_0145_51_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_51, 221); //B50(12-15) B51(12-15) B50(12-15) B51(12-15) B54(12-15) B55(12-15) B54(12-15) B55(12-15) + const __m256i rhs_mat_2367_51_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_51, 221); //B52(12-15) B53(12-15) B52(12-15) B53(12-15) B56(12-15) B57(12-15) B56(12-15) B57(12-15) + + const __m256i rhs_mat_0145_60_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_60, 221); //B60(4-7) B61(4-7) B60(4-7) B61(4-7) B64(4-7) B65(4-7) B64(4-7) B65(4-7) + const __m256i rhs_mat_2367_60_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_60, 221); //B62(4-7) B63(4-7) B62(4-7) B63(4-7) B66(4-7) B67(4-7) B66(4-7) B67(4-7) + + const __m256i rhs_mat_0145_61_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_61, 221); //B60(12-15) B61(12-15) B60(12-15) B61(12-15) B64(12-15) B65(12-15) B64(12-15) B65(12-15) + const __m256i rhs_mat_2367_61_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_61, 221); //B62(12-15) B63(12-15) B62(12-15) B63(12-15) B66(12-15) B67(12-15) B66(12-15) B67(12-15) + + const __m256i rhs_mat_0145_70_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_70, 221); //B70(4-7) B71(4-7) B70(4-7) B71(4-7) B74(4-7) B75(4-7) B74(4-7) B75(4-7) + const __m256i rhs_mat_2367_70_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_70, 221); //B72(4-7) B73(4-7) B72(4-7) B73(4-7) B76(4-7) B77(4-7) B76(4-7) B77(4-7) + + const __m256i rhs_mat_0145_71_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_71, 221); //B70(12-15) B71(12-15) B70(12-15) B71(12-15) B74(12-15) B75(12-15) B74(12-15) B75(12-15) + const __m256i rhs_mat_2367_71_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_71, 221); //B72(12-15) B73(12-15) B72(12-15) B73(12-15) B76(12-15) B77(12-15) B76(12-15) B77(12-15) + + + //Scales and Mins of corresponding sub blocks from different Q2_K structures are stored together + //s00 m00 s01 m01 s10 m10 s11 m11 s20 m20 s21 m21 s30 m30 s31 m31 s40 m40 s41 m41 s50 m50 s51 m51 s60 m60 s61 m61 s70 m70 s71 m71 + + // Combine mins and scales for sub-blocks: 0-1, 2-3, 4-5, 6-7 in the sb loop + const __m128i mins_and_scales_01 = _mm_loadu_si128((const __m128i *)(b_ptr[b].scales + sb * 64)); + const __m128i mins_and_scales_23 = _mm_loadu_si128((const __m128i *)(b_ptr[b].scales + 16 + sb * 64)); + const __m128i mins_and_scales_45 = _mm_loadu_si128((const __m128i *)(b_ptr[b].scales + 32 + sb * 64)); + const __m128i mins_and_scales_67 = _mm_loadu_si128((const __m128i *)(b_ptr[b].scales + 48 + sb * 64)); + + // Extract scales which is lower half from mins_and_scales + const __m128i scales_01 = _mm_and_si128(mins_and_scales_01, m4b_sse); + const __m128i scales_23 = _mm_and_si128(mins_and_scales_23, m4b_sse); + const __m128i scales_45 = _mm_and_si128(mins_and_scales_45, m4b_sse); + const __m128i scales_67 = _mm_and_si128(mins_and_scales_67, m4b_sse); + + // Extract mins which is upper half from mins_and_scales + const __m256i mins_01 = _mm256_cvtepu8_epi16(_mm_and_si128(_mm_srli_epi16(mins_and_scales_01, 4), m4b_sse)); + const __m256i mins_23 = _mm256_cvtepu8_epi16(_mm_and_si128(_mm_srli_epi16(mins_and_scales_23, 4), m4b_sse)); + const __m256i mins_45 = _mm256_cvtepu8_epi16(_mm_and_si128(_mm_srli_epi16(mins_and_scales_45, 4), m4b_sse)); + const __m256i mins_67 = _mm256_cvtepu8_epi16(_mm_and_si128(_mm_srli_epi16(mins_and_scales_67, 4), m4b_sse)); + + const __m256i scales_0 = _mm256_cvtepu8_epi16(_mm_shuffle_epi8(scales_01, scalesmask1_sse)); + const __m256i scales_1 = _mm256_cvtepu8_epi16(_mm_shuffle_epi8(scales_01, scalesmask2_sse)); + + const __m256i scales_2 = _mm256_cvtepu8_epi16(_mm_shuffle_epi8(scales_23, scalesmask1_sse)); + const __m256i scales_3 = _mm256_cvtepu8_epi16(_mm_shuffle_epi8(scales_23, scalesmask2_sse)); + + const __m256i scales_4 = _mm256_cvtepu8_epi16(_mm_shuffle_epi8(scales_45, scalesmask1_sse)); + const __m256i scales_5 = _mm256_cvtepu8_epi16(_mm_shuffle_epi8(scales_45, scalesmask2_sse)); + + const __m256i scales_6 = _mm256_cvtepu8_epi16(_mm_shuffle_epi8(scales_67, scalesmask1_sse)); + const __m256i scales_7 = _mm256_cvtepu8_epi16(_mm_shuffle_epi8(scales_67, scalesmask2_sse)); + + const __m256i scale_0145_0 = _mm256_shuffle_epi32(scales_0, 68); + const __m256i scale_2367_0 = _mm256_shuffle_epi32(scales_0, 238); + + const __m256i scale_0145_1 = _mm256_shuffle_epi32(scales_1, 68); + const __m256i scale_2367_1 = _mm256_shuffle_epi32(scales_1, 238); + + const __m256i scale_0145_2 = _mm256_shuffle_epi32(scales_2, 68); + const __m256i scale_2367_2 = _mm256_shuffle_epi32(scales_2, 238); + + const __m256i scale_0145_3 = _mm256_shuffle_epi32(scales_3, 68); + const __m256i scale_2367_3 = _mm256_shuffle_epi32(scales_3, 238); + + const __m256i scale_0145_4 = _mm256_shuffle_epi32(scales_4, 68); + const __m256i scale_2367_4 = _mm256_shuffle_epi32(scales_4, 238); + + const __m256i scale_0145_5 = _mm256_shuffle_epi32(scales_5, 68); + const __m256i scale_2367_5 = _mm256_shuffle_epi32(scales_5, 238); + + const __m256i scale_0145_6 = _mm256_shuffle_epi32(scales_6, 68); + const __m256i scale_2367_6 = _mm256_shuffle_epi32(scales_6, 238); + + const __m256i scale_0145_7 = _mm256_shuffle_epi32(scales_7, 68); + const __m256i scale_2367_7 = _mm256_shuffle_epi32(scales_7, 238); + + // Load the four block_q8_k quantized values interleaved with each other in chunks of eight bytes - A0,A1,A2,A3 + // Loaded as set of 128 bit vectors and repeated into a 256 bit vector + __m256i lhs_mat_0123_00 = _mm256_loadu_si256((const __m256i * )((a_ptr[b].qs + 512 * sb))); + __m256i lhs_mat_01_00 = _mm256_permute2f128_si256(lhs_mat_0123_00, lhs_mat_0123_00, 0); + __m256i lhs_mat_23_00 = _mm256_permute2f128_si256(lhs_mat_0123_00, lhs_mat_0123_00, 17); + __m256i lhs_mat_0123_01 = _mm256_loadu_si256((const __m256i * )((a_ptr[b].qs + 32 + 512 * sb))); + __m256i lhs_mat_01_01 = _mm256_permute2f128_si256(lhs_mat_0123_01, lhs_mat_0123_01, 0); + __m256i lhs_mat_23_01 = _mm256_permute2f128_si256(lhs_mat_0123_01, lhs_mat_0123_01, 17); + __m256i lhs_mat_0123_10 = _mm256_loadu_si256((const __m256i * )((a_ptr[b].qs + 64 + 512 * sb))); + __m256i lhs_mat_01_10 = _mm256_permute2f128_si256(lhs_mat_0123_10, lhs_mat_0123_10, 0); + __m256i lhs_mat_23_10 = _mm256_permute2f128_si256(lhs_mat_0123_10, lhs_mat_0123_10, 17); + __m256i lhs_mat_0123_11 = _mm256_loadu_si256((const __m256i * )((a_ptr[b].qs + 96 + 512 * sb))); + __m256i lhs_mat_01_11 = _mm256_permute2f128_si256(lhs_mat_0123_11, lhs_mat_0123_11, 0); + __m256i lhs_mat_23_11 = _mm256_permute2f128_si256(lhs_mat_0123_11, lhs_mat_0123_11, 17); + __m256i lhs_mat_0123_20 = _mm256_loadu_si256((const __m256i * )((a_ptr[b].qs + 128 + 512 * sb))); + __m256i lhs_mat_01_20 = _mm256_permute2f128_si256(lhs_mat_0123_20, lhs_mat_0123_20, 0); + __m256i lhs_mat_23_20 = _mm256_permute2f128_si256(lhs_mat_0123_20, lhs_mat_0123_20, 17); + __m256i lhs_mat_0123_21 = _mm256_loadu_si256((const __m256i * )((a_ptr[b].qs + 160 + 512 * sb))); + __m256i lhs_mat_01_21 = _mm256_permute2f128_si256(lhs_mat_0123_21, lhs_mat_0123_21, 0); + __m256i lhs_mat_23_21 = _mm256_permute2f128_si256(lhs_mat_0123_21, lhs_mat_0123_21, 17); + __m256i lhs_mat_0123_30 = _mm256_loadu_si256((const __m256i * )((a_ptr[b].qs + 192 + 512 * sb))); + __m256i lhs_mat_01_30 = _mm256_permute2f128_si256(lhs_mat_0123_30, lhs_mat_0123_30, 0); + __m256i lhs_mat_23_30 = _mm256_permute2f128_si256(lhs_mat_0123_30, lhs_mat_0123_30, 17); + __m256i lhs_mat_0123_31 = _mm256_loadu_si256((const __m256i * )((a_ptr[b].qs + 224 + 512 * sb))); + __m256i lhs_mat_01_31 = _mm256_permute2f128_si256(lhs_mat_0123_31, lhs_mat_0123_31, 0); + __m256i lhs_mat_23_31 = _mm256_permute2f128_si256(lhs_mat_0123_31, lhs_mat_0123_31, 17); + + __m256i lhs_mat_0123_40 = _mm256_loadu_si256((const __m256i * )((a_ptr[b].qs + 256 + 512 * sb))); + __m256i lhs_mat_01_40 = _mm256_permute2f128_si256(lhs_mat_0123_40, lhs_mat_0123_40, 0); + __m256i lhs_mat_23_40 = _mm256_permute2f128_si256(lhs_mat_0123_40, lhs_mat_0123_40, 17); + __m256i lhs_mat_0123_41 = _mm256_loadu_si256((const __m256i * )((a_ptr[b].qs + 288 + 512 * sb))); + __m256i lhs_mat_01_41 = _mm256_permute2f128_si256(lhs_mat_0123_41, lhs_mat_0123_41, 0); + __m256i lhs_mat_23_41 = _mm256_permute2f128_si256(lhs_mat_0123_41, lhs_mat_0123_41, 17); + __m256i lhs_mat_0123_50 = _mm256_loadu_si256((const __m256i * )((a_ptr[b].qs + 320 + 512 * sb))); + __m256i lhs_mat_01_50 = _mm256_permute2f128_si256(lhs_mat_0123_50, lhs_mat_0123_50, 0); + __m256i lhs_mat_23_50 = _mm256_permute2f128_si256(lhs_mat_0123_50, lhs_mat_0123_50, 17); + __m256i lhs_mat_0123_51 = _mm256_loadu_si256((const __m256i * )((a_ptr[b].qs + 352 + 512 * sb))); + __m256i lhs_mat_01_51 = _mm256_permute2f128_si256(lhs_mat_0123_51, lhs_mat_0123_51, 0); + __m256i lhs_mat_23_51 = _mm256_permute2f128_si256(lhs_mat_0123_51, lhs_mat_0123_51, 17); + __m256i lhs_mat_0123_60 = _mm256_loadu_si256((const __m256i * )((a_ptr[b].qs + 384 + 512 * sb))); + __m256i lhs_mat_01_60 = _mm256_permute2f128_si256(lhs_mat_0123_60, lhs_mat_0123_60, 0); + __m256i lhs_mat_23_60 = _mm256_permute2f128_si256(lhs_mat_0123_60, lhs_mat_0123_60, 17); + __m256i lhs_mat_0123_61 = _mm256_loadu_si256((const __m256i * )((a_ptr[b].qs + 416 + 512 * sb))); + __m256i lhs_mat_01_61 = _mm256_permute2f128_si256(lhs_mat_0123_61, lhs_mat_0123_61, 0); + __m256i lhs_mat_23_61 = _mm256_permute2f128_si256(lhs_mat_0123_61, lhs_mat_0123_61, 17); + __m256i lhs_mat_0123_70 = _mm256_loadu_si256((const __m256i * )((a_ptr[b].qs + 448 + 512 * sb))); + __m256i lhs_mat_01_70 = _mm256_permute2f128_si256(lhs_mat_0123_70, lhs_mat_0123_70, 0); + __m256i lhs_mat_23_70 = _mm256_permute2f128_si256(lhs_mat_0123_70, lhs_mat_0123_70, 17); + __m256i lhs_mat_0123_71 = _mm256_loadu_si256((const __m256i * )((a_ptr[b].qs + 480 + 512 * sb))); + __m256i lhs_mat_01_71 = _mm256_permute2f128_si256(lhs_mat_0123_71, lhs_mat_0123_71, 0); + __m256i lhs_mat_23_71 = _mm256_permute2f128_si256(lhs_mat_0123_71, lhs_mat_0123_71, 17); + + // Bsums are loaded for the different Q8_K blocks + __m128i lhs_raw_bsums_01_0123 = _mm_loadu_si128((const __m128i *)((a_ptr[b].bsums + 32 * sb))); + __m128i lhs_raw_bsums_23_0123 = _mm_loadu_si128((const __m128i *)(a_ptr[b].bsums + 8 + 32 * sb)); + __m128i lhs_raw_bsums_01_4567 = _mm_loadu_si128((const __m128i *)((a_ptr[b].bsums + 16 + 32 * sb))); + __m128i lhs_raw_bsums_23_4567 = _mm_loadu_si128((const __m128i *)(a_ptr[b].bsums + 24 + 32 * sb)); + + // Shuffle pattern one - left side input + const __m256i lhs_mat_01_00_sp1 = _mm256_shuffle_epi32(lhs_mat_01_00, 160); //A00(0-3) A00(0-3) A01(0-3) A01(0-3) A00(0-3) A00(0-3) A01(0-3) A01(0-3) + const __m256i lhs_mat_23_00_sp1 = _mm256_shuffle_epi32(lhs_mat_23_00, 160); //A02(0-3) A03(0-3) A02(0-3) A03(0-3) A02(0-3) A03(0-3) A02(0-3) A03(0-3) + + const __m256i lhs_mat_01_01_sp1 = _mm256_shuffle_epi32(lhs_mat_01_01, 160); //A00(8-11) A00(8-11) A01(8-11) A01(8-11) A00(8-11) A00(8-11) A01(8-11) A01(8-11) + const __m256i lhs_mat_23_01_sp1 = _mm256_shuffle_epi32(lhs_mat_23_01, 160); //A02(8-11) A03(8-11) A02(8-11) A03(8-11) A02(8-11) A03(8-11) A02(8-11) A03(8-11) + + const __m256i lhs_mat_01_10_sp1 = _mm256_shuffle_epi32(lhs_mat_01_10, 160); //A10(0-3) A10(0-3) A11(0-3) A11(0-3) A10(0-3) A10(0-3) A11(0-3) A11(0-3) + const __m256i lhs_mat_23_10_sp1 = _mm256_shuffle_epi32(lhs_mat_23_10, 160); //A12(0-3) A13(0-3) A12(0-3) A13(0-3) A12(0-3) A13(0-3) A12(0-3) A13(0-3) + + const __m256i lhs_mat_01_11_sp1 = _mm256_shuffle_epi32(lhs_mat_01_11, 160); //A10(8-11) A10(8-11) A11(8-11) A11(8-11) A10(8-11) A10(8-11) A11(8-11) A11(8-11) + const __m256i lhs_mat_23_11_sp1 = _mm256_shuffle_epi32(lhs_mat_23_11, 160); //A12(8-11) A13(8-11) A12(8-11) A13(8-11) A12(8-11) A13(8-11) A12(8-11) A13(8-11) + + const __m256i lhs_mat_01_20_sp1 = _mm256_shuffle_epi32(lhs_mat_01_20, 160); //A20(0-3) A20(0-3) A21(0-3) A21(0-3) A20(0-3) A20(0-3) A21(0-3) A21(0-3) + const __m256i lhs_mat_23_20_sp1 = _mm256_shuffle_epi32(lhs_mat_23_20, 160); //A22(0-3) A23(0-3) A22(0-3) A23(0-3) A22(0-3) A23(0-3) A22(0-3) A23(0-3) + + const __m256i lhs_mat_01_21_sp1 = _mm256_shuffle_epi32(lhs_mat_01_21, 160); //A20(8-11) A20(8-11) A21(8-11) A21(8-11) A20(8-11) A20(8-11) A21(8-11) A21(8-11) + const __m256i lhs_mat_23_21_sp1 = _mm256_shuffle_epi32(lhs_mat_23_21, 160); //A22(8-11) A23(8-11) A22(8-11) A23(8-11) A22(8-11) A23(8-11) A22(8-11) A23(8-11) + + const __m256i lhs_mat_01_30_sp1 = _mm256_shuffle_epi32(lhs_mat_01_30, 160); //A30(0-3) A30(0-3) A31(0-3) A31(0-3) A30(0-3) A30(0-3) A31(0-3) A31(0-3) + const __m256i lhs_mat_23_30_sp1 = _mm256_shuffle_epi32(lhs_mat_23_30, 160); //A32(0-3) A33(0-3) A32(0-3) A33(0-3) A32(0-3) A33(0-3) A32(0-3) A33(0-3) + + const __m256i lhs_mat_01_31_sp1 = _mm256_shuffle_epi32(lhs_mat_01_31, 160); //A30(8-11) A30(8-11) A31(8-11) A31(8-11) A30(8-11) A30(8-11) A31(8-11) A31(8-11) + const __m256i lhs_mat_23_31_sp1 = _mm256_shuffle_epi32(lhs_mat_23_31, 160); //A32(8-11) A33(8-11) A32(8-11) A33(8-11) A32(8-11) A33(8-11) A32(8-11) A33(8-11) + + const __m256i lhs_mat_01_40_sp1 = _mm256_shuffle_epi32(lhs_mat_01_40, 160); //A40(0-3) A40(0-3) A41(0-3) A41(0-3) A40(0-3) A40(0-3) A41(0-3) A41(0-3) + const __m256i lhs_mat_23_40_sp1 = _mm256_shuffle_epi32(lhs_mat_23_40, 160); //A42(0-3) A43(0-3) A42(0-3) A43(0-3) A42(0-3) A43(0-3) A42(0-3) A43(0-3) + + const __m256i lhs_mat_01_41_sp1 = _mm256_shuffle_epi32(lhs_mat_01_41, 160); //A40(8-11) A40(8-11) A41(8-11) A41(8-11) A40(8-11) A40(8-11) A41(8-11) A41(8-11) + const __m256i lhs_mat_23_41_sp1 = _mm256_shuffle_epi32(lhs_mat_23_41, 160); //A42(8-11) A43(8-11) A42(8-11) A43(8-11) A42(8-11) A43(8-11) A42(8-11) A43(8-11) + + const __m256i lhs_mat_01_50_sp1 = _mm256_shuffle_epi32(lhs_mat_01_50, 160); //A50(0-3) A50(0-3) A51(0-3) A51(0-3) A50(0-3) A50(0-3) A51(0-3) A51(0-3) + const __m256i lhs_mat_23_50_sp1 = _mm256_shuffle_epi32(lhs_mat_23_50, 160); //A52(0-3) A53(0-3) A52(0-3) A53(0-3) A52(0-3) A53(0-3) A52(0-3) A53(0-3) + + const __m256i lhs_mat_01_51_sp1 = _mm256_shuffle_epi32(lhs_mat_01_51, 160); //A50(8-11) A50(8-11) A51(8-11) A51(8-11) A50(8-11) A50(8-11) A51(8-11) A51(8-11) + const __m256i lhs_mat_23_51_sp1 = _mm256_shuffle_epi32(lhs_mat_23_51, 160); //A52(8-11) A53(8-11) A52(8-11) A53(8-11) A52(8-11) A53(8-11) A52(8-11) A53(8-11) + + const __m256i lhs_mat_01_60_sp1 = _mm256_shuffle_epi32(lhs_mat_01_60, 160); //A60(0-3) A60(0-3) A61(0-3) A61(0-3) A60(0-3) A60(0-3) A61(0-3) A61(0-3) + const __m256i lhs_mat_23_60_sp1 = _mm256_shuffle_epi32(lhs_mat_23_60, 160); //A62(0-3) A63(0-3) A62(0-3) A63(0-3) A62(0-3) A63(0-3) A62(0-3) A63(0-3) + + const __m256i lhs_mat_01_61_sp1 = _mm256_shuffle_epi32(lhs_mat_01_61, 160); //A60(8-11) A60(8-11) A61(8-11) A61(8-11) A60(8-11) A60(8-11) A61(8-11) A61(8-11) + const __m256i lhs_mat_23_61_sp1 = _mm256_shuffle_epi32(lhs_mat_23_61, 160); //A62(8-11) A63(8-11) A62(8-11) A63(8-11) A62(8-11) A63(8-11) A62(8-11) A63(8-11) + + const __m256i lhs_mat_01_70_sp1 = _mm256_shuffle_epi32(lhs_mat_01_70, 160); //A70(0-3) A70(0-3) A71(0-3) A71(0-3) A70(0-3) A70(0-3) A71(0-3) A71(0-3) + const __m256i lhs_mat_23_70_sp1 = _mm256_shuffle_epi32(lhs_mat_23_70, 160); //A72(0-3) A73(0-3) A72(0-3) A73(0-3) A72(0-3) A73(0-3) A72(0-3) A73(0-3) + + const __m256i lhs_mat_01_71_sp1 = _mm256_shuffle_epi32(lhs_mat_01_71, 160); //A70(8-11) A70(8-11) A71(8-11) A71(8-11) A70(8-11) A70(8-11) A71(8-11) A71(8-11) + const __m256i lhs_mat_23_71_sp1 = _mm256_shuffle_epi32(lhs_mat_23_71, 160); //A72(8-11) A73(8-11) A72(8-11) A73(8-11) A72(8-11) A73(8-11) A72(8-11) A73(8-11) + + // Shuffle pattern two- left side input + const __m256i lhs_mat_01_00_sp2 = _mm256_shuffle_epi32(lhs_mat_01_00, 245); //A00(4-7) A00(4-7) A01(4-7) A01(4-7) A00(4-7) A00(4-7) A01(4-7) A01(4-7) + const __m256i lhs_mat_23_00_sp2 = _mm256_shuffle_epi32(lhs_mat_23_00, 245); //A02(4-7) A03(4-7) A02(4-7) A03(4-7) A02(4-7) A03(4-7) A02(4-7) A03(4-7) + + const __m256i lhs_mat_01_01_sp2 = _mm256_shuffle_epi32(lhs_mat_01_01, 245); //A00(12-15) A00(12-15) A01(12-15) A01(12-15) A00(12-15) A00(12-15) A01(12-15) A01(12-15) + const __m256i lhs_mat_23_01_sp2 = _mm256_shuffle_epi32(lhs_mat_23_01, 245); //A02(12-15) A03(12-15) A02(12-15) A03(12-15) A02(12-15) A03(12-15) A02(12-15) A03(12-15) + + const __m256i lhs_mat_01_10_sp2 = _mm256_shuffle_epi32(lhs_mat_01_10, 245); //A10(4-7) A10(4-7) A11(4-7) A11(4-7) A10(4-7) A10(4-7) A11(4-7) A11(4-7) + const __m256i lhs_mat_23_10_sp2 = _mm256_shuffle_epi32(lhs_mat_23_10, 245); //A12(4-7) A13(4-7) A12(4-7) A13(4-7) A12(4-7) A13(4-7) A12(4-7) A13(4-7) + + const __m256i lhs_mat_01_11_sp2 = _mm256_shuffle_epi32(lhs_mat_01_11, 245); //A10(12-15) A10(12-15) A11(12-15) A11(12-15) A10(12-15) A10(12-15) A11(12-15) A11(12-15) + const __m256i lhs_mat_23_11_sp2 = _mm256_shuffle_epi32(lhs_mat_23_11, 245); //A12(12-15) A13(12-15) A12(12-15) A13(12-15) A12(12-15) A13(12-15) A12(12-15) A13(12-15) + + const __m256i lhs_mat_01_20_sp2 = _mm256_shuffle_epi32(lhs_mat_01_20, 245); //A20(4-7) A20(4-7) A21(4-7) A21(4-7) A20(4-7) A20(4-7) A21(4-7) A21(4-7) + const __m256i lhs_mat_23_20_sp2 = _mm256_shuffle_epi32(lhs_mat_23_20, 245); //A22(4-7) A23(4-7) A22(4-7) A23(4-7) A22(4-7) A23(4-7) A22(4-7) A23(4-7) + + const __m256i lhs_mat_01_21_sp2 = _mm256_shuffle_epi32(lhs_mat_01_21, 245); //A20(12-15) A20(12-15) A21(12-15) A21(12-15) A20(12-15) A20(12-15) A21(12-15) A21(12-15) + const __m256i lhs_mat_23_21_sp2 = _mm256_shuffle_epi32(lhs_mat_23_21, 245); //A22(12-15) A23(12-15) A22(12-15) A23(12-15) A22(12-15) A23(12-15) A22(12-15) A23(12-15) + + const __m256i lhs_mat_01_30_sp2 = _mm256_shuffle_epi32(lhs_mat_01_30, 245); //A30(4-7) A30(4-7) A31(4-7) A31(4-7) A30(4-7) A30(4-7) A31(4-7) A31(4-7) + const __m256i lhs_mat_23_30_sp2 = _mm256_shuffle_epi32(lhs_mat_23_30, 245); //A32(4-7) A33(4-7) A32(4-7) A33(4-7) A32(4-7) A33(4-7) A32(4-7) A33(4-7) + + const __m256i lhs_mat_01_31_sp2 = _mm256_shuffle_epi32(lhs_mat_01_31, 245); //A30(12-15) A30(12-15) A31(12-15) A31(12-15) A30(12-15) A30(12-15) A31(12-15) A31(12-15) + const __m256i lhs_mat_23_31_sp2 = _mm256_shuffle_epi32(lhs_mat_23_31, 245); //A32(12-15) A33(12-15) A32(12-15) A33(12-15) A32(12-15) A33(12-15) A32(12-15) A33(12-15) + + const __m256i lhs_mat_01_40_sp2 = _mm256_shuffle_epi32(lhs_mat_01_40, 245); //A40(4-7) A40(4-7) A41(4-7) A41(4-7) A40(4-7) A40(4-7) A41(4-7) A41(4-7) + const __m256i lhs_mat_23_40_sp2 = _mm256_shuffle_epi32(lhs_mat_23_40, 245); //A42(4-7) A43(4-7) A42(4-7) A43(4-7) A42(4-7) A43(4-7) A42(4-7) A43(4-7) + + const __m256i lhs_mat_01_41_sp2 = _mm256_shuffle_epi32(lhs_mat_01_41, 245); //A40(12-15) A40(12-15) A41(12-15) A41(12-15) A40(12-15) A40(12-15) A41(12-15) A41(12-15) + const __m256i lhs_mat_23_41_sp2 = _mm256_shuffle_epi32(lhs_mat_23_41, 245); //A42(12-15) A43(12-15) A42(12-15) A43(12-15) A42(12-15) A43(12-15) A42(12-15) A43(12-15) + + const __m256i lhs_mat_01_50_sp2 = _mm256_shuffle_epi32(lhs_mat_01_50, 245); //A50(4-7) A50(4-7) A51(4-7) A51(4-7) A50(4-7) A50(4-7) A51(4-7) A51(4-7) + const __m256i lhs_mat_23_50_sp2 = _mm256_shuffle_epi32(lhs_mat_23_50, 245); //A52(4-7) A53(4-7) A52(4-7) A53(4-7) A52(4-7) A53(4-7) A52(4-7) A53(4-7) + + const __m256i lhs_mat_01_51_sp2 = _mm256_shuffle_epi32(lhs_mat_01_51, 245); //A50(12-15) A50(12-15) A51(12-15) A51(12-15) A50(12-15) A50(12-15) A51(12-15) A51(12-15) + const __m256i lhs_mat_23_51_sp2 = _mm256_shuffle_epi32(lhs_mat_23_51, 245); //A52(12-15) A53(12-15) A52(12-15) A53(12-15) A52(12-15) A53(12-15) A52(12-15) A53(12-15) + + const __m256i lhs_mat_01_60_sp2 = _mm256_shuffle_epi32(lhs_mat_01_60, 245); //A60(4-7) A60(4-7) A61(4-7) A61(4-7) A60(4-7) A60(4-7) A61(4-7) A61(4-7) + const __m256i lhs_mat_23_60_sp2 = _mm256_shuffle_epi32(lhs_mat_23_60, 245); //A62(4-7) A63(4-7) A62(4-7) A63(4-7) A62(4-7) A63(4-7) A62(4-7) A63(4-7) + + const __m256i lhs_mat_01_61_sp2 = _mm256_shuffle_epi32(lhs_mat_01_61, 245); //A60(12-15) A60(12-15) A61(12-15) A61(12-15) A60(12-15) A60(12-15) A61(12-15) A61(12-15) + const __m256i lhs_mat_23_61_sp2 = _mm256_shuffle_epi32(lhs_mat_23_61, 245); //A62(12-15) A63(12-15) A62(12-15) A63(12-15) A62(12-15) A63(12-15) A62(12-15) A63(12-15) + + const __m256i lhs_mat_01_70_sp2 = _mm256_shuffle_epi32(lhs_mat_01_70, 245); //A70(4-7) A70(4-7) A71(4-7) A71(4-7) A70(4-7) A70(4-7) A71(4-7) A71(4-7) + const __m256i lhs_mat_23_70_sp2 = _mm256_shuffle_epi32(lhs_mat_23_70, 245); //A72(4-7) A73(4-7) A72(4-7) A73(4-7) A72(4-7) A73(4-7) A72(4-7) A73(4-7) + + const __m256i lhs_mat_01_71_sp2 = _mm256_shuffle_epi32(lhs_mat_01_71, 245); //A70(12-15) A70(12-15) A71(12-15) A71(12-15) A70(12-15) A70(12-15) A71(12-15) A71(12-15) + const __m256i lhs_mat_23_71_sp2 = _mm256_shuffle_epi32(lhs_mat_23_71, 245); //A72(12-15) A73(12-15) A72(12-15) A73(12-15) A72(12-15) A73(12-15) A72(12-15) A73(12-15) + + // The values arranged in shuffle patterns are operated with dot product operation within 32 bit lane i.e corresponding bytes and multiplied and added into 32 bit integers within 32 bit lane + __m256i iacc_mat_00_0_sp1 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_00_sp1, lhs_mat_01_00_sp1),_mm256_maddubs_epi16(rhs_mat_0145_01_sp1, lhs_mat_01_01_sp1)); + __m256i iacc_mat_01_0_sp1 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_00_sp1, lhs_mat_01_00_sp1),_mm256_maddubs_epi16(rhs_mat_2367_01_sp1, lhs_mat_01_01_sp1)); + + __m256i iacc_mat_10_0_sp1 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_00_sp1, lhs_mat_23_00_sp1),_mm256_maddubs_epi16(rhs_mat_0145_01_sp1, lhs_mat_23_01_sp1)); + __m256i iacc_mat_11_0_sp1 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_00_sp1, lhs_mat_23_00_sp1),_mm256_maddubs_epi16(rhs_mat_2367_01_sp1, lhs_mat_23_01_sp1)); + + __m256i iacc_mat_00_1_sp1 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_10_sp1, lhs_mat_01_10_sp1),_mm256_maddubs_epi16(rhs_mat_0145_11_sp1, lhs_mat_01_11_sp1)); + __m256i iacc_mat_01_1_sp1 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_10_sp1, lhs_mat_01_10_sp1),_mm256_maddubs_epi16(rhs_mat_2367_11_sp1, lhs_mat_01_11_sp1)); + + __m256i iacc_mat_10_1_sp1 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_10_sp1, lhs_mat_23_10_sp1),_mm256_maddubs_epi16(rhs_mat_0145_11_sp1, lhs_mat_23_11_sp1)); + __m256i iacc_mat_11_1_sp1 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_10_sp1, lhs_mat_23_10_sp1),_mm256_maddubs_epi16(rhs_mat_2367_11_sp1, lhs_mat_23_11_sp1)); + + __m256i iacc_mat_00_2_sp1 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_20_sp1, lhs_mat_01_20_sp1),_mm256_maddubs_epi16(rhs_mat_0145_21_sp1, lhs_mat_01_21_sp1)); + __m256i iacc_mat_01_2_sp1 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_20_sp1, lhs_mat_01_20_sp1),_mm256_maddubs_epi16(rhs_mat_2367_21_sp1, lhs_mat_01_21_sp1)); + + __m256i iacc_mat_10_2_sp1 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_20_sp1, lhs_mat_23_20_sp1),_mm256_maddubs_epi16(rhs_mat_0145_21_sp1, lhs_mat_23_21_sp1)); + __m256i iacc_mat_11_2_sp1 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_20_sp1, lhs_mat_23_20_sp1),_mm256_maddubs_epi16(rhs_mat_2367_21_sp1, lhs_mat_23_21_sp1)); + + __m256i iacc_mat_00_3_sp1 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_30_sp1, lhs_mat_01_30_sp1),_mm256_maddubs_epi16(rhs_mat_0145_31_sp1, lhs_mat_01_31_sp1)); + __m256i iacc_mat_01_3_sp1 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_30_sp1, lhs_mat_01_30_sp1),_mm256_maddubs_epi16(rhs_mat_2367_31_sp1, lhs_mat_01_31_sp1)); + + __m256i iacc_mat_10_3_sp1 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_30_sp1, lhs_mat_23_30_sp1),_mm256_maddubs_epi16(rhs_mat_0145_31_sp1, lhs_mat_23_31_sp1)); + __m256i iacc_mat_11_3_sp1 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_30_sp1, lhs_mat_23_30_sp1),_mm256_maddubs_epi16(rhs_mat_2367_31_sp1, lhs_mat_23_31_sp1)); + + __m256i iacc_mat_00_4_sp1 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_40_sp1, lhs_mat_01_40_sp1),_mm256_maddubs_epi16(rhs_mat_0145_41_sp1, lhs_mat_01_41_sp1)); + __m256i iacc_mat_01_4_sp1 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_40_sp1, lhs_mat_01_40_sp1),_mm256_maddubs_epi16(rhs_mat_2367_41_sp1, lhs_mat_01_41_sp1)); + + __m256i iacc_mat_10_4_sp1 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_40_sp1, lhs_mat_23_40_sp1),_mm256_maddubs_epi16(rhs_mat_0145_41_sp1, lhs_mat_23_41_sp1)); + __m256i iacc_mat_11_4_sp1 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_40_sp1, lhs_mat_23_40_sp1),_mm256_maddubs_epi16(rhs_mat_2367_41_sp1, lhs_mat_23_41_sp1)); + + __m256i iacc_mat_00_5_sp1 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_50_sp1, lhs_mat_01_50_sp1),_mm256_maddubs_epi16(rhs_mat_0145_51_sp1, lhs_mat_01_51_sp1)); + __m256i iacc_mat_01_5_sp1 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_50_sp1, lhs_mat_01_50_sp1),_mm256_maddubs_epi16(rhs_mat_2367_51_sp1, lhs_mat_01_51_sp1)); + + __m256i iacc_mat_10_5_sp1 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_50_sp1, lhs_mat_23_50_sp1),_mm256_maddubs_epi16(rhs_mat_0145_51_sp1, lhs_mat_23_51_sp1)); + __m256i iacc_mat_11_5_sp1 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_50_sp1, lhs_mat_23_50_sp1),_mm256_maddubs_epi16(rhs_mat_2367_51_sp1, lhs_mat_23_51_sp1)); + + __m256i iacc_mat_00_6_sp1 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_60_sp1, lhs_mat_01_60_sp1),_mm256_maddubs_epi16(rhs_mat_0145_61_sp1, lhs_mat_01_61_sp1)); + __m256i iacc_mat_01_6_sp1 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_60_sp1, lhs_mat_01_60_sp1),_mm256_maddubs_epi16(rhs_mat_2367_61_sp1, lhs_mat_01_61_sp1)); + + __m256i iacc_mat_10_6_sp1 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_60_sp1, lhs_mat_23_60_sp1),_mm256_maddubs_epi16(rhs_mat_0145_61_sp1, lhs_mat_23_61_sp1)); + __m256i iacc_mat_11_6_sp1 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_60_sp1, lhs_mat_23_60_sp1),_mm256_maddubs_epi16(rhs_mat_2367_61_sp1, lhs_mat_23_61_sp1)); + + __m256i iacc_mat_00_7_sp1 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_70_sp1, lhs_mat_01_70_sp1),_mm256_maddubs_epi16(rhs_mat_0145_71_sp1, lhs_mat_01_71_sp1)); + __m256i iacc_mat_01_7_sp1 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_70_sp1, lhs_mat_01_70_sp1),_mm256_maddubs_epi16(rhs_mat_2367_71_sp1, lhs_mat_01_71_sp1)); + + __m256i iacc_mat_10_7_sp1 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_70_sp1, lhs_mat_23_70_sp1),_mm256_maddubs_epi16(rhs_mat_0145_71_sp1, lhs_mat_23_71_sp1)); + __m256i iacc_mat_11_7_sp1 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_70_sp1, lhs_mat_23_70_sp1),_mm256_maddubs_epi16(rhs_mat_2367_71_sp1, lhs_mat_23_71_sp1)); + + + __m256i iacc_mat_00_0_sp2 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_00_sp2, lhs_mat_01_00_sp2),_mm256_maddubs_epi16(rhs_mat_0145_01_sp2, lhs_mat_01_01_sp2)); + __m256i iacc_mat_01_0_sp2 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_00_sp2, lhs_mat_01_00_sp2),_mm256_maddubs_epi16(rhs_mat_2367_01_sp2, lhs_mat_01_01_sp2)); + + __m256i iacc_mat_10_0_sp2 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_00_sp2, lhs_mat_23_00_sp2),_mm256_maddubs_epi16(rhs_mat_0145_01_sp2, lhs_mat_23_01_sp2)); + __m256i iacc_mat_11_0_sp2 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_00_sp2, lhs_mat_23_00_sp2),_mm256_maddubs_epi16(rhs_mat_2367_01_sp2, lhs_mat_23_01_sp2)); + + __m256i iacc_mat_00_1_sp2 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_10_sp2, lhs_mat_01_10_sp2),_mm256_maddubs_epi16(rhs_mat_0145_11_sp2, lhs_mat_01_11_sp2)); + __m256i iacc_mat_01_1_sp2 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_10_sp2, lhs_mat_01_10_sp2),_mm256_maddubs_epi16(rhs_mat_2367_11_sp2, lhs_mat_01_11_sp2)); + + __m256i iacc_mat_10_1_sp2 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_10_sp2, lhs_mat_23_10_sp2),_mm256_maddubs_epi16(rhs_mat_0145_11_sp2, lhs_mat_23_11_sp2)); + __m256i iacc_mat_11_1_sp2 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_10_sp2, lhs_mat_23_10_sp2),_mm256_maddubs_epi16(rhs_mat_2367_11_sp2, lhs_mat_23_11_sp2)); + + __m256i iacc_mat_00_2_sp2 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_20_sp2, lhs_mat_01_20_sp2),_mm256_maddubs_epi16(rhs_mat_0145_21_sp2, lhs_mat_01_21_sp2)); + __m256i iacc_mat_01_2_sp2 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_20_sp2, lhs_mat_01_20_sp2),_mm256_maddubs_epi16(rhs_mat_2367_21_sp2, lhs_mat_01_21_sp2)); + + __m256i iacc_mat_10_2_sp2 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_20_sp2, lhs_mat_23_20_sp2),_mm256_maddubs_epi16(rhs_mat_0145_21_sp2, lhs_mat_23_21_sp2)); + __m256i iacc_mat_11_2_sp2 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_20_sp2, lhs_mat_23_20_sp2),_mm256_maddubs_epi16(rhs_mat_2367_21_sp2, lhs_mat_23_21_sp2)); + + __m256i iacc_mat_00_3_sp2 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_30_sp2, lhs_mat_01_30_sp2),_mm256_maddubs_epi16(rhs_mat_0145_31_sp2, lhs_mat_01_31_sp2)); + __m256i iacc_mat_01_3_sp2 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_30_sp2, lhs_mat_01_30_sp2),_mm256_maddubs_epi16(rhs_mat_2367_31_sp2, lhs_mat_01_31_sp2)); + + __m256i iacc_mat_10_3_sp2 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_30_sp2, lhs_mat_23_30_sp2),_mm256_maddubs_epi16(rhs_mat_0145_31_sp2, lhs_mat_23_31_sp2)); + __m256i iacc_mat_11_3_sp2 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_30_sp2, lhs_mat_23_30_sp2),_mm256_maddubs_epi16(rhs_mat_2367_31_sp2, lhs_mat_23_31_sp2)); + + __m256i iacc_mat_00_4_sp2 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_40_sp2, lhs_mat_01_40_sp2),_mm256_maddubs_epi16(rhs_mat_0145_41_sp2, lhs_mat_01_41_sp2)); + __m256i iacc_mat_01_4_sp2 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_40_sp2, lhs_mat_01_40_sp2),_mm256_maddubs_epi16(rhs_mat_2367_41_sp2, lhs_mat_01_41_sp2)); + + __m256i iacc_mat_10_4_sp2 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_40_sp2, lhs_mat_23_40_sp2),_mm256_maddubs_epi16(rhs_mat_0145_41_sp2, lhs_mat_23_41_sp2)); + __m256i iacc_mat_11_4_sp2 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_40_sp2, lhs_mat_23_40_sp2),_mm256_maddubs_epi16(rhs_mat_2367_41_sp2, lhs_mat_23_41_sp2)); + + __m256i iacc_mat_00_5_sp2 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_50_sp2, lhs_mat_01_50_sp2),_mm256_maddubs_epi16(rhs_mat_0145_51_sp2, lhs_mat_01_51_sp2)); + __m256i iacc_mat_01_5_sp2 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_50_sp2, lhs_mat_01_50_sp2),_mm256_maddubs_epi16(rhs_mat_2367_51_sp2, lhs_mat_01_51_sp2)); + + __m256i iacc_mat_10_5_sp2 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_50_sp2, lhs_mat_23_50_sp2),_mm256_maddubs_epi16(rhs_mat_0145_51_sp2, lhs_mat_23_51_sp2)); + __m256i iacc_mat_11_5_sp2 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_50_sp2, lhs_mat_23_50_sp2),_mm256_maddubs_epi16(rhs_mat_2367_51_sp2, lhs_mat_23_51_sp2)); + + __m256i iacc_mat_00_6_sp2 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_60_sp2, lhs_mat_01_60_sp2),_mm256_maddubs_epi16(rhs_mat_0145_61_sp2, lhs_mat_01_61_sp2)); + __m256i iacc_mat_01_6_sp2 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_60_sp2, lhs_mat_01_60_sp2),_mm256_maddubs_epi16(rhs_mat_2367_61_sp2, lhs_mat_01_61_sp2)); + + __m256i iacc_mat_10_6_sp2 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_60_sp2, lhs_mat_23_60_sp2),_mm256_maddubs_epi16(rhs_mat_0145_61_sp2, lhs_mat_23_61_sp2)); + __m256i iacc_mat_11_6_sp2 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_60_sp2, lhs_mat_23_60_sp2),_mm256_maddubs_epi16(rhs_mat_2367_61_sp2, lhs_mat_23_61_sp2)); + + __m256i iacc_mat_00_7_sp2 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_70_sp2, lhs_mat_01_70_sp2),_mm256_maddubs_epi16(rhs_mat_0145_71_sp2, lhs_mat_01_71_sp2)); + __m256i iacc_mat_01_7_sp2 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_70_sp2, lhs_mat_01_70_sp2),_mm256_maddubs_epi16(rhs_mat_2367_71_sp2, lhs_mat_01_71_sp2)); + + __m256i iacc_mat_10_7_sp2 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_70_sp2, lhs_mat_23_70_sp2),_mm256_maddubs_epi16(rhs_mat_0145_71_sp2, lhs_mat_23_71_sp2)); + __m256i iacc_mat_11_7_sp2 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_70_sp2, lhs_mat_23_70_sp2),_mm256_maddubs_epi16(rhs_mat_2367_71_sp2, lhs_mat_23_71_sp2)); + + // Combine results from both shuffle patterns for each output block. + __m256i iacc_mat_00_0 = _mm256_add_epi16(iacc_mat_00_0_sp1, iacc_mat_00_0_sp2); + __m256i iacc_mat_01_0 = _mm256_add_epi16(iacc_mat_01_0_sp1, iacc_mat_01_0_sp2); + __m256i iacc_mat_10_0 = _mm256_add_epi16(iacc_mat_10_0_sp1, iacc_mat_10_0_sp2); + __m256i iacc_mat_11_0 = _mm256_add_epi16(iacc_mat_11_0_sp1, iacc_mat_11_0_sp2); + + __m256i iacc_mat_00_1 = _mm256_add_epi16(iacc_mat_00_1_sp1, iacc_mat_00_1_sp2); + __m256i iacc_mat_01_1 = _mm256_add_epi16(iacc_mat_01_1_sp1, iacc_mat_01_1_sp2); + __m256i iacc_mat_10_1 = _mm256_add_epi16(iacc_mat_10_1_sp1, iacc_mat_10_1_sp2); + __m256i iacc_mat_11_1 = _mm256_add_epi16(iacc_mat_11_1_sp1, iacc_mat_11_1_sp2); + + __m256i iacc_mat_00_2 = _mm256_add_epi16(iacc_mat_00_2_sp1, iacc_mat_00_2_sp2); + __m256i iacc_mat_01_2 = _mm256_add_epi16(iacc_mat_01_2_sp1, iacc_mat_01_2_sp2); + __m256i iacc_mat_10_2 = _mm256_add_epi16(iacc_mat_10_2_sp1, iacc_mat_10_2_sp2); + __m256i iacc_mat_11_2 = _mm256_add_epi16(iacc_mat_11_2_sp1, iacc_mat_11_2_sp2); + + __m256i iacc_mat_00_3 = _mm256_add_epi16(iacc_mat_00_3_sp1, iacc_mat_00_3_sp2); + __m256i iacc_mat_01_3 = _mm256_add_epi16(iacc_mat_01_3_sp1, iacc_mat_01_3_sp2); + __m256i iacc_mat_10_3 = _mm256_add_epi16(iacc_mat_10_3_sp1, iacc_mat_10_3_sp2); + __m256i iacc_mat_11_3 = _mm256_add_epi16(iacc_mat_11_3_sp1, iacc_mat_11_3_sp2); + + __m256i iacc_mat_00_4 = _mm256_add_epi16(iacc_mat_00_4_sp1, iacc_mat_00_4_sp2); + __m256i iacc_mat_01_4 = _mm256_add_epi16(iacc_mat_01_4_sp1, iacc_mat_01_4_sp2); + __m256i iacc_mat_10_4 = _mm256_add_epi16(iacc_mat_10_4_sp1, iacc_mat_10_4_sp2); + __m256i iacc_mat_11_4 = _mm256_add_epi16(iacc_mat_11_4_sp1, iacc_mat_11_4_sp2); + + __m256i iacc_mat_00_5 = _mm256_add_epi16(iacc_mat_00_5_sp1, iacc_mat_00_5_sp2); + __m256i iacc_mat_01_5 = _mm256_add_epi16(iacc_mat_01_5_sp1, iacc_mat_01_5_sp2); + __m256i iacc_mat_10_5 = _mm256_add_epi16(iacc_mat_10_5_sp1, iacc_mat_10_5_sp2); + __m256i iacc_mat_11_5 = _mm256_add_epi16(iacc_mat_11_5_sp1, iacc_mat_11_5_sp2); + + __m256i iacc_mat_00_6 = _mm256_add_epi16(iacc_mat_00_6_sp1, iacc_mat_00_6_sp2); + __m256i iacc_mat_01_6 = _mm256_add_epi16(iacc_mat_01_6_sp1, iacc_mat_01_6_sp2); + __m256i iacc_mat_10_6 = _mm256_add_epi16(iacc_mat_10_6_sp1, iacc_mat_10_6_sp2); + __m256i iacc_mat_11_6 = _mm256_add_epi16(iacc_mat_11_6_sp1, iacc_mat_11_6_sp2); + + __m256i iacc_mat_00_7 = _mm256_add_epi16(iacc_mat_00_7_sp1, iacc_mat_00_7_sp2); + __m256i iacc_mat_01_7 = _mm256_add_epi16(iacc_mat_01_7_sp1, iacc_mat_01_7_sp2); + __m256i iacc_mat_10_7 = _mm256_add_epi16(iacc_mat_10_7_sp1, iacc_mat_10_7_sp2); + __m256i iacc_mat_11_7 = _mm256_add_epi16(iacc_mat_11_7_sp1, iacc_mat_11_7_sp2); + + // Output of both shuffle patterns are added in order to sum dot product outputs of all 32 values in block + iacc_mat_00_0 = _mm256_madd_epi16(iacc_mat_00_0, scale_0145_0); + iacc_mat_01_0 = _mm256_madd_epi16(iacc_mat_01_0, scale_2367_0); + iacc_mat_10_0 = _mm256_madd_epi16(iacc_mat_10_0, scale_0145_0); + iacc_mat_11_0 = _mm256_madd_epi16(iacc_mat_11_0, scale_2367_0); + + iacc_mat_00_1 = _mm256_madd_epi16(iacc_mat_00_1, scale_0145_1); + iacc_mat_01_1 = _mm256_madd_epi16(iacc_mat_01_1, scale_2367_1); + iacc_mat_10_1 = _mm256_madd_epi16(iacc_mat_10_1, scale_0145_1); + iacc_mat_11_1 = _mm256_madd_epi16(iacc_mat_11_1, scale_2367_1); + + iacc_mat_00_2 = _mm256_madd_epi16(iacc_mat_00_2, scale_0145_2); + iacc_mat_01_2 = _mm256_madd_epi16(iacc_mat_01_2, scale_2367_2); + iacc_mat_10_2 = _mm256_madd_epi16(iacc_mat_10_2, scale_0145_2); + iacc_mat_11_2 = _mm256_madd_epi16(iacc_mat_11_2, scale_2367_2); + + iacc_mat_00_3 = _mm256_madd_epi16(iacc_mat_00_3, scale_0145_3); + iacc_mat_01_3 = _mm256_madd_epi16(iacc_mat_01_3, scale_2367_3); + iacc_mat_10_3 = _mm256_madd_epi16(iacc_mat_10_3, scale_0145_3); + iacc_mat_11_3 = _mm256_madd_epi16(iacc_mat_11_3, scale_2367_3); + + iacc_mat_00_4 = _mm256_madd_epi16(iacc_mat_00_4, scale_0145_4); + iacc_mat_01_4 = _mm256_madd_epi16(iacc_mat_01_4, scale_2367_4); + iacc_mat_10_4 = _mm256_madd_epi16(iacc_mat_10_4, scale_0145_4); + iacc_mat_11_4 = _mm256_madd_epi16(iacc_mat_11_4, scale_2367_4); + + iacc_mat_00_5 = _mm256_madd_epi16(iacc_mat_00_5, scale_0145_5); + iacc_mat_01_5 = _mm256_madd_epi16(iacc_mat_01_5, scale_2367_5); + iacc_mat_10_5 = _mm256_madd_epi16(iacc_mat_10_5, scale_0145_5); + iacc_mat_11_5 = _mm256_madd_epi16(iacc_mat_11_5, scale_2367_5); + + iacc_mat_00_6 = _mm256_madd_epi16(iacc_mat_00_6, scale_0145_6); + iacc_mat_01_6 = _mm256_madd_epi16(iacc_mat_01_6, scale_2367_6); + iacc_mat_10_6 = _mm256_madd_epi16(iacc_mat_10_6, scale_0145_6); + iacc_mat_11_6 = _mm256_madd_epi16(iacc_mat_11_6, scale_2367_6); + + iacc_mat_00_7 = _mm256_madd_epi16(iacc_mat_00_7, scale_0145_7); + iacc_mat_01_7 = _mm256_madd_epi16(iacc_mat_01_7, scale_2367_7); + iacc_mat_10_7 = _mm256_madd_epi16(iacc_mat_10_7, scale_0145_7); + iacc_mat_11_7 = _mm256_madd_epi16(iacc_mat_11_7, scale_2367_7); + + __m256i iacc_mat_00 = _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(iacc_mat_00_0, iacc_mat_00_1), _mm256_add_epi32(iacc_mat_00_2, iacc_mat_00_3)), _mm256_add_epi32(_mm256_add_epi32(iacc_mat_00_4, iacc_mat_00_5), _mm256_add_epi32(iacc_mat_00_6, iacc_mat_00_7))); + __m256i iacc_mat_01 = _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(iacc_mat_01_0, iacc_mat_01_1), _mm256_add_epi32(iacc_mat_01_2, iacc_mat_01_3)), _mm256_add_epi32(_mm256_add_epi32(iacc_mat_01_4, iacc_mat_01_5), _mm256_add_epi32(iacc_mat_01_6, iacc_mat_01_7))); + __m256i iacc_mat_10 = _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(iacc_mat_10_0, iacc_mat_10_1), _mm256_add_epi32(iacc_mat_10_2, iacc_mat_10_3)), _mm256_add_epi32(_mm256_add_epi32(iacc_mat_10_4, iacc_mat_10_5), _mm256_add_epi32(iacc_mat_10_6, iacc_mat_10_7))); + __m256i iacc_mat_11 = _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(iacc_mat_11_0, iacc_mat_11_1), _mm256_add_epi32(iacc_mat_11_2, iacc_mat_11_3)), _mm256_add_epi32(_mm256_add_epi32(iacc_mat_11_4, iacc_mat_11_5), _mm256_add_epi32(iacc_mat_11_6, iacc_mat_11_7))); + + // Straighten out to make 4 row vectors + __m256i iacc_row_0 = _mm256_blend_epi32(iacc_mat_00, _mm256_shuffle_epi32(iacc_mat_01, 78), 204); + __m256i iacc_row_1 = _mm256_blend_epi32(_mm256_shuffle_epi32(iacc_mat_00, 78), iacc_mat_01, 204); + __m256i iacc_row_2 = _mm256_blend_epi32(iacc_mat_10, _mm256_shuffle_epi32(iacc_mat_11, 78), 204); + __m256i iacc_row_3 = _mm256_blend_epi32(_mm256_shuffle_epi32(iacc_mat_10, 78), iacc_mat_11, 204); + + // Load the scale(d) values for all the 4 Q8_k blocks and repeat it across lanes + const __m128 row_scale_f32_sse = _mm_load_ps(a_ptr[b].d); + const __m256 row_scale_f32 = _mm256_set_m128(row_scale_f32_sse, row_scale_f32_sse); + + // Multiply with appropiate scales and accumulate (for both d and dmin) below + acc_rows[0] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_0), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 0)), acc_rows[0]); + acc_rows[1] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_1), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 85)), acc_rows[1]); + acc_rows[2] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_2), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 170)), acc_rows[2]); + acc_rows[3] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_3), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 255)), acc_rows[3]); + + __m256i lhs_bsums_01_0123 = _mm256_inserti128_si256(_mm256_castsi128_si256(lhs_raw_bsums_01_0123), lhs_raw_bsums_01_0123, 1); + __m256i lhs_bsums_23_0123 = _mm256_inserti128_si256(_mm256_castsi128_si256(lhs_raw_bsums_23_0123), lhs_raw_bsums_23_0123, 1); + __m256i lhs_bsums_01_4567 = _mm256_inserti128_si256(_mm256_castsi128_si256(lhs_raw_bsums_01_4567), lhs_raw_bsums_01_4567, 1); + __m256i lhs_bsums_23_4567 = _mm256_inserti128_si256(_mm256_castsi128_si256(lhs_raw_bsums_23_4567), lhs_raw_bsums_23_4567, 1); + + // Take two bsums from two Q8_Ks at a time and multiply with corresponding mins values from each Q2_K + __m256i iacc_row_min_0_01 = _mm256_madd_epi16(_mm256_shuffle_epi32(lhs_bsums_01_0123, 0), mins_01); + __m256i iacc_row_min_1_01 = _mm256_madd_epi16(_mm256_shuffle_epi32(lhs_bsums_01_0123, 170), mins_01); + __m256i iacc_row_min_2_01 = _mm256_madd_epi16(_mm256_shuffle_epi32(lhs_bsums_23_0123, 0), mins_01); + __m256i iacc_row_min_3_01 = _mm256_madd_epi16(_mm256_shuffle_epi32(lhs_bsums_23_0123, 170), mins_01); + + __m256i iacc_row_min_0_23 = _mm256_madd_epi16(_mm256_shuffle_epi32(lhs_bsums_01_0123, 85), mins_23); + __m256i iacc_row_min_1_23 = _mm256_madd_epi16(_mm256_shuffle_epi32(lhs_bsums_01_0123, 255), mins_23); + __m256i iacc_row_min_2_23 = _mm256_madd_epi16(_mm256_shuffle_epi32(lhs_bsums_23_0123, 85), mins_23); + __m256i iacc_row_min_3_23 = _mm256_madd_epi16(_mm256_shuffle_epi32(lhs_bsums_23_0123, 255), mins_23); + + __m256i iacc_row_min_0_45 = _mm256_madd_epi16(_mm256_shuffle_epi32(lhs_bsums_01_4567, 0), mins_45); + __m256i iacc_row_min_1_45 = _mm256_madd_epi16(_mm256_shuffle_epi32(lhs_bsums_01_4567, 170), mins_45); + __m256i iacc_row_min_2_45 = _mm256_madd_epi16(_mm256_shuffle_epi32(lhs_bsums_23_4567, 0), mins_45); + __m256i iacc_row_min_3_45 = _mm256_madd_epi16(_mm256_shuffle_epi32(lhs_bsums_23_4567, 170), mins_45); + + __m256i iacc_row_min_0_67 = _mm256_madd_epi16(_mm256_shuffle_epi32(lhs_bsums_01_4567, 85), mins_67); + __m256i iacc_row_min_1_67 = _mm256_madd_epi16(_mm256_shuffle_epi32(lhs_bsums_01_4567, 255), mins_67); + __m256i iacc_row_min_2_67 = _mm256_madd_epi16(_mm256_shuffle_epi32(lhs_bsums_23_4567, 85), mins_67); + __m256i iacc_row_min_3_67 = _mm256_madd_epi16(_mm256_shuffle_epi32(lhs_bsums_23_4567, 255), mins_67); + + __m256i iacc_row_min_0 = _mm256_add_epi32(_mm256_add_epi32(iacc_row_min_0_01, iacc_row_min_0_23), _mm256_add_epi32(iacc_row_min_0_45,iacc_row_min_0_67)); + __m256i iacc_row_min_1 = _mm256_add_epi32(_mm256_add_epi32(iacc_row_min_1_01, iacc_row_min_1_23), _mm256_add_epi32(iacc_row_min_1_45,iacc_row_min_1_67)); + __m256i iacc_row_min_2 = _mm256_add_epi32(_mm256_add_epi32(iacc_row_min_2_01, iacc_row_min_2_23), _mm256_add_epi32(iacc_row_min_2_45,iacc_row_min_2_67)); + __m256i iacc_row_min_3 = _mm256_add_epi32(_mm256_add_epi32(iacc_row_min_3_01, iacc_row_min_3_23), _mm256_add_epi32(iacc_row_min_3_45,iacc_row_min_3_67)); + + acc_min_rows[0] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_min_0), _mm256_mul_ps(col_dmin_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 0)), acc_min_rows[0]); + acc_min_rows[1] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_min_1), _mm256_mul_ps(col_dmin_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 85)), acc_min_rows[1]); + acc_min_rows[2] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_min_2), _mm256_mul_ps(col_dmin_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 170)), acc_min_rows[2]); + acc_min_rows[3] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_min_3), _mm256_mul_ps(col_dmin_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 255)), acc_min_rows[3]); + } + } + // Store the accumulated values + for (int i = 0; i < 4; i++) { + _mm256_storeu_ps((float * )(s + ((y * 4 + i) * bs + x * 8)), _mm256_sub_ps(acc_rows[i], acc_min_rows[i])); + } + } + } +#else + + ggml_gemm_q2_K_8x8_q8_K_generic(n, s, bs, vx, vy, nr, nc); + + +#endif +} diff --git a/ggml/src/ggml-cpu/repack.cpp b/ggml/src/ggml-cpu/repack.cpp index 44952aea9..f3a88f79c 100644 --- a/ggml/src/ggml-cpu/repack.cpp +++ b/ggml/src/ggml-cpu/repack.cpp @@ -412,6 +412,82 @@ void ggml_gemv_q4_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, } } +void ggml_gemv_q2_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + const int qk = QK_K; + const int nb = n / qk; + const int ncols_interleaved = 8; + const int blocklen = 8; + + assert (n % qk == 0); + assert (nc % ncols_interleaved == 0); + + UNUSED(s); + UNUSED(bs); + UNUSED(vx); + UNUSED(vy); + UNUSED(nr); + UNUSED(nc); + UNUSED(nb); + UNUSED(ncols_interleaved); + UNUSED(blocklen); + + float sumf[8]; + float sum_minf[8]; + int sumi1,sumi2,sumi3,sumi4; + int sumi; + + const block_q8_K * a_ptr = (const block_q8_K *)vy; + for(int x = 0; x < nc / ncols_interleaved; x++) { + const block_q2_Kx8 * b_ptr = (const block_q2_Kx8 *) vx + (x * nb); + for (int j = 0; j < ncols_interleaved; j++) { + sumf[j] = 0.0; + sum_minf[j] = 0.0; + } + for (int l = 0; l < nb; l++) { + for (int k = 0; k < (qk / (4 * blocklen)); k++) { + const uint8_t *scales_0 = b_ptr[l].scales + (k / 4) * 64 ; + const uint8_t *scales_1 = b_ptr[l].scales + (k / 4) * 64 + 16; + const uint8_t *scales_2 = b_ptr[l].scales + (k / 4) * 64 + 32; + const uint8_t *scales_3 = b_ptr[l].scales + (k / 4) * 64 + 48; + for (int j = 0; j < ncols_interleaved; j++) { + sumi1 = 0; + sumi2 = 0; + sumi3 = 0; + sumi4 = 0; + sumi = 0; + int offset = ((k / 2) % 2) + j * 2; + for (int i = 0; i < blocklen; ++i){ + const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 3); + const int v1 = (int8_t) ((b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 2 ) & 3); + const int v2 = (int8_t) ((b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 4 ) & 3); + const int v3 = (int8_t) ((b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 6 ) & 3); + sumi1 = (v0 * a_ptr[l].qs[(k >> 2) * 128 + (k % 4) * blocklen + i]); + sumi2 = (v1 * a_ptr[l].qs[(k >> 2) * 128 + (k % 4) * blocklen + i + 32]); + sumi3 = (v2 * a_ptr[l].qs[(k >> 2) * 128 + (k % 4) * blocklen + i + 64]); + sumi4 = (v3 * a_ptr[l].qs[(k >> 2) * 128 + (k % 4) * blocklen + i + 96]); + + sumi1 = sumi1 * (scales_0[offset] & 0xF); + sumi2 = sumi2 * (scales_1[offset] & 0xF); + sumi3 = sumi3 * (scales_2[offset] & 0xF); + sumi4 = sumi4 * (scales_3[offset] & 0xF); + sumi += sumi1 + sumi2 + sumi3 + sumi4; + } + sumf[j] += sumi * GGML_FP16_TO_FP32(b_ptr[l].d[j]) * a_ptr[l].d; + } + } + for(int sb = 0; sb < 8; sb++) { + const uint8_t *mins = b_ptr[l].scales + sb * 16; + for(int j = 0; j < ncols_interleaved; j++){ + sum_minf[j] += ((mins[j * 2] >> 4) * a_ptr[l].bsums[sb * 2] + (mins[(j * 2)+ 1] >> 4) * a_ptr[l].bsums[sb * 2 + 1]) * GGML_FP16_TO_FP32(b_ptr[l].dmin[j]) * a_ptr[l].d; + } + } + } + for (int j = 0; j < ncols_interleaved; j++) { + s[x * ncols_interleaved + j] = sumf[j] - sum_minf[j]; + } + } +} + void ggml_gemv_iq4_nl_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { const int qk = QK8_0; const int nb = n / qk; @@ -711,6 +787,97 @@ void ggml_gemm_q4_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, } } +void ggml_gemm_q2_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + const int qk = QK_K; + const int nb = n / qk; + const int ncols_interleaved = 8; + const int blocklen = 8; + + assert (n % qk == 0); + assert (nr % 4 == 0); + assert (nc % ncols_interleaved == 0); + + UNUSED(s); + UNUSED(bs); + UNUSED(vx); + UNUSED(vy); + UNUSED(nr); + UNUSED(nc); + UNUSED(nb); + UNUSED(ncols_interleaved); + UNUSED(blocklen); + + float sumf[4][8]; + float sum_minf[4][8]; + int sumi1, sumi2, sumi3, sumi4; + int sumi; + + for (int y = 0; y < nr / 4; y++) { + const block_q8_Kx4 * a_ptr = (const block_q8_Kx4 *) vy + (y * nb); + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_q2_Kx8 * b_ptr = (const block_q2_Kx8 *) vx + (x * nb); + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) { + sumf[m][j] = 0.0; + sum_minf[m][j] = 0.0; + } + } + for (int l = 0; l < nb; l++) { + for (int k = 0; k < (qk / (4 * blocklen)); k++) { + + const uint8_t *scales_0 = b_ptr[l].scales + (k / 4) * 64 ; + const uint8_t *scales_1 = b_ptr[l].scales + (k / 4) * 64 + 16; + const uint8_t *scales_2 = b_ptr[l].scales + (k / 4) * 64 + 32; + const uint8_t *scales_3 = b_ptr[l].scales + (k / 4) * 64 + 48; + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) { + sumi1 = 0; + sumi2 = 0; + sumi3 = 0; + sumi4 = 0; + sumi = 0; + int offset = ((k / 2) % 2) + j * 2; + for (int i = 0; i < blocklen; ++i){ + const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 3); + const int v1 = (int8_t) ((b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 2 ) & 3); + const int v2 = (int8_t) ((b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 4 ) & 3); + const int v3 = (int8_t) ((b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 6 ) & 3); + sumi1 = (v0 * a_ptr[l].qs[(k >> 2) * 512 + (k % 4) * 4 * blocklen + m * blocklen + i]); + sumi2 = (v1 * a_ptr[l].qs[(k >> 2) * 512 + (k % 4) * 4 * blocklen + m * blocklen + i + 128]); + sumi3 = (v2 * a_ptr[l].qs[(k >> 2) * 512 + (k % 4) * 4 * blocklen + m * blocklen + i + 256]); + sumi4 = (v3 * a_ptr[l].qs[(k >> 2) * 512 + (k % 4) * 4 * blocklen + m * blocklen + i + 384]); + sumi1 = sumi1 * (scales_0[offset] & 0xF); + sumi2 = sumi2 * (scales_1[offset] & 0xF); + sumi3 = sumi3 * (scales_2[offset] & 0xF); + sumi4 = sumi4 * (scales_3[offset] & 0xF); + sumi += sumi1 + sumi2 + sumi3 + sumi4; + } + sumf[m][j] += sumi * GGML_FP16_TO_FP32(b_ptr[l].d[j]) * a_ptr[l].d[m]; + } + } + } + for(int sb = 0; sb < 8; sb++) { + const uint8_t *mins = b_ptr[l].scales + sb * 16; + for(int m = 0; m < 4; m++) { + const int16_t *bsums = a_ptr[l].bsums + (sb * 8) + (m * 4) - ((sb % 2) * 6); + for(int j = 0; j < ncols_interleaved; j++) { + int mins_prod = ((mins[j * 2] >> 4) * bsums[0] + (mins[(j * 2)+ 1] >> 4) * bsums[1]); + sum_minf[m][j] += (mins_prod) * GGML_FP16_TO_FP32(b_ptr[l].dmin[j]) * a_ptr[l].d[m]; + } + } + } + } + + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) { + s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j] - sum_minf[m][j]; + } + } + } + } +} + + void ggml_gemm_iq4_nl_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { const int qk = QK8_0; const int nb = n / qk; @@ -916,6 +1083,50 @@ static block_q4_Kx8 make_block_q4_Kx8(block_q4_K * in, unsigned int blck_size_in return out; } +static block_q2_Kx8 make_block_q2_Kx8(block_q2_K * in, unsigned int blck_size_interleave) { + block_q2_Kx8 out; + + // Delta(scale) and dmin values of the eight Q2_K structures are copied onto the output interleaved structure + for (int i = 0; i < 8; i++) { + out.d[i] = in[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.d; + } + + for (int i = 0; i < 8; i++) { + out.dmin[i] = in[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.dmin; + } + + const int end = QK_K * 2 / blck_size_interleave; + + // Interleave Q2_K quants by taking 8 bytes at a time + for (int i = 0; i < end; ++i) { + int src_id = i % 8; + int src_offset = (i / 8) * blck_size_interleave; + int dst_offset = i * blck_size_interleave; + + uint64_t elems; + memcpy(&elems, &in[src_id].qs[src_offset], sizeof(uint64_t)); + memcpy(&out.qs[dst_offset], &elems, sizeof(uint64_t)); + } + + // The below logic is designed so as to unpack and rearrange scales and mins values in Q2_K + // Currently the Q2_K structure has 16 scales and 16 mins packed in 16 bytes ( 4 bits for each value) + // The output Q2_Kx8 structure has 128 bytes for storing scales and mins + // Every 16 byte is packed such that it contains scales and mins for corresponding sub blocks from Q2_K structure + // For eg - First 16 bytes contains 16 scales and 16 mins - each of first and second sub blocks from different Q2_K structures + + for(int i = 0; i < 128; i++){ + + // Index for selecting which q2k super block + int src1 = (i % 16) / 2; + // Index for selecting scale + int src2 = ((i / 16) * 2) + (i % 2); + + out.scales[i] = in[src1].scales[src2]; + } + return out; + +} + static int repack_q4_0_to_q4_0_4_bl(struct ggml_tensor * t, int interleave_block, const void * GGML_RESTRICT data, size_t data_size) { if(kcpp_q_already_repacked) //using legacy prepacked quant, so just copy it { @@ -982,6 +1193,37 @@ static int repack_q4_K_to_q4_K_8_bl(struct ggml_tensor * t, int interleave_block GGML_UNUSED(data_size); } +static int repack_q2_K_to_q2_K_8_bl(struct ggml_tensor * t, int interleave_block, const void * GGML_RESTRICT data, size_t data_size) { + GGML_ASSERT(t->type == GGML_TYPE_Q2_K); + GGML_ASSERT(interleave_block == 8); + constexpr int nrows_interleaved = 8; + + block_q2_Kx8 * dst = (block_q2_Kx8*)t->data; + const block_q2_K * src = (const block_q2_K*) data; + block_q2_K dst_tmp[8]; + int nrow = ggml_nrows(t); + int nblocks = t->ne[0] / QK_K; + + GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q2_K)); + + if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % 8 != 0) { + return -1; + } + + for (int b = 0; b < nrow; b += nrows_interleaved) { + for (int64_t x = 0; x < nblocks; x++) { + for (int i = 0; i < nrows_interleaved; i++ ) { + dst_tmp[i] = src[x + i * nblocks]; + } + *dst++ = make_block_q2_Kx8(dst_tmp, interleave_block); + } + src += nrows_interleaved * nblocks; + } + return 0; + + GGML_UNUSED(data_size); +} + static int repack_q4_0_to_q4_0_8_bl(struct ggml_tensor * t, int interleave_block, const void * GGML_RESTRICT data, size_t data_size) { if(kcpp_q_already_repacked) //using legacy prepacked quant, so just copy it { @@ -1112,6 +1354,10 @@ template <> int repack(struct ggml_tensor * t, const void * da return repack_q4_K_to_q4_K_8_bl(t, 8, data, data_size); } +template <> int repack(struct ggml_tensor * t, const void * data, size_t data_size) { + return repack_q2_K_to_q2_K_8_bl(t, 8, data, data_size); +} + template <> int repack(struct ggml_tensor * t, const void * data, size_t data_size) { return repack_iq4_nl_to_iq4_nl_4_bl(t, 4, data, data_size); } @@ -1141,6 +1387,10 @@ template <> void gemv(int n, float * s, size_t ggml_gemv_q4_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc); } +template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemv_q2_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc); +} + template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { ggml_gemv_iq4_nl_4x4_q8_0(n, s, bs, vx, vy, nr, nc); } @@ -1165,6 +1415,10 @@ template <> void gemm(int n, float * s, size_t ggml_gemm_q4_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc); } +template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemm_q2_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc); +} + template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { ggml_gemm_iq4_nl_4x4_q8_0(n, s, bs, vx, vy, nr, nc); } @@ -1447,12 +1701,14 @@ static const ggml::cpu::tensor_traits * ggml_repack_get_optimal_repack_type(cons static const ggml::cpu::repack::tensor_traits q4_0_8x8_q8_0; static const ggml::cpu::repack::tensor_traits q4_K_8x8_q8_K; + // instance for Q2 + static const ggml::cpu::repack::tensor_traits q2_K_8x8_q8_K; + // instance for IQ4 static const ggml::cpu::repack::tensor_traits iq4_nl_4x4_q8_0; if (cur->type == GGML_TYPE_Q4_0) { - //we shall just use the regular avx2 handling, no repacking - if (/*ggml_cpu_has_avx2() ||*/ (ggml_cpu_has_sve() && ggml_cpu_has_matmul_int8() && ggml_cpu_get_sve_cnt() == QK8_0)) { + if (ggml_cpu_has_avx2() || (ggml_cpu_has_sve() && ggml_cpu_has_matmul_int8() && ggml_cpu_get_sve_cnt() == QK8_0)) { if (cur->ne[1] % 8 == 0) { return &q4_0_8x8_q8_0; } @@ -1468,11 +1724,17 @@ static const ggml::cpu::tensor_traits * ggml_repack_get_optimal_repack_type(cons } } } else if (cur->type == GGML_TYPE_Q4_K) { - // if (ggml_cpu_has_avx2()) { - // if (cur->ne[1] % 8 == 0) { - // return &q4_K_8x8_q8_K; - // } - // } + if (ggml_cpu_has_avx2()) { + if (cur->ne[1] % 8 == 0) { + return &q4_K_8x8_q8_K; + } + } + } else if (cur->type == GGML_TYPE_Q2_K) { + if (ggml_cpu_has_avx512()) { + if (cur->ne[1] % 8 == 0) { + return &q2_K_8x8_q8_K; + } + } } else if (cur->type == GGML_TYPE_IQ4_NL) { if (ggml_cpu_has_neon() && ggml_cpu_has_dotprod()) { if (cur->ne[1] % 4 == 0) { diff --git a/ggml/src/ggml-cpu/repack.h b/ggml/src/ggml-cpu/repack.h index 4421e5f8e..cd322e743 100644 --- a/ggml/src/ggml-cpu/repack.h +++ b/ggml/src/ggml-cpu/repack.h @@ -44,7 +44,14 @@ struct block_q4_Kx8 { }; static_assert(sizeof(block_q4_Kx8) == sizeof(ggml_half) * 16 + K_SCALE_SIZE * 8 + QK_K * 4, "wrong q4_K block size/padding"); +struct block_q2_Kx8 { + ggml_half d[8]; // super-block scale for quantized scales + ggml_half dmin[8]; // super-block scale for quantized mins + uint8_t scales[128]; // scales and mins, quantized with 4 bits + uint8_t qs[512]; // 2--bit quants +}; +static_assert(sizeof(block_q2_Kx8) == sizeof(ggml_half) * 16 + QK_K/2 + QK_K * 2, "wrong q2_K block size/padding"); struct block_q8_Kx4 { float d[4]; // delta int8_t qs[QK_K * 4]; // quants @@ -71,11 +78,13 @@ void ggml_gemv_q4_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const vo void ggml_gemv_q4_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_q2_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_iq4_nl_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q4_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q4_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_q2_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_iq4_nl_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); // Native implementations @@ -86,11 +95,13 @@ void ggml_gemv_q4_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, void ggml_gemv_q4_0_4x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_q4_0_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_q4_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_q2_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_iq4_nl_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q4_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q4_0_4x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q4_0_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q4_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_q2_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_iq4_nl_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); #if defined(__cplusplus) diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index 8f04ca855..b330527a7 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -231,9 +231,9 @@ typedef float2 dfloat2; #define FP16_MMA_AVAILABLE #endif // defined(GGML_HIP_ROCWMMA_FATTN) && (defined(CDNA) || defined(RDNA3) || (defined(GGML_HIP_ROCWMMA_FATTN_GFX12) && defined(RDNA4))) -#if defined(GGML_USE_HIP) && defined(CDNA3) && !defined(GGML_HIP_NO_MMQ_MFMA) +#if defined(GGML_USE_HIP) && defined(CDNA) && !defined(GGML_HIP_NO_MMQ_MFMA) #define AMD_MFMA_AVAILABLE -#endif // defined(GGML_USE_HIP) && defined(CDNA3) && !defined(GGML_HIP_NO_MMQ_MFMA) +#endif // defined(GGML_USE_HIP) && defined(CDNA) && !defined(GGML_HIP_NO_MMQ_MFMA) #if !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_TURING #define NEW_MMA_AVAILABLE @@ -297,10 +297,9 @@ static bool fp32_mma_hardware_available(const int cc) { return GGML_CUDA_CC_IS_CDNA(cc); } -// AMD CDNA3 matrix cores.. Will add support for other CDNA generations later. static bool amd_mfma_available(const int cc) { #if !defined(GGML_HIP_NO_MMQ_MFMA) - return GGML_CUDA_CC_IS_CDNA3(cc); + return GGML_CUDA_CC_IS_CDNA(cc); #else return false; #endif //!defined(GGML_HIP_NO_MMQ_MFMA) @@ -436,6 +435,20 @@ static __global__ void reduce_rows_f32(const float * x, float * dst, const int n dst[row] = norm ? sum / ncols : sum; } +template +static __device__ __forceinline__ int warp_reduce_all(int x) { +#ifdef GGML_USE_HIP +#pragma unroll + for (int offset = width/2; offset > 0; offset >>= 1) { + x = x && __shfl_xor_sync(0xffffffff, x, offset, width); + } + return x; +#else + static_assert(width == WARP_SIZE, "width != WARP_SIZE not implemented"); + return __all_sync(0xffffffff, x); +#endif // GGML_USE_HIP +} + template static __device__ __forceinline__ float warp_reduce_max(float x) { #pragma unroll diff --git a/ggml/src/ggml-cuda/fattn-common.cuh b/ggml/src/ggml-cuda/fattn-common.cuh index 0cc74f284..b6db446c6 100644 --- a/ggml/src/ggml-cuda/fattn-common.cuh +++ b/ggml/src/ggml-cuda/fattn-common.cuh @@ -15,6 +15,7 @@ typedef void (* fattn_kernel_t)( const char * __restrict__ K, const char * __restrict__ V, const char * __restrict__ mask, + const int * __restrict__ KV_max, float * __restrict__ dst, float2 * __restrict__ dst_meta, const float scale, @@ -500,6 +501,55 @@ constexpr __device__ dequantize_1_f32_t get_dequantize_1_f32(ggml_type type_V) { nullptr; } +template +__launch_bounds__(FATTN_KQ_STRIDE/2, 1) +static __global__ void flash_attn_mask_to_KV_max( + const half2 * __restrict__ mask, int * __restrict__ KV_max, const int ne30, const int s31, const int s33) { + const int ne31 = gridDim.x; + const int tid = threadIdx.x; + const int sequence = blockIdx.y; + const int jt = blockIdx.x; + + mask += sequence*s33 + jt*ncols1*s31; + + __shared__ int buf_iw[WARP_SIZE]; + if (tid < WARP_SIZE) { + buf_iw[tid] = 1; + } + __syncthreads(); + + int KV_max_sj = (ne30 - 1) * FATTN_KQ_STRIDE; + for (; KV_max_sj >= 0; KV_max_sj -= FATTN_KQ_STRIDE) { + int all_inf = 1; + +#pragma unroll + for (int j = 0; j < ncols1; ++j) { + const float2 tmp = __half22float2(mask[j*s31 + KV_max_sj/2 + tid]); + all_inf = all_inf && int(isinf(tmp.x)) && int(isinf(tmp.y)); + } + + all_inf = warp_reduce_all(all_inf); + if (tid % WARP_SIZE == 0) { + buf_iw[tid / WARP_SIZE] = all_inf; + } + __syncthreads(); + all_inf = buf_iw[tid % WARP_SIZE]; + __syncthreads(); + all_inf = warp_reduce_all(all_inf); + + if (!all_inf) { + KV_max_sj += FATTN_KQ_STRIDE; + break; + } + } + + if (threadIdx.x != 0) { + return; + } + + KV_max[sequence*ne31 + jt] = KV_max_sj; +} + template // D == head size __launch_bounds__(D, 1) static __global__ void flash_attn_stream_k_fixup( @@ -711,6 +761,7 @@ void launch_fattn( ggml_cuda_pool_alloc K_f16(pool); ggml_cuda_pool_alloc V_f16(pool); + ggml_cuda_pool_alloc KV_max(pool); ggml_cuda_pool_alloc dst_tmp(pool); ggml_cuda_pool_alloc dst_tmp_meta(pool); @@ -779,11 +830,30 @@ void launch_fattn( V_data = (char *) V_f16.ptr; } - int parallel_blocks = 1; - const int ntiles_x = ((Q->ne[1] + ncols1 - 1) / ncols1); const int ntiles_total = ntiles_x * (Q->ne[2] / ncols2) * Q->ne[3]; + // Optional optimization where the mask is scanned to determine whether part of the calculation can be skipped. + // Only worth the overhead if there is at lease one FATTN_KQ_STRIDE x FATTN_KQ_STRIDE square to be skipped or + // multiple sequences of possibly different lengths. + if (mask && (Q->ne[1] >= 1024 || Q->ne[3] > 1)) { + const int s31 = mask->nb[1] / sizeof(half2); + const int s33 = mask->nb[3] / sizeof(half2); + + const dim3 blocks_num_KV_max(ntiles_x, Q->ne[3], 1); + const dim3 block_dim_KV_max(FATTN_KQ_STRIDE/2, 1, 1); + + const int ne_KV_max = blocks_num_KV_max.x*blocks_num_KV_max.y; + const int iter_k = K->ne[1] / FATTN_KQ_STRIDE; + + KV_max.alloc(ne_KV_max); + flash_attn_mask_to_KV_max<<>> + ((const half2 *) mask->data, KV_max.ptr, iter_k, s31, s33); + CUDA_CHECK(cudaGetLastError()); + } + + int parallel_blocks = 1; + const dim3 block_dim(warp_size, nwarps, 1); int max_blocks_per_sm = 1; // Max. number of active blocks limited by occupancy. CUDA_CHECK(cudaOccupancyMaxActiveBlocksPerMultiprocessor(&max_blocks_per_sm, fattn_kernel, block_dim.x * block_dim.y * block_dim.z, nbytes_shared)); @@ -870,6 +940,7 @@ void launch_fattn( K_data, V_data, mask ? ((const char *) mask->data) : nullptr, + KV_max.ptr, !stream_k && parallel_blocks > 1 ? dst_tmp.ptr : (float *) KQV->data, dst_tmp_meta.ptr, scale, max_bias, m0, m1, n_head_log2, logit_softcap, Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], Q->nb[1], Q->nb[2], Q->nb[3], diff --git a/ggml/src/ggml-cuda/fattn-mma-f16.cuh b/ggml/src/ggml-cuda/fattn-mma-f16.cuh index 8e847d361..a86b95428 100644 --- a/ggml/src/ggml-cuda/fattn-mma-f16.cuh +++ b/ggml/src/ggml-cuda/fattn-mma-f16.cuh @@ -392,7 +392,8 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_mask( } } -template +template static __device__ __forceinline__ void flash_attn_ext_f16_iter( const float2 * const __restrict__ Q_f2, const half2 * const __restrict__ K_h2, @@ -922,7 +923,8 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( } // Iterate over ne11 == previous tokens: - for (int kb0 = kb0_start; kb0 < kb0_stop-1; ++kb0) { + int kb0 = kb0_start; + for (; kb0 < kb0_stop-1; ++kb0) { constexpr bool last_iter = false; flash_attn_ext_f16_iter (Q_f2, K_h2, V_h2, mask_h2, dstk, dstk_fixup, scale, slope, logit_softcap, @@ -932,7 +934,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( constexpr bool last_iter = true; flash_attn_ext_f16_iter (Q_f2, K_h2, V_h2, mask_h2, dstk, dstk_fixup, scale, slope, logit_softcap, - ne01, ne02, stride_K, stride_V, stride_mask, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0_stop-1); + ne01, ne02, stride_K, stride_V, stride_mask, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0); } // With multi-stage loading there is no __syncthreads at the end of the iter, @@ -1204,6 +1206,7 @@ static __global__ void flash_attn_ext_f16( const char * __restrict__ K, const char * __restrict__ V, const char * __restrict__ mask, + const int * __restrict__ KV_max, float * __restrict__ dst, float2 * __restrict__ dst_meta, const float scale, @@ -1280,7 +1283,11 @@ static __global__ void flash_attn_ext_f16( const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, head, n_head_log2, m0, m1) : 1.0f; const int kb0_start_kernel = kb0_start * kb_niter; - const int kb0_stop_kernel = kb0_stop * kb_niter; + int kb0_stop_kernel = kb0_stop * kb_niter; + + if (KV_max) { + kb0_stop_kernel = min(kb0_stop_kernel, KV_max[sequence*iter_j + jt] / c::nbatch_fa); + } constexpr bool is_fixup = false; // All but (potentially) the last iterations write their data to dst rather than the fixup buffer. if (kb0_start == 0) { @@ -1321,7 +1328,11 @@ static __global__ void flash_attn_ext_f16( const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, head, n_head_log2, m0, m1) : 1.0f; const int kb0_start_kernel = kb0_start * kb_niter; - const int kb0_stop_kernel = kb0_stop * kb_niter; + int kb0_stop_kernel = kb0_stop * kb_niter; + + if (KV_max) { + kb0_stop_kernel = min(kb0_stop_kernel, KV_max[sequence*iter_j + jt] / c::nbatch_fa); + } constexpr bool is_fixup = true; // Last index writes its data to fixup buffer to avoid data races with other blocks. constexpr bool needs_fixup = false; diff --git a/ggml/src/ggml-cuda/fattn-tile-f16.cu b/ggml/src/ggml-cuda/fattn-tile-f16.cu index afa7a2324..4595b62a1 100644 --- a/ggml/src/ggml-cuda/fattn-tile-f16.cu +++ b/ggml/src/ggml-cuda/fattn-tile-f16.cu @@ -13,6 +13,7 @@ static __global__ void flash_attn_tile_ext_f16( const char * __restrict__ K, const char * __restrict__ V, const char * __restrict__ mask, + const int * __restrict__ KV_max, float * __restrict__ dst, float2 * __restrict__ dst_meta, const float scale, @@ -90,7 +91,8 @@ static __global__ void flash_attn_tile_ext_f16( __syncthreads(); - for (int k_VKQ_0 = blockIdx.y*FATTN_KQ_STRIDE_TILE_F16; k_VKQ_0 < ne11; k_VKQ_0 += gridDim.y*FATTN_KQ_STRIDE_TILE_F16) { + const int k_VKQ_max = KV_max ? KV_max[sequence*gridDim.x + blockIdx.x] : ne11; + for (int k_VKQ_0 = blockIdx.y*FATTN_KQ_STRIDE_TILE_F16; k_VKQ_0 < k_VKQ_max; k_VKQ_0 += gridDim.y*FATTN_KQ_STRIDE_TILE_F16) { // Calculate KQ tile and keep track of new maximum KQ values: half kqmax_new[ncols/nwarps]; diff --git a/ggml/src/ggml-cuda/fattn-tile-f32.cu b/ggml/src/ggml-cuda/fattn-tile-f32.cu index bc283d9a7..be72f76fb 100644 --- a/ggml/src/ggml-cuda/fattn-tile-f32.cu +++ b/ggml/src/ggml-cuda/fattn-tile-f32.cu @@ -13,6 +13,7 @@ static __global__ void flash_attn_tile_ext_f32( const char * __restrict__ K, const char * __restrict__ V, const char * __restrict__ mask, + const int * __restrict__ KV_max, float * __restrict__ dst, float2 * __restrict__ dst_meta, const float scale, @@ -99,7 +100,8 @@ static __global__ void flash_attn_tile_ext_f32( __syncthreads(); - for (int k_VKQ_0 = blockIdx.y*FATTN_KQ_STRIDE_TILE_F32; k_VKQ_0 < ne11; k_VKQ_0 += gridDim.y*FATTN_KQ_STRIDE_TILE_F32) { + const int k_VKQ_max = KV_max ? KV_max[sequence*gridDim.x + blockIdx.x] : ne11; + for (int k_VKQ_0 = blockIdx.y*FATTN_KQ_STRIDE_TILE_F32; k_VKQ_0 < k_VKQ_max; k_VKQ_0 += gridDim.y*FATTN_KQ_STRIDE_TILE_F32) { // Calculate KQ tile and keep track of new maximum KQ values: float kqmax_new[ncols/nwarps]; diff --git a/ggml/src/ggml-cuda/fattn-vec-f16.cuh b/ggml/src/ggml-cuda/fattn-vec-f16.cuh index 658010b5c..03b80bc03 100644 --- a/ggml/src/ggml-cuda/fattn-vec-f16.cuh +++ b/ggml/src/ggml-cuda/fattn-vec-f16.cuh @@ -16,6 +16,7 @@ static __global__ void flash_attn_vec_ext_f16( const char * __restrict__ K, const char * __restrict__ V, const char * __restrict__ mask, + const int * __restrict__ KV_max, float * __restrict__ dst, float2 * __restrict__ dst_meta, const float scale, @@ -177,10 +178,11 @@ static __global__ void flash_attn_vec_ext_f16( half2 VKQ[ncols] = {{0.0f, 0.0f}}; + const int k_VKQ_max = KV_max ? KV_max[sequence*gridDim.x + blockIdx.x] : ne11; K += blockIdx.y*D * nb11; V += blockIdx.y*D * nb21; maskh += blockIdx.y*D; - for (int k_VKQ_0 = blockIdx.y*D; k_VKQ_0 < ne11; k_VKQ_0 += gridDim.y*D, + for (int k_VKQ_0 = blockIdx.y*D; k_VKQ_0 < k_VKQ_max; k_VKQ_0 += gridDim.y*D, // Increment pointers after each loop: K += gridDim.y*D*nb11, V += gridDim.y*D*nb21, maskh += gridDim.y*D) { @@ -191,29 +193,7 @@ static __global__ void flash_attn_vec_ext_f16( for (int j = 0; j < ncols; ++j) { maskh_shared[j*D + tid] = slopeh*maskh[j*ne11 + tid]; } - __syncthreads(); - - // When using multiple parallel sequences in llama.cpp, some KV slices can be fully masked out. - // In such cases, skip the KV slice. - // On AMD __all_sync would not work correctly because it assumes a warp size of 64. -#ifndef GGML_USE_HIP - bool skip = true; -#pragma unroll - for (int j = 0; j < ncols; ++j) { -#pragma unroll - for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) { - const int i = i0 + threadIdx.x; - - const float2 tmp = __half22float2(((const half2 *) maskh_shared)[j*(D/2) + i]); - skip = skip && isinf(tmp.x) && isinf(tmp.y); - } - } - if (__all_sync(0xFFFFFFFF, skip)) { - __syncthreads(); - continue; - } -#endif // GGML_USE_HIP } // For unknown reasons using a half array of size 1 for kqmax_new causes a performance regression, diff --git a/ggml/src/ggml-cuda/fattn-vec-f32.cuh b/ggml/src/ggml-cuda/fattn-vec-f32.cuh index 3595e2969..9ab0fc133 100644 --- a/ggml/src/ggml-cuda/fattn-vec-f32.cuh +++ b/ggml/src/ggml-cuda/fattn-vec-f32.cuh @@ -16,6 +16,7 @@ static __global__ void flash_attn_vec_ext_f32( const char * __restrict__ K, const char * __restrict__ V, const char * __restrict__ mask, + const int * __restrict__ KV_max, float * __restrict__ dst, float2 * __restrict__ dst_meta, const float scale, @@ -183,10 +184,11 @@ static __global__ void flash_attn_vec_ext_f32( float VKQ[ncols] = {0.0f}; + const int k_VKQ_max = KV_max ? KV_max[sequence*gridDim.x + blockIdx.x] : ne11; K += blockIdx.y*D * nb11; V += blockIdx.y*D * nb21; maskh += blockIdx.y*D; - for (int k_VKQ_0 = blockIdx.y*D; k_VKQ_0 < ne11; k_VKQ_0 += gridDim.y*D, + for (int k_VKQ_0 = blockIdx.y*D; k_VKQ_0 < k_VKQ_max; k_VKQ_0 += gridDim.y*D, // Increment pointers after each loop: K += gridDim.y*D*nb11, V += gridDim.y*D*nb21, maskh += gridDim.y*D) { @@ -197,28 +199,7 @@ static __global__ void flash_attn_vec_ext_f32( for (int j = 0; j < ncols; ++j) { maskf_shared[j*D + tid] = slope*__half2float(maskh[j*ne11 + tid]); } - __syncthreads(); - - // When using multiple parallel sequences in llama.cpp, some KV slices can be fully masked out. - // In such cases, skip the KV slice. - // On AMD __all_sync would not work correctly because it assumes a warp size of 64. -#ifndef GGML_USE_HIP - bool skip = true; -#pragma unroll - for (int j = 0; j < ncols; ++j) { -#pragma unroll - for (int i0 = 0; i0 < D; i0 += WARP_SIZE) { - const int i = i0 + threadIdx.x; - - skip = skip && isinf(maskf_shared[j*D + i]); - } - } - if (__all_sync(0xFFFFFFFF, skip)) { - __syncthreads(); - continue; - } -#endif // GGML_USE_HIP } float kqmax_new_arr[ncols]; diff --git a/ggml/src/ggml-cuda/fattn-wmma-f16.cu b/ggml/src/ggml-cuda/fattn-wmma-f16.cu index 1884eb091..a83c87d5e 100644 --- a/ggml/src/ggml-cuda/fattn-wmma-f16.cu +++ b/ggml/src/ggml-cuda/fattn-wmma-f16.cu @@ -29,6 +29,7 @@ static __global__ void flash_attn_ext_f16( const char * __restrict__ K, const char * __restrict__ V, const char * __restrict__ mask, + const int * __restrict__ KV_max, float * __restrict__ dst, float2 * __restrict__ dst_meta, const float scale, @@ -165,7 +166,8 @@ static __global__ void flash_attn_ext_f16( __syncthreads(); // Iterate over ne11 == previous tokens: - for (int k_VKQ_0 = blockIdx.y*FATTN_KQ_STRIDE; k_VKQ_0 < ne11; k_VKQ_0 += gridDim.y*FATTN_KQ_STRIDE) { + const int k_VKQ_max = KV_max ? KV_max[sequence*gridDim.x + blockIdx.x] : ne11; + for (int k_VKQ_0 = blockIdx.y*FATTN_KQ_STRIDE; k_VKQ_0 < k_VKQ_max; k_VKQ_0 += gridDim.y*FATTN_KQ_STRIDE) { // Calculate tile of KQ: #pragma unroll for (int i_KQ_0 = 0; i_KQ_0 < FATTN_KQ_STRIDE; i_KQ_0 += KQ_stride_tc) { diff --git a/ggml/src/ggml-cuda/fattn.cu b/ggml/src/ggml-cuda/fattn.cu index 07a7a6dfb..ebf5ce1fe 100644 --- a/ggml/src/ggml-cuda/fattn.cu +++ b/ggml/src/ggml-cuda/fattn.cu @@ -315,7 +315,8 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst const bool gqa_opt_applies = ((Q->ne[2] / K->ne[2]) % 2 == 0) && mask; // The mma-based kernels have GQA-specific optimizations const bool mma_needs_data_conversion = K->type != GGML_TYPE_F16 || V->type != GGML_TYPE_F16; - const bool mma_faster_for_bs1 = new_mma_available(cc) && gqa_opt_applies && cc < GGML_CUDA_CC_ADA_LOVELACE && !mma_needs_data_conversion; + const bool mma_faster_for_bs1 = new_mma_available(cc) && gqa_opt_applies && + (Q->ne[3] > 1 || cc < GGML_CUDA_CC_ADA_LOVELACE) && !mma_needs_data_conversion; const bool can_use_vector_kernel = Q->ne[0] <= 256 && Q->ne[0] % (2*warp_size) == 0; if (Q->ne[1] == 1 && can_use_vector_kernel && !mma_faster_for_bs1) { if (prec == GGML_PREC_DEFAULT) { diff --git a/ggml/src/ggml-cuda/mmq.cu b/ggml/src/ggml-cuda/mmq.cu index 0bb7d6e61..9068ce775 100644 --- a/ggml/src/ggml-cuda/mmq.cu +++ b/ggml/src/ggml-cuda/mmq.cu @@ -109,8 +109,8 @@ void ggml_cuda_mul_mat_q( const int64_t s03 = src0->nb[3] / ts_src0; const int64_t s3 = dst->nb[3] / ts_dst; - const bool use_stream_k = ((GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA) - || (GGML_CUDA_CC_IS_AMD(cc) && GGML_CUDA_CC_IS_CDNA3(cc))); + const bool use_stream_k = (GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA) + || GGML_CUDA_CC_IS_CDNA(cc); if (!ids) { const size_t nbytes_src1_q8_1 = ne13*ne12 * ne11*ne10_padded * sizeof(block_q8_1)/QK8_1 + @@ -252,7 +252,7 @@ void ggml_cuda_op_mul_mat_q( // Also its fixup needs to allocate a temporary buffer in the memory pool. // There are multiple parallel CUDA streams for src1_ncols != ne11 which would introduce a race condition for this buffer. const bool use_stream_k = ((GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA) - || (GGML_CUDA_CC_IS_AMD(cc) && GGML_CUDA_CC_IS_CDNA3(cc))) + || GGML_CUDA_CC_IS_CDNA(cc)) && src1_ncols == ne11; const mmq_args args = { src0_dd_i, src0->type, (const int *) src1_ddq_i, nullptr, nullptr, dst_dd_i, @@ -308,7 +308,7 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11) { return false; } - if (new_mma_available(cc) || amd_mfma_available(cc)) { + if (new_mma_available(cc)) { return true; } @@ -324,5 +324,21 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11) { return !fp16_mma_hardware_available(cc) || ne11 < MMQ_DP4A_MAX_BATCH_SIZE; } + if (amd_mfma_available(cc)) { + // As of ROCM 7.0 rocblas/tensile performs very poorly on CDNA3 and hipblaslt (via ROCBLAS_USE_HIPBLASLT) + // performs better but is currently suffering from a crash on this architecture. + // TODO: Revisit when hipblaslt is fixed on CDNA3 + if (GGML_CUDA_CC_IS_CDNA3(cc)) { + return true; + } + if (ne11 <= 128 || type == GGML_TYPE_Q4_0 || type == GGML_TYPE_Q4_1 || type == GGML_TYPE_Q5_0 || type == GGML_TYPE_Q5_1) { + return true; + } + if (ne11 <= 256 && (type == GGML_TYPE_Q4_K || type == GGML_TYPE_Q5_K)) { + return true; + } + return false; + } + return (!GGML_CUDA_CC_IS_RDNA4(cc) && !GGML_CUDA_CC_IS_RDNA3(cc) && !GGML_CUDA_CC_IS_CDNA(cc)) || ne11 < MMQ_DP4A_MAX_BATCH_SIZE; } diff --git a/ggml/src/ggml-cuda/mmq.cuh b/ggml/src/ggml-cuda/mmq.cuh index 984496a68..a809c68c8 100644 --- a/ggml/src/ggml-cuda/mmq.cuh +++ b/ggml/src/ggml-cuda/mmq.cuh @@ -252,25 +252,21 @@ static constexpr __device__ int mmq_get_granularity_device(const int /*mmq_x*/) #endif // AMD_MFMA_AVAILABLE #if defined(GGML_USE_HIP) -static int mmq_get_nwarps_host(const int cc) { - return amd_mfma_available(cc) ? 8 : 4; +static int mmq_get_nwarps_host(const int cc, const int warp_size) { + return amd_mfma_available(cc) ? 8 : 256/warp_size; } #else -static int mmq_get_nwarps_host(const int /*cc*/) { - return 8; +static int mmq_get_nwarps_host(const int /*cc*/, const int warp_size) { + return 256/warp_size; } #endif // (GGML_USE_HIP) static constexpr __device__ int mmq_get_nwarps_device() { -#if defined(GGML_USE_HIP) #if defined(AMD_MFMA_AVAILABLE) return 8; #else - return 4; + return 256/ggml_cuda_get_physical_warp_size(); #endif // AMD_MFMA_AVAILABLE -#else - return 8; -#endif // defined(GGML_USE_HIP) } // ------------------------------------------------------------ @@ -3097,8 +3093,8 @@ static __global__ void mul_mat_q( } __syncthreads(); - // On AMD or old CUDA the performance with stream-k was worse, use conventional tiling instead: -#if (defined(GGML_USE_HIP) && !defined(CDNA3)) || __CUDA_ARCH__ < GGML_CUDA_CC_VOLTA + // On non-CDNA AMD or old CUDA the performance with stream-k was worse, use conventional tiling instead: +#if (defined(GGML_USE_HIP) && !defined(CDNA)) || __CUDA_ARCH__ < GGML_CUDA_CC_VOLTA { const int wt = blockIdx.z / nchannels_y; const int zt = blockIdx.z - wt*nchannels_y; @@ -3473,7 +3469,7 @@ static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & a const int cc = ggml_cuda_info().devices[id].cc; const int nsm = ggml_cuda_info().devices[id].nsm; const int warp_size = ggml_cuda_info().devices[id].warp_size; - const int nwarps = mmq_get_nwarps_host(cc); + const int nwarps = mmq_get_nwarps_host(cc, warp_size); const int mmq_y = get_mmq_y_host(cc); const dim3 block_dims(warp_size, nwarps, 1); @@ -3560,7 +3556,7 @@ void mul_mat_q_case(ggml_backend_cuda_context & ctx, const mmq_args & args, cuda const int cc = ggml_cuda_info().devices[id].cc; const size_t smpbo = ggml_cuda_info().devices[id].smpbo; const int warp_size = ggml_cuda_info().devices[id].warp_size; - const int nwarps = mmq_get_nwarps_host(cc); + const int nwarps = mmq_get_nwarps_host(cc, warp_size); const int mmq_x_max = get_mmq_x_max_host(cc); const int mmq_y = get_mmq_y_host(cc); diff --git a/ggml/src/ggml-opencl/kernels/div.cl b/ggml/src/ggml-opencl/kernels/div.cl index d453ad99b..6d9b4ade9 100644 --- a/ggml/src/ggml-opencl/kernels/div.cl +++ b/ggml/src/ggml-opencl/kernels/div.cl @@ -70,3 +70,69 @@ kernel void kernel_div_row( uint idx1 = gid - (gid/ne)*ne; // get_global_id(0) % ne dst[gid] = src0[gid] / src1[idx1]; } + +kernel void kernel_div_f16( + global char * src0, + ulong offset0, + global char * src1, + ulong offset1, + global char * dst, + ulong offsetd, + ulong nb00, + ulong nb01, + ulong nb02, + ulong nb03, + int ne10, + int ne11, + int ne12, + int ne13, + ulong nb10, + ulong nb11, + ulong nb12, + ulong nb13, + int ne0, + ulong nb0, + ulong nb1, + ulong nb2, + ulong nb3 +) { + src0 = src0 + offset0; + src1 = src1 + offset1; + dst = dst + offsetd; + + int i03 = get_group_id(2); + int i02 = get_group_id(1); + int i01 = get_group_id(0); + + int i13 = i03 % ne13; + int i12 = i02 % ne12; + int i11 = i01 % ne11; + + global char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01; + global char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11; + global char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1; + + for (int i0 = get_local_id(0); i0 < ne0; i0 += get_local_size(0)) { + const int i10 = i0 % ne10; + *((global half *)(dst_ptr + i0*nb0)) = *((global half *)(src0_ptr + i0*nb00)) / *((global half *)(src1_ptr + i10*nb10)); + } +} + +kernel void kernel_div_row_f16( + global half4 * src0, + ulong offset0, + global half4 * src1, + ulong offset1, + global half4 * dst, + ulong offsetd, + int ne +) { + src0 = (global half4*)((global char*)src0 + offset0); + src1 = (global half4*)((global char*)src1 + offset1); + dst = (global half4*)((global char*)dst + offsetd); + + // This performs better than using %. + uint gid = get_global_id(0); + uint idx1 = gid - (gid/ne)*ne; // get_global_id(0) % ne + dst[gid] = src0[gid] / src1[idx1]; +} diff --git a/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_l4_lm.cl b/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_l4_lm.cl new file mode 100644 index 000000000..9599a0e15 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_l4_lm.cl @@ -0,0 +1,132 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +#define LOAD_VEC_A 4 +#define LOAD_VEC_B 4 + +#define BM 64 +#define BN 64 +#define BK 16 +#define TM 4 +#define TN 8 + +kernel void kernel_mul_mm_f16_f32_l4_lm( + global half4 * src0, + ulong offset0, + global float4 * src1, + ulong offset1, + global float * dst, + ulong offsetd, + + int ne00, + int ne01, + int ne02, + int ne11, + int ne12, + + int stride_a, + int stride_b, + int stride_d, + + int batch_stride_a, + int batch_stride_b, + int batch_stride_d, + + int r2, + int r3 +) { + src0 = (global half4*)((global char*)src0 + offset0); + src1 = (global float4*)((global char*)src1 + offset1); + dst = (global float*)((global char*)dst + offsetd); + + local half buf_a[BM * BK]; + local float buf_b[BN * BK]; + + const int batch_idx = get_global_id(2); + + const int i13 = batch_idx / ne12; + const int i12 = batch_idx % ne12; + + const int i03 = i13 / r3; + const int i02 = i12 / r2; + + const int batch_idx_a = i03 * ne02 + i02; + + const int ir = get_group_id(0); + const int ic = get_group_id(1); + + const int tid = get_local_id(0); + const int th_r = tid % (BM / TM); + const int th_c = tid / (BM / TM); + + const int loadr_a = get_local_id(0) % (BK / LOAD_VEC_A); + const int loadc_a = get_local_id(0) / (BK / LOAD_VEC_A); + const int loadr_b = get_local_id(0) % (BK / LOAD_VEC_B); + const int loadc_b = get_local_id(0) / (BK / LOAD_VEC_B); + + const int loadstride_a = get_local_size(0) * LOAD_VEC_A / BK; + const int loadstride_b = get_local_size(0) * LOAD_VEC_B / BK; + + int pos_a = (batch_idx_a * batch_stride_a + ir * BM * stride_a) / LOAD_VEC_A; + int pos_b = (batch_idx * batch_stride_b + ic * BN * stride_b) / LOAD_VEC_B; + + float sums[TM * TN]; + half cache_a[TM]; + float cache_b[TN]; + + for (int i = 0; i < TM * TN; i++) { + sums[i] = 0.0f; + } + + for (int block = 0; block < ne00; block += BK) { + for (int l = 0; l < BM; l += loadstride_a) { + const int idx = pos_a + (loadc_a + l) * stride_a / LOAD_VEC_A + loadr_a; + buf_a[(loadr_a * LOAD_VEC_A + 0) * BM + loadc_a + l] = src0[idx].s0; + buf_a[(loadr_a * LOAD_VEC_A + 1) * BM + loadc_a + l] = src0[idx].s1; + buf_a[(loadr_a * LOAD_VEC_A + 2) * BM + loadc_a + l] = src0[idx].s2; + buf_a[(loadr_a * LOAD_VEC_A + 3) * BM + loadc_a + l] = src0[idx].s3; + } + + for (int l = 0; l < BN; l += loadstride_b) { + const int idx = pos_b + (loadc_b + l) * stride_b / LOAD_VEC_B + loadr_b; + buf_b[(loadr_b * LOAD_VEC_B + 0) * BN + loadc_b + l] = src1[idx].s0; + buf_b[(loadr_b * LOAD_VEC_B + 1) * BN + loadc_b + l] = src1[idx].s1; + buf_b[(loadr_b * LOAD_VEC_B + 2) * BN + loadc_b + l] = src1[idx].s2; + buf_b[(loadr_b * LOAD_VEC_B + 3) * BN + loadc_b + l] = src1[idx].s3; + } + + barrier(CLK_LOCAL_MEM_FENCE); + + pos_a += BK / LOAD_VEC_A; + pos_b += BK / LOAD_VEC_B; + + for (int i = 0; i < BK; i++) { + for (int j = 0; j < TM; j++) { + cache_a[j] = buf_a[(i) * BM + th_r * TM + j]; + } + for (int j = 0; j < TN; j++) { + cache_b[j] = buf_b[(i) * BN + th_c * TN + j]; + } + + for (int cc = 0; cc < TN; cc++) { + for (int cr = 0; cr < TM; cr++) { + const int sums_idx = cc*TM + cr; + sums[sums_idx] = mad(convert_float(cache_a[cr]), cache_b[cc], sums[sums_idx]); + } + } + } + barrier(CLK_LOCAL_MEM_FENCE); + } + + const int dr = ir * BM + th_r * TM; + const int dc = ic * BN + th_c * TN; + + const int offsets = batch_idx * batch_stride_d; + + for (int cc = 0; cc < TN; cc++) { + for (int cr = 0; cr < TM; cr++) { + if (dr + cr < ne01 && dc + cc < ne11) { + dst[offsets + (dc + cc) * stride_d + dr + cr] = sums[cc * TM + cr]; + } + } + } +} diff --git a/ggml/src/ggml-opencl/kernels/mul_mm_f32_f32_l4_lm.cl b/ggml/src/ggml-opencl/kernels/mul_mm_f32_f32_l4_lm.cl new file mode 100644 index 000000000..58c5178e3 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/mul_mm_f32_f32_l4_lm.cl @@ -0,0 +1,133 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +#define LOAD_VEC_A 4 +#define LOAD_VEC_B 4 + +#define BM 64 +#define BN 64 +#define BK 16 +#define TM 4 +#define TN 8 + +kernel void kernel_mul_mm_f32_f32_l4_lm( + global float4 * src0, + ulong offset0, + global float4 * src1, + ulong offset1, + global float * dst, + ulong offsetd, + + int ne00, + int ne01, + int ne02, + int ne11, + int ne12, + + int stride_a, + int stride_b, + int stride_d, + + int batch_stride_a, + int batch_stride_b, + int batch_stride_d, + + int r2, + int r3 +) { + src0 = (global float4*)((global char*)src0 + offset0); + src1 = (global float4*)((global char*)src1 + offset1); + dst = (global float*)((global char*)dst + offsetd); + + local float buf_a[BM * BK]; + local float buf_b[BN * BK]; + + const int batch_idx = get_global_id(2); + + const int i13 = batch_idx / ne12; + const int i12 = batch_idx % ne12; + + const int i03 = i13 / r3; + const int i02 = i12 / r2; + + const int batch_idx_a = i03 * ne02 + i02; + + const int ir = get_group_id(0); + const int ic = get_group_id(1); + + const int tid = get_local_id(0); + const int th_r = tid % (BM / TM); + const int th_c = tid / (BM / TM); + + const int loadr_a = get_local_id(0) % (BK / LOAD_VEC_A); + const int loadc_a = get_local_id(0) / (BK / LOAD_VEC_A); + const int loadr_b = get_local_id(0) % (BK / LOAD_VEC_B); + const int loadc_b = get_local_id(0) / (BK / LOAD_VEC_B); + + const int loadstride_a = get_local_size(0) * LOAD_VEC_A / BK; + const int loadstride_b = get_local_size(0) * LOAD_VEC_B / BK; + + int pos_a = (batch_idx_a * batch_stride_a + ir * BM * stride_a) / LOAD_VEC_A; + int pos_b = (batch_idx * batch_stride_b + ic * BN * stride_b) / LOAD_VEC_B; + + float sums[TM * TN]; + float cache_a[TM]; + float cache_b[TN]; + + for (int i = 0; i < TM * TN; i++) { + sums[i] = 0.0f; + } + + for (int block = 0; block < ne00; block += BK) { + for (int l = 0; l < BM; l += loadstride_a) { + const int idx = pos_a + (loadc_a + l) * stride_a / LOAD_VEC_A + loadr_a; + buf_a[(loadr_a * LOAD_VEC_A + 0) * BM + loadc_a + l] = src0[idx].s0; + buf_a[(loadr_a * LOAD_VEC_A + 1) * BM + loadc_a + l] = src0[idx].s1; + buf_a[(loadr_a * LOAD_VEC_A + 2) * BM + loadc_a + l] = src0[idx].s2; + buf_a[(loadr_a * LOAD_VEC_A + 3) * BM + loadc_a + l] = src0[idx].s3; + } + + for (int l = 0; l < BN; l += loadstride_b) { + const int idx = pos_b + (loadc_b + l) * stride_b / LOAD_VEC_B + loadr_b; + buf_b[(loadr_b * LOAD_VEC_B + 0) * BN + loadc_b + l] = src1[idx].s0; + buf_b[(loadr_b * LOAD_VEC_B + 1) * BN + loadc_b + l] = src1[idx].s1; + buf_b[(loadr_b * LOAD_VEC_B + 2) * BN + loadc_b + l] = src1[idx].s2; + buf_b[(loadr_b * LOAD_VEC_B + 3) * BN + loadc_b + l] = src1[idx].s3; + } + + barrier(CLK_LOCAL_MEM_FENCE); + + pos_a += BK / LOAD_VEC_A; + pos_b += BK / LOAD_VEC_B; + + for (int i = 0; i < BK; i++) { + for (int j = 0; j < TM; j++) { + cache_a[j] = buf_a[(i) * BM + th_r * TM + j]; + } + + for (int j = 0; j < TN; j++) { + cache_b[j] = buf_b[(i) * BN + th_c * TN + j]; + } + + for (int cc = 0; cc < TN; cc++) { + for (int cr = 0; cr < TM; cr++) { + const int sums_idx = cc*TM + cr; + sums[sums_idx] = mad(cache_a[cr], cache_b[cc], sums[sums_idx]); + } + } + } + barrier(CLK_LOCAL_MEM_FENCE); + } + + const int dr = ir * BM + th_r * TM; + const int dc = ic * BN + th_c * TN; + + const int offsets = batch_idx * batch_stride_d; + + for (int cc = 0; cc < TN; cc++) { + for (int cr = 0; cr < TM; cr++) { + if (dr + cr < ne01 && dc + cc < ne11) { + dst[offsets + (dc + cc) * stride_d + dr + cr] = sums[cc * TM + cr]; + } + } + } +} diff --git a/ggml/src/ggml-opencl/kernels/sub.cl b/ggml/src/ggml-opencl/kernels/sub.cl index 041e88ad3..423ed595c 100644 --- a/ggml/src/ggml-opencl/kernels/sub.cl +++ b/ggml/src/ggml-opencl/kernels/sub.cl @@ -70,3 +70,69 @@ kernel void kernel_sub_row( uint idx1 = gid - (gid/ne)*ne; // get_global_id(0) % ne dst[gid] = src0[gid] - src1[idx1]; } + +kernel void kernel_sub_f16( + global char * src0, + ulong offset0, + global char * src1, + ulong offset1, + global char * dst, + ulong offsetd, + ulong nb00, + ulong nb01, + ulong nb02, + ulong nb03, + int ne10, + int ne11, + int ne12, + int ne13, + ulong nb10, + ulong nb11, + ulong nb12, + ulong nb13, + int ne0, + ulong nb0, + ulong nb1, + ulong nb2, + ulong nb3 +) { + src0 = src0 + offset0; + src1 = src1 + offset1; + dst = dst + offsetd; + + int i03 = get_group_id(2); + int i02 = get_group_id(1); + int i01 = get_group_id(0); + + int i13 = i03 % ne13; + int i12 = i02 % ne12; + int i11 = i01 % ne11; + + global char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01; + global char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11; + global char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1; + + for (int i0 = get_local_id(0); i0 < ne0; i0 += get_local_size(0)) { + const int i10 = i0 % ne10; + *((global half *)(dst_ptr + i0*nb0)) = *((global half *)(src0_ptr + i0*nb00)) - *((global half *)(src1_ptr + i10*nb10)); + } +} + +kernel void kernel_sub_row_f16( + global half4 * src0, + ulong offset0, + global half4 * src1, + ulong offset1, + global half4 * dst, + ulong offsetd, + int ne +) { + src0 = (global half4*)((global char*)src0 + offset0); + src1 = (global half4*)((global char*)src1 + offset1); + dst = (global half4*)((global char*)dst + offsetd); + + // This performs better than using %. + uint gid = get_global_id(0); + uint idx1 = gid - (gid/ne)*ne; // get_global_id(0) % ne + dst[gid] = src0[gid] - src1[idx1]; +} diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 6bb09ae8a..91bf7ca05 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -1357,7 +1357,7 @@ static void ggml_vk_create_pipeline_func(vk_device& device, vk_pipeline& pipelin vk::DebugUtilsObjectNameInfoEXT duoni; duoni.objectType = vk::ObjectType::ePipeline; duoni.pObjectName = pipeline->name.c_str(); - duoni.objectHandle = reinterpret_cast(static_cast(pipeline->pipeline)); + duoni.objectHandle = /*reinterpret_cast*/(uint64_t)(static_cast(pipeline->pipeline)); vk_instance.pfn_vkSetDebugUtilsObjectNameEXT(device->device, &static_cast(duoni)); } @@ -5249,9 +5249,9 @@ static void ggml_vk_quantize_q8_1(ggml_backend_vk_context * ctx, vk_context& sub } static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { - VK_LOG_DEBUG("ggml_vk_mul_mat_q_f16((" << src0 << ", name=" << src0->name << ", type=" << src0->type << ", ne0=" << src0->ne[0] << ", ne1=" << src0->ne[1] << ", ne2=" << src0->ne[2] << ", ne3=" << src0->ne[3] << ", nb0=" << src0->nb[0] << ", nb1=" << src0->nb[1] << ", nb2=" << src0->nb[2] << ", nb3=" << src0->nb[3]; - std::cerr << "), (" << src1 << ", name=" << src1->name << ", type=" << src1->type << ", ne0=" << src1->ne[0] << ", ne1=" << src1->ne[1] << ", ne2=" << src1->ne[2] << ", ne3=" << src1->ne[3] << ", nb0=" << src1->nb[0] << ", nb1=" << src1->nb[1] << ", nb2=" << src1->nb[2] << ", nb3=" << src1->nb[3]; - std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << dst->type << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3]; + VK_LOG_DEBUG("ggml_vk_mul_mat_q_f16((" << src0 << ", name=" << src0->name << ", type=" << ggml_type_name(src0->type) << ", ne0=" << src0->ne[0] << ", ne1=" << src0->ne[1] << ", ne2=" << src0->ne[2] << ", ne3=" << src0->ne[3] << ", nb0=" << src0->nb[0] << ", nb1=" << src0->nb[1] << ", nb2=" << src0->nb[2] << ", nb3=" << src0->nb[3]; + std::cerr << "), (" << src1 << ", name=" << src1->name << ", type=" << ggml_type_name(src1->type) << ", ne0=" << src1->ne[0] << ", ne1=" << src1->ne[1] << ", ne2=" << src1->ne[2] << ", ne3=" << src1->ne[3] << ", nb0=" << src1->nb[0] << ", nb1=" << src1->nb[1] << ", nb2=" << src1->nb[2] << ", nb3=" << src1->nb[3]; + std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << ggml_type_name(dst->type) << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3]; std::cerr << "), " << (dryrun ? "dryrun" : "") << ")"); GGML_ASSERT(ggml_vk_dim01_contiguous(src0) || src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_BF16); // NOLINT GGML_ASSERT(ggml_vk_dim01_contiguous(src1) || src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16); // NOLINT @@ -11192,7 +11192,7 @@ size_t comp_nb[GGML_MAX_DIMS]; size_t check_counter = 0; static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph * cgraph, int tensor_idx) { ggml_tensor * tensor = cgraph->nodes[tensor_idx]; - if (tensor->op == GGML_OP_TRANSPOSE) { + if (tensor->op == GGML_OP_TRANSPOSE || tensor->op == GGML_OP_SET_ROWS) { return; } @@ -11312,7 +11312,7 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph * tensor_clone = ggml_upscale_ext(ggml_ctx, src_clone[0], tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3], (ggml_scale_mode) tensor->op_params[0]); } else if (tensor->op == GGML_OP_SCALE) { const float * params = (const float *)tensor->op_params; - tensor_clone = ggml_scale(ggml_ctx, src_clone[0], params[0]); + tensor_clone = ggml_scale_bias(ggml_ctx, src_clone[0], params[0], params[1]); } else if (tensor->op == GGML_OP_SQR) { tensor_clone = ggml_sqr(ggml_ctx, src_clone[0]); } else if (tensor->op == GGML_OP_SIN) { @@ -11423,8 +11423,6 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph * } else { tensor_clone = ggml_cpy(ggml_ctx, src_clone[0], src_clone[1]); } - } else if (tensor->op == GGML_OP_SET_ROWS) { - tensor_clone = ggml_set_rows(ggml_ctx, src_clone[0], src_clone[1]); } else if (tensor->op == GGML_OP_CONT) { tensor_clone = ggml_cont_4d(ggml_ctx, src_clone[0], tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3]); } else if (tensor->op == GGML_OP_RESHAPE) { @@ -11532,7 +11530,7 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph * static void ggml_vk_check_results_1(ggml_backend_vk_context * ctx, ggml_cgraph * cgraph, int tensor_idx) { ggml_tensor * tensor = cgraph->nodes[tensor_idx]; - if (tensor->op == GGML_OP_TRANSPOSE) { + if (tensor->op == GGML_OP_TRANSPOSE || tensor->op == GGML_OP_SET_ROWS) { return; } bool fused_rms_norm_mul = false; @@ -11592,6 +11590,9 @@ static void ggml_vk_check_results_1(ggml_backend_vk_context * ctx, ggml_cgraph * } else if (tensor->type == GGML_TYPE_F16) { correct = ggml_fp16_to_fp32(*(ggml_fp16_t *) ((char *) comp_result + i3*comp_nb[3] + i2*comp_nb[2] + i1*comp_nb[1] + i0*comp_nb[0])); result = ggml_fp16_to_fp32(*(ggml_fp16_t *) ((char *) tensor_data + i3*tensor->nb[3] + i2*tensor->nb[2] + i1*tensor->nb[1] + i0*tensor->nb[0])); + } else if (tensor->type == GGML_TYPE_BF16) { + correct = ggml_bf16_to_fp32(*(ggml_bf16_t *) ((char *) comp_result + i3*comp_nb[3] + i2*comp_nb[2] + i1*comp_nb[1] + i0*comp_nb[0])); + result = ggml_bf16_to_fp32(*(ggml_bf16_t *) ((char *) tensor_data + i3*tensor->nb[3] + i2*tensor->nb[2] + i1*tensor->nb[1] + i0*tensor->nb[0])); } else if (tensor->type == GGML_TYPE_I32) { correct = *(int32_t *) ((char *) comp_result + i3*comp_nb[3] + i2*comp_nb[2] + i1*comp_nb[1] + i0*comp_nb[0]); result = *(int32_t *) ((char *) tensor_data + i3*tensor->nb[3] + i2*tensor->nb[2] + i1*tensor->nb[1] + i0*tensor->nb[0]); diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index c97b61d09..5707085cb 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -279,6 +279,9 @@ class Keys: class Projector: STACK_FACTOR = "clip.audio.projector.stack_factor" + class Diffusion: + SHIFT_LOGITS = "diffusion.shift_logits" + # # recommended mapping of model tensor names for storage in gguf # @@ -373,10 +376,12 @@ class MODEL_ARCH(IntEnum): ERNIE4_5 = auto() ERNIE4_5_MOE = auto() HUNYUAN_MOE = auto() + HUNYUAN_DENSE = auto() SMOLLM3 = auto() LFM2 = auto() DREAM = auto() SMALLTHINKER = auto() + LLADA = auto() class VISION_PROJECTOR_TYPE(IntEnum): @@ -693,10 +698,12 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = { MODEL_ARCH.ERNIE4_5_MOE: "ernie4_5-moe", MODEL_ARCH.FALCON_H1: "falcon-h1", MODEL_ARCH.HUNYUAN_MOE: "hunyuan-moe", + MODEL_ARCH.HUNYUAN_DENSE: "hunyuan-dense", MODEL_ARCH.SMOLLM3: "smollm3", MODEL_ARCH.LFM2: "lfm2", MODEL_ARCH.DREAM: "dream", MODEL_ARCH.SMALLTHINKER: "smallthinker", + MODEL_ARCH.LLADA: "llada", } VISION_PROJECTOR_TYPE_NAMES: dict[VISION_PROJECTOR_TYPE, str] = { @@ -1318,6 +1325,21 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = { MODEL_TENSOR.FFN_DOWN, MODEL_TENSOR.FFN_UP, ], + MODEL_ARCH.LLADA: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.ROPE_FREQS, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_Q, + MODEL_TENSOR.ATTN_K, + MODEL_TENSOR.ATTN_V, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.FFN_NORM, + MODEL_TENSOR.FFN_GATE, + MODEL_TENSOR.FFN_DOWN, + MODEL_TENSOR.FFN_UP, + ], MODEL_ARCH.QWEN2VL: [ MODEL_TENSOR.TOKEN_EMBD, MODEL_TENSOR.OUTPUT_NORM, @@ -2451,6 +2473,22 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = { MODEL_TENSOR.FFN_DOWN_SHEXP, MODEL_TENSOR.FFN_UP_SHEXP, ], + MODEL_ARCH.HUNYUAN_DENSE: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_Q, + MODEL_TENSOR.ATTN_Q_NORM, + MODEL_TENSOR.ATTN_K, + MODEL_TENSOR.ATTN_K_NORM, + MODEL_TENSOR.ATTN_V, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.FFN_NORM, + MODEL_TENSOR.FFN_GATE, + MODEL_TENSOR.FFN_DOWN, + MODEL_TENSOR.FFN_UP, + ], MODEL_ARCH.SMOLLM3: [ MODEL_TENSOR.TOKEN_EMBD, MODEL_TENSOR.OUTPUT_NORM, diff --git a/gguf-py/gguf/gguf_writer.py b/gguf-py/gguf/gguf_writer.py index 4f23f9b02..f4fd64ad8 100644 --- a/gguf-py/gguf/gguf_writer.py +++ b/gguf-py/gguf/gguf_writer.py @@ -1047,6 +1047,11 @@ class GGUFWriter: def add_audio_stack_factor(self, value: int) -> None: self.add_uint32(Keys.ClipAudio.Projector.STACK_FACTOR, value) + # diffusion models + + def add_diffusion_shift_logits(self, value: bool) -> None: + self.add_bool(Keys.Diffusion.SHIFT_LOGITS, value) + def _pack(self, fmt: str, value: Any, skip_pack_prefix: bool = False) -> bytes: pack_prefix = '' if not skip_pack_prefix: diff --git a/gguf-py/gguf/tensor_mapping.py b/gguf-py/gguf/tensor_mapping.py index bfd4fd37a..df490fc80 100644 --- a/gguf-py/gguf/tensor_mapping.py +++ b/gguf-py/gguf/tensor_mapping.py @@ -32,6 +32,7 @@ class TensorNameMap: "model.word_embeddings", # bailingmoe "language_model.model.embed_tokens", # llama4 "encoder", # neobert + "model.transformer.wte", # llada ), # Token type embeddings @@ -71,6 +72,7 @@ class TensorNameMap: "head", # rwkv "head.out", # wavtokenizer "lm_head", # llama4 + "model.transformer.ff_out", # llada ), # Output norm @@ -94,6 +96,7 @@ class TensorNameMap: "model.ln_out", # rwkv7 "backbone.final_layer_norm", # wavtokenizer "model.norm", # llama4 + "model.transformer.ln_f", # llada ), # Rope frequencies @@ -139,6 +142,7 @@ class TensorNameMap: "model.layers.{bid}.input_layernorm", # llama4 "transformer_encoder.{bid}.attention_norm", # neobert "model.layers.{bid}.operator_norm", # lfm2 + "model.transformer.blocks.{bid}.attn_norm", # llada ), # Attention norm 2 @@ -183,6 +187,7 @@ class TensorNameMap: "transformer.decoder_layer.{bid}.multi_head_attention.query",# Grok "transformer.h.{bid}.attn.attention.q_proj", # exaone "model.layers.{bid}.self_attn.q_proj", # llama4 + "model.transformer.blocks.{bid}.q_proj", # llada ), # Attention key @@ -199,6 +204,7 @@ class TensorNameMap: "transformer.decoder_layer.{bid}.multi_head_attention.key",# Grok "transformer.h.{bid}.attn.attention.k_proj", # exaone "model.layers.{bid}.self_attn.k_proj", # llama4 + "model.transformer.blocks.{bid}.k_proj", # llada ), # Attention value @@ -214,6 +220,7 @@ class TensorNameMap: "transformer.decoder_layer.{bid}.multi_head_attention.value",# Grok "transformer.h.{bid}.attn.attention.v_proj", # exaone "model.layers.{bid}.self_attn.v_proj", # llama4 + "model.transformer.blocks.{bid}.v_proj", # llada ), # Attention output @@ -246,6 +253,7 @@ class TensorNameMap: "transformer.h.{bid}.attn.attention.out_proj", # exaone "model.layers.{bid}.self_attn.o_proj", # llama4 "transformer_encoder.{bid}.wo", # neobert + "model.transformer.blocks.{bid}.attn_out", # llada ), # Attention output norm @@ -291,6 +299,7 @@ class TensorNameMap: "model.layers.{bid}.post_attention_layernorm", # llama4 "transformer_encoder.{bid}.ffn_norm", # neobert "model.layers.layers.{bid}.pre_mlp_norm", # plamo2 + "model.transformer.blocks.{bid}.ff_norm", # llada ), # Post feed-forward norm @@ -364,6 +373,7 @@ class TensorNameMap: "model.layers.{bid}.feed_forward.up_proj", # llama4 jamba granite-hybrid "transformer_encoder.{bid}.ffn.w12", # neobert "model.layers.{bid}.block_sparse_moe.up", # smallthinker + "model.transformer.blocks.{bid}.up_proj", # llada ), MODEL_TENSOR.FFN_UP_EXP: ( @@ -405,6 +415,7 @@ class TensorNameMap: "transformer.h.{bid}.mlp.c_fc_0", # exaone "model.layers.{bid}.feed_forward.gate_proj", # llama4 jamba granite-hybrid "model.layers.{bid}.block_sparse_moe.gate", # smallthinker + "model.transformer.blocks.{bid}.ff_proj", # llada ), MODEL_TENSOR.FFN_GATE_EXP: ( @@ -454,6 +465,7 @@ class TensorNameMap: "model.layers.{bid}.feed_forward.down_proj", # llama4 jamba granite-hybrid "transformer_encoder.{bid}.ffn.w3", # neobert "model.layers.{bid}.block_sparse_moe.down", # smallthinker + "model.transformer.blocks.{bid}.ff_out", # llada ), MODEL_TENSOR.FFN_DOWN_EXP: ( @@ -604,6 +616,7 @@ class TensorNameMap: ), MODEL_TENSOR.SSM_DT_NORM: ( + "model.layers.layers.{bid}.mixer.dt_norm.weight", # plamo2 "model.layers.{bid}.mamba.dt_layernorm", # jamba ), @@ -633,10 +646,6 @@ class TensorNameMap: "model.layers.layers.{bid}.mixer.D", # plamo2 ), - MODEL_TENSOR.SSM_DT_NORM: ( - "model.layers.layers.{bid}.mixer.dt_norm.weight", # plamo2 - ), - MODEL_TENSOR.SSM_NORM: ( "model.layers.{bid}.mamba.norm", # falcon-h1 granite-hybrid "backbone.layers.{bid}.mixer.norm", # mamba2 diff --git a/include/llama.h b/include/llama.h index b46416b13..0e6373526 100644 --- a/include/llama.h +++ b/include/llama.h @@ -287,10 +287,11 @@ extern "C" { const struct llama_model_kv_override * kv_overrides; // Keep the booleans together to avoid misalignment during copy-by-value. - bool vocab_only; // only load the vocabulary, no weights - bool use_mmap; // use mmap if possible - bool use_mlock; // force system to keep model in RAM - bool check_tensors; // validate model tensor data + bool vocab_only; // only load the vocabulary, no weights + bool use_mmap; // use mmap if possible + bool use_mlock; // force system to keep model in RAM + bool check_tensors; // validate model tensor data + bool use_extra_bufts; // use extra buffer types (used for weight repacking) }; // NOTE: changing the default values of parameters marked as [EXPERIMENTAL] may cause crashes or incorrect results in certain configurations @@ -540,6 +541,9 @@ extern "C" { // Returns true if the model is recurrent (like Mamba, RWKV, etc.) LLAMA_API bool llama_model_is_recurrent(const struct llama_model * model); + // Returns true if the model is diffusion-based (like LLaDA, Dream, etc.) + LLAMA_API bool llama_model_is_diffusion(const struct llama_model * model); + // Returns 0 on success LLAMA_API uint32_t llama_model_quantize( const char * fname_inp, diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp index dbf977443..ba7bf9598 100644 --- a/src/llama-arch.cpp +++ b/src/llama-arch.cpp @@ -85,10 +85,12 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_ERNIE4_5, "ernie4_5" }, { LLM_ARCH_ERNIE4_5_MOE, "ernie4_5-moe" }, { LLM_ARCH_HUNYUAN_MOE, "hunyuan-moe" }, + { LLM_ARCH_HUNYUAN_DENSE, "hunyuan-dense" }, { LLM_ARCH_SMOLLM3, "smollm3" }, { LLM_ARCH_LFM2, "lfm2" }, { LLM_ARCH_DREAM, "dream" }, { LLM_ARCH_SMALLTHINKER, "smallthinker" }, + { LLM_ARCH_LLADA, "llada" }, { LLM_ARCH_UNKNOWN, "(unknown)" }, }; @@ -1896,6 +1898,26 @@ static const std::map> LLM_TENSOR_N { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" }, }, }, + { + LLM_ARCH_HUNYUAN_DENSE, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + + }, + }, { LLM_ARCH_SMOLLM3, { @@ -1972,6 +1994,23 @@ static const std::map> LLM_TENSOR_N { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, }, }, + { + LLM_ARCH_LLADA, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + }, + }, { LLM_ARCH_UNKNOWN, { @@ -2224,6 +2263,7 @@ bool llm_arch_is_hybrid(const llm_arch & arch) { bool llm_arch_is_diffusion(const llm_arch & arch) { switch (arch) { case LLM_ARCH_DREAM: + case LLM_ARCH_LLADA: return true; default: return false; diff --git a/src/llama-arch.h b/src/llama-arch.h index 8267a8d3a..9b8bd65b2 100644 --- a/src/llama-arch.h +++ b/src/llama-arch.h @@ -89,10 +89,12 @@ enum llm_arch { LLM_ARCH_ERNIE4_5, LLM_ARCH_ERNIE4_5_MOE, LLM_ARCH_HUNYUAN_MOE, + LLM_ARCH_HUNYUAN_DENSE, LLM_ARCH_SMOLLM3, LLM_ARCH_LFM2, LLM_ARCH_DREAM, LLM_ARCH_SMALLTHINKER, + LLM_ARCH_LLADA, LLM_ARCH_UNKNOWN, }; diff --git a/src/llama-chat.cpp b/src/llama-chat.cpp index d34bb2687..c4576e242 100644 --- a/src/llama-chat.cpp +++ b/src/llama-chat.cpp @@ -66,6 +66,7 @@ static const std::map LLM_CHAT_TEMPLATES = { { "llama4", LLM_CHAT_TEMPLATE_LLAMA4 }, { "smolvlm", LLM_CHAT_TEMPLATE_SMOLVLM }, { "hunyuan-moe", LLM_CHAT_TEMPLATE_HUNYUAN_MOE }, + { "hunyuan-dense", LLM_CHAT_TEMPLATE_HUNYUAN_DENSE }, { "kimi-k2", LLM_CHAT_TEMPLATE_KIMI_K2 }, }; @@ -193,6 +194,8 @@ llm_chat_template llm_chat_detect_template(const std::string & tmpl) { return LLM_CHAT_TEMPLATE_DOTS1; } else if (tmpl_contains("<|startoftext|>") && tmpl_contains("<|extra_4|>")) { return LLM_CHAT_TEMPLATE_HUNYUAN_MOE; + } else if (tmpl_contains("<|hy_place▁holder▁no▁2|>") && tmpl_contains("<|hy_place▁holder▁no▁3|>")) { + return LLM_CHAT_TEMPLATE_HUNYUAN_DENSE; } else if (tmpl_contains("<|im_assistant|>assistant<|im_middle|>")) { return LLM_CHAT_TEMPLATE_KIMI_K2; } @@ -698,11 +701,27 @@ int32_t llm_chat_apply_template( if (role == "system") { ss << "<|startoftext|>" << message->content << "<|extra_4|>"; } else if (role == "assistant") { - ss << "<|startoftext|>" << message->content << "<|eos|>"; + ss << message->content << "<|eos|>"; } else { ss << "<|startoftext|>" << message->content << "<|extra_0|>"; } } + } else if (tmpl == LLM_CHAT_TEMPLATE_HUNYUAN_DENSE) { + // tencent/Hunyuan-4B-Instruct + for (size_t i = 0; i < chat.size(); i++) { + std::string role(chat[i]->role); + if (i == 0) { + if (role == "system") { + ss << chat[i]->content << "<|hy_place▁holder▁no▁3|>"; + } + } + + if (role == "assistant") { + ss << "<|hy_Assistant|>" << chat[i]->content << "<|hy_place▁holder▁no▁2|>"; + } else if (role == "user") { + ss << "<|hy_User|>" << chat[i]->content << "<|hy_Assistant|>"; + } + } } else if (tmpl == LLM_CHAT_TEMPLATE_KIMI_K2) { // moonshotai/Kimi-K2-Instruct for (auto message : chat) { diff --git a/src/llama-chat.h b/src/llama-chat.h index 6968a19fb..4cf77fd28 100644 --- a/src/llama-chat.h +++ b/src/llama-chat.h @@ -46,6 +46,7 @@ enum llm_chat_template { LLM_CHAT_TEMPLATE_SMOLVLM, LLM_CHAT_TEMPLATE_DOTS1, LLM_CHAT_TEMPLATE_HUNYUAN_MOE, + LLM_CHAT_TEMPLATE_HUNYUAN_DENSE, LLM_CHAT_TEMPLATE_KIMI_K2, LLM_CHAT_TEMPLATE_UNKNOWN, }; diff --git a/src/llama-context.cpp b/src/llama-context.cpp index bc127259d..9bf77c169 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -113,6 +113,15 @@ llama_context::llama_context( } } + { + const char * LLAMA_GRAPH_REUSE_DISABLE = getenv("LLAMA_GRAPH_REUSE_DISABLE"); + graph_reuse_disable = LLAMA_GRAPH_REUSE_DISABLE ? (atoi(LLAMA_GRAPH_REUSE_DISABLE) != 0) : graph_reuse_disable; + + if (graph_reuse_disable) { + LLAMA_LOG_WARN("%s: graph reuse disabled\n", __func__); + } + } + const uint32_t n_ctx_per_seq = cparams.n_ctx / cparams.n_seq_max; LLAMA_LOG_INFO("%s: n_seq_max = %u\n", __func__, cparams.n_seq_max); @@ -716,7 +725,7 @@ llm_graph_result * llama_context::process_ubatch(const llama_ubatch & ubatch, ll // in order to correctly reuse a graph, it's full topology has to be uniquely determined by these parameters const auto gparams = graph_params(res, ubatch, mctx, gtype); - if (res->can_reuse(gparams)) { + if (!graph_reuse_disable && res->can_reuse(gparams)) { //LLAMA_LOG_DEBUG("%s: reusing previous graph\n", __func__); n_reused++; diff --git a/src/llama-context.h b/src/llama-context.h index 5c3a1c098..7cfdc6a51 100644 --- a/src/llama-context.h +++ b/src/llama-context.h @@ -291,6 +291,9 @@ private: // ref: https://github.com/ggml-org/llama.cpp/pull/14285 bool supports_set_rows = false; + // env: LLAMA_GRAPH_REUSE_DISABLE + bool graph_reuse_disable = false; + // perf mutable int64_t t_start_us = 0; mutable int64_t t_load_us = 0; diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index 702192b79..491a26b63 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -785,13 +785,20 @@ ggml_tensor * llm_graph_context::build_moe_ffn( bool scale_w, float w_scale, llama_expert_gating_func_type gating_op, - int il) const { + int il, + ggml_tensor * probs_in) const { const int64_t n_embd = cur->ne[0]; const int64_t n_tokens = cur->ne[1]; const bool weight_before_ffn = arch == LLM_ARCH_LLAMA4; // for llama4, we apply the sigmoid-ed weights before the FFN - ggml_tensor * logits = build_lora_mm(gate_inp, cur); // [n_expert, n_tokens] - cb(logits, "ffn_moe_logits", il); + ggml_tensor * logits = nullptr; + + if (probs_in == nullptr) { + logits = build_lora_mm(gate_inp, cur); // [n_expert, n_tokens] + cb(logits, "ffn_moe_logits", il); + } else { + logits = probs_in; + } ggml_tensor * probs = nullptr; switch (gating_op) { @@ -884,6 +891,14 @@ ggml_tensor * llm_graph_context::build_moe_ffn( cur = ggml_gelu(ctx0, cur); cb(cur, "ffn_moe_gelu", il); } break; + case LLM_FFN_RELU: + if (gate_exps) { + cur = ggml_reglu_split(ctx0, cur, up); + cb(cur, "ffn_moe_reglu", il); + } else { + cur = ggml_relu(ctx0, cur); + cb(cur, "ffn_moe_relu", il); + } break; default: GGML_ABORT("fatal error"); } @@ -927,100 +942,6 @@ ggml_tensor * llm_graph_context::build_moe_ffn( return moe_out; } -ggml_tensor * llm_graph_context::build_moe_ffn_from_probs( - ggml_tensor * cur, - ggml_tensor * probs, - ggml_tensor * up_exps, - ggml_tensor * gate_exps, - ggml_tensor * down_exps, - ggml_tensor * exp_probs_b, - int64_t n_expert, - int64_t n_expert_used, - llama_expert_gating_func_type gating_op, - int il) const { - const int64_t n_embd = cur->ne[0]; - const int64_t n_tokens = cur->ne[1]; - - // add experts selection bias - introduced in DeepSeek V3 - // leave probs unbiased as it's later used to get expert weights - ggml_tensor * selection_probs = probs; - if (exp_probs_b != nullptr) { - selection_probs = ggml_add(ctx0, probs, exp_probs_b); - cb(selection_probs, "ffn_moe_probs_biased", il); - } - - // select experts - ggml_tensor * selected_experts = ggml_top_k(ctx0, selection_probs, n_expert_used); // [n_expert_used, n_tokens] - cb(selected_experts->src[0], "ffn_moe_argsort", il); - cb(selected_experts, "ffn_moe_topk", il); - - ggml_tensor * weights = ggml_get_rows(ctx0, - ggml_reshape_3d(ctx0, probs, 1, n_expert, n_tokens), selected_experts); // [1, n_expert_used, n_tokens] - cb(weights, "ffn_moe_weights", il); - - weights = ggml_reshape_2d(ctx0, weights, n_expert_used, n_tokens); - if (gating_op == LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX) { - weights = ggml_soft_max(ctx0, weights); - } else { - weights = ggml_sigmoid(ctx0, weights); - ggml_tensor * weights_sum = ggml_sum_rows(ctx0, weights); // [1, n_tokens] - cb(weights_sum, "ffn_moe_weights_sum", il); - - weights = ggml_div(ctx0, weights, weights_sum); // [n_expert_used, n_tokens] - cb(weights, "ffn_moe_weights_norm", il); - } - - weights = ggml_reshape_3d(ctx0, weights, 1, n_expert_used, n_tokens); - - cur = ggml_reshape_3d(ctx0, cur, n_embd, 1, n_tokens); - - ggml_tensor * up = build_lora_mm_id(up_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens] - cb(up, "ffn_moe_up", il); - - ggml_tensor * experts = nullptr; - cur = build_lora_mm_id(gate_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens] - cb(cur, "ffn_moe_gate", il); - - cur = ggml_reglu_split(ctx0, cur, up); - cb(cur, "ffn_moe_reglu", il); - - experts = build_lora_mm_id(down_exps, cur, selected_experts); // [n_embd, n_expert_used, n_tokens] - cb(experts, "ffn_moe_down", il); - - experts = ggml_mul(ctx0, experts, weights); - cb(cur, "ffn_moe_weighted", il); - - ggml_tensor * cur_experts[LLAMA_MAX_EXPERTS] = { nullptr }; - - assert(n_expert_used > 0); - - // order the views before the adds - for (uint32_t i = 0; i < hparams.n_expert_used; ++i) { - cur_experts[i] = ggml_view_2d(ctx0, experts, n_embd, n_tokens, experts->nb[2], i*experts->nb[1]); - - ggml_build_forward_expand(gf, cur_experts[i]); - } - - // aggregate experts - // note: here we explicitly use hparams.n_expert_used instead of n_expert_used - // to avoid potentially a large number of add nodes during warmup - // ref: https://github.com/ggml-org/llama.cpp/pull/14753 - ggml_tensor * moe_out = cur_experts[0]; - - for (uint32_t i = 1; i < hparams.n_expert_used; ++i) { - moe_out = ggml_add(ctx0, moe_out, cur_experts[i]); - } - - if (n_expert_used == 1) { - // avoid returning a non-contiguous tensor - moe_out = ggml_cont(ctx0, moe_out); - } - - cb(moe_out, "ffn_moe_out", il); - - return moe_out; -} - // input embeddings with optional lora ggml_tensor * llm_graph_context::build_inp_embd(ggml_tensor * tok_embd) const { const int64_t n_embd = hparams.n_embd; @@ -1644,16 +1565,17 @@ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unif ggml_tensor * llm_graph_context::build_rs( ggml_tensor * s, - ggml_tensor * state_copy, + ggml_tensor * state_copy_main, + ggml_tensor * state_copy_extra, int32_t state_size, int32_t n_seqs, - uint32_t n_kv, - uint32_t kv_head, - uint32_t kv_size, + uint32_t n_rs, + uint32_t rs_head, + uint32_t rs_size, int32_t rs_zero, const llm_graph_get_rows_fn & get_state_rows) const { - ggml_tensor * states = ggml_reshape_2d(ctx0, s, state_size, kv_size); + ggml_tensor * states = ggml_reshape_2d(ctx0, s, state_size, rs_size); // Clear a single state which will then be copied to the other cleared states. // Note that this is a no-op when the view is zero-sized. @@ -1661,39 +1583,44 @@ ggml_tensor * llm_graph_context::build_rs( ggml_build_forward_expand(gf, ggml_scale_inplace(ctx0, state_zero, 0)); // copy states - // NOTE: assuming the copy destinations are ALL contained between kv_head and kv_head + n_kv - // {state_size, kv_size} -> {state_size, n_seqs} - ggml_tensor * output_states = get_state_rows(ctx0, states, ggml_view_1d(ctx0, state_copy, n_seqs, 0)); + // NOTE: assuming the copy destinations are ALL contained between rs_head and rs_head + n_rs + // {state_size, rs_size} -> {state_size, n_seqs} + ggml_tensor * output_states = get_state_rows(ctx0, states, state_copy_main); ggml_build_forward_expand(gf, output_states); - // copy extra states which won't be changed further (between n_seqs and n_kv) - ggml_tensor * states_extra = ggml_get_rows(ctx0, states, ggml_view_1d(ctx0, state_copy, n_kv - n_seqs, n_seqs*state_copy->nb[0])); + // copy extra states which won't be changed further (between n_seqs and n_rs) + ggml_tensor * states_extra = ggml_get_rows(ctx0, states, state_copy_extra); ggml_build_forward_expand(gf, ggml_cpy(ctx0, states_extra, - ggml_view_1d(ctx0, s, state_size*(n_kv - n_seqs), (kv_head + n_seqs)*state_size*ggml_element_size(s)))); + ggml_view_1d(ctx0, s, state_size*(n_rs - n_seqs), (rs_head + n_seqs)*state_size*ggml_element_size(s)))); return output_states; } static std::unique_ptr build_rs_inp_impl( ggml_context * ctx0, + const llama_ubatch & ubatch, const llama_memory_recurrent_context * mctx_cur) { auto inp = std::make_unique(mctx_cur); - const auto n_rs = mctx_cur->get_n_rs(); + const int64_t n_rs = mctx_cur->get_n_rs(); + const int64_t n_seqs = ubatch.n_seqs; inp->s_copy = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_rs); ggml_set_input(inp->s_copy); + inp->s_copy_main = ggml_view_1d(ctx0, inp->s_copy, n_seqs, 0); + inp->s_copy_extra = ggml_view_1d(ctx0, inp->s_copy, n_rs - n_seqs, n_seqs * inp->s_copy->nb[0]); + return inp; } llm_graph_input_rs * llm_graph_context::build_rs_inp() const { const auto * mctx_cur = static_cast(mctx); - auto inp = build_rs_inp_impl(ctx0, mctx_cur); + auto inp = build_rs_inp_impl(ctx0, ubatch, mctx_cur); return (llm_graph_input_rs *) res->add_input(std::move(inp)); } @@ -1706,7 +1633,9 @@ ggml_tensor * llm_graph_context::build_rs( const llm_graph_get_rows_fn & get_state_rows) const { const auto * kv_state = inp->mctx; - return build_rs(s, inp->s_copy, state_size, n_seqs, kv_state->get_n_rs(), kv_state->get_head(), kv_state->get_size(), kv_state->get_rs_z(), get_state_rows); + return build_rs(s, inp->s_copy_main, inp->s_copy_extra, state_size, n_seqs, + kv_state->get_n_rs(), kv_state->get_head(), kv_state->get_size(), kv_state->get_rs_z(), + get_state_rows); } ggml_tensor * llm_graph_context::build_rwkv_token_shift_load( @@ -1753,7 +1682,7 @@ ggml_tensor * llm_graph_context::build_rwkv_token_shift_store( llm_graph_input_mem_hybrid * llm_graph_context::build_inp_mem_hybrid() const { const auto * mctx_cur = static_cast(mctx); - auto inp_rs = build_rs_inp_impl(ctx0, mctx_cur->get_recr()); + auto inp_rs = build_rs_inp_impl(ctx0, ubatch, mctx_cur->get_recr()); auto inp_attn = build_attn_inp_kv_unified_impl(ctx0, ubatch, hparams, cparams, mctx_cur->get_attn()); auto inp = std::make_unique(std::move(inp_attn), std::move(inp_rs), mctx_cur); diff --git a/src/llama-graph.h b/src/llama-graph.h index 8eae4f551..8614d4967 100644 --- a/src/llama-graph.h +++ b/src/llama-graph.h @@ -214,7 +214,12 @@ public: void set_input(const llama_ubatch * ubatch) override; - ggml_tensor * s_copy; // I32 [kv_size] + ggml_tensor * s_copy; // I32 [n_rs] + + // views of s_copy, computed once per graph + // and shared across layers which use build_rs + ggml_tensor * s_copy_main; // I32 [n_seqs] + ggml_tensor * s_copy_extra; // I32 [n_rs - n_seqs] const llama_memory_recurrent_context * mctx; }; @@ -418,7 +423,9 @@ struct llm_graph_params { (!ubatch.embd && !other.ubatch.embd) ); - if (can_reuse_ubatch && !ubatch.equal_seqs()) { + // when we split the batch using "equal_seqs" we have to verify that the participating sequences are the same + // the reason is because the set of attention streams would be different for different sequences + if (can_reuse_ubatch && ubatch.equal_seqs()) { if (!ubatch.data) { // if the old ubatch does not own it's data, then we cannot guarantee that it is still alive, and // therefore we cannot perform the sequence id check. normally should never happen @@ -626,19 +633,8 @@ struct llm_graph_context { bool scale_w, float w_scale, llama_expert_gating_func_type gating_op, - int il) const; - - ggml_tensor * build_moe_ffn_from_probs( - ggml_tensor * cur, - ggml_tensor * probs, - ggml_tensor * up_exps, - ggml_tensor * gate_exps, - ggml_tensor * down_exps, - ggml_tensor * exp_probs_b, - int64_t n_expert, - int64_t n_expert_used, - llama_expert_gating_func_type gating_op, - int il) const; + int il, + ggml_tensor * probs_in = nullptr) const; // // inputs @@ -730,7 +726,6 @@ struct llm_graph_context { // recurrent // - // TODO: avoid notion of "kv" // TODO: move this implementation to llama_memory_recurrent. // this is analogous to llama_kv_cache_unified::cpy_k / cpy_v // when moving, avoid passing `ggml_cgraph` - only pass `ggml_context`. would likely need to split the @@ -738,12 +733,13 @@ struct llm_graph_context { // `llama_memory_recurrent` ggml_tensor * build_rs( ggml_tensor * s, - ggml_tensor * state_copy, + ggml_tensor * state_copy_main, + ggml_tensor * state_copy_extra, int32_t state_size, int32_t n_seqs, - uint32_t n_kv, - uint32_t kv_head, - uint32_t kv_size, + uint32_t n_rs, + uint32_t rs_head, + uint32_t rs_size, int32_t rs_zero, const llm_graph_get_rows_fn & get_state_rows = ggml_get_rows) const; diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 09ee8c326..16b49cd6a 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -295,7 +295,7 @@ static ggml_backend_buffer_type_t select_weight_buft(const llama_hparams & hpara } // CPU: ACCEL -> GPU host -> CPU extra -> CPU -static buft_list_t make_cpu_buft_list(const std::vector & devices) { +static buft_list_t make_cpu_buft_list(const std::vector & devices, bool use_extra_bufts) { buft_list_t buft_list; // add ACCEL buffer types @@ -324,21 +324,22 @@ static buft_list_t make_cpu_buft_list(const std::vector & de } } - // add extra buffer types, only if no GPU device is present - // ref: https://github.com/ggml-org/llama.cpp/issues/12481#issuecomment-2743136094 - auto * cpu_dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU); - if (cpu_dev == nullptr) { - throw std::runtime_error(format("%s: no CPU backend found", __func__)); - } + // add extra buffer types + if (use_extra_bufts) { + auto * cpu_dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU); + if (cpu_dev == nullptr) { + throw std::runtime_error(format("%s: no CPU backend found", __func__)); + } - auto * cpu_reg = ggml_backend_dev_backend_reg(cpu_dev); - auto ggml_backend_dev_get_extra_bufts_fn = (ggml_backend_dev_get_extra_bufts_t) - ggml_backend_reg_get_proc_address(cpu_reg, "ggml_backend_dev_get_extra_bufts"); - if (ggml_backend_dev_get_extra_bufts_fn) { - ggml_backend_buffer_type_t * extra_bufts = ggml_backend_dev_get_extra_bufts_fn(cpu_dev); - while (extra_bufts && *extra_bufts) { - buft_list.emplace_back(cpu_dev, *extra_bufts); - ++extra_bufts; + auto * cpu_reg = ggml_backend_dev_backend_reg(cpu_dev); + auto ggml_backend_dev_get_extra_bufts_fn = (ggml_backend_dev_get_extra_bufts_t) + ggml_backend_reg_get_proc_address(cpu_reg, "ggml_backend_dev_get_extra_bufts"); + if (ggml_backend_dev_get_extra_bufts_fn) { + ggml_backend_buffer_type_t * extra_bufts = ggml_backend_dev_get_extra_bufts_fn(cpu_dev); + while (extra_bufts && *extra_bufts) { + buft_list.emplace_back(cpu_dev, *extra_bufts); + ++extra_bufts; + } } } @@ -874,6 +875,21 @@ void llama_model::load_hparams(llama_model_loader & ml) { hparams.causal_attn = false; } break; + case LLM_ARCH_LLADA: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + // LLaDA-8B has 32 layers, similar to LLaMA but for diffusion + switch (hparams.n_layer) { + case 32: + type = LLM_TYPE_8B; + break; + default: + type = LLM_TYPE_UNKNOWN; + } + // Set non-causal attention for diffusion models + hparams.causal_attn = false; + } + break; case LLM_ARCH_QWEN2MOE: { ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp, false); @@ -1749,6 +1765,18 @@ void llama_model::load_hparams(llama_model_loader & ml) { default: type = LLM_TYPE_UNKNOWN; } } break; + case LLM_ARCH_HUNYUAN_DENSE: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + + switch (hparams.n_embd) { + case 1024: type = LLM_TYPE_0_5B; break; + case 2048: type = LLM_TYPE_1_8B; break; + case 3072: type = LLM_TYPE_4B; break; + case 4096: type = LLM_TYPE_7B; break; + default: type = LLM_TYPE_UNKNOWN; + } + } break; case LLM_ARCH_SMOLLM3: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); @@ -1829,7 +1857,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { LLAMA_LOG_INFO("%s: loading model tensors, this can take a while... (mmap = %s)\n", __func__, ml.use_mmap ? "true" : "false"); // build a list of buffer types for the CPU and GPU devices - pimpl->cpu_buft_list = make_cpu_buft_list(devices); + pimpl->cpu_buft_list = make_cpu_buft_list(devices, params.use_extra_bufts); for (auto * dev : devices) { buft_list_t buft_list = make_gpu_buft_list(dev, split_mode, tensor_split); // add CPU buffer types as a fallback @@ -2045,7 +2073,13 @@ bool llama_model::load_tensors(llama_model_loader & ml) { for (const auto * overrides = ml.tensor_buft_overrides; overrides->pattern != nullptr; ++overrides) { std::regex pattern(overrides->pattern); if (std::regex_search(tensor_name, pattern)) { - buft = overrides->buft; + if (overrides->buft == ggml_backend_cpu_buffer_type()) { + // when overriding to a CPU buffer, consider the extra buffer types + buft = select_weight_buft(hparams, t_meta, op, pimpl->cpu_buft_list); + } else { + buft = overrides->buft; + } + LLAMA_LOG_DEBUG("tensor %s (%zu MiB %s) buffer type overridden to %s\n", tensor_name.c_str(), ggml_nbytes(t_meta) / 1024 / 1024, ggml_type_name(t_meta->type), @@ -2207,6 +2241,53 @@ bool llama_model::load_tensors(llama_model_loader & ml) { } } } break; + case LLM_ARCH_LLADA: + { + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), { n_embd }, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), { n_embd, n_vocab }, TENSOR_NOT_REQUIRED); + + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = + create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, TENSOR_DUPLICATED); + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), { n_embd }, 0); + + // Use separate Q, K, V projections without bias, matching LLaDALlamaBlock + layer.wq = + create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), { n_embd, n_embd_head_k * n_head }, 0); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), { n_embd, n_embd_k_gqa }, 0); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), { n_embd, n_embd_v_gqa }, 0); + // No bias for QKV projections as per config: include_bias=false, include_qkv_bias=false + layer.wo = + create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd_head_k * n_head, n_embd }, 0); + layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), { n_embd }, TENSOR_NOT_REQUIRED); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), { n_embd }, 0); + + layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), { n_rot / 2 }, + TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); + + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), { n_embd, n_ff }, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd }, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), { n_embd, n_ff }, 0); + + // optional MLP bias + layer.ffn_gate_b = + create_tensor(tn(LLM_TENSOR_FFN_GATE, "bias", i), { n_ff }, TENSOR_NOT_REQUIRED); + layer.ffn_down_b = + create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), { n_embd }, TENSOR_NOT_REQUIRED); + layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), { n_ff }, TENSOR_NOT_REQUIRED); + } + } + break; case LLM_ARCH_LLAMA4: { tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); @@ -5222,6 +5303,39 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {hparams.n_ff_shexp, n_embd}, 0); } } break; + case LLM_ARCH_HUNYUAN_DENSE: + { + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, 0); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); + + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, 0); + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + + } + } break; case LLM_ARCH_SMOLLM3: { tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); @@ -8142,6 +8256,106 @@ struct llm_build_dream : public llm_graph_context { } }; +struct llm_build_llada : public llm_graph_context { + llm_build_llada(const llama_model & model, const llm_graph_params & params) : + llm_graph_context(params) { + // LLaDA is similar to LLaMA but uses non-causal attention for diffusion + const int64_t n_embd_head = hparams.n_embd_head_v; + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + GGML_ASSERT(n_embd_head == hparams.n_rot); + + ggml_tensor * cur; + ggml_tensor * inpL; + + inpL = build_inp_embd(model.tok_embd); + + // inp_pos - contains the positions + ggml_tensor * inp_pos = build_inp_pos(); + + // Non-causal attention for diffusion + auto * inp_attn = build_attn_inp_no_cache(); + + ggml_tensor * inp_out_ids = build_inp_out_ids(); + + for (int il = 0; il < n_layer; ++il) { + ggml_tensor * inpSA = inpL; + + // norm + cur = build_norm(inpL, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, il); + cb(cur, "attn_norm", il); + + // self-attention + { + // compute separate Q, K, V projections without bias, matching LLaDALlamaBlock + ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); + ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); + ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + + Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow); + + Kcur = ggml_rope_ext(ctx0, Kcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow); + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + cur = build_attn(inp_attn, model.layers[il].wo, NULL, Qcur, Kcur, Vcur, nullptr, nullptr, + 1.0f / sqrtf(float(n_embd_head)), il); + } + + if (il == n_layer - 1 && inp_out_ids) { + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + + ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + + // feed-forward network + cur = build_norm(ffn_inp, model.layers[il].ffn_norm, NULL, LLM_NORM_RMS, il); + cb(cur, "ffn_norm", il); + + cur = build_ffn(cur, model.layers[il].ffn_up, NULL, NULL, model.layers[il].ffn_gate, NULL, NULL, + model.layers[il].ffn_down, NULL, NULL, NULL, LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(cur, "ffn_out", il); + + cur = ggml_add(ctx0, cur, ffn_inp); + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + cur = inpL; + + cur = build_norm(cur, model.output_norm, NULL, LLM_NORM_RMS, -1); + + cb(cur, "result_norm", -1); + res->t_embd = cur; + + // lm_head + cur = build_lora_mm(model.output, cur); + + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); + } +}; + struct llm_build_qwen2vl : public llm_graph_context { llm_build_qwen2vl(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v; @@ -16861,6 +17075,144 @@ struct llm_build_hunyuan_moe : public llm_graph_context { } }; +struct llm_build_hunyuan_dense : public llm_graph_context { + llm_build_hunyuan_dense(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_v; + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + GGML_ASSERT(n_embd_head == hparams.n_rot); + + ggml_tensor * cur; + ggml_tensor * inpL; + + inpL = build_inp_embd(model.tok_embd); + + // inp_pos - contains the positions + ggml_tensor * inp_pos = build_inp_pos(); + + auto * inp_attn = build_attn_inp_kv_unified(); + + const float kq_scale = 1.0f / sqrtf(float(n_embd_head)); + + ggml_tensor * inp_out_ids = build_inp_out_ids(); + + for (int il = 0; il < n_layer; ++il) { + ggml_tensor * inpSA = inpL; + + // norm + cur = build_norm(inpL, + model.layers[il].attn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "attn_norm", il); + // self-attention + { + // rope freq factors for llama3; may return nullptr for llama2 and other models + ggml_tensor * rope_factors = model.get_rope_factors(cparams, il); + + // compute Q and K and RoPE them + ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); + cb(Qcur, "Qcur", il); + if (model.layers[il].bq) { + Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); + cb(Qcur, "Qcur", il); + } + + ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); + cb(Kcur, "Kcur", il); + if (model.layers[il].bk) { + Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); + cb(Kcur, "Kcur", il); + } + + ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); + cb(Vcur, "Vcur", il); + if (model.layers[il].bv) { + Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); + cb(Vcur, "Vcur", il); + } + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, rope_factors, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, rope_factors, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + Kcur = build_norm(Kcur, + model.layers[il].attn_k_norm, nullptr, + LLM_NORM_RMS, il); + cb(Kcur, "Kcur_norm", il); + + Qcur = build_norm(Qcur, + model.layers[il].attn_q_norm, nullptr, + LLM_NORM_RMS, il); + cb(Qcur, "Qcur_norm", il); + + cur = build_attn(inp_attn, + model.layers[il].wo, model.layers[il].bo, + Qcur, Kcur, Vcur, nullptr, nullptr, kq_scale, il); + cb(cur, "attn_out", il); + } + + if (il == n_layer - 1 && inp_out_ids) { + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + + ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + + cur = build_norm(ffn_inp, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "ffn_norm", il); + // feed-forward network (non-MoE) + ggml_tensor * cur_mlp = build_ffn(cur, + model.layers[il].ffn_up, NULL, NULL, + model.layers[il].ffn_gate, NULL, NULL, + model.layers[il].ffn_down, NULL, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(cur_mlp, "ffn_out", il); + + cur = ggml_add(ctx0, cur_mlp, ffn_inp); + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + cur = inpL; + + cur = build_norm(cur, + model.output_norm, NULL, + LLM_NORM_RMS, -1); + + cb(cur, "result_norm", -1); + res->t_embd = cur; + // lm_head + cur = build_lora_mm(model.output, cur); + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); + } +}; + struct llm_build_smollm3 : public llm_graph_context { llm_build_smollm3(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v; @@ -17258,10 +17610,18 @@ struct llm_build_smallthinker : public llm_graph_context{ cur = build_norm(ffn_inp, model.layers[il].ffn_norm, NULL, LLM_NORM_RMS, il); cb(cur, "ffn_norm", il); - ggml_tensor * ffn_out = build_moe_ffn_from_probs(cur, probs, model.layers[il].ffn_up_exps, - model.layers[il].ffn_gate_exps, model.layers[il].ffn_down_exps, - nullptr, n_expert, n_expert_used, - static_cast(hparams.expert_gating_func), il); + ggml_tensor * ffn_out = + build_moe_ffn(cur, + nullptr, + model.layers[il].ffn_up_exps, + model.layers[il].ffn_gate_exps, + model.layers[il].ffn_down_exps, + nullptr, + n_expert, n_expert_used, + LLM_FFN_RELU, true, + false, 0.0, + static_cast(hparams.expert_gating_func), + il, probs); cb(ffn_out, "ffn_out", il); cur = ffn_out; @@ -17301,6 +17661,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, case LLM_ARCH_NEO_BERT: case LLM_ARCH_WAVTOKENIZER_DEC: case LLM_ARCH_DREAM: + case LLM_ARCH_LLADA: { res = nullptr; } break; @@ -17467,6 +17828,11 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const { llm = std::make_unique(*this, params); } break; + case LLM_ARCH_LLADA: + { + llm = std::make_unique(*this, params); + } + break; case LLM_ARCH_QWEN2VL: { llm = std::make_unique(*this, params); @@ -17714,6 +18080,10 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const { { llm = std::make_unique(*this, params); } break; + case LLM_ARCH_HUNYUAN_DENSE: + { + llm = std::make_unique(*this, params); + } break; case LLM_ARCH_SMOLLM3: { llm = std::make_unique(*this, params); @@ -17763,6 +18133,7 @@ llama_model_params llama_model_default_params() { /*.use_mmap =*/ true, /*.use_mlock =*/ false, /*.check_tensors =*/ false, + /*.use_extra_bufts =*/ true, }; #ifdef GGML_USE_METAL @@ -17865,6 +18236,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) { // use what we call a normal RoPE, operating on pairs of consecutive head values case LLM_ARCH_LLAMA: + case LLM_ARCH_LLADA: case LLM_ARCH_LLAMA4: case LLM_ARCH_DECI: case LLM_ARCH_BAICHUAN: @@ -17931,6 +18303,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) { case LLM_ARCH_MINICPM3: case LLM_ARCH_DOTS1: case LLM_ARCH_HUNYUAN_MOE: + case LLM_ARCH_HUNYUAN_DENSE: case LLM_ARCH_LFM2: case LLM_ARCH_SMALLTHINKER: return LLAMA_ROPE_TYPE_NEOX; @@ -18043,6 +18416,10 @@ bool llama_model_is_recurrent(const llama_model * model) { return llm_arch_is_recurrent(model->arch); } +bool llama_model_is_diffusion(const llama_model * model) { + return llm_arch_is_diffusion(model->arch); +} + const std::vector> & llama_internal_get_tensor_map(const llama_model * model) { return model->tensors_by_name; } diff --git a/src/llama-quant.cpp b/src/llama-quant.cpp index 986dabc31..a784060ef 100644 --- a/src/llama-quant.cpp +++ b/src/llama-quant.cpp @@ -878,9 +878,10 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std:: // get more optimal quantization type based on the tensor shape, layer, etc. if (!params->pure && ggml_is_quantized(default_type)) { + int fallback = qs.n_fallback; new_type = llama_tensor_get_type(qs, new_type, tensor, ftype); - // unless the user specifies a type - if (params->tensor_types) { + // unless the user specifies a type, and the tensor geometry will not require fallback quantisation + if (params->tensor_types && qs.n_fallback - fallback == 0) { const std::vector & tensor_types = *static_cast *>(params->tensor_types); const std::string tensor_name(tensor->name); for (const auto & [tname, qtype] : tensor_types) { @@ -893,7 +894,6 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std:: } } } - if (params->token_embedding_type < GGML_TYPE_COUNT && strcmp(tensor->name, "token_embd.weight") == 0) { new_type = params->token_embedding_type; } diff --git a/src/llama-vocab.cpp b/src/llama-vocab.cpp index e5ff54da0..db2a57820 100644 --- a/src/llama-vocab.cpp +++ b/src/llama-vocab.cpp @@ -532,6 +532,7 @@ struct llm_tokenizer_bpe : llm_tokenizer { }; break; case LLAMA_VOCAB_PRE_TYPE_DEEPSEEK3_LLM: + case LLAMA_VOCAB_PRE_TYPE_HUNYUAN_DENSE: regex_exprs = { "\\p{N}{1,3}", "[一-龥぀-ゟ゠-ヿ]+", @@ -2200,6 +2201,10 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { tokenizer_pre == "hunyuan") { pre_type = LLAMA_VOCAB_PRE_TYPE_HUNYUAN; clean_spaces = false; + } else if ( + tokenizer_pre == "hunyuan-dense") { + pre_type = LLAMA_VOCAB_PRE_TYPE_HUNYUAN_DENSE; + clean_spaces = false; } else if ( tokenizer_pre == "kimi-k2") { pre_type = LLAMA_VOCAB_PRE_TYPE_KIMI_K2; diff --git a/src/llama-vocab.h b/src/llama-vocab.h index 221d04d90..1719a78bb 100644 --- a/src/llama-vocab.h +++ b/src/llama-vocab.h @@ -47,6 +47,7 @@ enum llama_vocab_pre_type { LLAMA_VOCAB_PRE_TYPE_SEED_CODER = 35, LLAMA_VOCAB_PRE_TYPE_HUNYUAN = 36, LLAMA_VOCAB_PRE_TYPE_KIMI_K2 = 37, + LLAMA_VOCAB_PRE_TYPE_HUNYUAN_DENSE = 38, }; struct LLM_KV; diff --git a/tools/mtmd/clip.cpp b/tools/mtmd/clip.cpp index 27355a9d9..1309729a1 100644 --- a/tools/mtmd/clip.cpp +++ b/tools/mtmd/clip.cpp @@ -893,10 +893,16 @@ struct clip_graph { int n_head = n_embd/d_head; int num_query = 96; if (ctx->model.hparams.minicpmv_version == 2) { + // MiniCPM-V 2.5 num_query = 96; } else if (ctx->model.hparams.minicpmv_version == 3) { + // MiniCPM-V 2.6 num_query = 64; } else if (ctx->model.hparams.minicpmv_version == 4) { + // MiniCPM-o 2.6 + num_query = 64; + } else if (ctx->model.hparams.minicpmv_version == 5) { + // MiniCPM-V 4.0 num_query = 64; } @@ -3727,10 +3733,16 @@ int clip_n_output_tokens(const struct clip_ctx * ctx, struct clip_image_f32 * im case PROJECTOR_TYPE_MINICPMV: { if (params.minicpmv_version == 2) { + // MiniCPM-V 2.5 n_patches_sq = 96; } else if (params.minicpmv_version == 3) { + // MiniCPM-V 2.6 n_patches_sq = 64; } else if (params.minicpmv_version == 4) { + // MiniCPM-o 2.6 + n_patches_sq = 64; + } else if (params.minicpmv_version == 5) { + // MiniCPM-V 4.0 n_patches_sq = 64; } else { GGML_ABORT("Unknown minicpmv version"); @@ -4459,11 +4471,17 @@ int clip_n_mmproj_embd(const struct clip_ctx * ctx) { return ctx->model.mm_3_b->ne[0]; case PROJECTOR_TYPE_MINICPMV: if (hparams.minicpmv_version == 2) { + // MiniCPM-V 2.5 return 4096; } else if (hparams.minicpmv_version == 3) { + // MiniCPM-V 2.6 return 3584; } else if (hparams.minicpmv_version == 4) { + // MiniCPM-o 2.6 return 3584; + } else if (hparams.minicpmv_version == 5) { + // MiniCPM-V 4.0 + return 2560; } GGML_ABORT("Unknown minicpmv version"); case PROJECTOR_TYPE_GLM_EDGE: diff --git a/tools/mtmd/legacy-models/minicpmv-convert-image-encoder-to-gguf.py b/tools/mtmd/legacy-models/minicpmv-convert-image-encoder-to-gguf.py index cfe0961f9..3c6020954 100644 --- a/tools/mtmd/legacy-models/minicpmv-convert-image-encoder-to-gguf.py +++ b/tools/mtmd/legacy-models/minicpmv-convert-image-encoder-to-gguf.py @@ -497,11 +497,11 @@ ap.add_argument("--projector-type", help="Type of projector. Possible values: ml ap.add_argument("-o", "--output-dir", help="Directory to save GGUF files. Default is the original model directory", default=None) # Example --image_mean 0.48145466 0.4578275 0.40821073 --image_std 0.26862954 0.26130258 0.27577711 # Example --image_mean 0.5 0.5 0.5 --image_std 0.5 0.5 0.5 -default_image_mean = [0.48145466, 0.4578275, 0.40821073] -default_image_std = [0.26862954, 0.26130258, 0.27577711] +default_image_mean = [0.5, 0.5, 0.5] +default_image_std = [0.5, 0.5, 0.5] ap.add_argument('--image-mean', type=float, nargs='+', help='Mean of the images for normalization (overrides processor) ', default=None) ap.add_argument('--image-std', type=float, nargs='+', help='Standard deviation of the images for normalization (overrides processor)', default=None) -ap.add_argument('--minicpmv_version', type=int, help='minicpmv_version: MiniCPM-V-2 use 1; MiniCPM-V-2.5 use 2; MiniCPM-V-2.6 use 3; MiniCPM-o-2.6 use 4', default=2) +ap.add_argument('--minicpmv_version', type=int, help='minicpmv_version: MiniCPM-V-2 use 1; MiniCPM-V-2.5 use 2; MiniCPM-V-2.6 use 3; MiniCPM-o-2.6 use 4; MiniCPM-V 4.0 use 5; MiniCPM-o-4.0 use 6', default=2) # with proper args = ap.parse_args() @@ -517,6 +517,17 @@ if args.use_f32: # output in the same directory as the model if output_dir is None dir_model = args.model_dir +# If minicpmv_projector is not specified but the default path exists, use the default path +if args.minicpmv_projector is None: + default_projector_path = os.path.join(dir_model, "minicpmv.projector") + if os.path.isfile(default_projector_path): + args.minicpmv_projector = default_projector_path + print(f"Found default projector file: {default_projector_path}") + +# If output_dir is not specified, use model_dir as the default value +if args.output_dir is None: + args.output_dir = dir_model + if args.clip_model_is_vision or not os.path.exists(dir_model + "/vocab.json") or args.clip_model_is_openclip: vocab = None tokens = None @@ -546,18 +557,21 @@ if args.use_f32: minicpmv_version = args.minicpmv_version emb_dim = 4096 block_count = 26 -if minicpmv_version == 1: +if minicpmv_version == 1: # MiniCPM-V 2.0 emb_dim = 2304 block_count = 26 -elif minicpmv_version == 2: +elif minicpmv_version == 2: # MiniCPM-V 2.5 emb_dim = 4096 block_count = 27 -elif minicpmv_version == 3: +elif minicpmv_version == 3: # MiniCPM-V 2.6 emb_dim = 3584 block_count = 27 -elif minicpmv_version == 4: +elif minicpmv_version == 4: # MiniCPM-o 2.6 emb_dim = 3584 block_count = 27 +elif minicpmv_version == 5: # MiniCPM-V 4.0 + emb_dim = 2560 + block_count = 27 default_vision_config = { "hidden_size": 1152, @@ -577,6 +591,10 @@ if minicpmv_version == 3: elif minicpmv_version == 4: vision_config = SiglipVisionConfig(**default_vision_config) model = SiglipVisionTransformer(vision_config) +elif minicpmv_version == 5: + default_vision_config["model_type"] = "siglip_vision_model" + vision_config = SiglipVisionConfig(**default_vision_config) + model = SiglipVisionTransformer(vision_config) processor = None # if model.attn_pool is not None: @@ -603,7 +621,7 @@ elif args.vision_only: else: fname_middle = "" -output_dir = args.output_dir if args.output_dir is not None else dir_model +output_dir = args.output_dir os.makedirs(output_dir, exist_ok=True) output_prefix = os.path.basename(output_dir).replace("ggml_", "") fname_out = os.path.join(output_dir, f"{fname_middle}model-{ftype_str[ftype]}.gguf") diff --git a/tools/mtmd/mtmd.cpp b/tools/mtmd/mtmd.cpp index 45b2f1f25..a05373d5b 100644 --- a/tools/mtmd/mtmd.cpp +++ b/tools/mtmd/mtmd.cpp @@ -207,7 +207,7 @@ struct mtmd_context { tok_row_end_trail = false; // no trailing end-of-row token ov_img_first = true; - } else if (minicpmv_version == 3 || minicpmv_version == 4) { + } else if (minicpmv_version == 3 || minicpmv_version == 4 || minicpmv_version == 5) { // minicpmv 2.6 format: // (overview) (slice) (slice) \n ... slice_tmpl = MTMD_SLICE_TMPL_MINICPMV_2_6; diff --git a/tools/quantize/quantize.cpp b/tools/quantize/quantize.cpp index 7c6e009be..26457680c 100644 --- a/tools/quantize/quantize.cpp +++ b/tools/quantize/quantize.cpp @@ -312,7 +312,7 @@ static int load_imatrix(const std::string & imatrix_file, std::vector2=p-norm) + json to_json() const { std::vector samplers; samplers.reserve(sampling.samplers.size()); @@ -470,6 +473,33 @@ struct server_task { } } } + } else if (logit_bias != data.end() && logit_bias->is_object()) { + const int n_vocab = llama_vocab_n_tokens(vocab); + for (const auto & el : logit_bias->items()) { + float bias; + const auto & key = el.key(); + const auto & value = el.value(); + if (value.is_number()) { + bias = value.get(); + } else if (value.is_boolean() && !value.get()) { + bias = -INFINITY; + } else { + continue; + } + + char *end; + llama_token tok = strtol(key.c_str(), &end, 10); + if (*end == 0) { + if (tok >= 0 && tok < n_vocab) { + params.sampling.logit_bias.push_back({tok, bias}); + } + } else { + auto toks = common_tokenize(vocab, key, false); + for (auto tok : toks) { + params.sampling.logit_bias.push_back({tok, bias}); + } + } + } } params.sampling.ignore_eos = json_value(data, "ignore_eos", params_base.sampling.ignore_eos); @@ -1899,6 +1929,7 @@ struct server_context { mtmd_context * mctx = nullptr; const llama_vocab * vocab = nullptr; + bool vocab_dft_compatible = true; llama_model * model_dft = nullptr; @@ -1989,10 +2020,9 @@ struct server_context { return false; } - if (!common_speculative_are_compatible(ctx, llama_init_dft.context.get())) { - SRV_ERR("the draft model '%s' is not compatible with the target model '%s'\n", params_base.speculative.model.path.c_str(), params_base.model.path.c_str()); - - return false; + vocab_dft_compatible = common_speculative_are_compatible(ctx, llama_init_dft.context.get()); + if (!vocab_dft_compatible) { + SRV_INF("the draft model '%s' is not compatible with the target model '%s'. tokens will be translated between the draft and target models.\n", params_base.speculative.model.path.c_str(), params_base.model.path.c_str()); } const int n_ctx_dft = llama_n_ctx(llama_init_dft.context.get()); @@ -2082,11 +2112,14 @@ struct server_context { return; } - slot.spec = common_speculative_init(slot.ctx_dft); + slot.spec = common_speculative_init(slot.ctx, slot.ctx_dft); if (slot.spec == nullptr) { SRV_ERR("%s", "failed to create speculator\n"); return; } + for (auto &pair : params_base.speculative.replacements) { + common_speculative_add_replacement_tgt_dft(slot.spec, pair.first.c_str(), pair.second.c_str()); + } } SLT_INF(slot, "new slot n_ctx_slot = %d\n", slot.n_ctx); @@ -2601,7 +2634,7 @@ struct server_context { // normalize only when there is pooling if (llama_pooling_type(slot.ctx) != LLAMA_POOLING_TYPE_NONE) { - common_embd_normalize(embd, embd_res.data(), n_embd, 2); + common_embd_normalize(embd, embd_res.data(), n_embd, slot.params.embd_normalize); res->embedding.push_back(embd_res); break; } else { @@ -4614,6 +4647,14 @@ int main(int argc, char ** argv) { } } + int embd_normalize = 2; // default to Euclidean/L2 norm + if (body.count("embd_normalize") != 0) { + embd_normalize = body.at("embd_normalize"); + if (llama_pooling_type(ctx_server.ctx) == LLAMA_POOLING_TYPE_NONE) { + SRV_DBG("embd_normalize is not supported by pooling type %d, ignoring it\n", llama_pooling_type(ctx_server.ctx)); + } + } + // create and queue the task json responses = json::array(); bool error = false; @@ -4629,6 +4670,7 @@ int main(int argc, char ** argv) { // OAI-compat task.params.oaicompat = oaicompat; + task.params.embd_normalize = embd_normalize; tasks.push_back(std::move(task)); } diff --git a/tools/server/tests/unit/test_chat_completion.py b/tools/server/tests/unit/test_chat_completion.py index 7ee9a1651..6c6f64f5e 100644 --- a/tools/server/tests/unit/test_chat_completion.py +++ b/tools/server/tests/unit/test_chat_completion.py @@ -351,3 +351,32 @@ def test_logprobs_stream(): assert token.top_logprobs is not None assert len(token.top_logprobs) > 0 assert aggregated_text == output_text + + +def test_logit_bias(): + global server + server.start() + + exclude = ["i", "I", "the", "The", "to", "a", "an", "be", "is", "was", "but", "But", "and", "And", "so", "So", "you", "You", "he", "He", "she", "She", "we", "We", "they", "They", "it", "It", "his", "His", "her", "Her", "book", "Book"] + + res = server.make_request("POST", "/tokenize", data={ + "content": " " + " ".join(exclude) + " ", + }) + assert res.status_code == 200 + tokens = res.body["tokens"] + logit_bias = {tok: -100 for tok in tokens} + + client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}/v1") + res = client.chat.completions.create( + model="gpt-3.5-turbo-instruct", + temperature=0.0, + messages=[ + {"role": "system", "content": "Book"}, + {"role": "user", "content": "What is the best book"}, + ], + max_tokens=64, + logit_bias=logit_bias + ) + output_text = res.choices[0].message.content + assert output_text + assert all(output_text.find(" " + tok + " ") == -1 for tok in exclude) diff --git a/tools/server/tests/unit/test_completion.py b/tools/server/tests/unit/test_completion.py index f6909e9ae..be3a0052c 100644 --- a/tools/server/tests/unit/test_completion.py +++ b/tools/server/tests/unit/test_completion.py @@ -444,6 +444,39 @@ def test_n_probs_post_sampling(): assert any(prob["prob"] == 1.0 for prob in tok["top_probs"]) +@pytest.mark.parametrize("tokenize,openai_style", [(False, False), (False, True), (True, False), (True, True)]) +def test_logit_bias(tokenize, openai_style): + global server + server.start() + + exclude = ["i", "I", "the", "The", "to", "a", "an", "be", "is", "was", "but", "But", "and", "And", "so", "So", "you", "You", "he", "He", "she", "She", "we", "We", "they", "They", "it", "It", "his", "His", "her", "Her", "book", "Book"] + + logit_bias = [] + if tokenize: + res = server.make_request("POST", "/tokenize", data={ + "content": " " + " ".join(exclude) + " ", + }) + assert res.status_code == 200 + tokens = res.body["tokens"] + logit_bias = [[tok, -100] for tok in tokens] + + else: + logit_bias = [[" " + tok + " ", -100] for tok in exclude] + + if openai_style: + logit_bias = {el[0]: -100 for el in logit_bias} + + res = server.make_request("POST", "/completion", data={ + "n_predict": 64, + "prompt": "What is the best book", + "logit_bias": logit_bias, + "temperature": 0.0 + }) + assert res.status_code == 200 + output_text = res.body["content"] + assert all(output_text.find(" " + tok + " ") == -1 for tok in exclude) + + def test_cancel_request(): global server server.n_ctx = 4096 diff --git a/vendor/minja/chat-template.hpp b/vendor/minja/chat-template.hpp index ab5b521dd..cf113bf22 100644 --- a/vendor/minja/chat-template.hpp +++ b/vendor/minja/chat-template.hpp @@ -162,10 +162,15 @@ class chat_template { }), false); caps_.supports_tools = contains(out, "some_tool"); + auto out_empty = try_raw_render(json::array({dummy_user_msg, {{"role", "assistant"}, {"content", ""}}}), {}, false); + auto out_null = try_raw_render(json::array({dummy_user_msg, {{"role", "assistant"}, {"content", nullptr}}}), {}, false); + caps_.requires_non_null_content = contains(out_empty, user_needle) && !contains(out_null, user_needle); + + json j_null; auto make_tool_calls_msg = [&](const json & tool_calls) { return json { {"role", "assistant"}, - {"content", nullptr}, + {"content", caps_.requires_non_null_content? "" : j_null}, {"tool_calls", tool_calls}, }; }; @@ -195,9 +200,6 @@ class chat_template { caps_.supports_tool_calls = tool_call_renders_str_arguments || tool_call_renders_obj_arguments; caps_.requires_object_arguments = !tool_call_renders_str_arguments && tool_call_renders_obj_arguments; - auto out_empty = try_raw_render(json::array({dummy_user_msg, {{"role", "assistant"}, {"content", ""}}}), {}, false); - auto out_null = try_raw_render(json::array({dummy_user_msg, {{"role", "assistant"}, {"content", nullptr}}}), {}, false); - caps_.requires_non_null_content = contains(out_empty, user_needle) && !contains(out_null, user_needle); if (caps_.supports_tool_calls) { auto dummy_args = caps_.requires_object_arguments ? dummy_args_obj : json(dummy_args_obj.dump()); @@ -234,7 +236,7 @@ class chat_template { }; const json tool_call_msg { {"role", "assistant"}, - {"content", nullptr}, + {"content", caps_.requires_non_null_content ? "" : j_null}, {"tool_calls", json::array({ { // TODO: detect if requires numerical id or fixed length == 6 like Nemo diff --git a/vendor/minja/minja.hpp b/vendor/minja/minja.hpp index f9658ddc0..dd107dccd 100644 --- a/vendor/minja/minja.hpp +++ b/vendor/minja/minja.hpp @@ -1355,8 +1355,13 @@ public: case Op::Gt: return l > r; case Op::Le: return l <= r; case Op::Ge: return l >= r; - case Op::In: return (r.is_array() || r.is_object()) && r.contains(l); - case Op::NotIn: return !(r.is_array() && r.contains(l)); + case Op::In: return (((r.is_array() || r.is_object()) && r.contains(l)) || + (l.is_string() && r.is_string() && + r.to_str().find(l.to_str()) != std::string::npos)); + case Op::NotIn: + return !(((r.is_array() || r.is_object()) && r.contains(l)) || + (l.is_string() && r.is_string() && + r.to_str().find(l.to_str()) != std::string::npos)); default: break; } throw std::runtime_error("Unknown binary operator"); @@ -1552,6 +1557,19 @@ public: else res[i] = std::tolower(res[i]); } return res; + } else if (method->get_name() == "replace") { + vargs.expectArgs("replace method", {2, 3}, {0, 0}); + auto before = vargs.args[0].get(); + auto after = vargs.args[1].get(); + auto count = vargs.args.size() == 3 ? vargs.args[2].get() + : str.length(); + size_t start_pos = 0; + while ((start_pos = str.find(before, start_pos)) != std::string::npos && + count-- > 0) { + str.replace(start_pos, before.length(), after); + start_pos += after.length(); + } + return str; } } throw std::runtime_error("Unknown method: " + method->get_name()); @@ -2128,7 +2146,7 @@ private: } } - if ((has_first_colon || has_second_colon) && (start || end || step)) { + if ((has_first_colon || has_second_colon)) { index = std::make_shared(slice_loc, std::move(start), std::move(end), std::move(step)); } else { index = std::move(start);