diff --git a/common/arg.cpp b/common/arg.cpp
index 732e1a853..efbac3668 100644
--- a/common/arg.cpp
+++ b/common/arg.cpp
@@ -692,7 +692,7 @@ gpt_params_context gpt_params_parser_init(gpt_params & params, llama_example ex,
[](gpt_params & params) {
params.ctx_shift = false;
}
- ).set_examples({LLAMA_EXAMPLE_MAIN}));
+ ).set_examples({LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_NO_CONTEXT_SHIFT"));
add_opt(llama_arg(
{"--chunks"}, "N",
format("max number of chunks to process (default: %d, -1 = all)", params.n_chunks),
@@ -1103,7 +1103,7 @@ gpt_params_context gpt_params_parser_init(gpt_params & params, llama_example ex,
else if (value == "last") { params.pooling_type = LLAMA_POOLING_TYPE_LAST; }
else { throw std::invalid_argument("invalid value"); }
}
- ).set_examples({LLAMA_EXAMPLE_EMBEDDING}));
+ ).set_examples({LLAMA_EXAMPLE_EMBEDDING, LLAMA_EXAMPLE_RETRIEVAL, LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_POOLING"));
add_opt(llama_arg(
{"--attention"}, "{causal,non,causal}",
"attention type for embeddings, use model default if unspecified",
@@ -1122,77 +1122,77 @@ gpt_params_context gpt_params_parser_init(gpt_params & params, llama_example ex,
else if (value == "yarn") { params.rope_scaling_type = LLAMA_ROPE_SCALING_TYPE_YARN; }
else { throw std::invalid_argument("invalid value"); }
}
- ));
+ ).set_env("LLAMA_ARG_ROPE_SCALING_TYPE"));
add_opt(llama_arg(
{"--rope-scale"}, "N",
"RoPE context scaling factor, expands context by a factor of N",
[](gpt_params & params, const std::string & value) {
params.rope_freq_scale = 1.0f / std::stof(value);
}
- ));
+ ).set_env("LLAMA_ARG_ROPE_SCALE"));
add_opt(llama_arg(
{"--rope-freq-base"}, "N",
"RoPE base frequency, used by NTK-aware scaling (default: loaded from model)",
[](gpt_params & params, const std::string & value) {
params.rope_freq_base = std::stof(value);
}
- ));
+ ).set_env("LLAMA_ARG_ROPE_FREQ_BASE"));
add_opt(llama_arg(
{"--rope-freq-scale"}, "N",
"RoPE frequency scaling factor, expands context by a factor of 1/N",
[](gpt_params & params, const std::string & value) {
params.rope_freq_scale = std::stof(value);
}
- ));
+ ).set_env("LLAMA_ARG_ROPE_FREQ_SCALE"));
add_opt(llama_arg(
{"--yarn-orig-ctx"}, "N",
format("YaRN: original context size of model (default: %d = model training context size)", params.yarn_orig_ctx),
[](gpt_params & params, int value) {
params.yarn_orig_ctx = value;
}
- ));
+ ).set_env("LLAMA_ARG_YARN_ORIG_CTX"));
add_opt(llama_arg(
{"--yarn-ext-factor"}, "N",
format("YaRN: extrapolation mix factor (default: %.1f, 0.0 = full interpolation)", (double)params.yarn_ext_factor),
[](gpt_params & params, const std::string & value) {
params.yarn_ext_factor = std::stof(value);
}
- ));
+ ).set_env("LLAMA_ARG_YARN_EXT_FACTOR"));
add_opt(llama_arg(
{"--yarn-attn-factor"}, "N",
format("YaRN: scale sqrt(t) or attention magnitude (default: %.1f)", (double)params.yarn_attn_factor),
[](gpt_params & params, const std::string & value) {
params.yarn_attn_factor = std::stof(value);
}
- ));
+ ).set_env("LLAMA_ARG_YARN_ATTN_FACTOR"));
add_opt(llama_arg(
{"--yarn-beta-slow"}, "N",
format("YaRN: high correction dim or alpha (default: %.1f)", (double)params.yarn_beta_slow),
[](gpt_params & params, const std::string & value) {
params.yarn_beta_slow = std::stof(value);
}
- ));
+ ).set_env("LLAMA_ARG_YARN_BETA_SLOW"));
add_opt(llama_arg(
{"--yarn-beta-fast"}, "N",
format("YaRN: low correction dim or beta (default: %.1f)", (double)params.yarn_beta_fast),
[](gpt_params & params, const std::string & value) {
params.yarn_beta_fast = std::stof(value);
}
- ));
+ ).set_env("LLAMA_ARG_YARN_BETA_FAST"));
add_opt(llama_arg(
{"-gan", "--grp-attn-n"}, "N",
format("group-attention factor (default: %d)", params.grp_attn_n),
[](gpt_params & params, int value) {
params.grp_attn_n = value;
}
- ));
+ ).set_env("LLAMA_ARG_GRP_ATTN_N"));
add_opt(llama_arg(
{"-gaw", "--grp-attn-w"}, "N",
format("group-attention width (default: %.1f)", (double)params.grp_attn_w),
[](gpt_params & params, int value) {
params.grp_attn_w = value;
}
- ));
+ ).set_env("LLAMA_ARG_GRP_ATTN_W"));
add_opt(llama_arg(
{"-dkvc", "--dump-kv-cache"},
"verbose print of the KV cache",
@@ -1206,7 +1206,7 @@ gpt_params_context gpt_params_parser_init(gpt_params & params, llama_example ex,
[](gpt_params & params) {
params.no_kv_offload = true;
}
- ));
+ ).set_env("LLAMA_ARG_NO_KV_OFFLOAD"));
add_opt(llama_arg(
{"-ctk", "--cache-type-k"}, "TYPE",
format("KV cache data type for K (default: %s)", params.cache_type_k.c_str()),
@@ -1214,7 +1214,7 @@ gpt_params_context gpt_params_parser_init(gpt_params & params, llama_example ex,
// TODO: get the type right here
params.cache_type_k = value;
}
- ));
+ ).set_env("LLAMA_ARG_CACHE_TYPE_K"));
add_opt(llama_arg(
{"-ctv", "--cache-type-v"}, "TYPE",
format("KV cache data type for V (default: %s)", params.cache_type_v.c_str()),
@@ -1222,7 +1222,7 @@ gpt_params_context gpt_params_parser_init(gpt_params & params, llama_example ex,
// TODO: get the type right here
params.cache_type_v = value;
}
- ));
+ ).set_env("LLAMA_ARG_CACHE_TYPE_V"));
add_opt(llama_arg(
{"--perplexity", "--all-logits"},
format("return logits for all tokens in the batch (default: %s)", params.logits_all ? "true" : "false"),
@@ -1356,7 +1356,7 @@ gpt_params_context gpt_params_parser_init(gpt_params & params, llama_example ex,
[](gpt_params & params, const std::string & value) {
params.rpc_servers = value;
}
- ));
+ ).set_env("LLAMA_ARG_RPC"));
#endif
add_opt(llama_arg(
{"--mlock"},
@@ -1364,14 +1364,14 @@ gpt_params_context gpt_params_parser_init(gpt_params & params, llama_example ex,
[](gpt_params & params) {
params.use_mlock = true;
}
- ));
+ ).set_env("LLAMA_ARG_MLOCK"));
add_opt(llama_arg(
{"--no-mmap"},
"do not memory-map model (slower load but may reduce pageouts if not using mlock)",
[](gpt_params & params) {
params.use_mmap = false;
}
- ));
+ ).set_env("LLAMA_ARG_NO_MMAP"));
add_opt(llama_arg(
{"--numa"}, "TYPE",
"attempt optimizations that help on some NUMA systems\n"
@@ -1386,7 +1386,7 @@ gpt_params_context gpt_params_parser_init(gpt_params & params, llama_example ex,
else if (value == "numactl") { params.numa = GGML_NUMA_STRATEGY_NUMACTL; }
else { throw std::invalid_argument("invalid value"); }
}
- ));
+ ).set_env("LLAMA_ARG_NUMA"));
add_opt(llama_arg(
{"-ngl", "--gpu-layers", "--n-gpu-layers"}, "N",
"number of layers to store in VRAM",
@@ -1434,7 +1434,7 @@ gpt_params_context gpt_params_parser_init(gpt_params & params, llama_example ex,
fprintf(stderr, "warning: llama.cpp was compiled without support for GPU offload. Setting the split mode has no effect.\n");
}
}
- ));
+ ).set_env("LLAMA_ARG_SPLIT_MODE"));
add_opt(llama_arg(
{"-ts", "--tensor-split"}, "N0,N1,N2,...",
"fraction of the model to offload to each GPU, comma-separated list of proportions, e.g. 3,1",
@@ -1461,7 +1461,7 @@ gpt_params_context gpt_params_parser_init(gpt_params & params, llama_example ex,
fprintf(stderr, "warning: llama.cpp was compiled without support for GPU offload. Setting a tensor split has no effect.\n");
}
}
- ));
+ ).set_env("LLAMA_ARG_TENSOR_SPLIT"));
add_opt(llama_arg(
{"-mg", "--main-gpu"}, "INDEX",
format("the GPU to use for the model (with split-mode = none), or for intermediate results and KV (with split-mode = row) (default: %d)", params.main_gpu),
@@ -1471,7 +1471,7 @@ gpt_params_context gpt_params_parser_init(gpt_params & params, llama_example ex,
fprintf(stderr, "warning: llama.cpp was compiled without support for GPU offload. Setting the main GPU has no effect.\n");
}
}
- ));
+ ).set_env("LLAMA_ARG_MAIN_GPU"));
add_opt(llama_arg(
{"--check-tensors"},
format("check model tensor data for invalid values (default: %s)", params.check_tensors ? "true" : "false"),
@@ -1534,7 +1534,7 @@ gpt_params_context gpt_params_parser_init(gpt_params & params, llama_example ex,
[](gpt_params & params, const std::string & value) {
params.model_alias = value;
}
- ).set_examples({LLAMA_EXAMPLE_SERVER}));
+ ).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_ALIAS"));
add_opt(llama_arg(
{"-m", "--model"}, "FNAME",
ex == LLAMA_EXAMPLE_EXPORT_LORA
@@ -1742,7 +1742,7 @@ gpt_params_context gpt_params_parser_init(gpt_params & params, llama_example ex,
[](gpt_params & params, const std::string & value) {
params.public_path = value;
}
- ).set_examples({LLAMA_EXAMPLE_SERVER}));
+ ).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_STATIC_PATH"));
add_opt(llama_arg(
{"--embedding", "--embeddings"},
format("restrict to only support embedding use case; use only with dedicated embedding models (default: %s)", params.embedding ? "enabled" : "disabled"),
@@ -1780,14 +1780,14 @@ gpt_params_context gpt_params_parser_init(gpt_params & params, llama_example ex,
[](gpt_params & params, const std::string & value) {
params.ssl_file_key = value;
}
- ).set_examples({LLAMA_EXAMPLE_SERVER}));
+ ).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_SSL_KEY_FILE"));
add_opt(llama_arg(
{"--ssl-cert-file"}, "FNAME",
"path to file a PEM-encoded SSL certificate",
[](gpt_params & params, const std::string & value) {
params.ssl_file_cert = value;
}
- ).set_examples({LLAMA_EXAMPLE_SERVER}));
+ ).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_SSL_CERT_FILE"));
add_opt(llama_arg(
{"-to", "--timeout"}, "N",
format("server read/write timeout in seconds (default: %d)", params.timeout_read),
@@ -1795,7 +1795,7 @@ gpt_params_context gpt_params_parser_init(gpt_params & params, llama_example ex,
params.timeout_read = value;
params.timeout_write = value;
}
- ).set_examples({LLAMA_EXAMPLE_SERVER}));
+ ).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_TIMEOUT"));
add_opt(llama_arg(
{"--threads-http"}, "N",
format("number of threads used to process HTTP requests (default: %d)", params.n_threads_http),
diff --git a/common/log.cpp b/common/log.cpp
index 2825a227e..5a844ed59 100644
--- a/common/log.cpp
+++ b/common/log.cpp
@@ -82,7 +82,7 @@ struct gpt_log_entry {
}
}
- if (level != GGML_LOG_LEVEL_NONE && prefix) {
+ if (level != GGML_LOG_LEVEL_NONE && level != GGML_LOG_LEVEL_CONT && prefix) {
if (timestamp) {
// [M.s.ms.us]
fprintf(fcur, "%s%d.%02d.%03d.%03d%s ",
diff --git a/common/log.h b/common/log.h
index d13f72d89..84f9b3ed7 100644
--- a/common/log.h
+++ b/common/log.h
@@ -83,8 +83,10 @@ void gpt_log_set_timestamps(struct gpt_log * log, bool timestamps); // w
#define LOG_WRN(...) LOG_TMPL(GGML_LOG_LEVEL_WARN, 0, __VA_ARGS__)
#define LOG_ERR(...) LOG_TMPL(GGML_LOG_LEVEL_ERROR, 0, __VA_ARGS__)
#define LOG_DBG(...) LOG_TMPL(GGML_LOG_LEVEL_DEBUG, LOG_DEFAULT_DEBUG, __VA_ARGS__)
+#define LOG_CNT(...) LOG_TMPL(GGML_LOG_LEVEL_CONT, 0, __VA_ARGS__)
#define LOG_INFV(verbosity, ...) LOG_TMPL(GGML_LOG_LEVEL_INFO, verbosity, __VA_ARGS__)
#define LOG_WRNV(verbosity, ...) LOG_TMPL(GGML_LOG_LEVEL_WARN, verbosity, __VA_ARGS__)
#define LOG_ERRV(verbosity, ...) LOG_TMPL(GGML_LOG_LEVEL_ERROR, verbosity, __VA_ARGS__)
#define LOG_DBGV(verbosity, ...) LOG_TMPL(GGML_LOG_LEVEL_DEBUG, verbosity, __VA_ARGS__)
+#define LOG_CNTV(verbosity, ...) LOG_TMPL(GGML_LOG_LEVEL_CONT, verbosity, __VA_ARGS__)
diff --git a/common/sampling.cpp b/common/sampling.cpp
index e51d07611..3dc7f1120 100644
--- a/common/sampling.cpp
+++ b/common/sampling.cpp
@@ -209,7 +209,15 @@ struct gpt_sampler * gpt_sampler_init(const struct llama_model * model, const st
GGML_ASSERT(false && "unknown mirostat version");
}
} else {
- llama_sampler_chain_add(result->chain, llama_sampler_init_softmax());
+ if (params.n_probs > 0) {
+ // some use cases require to sample greedily, but still obtain the probabilities of the top tokens
+ // ref: https://github.com/ggerganov/llama.cpp/pull/9605
+ //
+ // the following will not produce exactly the same probs as applyging softmax to the full vocabulary, but
+ // it is much faster, since we avoid sorting all tokens and should give a good approximation
+ llama_sampler_chain_add(result->chain, llama_sampler_init_top_k(params.n_probs));
+ llama_sampler_chain_add(result->chain, llama_sampler_init_softmax());
+ }
llama_sampler_chain_add(result->chain, llama_sampler_init_greedy());
}
diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py
index ff4c9226f..7be609054 100755
--- a/convert_hf_to_gguf.py
+++ b/convert_hf_to_gguf.py
@@ -4102,16 +4102,45 @@ class GraniteModel(LlamaModel):
# consistency
if attention_scale := self.hparams.get("attention_multiplier"):
self.gguf_writer.add_attention_scale(attention_scale)
+ logger.info("gguf: (granite) attention_scale = %s", attention_scale)
if embedding_scale := self.hparams.get("embedding_multiplier"):
self.gguf_writer.add_embedding_scale(embedding_scale)
+ logger.info("gguf: (granite) embedding_scale = %s", embedding_scale)
if residual_scale := self.hparams.get("residual_multiplier"):
self.gguf_writer.add_residual_scale(residual_scale)
- if logits_scaling := self.hparams.get("logits_scaling"):
- self.gguf_writer.add_logit_scale(logits_scaling)
+ logger.info("gguf: (granite) residual_scale = %s", residual_scale)
+ if logits_scale := self.hparams.get("logits_scaling"):
+ self.gguf_writer.add_logit_scale(logits_scale)
+ logger.info("gguf: (granite) logits_scale = %s", logits_scale)
+
+
+@Model.register("GraniteMoeForCausalLM")
+class GraniteMoeModel(GraniteModel):
+ """Conversion for IBM's GraniteMoeForCausalLM"""
+ model_arch = gguf.MODEL_ARCH.GRANITE_MOE
+
+ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
+ """In modeling_granitemoe, the JetMoe implementation of parallel experts
+ is used. This essentially merges w1 and w3 into a single tensor with 2x
+ the hidden size that is then split during forward. To keep compatibility
+ with existing mixtral support, we pull them apart here.
+ """
+
+ if name.endswith("block_sparse_moe.input_linear.weight"):
+ ffn_dim = self.hparams["intermediate_size"]
+ assert data_torch.shape[-2] == 2 * ffn_dim, "Merged FFN tensor size must be 2 * intermediate_size"
+ gate, up = data_torch[..., :ffn_dim, :], data_torch[..., ffn_dim:, :]
+ return [
+ (self.format_tensor_name(gguf.MODEL_TENSOR.FFN_GATE_EXP, bid), gate),
+ (self.format_tensor_name(gguf.MODEL_TENSOR.FFN_UP_EXP, bid), up),
+ ]
+
+ return super().modify_tensors(data_torch, name, bid)
###### CONVERSION LOGIC ######
+
# tree of lazy tensors
class LazyTorchTensor(gguf.LazyBase):
_tensor_type = torch.Tensor
diff --git a/examples/gen-docs/gen-docs.cpp b/examples/gen-docs/gen-docs.cpp
index b6d4725fd..4b19a9dc2 100644
--- a/examples/gen-docs/gen-docs.cpp
+++ b/examples/gen-docs/gen-docs.cpp
@@ -6,42 +6,73 @@
// Export usage message (-h) to markdown format
+static void write_table_header(std::ofstream & file) {
+ file << "| Argument | Explanation |\n";
+ file << "| -------- | ----------- |\n";
+}
+
+static void write_table_entry(std::ofstream & file, const llama_arg & opt) {
+ file << "| `";
+ // args
+ for (const auto & arg : opt.args) {
+ if (arg == opt.args.front()) {
+ file << arg;
+ if (opt.args.size() > 1) file << ", ";
+ } else {
+ file << arg << (arg != opt.args.back() ? ", " : "");
+ }
+ }
+ // value hint
+ if (opt.value_hint) {
+ std::string md_value_hint(opt.value_hint);
+ string_replace_all(md_value_hint, "|", "\\|");
+ file << " " << md_value_hint;
+ }
+ if (opt.value_hint_2) {
+ std::string md_value_hint_2(opt.value_hint_2);
+ string_replace_all(md_value_hint_2, "|", "\\|");
+ file << " " << md_value_hint_2;
+ }
+ // help text
+ std::string md_help(opt.help);
+ string_replace_all(md_help, "\n", "
");
+ string_replace_all(md_help, "|", "\\|");
+ file << "` | " << md_help << " |\n";
+}
+
+static void write_table(std::ofstream & file, std::vector & opts) {
+ write_table_header(file);
+ for (const auto & opt : opts) {
+ write_table_entry(file, *opt);
+ }
+}
+
static void export_md(std::string fname, llama_example ex) {
std::ofstream file(fname, std::ofstream::out | std::ofstream::trunc);
gpt_params params;
auto ctx_arg = gpt_params_parser_init(params, ex);
- file << "| Argument | Explanation |\n";
- file << "| -------- | ----------- |\n";
+ std::vector common_options;
+ std::vector sparam_options;
+ std::vector specific_options;
for (auto & opt : ctx_arg.options) {
- file << "| `";
- // args
- for (const auto & arg : opt.args) {
- if (arg == opt.args.front()) {
- file << arg;
- if (opt.args.size() > 1) file << ", ";
- } else {
- file << arg << (arg != opt.args.back() ? ", " : "");
- }
+ // in case multiple LLAMA_EXAMPLE_* are set, we prioritize the LLAMA_EXAMPLE_* matching current example
+ if (opt.is_sparam) {
+ sparam_options.push_back(&opt);
+ } else if (opt.in_example(ctx_arg.ex)) {
+ specific_options.push_back(&opt);
+ } else {
+ common_options.push_back(&opt);
}
- // value hint
- if (opt.value_hint) {
- std::string md_value_hint(opt.value_hint);
- string_replace_all(md_value_hint, "|", "\\|");
- file << " " << md_value_hint;
- }
- if (opt.value_hint_2) {
- std::string md_value_hint_2(opt.value_hint_2);
- string_replace_all(md_value_hint_2, "|", "\\|");
- file << " " << md_value_hint_2;
- }
- // help text
- std::string md_help(opt.help);
- string_replace_all(md_help, "\n", "
");
- string_replace_all(md_help, "|", "\\|");
- file << "` | " << md_help << " |\n";
}
+
+ file << "**Common params**\n\n";
+ write_table(file, common_options);
+ file << "\n\n**Sampling params**\n\n";
+ write_table(file, sparam_options);
+ file << "\n\n**Example-specific params**\n\n";
+ write_table(file, specific_options);
}
int main(int, char **) {
diff --git a/examples/main/main.cpp b/examples/main/main.cpp
index 8ceec393d..a4bf4abaf 100644
--- a/examples/main/main.cpp
+++ b/examples/main/main.cpp
@@ -386,9 +386,9 @@ int main(int argc, char ** argv) {
if (params.n_keep > add_bos) {
LOG_INF("%s: static prompt based on n_keep: '", __func__);
for (int i = 0; i < params.n_keep; i++) {
- LOG("%s", llama_token_to_piece(ctx, embd_inp[i]).c_str());
+ LOG_CNT("%s", llama_token_to_piece(ctx, embd_inp[i]).c_str());
}
- LOG("'\n");
+ LOG_CNT("'\n");
}
LOG_INF("\n");
}
@@ -410,40 +410,40 @@ int main(int argc, char ** argv) {
}
if (params.interactive) {
- LOG("%s: interactive mode on.\n", __func__);
+ LOG_INF("%s: interactive mode on.\n", __func__);
if (!params.antiprompt.empty()) {
for (const auto & antiprompt : params.antiprompt) {
- LOG("Reverse prompt: '%s'\n", antiprompt.c_str());
+ LOG_INF("Reverse prompt: '%s'\n", antiprompt.c_str());
if (params.verbose_prompt) {
auto tmp = ::llama_tokenize(ctx, antiprompt, false, true);
for (int i = 0; i < (int) tmp.size(); i++) {
- LOG("%6d -> '%s'\n", tmp[i], llama_token_to_piece(ctx, tmp[i]).c_str());
+ LOG_INF("%6d -> '%s'\n", tmp[i], llama_token_to_piece(ctx, tmp[i]).c_str());
}
}
}
}
if (params.input_prefix_bos) {
- LOG("Input prefix with BOS\n");
+ LOG_INF("Input prefix with BOS\n");
}
if (!params.input_prefix.empty()) {
- LOG("Input prefix: '%s'\n", params.input_prefix.c_str());
+ LOG_INF("Input prefix: '%s'\n", params.input_prefix.c_str());
if (params.verbose_prompt) {
auto tmp = ::llama_tokenize(ctx, params.input_prefix, true, true);
for (int i = 0; i < (int) tmp.size(); i++) {
- LOG("%6d -> '%s'\n", tmp[i], llama_token_to_piece(ctx, tmp[i]).c_str());
+ LOG_INF("%6d -> '%s'\n", tmp[i], llama_token_to_piece(ctx, tmp[i]).c_str());
}
}
}
if (!params.input_suffix.empty()) {
- LOG("Input suffix: '%s'\n", params.input_suffix.c_str());
+ LOG_INF("Input suffix: '%s'\n", params.input_suffix.c_str());
if (params.verbose_prompt) {
auto tmp = ::llama_tokenize(ctx, params.input_suffix, false, true);
for (int i = 0; i < (int) tmp.size(); i++) {
- LOG("%6d -> '%s'\n", tmp[i], llama_token_to_piece(ctx, tmp[i]).c_str());
+ LOG_INF("%6d -> '%s'\n", tmp[i], llama_token_to_piece(ctx, tmp[i]).c_str());
}
}
}
@@ -475,7 +475,7 @@ int main(int argc, char ** argv) {
//GGML_ASSERT(n_ctx >= n_ctx_train * ga_n && "n_ctx must be at least n_ctx_train * grp_attn_n"); // NOLINT
LOG_INF("self-extend: n_ctx_train = %d, grp_attn_n = %d, grp_attn_w = %d\n", n_ctx_train, ga_n, ga_w);
}
- LOG("\n");
+ LOG_INF("\n");
if (params.interactive) {
const char * control_message;
@@ -487,11 +487,11 @@ int main(int argc, char ** argv) {
" - To return control without starting a new line, end your input with '/'.\n"
" - If you want to submit another line, end your input with '\\'.\n";
}
- LOG("== Running in interactive mode. ==\n");
+ LOG_INF("== Running in interactive mode. ==\n");
#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) || defined (_WIN32)
- LOG( " - Press Ctrl+C to interject at any time.\n");
+ LOG_INF( " - Press Ctrl+C to interject at any time.\n");
#endif
- LOG( "%s\n", control_message);
+ LOG_INF( "%s\n", control_message);
is_interacting = params.interactive_first;
}
diff --git a/examples/server/server.cpp b/examples/server/server.cpp
index e07cc65fa..38bd20bff 100644
--- a/examples/server/server.cpp
+++ b/examples/server/server.cpp
@@ -1181,6 +1181,15 @@ struct server_context {
SLT_DBG(slot, "stopped by limit, n_decoded = %d, n_predict = %d\n", slot.n_decoded, slot.params.n_predict);
}
+ // if context shift is disabled, we stop when it reaches the context limit
+ if (slot.n_decoded >= slot.n_ctx) {
+ slot.truncated = true;
+ slot.stopped_limit = true;
+ slot.has_next_token = false;
+
+ SLT_DBG(slot, "stopped due to running out of context capacity, n_decoded = %d, n_ctx = %d\n", slot.n_decoded, slot.n_ctx);
+ }
+
if (llama_token_is_eog(model, result.tok)) {
slot.stopped_eos = true;
slot.has_next_token = false;
@@ -1481,7 +1490,7 @@ struct server_context {
if (result.error) {
error_handler(result.data);
cancel_tasks(id_tasks);
- break;
+ return;
}
size_t idx = result.data["index"];
@@ -1828,6 +1837,14 @@ struct server_context {
for (server_slot & slot : slots) {
if (slot.ga_n == 1) {
if (slot.is_processing() && (int) system_tokens.size() + slot.n_past >= slot.n_ctx - 1) {
+ if (!params.ctx_shift) {
+ // this check is redundant (for good)
+ // we should never get here, because generation should already stopped in process_token()
+ slot.release();
+ send_error(slot, "context shift is disabled", ERROR_TYPE_SERVER);
+ continue;
+ }
+
// Shift context
const int n_keep = slot.params.n_keep + add_bos_token;
const int n_left = (int) system_tokens.size() + slot.n_past - n_keep;
@@ -1962,6 +1979,14 @@ struct server_context {
continue;
}
} else {
+ if (!params.ctx_shift) {
+ // if context shift is disabled, we make sure prompt size is smaller than KV size
+ if ((int) system_tokens.size() + slot.n_prompt_tokens >= slot.n_ctx) {
+ slot.release();
+ send_error(slot, "the request exceeds the available context size. try increasing the context size or enable context shift", ERROR_TYPE_INVALID_REQUEST);
+ continue;
+ }
+ }
if (slot.params.n_keep < 0) {
slot.params.n_keep = slot.n_prompt_tokens;
}
@@ -2332,6 +2357,10 @@ int main(int argc, char ** argv) {
svr.reset(new httplib::Server());
}
#else
+ if (params.ssl_file_key != "" && params.ssl_file_cert != "") {
+ LOG_ERR("Server is built without SSL support\n");
+ return 1;
+ }
svr.reset(new httplib::Server());
#endif
@@ -3155,7 +3184,7 @@ int main(int argc, char ** argv) {
}
// print sample chat example to make it clear which template is used
- LOG_INF("%s: chat template, built_in: %d, chat_example: '%s\n'", __func__, params.chat_template.empty(), llama_chat_format_example(ctx_server.model, params.chat_template).c_str());
+ LOG_INF("%s: chat template, built_in: %d, chat_example: '%s'\n", __func__, params.chat_template.empty(), llama_chat_format_example(ctx_server.model, params.chat_template).c_str());
ctx_server.queue_tasks.on_new_task(std::bind(
&server_context::process_single_task, &ctx_server, std::placeholders::_1));
diff --git a/examples/server/tests/features/ctx_shift.feature b/examples/server/tests/features/ctx_shift.feature
new file mode 100644
index 000000000..ba3afcf06
--- /dev/null
+++ b/examples/server/tests/features/ctx_shift.feature
@@ -0,0 +1,62 @@
+@llama.cpp
+@ctx_shift
+Feature: llama.cpp server
+
+ Background: Server startup
+ Given a server listening on localhost:8080
+ And a model file tinyllamas/stories260K.gguf from HF repo ggml-org/models
+ And a model file test-model.gguf
+ And a model alias tinyllama-2
+ And BOS token is 1
+ And 42 as server seed
+ And 256 KV cache size
+ And 32 as batch size
+ And 2 slots
+
+ Scenario: Inference with context shift
+ And 64 server max tokens to predict
+ Then the server is starting
+ Then the server is healthy
+ Given a prompt:
+ """
+ Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua.
+ Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat.
+ Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur.
+ Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum.
+ """
+ And a completion request with no api error
+ Then 64 tokens are predicted matching fun|Annaks|popcorns|pictry|bowl
+ And the completion is truncated
+ And 109 prompt tokens are processed
+
+ Scenario Outline: Inference without context shift
+ And server max tokens to predict
+ And disable context shifting
+ Then the server is starting
+ Then the server is healthy
+ Given a prompt:
+ """
+ Hi how are you
+ """
+ And a completion request with no api error
+ Then tokens are predicted matching twind|Anna
+ And the completion is truncated
+ And 8 prompt tokens are processed
+ Examples:
+ | n_predict | n_token_output | truncated |
+ | 64 | 64 | not |
+ | -1 | 120 | |
+
+ Scenario: Inference without context shift (expected error: prompt too long)
+ And disable context shifting
+ Then the server is starting
+ Then the server is healthy
+ Given a prompt:
+ """
+ Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua.
+ Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat.
+ Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur.
+ Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum.
+ """
+ And a completion request with 400 api error
+
diff --git a/examples/server/tests/features/embeddings.feature b/examples/server/tests/features/embeddings.feature
index e1eade6cd..818ea3beb 100644
--- a/examples/server/tests/features/embeddings.feature
+++ b/examples/server/tests/features/embeddings.feature
@@ -10,11 +10,11 @@ Feature: llama.cpp server
And 42 as server seed
And 2 slots
# the bert-bge-small model has context size of 512
- # since the generated prompts are as big as the batch size, we need to set the batch size to 512
+ # since the generated prompts are as big as the batch size, we need to set the batch size to <= 512
# ref: https://huggingface.co/BAAI/bge-small-en-v1.5/blob/5c38ec7c405ec4b44b94cc5a9bb96e735b38267a/config.json#L20
- And 512 as batch size
- And 512 as ubatch size
- And 2048 KV cache size
+ And 128 as batch size
+ And 128 as ubatch size
+ And 512 KV cache size
And embeddings extraction
Then the server is starting
Then the server is healthy
@@ -26,6 +26,20 @@ Feature: llama.cpp server
"""
Then embeddings are generated
+ Scenario: Embedding (error: prompt too long)
+ When embeddings are computed for:
+ """
+ Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua.
+ Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat.
+ Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur.
+ Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum.
+ Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua.
+ Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat.
+ Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur.
+ Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum.
+ """
+ And embeddings request with 500 api error
+
Scenario: OAI Embeddings compatibility
Given a model bert-bge-small
When an OAI compatible embeddings computation request for:
diff --git a/examples/server/tests/features/steps/steps.py b/examples/server/tests/features/steps/steps.py
index 062f084be..0fea0fe87 100644
--- a/examples/server/tests/features/steps/steps.py
+++ b/examples/server/tests/features/steps/steps.py
@@ -77,6 +77,7 @@ def step_server_config(context, server_fqdn: str, server_port: str):
context.response_format = None
context.temperature = None
context.lora_file = None
+ context.disable_ctx_shift = False
context.tasks_result = []
context.concurrent_tasks = []
@@ -148,7 +149,7 @@ def step_n_slots(context, n_slots: int):
@step('{n_predict:d} server max tokens to predict')
def step_server_n_predict(context, n_predict: int):
- context.n_server_predict = n_predict
+ context.n_server_predict = n_predict if n_predict > 0 else None
@step('{slot_save_path} as slot save path')
@@ -180,6 +181,9 @@ def step_server_embeddings(context):
def step_server_metrics(context):
context.server_metrics = True
+@step('disable context shifting')
+def step_server_disable_ctx_shift(context):
+ context.disable_ctx_shift = True
@step("the server is starting")
def step_start_server(context):
@@ -257,7 +261,7 @@ async def step_all_slots_status(context, expected_slot_status_string: Literal['i
@step('a completion request with {api_error} api error')
@async_run_until_complete
async def step_request_completion(context, api_error: Literal['raised'] | str):
- expect_api_error = api_error == 'raised'
+ expect_api_error = api_error == 'raised' or api_error != 'no'
seeds = await completions_seed(context, num_seeds=1)
completion = await request_completion(context.prompts.pop(),
seeds[0] if seeds is not None else seeds,
@@ -272,8 +276,11 @@ async def step_request_completion(context, api_error: Literal['raised'] | str):
context.tasks_result.append(completion)
if context.debug:
print(f"Completion response: {completion}")
- if expect_api_error:
+ if api_error == 'raised':
assert completion == 401, f"completion must be an 401 status code: {completion}"
+ elif api_error.isdigit():
+ api_error_code = int(api_error)
+ assert completion == api_error_code, f"completion must be an {api_error_code} status code: {completion}"
@step('{predicted_n:d} tokens are predicted matching {re_content}')
@@ -645,6 +652,9 @@ def step_assert_embeddings(context):
for embedding in context.embeddings:
assert_embeddings(embedding)
+@step('embeddings request with {api_error_code:d} api error')
+def step_assert_embeddings(context, api_error_code: int):
+ assert context.embeddings == api_error_code, f"embeddings request must return code {api_error_code}, but got {context.embeddings}"
@step('an OAI compatible embeddings computation request for')
@async_run_until_complete
@@ -1089,15 +1099,17 @@ async def oai_chat_completions(user_prompt,
return completion_response
-async def request_embedding(content, seed, base_url=None) -> list[list[float]]:
+async def request_embedding(content, seed, base_url=None) -> list[list[float]] | int:
async with aiohttp.ClientSession(timeout=DEFAULT_TIMEOUT_SECONDS) as session:
async with session.post(f'{base_url}/embedding',
json={
"content": content,
}) as response:
- assert response.status == 200
- response_json = await response.json()
- return [response_json['embedding']]
+ if response.status == 200:
+ response_json = await response.json()
+ return [response_json['embedding']]
+ else:
+ return response.status
async def request_oai_embeddings(input, seed,
@@ -1372,6 +1384,8 @@ def start_server_background(context):
server_args.append('--verbose')
if context.lora_file:
server_args.extend(['--lora', context.lora_file])
+ if context.disable_ctx_shift:
+ server_args.extend(['--no-context-shift'])
args = [str(arg) for arg in [context.server_path, *server_args]]
print(f"bench: starting server with: {' '.join(args)}")
diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h
index 94e40b632..ef45ce16d 100644
--- a/ggml/include/ggml.h
+++ b/ggml/include/ggml.h
@@ -576,6 +576,7 @@ extern "C" {
GGML_LOG_LEVEL_WARN = 2,
GGML_LOG_LEVEL_ERROR = 3,
GGML_LOG_LEVEL_DEBUG = 4,
+ GGML_LOG_LEVEL_CONT = 5, // continue previous log
};
// this tensor...
@@ -1985,6 +1986,9 @@ extern "C" {
typedef void (*ggml_custom2_op_t)(struct ggml_tensor * dst , const struct ggml_tensor * a, const struct ggml_tensor * b, int ith, int nth, void * userdata);
typedef void (*ggml_custom3_op_t)(struct ggml_tensor * dst , const struct ggml_tensor * a, const struct ggml_tensor * b, const struct ggml_tensor * c, int ith, int nth, void * userdata);
+#define GGML_N_TASKS_MAX (-1)
+ // n_tasks == GGML_N_TASKS_MAX means to use max number of tasks
+
GGML_API struct ggml_tensor * ggml_map_custom1(
struct ggml_context * ctx,
struct ggml_tensor * a,
diff --git a/ggml/src/ggml-aarch64.c b/ggml/src/ggml-aarch64.c
index 27375d0d7..8912de63d 100644
--- a/ggml/src/ggml-aarch64.c
+++ b/ggml/src/ggml-aarch64.c
@@ -1,4 +1,7 @@
-// SPDX-FileCopyrightText: Copyright 2024 Arm Ltd.
+// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates
+// SPDX-License-Identifier: MIT
+//
+
#define GGML_COMMON_IMPL_C
#include "ggml-common.h"
@@ -39,11 +42,44 @@
//
#if defined(__AVX__)
#if defined(__F16C__)
+#if defined(__AVX512F__)
+#define GGML_F32Cx8x2_LOAD(x, y) _mm512_cvtph_ps(_mm256_set_m128i(_mm_loadu_si128((const __m128i *)(y)), _mm_loadu_si128((const __m128i *)(x))))
+#define GGML_F32Cx16_REPEAT_LOAD(x) _mm512_cvtph_ps(_mm256_set_m128i(x, x))
+#endif
// the _mm256_cvt intrinsics require F16C
#define GGML_F32Cx8_LOAD(x) _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)(x)))
#define GGML_F32Cx8_REPEAT_LOAD(x, loadMask) _mm256_cvtph_ps(_mm_shuffle_epi32(_mm_maskload_epi32((int const*)(x), loadMask), 68))
#define GGML_F32Cx8_REARRANGE_LOAD(x, arrangeMask) _mm256_cvtph_ps(_mm_shuffle_epi8(_mm_loadu_si128((const __m128i *) x), arrangeMask))
#else
+#if defined(__AVX512F__)
+static inline __m512 __avx512_f32cx8x2_load(ggml_fp16_t *x, ggml_fp16_t *y) {
+ float tmp[16];
+
+ for (int i = 0; i < 8; i++) {
+ tmp[i] = GGML_FP16_TO_FP32(x[i]);
+ }
+
+ for (int i = 0; i < 8; i++) {
+ tmp[i + 8] = GGML_FP16_TO_FP32(y[i]);
+ }
+
+ return _mm512_loadu_ps(tmp);
+}
+static inline __m512 __avx512_repeat_f32cx16_load(__m128i x) {
+ float tmp[16];
+ uint16_t tmphalf[8];
+ _mm_storeu_si128((__m128i*)tmphalf, x);
+
+ for (int i = 0; i < 4; i++) {
+ tmp[i] = GGML_FP16_TO_FP32(tmphalf[i]);
+ tmp[i + 4] = GGML_FP16_TO_FP32(tmphalf[i]);
+ tmp[i + 8] = GGML_FP16_TO_FP32(tmphalf[i]);
+ tmp[i + 12] = GGML_FP16_TO_FP32(tmphalf[i]);
+ }
+
+ return _mm512_loadu_ps(tmp);
+}
+#endif
static inline __m256 __avx_f32cx8_load(ggml_fp16_t *x) {
float tmp[8];
@@ -78,30 +114,65 @@ static inline __m256 __avx_rearranged_f32cx8_load(ggml_fp16_t *x, __m128i arrang
#define GGML_F32Cx8_LOAD(x) __avx_f32cx8_load(x)
#define GGML_F32Cx8_REPEAT_LOAD(x, loadMask) __avx_repeat_f32cx8_load(x)
#define GGML_F32Cx8_REARRANGE_LOAD(x, arrangeMask) __avx_rearranged_f32cx8_load(x, arrangeMask)
+#if defined(__AVX512F__)
+#define GGML_F32Cx8x2_LOAD(x, y) __avx512_f32cx8x2_load(x, y)
+#define GGML_F32Cx16_REPEAT_LOAD(x) __avx512_repeat_f32cx16_load(x)
+#endif
#endif
#endif
#if defined(__AVX2__) || defined(__AVX512F__)
-static inline __m256i sum_i16_pairs_int(const __m256i x) {
+#if defined(__AVX512F__)
+// add int16_t pairwise and return as 512 bit int vector
+static inline __m512i sum_i16_pairs_int_32x16(const __m512i x) {
+ const __m512i ones = _mm512_set1_epi16(1);
+ return _mm512_madd_epi16(ones, x);
+}
+
+static inline __m512i mul_sum_us8_pairs_int32x16(const __m512i ax, const __m512i sy) {
+#if defined(__AVXVNNI__) || (defined(__AVX512VNNI__) && defined(__AVX512VL__))
+ const __m512i zero = _mm512_setzero_si512();
+ return _mm512_dpbusd_epi32(zero, ax, sy);
+#else
+ // Perform multiplication and create 16-bit values
+ const __m512i dot = _mm512_maddubs_epi16(ax, sy);
+ return sum_i16_pairs_int_32x16(dot);
+#endif
+}
+
+// multiply int8_t, add results pairwise twice and return as 512 bit int vector
+static inline __m512i mul_sum_i8_pairs_int32x16(const __m512i x, const __m512i y) {
+ const __m512i zero = _mm512_setzero_si512();
+ // Get absolute values of x vectors
+ const __m512i ax = _mm512_abs_epi8(x);
+ // Sign the values of the y vectors
+ __mmask64 blt0 = _mm512_movepi8_mask(x);
+ const __m512i sy = _mm512_mask_sub_epi8(y, blt0, zero, y);
+ return mul_sum_us8_pairs_int32x16(ax, sy);
+}
+#endif
+
+// add int16_t pairwise and return as 256 bit int vector
+static inline __m256i sum_i16_pairs_int32x8(const __m256i x) {
const __m256i ones = _mm256_set1_epi16(1);
return _mm256_madd_epi16(ones, x);
}
-static inline __m256i mul_sum_us8_pairs_int(const __m256i ax, const __m256i sy) {
+static inline __m256i mul_sum_us8_pairs_int32x8(const __m256i ax, const __m256i sy) {
#if defined(__AVXVNNI__) || (defined(__AVX512VNNI__) && defined(__AVX512VL__))
const __m256i zero = _mm256_setzero_si256();
return _mm256_dpbusd_epi32(zero, ax, sy);
#else
// Perform multiplication and create 16-bit values
const __m256i dot = _mm256_maddubs_epi16(ax, sy);
- return sum_i16_pairs_int(dot);
+ return sum_i16_pairs_int32x8(dot);
#endif
}
// Integer variant of the function defined in ggml-quants.c
-// multiply int8_t, add results pairwise twice and return as float vector
-static inline __m256i mul_sum_i8_pairs_int(const __m256i x, const __m256i y) {
+// multiply int8_t, add results pairwise twice and return as 256 bit int vector
+static inline __m256i mul_sum_i8_pairs_int32x8(const __m256i x, const __m256i y) {
#if __AVXVNNIINT8__
const __m256i zero = _mm256_setzero_si256();
return _mm256_dpbssd_epi32(zero, x, y);
@@ -110,7 +181,7 @@ static inline __m256i mul_sum_i8_pairs_int(const __m256i x, const __m256i y) {
const __m256i ax = _mm256_sign_epi8(x, x);
// Sign the values of the y vectors
const __m256i sy = _mm256_sign_epi8(y, x);
- return mul_sum_us8_pairs_int(ax, sy);
+ return mul_sum_us8_pairs_int32x8(ax, sy);
#endif
}
#endif
@@ -527,6 +598,15 @@ size_t quantize_q4_0_8x8(const float * restrict src, void * restrict dst, int64_
return quantize_q4_0_nr_bl(src, dst, nrow, n_per_row, 8, 8);
}
+// Return the number of byte lanes in the SVE vector if SVE is supported; otherwise, returns 0 if SVE is not supported.
+static int sve_lane_count(void) {
+#if defined(__ARM_FEATURE_SVE)
+ return ggml_sve_cnt_b;
+#else
+ return 0;
+#endif
+}
+
void ggml_gemv_q4_0_4x4_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, const void * restrict vy, int nr, int nc) {
const int qk = QK8_0;
const int nb = n / qk;
@@ -546,73 +626,67 @@ void ggml_gemv_q4_0_4x4_q8_0(int n, float * restrict s, size_t bs, const void *
UNUSED(ncols_interleaved);
UNUSED(blocklen);
-#if defined(__ARM_FEATURE_SVE)
- if (ggml_sve_cnt_b == QK8_0) {
- GGML_ASSERT(!(ggml_cpu_has_sve() && (ggml_sve_cnt_b == QK8_0)) &&
- "__ARM_FEATURE_SVE defined, use the Q4_0_8_8 quantization format for optimal performance");
- }
-#endif
-#if defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8)
- GGML_ASSERT(!(ggml_cpu_has_neon() && ggml_cpu_has_matmul_int8()) &&
- "__ARM_NEON and __ARM_FEATURE_MATMUL_INT8 defined, use the Q4_0_4_8 quantization format for optimal performance");
-#elif defined(__ARM_NEON) && defined(__aarch64__) && ! ((defined(_MSC_VER)) && ! defined(__clang__))
- const void * b_ptr = vx;
- const void * a_ptr = vy;
- float * res_ptr = s;
+#if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON)
+ if (ggml_cpu_has_neon()) {
+ const void * b_ptr = vx;
+ const void * a_ptr = vy;
+ float * res_ptr = s;
- __asm__ __volatile__(
- "movi v31.16b, #0x4\n"
- "movi v30.16b, #0xf0\n"
- "add %x[b_ptr], %x[b_ptr], #0x8\n"
- "1:" // Column loop
- "add x22, %x[a_ptr], #0x2\n"
- "movi v29.16b, #0x0\n"
- "mov x21, %x[nb]\n"
- "2:" // Block loop
- "ldr q28, [%x[b_ptr], #0x0]\n"
- "ldr q27, [x22, #0x0]\n"
- "movi v26.4s, #0x0\n"
- "sub x20, x22, #0x2\n"
- "ldr q25, [x22, #0x10]\n"
- "ldr q24, [%x[b_ptr], #0x10]\n"
- "sub x21, x21, #0x1\n"
- "add x22, x22, #0x22\n"
- "ldr q23, [%x[b_ptr], #0x20]\n"
- "ldr q22, [%x[b_ptr], #0x30]\n"
- "ld1r { v21.8h }, [x20]\n"
- "ldr q20, [%x[b_ptr], #-0x8]\n"
- "sshl v16.16b, v28.16b, v31.16b\n"
- "and v28.16b, v28.16b, v30.16b\n"
- "sshl v19.16b, v24.16b, v31.16b\n"
- "and v24.16b, v24.16b, v30.16b\n"
- "add %x[b_ptr], %x[b_ptr], #0x48\n"
- "sshl v18.16b, v23.16b, v31.16b\n"
- "and v23.16b, v23.16b, v30.16b\n"
- ".inst 0x4f9be21a // sdot v26.4s, v16.16b, v27.4b[0]\n"
- "sshl v17.16b, v22.16b, v31.16b\n"
- "and v22.16b, v22.16b, v30.16b\n"
- "fcvtl v21.4s, v21.4h\n"
- "fcvtl v16.4s, v20.4h\n"
- ".inst 0x4f99e39a // sdot v26.4s, v28.16b, v25.4b[0]\n"
- "fmul v16.4s, v16.4s, v21.4s\n"
- ".inst 0x4fbbe27a // sdot v26.4s, v19.16b, v27.4b[1]\n"
- ".inst 0x4fb9e31a // sdot v26.4s, v24.16b, v25.4b[1]\n"
- ".inst 0x4f9bea5a // sdot v26.4s, v18.16b, v27.4b[2]\n"
- ".inst 0x4f99eafa // sdot v26.4s, v23.16b, v25.4b[2]\n"
- ".inst 0x4fbbea3a // sdot v26.4s, v17.16b, v27.4b[3]\n"
- ".inst 0x4fb9eada // sdot v26.4s, v22.16b, v25.4b[3]\n"
- "scvtf v26.4s, v26.4s, #0x4\n"
- "fmla v29.4s, v26.4s, v16.4s\n"
- "cbnz x21, 2b\n"
- "sub %x[nc], %x[nc], #0x4\n"
- "str q29, [%x[res_ptr], #0x0]\n"
- "add %x[res_ptr], %x[res_ptr], #0x10\n"
- "cbnz %x[nc], 1b\n"
- : [b_ptr] "+&r" (b_ptr), [res_ptr] "+&r" (res_ptr), [nc] "+&r" (nc)
- : [a_ptr] "r" (a_ptr), [nb] "r" (nb)
- : "memory", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "x20", "x21", "x22"
- );
-#else
+ __asm__ __volatile__(
+ "movi v31.16b, #0x4\n"
+ "movi v30.16b, #0xf0\n"
+ "add %x[b_ptr], %x[b_ptr], #0x8\n"
+ "1:" // Column loop
+ "add x22, %x[a_ptr], #0x2\n"
+ "movi v29.16b, #0x0\n"
+ "mov x21, %x[nb]\n"
+ "2:" // Block loop
+ "ldr q28, [%x[b_ptr], #0x0]\n"
+ "ldr q27, [x22, #0x0]\n"
+ "movi v26.4s, #0x0\n"
+ "sub x20, x22, #0x2\n"
+ "ldr q25, [x22, #0x10]\n"
+ "ldr q24, [%x[b_ptr], #0x10]\n"
+ "sub x21, x21, #0x1\n"
+ "add x22, x22, #0x22\n"
+ "ldr q23, [%x[b_ptr], #0x20]\n"
+ "ldr q22, [%x[b_ptr], #0x30]\n"
+ "ld1r { v21.8h }, [x20]\n"
+ "ldr q20, [%x[b_ptr], #-0x8]\n"
+ "sshl v16.16b, v28.16b, v31.16b\n"
+ "and v28.16b, v28.16b, v30.16b\n"
+ "sshl v19.16b, v24.16b, v31.16b\n"
+ "and v24.16b, v24.16b, v30.16b\n"
+ "add %x[b_ptr], %x[b_ptr], #0x48\n"
+ "sshl v18.16b, v23.16b, v31.16b\n"
+ "and v23.16b, v23.16b, v30.16b\n"
+ ".inst 0x4f9be21a // sdot v26.4s, v16.16b, v27.4b[0]\n"
+ "sshl v17.16b, v22.16b, v31.16b\n"
+ "and v22.16b, v22.16b, v30.16b\n"
+ "fcvtl v21.4s, v21.4h\n"
+ "fcvtl v16.4s, v20.4h\n"
+ ".inst 0x4f99e39a // sdot v26.4s, v28.16b, v25.4b[0]\n"
+ "fmul v16.4s, v16.4s, v21.4s\n"
+ ".inst 0x4fbbe27a // sdot v26.4s, v19.16b, v27.4b[1]\n"
+ ".inst 0x4fb9e31a // sdot v26.4s, v24.16b, v25.4b[1]\n"
+ ".inst 0x4f9bea5a // sdot v26.4s, v18.16b, v27.4b[2]\n"
+ ".inst 0x4f99eafa // sdot v26.4s, v23.16b, v25.4b[2]\n"
+ ".inst 0x4fbbea3a // sdot v26.4s, v17.16b, v27.4b[3]\n"
+ ".inst 0x4fb9eada // sdot v26.4s, v22.16b, v25.4b[3]\n"
+ "scvtf v26.4s, v26.4s, #0x4\n"
+ "fmla v29.4s, v26.4s, v16.4s\n"
+ "cbnz x21, 2b\n"
+ "sub %x[nc], %x[nc], #0x4\n"
+ "str q29, [%x[res_ptr], #0x0]\n"
+ "add %x[res_ptr], %x[res_ptr], #0x10\n"
+ "cbnz %x[nc], 1b\n"
+ : [b_ptr] "+&r" (b_ptr), [res_ptr] "+&r" (res_ptr), [nc] "+&r" (nc)
+ : [a_ptr] "r" (a_ptr), [nb] "r" (nb)
+ : "memory", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "x20", "x21", "x22"
+ );
+ return;
+ }
+#endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON)
float sumf[4];
int sumi;
@@ -636,7 +710,6 @@ void ggml_gemv_q4_0_4x4_q8_0(int n, float * restrict s, size_t bs, const void *
}
for (int j = 0; j < ncols_interleaved; j++) s[x * ncols_interleaved + j] = sumf[j];
}
-#endif
}
void ggml_gemv_q4_0_4x8_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, const void * restrict vy, int nr, int nc) {
@@ -658,79 +731,72 @@ void ggml_gemv_q4_0_4x8_q8_0(int n, float * restrict s, size_t bs, const void *
UNUSED(ncols_interleaved);
UNUSED(blocklen);
-#if defined(__ARM_FEATURE_SVE)
- if (ggml_sve_cnt_b == QK8_0) {
- GGML_ASSERT(!(ggml_cpu_has_sve() && (ggml_sve_cnt_b == QK8_0)) &&
- "__ARM_FEATURE_SVE defined, use the Q4_0_8_8 quantization format for optimal performance");
- }
-#endif
-#if defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8) && ! ((defined(_MSC_VER)) && ! defined(__clang__))
- const void * b_ptr = vx;
- const void * a_ptr = vy;
- float * res_ptr = s;
+#if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8)
+ if (ggml_cpu_has_neon() && ggml_cpu_has_matmul_int8()) {
+ const void * b_ptr = vx;
+ const void * a_ptr = vy;
+ float * res_ptr = s;
- __asm__ __volatile__(
- "movi v2.16b, #0x4\n"
- "movi v1.16b, #0xf0\n"
- "add %x[b_ptr], %x[b_ptr], #0x8\n"
- "1:" // Column loop
- "add x23, %x[a_ptr], #0x2\n"
- "movi v0.16b, #0x0\n"
- "mov x22, %x[nb]\n"
- "2:" // Block loop
- "ldr q31, [%x[b_ptr], #0x0]\n"
- "ldr q30, [%x[b_ptr], #0x10]\n"
- "mov x21, x23\n"
- "movi v29.4s, #0x0\n"
- "ldr q28, [%x[b_ptr], #0x20]\n"
- "ldr q27, [%x[b_ptr], #0x30]\n"
- "movi v26.4s, #0x0\n"
- "sub x20, x23, #0x2\n"
- "ld1r { v25.8h }, [x20]\n"
- "ldr q24, [%x[b_ptr], #-0x8]\n"
- "sub x22, x22, #0x1\n"
- "add x23, x23, #0x22\n"
- "ld1r { v23.2d }, [x21], #0x8\n"
- "sshl v22.16b, v31.16b, v2.16b\n"
- "sshl v16.16b, v30.16b, v2.16b\n"
- "add %x[b_ptr], %x[b_ptr], #0x48\n"
- "ld1r { v21.2d }, [x21], #0x8\n"
- "sshl v20.16b, v28.16b, v2.16b\n"
- "sshl v19.16b, v27.16b, v2.16b\n"
- "ld1r { v18.2d }, [x21], #0x8\n"
- "ld1r { v17.2d }, [x21], #0x8\n"
- "and v31.16b, v31.16b, v1.16b\n"
- "and v30.16b, v30.16b, v1.16b\n"
- ".inst 0x4e9796dd // sdot v29.4s, v22.16b, v23.16b\n"
- ".inst 0x4e97961a // sdot v26.4s, v16.16b, v23.16b\n"
- "and v28.16b, v28.16b, v1.16b\n"
- "and v27.16b, v27.16b, v1.16b\n"
- "fcvtl v25.4s, v25.4h\n"
- "fcvtl v16.4s, v24.4h\n"
- ".inst 0x4e95969d // sdot v29.4s, v20.16b, v21.16b\n"
- ".inst 0x4e95967a // sdot v26.4s, v19.16b, v21.16b\n"
- "fmul v16.4s, v16.4s, v25.4s\n"
- ".inst 0x4e9297fd // sdot v29.4s, v31.16b, v18.16b\n"
- ".inst 0x4e9297da // sdot v26.4s, v30.16b, v18.16b\n"
- ".inst 0x4e91979d // sdot v29.4s, v28.16b, v17.16b\n"
- ".inst 0x4e91977a // sdot v26.4s, v27.16b, v17.16b\n"
- "addp v29.4s, v29.4s, v26.4s\n"
- "scvtf v29.4s, v29.4s, #0x4\n"
- "fmla v0.4s, v29.4s, v16.4s\n"
- "cbnz x22, 2b\n"
- "sub %x[nc], %x[nc], #0x4\n"
- "str q0, [%x[res_ptr], #0x0]\n"
- "add %x[res_ptr], %x[res_ptr], #0x10\n"
- "cbnz %x[nc], 1b\n"
- : [b_ptr] "+&r" (b_ptr), [res_ptr] "+&r" (res_ptr), [nc] "+&r" (nc)
- : [a_ptr] "r" (a_ptr), [nb] "r" (nb)
- : "memory", "v0", "v1", "v2", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "x20", "x21", "x22", "x23"
- );
-#elif defined(__ARM_NEON) && defined(__aarch64__)
- GGML_ASSERT((ggml_cpu_has_sve() || ggml_cpu_has_matmul_int8()) &&
- "__ARM_FEATURE_SVE and __ARM_FEATURE_MATMUL_INT8 not defined, use the Q4_0_4_4 quantization format for optimal "
- "performance");
-#else
+ __asm__ __volatile__(
+ "movi v2.16b, #0x4\n"
+ "movi v1.16b, #0xf0\n"
+ "add %x[b_ptr], %x[b_ptr], #0x8\n"
+ "1:" // Column loop
+ "add x23, %x[a_ptr], #0x2\n"
+ "movi v0.16b, #0x0\n"
+ "mov x22, %x[nb]\n"
+ "2:" // Block loop
+ "ldr q31, [%x[b_ptr], #0x0]\n"
+ "ldr q30, [%x[b_ptr], #0x10]\n"
+ "mov x21, x23\n"
+ "movi v29.4s, #0x0\n"
+ "ldr q28, [%x[b_ptr], #0x20]\n"
+ "ldr q27, [%x[b_ptr], #0x30]\n"
+ "movi v26.4s, #0x0\n"
+ "sub x20, x23, #0x2\n"
+ "ld1r { v25.8h }, [x20]\n"
+ "ldr q24, [%x[b_ptr], #-0x8]\n"
+ "sub x22, x22, #0x1\n"
+ "add x23, x23, #0x22\n"
+ "ld1r { v23.2d }, [x21], #0x8\n"
+ "sshl v22.16b, v31.16b, v2.16b\n"
+ "sshl v16.16b, v30.16b, v2.16b\n"
+ "add %x[b_ptr], %x[b_ptr], #0x48\n"
+ "ld1r { v21.2d }, [x21], #0x8\n"
+ "sshl v20.16b, v28.16b, v2.16b\n"
+ "sshl v19.16b, v27.16b, v2.16b\n"
+ "ld1r { v18.2d }, [x21], #0x8\n"
+ "ld1r { v17.2d }, [x21], #0x8\n"
+ "and v31.16b, v31.16b, v1.16b\n"
+ "and v30.16b, v30.16b, v1.16b\n"
+ ".inst 0x4e9796dd // sdot v29.4s, v22.16b, v23.16b\n"
+ ".inst 0x4e97961a // sdot v26.4s, v16.16b, v23.16b\n"
+ "and v28.16b, v28.16b, v1.16b\n"
+ "and v27.16b, v27.16b, v1.16b\n"
+ "fcvtl v25.4s, v25.4h\n"
+ "fcvtl v16.4s, v24.4h\n"
+ ".inst 0x4e95969d // sdot v29.4s, v20.16b, v21.16b\n"
+ ".inst 0x4e95967a // sdot v26.4s, v19.16b, v21.16b\n"
+ "fmul v16.4s, v16.4s, v25.4s\n"
+ ".inst 0x4e9297fd // sdot v29.4s, v31.16b, v18.16b\n"
+ ".inst 0x4e9297da // sdot v26.4s, v30.16b, v18.16b\n"
+ ".inst 0x4e91979d // sdot v29.4s, v28.16b, v17.16b\n"
+ ".inst 0x4e91977a // sdot v26.4s, v27.16b, v17.16b\n"
+ "addp v29.4s, v29.4s, v26.4s\n"
+ "scvtf v29.4s, v29.4s, #0x4\n"
+ "fmla v0.4s, v29.4s, v16.4s\n"
+ "cbnz x22, 2b\n"
+ "sub %x[nc], %x[nc], #0x4\n"
+ "str q0, [%x[res_ptr], #0x0]\n"
+ "add %x[res_ptr], %x[res_ptr], #0x10\n"
+ "cbnz %x[nc], 1b\n"
+ : [b_ptr] "+&r" (b_ptr), [res_ptr] "+&r" (res_ptr), [nc] "+&r" (nc)
+ : [a_ptr] "r" (a_ptr), [nb] "r" (nb)
+ : "memory", "v0", "v1", "v2", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "x20", "x21", "x22", "x23"
+ );
+ return;
+ }
+#endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8)
float sumf[4];
int sumi;
@@ -754,7 +820,6 @@ void ggml_gemv_q4_0_4x8_q8_0(int n, float * restrict s, size_t bs, const void *
}
for (int j = 0; j < ncols_interleaved; j++) s[x * ncols_interleaved + j] = sumf[j];
}
-#endif
}
void ggml_gemv_q4_0_8x8_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, const void * restrict vy, int nr, int nc) {
@@ -776,8 +841,9 @@ void ggml_gemv_q4_0_8x8_q8_0(int n, float * restrict s, size_t bs, const void *
UNUSED(ncols_interleaved);
UNUSED(blocklen);
-#if defined(__ARM_FEATURE_SVE) && ! ((defined(_MSC_VER)) && ! defined(__clang__))
- if (ggml_sve_cnt_b == QK8_0) {
+#if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__)
+#if defined(__ARM_FEATURE_SVE)
+ if (ggml_cpu_has_sve() && sve_lane_count() == QK8_0) {
const void * b_ptr = vx;
const void * a_ptr = vy;
float * res_ptr = s;
@@ -842,24 +908,7 @@ void ggml_gemv_q4_0_8x8_q8_0(int n, float * restrict s, size_t bs, const void *
);
return;
}
- else if (ggml_cpu_has_neon() && ggml_cpu_has_matmul_int8()) {
- GGML_ASSERT((ggml_cpu_has_sve() && (ggml_sve_cnt_b == QK8_0)) &&
- "__ARM_FEATURE_SVE for vector size of 256-bits not defined, use the Q4_0_4_8 quantization format for optimal "
- "performance");
- }
- else if (ggml_cpu_has_neon()) {
- GGML_ASSERT(((ggml_cpu_has_sve() && (ggml_sve_cnt_b == QK8_0)) || ggml_cpu_has_matmul_int8()) &&
- "__ARM_FEATURE_SVE for vector size of 256-bits and __ARM_FEATURE_MATMUL_INT8 not defined, use the Q4_0_4_4 "
- "quantization format for optimal performance");
- }
-#endif
-#if defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8)
- GGML_ASSERT(ggml_cpu_has_sve() &&
- "__ARM_FEATURE_SVE not defined, use the Q4_0_4_8 quantization format for optimal performance");
-#elif defined(__ARM_NEON) && defined(__aarch64__)
- GGML_ASSERT((ggml_cpu_has_sve() || ggml_cpu_has_matmul_int8()) &&
- "__ARM_FEATURE_SVE and __ARM_FEATURE_MATMUL_INT8 not defined, use the Q4_0_4_4 quantization format for optimal "
- "performance");
+#endif // #if defined(__ARM_FEATURE_SVE)
#elif defined(__AVX2__)
// Lookup table to convert signed nibbles to signed bytes
__m256i signextendlut = _mm256_castsi128_si256(_mm_set_epi8(-1, -2, -3, -4, -5, -6, -7, -8, 7, 6, 5, 4, 3, 2, 1, 0));
@@ -929,17 +978,17 @@ void ggml_gemv_q4_0_8x8_q8_0(int n, float * restrict s, size_t bs, const void *
// ...........................................................................
// B0(28-31) B4(28-31) B1(28-31) B5(28-31) B2(28-31) B6(28-31) B3(28-31) B7(28-31) with A0(28-31)
- iacc = _mm256_add_epi32(iacc, mul_sum_i8_pairs_int(_mm256_blend_epi32(rhs_vec_0123_0 ,_mm256_shuffle_epi32(rhs_vec_4567_0, 177), 170), _mm256_shuffle_epi32(lhs_vec_0, 0)));
- iacc = _mm256_add_epi32(iacc, mul_sum_i8_pairs_int(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_0, 177) ,rhs_vec_4567_0, 170), _mm256_shuffle_epi32(lhs_vec_0, 85)));
+ iacc = _mm256_add_epi32(iacc, mul_sum_i8_pairs_int32x8(_mm256_blend_epi32(rhs_vec_0123_0 ,_mm256_shuffle_epi32(rhs_vec_4567_0, 177), 170), _mm256_shuffle_epi32(lhs_vec_0, 0)));
+ iacc = _mm256_add_epi32(iacc, mul_sum_i8_pairs_int32x8(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_0, 177) ,rhs_vec_4567_0, 170), _mm256_shuffle_epi32(lhs_vec_0, 85)));
- iacc = _mm256_add_epi32(iacc, mul_sum_i8_pairs_int(_mm256_blend_epi32(rhs_vec_0123_1 ,_mm256_shuffle_epi32(rhs_vec_4567_1, 177), 170), _mm256_shuffle_epi32(lhs_vec_0, 170)));
- iacc = _mm256_add_epi32(iacc, mul_sum_i8_pairs_int(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_1, 177) ,rhs_vec_4567_1, 170), _mm256_shuffle_epi32(lhs_vec_0, 255)));
+ iacc = _mm256_add_epi32(iacc, mul_sum_i8_pairs_int32x8(_mm256_blend_epi32(rhs_vec_0123_1 ,_mm256_shuffle_epi32(rhs_vec_4567_1, 177), 170), _mm256_shuffle_epi32(lhs_vec_0, 170)));
+ iacc = _mm256_add_epi32(iacc, mul_sum_i8_pairs_int32x8(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_1, 177) ,rhs_vec_4567_1, 170), _mm256_shuffle_epi32(lhs_vec_0, 255)));
- iacc = _mm256_add_epi32(iacc, mul_sum_i8_pairs_int(_mm256_blend_epi32(rhs_vec_0123_2 ,_mm256_shuffle_epi32(rhs_vec_4567_2, 177), 170), _mm256_shuffle_epi32(lhs_vec_1, 0)));
- iacc = _mm256_add_epi32(iacc, mul_sum_i8_pairs_int(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_2, 177) ,rhs_vec_4567_2, 170), _mm256_shuffle_epi32(lhs_vec_1, 85)));
+ iacc = _mm256_add_epi32(iacc, mul_sum_i8_pairs_int32x8(_mm256_blend_epi32(rhs_vec_0123_2 ,_mm256_shuffle_epi32(rhs_vec_4567_2, 177), 170), _mm256_shuffle_epi32(lhs_vec_1, 0)));
+ iacc = _mm256_add_epi32(iacc, mul_sum_i8_pairs_int32x8(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_2, 177) ,rhs_vec_4567_2, 170), _mm256_shuffle_epi32(lhs_vec_1, 85)));
- iacc = _mm256_add_epi32(iacc, mul_sum_i8_pairs_int(_mm256_blend_epi32(rhs_vec_0123_3 ,_mm256_shuffle_epi32(rhs_vec_4567_3, 177), 170), _mm256_shuffle_epi32(lhs_vec_1, 170)));
- iacc = _mm256_add_epi32(iacc, mul_sum_i8_pairs_int(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_3, 177) ,rhs_vec_4567_3, 170), _mm256_shuffle_epi32(lhs_vec_1, 255)));
+ iacc = _mm256_add_epi32(iacc, mul_sum_i8_pairs_int32x8(_mm256_blend_epi32(rhs_vec_0123_3 ,_mm256_shuffle_epi32(rhs_vec_4567_3, 177), 170), _mm256_shuffle_epi32(lhs_vec_1, 170)));
+ iacc = _mm256_add_epi32(iacc, mul_sum_i8_pairs_int32x8(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_3, 177) ,rhs_vec_4567_3, 170), _mm256_shuffle_epi32(lhs_vec_1, 255)));
// Accumulated values multipled with appropriate scales
acc_row = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc), _mm256_mul_ps(col_scale_f32, row_scale_f32), acc_row);
@@ -950,31 +999,33 @@ void ggml_gemv_q4_0_8x8_q8_0(int n, float * restrict s, size_t bs, const void *
_mm256_storeu_ps(s + (y * nr + x * 8), acc_row);
}
}
-#else
- float sumf[8];
- int sumi;
+ return;
+#endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__)
+ {
+ float sumf[8];
+ int sumi;
- const block_q8_0 * a_ptr = (const block_q8_0 *) vy;
- for (int x = 0; x < nc / ncols_interleaved; x++) {
- const block_q4_0x8 * b_ptr = (const block_q4_0x8 *) vx + (x * nb);
+ const block_q8_0 * a_ptr = (const block_q8_0 *) vy;
+ for (int x = 0; x < nc / ncols_interleaved; x++) {
+ const block_q4_0x8 * b_ptr = (const block_q4_0x8 *) vx + (x * nb);
- for (int j = 0; j < ncols_interleaved; j++) sumf[j] = 0.0;
- for (int l = 0; l < nb; l++) {
- for (int k = 0; k < (qk / (2 * blocklen)); k++) {
- for (int j = 0; j < ncols_interleaved; j++) {
- sumi = 0;
- for (int i = 0; i < blocklen; ++i) {
- const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] << 4);
- const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF0);
- sumi += ((v0 * a_ptr[l].qs[k * blocklen + i]) + (v1 * a_ptr[l].qs[k * blocklen + i + qk / 2])) >> 4;
+ for (int j = 0; j < ncols_interleaved; j++) sumf[j] = 0.0;
+ for (int l = 0; l < nb; l++) {
+ for (int k = 0; k < (qk / (2 * blocklen)); k++) {
+ for (int j = 0; j < ncols_interleaved; j++) {
+ sumi = 0;
+ for (int i = 0; i < blocklen; ++i) {
+ const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] << 4);
+ const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF0);
+ sumi += ((v0 * a_ptr[l].qs[k * blocklen + i]) + (v1 * a_ptr[l].qs[k * blocklen + i + qk / 2])) >> 4;
+ }
+ sumf[j] += sumi * GGML_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_FP16_TO_FP32(a_ptr[l].d);
}
- sumf[j] += sumi * GGML_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_FP16_TO_FP32(a_ptr[l].d);
}
}
+ for (int j = 0; j < ncols_interleaved; j++) s[x * ncols_interleaved + j] = sumf[j];
}
- for (int j = 0; j < ncols_interleaved; j++) s[x * ncols_interleaved + j] = sumf[j];
}
-#endif
}
void ggml_gemm_q4_0_4x4_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, const void * restrict vy, int nr, int nc) {
@@ -997,505 +1048,500 @@ void ggml_gemm_q4_0_4x4_q8_0(int n, float * restrict s, size_t bs, const void *
UNUSED(ncols_interleaved);
UNUSED(blocklen);
-#if defined(__ARM_FEATURE_SVE) && defined(__ARM_FEATURE_MATMUL_INT8)
- if (ggml_sve_cnt_b == QK8_0) {
- GGML_ASSERT(!(ggml_cpu_has_sve() && (ggml_sve_cnt_b == QK8_0)) &&
- "__ARM_FEATURE_SVE defined, use the Q4_0_8_8 quantization format for optimal performance");
+#if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON)
+ if (ggml_cpu_has_neon()) {
+ const void * b_ptr = vx;
+ const void * a_ptr = vy;
+ float * res_ptr = s;
+ size_t res_stride = bs * sizeof(float);
+
+ __asm__ __volatile__(
+ "mov x10, %x[nr]\n"
+ "mov x9, #0x88\n"
+ "cmp x10, #0x10\n"
+ "mul x9, %x[nb], x9\n"
+ "blt 4f\n"
+ "1:" // Row loop
+ "add x28, %x[b_ptr], #0x8\n"
+ "mov x27, %x[nc]\n"
+ "add x26, %x[res_ptr], %x[res_stride], LSL #4\n"
+ "2:" // Column loop
+ "add x25, %x[a_ptr], #0x8\n"
+ "movi v15.16b, #0x0\n"
+ "movi v19.16b, #0x0\n"
+ "mov x24, %x[nb]\n"
+ "add x23, x25, x9\n"
+ "movi v18.16b, #0x0\n"
+ "movi v14.16b, #0x0\n"
+ "add x22, x23, x9\n"
+ "movi v11.16b, #0x0\n"
+ "movi v13.16b, #0x0\n"
+ "add x21, x22, x9\n"
+ "movi v23.16b, #0x0\n"
+ "movi v16.16b, #0x0\n"
+ "movi v25.16b, #0x0\n"
+ "movi v7.16b, #0x0\n"
+ "movi v0.16b, #0x0\n"
+ "movi v4.16b, #0x0\n"
+ "movi v5.16b, #0x0\n"
+ "movi v21.16b, #0x0\n"
+ "movi v8.16b, #0x0\n"
+ "movi v1.16b, #0x0\n"
+ "3:" // Block loop
+ "ldr q3, [x28, #0x0]\n"
+ "ldr q31, [x25, #0x0]\n"
+ "movi v28.16b, #0x4\n"
+ "movi v10.4s, #0x0\n"
+ "ldr q22, [x28, #0x10]\n"
+ "ldr q6, [x25, #0x10]\n"
+ "movi v29.4s, #0x0\n"
+ "movi v9.4s, #0x0\n"
+ "ldr q27, [x28, #0x20]\n"
+ "ldr q30, [x28, #0x30]\n"
+ "movi v20.4s, #0x0\n"
+ "movi v24.16b, #0xf0\n"
+ "ldr d2, [x25, #-0x8]\n"
+ "ldr d26, [x23, #-0x8]\n"
+ "sshl v12.16b, v3.16b, v28.16b\n"
+ "sub x20, x28, #0x8\n"
+ "ldr d17, [x20, #0x0]\n"
+ "and v3.16b, v3.16b, v24.16b\n"
+ "subs x24, x24, #0x1\n"
+ "add x28, x28, #0x48\n"
+ ".inst 0x4f9fe18a // sdot v10.4s, v12.16b, v31.4b[0]\n"
+ ".inst 0x4fbfe19d // sdot v29.4s, v12.16b, v31.4b[1]\n"
+ ".inst 0x4f9fe989 // sdot v9.4s, v12.16b, v31.4b[2]\n"
+ ".inst 0x4fbfe994 // sdot v20.4s, v12.16b, v31.4b[3]\n"
+ "sshl v31.16b, v22.16b, v28.16b\n"
+ "and v22.16b, v22.16b, v24.16b\n"
+ "fcvtl v17.4s, v17.4h\n"
+ "fcvtl v2.4s, v2.4h\n"
+ "fcvtl v26.4s, v26.4h\n"
+ ".inst 0x4f86e3ea // sdot v10.4s, v31.16b, v6.4b[0]\n"
+ ".inst 0x4fa6e3fd // sdot v29.4s, v31.16b, v6.4b[1]\n"
+ ".inst 0x4f86ebe9 // sdot v9.4s, v31.16b, v6.4b[2]\n"
+ ".inst 0x4fa6ebf4 // sdot v20.4s, v31.16b, v6.4b[3]\n"
+ "sshl v6.16b, v27.16b, v28.16b\n"
+ "sshl v28.16b, v30.16b, v28.16b\n"
+ "and v27.16b, v27.16b, v24.16b\n"
+ "and v30.16b, v30.16b, v24.16b\n"
+ "ldr q24, [x25, #0x20]\n"
+ ".inst 0x4f98e0ca // sdot v10.4s, v6.16b, v24.4b[0]\n"
+ ".inst 0x4fb8e0dd // sdot v29.4s, v6.16b, v24.4b[1]\n"
+ ".inst 0x4f98e8c9 // sdot v9.4s, v6.16b, v24.4b[2]\n"
+ ".inst 0x4fb8e8d4 // sdot v20.4s, v6.16b, v24.4b[3]\n"
+ "ldr q24, [x25, #0x30]\n"
+ ".inst 0x4f98e38a // sdot v10.4s, v28.16b, v24.4b[0]\n"
+ ".inst 0x4fb8e39d // sdot v29.4s, v28.16b, v24.4b[1]\n"
+ ".inst 0x4f98eb89 // sdot v9.4s, v28.16b, v24.4b[2]\n"
+ ".inst 0x4fb8eb94 // sdot v20.4s, v28.16b, v24.4b[3]\n"
+ "ldr q24, [x25, #0x40]\n"
+ ".inst 0x4f98e06a // sdot v10.4s, v3.16b, v24.4b[0]\n"
+ ".inst 0x4fb8e07d // sdot v29.4s, v3.16b, v24.4b[1]\n"
+ ".inst 0x4f98e869 // sdot v9.4s, v3.16b, v24.4b[2]\n"
+ ".inst 0x4fb8e874 // sdot v20.4s, v3.16b, v24.4b[3]\n"
+ "ldr q24, [x25, #0x50]\n"
+ ".inst 0x4f98e2ca // sdot v10.4s, v22.16b, v24.4b[0]\n"
+ ".inst 0x4fb8e2dd // sdot v29.4s, v22.16b, v24.4b[1]\n"
+ ".inst 0x4f98eac9 // sdot v9.4s, v22.16b, v24.4b[2]\n"
+ ".inst 0x4fb8ead4 // sdot v20.4s, v22.16b, v24.4b[3]\n"
+ "ldr q24, [x25, #0x60]\n"
+ ".inst 0x4f98e36a // sdot v10.4s, v27.16b, v24.4b[0]\n"
+ ".inst 0x4fb8e37d // sdot v29.4s, v27.16b, v24.4b[1]\n"
+ ".inst 0x4f98eb69 // sdot v9.4s, v27.16b, v24.4b[2]\n"
+ ".inst 0x4fb8eb74 // sdot v20.4s, v27.16b, v24.4b[3]\n"
+ "ldr q24, [x25, #0x70]\n"
+ "add x25, x25, #0x88\n"
+ ".inst 0x4f98e3ca // sdot v10.4s, v30.16b, v24.4b[0]\n"
+ ".inst 0x4fb8e3dd // sdot v29.4s, v30.16b, v24.4b[1]\n"
+ ".inst 0x4f98ebc9 // sdot v9.4s, v30.16b, v24.4b[2]\n"
+ ".inst 0x4fb8ebd4 // sdot v20.4s, v30.16b, v24.4b[3]\n"
+ "fmul v24.4s, v17.4s, v2.s[0]\n"
+ "scvtf v10.4s, v10.4s, #0x4\n"
+ "scvtf v29.4s, v29.4s, #0x4\n"
+ "scvtf v9.4s, v9.4s, #0x4\n"
+ "scvtf v20.4s, v20.4s, #0x4\n"
+ "fmla v15.4s, v10.4s, v24.4s\n"
+ "ldr q24, [x23, #0x0]\n"
+ "fmul v10.4s, v17.4s, v2.s[1]\n"
+ "fmla v19.4s, v29.4s, v10.4s\n"
+ "ldr q10, [x23, #0x10]\n"
+ "fmul v29.4s, v17.4s, v2.s[2]\n"
+ "fmul v2.4s, v17.4s, v2.s[3]\n"
+ "fmla v18.4s, v9.4s, v29.4s\n"
+ "movi v9.4s, #0x0\n"
+ "movi v29.4s, #0x0\n"
+ ".inst 0x4f98e189 // sdot v9.4s, v12.16b, v24.4b[0]\n"
+ ".inst 0x4fb8e19d // sdot v29.4s, v12.16b, v24.4b[1]\n"
+ "fmla v14.4s, v20.4s, v2.4s\n"
+ "movi v20.4s, #0x0\n"
+ "movi v2.4s, #0x0\n"
+ ".inst 0x4f98e994 // sdot v20.4s, v12.16b, v24.4b[2]\n"
+ ".inst 0x4fb8e982 // sdot v2.4s, v12.16b, v24.4b[3]\n"
+ "ldr q24, [x23, #0x20]\n"
+ ".inst 0x4f8ae3e9 // sdot v9.4s, v31.16b, v10.4b[0]\n"
+ ".inst 0x4faae3fd // sdot v29.4s, v31.16b, v10.4b[1]\n"
+ ".inst 0x4f8aebf4 // sdot v20.4s, v31.16b, v10.4b[2]\n"
+ ".inst 0x4faaebe2 // sdot v2.4s, v31.16b, v10.4b[3]\n"
+ "ldr q10, [x23, #0x30]\n"
+ ".inst 0x4f98e0c9 // sdot v9.4s, v6.16b, v24.4b[0]\n"
+ ".inst 0x4fb8e0dd // sdot v29.4s, v6.16b, v24.4b[1]\n"
+ ".inst 0x4f98e8d4 // sdot v20.4s, v6.16b, v24.4b[2]\n"
+ ".inst 0x4fb8e8c2 // sdot v2.4s, v6.16b, v24.4b[3]\n"
+ "ldr q24, [x23, #0x40]\n"
+ ".inst 0x4f8ae389 // sdot v9.4s, v28.16b, v10.4b[0]\n"
+ ".inst 0x4faae39d // sdot v29.4s, v28.16b, v10.4b[1]\n"
+ ".inst 0x4f8aeb94 // sdot v20.4s, v28.16b, v10.4b[2]\n"
+ ".inst 0x4faaeb82 // sdot v2.4s, v28.16b, v10.4b[3]\n"
+ "ldr q10, [x23, #0x50]\n"
+ ".inst 0x4f98e069 // sdot v9.4s, v3.16b, v24.4b[0]\n"
+ ".inst 0x4fb8e07d // sdot v29.4s, v3.16b, v24.4b[1]\n"
+ ".inst 0x4f98e874 // sdot v20.4s, v3.16b, v24.4b[2]\n"
+ ".inst 0x4fb8e862 // sdot v2.4s, v3.16b, v24.4b[3]\n"
+ "ldr q24, [x23, #0x60]\n"
+ ".inst 0x4f8ae2c9 // sdot v9.4s, v22.16b, v10.4b[0]\n"
+ ".inst 0x4faae2dd // sdot v29.4s, v22.16b, v10.4b[1]\n"
+ ".inst 0x4f8aead4 // sdot v20.4s, v22.16b, v10.4b[2]\n"
+ ".inst 0x4faaeac2 // sdot v2.4s, v22.16b, v10.4b[3]\n"
+ "ldr q10, [x23, #0x70]\n"
+ "add x23, x23, #0x88\n"
+ ".inst 0x4f98e369 // sdot v9.4s, v27.16b, v24.4b[0]\n"
+ ".inst 0x4fb8e37d // sdot v29.4s, v27.16b, v24.4b[1]\n"
+ ".inst 0x4f98eb74 // sdot v20.4s, v27.16b, v24.4b[2]\n"
+ ".inst 0x4fb8eb62 // sdot v2.4s, v27.16b, v24.4b[3]\n"
+ "ldr q24, [x22, #0x0]\n"
+ ".inst 0x4f8ae3c9 // sdot v9.4s, v30.16b, v10.4b[0]\n"
+ ".inst 0x4faae3dd // sdot v29.4s, v30.16b, v10.4b[1]\n"
+ ".inst 0x4f8aebd4 // sdot v20.4s, v30.16b, v10.4b[2]\n"
+ ".inst 0x4faaebc2 // sdot v2.4s, v30.16b, v10.4b[3]\n"
+ "fmul v10.4s, v17.4s, v26.s[0]\n"
+ "scvtf v9.4s, v9.4s, #0x4\n"
+ "scvtf v29.4s, v29.4s, #0x4\n"
+ "scvtf v20.4s, v20.4s, #0x4\n"
+ "scvtf v2.4s, v2.4s, #0x4\n"
+ "fmla v11.4s, v9.4s, v10.4s\n"
+ "ldr q9, [x22, #0x10]\n"
+ "fmul v10.4s, v17.4s, v26.s[1]\n"
+ "fmla v13.4s, v29.4s, v10.4s\n"
+ "ldr d29, [x22, #-0x8]\n"
+ "fmul v10.4s, v17.4s, v26.s[2]\n"
+ "fmul v26.4s, v17.4s, v26.s[3]\n"
+ "fcvtl v29.4s, v29.4h\n"
+ "fmla v23.4s, v20.4s, v10.4s\n"
+ "movi v20.4s, #0x0\n"
+ "movi v10.4s, #0x0\n"
+ "fmla v16.4s, v2.4s, v26.4s\n"
+ "movi v26.4s, #0x0\n"
+ "movi v2.4s, #0x0\n"
+ ".inst 0x4f98e194 // sdot v20.4s, v12.16b, v24.4b[0]\n"
+ ".inst 0x4fb8e18a // sdot v10.4s, v12.16b, v24.4b[1]\n"
+ ".inst 0x4f98e99a // sdot v26.4s, v12.16b, v24.4b[2]\n"
+ ".inst 0x4fb8e982 // sdot v2.4s, v12.16b, v24.4b[3]\n"
+ "ldr q24, [x22, #0x20]\n"
+ ".inst 0x4f89e3f4 // sdot v20.4s, v31.16b, v9.4b[0]\n"
+ ".inst 0x4fa9e3ea // sdot v10.4s, v31.16b, v9.4b[1]\n"
+ ".inst 0x4f89ebfa // sdot v26.4s, v31.16b, v9.4b[2]\n"
+ ".inst 0x4fa9ebe2 // sdot v2.4s, v31.16b, v9.4b[3]\n"
+ "ldr q9, [x22, #0x30]\n"
+ ".inst 0x4f98e0d4 // sdot v20.4s, v6.16b, v24.4b[0]\n"
+ ".inst 0x4fb8e0ca // sdot v10.4s, v6.16b, v24.4b[1]\n"
+ ".inst 0x4f98e8da // sdot v26.4s, v6.16b, v24.4b[2]\n"
+ ".inst 0x4fb8e8c2 // sdot v2.4s, v6.16b, v24.4b[3]\n"
+ "ldr q24, [x22, #0x40]\n"
+ ".inst 0x4f89e394 // sdot v20.4s, v28.16b, v9.4b[0]\n"
+ ".inst 0x4fa9e38a // sdot v10.4s, v28.16b, v9.4b[1]\n"
+ ".inst 0x4f89eb9a // sdot v26.4s, v28.16b, v9.4b[2]\n"
+ ".inst 0x4fa9eb82 // sdot v2.4s, v28.16b, v9.4b[3]\n"
+ "ldr q9, [x22, #0x50]\n"
+ ".inst 0x4f98e074 // sdot v20.4s, v3.16b, v24.4b[0]\n"
+ ".inst 0x4fb8e06a // sdot v10.4s, v3.16b, v24.4b[1]\n"
+ ".inst 0x4f98e87a // sdot v26.4s, v3.16b, v24.4b[2]\n"
+ ".inst 0x4fb8e862 // sdot v2.4s, v3.16b, v24.4b[3]\n"
+ "ldr q24, [x22, #0x60]\n"
+ ".inst 0x4f89e2d4 // sdot v20.4s, v22.16b, v9.4b[0]\n"
+ ".inst 0x4fa9e2ca // sdot v10.4s, v22.16b, v9.4b[1]\n"
+ ".inst 0x4f89eada // sdot v26.4s, v22.16b, v9.4b[2]\n"
+ ".inst 0x4fa9eac2 // sdot v2.4s, v22.16b, v9.4b[3]\n"
+ "ldr q9, [x22, #0x70]\n"
+ "add x22, x22, #0x88\n"
+ ".inst 0x4f98e374 // sdot v20.4s, v27.16b, v24.4b[0]\n"
+ ".inst 0x4fb8e36a // sdot v10.4s, v27.16b, v24.4b[1]\n"
+ ".inst 0x4f98eb7a // sdot v26.4s, v27.16b, v24.4b[2]\n"
+ ".inst 0x4fb8eb62 // sdot v2.4s, v27.16b, v24.4b[3]\n"
+ "ldr q24, [x21, #0x0]\n"
+ ".inst 0x4f89e3d4 // sdot v20.4s, v30.16b, v9.4b[0]\n"
+ ".inst 0x4fa9e3ca // sdot v10.4s, v30.16b, v9.4b[1]\n"
+ ".inst 0x4f89ebda // sdot v26.4s, v30.16b, v9.4b[2]\n"
+ ".inst 0x4fa9ebc2 // sdot v2.4s, v30.16b, v9.4b[3]\n"
+ "fmul v9.4s, v17.4s, v29.s[0]\n"
+ "scvtf v20.4s, v20.4s, #0x4\n"
+ "scvtf v10.4s, v10.4s, #0x4\n"
+ "scvtf v26.4s, v26.4s, #0x4\n"
+ "scvtf v2.4s, v2.4s, #0x4\n"
+ "fmla v25.4s, v20.4s, v9.4s\n"
+ "ldr q9, [x21, #0x10]\n"
+ "fmul v20.4s, v17.4s, v29.s[1]\n"
+ "fmla v7.4s, v10.4s, v20.4s\n"
+ "ldr d20, [x21, #-0x8]\n"
+ "fmul v10.4s, v17.4s, v29.s[2]\n"
+ "fmul v29.4s, v17.4s, v29.s[3]\n"
+ "fcvtl v20.4s, v20.4h\n"
+ "fmla v0.4s, v26.4s, v10.4s\n"
+ "movi v26.4s, #0x0\n"
+ "movi v10.4s, #0x0\n"
+ "fmla v4.4s, v2.4s, v29.4s\n"
+ "movi v2.4s, #0x0\n"
+ "movi v29.4s, #0x0\n"
+ ".inst 0x4f98e19a // sdot v26.4s, v12.16b, v24.4b[0]\n"
+ ".inst 0x4fb8e18a // sdot v10.4s, v12.16b, v24.4b[1]\n"
+ ".inst 0x4f98e982 // sdot v2.4s, v12.16b, v24.4b[2]\n"
+ ".inst 0x4fb8e99d // sdot v29.4s, v12.16b, v24.4b[3]\n"
+ "ldr q12, [x21, #0x20]\n"
+ "fmul v24.4s, v17.4s, v20.s[0]\n"
+ ".inst 0x4f89e3fa // sdot v26.4s, v31.16b, v9.4b[0]\n"
+ ".inst 0x4fa9e3ea // sdot v10.4s, v31.16b, v9.4b[1]\n"
+ ".inst 0x4f89ebe2 // sdot v2.4s, v31.16b, v9.4b[2]\n"
+ ".inst 0x4fa9ebfd // sdot v29.4s, v31.16b, v9.4b[3]\n"
+ "ldr q9, [x21, #0x30]\n"
+ "fmul v31.4s, v17.4s, v20.s[1]\n"
+ ".inst 0x4f8ce0da // sdot v26.4s, v6.16b, v12.4b[0]\n"
+ ".inst 0x4face0ca // sdot v10.4s, v6.16b, v12.4b[1]\n"
+ ".inst 0x4f8ce8c2 // sdot v2.4s, v6.16b, v12.4b[2]\n"
+ ".inst 0x4face8dd // sdot v29.4s, v6.16b, v12.4b[3]\n"
+ "ldr q12, [x21, #0x40]\n"
+ "fmul v6.4s, v17.4s, v20.s[2]\n"
+ "fmul v20.4s, v17.4s, v20.s[3]\n"
+ ".inst 0x4f89e39a // sdot v26.4s, v28.16b, v9.4b[0]\n"
+ ".inst 0x4fa9e38a // sdot v10.4s, v28.16b, v9.4b[1]\n"
+ ".inst 0x4f89eb82 // sdot v2.4s, v28.16b, v9.4b[2]\n"
+ ".inst 0x4fa9eb9d // sdot v29.4s, v28.16b, v9.4b[3]\n"
+ "ldr q9, [x21, #0x50]\n"
+ ".inst 0x4f8ce07a // sdot v26.4s, v3.16b, v12.4b[0]\n"
+ ".inst 0x4face06a // sdot v10.4s, v3.16b, v12.4b[1]\n"
+ ".inst 0x4f8ce862 // sdot v2.4s, v3.16b, v12.4b[2]\n"
+ ".inst 0x4face87d // sdot v29.4s, v3.16b, v12.4b[3]\n"
+ "ldr q12, [x21, #0x60]\n"
+ ".inst 0x4f89e2da // sdot v26.4s, v22.16b, v9.4b[0]\n"
+ ".inst 0x4fa9e2ca // sdot v10.4s, v22.16b, v9.4b[1]\n"
+ ".inst 0x4f89eac2 // sdot v2.4s, v22.16b, v9.4b[2]\n"
+ ".inst 0x4fa9eadd // sdot v29.4s, v22.16b, v9.4b[3]\n"
+ "ldr q17, [x21, #0x70]\n"
+ "add x21, x21, #0x88\n"
+ ".inst 0x4f8ce37a // sdot v26.4s, v27.16b, v12.4b[0]\n"
+ ".inst 0x4face36a // sdot v10.4s, v27.16b, v12.4b[1]\n"
+ ".inst 0x4f8ceb62 // sdot v2.4s, v27.16b, v12.4b[2]\n"
+ ".inst 0x4faceb7d // sdot v29.4s, v27.16b, v12.4b[3]\n"
+ ".inst 0x4f91e3da // sdot v26.4s, v30.16b, v17.4b[0]\n"
+ ".inst 0x4fb1e3ca // sdot v10.4s, v30.16b, v17.4b[1]\n"
+ ".inst 0x4f91ebc2 // sdot v2.4s, v30.16b, v17.4b[2]\n"
+ ".inst 0x4fb1ebdd // sdot v29.4s, v30.16b, v17.4b[3]\n"
+ "scvtf v26.4s, v26.4s, #0x4\n"
+ "scvtf v10.4s, v10.4s, #0x4\n"
+ "fmla v5.4s, v26.4s, v24.4s\n"
+ "scvtf v2.4s, v2.4s, #0x4\n"
+ "scvtf v29.4s, v29.4s, #0x4\n"
+ "fmla v21.4s, v10.4s, v31.4s\n"
+ "fmla v8.4s, v2.4s, v6.4s\n"
+ "fmla v1.4s, v29.4s, v20.4s\n"
+ "bgt 3b\n"
+ "mov x20, %x[res_ptr]\n"
+ "subs x27, x27, #0x4\n"
+ "add %x[res_ptr], %x[res_ptr], #0x10\n"
+ "str q15, [x20, #0x0]\n"
+ "add x20, x20, %x[res_stride]\n"
+ "str q19, [x20, #0x0]\n"
+ "add x20, x20, %x[res_stride]\n"
+ "str q18, [x20, #0x0]\n"
+ "add x20, x20, %x[res_stride]\n"
+ "str q14, [x20, #0x0]\n"
+ "add x20, x20, %x[res_stride]\n"
+ "str q11, [x20, #0x0]\n"
+ "add x20, x20, %x[res_stride]\n"
+ "str q13, [x20, #0x0]\n"
+ "add x20, x20, %x[res_stride]\n"
+ "str q23, [x20, #0x0]\n"
+ "add x20, x20, %x[res_stride]\n"
+ "str q16, [x20, #0x0]\n"
+ "add x20, x20, %x[res_stride]\n"
+ "str q25, [x20, #0x0]\n"
+ "add x20, x20, %x[res_stride]\n"
+ "str q7, [x20, #0x0]\n"
+ "add x20, x20, %x[res_stride]\n"
+ "str q0, [x20, #0x0]\n"
+ "add x20, x20, %x[res_stride]\n"
+ "str q4, [x20, #0x0]\n"
+ "add x20, x20, %x[res_stride]\n"
+ "str q5, [x20, #0x0]\n"
+ "add x20, x20, %x[res_stride]\n"
+ "str q21, [x20, #0x0]\n"
+ "add x20, x20, %x[res_stride]\n"
+ "str q8, [x20, #0x0]\n"
+ "add x20, x20, %x[res_stride]\n"
+ "str q1, [x20, #0x0]\n"
+ "bne 2b\n"
+ "mov x20, #0x4\n"
+ "sub x10, x10, #0x10\n"
+ "cmp x10, #0x10\n"
+ "mov %x[res_ptr], x26\n"
+ "madd %x[a_ptr], x20, x9, %x[a_ptr]\n"
+ "bge 1b\n"
+ "4:" // Row loop skip
+ "cbz x10, 9f\n"
+ "5:" // Row tail: Row loop
+ "add x24, %x[b_ptr], #0x8\n"
+ "mov x23, %x[nc]\n"
+ "add x22, %x[res_ptr], %x[res_stride], LSL #2\n"
+ "6:" // Row tail: Column loop
+ "movi v15.16b, #0x0\n"
+ "movi v19.16b, #0x0\n"
+ "add x25, %x[a_ptr], #0x8\n"
+ "mov x21, %x[nb]\n"
+ "movi v18.16b, #0x0\n"
+ "movi v14.16b, #0x0\n"
+ "7:" // Row tail: Block loop
+ "ldr q7, [x24, #0x0]\n"
+ "ldr q5, [x25, #0x0]\n"
+ "movi v9.16b, #0x4\n"
+ "movi v4.4s, #0x0\n"
+ "ldr q3, [x24, #0x10]\n"
+ "ldr q2, [x25, #0x10]\n"
+ "movi v1.4s, #0x0\n"
+ "movi v0.4s, #0x0\n"
+ "ldr q13, [x24, #0x20]\n"
+ "ldr q31, [x25, #0x20]\n"
+ "movi v30.4s, #0x0\n"
+ "movi v29.16b, #0xf0\n"
+ "ldr q28, [x24, #0x30]\n"
+ "ldr q27, [x25, #0x30]\n"
+ "sshl v20.16b, v7.16b, v9.16b\n"
+ "sub x20, x24, #0x8\n"
+ "ldr q26, [x25, #0x40]\n"
+ "ldr q25, [x25, #0x50]\n"
+ "sshl v17.16b, v3.16b, v9.16b\n"
+ "and v7.16b, v7.16b, v29.16b\n"
+ "ldr q24, [x25, #0x60]\n"
+ "ldr q16, [x25, #0x70]\n"
+ "sshl v22.16b, v13.16b, v9.16b\n"
+ "and v3.16b, v3.16b, v29.16b\n"
+ "ldr d21, [x20, #0x0]\n"
+ "ldr d12, [x25, #-0x8]\n"
+ ".inst 0x4f85e284 // sdot v4.4s, v20.16b, v5.4b[0]\n"
+ ".inst 0x4fa5e281 // sdot v1.4s, v20.16b, v5.4b[1]\n"
+ ".inst 0x4f85ea80 // sdot v0.4s, v20.16b, v5.4b[2]\n"
+ ".inst 0x4fa5ea9e // sdot v30.4s, v20.16b, v5.4b[3]\n"
+ "sshl v9.16b, v28.16b, v9.16b\n"
+ "subs x21, x21, #0x1\n"
+ "and v13.16b, v13.16b, v29.16b\n"
+ "and v28.16b, v28.16b, v29.16b\n"
+ "add x25, x25, #0x88\n"
+ "add x24, x24, #0x48\n"
+ "fcvtl v21.4s, v21.4h\n"
+ "fcvtl v12.4s, v12.4h\n"
+ ".inst 0x4f82e224 // sdot v4.4s, v17.16b, v2.4b[0]\n"
+ ".inst 0x4fa2e221 // sdot v1.4s, v17.16b, v2.4b[1]\n"
+ ".inst 0x4f82ea20 // sdot v0.4s, v17.16b, v2.4b[2]\n"
+ ".inst 0x4fa2ea3e // sdot v30.4s, v17.16b, v2.4b[3]\n"
+ "fmul v11.4s, v21.4s, v12.s[0]\n"
+ "fmul v23.4s, v21.4s, v12.s[1]\n"
+ "fmul v17.4s, v21.4s, v12.s[2]\n"
+ ".inst 0x4f9fe2c4 // sdot v4.4s, v22.16b, v31.4b[0]\n"
+ "fmul v6.4s, v21.4s, v12.s[3]\n"
+ ".inst 0x4fbfe2c1 // sdot v1.4s, v22.16b, v31.4b[1]\n"
+ ".inst 0x4f9feac0 // sdot v0.4s, v22.16b, v31.4b[2]\n"
+ ".inst 0x4fbfeade // sdot v30.4s, v22.16b, v31.4b[3]\n"
+ ".inst 0x4f9be124 // sdot v4.4s, v9.16b, v27.4b[0]\n"
+ ".inst 0x4fbbe121 // sdot v1.4s, v9.16b, v27.4b[1]\n"
+ ".inst 0x4f9be920 // sdot v0.4s, v9.16b, v27.4b[2]\n"
+ ".inst 0x4fbbe93e // sdot v30.4s, v9.16b, v27.4b[3]\n"
+ ".inst 0x4f9ae0e4 // sdot v4.4s, v7.16b, v26.4b[0]\n"
+ ".inst 0x4fbae0e1 // sdot v1.4s, v7.16b, v26.4b[1]\n"
+ ".inst 0x4f9ae8e0 // sdot v0.4s, v7.16b, v26.4b[2]\n"
+ ".inst 0x4fbae8fe // sdot v30.4s, v7.16b, v26.4b[3]\n"
+ ".inst 0x4f99e064 // sdot v4.4s, v3.16b, v25.4b[0]\n"
+ ".inst 0x4fb9e061 // sdot v1.4s, v3.16b, v25.4b[1]\n"
+ ".inst 0x4f99e860 // sdot v0.4s, v3.16b, v25.4b[2]\n"
+ ".inst 0x4fb9e87e // sdot v30.4s, v3.16b, v25.4b[3]\n"
+ ".inst 0x4f98e1a4 // sdot v4.4s, v13.16b, v24.4b[0]\n"
+ ".inst 0x4fb8e1a1 // sdot v1.4s, v13.16b, v24.4b[1]\n"
+ ".inst 0x4f98e9a0 // sdot v0.4s, v13.16b, v24.4b[2]\n"
+ ".inst 0x4fb8e9be // sdot v30.4s, v13.16b, v24.4b[3]\n"
+ ".inst 0x4f90e384 // sdot v4.4s, v28.16b, v16.4b[0]\n"
+ ".inst 0x4fb0e381 // sdot v1.4s, v28.16b, v16.4b[1]\n"
+ ".inst 0x4f90eb80 // sdot v0.4s, v28.16b, v16.4b[2]\n"
+ ".inst 0x4fb0eb9e // sdot v30.4s, v28.16b, v16.4b[3]\n"
+ "scvtf v4.4s, v4.4s, #0x4\n"
+ "scvtf v1.4s, v1.4s, #0x4\n"
+ "scvtf v0.4s, v0.4s, #0x4\n"
+ "fmla v15.4s, v4.4s, v11.4s\n"
+ "scvtf v30.4s, v30.4s, #0x4\n"
+ "fmla v19.4s, v1.4s, v23.4s\n"
+ "fmla v18.4s, v0.4s, v17.4s\n"
+ "fmla v14.4s, v30.4s, v6.4s\n"
+ "bgt 7b\n"
+ "mov x20, %x[res_ptr]\n"
+ "cmp x10, #0x1\n"
+ "str q15, [x20, #0x0]\n"
+ "add x20, x20, %x[res_stride]\n"
+ "ble 8f\n"
+ "cmp x10, #0x2\n"
+ "str q19, [x20, #0x0]\n"
+ "add x20, x20, %x[res_stride]\n"
+ "ble 8f\n"
+ "cmp x10, #0x3\n"
+ "str q18, [x20, #0x0]\n"
+ "add x20, x20, %x[res_stride]\n"
+ "ble 8f\n"
+ "str q14, [x20, #0x0]\n"
+ "8:" // Row tail: Accumulator store skip
+ "subs x23, x23, #0x4\n"
+ "add %x[res_ptr], %x[res_ptr], #0x10\n"
+ "bne 6b\n"
+ "subs x10, x10, #0x4\n"
+ "add %x[a_ptr], %x[a_ptr], x9\n"
+ "mov %x[res_ptr], x22\n"
+ "bgt 5b\n"
+ "9:" // Row tail: Row loop skip
+ : [a_ptr] "+&r" (a_ptr), [res_ptr] "+&r" (res_ptr)
+ : [b_ptr] "r" (b_ptr), [nr] "r" (nr), [nb] "r" (nb), [res_stride] "r" (res_stride), [nc] "r" (nc)
+ : "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "x9", "x10", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28"
+ );
+ return;
}
-#endif
-#if defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8)
- GGML_ASSERT(!(ggml_cpu_has_neon() && ggml_cpu_has_matmul_int8()) &&
- "__ARM_NEON and __ARM_FEATURE_MATMUL_INT8 defined, use the Q4_0_4_8 quantization format for optimal performance");
-#elif defined(__ARM_NEON) && defined(__aarch64__) && ! ((defined(_MSC_VER)) && ! defined(__clang__))
- const void * b_ptr = vx;
- const void * a_ptr = vy;
- float * res_ptr = s;
- size_t res_stride = bs * sizeof(float);
+#endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON)
+ {
+ float sumf[4][4];
+ int sumi;
- __asm__ __volatile__(
- "mov x10, %x[nr]\n"
- "mov x9, #0x88\n"
- "cmp x10, #0x10\n"
- "mul x9, %x[nb], x9\n"
- "blt 4f\n"
- "1:" // Row loop
- "add x28, %x[b_ptr], #0x8\n"
- "mov x27, %x[nc]\n"
- "add x26, %x[res_ptr], %x[res_stride], LSL #4\n"
- "2:" // Column loop
- "add x25, %x[a_ptr], #0x8\n"
- "movi v15.16b, #0x0\n"
- "movi v19.16b, #0x0\n"
- "mov x24, %x[nb]\n"
- "add x23, x25, x9\n"
- "movi v18.16b, #0x0\n"
- "movi v14.16b, #0x0\n"
- "add x22, x23, x9\n"
- "movi v11.16b, #0x0\n"
- "movi v13.16b, #0x0\n"
- "add x21, x22, x9\n"
- "movi v23.16b, #0x0\n"
- "movi v16.16b, #0x0\n"
- "movi v25.16b, #0x0\n"
- "movi v7.16b, #0x0\n"
- "movi v0.16b, #0x0\n"
- "movi v4.16b, #0x0\n"
- "movi v5.16b, #0x0\n"
- "movi v21.16b, #0x0\n"
- "movi v8.16b, #0x0\n"
- "movi v1.16b, #0x0\n"
- "3:" // Block loop
- "ldr q3, [x28, #0x0]\n"
- "ldr q31, [x25, #0x0]\n"
- "movi v28.16b, #0x4\n"
- "movi v10.4s, #0x0\n"
- "ldr q22, [x28, #0x10]\n"
- "ldr q6, [x25, #0x10]\n"
- "movi v29.4s, #0x0\n"
- "movi v9.4s, #0x0\n"
- "ldr q27, [x28, #0x20]\n"
- "ldr q30, [x28, #0x30]\n"
- "movi v20.4s, #0x0\n"
- "movi v24.16b, #0xf0\n"
- "ldr d2, [x25, #-0x8]\n"
- "ldr d26, [x23, #-0x8]\n"
- "sshl v12.16b, v3.16b, v28.16b\n"
- "sub x20, x28, #0x8\n"
- "ldr d17, [x20, #0x0]\n"
- "and v3.16b, v3.16b, v24.16b\n"
- "subs x24, x24, #0x1\n"
- "add x28, x28, #0x48\n"
- ".inst 0x4f9fe18a // sdot v10.4s, v12.16b, v31.4b[0]\n"
- ".inst 0x4fbfe19d // sdot v29.4s, v12.16b, v31.4b[1]\n"
- ".inst 0x4f9fe989 // sdot v9.4s, v12.16b, v31.4b[2]\n"
- ".inst 0x4fbfe994 // sdot v20.4s, v12.16b, v31.4b[3]\n"
- "sshl v31.16b, v22.16b, v28.16b\n"
- "and v22.16b, v22.16b, v24.16b\n"
- "fcvtl v17.4s, v17.4h\n"
- "fcvtl v2.4s, v2.4h\n"
- "fcvtl v26.4s, v26.4h\n"
- ".inst 0x4f86e3ea // sdot v10.4s, v31.16b, v6.4b[0]\n"
- ".inst 0x4fa6e3fd // sdot v29.4s, v31.16b, v6.4b[1]\n"
- ".inst 0x4f86ebe9 // sdot v9.4s, v31.16b, v6.4b[2]\n"
- ".inst 0x4fa6ebf4 // sdot v20.4s, v31.16b, v6.4b[3]\n"
- "sshl v6.16b, v27.16b, v28.16b\n"
- "sshl v28.16b, v30.16b, v28.16b\n"
- "and v27.16b, v27.16b, v24.16b\n"
- "and v30.16b, v30.16b, v24.16b\n"
- "ldr q24, [x25, #0x20]\n"
- ".inst 0x4f98e0ca // sdot v10.4s, v6.16b, v24.4b[0]\n"
- ".inst 0x4fb8e0dd // sdot v29.4s, v6.16b, v24.4b[1]\n"
- ".inst 0x4f98e8c9 // sdot v9.4s, v6.16b, v24.4b[2]\n"
- ".inst 0x4fb8e8d4 // sdot v20.4s, v6.16b, v24.4b[3]\n"
- "ldr q24, [x25, #0x30]\n"
- ".inst 0x4f98e38a // sdot v10.4s, v28.16b, v24.4b[0]\n"
- ".inst 0x4fb8e39d // sdot v29.4s, v28.16b, v24.4b[1]\n"
- ".inst 0x4f98eb89 // sdot v9.4s, v28.16b, v24.4b[2]\n"
- ".inst 0x4fb8eb94 // sdot v20.4s, v28.16b, v24.4b[3]\n"
- "ldr q24, [x25, #0x40]\n"
- ".inst 0x4f98e06a // sdot v10.4s, v3.16b, v24.4b[0]\n"
- ".inst 0x4fb8e07d // sdot v29.4s, v3.16b, v24.4b[1]\n"
- ".inst 0x4f98e869 // sdot v9.4s, v3.16b, v24.4b[2]\n"
- ".inst 0x4fb8e874 // sdot v20.4s, v3.16b, v24.4b[3]\n"
- "ldr q24, [x25, #0x50]\n"
- ".inst 0x4f98e2ca // sdot v10.4s, v22.16b, v24.4b[0]\n"
- ".inst 0x4fb8e2dd // sdot v29.4s, v22.16b, v24.4b[1]\n"
- ".inst 0x4f98eac9 // sdot v9.4s, v22.16b, v24.4b[2]\n"
- ".inst 0x4fb8ead4 // sdot v20.4s, v22.16b, v24.4b[3]\n"
- "ldr q24, [x25, #0x60]\n"
- ".inst 0x4f98e36a // sdot v10.4s, v27.16b, v24.4b[0]\n"
- ".inst 0x4fb8e37d // sdot v29.4s, v27.16b, v24.4b[1]\n"
- ".inst 0x4f98eb69 // sdot v9.4s, v27.16b, v24.4b[2]\n"
- ".inst 0x4fb8eb74 // sdot v20.4s, v27.16b, v24.4b[3]\n"
- "ldr q24, [x25, #0x70]\n"
- "add x25, x25, #0x88\n"
- ".inst 0x4f98e3ca // sdot v10.4s, v30.16b, v24.4b[0]\n"
- ".inst 0x4fb8e3dd // sdot v29.4s, v30.16b, v24.4b[1]\n"
- ".inst 0x4f98ebc9 // sdot v9.4s, v30.16b, v24.4b[2]\n"
- ".inst 0x4fb8ebd4 // sdot v20.4s, v30.16b, v24.4b[3]\n"
- "fmul v24.4s, v17.4s, v2.s[0]\n"
- "scvtf v10.4s, v10.4s, #0x4\n"
- "scvtf v29.4s, v29.4s, #0x4\n"
- "scvtf v9.4s, v9.4s, #0x4\n"
- "scvtf v20.4s, v20.4s, #0x4\n"
- "fmla v15.4s, v10.4s, v24.4s\n"
- "ldr q24, [x23, #0x0]\n"
- "fmul v10.4s, v17.4s, v2.s[1]\n"
- "fmla v19.4s, v29.4s, v10.4s\n"
- "ldr q10, [x23, #0x10]\n"
- "fmul v29.4s, v17.4s, v2.s[2]\n"
- "fmul v2.4s, v17.4s, v2.s[3]\n"
- "fmla v18.4s, v9.4s, v29.4s\n"
- "movi v9.4s, #0x0\n"
- "movi v29.4s, #0x0\n"
- ".inst 0x4f98e189 // sdot v9.4s, v12.16b, v24.4b[0]\n"
- ".inst 0x4fb8e19d // sdot v29.4s, v12.16b, v24.4b[1]\n"
- "fmla v14.4s, v20.4s, v2.4s\n"
- "movi v20.4s, #0x0\n"
- "movi v2.4s, #0x0\n"
- ".inst 0x4f98e994 // sdot v20.4s, v12.16b, v24.4b[2]\n"
- ".inst 0x4fb8e982 // sdot v2.4s, v12.16b, v24.4b[3]\n"
- "ldr q24, [x23, #0x20]\n"
- ".inst 0x4f8ae3e9 // sdot v9.4s, v31.16b, v10.4b[0]\n"
- ".inst 0x4faae3fd // sdot v29.4s, v31.16b, v10.4b[1]\n"
- ".inst 0x4f8aebf4 // sdot v20.4s, v31.16b, v10.4b[2]\n"
- ".inst 0x4faaebe2 // sdot v2.4s, v31.16b, v10.4b[3]\n"
- "ldr q10, [x23, #0x30]\n"
- ".inst 0x4f98e0c9 // sdot v9.4s, v6.16b, v24.4b[0]\n"
- ".inst 0x4fb8e0dd // sdot v29.4s, v6.16b, v24.4b[1]\n"
- ".inst 0x4f98e8d4 // sdot v20.4s, v6.16b, v24.4b[2]\n"
- ".inst 0x4fb8e8c2 // sdot v2.4s, v6.16b, v24.4b[3]\n"
- "ldr q24, [x23, #0x40]\n"
- ".inst 0x4f8ae389 // sdot v9.4s, v28.16b, v10.4b[0]\n"
- ".inst 0x4faae39d // sdot v29.4s, v28.16b, v10.4b[1]\n"
- ".inst 0x4f8aeb94 // sdot v20.4s, v28.16b, v10.4b[2]\n"
- ".inst 0x4faaeb82 // sdot v2.4s, v28.16b, v10.4b[3]\n"
- "ldr q10, [x23, #0x50]\n"
- ".inst 0x4f98e069 // sdot v9.4s, v3.16b, v24.4b[0]\n"
- ".inst 0x4fb8e07d // sdot v29.4s, v3.16b, v24.4b[1]\n"
- ".inst 0x4f98e874 // sdot v20.4s, v3.16b, v24.4b[2]\n"
- ".inst 0x4fb8e862 // sdot v2.4s, v3.16b, v24.4b[3]\n"
- "ldr q24, [x23, #0x60]\n"
- ".inst 0x4f8ae2c9 // sdot v9.4s, v22.16b, v10.4b[0]\n"
- ".inst 0x4faae2dd // sdot v29.4s, v22.16b, v10.4b[1]\n"
- ".inst 0x4f8aead4 // sdot v20.4s, v22.16b, v10.4b[2]\n"
- ".inst 0x4faaeac2 // sdot v2.4s, v22.16b, v10.4b[3]\n"
- "ldr q10, [x23, #0x70]\n"
- "add x23, x23, #0x88\n"
- ".inst 0x4f98e369 // sdot v9.4s, v27.16b, v24.4b[0]\n"
- ".inst 0x4fb8e37d // sdot v29.4s, v27.16b, v24.4b[1]\n"
- ".inst 0x4f98eb74 // sdot v20.4s, v27.16b, v24.4b[2]\n"
- ".inst 0x4fb8eb62 // sdot v2.4s, v27.16b, v24.4b[3]\n"
- "ldr q24, [x22, #0x0]\n"
- ".inst 0x4f8ae3c9 // sdot v9.4s, v30.16b, v10.4b[0]\n"
- ".inst 0x4faae3dd // sdot v29.4s, v30.16b, v10.4b[1]\n"
- ".inst 0x4f8aebd4 // sdot v20.4s, v30.16b, v10.4b[2]\n"
- ".inst 0x4faaebc2 // sdot v2.4s, v30.16b, v10.4b[3]\n"
- "fmul v10.4s, v17.4s, v26.s[0]\n"
- "scvtf v9.4s, v9.4s, #0x4\n"
- "scvtf v29.4s, v29.4s, #0x4\n"
- "scvtf v20.4s, v20.4s, #0x4\n"
- "scvtf v2.4s, v2.4s, #0x4\n"
- "fmla v11.4s, v9.4s, v10.4s\n"
- "ldr q9, [x22, #0x10]\n"
- "fmul v10.4s, v17.4s, v26.s[1]\n"
- "fmla v13.4s, v29.4s, v10.4s\n"
- "ldr d29, [x22, #-0x8]\n"
- "fmul v10.4s, v17.4s, v26.s[2]\n"
- "fmul v26.4s, v17.4s, v26.s[3]\n"
- "fcvtl v29.4s, v29.4h\n"
- "fmla v23.4s, v20.4s, v10.4s\n"
- "movi v20.4s, #0x0\n"
- "movi v10.4s, #0x0\n"
- "fmla v16.4s, v2.4s, v26.4s\n"
- "movi v26.4s, #0x0\n"
- "movi v2.4s, #0x0\n"
- ".inst 0x4f98e194 // sdot v20.4s, v12.16b, v24.4b[0]\n"
- ".inst 0x4fb8e18a // sdot v10.4s, v12.16b, v24.4b[1]\n"
- ".inst 0x4f98e99a // sdot v26.4s, v12.16b, v24.4b[2]\n"
- ".inst 0x4fb8e982 // sdot v2.4s, v12.16b, v24.4b[3]\n"
- "ldr q24, [x22, #0x20]\n"
- ".inst 0x4f89e3f4 // sdot v20.4s, v31.16b, v9.4b[0]\n"
- ".inst 0x4fa9e3ea // sdot v10.4s, v31.16b, v9.4b[1]\n"
- ".inst 0x4f89ebfa // sdot v26.4s, v31.16b, v9.4b[2]\n"
- ".inst 0x4fa9ebe2 // sdot v2.4s, v31.16b, v9.4b[3]\n"
- "ldr q9, [x22, #0x30]\n"
- ".inst 0x4f98e0d4 // sdot v20.4s, v6.16b, v24.4b[0]\n"
- ".inst 0x4fb8e0ca // sdot v10.4s, v6.16b, v24.4b[1]\n"
- ".inst 0x4f98e8da // sdot v26.4s, v6.16b, v24.4b[2]\n"
- ".inst 0x4fb8e8c2 // sdot v2.4s, v6.16b, v24.4b[3]\n"
- "ldr q24, [x22, #0x40]\n"
- ".inst 0x4f89e394 // sdot v20.4s, v28.16b, v9.4b[0]\n"
- ".inst 0x4fa9e38a // sdot v10.4s, v28.16b, v9.4b[1]\n"
- ".inst 0x4f89eb9a // sdot v26.4s, v28.16b, v9.4b[2]\n"
- ".inst 0x4fa9eb82 // sdot v2.4s, v28.16b, v9.4b[3]\n"
- "ldr q9, [x22, #0x50]\n"
- ".inst 0x4f98e074 // sdot v20.4s, v3.16b, v24.4b[0]\n"
- ".inst 0x4fb8e06a // sdot v10.4s, v3.16b, v24.4b[1]\n"
- ".inst 0x4f98e87a // sdot v26.4s, v3.16b, v24.4b[2]\n"
- ".inst 0x4fb8e862 // sdot v2.4s, v3.16b, v24.4b[3]\n"
- "ldr q24, [x22, #0x60]\n"
- ".inst 0x4f89e2d4 // sdot v20.4s, v22.16b, v9.4b[0]\n"
- ".inst 0x4fa9e2ca // sdot v10.4s, v22.16b, v9.4b[1]\n"
- ".inst 0x4f89eada // sdot v26.4s, v22.16b, v9.4b[2]\n"
- ".inst 0x4fa9eac2 // sdot v2.4s, v22.16b, v9.4b[3]\n"
- "ldr q9, [x22, #0x70]\n"
- "add x22, x22, #0x88\n"
- ".inst 0x4f98e374 // sdot v20.4s, v27.16b, v24.4b[0]\n"
- ".inst 0x4fb8e36a // sdot v10.4s, v27.16b, v24.4b[1]\n"
- ".inst 0x4f98eb7a // sdot v26.4s, v27.16b, v24.4b[2]\n"
- ".inst 0x4fb8eb62 // sdot v2.4s, v27.16b, v24.4b[3]\n"
- "ldr q24, [x21, #0x0]\n"
- ".inst 0x4f89e3d4 // sdot v20.4s, v30.16b, v9.4b[0]\n"
- ".inst 0x4fa9e3ca // sdot v10.4s, v30.16b, v9.4b[1]\n"
- ".inst 0x4f89ebda // sdot v26.4s, v30.16b, v9.4b[2]\n"
- ".inst 0x4fa9ebc2 // sdot v2.4s, v30.16b, v9.4b[3]\n"
- "fmul v9.4s, v17.4s, v29.s[0]\n"
- "scvtf v20.4s, v20.4s, #0x4\n"
- "scvtf v10.4s, v10.4s, #0x4\n"
- "scvtf v26.4s, v26.4s, #0x4\n"
- "scvtf v2.4s, v2.4s, #0x4\n"
- "fmla v25.4s, v20.4s, v9.4s\n"
- "ldr q9, [x21, #0x10]\n"
- "fmul v20.4s, v17.4s, v29.s[1]\n"
- "fmla v7.4s, v10.4s, v20.4s\n"
- "ldr d20, [x21, #-0x8]\n"
- "fmul v10.4s, v17.4s, v29.s[2]\n"
- "fmul v29.4s, v17.4s, v29.s[3]\n"
- "fcvtl v20.4s, v20.4h\n"
- "fmla v0.4s, v26.4s, v10.4s\n"
- "movi v26.4s, #0x0\n"
- "movi v10.4s, #0x0\n"
- "fmla v4.4s, v2.4s, v29.4s\n"
- "movi v2.4s, #0x0\n"
- "movi v29.4s, #0x0\n"
- ".inst 0x4f98e19a // sdot v26.4s, v12.16b, v24.4b[0]\n"
- ".inst 0x4fb8e18a // sdot v10.4s, v12.16b, v24.4b[1]\n"
- ".inst 0x4f98e982 // sdot v2.4s, v12.16b, v24.4b[2]\n"
- ".inst 0x4fb8e99d // sdot v29.4s, v12.16b, v24.4b[3]\n"
- "ldr q12, [x21, #0x20]\n"
- "fmul v24.4s, v17.4s, v20.s[0]\n"
- ".inst 0x4f89e3fa // sdot v26.4s, v31.16b, v9.4b[0]\n"
- ".inst 0x4fa9e3ea // sdot v10.4s, v31.16b, v9.4b[1]\n"
- ".inst 0x4f89ebe2 // sdot v2.4s, v31.16b, v9.4b[2]\n"
- ".inst 0x4fa9ebfd // sdot v29.4s, v31.16b, v9.4b[3]\n"
- "ldr q9, [x21, #0x30]\n"
- "fmul v31.4s, v17.4s, v20.s[1]\n"
- ".inst 0x4f8ce0da // sdot v26.4s, v6.16b, v12.4b[0]\n"
- ".inst 0x4face0ca // sdot v10.4s, v6.16b, v12.4b[1]\n"
- ".inst 0x4f8ce8c2 // sdot v2.4s, v6.16b, v12.4b[2]\n"
- ".inst 0x4face8dd // sdot v29.4s, v6.16b, v12.4b[3]\n"
- "ldr q12, [x21, #0x40]\n"
- "fmul v6.4s, v17.4s, v20.s[2]\n"
- "fmul v20.4s, v17.4s, v20.s[3]\n"
- ".inst 0x4f89e39a // sdot v26.4s, v28.16b, v9.4b[0]\n"
- ".inst 0x4fa9e38a // sdot v10.4s, v28.16b, v9.4b[1]\n"
- ".inst 0x4f89eb82 // sdot v2.4s, v28.16b, v9.4b[2]\n"
- ".inst 0x4fa9eb9d // sdot v29.4s, v28.16b, v9.4b[3]\n"
- "ldr q9, [x21, #0x50]\n"
- ".inst 0x4f8ce07a // sdot v26.4s, v3.16b, v12.4b[0]\n"
- ".inst 0x4face06a // sdot v10.4s, v3.16b, v12.4b[1]\n"
- ".inst 0x4f8ce862 // sdot v2.4s, v3.16b, v12.4b[2]\n"
- ".inst 0x4face87d // sdot v29.4s, v3.16b, v12.4b[3]\n"
- "ldr q12, [x21, #0x60]\n"
- ".inst 0x4f89e2da // sdot v26.4s, v22.16b, v9.4b[0]\n"
- ".inst 0x4fa9e2ca // sdot v10.4s, v22.16b, v9.4b[1]\n"
- ".inst 0x4f89eac2 // sdot v2.4s, v22.16b, v9.4b[2]\n"
- ".inst 0x4fa9eadd // sdot v29.4s, v22.16b, v9.4b[3]\n"
- "ldr q17, [x21, #0x70]\n"
- "add x21, x21, #0x88\n"
- ".inst 0x4f8ce37a // sdot v26.4s, v27.16b, v12.4b[0]\n"
- ".inst 0x4face36a // sdot v10.4s, v27.16b, v12.4b[1]\n"
- ".inst 0x4f8ceb62 // sdot v2.4s, v27.16b, v12.4b[2]\n"
- ".inst 0x4faceb7d // sdot v29.4s, v27.16b, v12.4b[3]\n"
- ".inst 0x4f91e3da // sdot v26.4s, v30.16b, v17.4b[0]\n"
- ".inst 0x4fb1e3ca // sdot v10.4s, v30.16b, v17.4b[1]\n"
- ".inst 0x4f91ebc2 // sdot v2.4s, v30.16b, v17.4b[2]\n"
- ".inst 0x4fb1ebdd // sdot v29.4s, v30.16b, v17.4b[3]\n"
- "scvtf v26.4s, v26.4s, #0x4\n"
- "scvtf v10.4s, v10.4s, #0x4\n"
- "fmla v5.4s, v26.4s, v24.4s\n"
- "scvtf v2.4s, v2.4s, #0x4\n"
- "scvtf v29.4s, v29.4s, #0x4\n"
- "fmla v21.4s, v10.4s, v31.4s\n"
- "fmla v8.4s, v2.4s, v6.4s\n"
- "fmla v1.4s, v29.4s, v20.4s\n"
- "bgt 3b\n"
- "mov x20, %x[res_ptr]\n"
- "subs x27, x27, #0x4\n"
- "add %x[res_ptr], %x[res_ptr], #0x10\n"
- "str q15, [x20, #0x0]\n"
- "add x20, x20, %x[res_stride]\n"
- "str q19, [x20, #0x0]\n"
- "add x20, x20, %x[res_stride]\n"
- "str q18, [x20, #0x0]\n"
- "add x20, x20, %x[res_stride]\n"
- "str q14, [x20, #0x0]\n"
- "add x20, x20, %x[res_stride]\n"
- "str q11, [x20, #0x0]\n"
- "add x20, x20, %x[res_stride]\n"
- "str q13, [x20, #0x0]\n"
- "add x20, x20, %x[res_stride]\n"
- "str q23, [x20, #0x0]\n"
- "add x20, x20, %x[res_stride]\n"
- "str q16, [x20, #0x0]\n"
- "add x20, x20, %x[res_stride]\n"
- "str q25, [x20, #0x0]\n"
- "add x20, x20, %x[res_stride]\n"
- "str q7, [x20, #0x0]\n"
- "add x20, x20, %x[res_stride]\n"
- "str q0, [x20, #0x0]\n"
- "add x20, x20, %x[res_stride]\n"
- "str q4, [x20, #0x0]\n"
- "add x20, x20, %x[res_stride]\n"
- "str q5, [x20, #0x0]\n"
- "add x20, x20, %x[res_stride]\n"
- "str q21, [x20, #0x0]\n"
- "add x20, x20, %x[res_stride]\n"
- "str q8, [x20, #0x0]\n"
- "add x20, x20, %x[res_stride]\n"
- "str q1, [x20, #0x0]\n"
- "bne 2b\n"
- "mov x20, #0x4\n"
- "sub x10, x10, #0x10\n"
- "cmp x10, #0x10\n"
- "mov %x[res_ptr], x26\n"
- "madd %x[a_ptr], x20, x9, %x[a_ptr]\n"
- "bge 1b\n"
- "4:" // Row loop skip
- "cbz x10, 9f\n"
- "5:" // Row tail: Row loop
- "add x24, %x[b_ptr], #0x8\n"
- "mov x23, %x[nc]\n"
- "add x22, %x[res_ptr], %x[res_stride], LSL #2\n"
- "6:" // Row tail: Column loop
- "movi v15.16b, #0x0\n"
- "movi v19.16b, #0x0\n"
- "add x25, %x[a_ptr], #0x8\n"
- "mov x21, %x[nb]\n"
- "movi v18.16b, #0x0\n"
- "movi v14.16b, #0x0\n"
- "7:" // Row tail: Block loop
- "ldr q7, [x24, #0x0]\n"
- "ldr q5, [x25, #0x0]\n"
- "movi v9.16b, #0x4\n"
- "movi v4.4s, #0x0\n"
- "ldr q3, [x24, #0x10]\n"
- "ldr q2, [x25, #0x10]\n"
- "movi v1.4s, #0x0\n"
- "movi v0.4s, #0x0\n"
- "ldr q13, [x24, #0x20]\n"
- "ldr q31, [x25, #0x20]\n"
- "movi v30.4s, #0x0\n"
- "movi v29.16b, #0xf0\n"
- "ldr q28, [x24, #0x30]\n"
- "ldr q27, [x25, #0x30]\n"
- "sshl v20.16b, v7.16b, v9.16b\n"
- "sub x20, x24, #0x8\n"
- "ldr q26, [x25, #0x40]\n"
- "ldr q25, [x25, #0x50]\n"
- "sshl v17.16b, v3.16b, v9.16b\n"
- "and v7.16b, v7.16b, v29.16b\n"
- "ldr q24, [x25, #0x60]\n"
- "ldr q16, [x25, #0x70]\n"
- "sshl v22.16b, v13.16b, v9.16b\n"
- "and v3.16b, v3.16b, v29.16b\n"
- "ldr d21, [x20, #0x0]\n"
- "ldr d12, [x25, #-0x8]\n"
- ".inst 0x4f85e284 // sdot v4.4s, v20.16b, v5.4b[0]\n"
- ".inst 0x4fa5e281 // sdot v1.4s, v20.16b, v5.4b[1]\n"
- ".inst 0x4f85ea80 // sdot v0.4s, v20.16b, v5.4b[2]\n"
- ".inst 0x4fa5ea9e // sdot v30.4s, v20.16b, v5.4b[3]\n"
- "sshl v9.16b, v28.16b, v9.16b\n"
- "subs x21, x21, #0x1\n"
- "and v13.16b, v13.16b, v29.16b\n"
- "and v28.16b, v28.16b, v29.16b\n"
- "add x25, x25, #0x88\n"
- "add x24, x24, #0x48\n"
- "fcvtl v21.4s, v21.4h\n"
- "fcvtl v12.4s, v12.4h\n"
- ".inst 0x4f82e224 // sdot v4.4s, v17.16b, v2.4b[0]\n"
- ".inst 0x4fa2e221 // sdot v1.4s, v17.16b, v2.4b[1]\n"
- ".inst 0x4f82ea20 // sdot v0.4s, v17.16b, v2.4b[2]\n"
- ".inst 0x4fa2ea3e // sdot v30.4s, v17.16b, v2.4b[3]\n"
- "fmul v11.4s, v21.4s, v12.s[0]\n"
- "fmul v23.4s, v21.4s, v12.s[1]\n"
- "fmul v17.4s, v21.4s, v12.s[2]\n"
- ".inst 0x4f9fe2c4 // sdot v4.4s, v22.16b, v31.4b[0]\n"
- "fmul v6.4s, v21.4s, v12.s[3]\n"
- ".inst 0x4fbfe2c1 // sdot v1.4s, v22.16b, v31.4b[1]\n"
- ".inst 0x4f9feac0 // sdot v0.4s, v22.16b, v31.4b[2]\n"
- ".inst 0x4fbfeade // sdot v30.4s, v22.16b, v31.4b[3]\n"
- ".inst 0x4f9be124 // sdot v4.4s, v9.16b, v27.4b[0]\n"
- ".inst 0x4fbbe121 // sdot v1.4s, v9.16b, v27.4b[1]\n"
- ".inst 0x4f9be920 // sdot v0.4s, v9.16b, v27.4b[2]\n"
- ".inst 0x4fbbe93e // sdot v30.4s, v9.16b, v27.4b[3]\n"
- ".inst 0x4f9ae0e4 // sdot v4.4s, v7.16b, v26.4b[0]\n"
- ".inst 0x4fbae0e1 // sdot v1.4s, v7.16b, v26.4b[1]\n"
- ".inst 0x4f9ae8e0 // sdot v0.4s, v7.16b, v26.4b[2]\n"
- ".inst 0x4fbae8fe // sdot v30.4s, v7.16b, v26.4b[3]\n"
- ".inst 0x4f99e064 // sdot v4.4s, v3.16b, v25.4b[0]\n"
- ".inst 0x4fb9e061 // sdot v1.4s, v3.16b, v25.4b[1]\n"
- ".inst 0x4f99e860 // sdot v0.4s, v3.16b, v25.4b[2]\n"
- ".inst 0x4fb9e87e // sdot v30.4s, v3.16b, v25.4b[3]\n"
- ".inst 0x4f98e1a4 // sdot v4.4s, v13.16b, v24.4b[0]\n"
- ".inst 0x4fb8e1a1 // sdot v1.4s, v13.16b, v24.4b[1]\n"
- ".inst 0x4f98e9a0 // sdot v0.4s, v13.16b, v24.4b[2]\n"
- ".inst 0x4fb8e9be // sdot v30.4s, v13.16b, v24.4b[3]\n"
- ".inst 0x4f90e384 // sdot v4.4s, v28.16b, v16.4b[0]\n"
- ".inst 0x4fb0e381 // sdot v1.4s, v28.16b, v16.4b[1]\n"
- ".inst 0x4f90eb80 // sdot v0.4s, v28.16b, v16.4b[2]\n"
- ".inst 0x4fb0eb9e // sdot v30.4s, v28.16b, v16.4b[3]\n"
- "scvtf v4.4s, v4.4s, #0x4\n"
- "scvtf v1.4s, v1.4s, #0x4\n"
- "scvtf v0.4s, v0.4s, #0x4\n"
- "fmla v15.4s, v4.4s, v11.4s\n"
- "scvtf v30.4s, v30.4s, #0x4\n"
- "fmla v19.4s, v1.4s, v23.4s\n"
- "fmla v18.4s, v0.4s, v17.4s\n"
- "fmla v14.4s, v30.4s, v6.4s\n"
- "bgt 7b\n"
- "mov x20, %x[res_ptr]\n"
- "cmp x10, #0x1\n"
- "str q15, [x20, #0x0]\n"
- "add x20, x20, %x[res_stride]\n"
- "ble 8f\n"
- "cmp x10, #0x2\n"
- "str q19, [x20, #0x0]\n"
- "add x20, x20, %x[res_stride]\n"
- "ble 8f\n"
- "cmp x10, #0x3\n"
- "str q18, [x20, #0x0]\n"
- "add x20, x20, %x[res_stride]\n"
- "ble 8f\n"
- "str q14, [x20, #0x0]\n"
- "8:" // Row tail: Accumulator store skip
- "subs x23, x23, #0x4\n"
- "add %x[res_ptr], %x[res_ptr], #0x10\n"
- "bne 6b\n"
- "subs x10, x10, #0x4\n"
- "add %x[a_ptr], %x[a_ptr], x9\n"
- "mov %x[res_ptr], x22\n"
- "bgt 5b\n"
- "9:" // Row tail: Row loop skip
- : [a_ptr] "+&r" (a_ptr), [res_ptr] "+&r" (res_ptr)
- : [b_ptr] "r" (b_ptr), [nr] "r" (nr), [nb] "r" (nb), [res_stride] "r" (res_stride), [nc] "r" (nc)
- : "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "x9", "x10", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28"
- );
-#else
- float sumf[4][4];
- int sumi;
-
- for (int y = 0; y < nr / 4; y++) {
- const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb);
- for (int x = 0; x < nc / ncols_interleaved; x++) {
- const block_q4_0x4 * b_ptr = (const block_q4_0x4 *) vx + (x * nb);
- for (int m = 0; m < 4; m++) {
- for (int j = 0; j < ncols_interleaved; j++) sumf[m][j] = 0.0;
- }
- for (int l = 0; l < nb; l++) {
- for (int k = 0; k < (qk / (2 * blocklen)); k++) {
- for (int m = 0; m < 4; m++) {
- for (int j = 0; j < ncols_interleaved; j++) {
- sumi = 0;
- for (int i = 0; i < blocklen; ++i) {
- const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] << 4);
- const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF0);
- sumi += ((v0 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i]) +
- (v1 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i + qk / 2 * 4])) >> 4;
+ for (int y = 0; y < nr / 4; y++) {
+ const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb);
+ for (int x = 0; x < nc / ncols_interleaved; x++) {
+ const block_q4_0x4 * b_ptr = (const block_q4_0x4 *) vx + (x * nb);
+ for (int m = 0; m < 4; m++) {
+ for (int j = 0; j < ncols_interleaved; j++) sumf[m][j] = 0.0;
+ }
+ for (int l = 0; l < nb; l++) {
+ for (int k = 0; k < (qk / (2 * blocklen)); k++) {
+ for (int m = 0; m < 4; m++) {
+ for (int j = 0; j < ncols_interleaved; j++) {
+ sumi = 0;
+ for (int i = 0; i < blocklen; ++i) {
+ const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] << 4);
+ const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF0);
+ sumi += ((v0 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i]) +
+ (v1 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i + qk / 2 * 4])) >> 4;
+ }
+ sumf[m][j] += sumi * GGML_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_FP16_TO_FP32(a_ptr[l].d[m]);
}
- sumf[m][j] += sumi * GGML_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_FP16_TO_FP32(a_ptr[l].d[m]);
}
}
}
- }
- for (int m = 0; m < 4; m++) {
- for (int j = 0; j < ncols_interleaved; j++)
- s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j];
+ for (int m = 0; m < 4; m++) {
+ for (int j = 0; j < ncols_interleaved; j++)
+ s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j];
+ }
}
}
}
-#endif
}
void ggml_gemm_q4_0_4x8_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, const void * restrict vy, int nr, int nc) {
@@ -1518,413 +1564,406 @@ void ggml_gemm_q4_0_4x8_q8_0(int n, float * restrict s, size_t bs, const void *
UNUSED(ncols_interleaved);
UNUSED(blocklen);
-#if defined(__ARM_FEATURE_SVE) && defined(__ARM_FEATURE_MATMUL_INT8)
- if (ggml_sve_cnt_b == QK8_0) {
- GGML_ASSERT(!(ggml_cpu_has_sve() && (ggml_sve_cnt_b == QK8_0)) &&
- "__ARM_FEATURE_SVE defined, use the Q4_0_8_8 quantization format for optimal performance");
- }
-#endif
-#if defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8) && ! ((defined(_MSC_VER)) && ! defined(__clang__))
- const void * b_ptr = vx;
- const void * a_ptr = vy;
- float * res_ptr = s;
- size_t res_stride = bs * sizeof(float);
+#if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8)
+ if (ggml_cpu_has_neon() && ggml_cpu_has_matmul_int8()) {
+ const void * b_ptr = vx;
+ const void * a_ptr = vy;
+ float * res_ptr = s;
+ size_t res_stride = bs * sizeof(float);
- __asm__ __volatile__(
- "mov x10, %x[nr]\n"
- "mov x9, #0x88\n"
- "cmp x10, #0x10\n"
- "mul x9, %x[nb], x9\n"
- "blt 4f\n"
- "1:" // Row loop
- "add x28, %x[b_ptr], #0x8\n"
- "mov x27, %x[nc]\n"
- "add x26, %x[res_ptr], %x[res_stride], LSL #4\n"
- "2:" // Column loop
- "add x25, %x[a_ptr], #0x8\n"
- "movi v2.16b, #0x0\n"
- "movi v10.16b, #0x0\n"
- "mov x24, %x[nb]\n"
- "add x23, x25, x9\n"
- "movi v12.16b, #0x0\n"
- "movi v28.16b, #0x0\n"
- "add x22, x23, x9\n"
- "movi v11.16b, #0x0\n"
- "movi v13.16b, #0x0\n"
- "add x21, x22, x9\n"
- "movi v22.16b, #0x0\n"
- "movi v23.16b, #0x0\n"
- "movi v25.16b, #0x0\n"
- "movi v5.16b, #0x0\n"
- "movi v7.16b, #0x0\n"
- "movi v4.16b, #0x0\n"
- "movi v6.16b, #0x0\n"
- "movi v30.16b, #0x0\n"
- "movi v24.16b, #0x0\n"
- "movi v14.16b, #0x0\n"
- "3:" // Block loop
- "ldr q21, [x28, #0x0]\n"
- "ldr q16, [x28, #0x10]\n"
- "movi v1.16b, #0x4\n"
- "movi v19.4s, #0x0\n"
- "ldr q27, [x25, #0x0]\n"
- "ldr q15, [x25, #0x10]\n"
- "movi v26.4s, #0x0\n"
- "movi v18.4s, #0x0\n"
- "ldr q29, [x28, #0x20]\n"
- "ldr q3, [x28, #0x30]\n"
- "movi v17.4s, #0x0\n"
- "movi v0.16b, #0xf0\n"
- "ldr d20, [x25, #-0x8]\n"
- "ldr d9, [x23, #-0x8]\n"
- "sshl v8.16b, v21.16b, v1.16b\n"
- "sshl v31.16b, v16.16b, v1.16b\n"
- "and v21.16b, v21.16b, v0.16b\n"
- "and v16.16b, v16.16b, v0.16b\n"
- "sub x20, x28, #0x8\n"
- "subs x24, x24, #0x1\n"
- "add x28, x28, #0x48\n"
- ".inst 0x4e88a773 // smmla v19.4s, v27.16b, v8.16b\n"
- ".inst 0x4e9fa77a // smmla v26.4s, v27.16b, v31.16b\n"
- "ldr q27, [x25, #0x20]\n"
- ".inst 0x4e88a5f2 // smmla v18.4s, v15.16b, v8.16b\n"
- ".inst 0x4e9fa5f1 // smmla v17.4s, v15.16b, v31.16b\n"
- "sshl v15.16b, v29.16b, v1.16b\n"
- "sshl v1.16b, v3.16b, v1.16b\n"
- "and v29.16b, v29.16b, v0.16b\n"
- "and v3.16b, v3.16b, v0.16b\n"
- "ldr q0, [x25, #0x30]\n"
- "fcvtl v20.4s, v20.4h\n"
- ".inst 0x4e8fa773 // smmla v19.4s, v27.16b, v15.16b\n"
- "fcvtl v9.4s, v9.4h\n"
- ".inst 0x4e81a77a // smmla v26.4s, v27.16b, v1.16b\n"
- "ldr q27, [x25, #0x40]\n"
- ".inst 0x4e8fa412 // smmla v18.4s, v0.16b, v15.16b\n"
- ".inst 0x4e81a411 // smmla v17.4s, v0.16b, v1.16b\n"
- "ldr q0, [x25, #0x50]\n"
- ".inst 0x4e95a773 // smmla v19.4s, v27.16b, v21.16b\n"
- ".inst 0x4e90a77a // smmla v26.4s, v27.16b, v16.16b\n"
- "ldr q27, [x25, #0x60]\n"
- ".inst 0x4e95a412 // smmla v18.4s, v0.16b, v21.16b\n"
- ".inst 0x4e90a411 // smmla v17.4s, v0.16b, v16.16b\n"
- "ldr q0, [x25, #0x70]\n"
- "add x25, x25, #0x88\n"
- ".inst 0x4e9da773 // smmla v19.4s, v27.16b, v29.16b\n"
- ".inst 0x4e83a77a // smmla v26.4s, v27.16b, v3.16b\n"
- "ldr d27, [x20, #0x0]\n"
- ".inst 0x4e9da412 // smmla v18.4s, v0.16b, v29.16b\n"
- ".inst 0x4e83a411 // smmla v17.4s, v0.16b, v3.16b\n"
- "fcvtl v27.4s, v27.4h\n"
- "uzp1 v0.2d, v19.2d, v26.2d\n"
- "uzp2 v26.2d, v19.2d, v26.2d\n"
- "fmul v19.4s, v27.4s, v20.s[0]\n"
- "scvtf v0.4s, v0.4s, #0x4\n"
- "scvtf v26.4s, v26.4s, #0x4\n"
- "fmla v2.4s, v0.4s, v19.4s\n"
- "ldr q19, [x23, #0x0]\n"
- "uzp1 v0.2d, v18.2d, v17.2d\n"
- "uzp2 v18.2d, v18.2d, v17.2d\n"
- "fmul v17.4s, v27.4s, v20.s[1]\n"
- "scvtf v0.4s, v0.4s, #0x4\n"
- "scvtf v18.4s, v18.4s, #0x4\n"
- "fmla v10.4s, v26.4s, v17.4s\n"
- "ldr q17, [x23, #0x10]\n"
- "fmul v26.4s, v27.4s, v20.s[2]\n"
- "fmul v20.4s, v27.4s, v20.s[3]\n"
- "fmla v12.4s, v0.4s, v26.4s\n"
- "ldr d0, [x22, #-0x8]\n"
- "ldr d26, [x21, #-0x8]\n"
- "fcvtl v0.4s, v0.4h\n"
- "fmla v28.4s, v18.4s, v20.4s\n"
- "movi v20.4s, #0x0\n"
- "movi v18.4s, #0x0\n"
- ".inst 0x4e88a674 // smmla v20.4s, v19.16b, v8.16b\n"
- ".inst 0x4e9fa672 // smmla v18.4s, v19.16b, v31.16b\n"
- "ldr q19, [x23, #0x20]\n"
- "fcvtl v26.4s, v26.4h\n"
- ".inst 0x4e8fa674 // smmla v20.4s, v19.16b, v15.16b\n"
- ".inst 0x4e81a672 // smmla v18.4s, v19.16b, v1.16b\n"
- "ldr q19, [x23, #0x40]\n"
- ".inst 0x4e95a674 // smmla v20.4s, v19.16b, v21.16b\n"
- ".inst 0x4e90a672 // smmla v18.4s, v19.16b, v16.16b\n"
- "ldr q19, [x23, #0x60]\n"
- ".inst 0x4e9da674 // smmla v20.4s, v19.16b, v29.16b\n"
- ".inst 0x4e83a672 // smmla v18.4s, v19.16b, v3.16b\n"
- "uzp1 v19.2d, v20.2d, v18.2d\n"
- "scvtf v19.4s, v19.4s, #0x4\n"
- "uzp2 v20.2d, v20.2d, v18.2d\n"
- "fmul v18.4s, v27.4s, v9.s[0]\n"
- "scvtf v20.4s, v20.4s, #0x4\n"
- "fmla v11.4s, v19.4s, v18.4s\n"
- "ldr q18, [x22, #0x0]\n"
- "fmul v19.4s, v27.4s, v9.s[1]\n"
- "fmla v13.4s, v20.4s, v19.4s\n"
- "movi v19.4s, #0x0\n"
- "movi v20.4s, #0x0\n"
- ".inst 0x4e88a633 // smmla v19.4s, v17.16b, v8.16b\n"
- ".inst 0x4e9fa634 // smmla v20.4s, v17.16b, v31.16b\n"
- "ldr q17, [x23, #0x30]\n"
- ".inst 0x4e8fa633 // smmla v19.4s, v17.16b, v15.16b\n"
- ".inst 0x4e81a634 // smmla v20.4s, v17.16b, v1.16b\n"
- "ldr q17, [x23, #0x50]\n"
- ".inst 0x4e95a633 // smmla v19.4s, v17.16b, v21.16b\n"
- ".inst 0x4e90a634 // smmla v20.4s, v17.16b, v16.16b\n"
- "ldr q17, [x23, #0x70]\n"
- "add x23, x23, #0x88\n"
- ".inst 0x4e9da633 // smmla v19.4s, v17.16b, v29.16b\n"
- ".inst 0x4e83a634 // smmla v20.4s, v17.16b, v3.16b\n"
- "uzp1 v17.2d, v19.2d, v20.2d\n"
- "scvtf v17.4s, v17.4s, #0x4\n"
- "uzp2 v20.2d, v19.2d, v20.2d\n"
- "fmul v19.4s, v27.4s, v9.s[2]\n"
- "fmul v9.4s, v27.4s, v9.s[3]\n"
- "scvtf v20.4s, v20.4s, #0x4\n"
- "fmla v22.4s, v17.4s, v19.4s\n"
- "ldr q17, [x22, #0x10]\n"
- "movi v19.4s, #0x0\n"
- ".inst 0x4e88a653 // smmla v19.4s, v18.16b, v8.16b\n"
- "fmla v23.4s, v20.4s, v9.4s\n"
- "movi v20.4s, #0x0\n"
- "movi v9.4s, #0x0\n"
- ".inst 0x4e9fa654 // smmla v20.4s, v18.16b, v31.16b\n"
- "ldr q18, [x22, #0x20]\n"
- ".inst 0x4e88a629 // smmla v9.4s, v17.16b, v8.16b\n"
- ".inst 0x4e8fa653 // smmla v19.4s, v18.16b, v15.16b\n"
- ".inst 0x4e81a654 // smmla v20.4s, v18.16b, v1.16b\n"
- "ldr q18, [x22, #0x40]\n"
- ".inst 0x4e95a653 // smmla v19.4s, v18.16b, v21.16b\n"
- ".inst 0x4e90a654 // smmla v20.4s, v18.16b, v16.16b\n"
- "ldr q18, [x22, #0x60]\n"
- ".inst 0x4e9da653 // smmla v19.4s, v18.16b, v29.16b\n"
- ".inst 0x4e83a654 // smmla v20.4s, v18.16b, v3.16b\n"
- "movi v18.4s, #0x0\n"
- ".inst 0x4e9fa632 // smmla v18.4s, v17.16b, v31.16b\n"
- "ldr q17, [x22, #0x30]\n"
- ".inst 0x4e8fa629 // smmla v9.4s, v17.16b, v15.16b\n"
- ".inst 0x4e81a632 // smmla v18.4s, v17.16b, v1.16b\n"
- "ldr q17, [x22, #0x50]\n"
- ".inst 0x4e95a629 // smmla v9.4s, v17.16b, v21.16b\n"
- ".inst 0x4e90a632 // smmla v18.4s, v17.16b, v16.16b\n"
- "ldr q17, [x22, #0x70]\n"
- "add x22, x22, #0x88\n"
- ".inst 0x4e9da629 // smmla v9.4s, v17.16b, v29.16b\n"
- ".inst 0x4e83a632 // smmla v18.4s, v17.16b, v3.16b\n"
- "uzp1 v17.2d, v19.2d, v20.2d\n"
- "uzp2 v20.2d, v19.2d, v20.2d\n"
- "fmul v19.4s, v27.4s, v0.s[0]\n"
- "scvtf v17.4s, v17.4s, #0x4\n"
- "scvtf v20.4s, v20.4s, #0x4\n"
- "fmla v25.4s, v17.4s, v19.4s\n"
- "ldr q19, [x21, #0x0]\n"
- "fmul v17.4s, v27.4s, v0.s[1]\n"
- "fmla v5.4s, v20.4s, v17.4s\n"
- "ldr q17, [x21, #0x10]\n"
- "uzp1 v20.2d, v9.2d, v18.2d\n"
- "uzp2 v9.2d, v9.2d, v18.2d\n"
- "fmul v18.4s, v27.4s, v0.s[2]\n"
- "fmul v0.4s, v27.4s, v0.s[3]\n"
- "scvtf v20.4s, v20.4s, #0x4\n"
- "scvtf v9.4s, v9.4s, #0x4\n"
- "fmla v7.4s, v20.4s, v18.4s\n"
- "movi v20.4s, #0x0\n"
- "movi v18.4s, #0x0\n"
- ".inst 0x4e88a674 // smmla v20.4s, v19.16b, v8.16b\n"
- ".inst 0x4e9fa672 // smmla v18.4s, v19.16b, v31.16b\n"
- "ldr q19, [x21, #0x20]\n"
- "fmla v4.4s, v9.4s, v0.4s\n"
- "movi v9.4s, #0x0\n"
- "movi v0.4s, #0x0\n"
- ".inst 0x4e88a629 // smmla v9.4s, v17.16b, v8.16b\n"
- "fmul v8.4s, v27.4s, v26.s[0]\n"
- ".inst 0x4e9fa620 // smmla v0.4s, v17.16b, v31.16b\n"
- "ldr q17, [x21, #0x30]\n"
- ".inst 0x4e8fa674 // smmla v20.4s, v19.16b, v15.16b\n"
- "fmul v31.4s, v27.4s, v26.s[1]\n"
- ".inst 0x4e81a672 // smmla v18.4s, v19.16b, v1.16b\n"
- "ldr q19, [x21, #0x40]\n"
- ".inst 0x4e8fa629 // smmla v9.4s, v17.16b, v15.16b\n"
- "fmul v15.4s, v27.4s, v26.s[2]\n"
- "fmul v27.4s, v27.4s, v26.s[3]\n"
- ".inst 0x4e81a620 // smmla v0.4s, v17.16b, v1.16b\n"
- "ldr q1, [x21, #0x50]\n"
- ".inst 0x4e95a674 // smmla v20.4s, v19.16b, v21.16b\n"
- ".inst 0x4e90a672 // smmla v18.4s, v19.16b, v16.16b\n"
- "ldr q26, [x21, #0x60]\n"
- ".inst 0x4e95a429 // smmla v9.4s, v1.16b, v21.16b\n"
- ".inst 0x4e90a420 // smmla v0.4s, v1.16b, v16.16b\n"
- "ldr q21, [x21, #0x70]\n"
- "add x21, x21, #0x88\n"
- ".inst 0x4e9da754 // smmla v20.4s, v26.16b, v29.16b\n"
- ".inst 0x4e83a752 // smmla v18.4s, v26.16b, v3.16b\n"
- ".inst 0x4e9da6a9 // smmla v9.4s, v21.16b, v29.16b\n"
- ".inst 0x4e83a6a0 // smmla v0.4s, v21.16b, v3.16b\n"
- "uzp1 v29.2d, v20.2d, v18.2d\n"
- "uzp2 v21.2d, v20.2d, v18.2d\n"
- "scvtf v29.4s, v29.4s, #0x4\n"
- "uzp1 v18.2d, v9.2d, v0.2d\n"
- "uzp2 v16.2d, v9.2d, v0.2d\n"
- "scvtf v21.4s, v21.4s, #0x4\n"
- "fmla v6.4s, v29.4s, v8.4s\n"
- "scvtf v18.4s, v18.4s, #0x4\n"
- "scvtf v16.4s, v16.4s, #0x4\n"
- "fmla v30.4s, v21.4s, v31.4s\n"
- "fmla v24.4s, v18.4s, v15.4s\n"
- "fmla v14.4s, v16.4s, v27.4s\n"
- "bgt 3b\n"
- "mov x20, %x[res_ptr]\n"
- "subs x27, x27, #0x4\n"
- "add %x[res_ptr], %x[res_ptr], #0x10\n"
- "str q2, [x20, #0x0]\n"
- "add x20, x20, %x[res_stride]\n"
- "str q10, [x20, #0x0]\n"
- "add x20, x20, %x[res_stride]\n"
- "str q12, [x20, #0x0]\n"
- "add x20, x20, %x[res_stride]\n"
- "str q28, [x20, #0x0]\n"
- "add x20, x20, %x[res_stride]\n"
- "str q11, [x20, #0x0]\n"
- "add x20, x20, %x[res_stride]\n"
- "str q13, [x20, #0x0]\n"
- "add x20, x20, %x[res_stride]\n"
- "str q22, [x20, #0x0]\n"
- "add x20, x20, %x[res_stride]\n"
- "str q23, [x20, #0x0]\n"
- "add x20, x20, %x[res_stride]\n"
- "str q25, [x20, #0x0]\n"
- "add x20, x20, %x[res_stride]\n"
- "str q5, [x20, #0x0]\n"
- "add x20, x20, %x[res_stride]\n"
- "str q7, [x20, #0x0]\n"
- "add x20, x20, %x[res_stride]\n"
- "str q4, [x20, #0x0]\n"
- "add x20, x20, %x[res_stride]\n"
- "str q6, [x20, #0x0]\n"
- "add x20, x20, %x[res_stride]\n"
- "str q30, [x20, #0x0]\n"
- "add x20, x20, %x[res_stride]\n"
- "str q24, [x20, #0x0]\n"
- "add x20, x20, %x[res_stride]\n"
- "str q14, [x20, #0x0]\n"
- "bne 2b\n"
- "mov x20, #0x4\n"
- "sub x10, x10, #0x10\n"
- "cmp x10, #0x10\n"
- "mov %x[res_ptr], x26\n"
- "madd %x[a_ptr], x20, x9, %x[a_ptr]\n"
- "bge 1b\n"
- "4:" // Row loop skip
- "cbz x10, 9f\n"
- "5:" // Row tail: Row loop
- "add x24, %x[b_ptr], #0x8\n"
- "mov x23, %x[nc]\n"
- "add x22, %x[res_ptr], %x[res_stride], LSL #2\n"
- "6:" // Row tail: Column loop
- "movi v2.16b, #0x0\n"
- "movi v10.16b, #0x0\n"
- "add x25, %x[a_ptr], #0x8\n"
- "mov x21, %x[nb]\n"
- "movi v12.16b, #0x0\n"
- "movi v28.16b, #0x0\n"
- "7:" // Row tail: Block loop
- "ldr q6, [x24, #0x0]\n"
- "ldr q5, [x24, #0x10]\n"
- "movi v17.16b, #0x4\n"
- "movi v8.4s, #0x0\n"
- "ldr q4, [x25, #0x0]\n"
- "ldr q13, [x25, #0x10]\n"
- "movi v27.4s, #0x0\n"
- "movi v0.4s, #0x0\n"
- "ldr q31, [x24, #0x20]\n"
- "ldr q14, [x24, #0x30]\n"
- "movi v29.4s, #0x0\n"
- "movi v22.16b, #0xf0\n"
- "ldr q11, [x25, #0x20]\n"
- "ldr q23, [x25, #0x30]\n"
- "sshl v21.16b, v6.16b, v17.16b\n"
- "sshl v16.16b, v5.16b, v17.16b\n"
- "ldr q20, [x25, #0x40]\n"
- "ldr q26, [x25, #0x50]\n"
- "and v6.16b, v6.16b, v22.16b\n"
- "and v5.16b, v5.16b, v22.16b\n"
- "ldr q25, [x25, #0x60]\n"
- "ldr q3, [x25, #0x70]\n"
- "sshl v19.16b, v31.16b, v17.16b\n"
- "sshl v18.16b, v14.16b, v17.16b\n"
- "ldr d17, [x25, #-0x8]\n"
- ".inst 0x4e95a488 // smmla v8.4s, v4.16b, v21.16b\n"
- ".inst 0x4e90a49b // smmla v27.4s, v4.16b, v16.16b\n"
- "and v31.16b, v31.16b, v22.16b\n"
- ".inst 0x4e95a5a0 // smmla v0.4s, v13.16b, v21.16b\n"
- ".inst 0x4e90a5bd // smmla v29.4s, v13.16b, v16.16b\n"
- "and v14.16b, v14.16b, v22.16b\n"
- "sub x20, x24, #0x8\n"
- "ldr d16, [x20, #0x0]\n"
- "subs x21, x21, #0x1\n"
- "add x25, x25, #0x88\n"
- "fcvtl v17.4s, v17.4h\n"
- "add x24, x24, #0x48\n"
- ".inst 0x4e93a568 // smmla v8.4s, v11.16b, v19.16b\n"
- ".inst 0x4e92a57b // smmla v27.4s, v11.16b, v18.16b\n"
- ".inst 0x4e93a6e0 // smmla v0.4s, v23.16b, v19.16b\n"
- ".inst 0x4e92a6fd // smmla v29.4s, v23.16b, v18.16b\n"
- "fcvtl v16.4s, v16.4h\n"
- ".inst 0x4e86a688 // smmla v8.4s, v20.16b, v6.16b\n"
- ".inst 0x4e85a69b // smmla v27.4s, v20.16b, v5.16b\n"
- "fmul v23.4s, v16.4s, v17.s[0]\n"
- "fmul v21.4s, v16.4s, v17.s[1]\n"
- "fmul v1.4s, v16.4s, v17.s[2]\n"
- "fmul v20.4s, v16.4s, v17.s[3]\n"
- ".inst 0x4e86a740 // smmla v0.4s, v26.16b, v6.16b\n"
- ".inst 0x4e85a75d // smmla v29.4s, v26.16b, v5.16b\n"
- ".inst 0x4e9fa728 // smmla v8.4s, v25.16b, v31.16b\n"
- ".inst 0x4e8ea73b // smmla v27.4s, v25.16b, v14.16b\n"
- ".inst 0x4e9fa460 // smmla v0.4s, v3.16b, v31.16b\n"
- ".inst 0x4e8ea47d // smmla v29.4s, v3.16b, v14.16b\n"
- "uzp1 v19.2d, v8.2d, v27.2d\n"
- "uzp2 v18.2d, v8.2d, v27.2d\n"
- "scvtf v19.4s, v19.4s, #0x4\n"
- "uzp1 v17.2d, v0.2d, v29.2d\n"
- "uzp2 v16.2d, v0.2d, v29.2d\n"
- "scvtf v18.4s, v18.4s, #0x4\n"
- "fmla v2.4s, v19.4s, v23.4s\n"
- "scvtf v17.4s, v17.4s, #0x4\n"
- "scvtf v16.4s, v16.4s, #0x4\n"
- "fmla v10.4s, v18.4s, v21.4s\n"
- "fmla v12.4s, v17.4s, v1.4s\n"
- "fmla v28.4s, v16.4s, v20.4s\n"
- "bgt 7b\n"
- "mov x20, %x[res_ptr]\n"
- "cmp x10, #0x1\n"
- "str q2, [x20, #0x0]\n"
- "add x20, x20, %x[res_stride]\n"
- "ble 8f\n"
- "cmp x10, #0x2\n"
- "str q10, [x20, #0x0]\n"
- "add x20, x20, %x[res_stride]\n"
- "ble 8f\n"
- "cmp x10, #0x3\n"
- "str q12, [x20, #0x0]\n"
- "add x20, x20, %x[res_stride]\n"
- "ble 8f\n"
- "str q28, [x20, #0x0]\n"
- "8:" // Row tail: Accumulator store skip
- "subs x23, x23, #0x4\n"
- "add %x[res_ptr], %x[res_ptr], #0x10\n"
- "bne 6b\n"
- "subs x10, x10, #0x4\n"
- "add %x[a_ptr], %x[a_ptr], x9\n"
- "mov %x[res_ptr], x22\n"
- "bgt 5b\n"
- "9:" // Row tail: Row loop skip
- : [a_ptr] "+&r" (a_ptr), [res_ptr] "+&r" (res_ptr)
- : [b_ptr] "r" (b_ptr), [nr] "r" (nr), [nb] "r" (nb), [res_stride] "r" (res_stride), [nc] "r" (nc)
- : "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "x9", "x10", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28"
- );
-#elif defined(__ARM_NEON) && defined(__aarch64__)
- GGML_ASSERT((ggml_cpu_has_sve() || ggml_cpu_has_matmul_int8()) &&
- "__ARM_FEATURE_SVE and __ARM_FEATURE_MATMUL_INT8 not defined, use the Q4_0_4_4 quantization format for optimal "
- "performance");
-#else
+ __asm__ __volatile__(
+ "mov x10, %x[nr]\n"
+ "mov x9, #0x88\n"
+ "cmp x10, #0x10\n"
+ "mul x9, %x[nb], x9\n"
+ "blt 4f\n"
+ "1:" // Row loop
+ "add x28, %x[b_ptr], #0x8\n"
+ "mov x27, %x[nc]\n"
+ "add x26, %x[res_ptr], %x[res_stride], LSL #4\n"
+ "2:" // Column loop
+ "add x25, %x[a_ptr], #0x8\n"
+ "movi v2.16b, #0x0\n"
+ "movi v10.16b, #0x0\n"
+ "mov x24, %x[nb]\n"
+ "add x23, x25, x9\n"
+ "movi v12.16b, #0x0\n"
+ "movi v28.16b, #0x0\n"
+ "add x22, x23, x9\n"
+ "movi v11.16b, #0x0\n"
+ "movi v13.16b, #0x0\n"
+ "add x21, x22, x9\n"
+ "movi v22.16b, #0x0\n"
+ "movi v23.16b, #0x0\n"
+ "movi v25.16b, #0x0\n"
+ "movi v5.16b, #0x0\n"
+ "movi v7.16b, #0x0\n"
+ "movi v4.16b, #0x0\n"
+ "movi v6.16b, #0x0\n"
+ "movi v30.16b, #0x0\n"
+ "movi v24.16b, #0x0\n"
+ "movi v14.16b, #0x0\n"
+ "3:" // Block loop
+ "ldr q21, [x28, #0x0]\n"
+ "ldr q16, [x28, #0x10]\n"
+ "movi v1.16b, #0x4\n"
+ "movi v19.4s, #0x0\n"
+ "ldr q27, [x25, #0x0]\n"
+ "ldr q15, [x25, #0x10]\n"
+ "movi v26.4s, #0x0\n"
+ "movi v18.4s, #0x0\n"
+ "ldr q29, [x28, #0x20]\n"
+ "ldr q3, [x28, #0x30]\n"
+ "movi v17.4s, #0x0\n"
+ "movi v0.16b, #0xf0\n"
+ "ldr d20, [x25, #-0x8]\n"
+ "ldr d9, [x23, #-0x8]\n"
+ "sshl v8.16b, v21.16b, v1.16b\n"
+ "sshl v31.16b, v16.16b, v1.16b\n"
+ "and v21.16b, v21.16b, v0.16b\n"
+ "and v16.16b, v16.16b, v0.16b\n"
+ "sub x20, x28, #0x8\n"
+ "subs x24, x24, #0x1\n"
+ "add x28, x28, #0x48\n"
+ ".inst 0x4e88a773 // smmla v19.4s, v27.16b, v8.16b\n"
+ ".inst 0x4e9fa77a // smmla v26.4s, v27.16b, v31.16b\n"
+ "ldr q27, [x25, #0x20]\n"
+ ".inst 0x4e88a5f2 // smmla v18.4s, v15.16b, v8.16b\n"
+ ".inst 0x4e9fa5f1 // smmla v17.4s, v15.16b, v31.16b\n"
+ "sshl v15.16b, v29.16b, v1.16b\n"
+ "sshl v1.16b, v3.16b, v1.16b\n"
+ "and v29.16b, v29.16b, v0.16b\n"
+ "and v3.16b, v3.16b, v0.16b\n"
+ "ldr q0, [x25, #0x30]\n"
+ "fcvtl v20.4s, v20.4h\n"
+ ".inst 0x4e8fa773 // smmla v19.4s, v27.16b, v15.16b\n"
+ "fcvtl v9.4s, v9.4h\n"
+ ".inst 0x4e81a77a // smmla v26.4s, v27.16b, v1.16b\n"
+ "ldr q27, [x25, #0x40]\n"
+ ".inst 0x4e8fa412 // smmla v18.4s, v0.16b, v15.16b\n"
+ ".inst 0x4e81a411 // smmla v17.4s, v0.16b, v1.16b\n"
+ "ldr q0, [x25, #0x50]\n"
+ ".inst 0x4e95a773 // smmla v19.4s, v27.16b, v21.16b\n"
+ ".inst 0x4e90a77a // smmla v26.4s, v27.16b, v16.16b\n"
+ "ldr q27, [x25, #0x60]\n"
+ ".inst 0x4e95a412 // smmla v18.4s, v0.16b, v21.16b\n"
+ ".inst 0x4e90a411 // smmla v17.4s, v0.16b, v16.16b\n"
+ "ldr q0, [x25, #0x70]\n"
+ "add x25, x25, #0x88\n"
+ ".inst 0x4e9da773 // smmla v19.4s, v27.16b, v29.16b\n"
+ ".inst 0x4e83a77a // smmla v26.4s, v27.16b, v3.16b\n"
+ "ldr d27, [x20, #0x0]\n"
+ ".inst 0x4e9da412 // smmla v18.4s, v0.16b, v29.16b\n"
+ ".inst 0x4e83a411 // smmla v17.4s, v0.16b, v3.16b\n"
+ "fcvtl v27.4s, v27.4h\n"
+ "uzp1 v0.2d, v19.2d, v26.2d\n"
+ "uzp2 v26.2d, v19.2d, v26.2d\n"
+ "fmul v19.4s, v27.4s, v20.s[0]\n"
+ "scvtf v0.4s, v0.4s, #0x4\n"
+ "scvtf v26.4s, v26.4s, #0x4\n"
+ "fmla v2.4s, v0.4s, v19.4s\n"
+ "ldr q19, [x23, #0x0]\n"
+ "uzp1 v0.2d, v18.2d, v17.2d\n"
+ "uzp2 v18.2d, v18.2d, v17.2d\n"
+ "fmul v17.4s, v27.4s, v20.s[1]\n"
+ "scvtf v0.4s, v0.4s, #0x4\n"
+ "scvtf v18.4s, v18.4s, #0x4\n"
+ "fmla v10.4s, v26.4s, v17.4s\n"
+ "ldr q17, [x23, #0x10]\n"
+ "fmul v26.4s, v27.4s, v20.s[2]\n"
+ "fmul v20.4s, v27.4s, v20.s[3]\n"
+ "fmla v12.4s, v0.4s, v26.4s\n"
+ "ldr d0, [x22, #-0x8]\n"
+ "ldr d26, [x21, #-0x8]\n"
+ "fcvtl v0.4s, v0.4h\n"
+ "fmla v28.4s, v18.4s, v20.4s\n"
+ "movi v20.4s, #0x0\n"
+ "movi v18.4s, #0x0\n"
+ ".inst 0x4e88a674 // smmla v20.4s, v19.16b, v8.16b\n"
+ ".inst 0x4e9fa672 // smmla v18.4s, v19.16b, v31.16b\n"
+ "ldr q19, [x23, #0x20]\n"
+ "fcvtl v26.4s, v26.4h\n"
+ ".inst 0x4e8fa674 // smmla v20.4s, v19.16b, v15.16b\n"
+ ".inst 0x4e81a672 // smmla v18.4s, v19.16b, v1.16b\n"
+ "ldr q19, [x23, #0x40]\n"
+ ".inst 0x4e95a674 // smmla v20.4s, v19.16b, v21.16b\n"
+ ".inst 0x4e90a672 // smmla v18.4s, v19.16b, v16.16b\n"
+ "ldr q19, [x23, #0x60]\n"
+ ".inst 0x4e9da674 // smmla v20.4s, v19.16b, v29.16b\n"
+ ".inst 0x4e83a672 // smmla v18.4s, v19.16b, v3.16b\n"
+ "uzp1 v19.2d, v20.2d, v18.2d\n"
+ "scvtf v19.4s, v19.4s, #0x4\n"
+ "uzp2 v20.2d, v20.2d, v18.2d\n"
+ "fmul v18.4s, v27.4s, v9.s[0]\n"
+ "scvtf v20.4s, v20.4s, #0x4\n"
+ "fmla v11.4s, v19.4s, v18.4s\n"
+ "ldr q18, [x22, #0x0]\n"
+ "fmul v19.4s, v27.4s, v9.s[1]\n"
+ "fmla v13.4s, v20.4s, v19.4s\n"
+ "movi v19.4s, #0x0\n"
+ "movi v20.4s, #0x0\n"
+ ".inst 0x4e88a633 // smmla v19.4s, v17.16b, v8.16b\n"
+ ".inst 0x4e9fa634 // smmla v20.4s, v17.16b, v31.16b\n"
+ "ldr q17, [x23, #0x30]\n"
+ ".inst 0x4e8fa633 // smmla v19.4s, v17.16b, v15.16b\n"
+ ".inst 0x4e81a634 // smmla v20.4s, v17.16b, v1.16b\n"
+ "ldr q17, [x23, #0x50]\n"
+ ".inst 0x4e95a633 // smmla v19.4s, v17.16b, v21.16b\n"
+ ".inst 0x4e90a634 // smmla v20.4s, v17.16b, v16.16b\n"
+ "ldr q17, [x23, #0x70]\n"
+ "add x23, x23, #0x88\n"
+ ".inst 0x4e9da633 // smmla v19.4s, v17.16b, v29.16b\n"
+ ".inst 0x4e83a634 // smmla v20.4s, v17.16b, v3.16b\n"
+ "uzp1 v17.2d, v19.2d, v20.2d\n"
+ "scvtf v17.4s, v17.4s, #0x4\n"
+ "uzp2 v20.2d, v19.2d, v20.2d\n"
+ "fmul v19.4s, v27.4s, v9.s[2]\n"
+ "fmul v9.4s, v27.4s, v9.s[3]\n"
+ "scvtf v20.4s, v20.4s, #0x4\n"
+ "fmla v22.4s, v17.4s, v19.4s\n"
+ "ldr q17, [x22, #0x10]\n"
+ "movi v19.4s, #0x0\n"
+ ".inst 0x4e88a653 // smmla v19.4s, v18.16b, v8.16b\n"
+ "fmla v23.4s, v20.4s, v9.4s\n"
+ "movi v20.4s, #0x0\n"
+ "movi v9.4s, #0x0\n"
+ ".inst 0x4e9fa654 // smmla v20.4s, v18.16b, v31.16b\n"
+ "ldr q18, [x22, #0x20]\n"
+ ".inst 0x4e88a629 // smmla v9.4s, v17.16b, v8.16b\n"
+ ".inst 0x4e8fa653 // smmla v19.4s, v18.16b, v15.16b\n"
+ ".inst 0x4e81a654 // smmla v20.4s, v18.16b, v1.16b\n"
+ "ldr q18, [x22, #0x40]\n"
+ ".inst 0x4e95a653 // smmla v19.4s, v18.16b, v21.16b\n"
+ ".inst 0x4e90a654 // smmla v20.4s, v18.16b, v16.16b\n"
+ "ldr q18, [x22, #0x60]\n"
+ ".inst 0x4e9da653 // smmla v19.4s, v18.16b, v29.16b\n"
+ ".inst 0x4e83a654 // smmla v20.4s, v18.16b, v3.16b\n"
+ "movi v18.4s, #0x0\n"
+ ".inst 0x4e9fa632 // smmla v18.4s, v17.16b, v31.16b\n"
+ "ldr q17, [x22, #0x30]\n"
+ ".inst 0x4e8fa629 // smmla v9.4s, v17.16b, v15.16b\n"
+ ".inst 0x4e81a632 // smmla v18.4s, v17.16b, v1.16b\n"
+ "ldr q17, [x22, #0x50]\n"
+ ".inst 0x4e95a629 // smmla v9.4s, v17.16b, v21.16b\n"
+ ".inst 0x4e90a632 // smmla v18.4s, v17.16b, v16.16b\n"
+ "ldr q17, [x22, #0x70]\n"
+ "add x22, x22, #0x88\n"
+ ".inst 0x4e9da629 // smmla v9.4s, v17.16b, v29.16b\n"
+ ".inst 0x4e83a632 // smmla v18.4s, v17.16b, v3.16b\n"
+ "uzp1 v17.2d, v19.2d, v20.2d\n"
+ "uzp2 v20.2d, v19.2d, v20.2d\n"
+ "fmul v19.4s, v27.4s, v0.s[0]\n"
+ "scvtf v17.4s, v17.4s, #0x4\n"
+ "scvtf v20.4s, v20.4s, #0x4\n"
+ "fmla v25.4s, v17.4s, v19.4s\n"
+ "ldr q19, [x21, #0x0]\n"
+ "fmul v17.4s, v27.4s, v0.s[1]\n"
+ "fmla v5.4s, v20.4s, v17.4s\n"
+ "ldr q17, [x21, #0x10]\n"
+ "uzp1 v20.2d, v9.2d, v18.2d\n"
+ "uzp2 v9.2d, v9.2d, v18.2d\n"
+ "fmul v18.4s, v27.4s, v0.s[2]\n"
+ "fmul v0.4s, v27.4s, v0.s[3]\n"
+ "scvtf v20.4s, v20.4s, #0x4\n"
+ "scvtf v9.4s, v9.4s, #0x4\n"
+ "fmla v7.4s, v20.4s, v18.4s\n"
+ "movi v20.4s, #0x0\n"
+ "movi v18.4s, #0x0\n"
+ ".inst 0x4e88a674 // smmla v20.4s, v19.16b, v8.16b\n"
+ ".inst 0x4e9fa672 // smmla v18.4s, v19.16b, v31.16b\n"
+ "ldr q19, [x21, #0x20]\n"
+ "fmla v4.4s, v9.4s, v0.4s\n"
+ "movi v9.4s, #0x0\n"
+ "movi v0.4s, #0x0\n"
+ ".inst 0x4e88a629 // smmla v9.4s, v17.16b, v8.16b\n"
+ "fmul v8.4s, v27.4s, v26.s[0]\n"
+ ".inst 0x4e9fa620 // smmla v0.4s, v17.16b, v31.16b\n"
+ "ldr q17, [x21, #0x30]\n"
+ ".inst 0x4e8fa674 // smmla v20.4s, v19.16b, v15.16b\n"
+ "fmul v31.4s, v27.4s, v26.s[1]\n"
+ ".inst 0x4e81a672 // smmla v18.4s, v19.16b, v1.16b\n"
+ "ldr q19, [x21, #0x40]\n"
+ ".inst 0x4e8fa629 // smmla v9.4s, v17.16b, v15.16b\n"
+ "fmul v15.4s, v27.4s, v26.s[2]\n"
+ "fmul v27.4s, v27.4s, v26.s[3]\n"
+ ".inst 0x4e81a620 // smmla v0.4s, v17.16b, v1.16b\n"
+ "ldr q1, [x21, #0x50]\n"
+ ".inst 0x4e95a674 // smmla v20.4s, v19.16b, v21.16b\n"
+ ".inst 0x4e90a672 // smmla v18.4s, v19.16b, v16.16b\n"
+ "ldr q26, [x21, #0x60]\n"
+ ".inst 0x4e95a429 // smmla v9.4s, v1.16b, v21.16b\n"
+ ".inst 0x4e90a420 // smmla v0.4s, v1.16b, v16.16b\n"
+ "ldr q21, [x21, #0x70]\n"
+ "add x21, x21, #0x88\n"
+ ".inst 0x4e9da754 // smmla v20.4s, v26.16b, v29.16b\n"
+ ".inst 0x4e83a752 // smmla v18.4s, v26.16b, v3.16b\n"
+ ".inst 0x4e9da6a9 // smmla v9.4s, v21.16b, v29.16b\n"
+ ".inst 0x4e83a6a0 // smmla v0.4s, v21.16b, v3.16b\n"
+ "uzp1 v29.2d, v20.2d, v18.2d\n"
+ "uzp2 v21.2d, v20.2d, v18.2d\n"
+ "scvtf v29.4s, v29.4s, #0x4\n"
+ "uzp1 v18.2d, v9.2d, v0.2d\n"
+ "uzp2 v16.2d, v9.2d, v0.2d\n"
+ "scvtf v21.4s, v21.4s, #0x4\n"
+ "fmla v6.4s, v29.4s, v8.4s\n"
+ "scvtf v18.4s, v18.4s, #0x4\n"
+ "scvtf v16.4s, v16.4s, #0x4\n"
+ "fmla v30.4s, v21.4s, v31.4s\n"
+ "fmla v24.4s, v18.4s, v15.4s\n"
+ "fmla v14.4s, v16.4s, v27.4s\n"
+ "bgt 3b\n"
+ "mov x20, %x[res_ptr]\n"
+ "subs x27, x27, #0x4\n"
+ "add %x[res_ptr], %x[res_ptr], #0x10\n"
+ "str q2, [x20, #0x0]\n"
+ "add x20, x20, %x[res_stride]\n"
+ "str q10, [x20, #0x0]\n"
+ "add x20, x20, %x[res_stride]\n"
+ "str q12, [x20, #0x0]\n"
+ "add x20, x20, %x[res_stride]\n"
+ "str q28, [x20, #0x0]\n"
+ "add x20, x20, %x[res_stride]\n"
+ "str q11, [x20, #0x0]\n"
+ "add x20, x20, %x[res_stride]\n"
+ "str q13, [x20, #0x0]\n"
+ "add x20, x20, %x[res_stride]\n"
+ "str q22, [x20, #0x0]\n"
+ "add x20, x20, %x[res_stride]\n"
+ "str q23, [x20, #0x0]\n"
+ "add x20, x20, %x[res_stride]\n"
+ "str q25, [x20, #0x0]\n"
+ "add x20, x20, %x[res_stride]\n"
+ "str q5, [x20, #0x0]\n"
+ "add x20, x20, %x[res_stride]\n"
+ "str q7, [x20, #0x0]\n"
+ "add x20, x20, %x[res_stride]\n"
+ "str q4, [x20, #0x0]\n"
+ "add x20, x20, %x[res_stride]\n"
+ "str q6, [x20, #0x0]\n"
+ "add x20, x20, %x[res_stride]\n"
+ "str q30, [x20, #0x0]\n"
+ "add x20, x20, %x[res_stride]\n"
+ "str q24, [x20, #0x0]\n"
+ "add x20, x20, %x[res_stride]\n"
+ "str q14, [x20, #0x0]\n"
+ "bne 2b\n"
+ "mov x20, #0x4\n"
+ "sub x10, x10, #0x10\n"
+ "cmp x10, #0x10\n"
+ "mov %x[res_ptr], x26\n"
+ "madd %x[a_ptr], x20, x9, %x[a_ptr]\n"
+ "bge 1b\n"
+ "4:" // Row loop skip
+ "cbz x10, 9f\n"
+ "5:" // Row tail: Row loop
+ "add x24, %x[b_ptr], #0x8\n"
+ "mov x23, %x[nc]\n"
+ "add x22, %x[res_ptr], %x[res_stride], LSL #2\n"
+ "6:" // Row tail: Column loop
+ "movi v2.16b, #0x0\n"
+ "movi v10.16b, #0x0\n"
+ "add x25, %x[a_ptr], #0x8\n"
+ "mov x21, %x[nb]\n"
+ "movi v12.16b, #0x0\n"
+ "movi v28.16b, #0x0\n"
+ "7:" // Row tail: Block loop
+ "ldr q6, [x24, #0x0]\n"
+ "ldr q5, [x24, #0x10]\n"
+ "movi v17.16b, #0x4\n"
+ "movi v8.4s, #0x0\n"
+ "ldr q4, [x25, #0x0]\n"
+ "ldr q13, [x25, #0x10]\n"
+ "movi v27.4s, #0x0\n"
+ "movi v0.4s, #0x0\n"
+ "ldr q31, [x24, #0x20]\n"
+ "ldr q14, [x24, #0x30]\n"
+ "movi v29.4s, #0x0\n"
+ "movi v22.16b, #0xf0\n"
+ "ldr q11, [x25, #0x20]\n"
+ "ldr q23, [x25, #0x30]\n"
+ "sshl v21.16b, v6.16b, v17.16b\n"
+ "sshl v16.16b, v5.16b, v17.16b\n"
+ "ldr q20, [x25, #0x40]\n"
+ "ldr q26, [x25, #0x50]\n"
+ "and v6.16b, v6.16b, v22.16b\n"
+ "and v5.16b, v5.16b, v22.16b\n"
+ "ldr q25, [x25, #0x60]\n"
+ "ldr q3, [x25, #0x70]\n"
+ "sshl v19.16b, v31.16b, v17.16b\n"
+ "sshl v18.16b, v14.16b, v17.16b\n"
+ "ldr d17, [x25, #-0x8]\n"
+ ".inst 0x4e95a488 // smmla v8.4s, v4.16b, v21.16b\n"
+ ".inst 0x4e90a49b // smmla v27.4s, v4.16b, v16.16b\n"
+ "and v31.16b, v31.16b, v22.16b\n"
+ ".inst 0x4e95a5a0 // smmla v0.4s, v13.16b, v21.16b\n"
+ ".inst 0x4e90a5bd // smmla v29.4s, v13.16b, v16.16b\n"
+ "and v14.16b, v14.16b, v22.16b\n"
+ "sub x20, x24, #0x8\n"
+ "ldr d16, [x20, #0x0]\n"
+ "subs x21, x21, #0x1\n"
+ "add x25, x25, #0x88\n"
+ "fcvtl v17.4s, v17.4h\n"
+ "add x24, x24, #0x48\n"
+ ".inst 0x4e93a568 // smmla v8.4s, v11.16b, v19.16b\n"
+ ".inst 0x4e92a57b // smmla v27.4s, v11.16b, v18.16b\n"
+ ".inst 0x4e93a6e0 // smmla v0.4s, v23.16b, v19.16b\n"
+ ".inst 0x4e92a6fd // smmla v29.4s, v23.16b, v18.16b\n"
+ "fcvtl v16.4s, v16.4h\n"
+ ".inst 0x4e86a688 // smmla v8.4s, v20.16b, v6.16b\n"
+ ".inst 0x4e85a69b // smmla v27.4s, v20.16b, v5.16b\n"
+ "fmul v23.4s, v16.4s, v17.s[0]\n"
+ "fmul v21.4s, v16.4s, v17.s[1]\n"
+ "fmul v1.4s, v16.4s, v17.s[2]\n"
+ "fmul v20.4s, v16.4s, v17.s[3]\n"
+ ".inst 0x4e86a740 // smmla v0.4s, v26.16b, v6.16b\n"
+ ".inst 0x4e85a75d // smmla v29.4s, v26.16b, v5.16b\n"
+ ".inst 0x4e9fa728 // smmla v8.4s, v25.16b, v31.16b\n"
+ ".inst 0x4e8ea73b // smmla v27.4s, v25.16b, v14.16b\n"
+ ".inst 0x4e9fa460 // smmla v0.4s, v3.16b, v31.16b\n"
+ ".inst 0x4e8ea47d // smmla v29.4s, v3.16b, v14.16b\n"
+ "uzp1 v19.2d, v8.2d, v27.2d\n"
+ "uzp2 v18.2d, v8.2d, v27.2d\n"
+ "scvtf v19.4s, v19.4s, #0x4\n"
+ "uzp1 v17.2d, v0.2d, v29.2d\n"
+ "uzp2 v16.2d, v0.2d, v29.2d\n"
+ "scvtf v18.4s, v18.4s, #0x4\n"
+ "fmla v2.4s, v19.4s, v23.4s\n"
+ "scvtf v17.4s, v17.4s, #0x4\n"
+ "scvtf v16.4s, v16.4s, #0x4\n"
+ "fmla v10.4s, v18.4s, v21.4s\n"
+ "fmla v12.4s, v17.4s, v1.4s\n"
+ "fmla v28.4s, v16.4s, v20.4s\n"
+ "bgt 7b\n"
+ "mov x20, %x[res_ptr]\n"
+ "cmp x10, #0x1\n"
+ "str q2, [x20, #0x0]\n"
+ "add x20, x20, %x[res_stride]\n"
+ "ble 8f\n"
+ "cmp x10, #0x2\n"
+ "str q10, [x20, #0x0]\n"
+ "add x20, x20, %x[res_stride]\n"
+ "ble 8f\n"
+ "cmp x10, #0x3\n"
+ "str q12, [x20, #0x0]\n"
+ "add x20, x20, %x[res_stride]\n"
+ "ble 8f\n"
+ "str q28, [x20, #0x0]\n"
+ "8:" // Row tail: Accumulator store skip
+ "subs x23, x23, #0x4\n"
+ "add %x[res_ptr], %x[res_ptr], #0x10\n"
+ "bne 6b\n"
+ "subs x10, x10, #0x4\n"
+ "add %x[a_ptr], %x[a_ptr], x9\n"
+ "mov %x[res_ptr], x22\n"
+ "bgt 5b\n"
+ "9:" // Row tail: Row loop skip
+ : [a_ptr] "+&r" (a_ptr), [res_ptr] "+&r" (res_ptr)
+ : [b_ptr] "r" (b_ptr), [nr] "r" (nr), [nb] "r" (nb), [res_stride] "r" (res_stride), [nc] "r" (nc)
+ : "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "x9", "x10", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28"
+ );
+ return;
+ }
+#endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8)
float sumf[4][4];
int sumi;
@@ -1944,7 +1983,7 @@ void ggml_gemm_q4_0_4x8_q8_0(int n, float * restrict s, size_t bs, const void *
const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] << 4);
const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF0);
sumi += ((v0 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i]) +
- (v1 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i + qk / 2 * 4])) >> 4;
+ (v1 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i + qk / 2 * 4])) >> 4;
}
sumf[m][j] += sumi * GGML_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_FP16_TO_FP32(a_ptr[l].d[m]);
}
@@ -1957,7 +1996,6 @@ void ggml_gemm_q4_0_4x8_q8_0(int n, float * restrict s, size_t bs, const void *
}
}
}
-#endif
}
void ggml_gemm_q4_0_8x8_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, const void * restrict vy, int nr, int nc) {
@@ -1980,8 +2018,9 @@ void ggml_gemm_q4_0_8x8_q8_0(int n, float * restrict s, size_t bs, const void *
UNUSED(ncols_interleaved);
UNUSED(blocklen);
-#if defined(__ARM_FEATURE_SVE) && defined(__ARM_FEATURE_MATMUL_INT8) && ! ((defined(_MSC_VER)) && ! defined(__clang__))
- if (ggml_sve_cnt_b == QK8_0) {
+#if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__)
+#if defined(__ARM_FEATURE_SVE) && defined(__ARM_FEATURE_MATMUL_INT8)
+ if (ggml_cpu_has_sve() && ggml_cpu_has_matmul_int8() && sve_lane_count() == QK8_0) {
const void * b_ptr = vx;
const void * a_ptr = vy;
float * res_ptr = s;
@@ -2391,134 +2430,682 @@ void ggml_gemm_q4_0_8x8_q8_0(int n, float * restrict s, size_t bs, const void *
);
return;
}
- else if (ggml_cpu_has_neon() && ggml_cpu_has_matmul_int8()) {
- GGML_ASSERT((ggml_cpu_has_sve() && (ggml_sve_cnt_b == QK8_0)) &&
- "__ARM_FEATURE_SVE for vector size of 256-bits not defined, use the Q4_0_4_8 quantization format for optimal "
- "performance");
- }
- else if (ggml_cpu_has_neon()) {
- GGML_ASSERT(((ggml_cpu_has_sve() && (ggml_sve_cnt_b == QK8_0)) || ggml_cpu_has_matmul_int8()) &&
- "__ARM_FEATURE_SVE for vector size of 256-bits and __ARM_FEATURE_MATMUL_INT8 not defined, use the Q4_0_4_4 "
- "quantization format for optimal performance");
- }
-#endif
-#if defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8)
- GGML_ASSERT(ggml_cpu_has_sve() &&
- "__ARM_FEATURE_SVE not defined, use the Q4_0_4_8 quantization format for optimal performance");
-#elif defined(__ARM_NEON) && defined(__aarch64__)
- GGML_ASSERT((ggml_cpu_has_sve() || ggml_cpu_has_matmul_int8()) &&
- "__ARM_FEATURE_SVE and __ARM_FEATURE_MATMUL_INT8 not defined, use the Q4_0_4_4 quantization format for optimal "
- "performance");
+#endif // #if defined(__ARM_FEATURE_SVE) && defined(__ARM_FEATURE_MATMUL_INT8)
#elif defined(__AVX2__) || defined(__AVX512F__)
- const block_q4_0x8 * b_ptr_start = (const block_q4_0x8 *)vx;
- const block_q8_0x4 * a_ptr_start = (const block_q8_0x4 *)vy;
- int64_t b_nb = n / QK4_0;
- int64_t y = 0;
- // Mask to mask out nibbles from packed bytes
- const __m256i m4b = _mm256_set1_epi8(0x0F);
- const __m128i loadMask = _mm_blend_epi32(_mm_setzero_si128(), _mm_set1_epi32(0xFFFFFFFF), 3);
- // Lookup table to convert signed nibbles to signed bytes
- __m256i signextendlut = _mm256_castsi128_si256(_mm_set_epi8(-1, -2, -3, -4, -5, -6, -7, -8, 7, 6, 5, 4, 3, 2, 1, 0));
- signextendlut = _mm256_permute2f128_si256(signextendlut, signextendlut, 0);
- // Permute mask used for easier vector processing at later stages
- __m256i requiredOrder = _mm256_set_epi32(3 ,2 ,1 ,0, 7 ,6, 5, 4);
+ {
+ const block_q4_0x8 * b_ptr_start = (const block_q4_0x8 *)vx;
+ const block_q8_0x4 * a_ptr_start = (const block_q8_0x4 *)vy;
+ int64_t b_nb = n / QK4_0;
+ int64_t y = 0;
+ // Mask to mask out nibbles from packed bytes
+ const __m256i m4b = _mm256_set1_epi8(0x0F);
+ const __m128i loadMask = _mm_blend_epi32(_mm_setzero_si128(), _mm_set1_epi32(0xFFFFFFFF), 3);
+ // Lookup table to convert signed nibbles to signed bytes
+ __m256i signextendlut = _mm256_castsi128_si256(_mm_set_epi8(-1, -2, -3, -4, -5, -6, -7, -8, 7, 6, 5, 4, 3, 2, 1, 0));
+ signextendlut = _mm256_permute2f128_si256(signextendlut, signextendlut, 0);
+ // Permute mask used for easier vector processing at later stages
+ __m256i requiredOrder = _mm256_set_epi32(3, 2, 1, 0, 7, 6, 5, 4);
+ int64_t xstart = 0;
+ int anr = nr - nr%16; // Used to align nr with boundary of 16
+ #ifdef __AVX512F__
+ int anc = nc - nc%16; // Used to align nc with boundary of 16
+ // Mask to mask out nibbles from packed bytes expanded to 512 bit length
+ const __m512i m4bexpanded = _mm512_set1_epi8(0x0F);
+ // Lookup table to convert signed nibbles to signed bytes expanded to 512 bit length
+ __m512i signextendlutexpanded = _mm512_inserti32x8(_mm512_castsi256_si512(signextendlut), signextendlut, 1);
- // Take group of four block_q8_0x4 structures at each pass of the loop and perform dot product operation
- int anr = nr - nr %16; // Used to align nr with boundary of 16
+ // Take group of four block_q8_0x4 structures at each pass of the loop and perform dot product operation
+ for (; y < anr / 4; y += 4) {
- for (; y < anr / 4; y += 4) {
- const block_q8_0x4 * a_ptrs[4];
+ const block_q8_0x4 * a_ptrs[4];
- a_ptrs[0] = a_ptr_start + (y * nb);
- for (int i = 0; i < 3; ++i) {
- a_ptrs[i + 1] = a_ptrs[i] + nb;
- }
-
- // Take group of eight block_q4_0x8 structures at each pass of the loop and perform dot product operation
- for (int64_t x = 0; x < nc / 8; x++) {
-
- const block_q4_0x8 * b_ptr = b_ptr_start + (x * b_nb);
-
- // Master FP accumulators
- __m256 acc_rows[16];
- for (int i = 0; i < 16; i++) {
- acc_rows[i] = _mm256_setzero_ps();
+ a_ptrs[0] = a_ptr_start + (y * nb);
+ for (int i = 0; i < 3; ++i) {
+ a_ptrs[i + 1] = a_ptrs[i] + nb;
}
- for (int64_t b = 0; b < nb; b++) {
- // Load the eight block_q4_0 quantized values interleaved with each other in chunks of eight - B0,B1 ....B6,B7
- const __m256i rhs_raw_mat_0123_0 = _mm256_loadu_si256((const __m256i *)(b_ptr[b].qs));
- const __m256i rhs_raw_mat_4567_0 = _mm256_loadu_si256((const __m256i *)(b_ptr[b].qs + 32));
- const __m256i rhs_raw_mat_0123_1 = _mm256_loadu_si256((const __m256i *)(b_ptr[b].qs + 64));
- const __m256i rhs_raw_mat_4567_1 = _mm256_loadu_si256((const __m256i *)(b_ptr[b].qs + 96));
+ // Take group of two block_q4_0x8 structures at each pass of the loop and perform dot product operation
+ for (int64_t x = 0; x < anc / 8; x += 2) {
- // Save the values in the following vectors in the formats B0B1B4B5, B2B3B6B7 for further processing and storing of values
- const __m256i rhs_raw_mat_0145_0 = _mm256_blend_epi32(rhs_raw_mat_0123_0, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_0, requiredOrder), 240);
- const __m256i rhs_raw_mat_2367_0 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_0, requiredOrder), rhs_raw_mat_4567_0, 240);
- const __m256i rhs_raw_mat_0145_1 = _mm256_blend_epi32(rhs_raw_mat_0123_1, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_1, requiredOrder), 240);
- const __m256i rhs_raw_mat_2367_1 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_1, requiredOrder), rhs_raw_mat_4567_1, 240);
+ const block_q4_0x8 * b_ptr_0 = b_ptr_start + ((x) * b_nb);
+ const block_q4_0x8 * b_ptr_1 = b_ptr_start + ((x + 1) * b_nb);
- // 4-bit -> 8-bit - Sign is maintained
- const __m256i rhs_mat_0145_0 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(rhs_raw_mat_0145_0, m4b)); //B0(0-7) B1(0-7) B4(0-7) B5(0-7)
- const __m256i rhs_mat_2367_0 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(rhs_raw_mat_2367_0, m4b)); //B2(0-7) B3(0-7) B6(0-7) B7(0-7)
+ // Master FP accumulators
+ __m512 acc_rows[16];
+ for (int i = 0; i < 16; i++) {
+ acc_rows[i] = _mm512_setzero_ps();
+ }
- const __m256i rhs_mat_0145_1 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(rhs_raw_mat_0145_1, m4b)); //B0(8-15) B1(8-15) B4(8-15) B5(8-15)
- const __m256i rhs_mat_2367_1 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(rhs_raw_mat_2367_1, m4b)); //B2(8-15) B3(8-15) B6(8-15) B7(8-15)
+ for (int64_t b = 0; b < nb; b++) {
+ // Load the sixteen block_q4_0 quantized values interleaved with each other in chunks of eight - B0,B1 ....BE,BF
+ const __m256i rhs_raw_mat_0123_0 = _mm256_loadu_si256((const __m256i *)(b_ptr_0[b].qs));
+ const __m256i rhs_raw_mat_4567_0 = _mm256_loadu_si256((const __m256i *)(b_ptr_0[b].qs + 32));
+ const __m256i rhs_raw_mat_0123_1 = _mm256_loadu_si256((const __m256i *)(b_ptr_0[b].qs + 64));
+ const __m256i rhs_raw_mat_4567_1 = _mm256_loadu_si256((const __m256i *)(b_ptr_0[b].qs + 96));
- const __m256i rhs_mat_0145_2 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_0145_0, 4), m4b)); //B0(16-23) B1(16-23) B4(16-23) B5(16-23)
- const __m256i rhs_mat_2367_2 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_2367_0, 4), m4b)); //B2(16-23) B3(16-23) B6(16-23) B7(16-23)
+ const __m256i rhs_raw_mat_89AB_0 = _mm256_loadu_si256((const __m256i *)(b_ptr_1[b].qs));
+ const __m256i rhs_raw_mat_CDEF_0 = _mm256_loadu_si256((const __m256i *)(b_ptr_1[b].qs + 32));
+ const __m256i rhs_raw_mat_89AB_1 = _mm256_loadu_si256((const __m256i *)(b_ptr_1[b].qs + 64));
+ const __m256i rhs_raw_mat_CDEF_1 = _mm256_loadu_si256((const __m256i *)(b_ptr_1[b].qs + 96));
- const __m256i rhs_mat_0145_3 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_0145_1, 4), m4b)); //B0(24-31) B1(24-31) B4(24-31) B5(24-31)
- const __m256i rhs_mat_2367_3 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_2367_1, 4), m4b)); //B2(24-31) B3(24-31) B6(24-31) B7(24-31)
+ // Save the values in the following vectors in the formats B0B1B4B5B8B9BCBD, B2B3B6B7BABBBEBF for further processing and storing of values
+ const __m256i rhs_raw_mat_0145_0 = _mm256_blend_epi32(rhs_raw_mat_0123_0, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_0, requiredOrder), 240);
+ const __m256i rhs_raw_mat_2367_0 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_0, requiredOrder), rhs_raw_mat_4567_0, 240);
+ const __m256i rhs_raw_mat_0145_1 = _mm256_blend_epi32(rhs_raw_mat_0123_1, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_1, requiredOrder), 240);
+ const __m256i rhs_raw_mat_2367_1 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_1, requiredOrder), rhs_raw_mat_4567_1, 240);
- // Shuffle pattern one - right side input
- const __m256i rhs_mat_0145_0_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_0, 136); //B0(0-3) B1(0-3) B0(0-3) B1(0-3) B4(0-3) B5(0-3) B4(0-3) B5(0-3)
- const __m256i rhs_mat_2367_0_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_0, 136); //B2(0-3) B3(0-3) B2(0-3) B3(0-3) B6(0-3) B7(0-3) B6(0-3) B7(0-3)
+ const __m256i rhs_raw_mat_89CD_0 = _mm256_blend_epi32(rhs_raw_mat_89AB_0, _mm256_permutevar8x32_epi32(rhs_raw_mat_CDEF_0, requiredOrder), 240);
+ const __m256i rhs_raw_mat_ABEF_0 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_89AB_0, requiredOrder), rhs_raw_mat_CDEF_0, 240);
+ const __m256i rhs_raw_mat_89CD_1 = _mm256_blend_epi32(rhs_raw_mat_89AB_1, _mm256_permutevar8x32_epi32(rhs_raw_mat_CDEF_1, requiredOrder), 240);
+ const __m256i rhs_raw_mat_ABEF_1 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_89AB_1, requiredOrder), rhs_raw_mat_CDEF_1, 240);
- const __m256i rhs_mat_0145_1_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_1, 136); //B0(8-11) B1(8-11) B0(8-11) B1(8-11) B4(8-11) B5(8-11) B4(8-11) B5(8-11)
- const __m256i rhs_mat_2367_1_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_1, 136); //B2(8-11) B3(8-11) B2(8-11) B3(8-11) B6(8-11) B7(8-11) B6(8-11) B7(8-11)
+ const __m512i rhs_raw_mat_014589CD_0 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_0145_0), rhs_raw_mat_89CD_0, 1);
+ const __m512i rhs_raw_mat_2367ABEF_0 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_2367_0), rhs_raw_mat_ABEF_0, 1);
+ const __m512i rhs_raw_mat_014589CD_1 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_0145_1), rhs_raw_mat_89CD_1, 1);
+ const __m512i rhs_raw_mat_2367ABEF_1 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_2367_1), rhs_raw_mat_ABEF_1, 1);
- const __m256i rhs_mat_0145_2_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_2, 136); //B0(16-19) B1(16-19) B0(16-19) B1(16-19) B4(16-19) B5(16-19) B4(16-19) B5(16-19)
- const __m256i rhs_mat_2367_2_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_2, 136); //B2(16-19) B3(16-19) B2(16-19) B3(16-19) B6(16-19) B7(16-19) B6(16-19) B7(16-19)
+ // 4-bit -> 8-bit - Sign is maintained
+ const __m512i rhs_mat_014589CD_0 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(rhs_raw_mat_014589CD_0, m4bexpanded)); //B0(0-7) B1(0-7) B4(0-7) B5(0-7) B8(0-7) B9(0-7) BC(0-7) BD(0-7)
+ const __m512i rhs_mat_2367ABEF_0 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(rhs_raw_mat_2367ABEF_0, m4bexpanded)); //B2(0-7) B3(0-7) B6(0-7) B7(0-7) BA(0-7) BB(0-7) BE(0-7) BF(0-7)
- const __m256i rhs_mat_0145_3_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_3, 136); //B0(24-27) B1(24-27) B0(24-27) B1(24-27) B4(24-27) B5(24-27) B4(24-27) B5(24-27)
- const __m256i rhs_mat_2367_3_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_3, 136); //B2(24-27) B3(24-27) B2(24-27) B3(24-27) B6(24-27) B7(24-27) B6(24-27) B7(24-27)
+ const __m512i rhs_mat_014589CD_1 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(rhs_raw_mat_014589CD_1, m4bexpanded)); //B0(8-15) B1(8-15) B4(8-15) B5(8-15) B8(8-15) B9(8-15) BC(8-15) BD(8-15)
+ const __m512i rhs_mat_2367ABEF_1 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(rhs_raw_mat_2367ABEF_1, m4bexpanded)); //B2(8-15) B3(8-15) B6(8-15) B7(8-15) BA(8-15) BB(8-15) BE(8-15) BF(8-15)
- // Shuffle pattern two - right side input
+ const __m512i rhs_mat_014589CD_2 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_014589CD_0, 4), m4bexpanded)); //B0(16-23) B1(16-23) B4(16-23) B5(16-23) B8(16-23) B9(16-23) BC(16-23) BD(16-23)
+ const __m512i rhs_mat_2367ABEF_2 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_2367ABEF_0, 4), m4bexpanded)); //B2(16-23) B3(16-23) B6(16-23) B7(16-23) BA(16-23) BB(16-23) BE(16-23) BF(16-23)
- const __m256i rhs_mat_0145_0_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_0, 221); //B0(4-7) B1(4-7) B0(4-7) B1(4-7) B4(4-7) B5(4-7) B4(4-7) B5(4-7)
- const __m256i rhs_mat_2367_0_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_0, 221); //B2(4-7) B3(4-7) B2(4-7) B3(4-7) B6(4-7) B7(4-7) B6(4-7) B7(4-7)
+ const __m512i rhs_mat_014589CD_3 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_014589CD_1, 4), m4bexpanded)); //B0(24-31) B1(24-31) B4(24-31) B5(24-31) B8(24-31) B9(24-31) BC(24-31) BD(24-31)
+ const __m512i rhs_mat_2367ABEF_3 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_2367ABEF_1, 4), m4bexpanded)); //B2(24-31) B3(24-31) B6(24-31) B7(24-31) BA(24-31) BB(24-31) BE(24-31) BF(24-31)
- const __m256i rhs_mat_0145_1_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_1, 221); //B0(12-15) B1(12-15) B0(12-15) B1(12-15) B4(12-15) B5(12-15) B4(12-15) B5(12-15)
- const __m256i rhs_mat_2367_1_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_1, 221); //B2(12-15) B3(12-15) B2(12-15) B3(12-15) B6(12-15) B7(12-15) B6(12-15) B7(12-15)
+ // Shuffle pattern one - right side input
+ const __m512i rhs_mat_014589CD_0_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_0, 136); //B0(0-3) B1(0-3) B0(0-3) B1(0-3) B4(0-3) B5(0-3) B4(0-3) B5(0-3) B8(0-3) B9(0-3) B8(0-3) B9(0-3) BC(0-3) BD(0-3) BC(0-3) BD(0-3)
+ const __m512i rhs_mat_2367ABEF_0_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_0, 136); //B2(0-3) B3(0-3) B2(0-3) B3(0-3) B6(0-3) B7(0-3) B6(0-3) B7(0-3) BA(0-3) BB(0-3) BA(0-3) BB(0-3) BE(0-3) BF(0-3) BE(0-3) BF(0-3)
- const __m256i rhs_mat_0145_2_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_2, 221); //B0(20-23) B1(20-23) B0(20-23) B1(20-23) B4(20-23) B5(20-23) B4(20-23) B5(20-23)
- const __m256i rhs_mat_2367_2_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_2, 221); //B2(20-23) B3(20-23) B2(20-23) B3(20-23) B6(20-23) B7(20-23) B6(20-23) B7(20-23)
+ const __m512i rhs_mat_014589CD_1_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_1, 136); //B0(8-11) B1(8-11) B0(8-11) B1(8-11) B4(8-11) B5(8-11) B4(8-11) B5(8-11) B8(8-11) B9(8-11) B8(8-11) B9(8-11) BC(8-11) BD(8-11) BC(8-11) BD(8-11)
+ const __m512i rhs_mat_2367ABEF_1_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_1, 136); //B2(8-11) B3(8-11) B2(8-11) B3(8-11) B6(8-11) B7(8-11) B6(8-11) B7(8-11) BA(8-11) BB(8-11) BA(8-11) BB(8-11) BE(8-11) BF(8-11) BE(8-11) BF(8-11)
- const __m256i rhs_mat_0145_3_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_3, 221); //B0(28-31) B1(28-31) B0(28-31) B1(28-31) B4(28-31) B5(28-31) B4(28-31) B5(28-31)
- const __m256i rhs_mat_2367_3_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_3, 221); //B2(28-31) B3(28-31) B2(28-31) B3(28-31) B6(28-31) B7(28-31) B6(28-31) B7(28-31)
+ const __m512i rhs_mat_014589CD_2_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_2, 136); //B0(16-19) B1(16-19) B0(16-19) B1(16-19) B4(16-19) B5(16-19) B4(16-19) B5(16-19) B8(16-19) B9(16-19) B8(16-19) B9(16-19) BC(16-19) BD(16-19) BC(16-19) BD(16-19)
+ const __m512i rhs_mat_2367ABEF_2_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_2, 136); //B2(16-19) B3(16-19) B2(16-19) B3(16-19) B6(16-19) B7(16-19) B6(16-19) B7(16-19) BA(16-19) BB(16-19) BA(16-19) BB(16-19) BE(16-19) BF(16-19) BE(16-19) BF(16-19)
- // Scale values - Load the wight scale values of block_q4_0x8
- const __m256 col_scale_f32 = GGML_F32Cx8_LOAD(b_ptr[b].d);
+ const __m512i rhs_mat_014589CD_3_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_3, 136); //B0(24-27) B1(24-27) B0(24-27) B1(24-27) B4(24-27) B5(24-27) B4(24-27) B5(24-27) B8(24-27) B9(24-27) B8(24-27) B9(24-27) BC(24-27) BD(24-27) BC(24-27) BD(24-27)
+ const __m512i rhs_mat_2367ABEF_3_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_3, 136); //B2(24-27) B3(24-27) B2(24-27) B3(24-27) B6(24-27) B7(24-27) B6(24-27) B7(24-27) BA(24-27) BB(24-27) BA(24-27) BB(24-27) BE(24-27) BF(24-27) BE(24-27) BF(24-27)
+
+ // Shuffle pattern two - right side input
+
+ const __m512i rhs_mat_014589CD_0_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_0, 221); //B0(4-7) B1(4-7) B0(4-7) B1(4-7) B4(4-7) B5(4-7) B4(4-7) B5(4-7) B8(4-7) B9(4-7) B8(4-7) B9(4-7) BC(4-7) BD(4-7) BC(4-7) BD(4-7)
+ const __m512i rhs_mat_2367ABEF_0_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_0, 221); //B2(4-7) B3(4-7) B2(4-7) B3(4-7) B6(4-7) B7(4-7) B6(4-7) B7(4-7) BA(4-7) BB(4-7) BA(4-7) BB(4-7) BE(4-7) BF(4-7) BE(4-7) BF(4-7)
+
+ const __m512i rhs_mat_014589CD_1_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_1, 221); //B0(12-15) B1(12-15) B0(12-15) B1(12-15) B4(12-15) B5(12-15) B4(12-15) B5(12-15) B8(12-15) B9(12-15) B8(12-15) B9(12-15) BC(12-15) BD(12-15) BC(12-15) BD(12-15)
+ const __m512i rhs_mat_2367ABEF_1_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_1, 221); //B2(12-15) B3(12-15) B2(12-15) B3(12-15) B6(12-15) B7(12-15) B6(12-15) B7(12-15) BA(12-15) BB(12-15) BA(12-15) BB(12-15) BE(12-15) BF(12-15) BE(12-15) BF(12-15)
+
+ const __m512i rhs_mat_014589CD_2_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_2, 221); //B0(20-23) B1(20-23) B0(20-23) B1(20-23) B4(20-23) B5(20-23) B4(20-23) B5(20-23) B8(20-23) B9(20-23) B8(20-23) B9(20-23) BC(20-23) BD(20-23) BC(20-23) BD(20-23)
+ const __m512i rhs_mat_2367ABEF_2_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_2, 221); //B2(20-23) B3(20-23) B2(20-23) B3(20-23) B6(20-23) B7(20-23) B6(20-23) B7(20-23) BA(20-23) BB(20-23) BA(20-23) BB(20-23) BE(20-23) BF(20-23) BE(20-23) BF(20-23)
+
+ const __m512i rhs_mat_014589CD_3_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_3, 221); //B0(28-31) B1(28-31) B0(28-31) B1(28-31) B4(28-31) B5(28-31) B4(28-31) B5(28-31) B8(28-31) B9(28-31) B8(28-31) B9(28-31) BC(28-31) BD(28-31) BC(28-31) BD(28-31)
+ const __m512i rhs_mat_2367ABEF_3_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_3, 221); //B2(28-31) B3(28-31) B2(28-31) B3(28-31) B6(28-31) B7(28-31) B6(28-31) B7(28-31) BA(28-31) BB(28-31) BA(28-31) BB(28-31) BE(28-31) BF(28-31) BE(28-31) BF(28-31)
+
+ // Scale values - Load the weight scale values of two block_q4_0x8
+ const __m512 col_scale_f32 = GGML_F32Cx8x2_LOAD(b_ptr_0[b].d, b_ptr_1[b].d);
+
+ // Process LHS in pairs of rows
+ for (int rp = 0; rp < 4; rp++) {
+
+ // Load the four block_q4_0 quantized values interleaved with each other in chunks of eight - A0,A1,A2,A3
+ // Loaded as set of 128 bit vectors and repeated and stored into a 256 bit vector before again repeating into 512 bit vector
+ __m256i lhs_mat_ymm_0123_0 = _mm256_loadu_si256((const __m256i *)((a_ptrs[rp][b].qs)));
+ __m256i lhs_mat_ymm_01_0 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_0, lhs_mat_ymm_0123_0, 0);
+ __m256i lhs_mat_ymm_23_0 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_0, lhs_mat_ymm_0123_0, 17);
+ __m256i lhs_mat_ymm_0123_1 = _mm256_loadu_si256((const __m256i *)((a_ptrs[rp][b].qs + 32)));
+ __m256i lhs_mat_ymm_01_1 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_1, lhs_mat_ymm_0123_1, 0);
+ __m256i lhs_mat_ymm_23_1 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_1, lhs_mat_ymm_0123_1, 17);
+ __m256i lhs_mat_ymm_0123_2 = _mm256_loadu_si256((const __m256i *)((a_ptrs[rp][b].qs + 64)));
+ __m256i lhs_mat_ymm_01_2 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_2, lhs_mat_ymm_0123_2, 0);
+ __m256i lhs_mat_ymm_23_2 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_2, lhs_mat_ymm_0123_2, 17);
+ __m256i lhs_mat_ymm_0123_3 = _mm256_loadu_si256((const __m256i *)((a_ptrs[rp][b].qs + 96)));
+ __m256i lhs_mat_ymm_01_3 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_3, lhs_mat_ymm_0123_3, 0);
+ __m256i lhs_mat_ymm_23_3 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_3, lhs_mat_ymm_0123_3, 17);
+
+ __m512i lhs_mat_01_0 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_0), lhs_mat_ymm_01_0, 1);
+ __m512i lhs_mat_23_0 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_0), lhs_mat_ymm_23_0, 1);
+ __m512i lhs_mat_01_1 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_1), lhs_mat_ymm_01_1, 1);
+ __m512i lhs_mat_23_1 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_1), lhs_mat_ymm_23_1, 1);
+ __m512i lhs_mat_01_2 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_2), lhs_mat_ymm_01_2, 1);
+ __m512i lhs_mat_23_2 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_2), lhs_mat_ymm_23_2, 1);
+ __m512i lhs_mat_01_3 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_3), lhs_mat_ymm_01_3, 1);
+ __m512i lhs_mat_23_3 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_3), lhs_mat_ymm_23_3, 1);
+
+ // Shuffle pattern one - left side input
+
+ const __m512i lhs_mat_01_0_sp1 = _mm512_shuffle_epi32(lhs_mat_01_0, 160); //A0(0-3) A0(0-3) A1(0-3) A1(0-3) A0(0-3) A0(0-3) A1(0-3) A1(0-3) A0(0-3) A0(0-3) A1(0-3) A1(0-3) A0(0-3) A0(0-3) A1(0-3) A1(0-3)
+ const __m512i lhs_mat_23_0_sp1 = _mm512_shuffle_epi32(lhs_mat_23_0, 160); //A2(0-3) A2(0-3) A3(0-3) A3(0-3) A2(0-3) A2(0-3) A3(0-3) A3(0-3) A2(0-3) A2(0-3) A3(0-3) A3(0-3) A2(0-3) A2(0-3) A3(0-3) A3(0-3)
+
+ const __m512i lhs_mat_01_1_sp1 = _mm512_shuffle_epi32(lhs_mat_01_1, 160); //A0(8-11) A0(8-11) A1(8-11) A1(8-11) A0(8-11) A0(8-11) A1(8-11) A1(8-11) A0(8-11) A0(8-11) A1(8-11) A1(8-11) A0(8-11) A0(8-11) A1(8-11) A1(8-11)
+ const __m512i lhs_mat_23_1_sp1 = _mm512_shuffle_epi32(lhs_mat_23_1, 160); //A2(8-11) A2(8-11) A3(8-11) A3(8-11) A2(8-11) A2(8-11) A3(8-11) A3(8-11) A2(8-11) A2(8-11) A3(8-11) A3(8-11) A2(8-11) A2(8-11) A3(8-11) A3(8-11)
+
+ const __m512i lhs_mat_01_2_sp1 = _mm512_shuffle_epi32(lhs_mat_01_2, 160); //A0(16-19) A0(16-19) A1(16-19) A1(16-19) A0(16-19) A0(16-19) A1(16-19) A1(16-19) A0(16-19) A0(16-19) A1(16-19) A1(16-19) A0(16-19) A0(16-19) A1(16-19) A1(16-19)
+ const __m512i lhs_mat_23_2_sp1 = _mm512_shuffle_epi32(lhs_mat_23_2, 160); //A2(16-19) A2(16-19) A3(16-19) A3(16-19) A2(16-19) A2(16-19) A3(16-19) A3(16-19) A2(16-19) A2(16-19) A3(16-19) A3(16-19) A2(16-19) A2(16-19) A3(16-19) A3(16-19)
+
+ const __m512i lhs_mat_01_3_sp1 = _mm512_shuffle_epi32(lhs_mat_01_3, 160); //A0(24-27) A0(24-27) A1(24-27) A1(24-27) A0(24-27) A0(24-27) A1(24-27) A1(24-27) A0(24-27) A0(24-27) A1(24-27) A1(24-27) A0(24-27) A0(24-27) A1(24-27) A1(24-27)
+ const __m512i lhs_mat_23_3_sp1 = _mm512_shuffle_epi32(lhs_mat_23_3, 160); //A2(24-27) A2(24-27) A3(24-27) A3(24-27) A2(24-27) A2(24-27) A3(24-27) A3(24-27) A2(24-27) A2(24-27) A3(24-27) A3(24-27) A2(24-27) A2(24-27) A3(24-27) A3(24-27)
+
+ // Shuffle pattern two - left side input
+
+ const __m512i lhs_mat_01_0_sp2 = _mm512_shuffle_epi32(lhs_mat_01_0, 245); //A0(4-7) A0(4-7) A1(4-7) A1(4-7) A0(4-7) A0(4-7) A1(4-7) A1(4-7) A0(4-7) A0(4-7) A1(4-7) A1(4-7) A0(4-7) A0(4-7) A1(4-7) A1(4-7)
+ const __m512i lhs_mat_23_0_sp2 = _mm512_shuffle_epi32(lhs_mat_23_0, 245); //A2(4-7) A2(4-7) A3(4-7) A3(4-7) A2(4-7) A2(4-7) A3(4-7) A3(4-7) A2(4-7) A2(4-7) A3(4-7) A3(4-7) A2(4-7) A2(4-7) A3(4-7) A3(4-7)
+
+ const __m512i lhs_mat_01_1_sp2 = _mm512_shuffle_epi32(lhs_mat_01_1, 245); //A0(12-15) A0(12-15) A1(12-15) A1(12-15) A0(12-15) A0(12-15) A1(12-15) A1(12-15) A0(12-15) A0(12-15) A1(12-15) A1(12-15) A0(12-15) A0(12-15) A1(12-15) A1(12-15)
+ const __m512i lhs_mat_23_1_sp2 = _mm512_shuffle_epi32(lhs_mat_23_1, 245); //A2(12-15) A2(12-15) A3(12-15) A3(12-15) A2(12-15) A2(12-15) A3(12-15) A3(12-15) A2(12-15) A2(12-15) A3(12-15) A3(12-15) A2(12-15) A2(12-15) A3(12-15) A3(12-15)
+
+ const __m512i lhs_mat_01_2_sp2 = _mm512_shuffle_epi32(lhs_mat_01_2, 245); //A0(20-23) A0(20-23) A1(20-23) A1(20-23) A0(20-23) A0(20-23) A1(20-23) A1(20-23) A0(20-23) A0(20-23) A1(20-23) A1(20-23) A0(20-23) A0(20-23) A1(20-23) A1(20-23)
+ const __m512i lhs_mat_23_2_sp2 = _mm512_shuffle_epi32(lhs_mat_23_2, 245); //A2(20-23) A2(20-23) A3(20-23) A3(20-23) A2(20-23) A2(20-23) A3(20-23) A3(20-23) A2(20-23) A2(20-23) A3(20-23) A3(20-23) A2(20-23) A2(20-23) A3(20-23) A3(20-23)
+
+ const __m512i lhs_mat_01_3_sp2 = _mm512_shuffle_epi32(lhs_mat_01_3, 245); //A0(28-31) A0(28-31) A1(28-31) A1(28-31) A0(28-31) A0(28-31) A1(28-31) A1(28-31) A0(28-31) A0(28-31) A1(28-31) A1(28-31) A0(28-31) A0(28-31) A1(28-31) A1(28-31)
+ const __m512i lhs_mat_23_3_sp2 = _mm512_shuffle_epi32(lhs_mat_23_3, 245); //A2(28-31) A2(28-31) A3(28-31) A3(28-31) A2(28-31) A2(28-31) A3(28-31) A3(28-31) A2(28-31) A2(28-31) A3(28-31) A3(28-31) A2(28-31) A2(28-31) A3(28-31) A3(28-31)
+
+ // The values arranged in shuffle patterns are operated with dot product operation within 32 bit lane i.e corresponding bytes and multiplied and added into 32 bit integers within 32 bit lane
+ // Resembles MMLAs into 2x2 matrices in ARM Version
+ __m512i iacc_mat_00_sp1 =
+ _mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int32x16(lhs_mat_01_3_sp1, rhs_mat_014589CD_3_sp1), mul_sum_i8_pairs_int32x16(lhs_mat_01_2_sp1, rhs_mat_014589CD_2_sp1)), mul_sum_i8_pairs_int32x16(lhs_mat_01_1_sp1, rhs_mat_014589CD_1_sp1)), mul_sum_i8_pairs_int32x16(lhs_mat_01_0_sp1, rhs_mat_014589CD_0_sp1));
+ __m512i iacc_mat_01_sp1 =
+ _mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int32x16(lhs_mat_01_3_sp1, rhs_mat_2367ABEF_3_sp1), mul_sum_i8_pairs_int32x16(lhs_mat_01_2_sp1, rhs_mat_2367ABEF_2_sp1)), mul_sum_i8_pairs_int32x16(lhs_mat_01_1_sp1, rhs_mat_2367ABEF_1_sp1)), mul_sum_i8_pairs_int32x16(lhs_mat_01_0_sp1, rhs_mat_2367ABEF_0_sp1));
+ __m512i iacc_mat_10_sp1 =
+ _mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int32x16(lhs_mat_23_3_sp1, rhs_mat_014589CD_3_sp1), mul_sum_i8_pairs_int32x16(lhs_mat_23_2_sp1, rhs_mat_014589CD_2_sp1)), mul_sum_i8_pairs_int32x16(lhs_mat_23_1_sp1, rhs_mat_014589CD_1_sp1)), mul_sum_i8_pairs_int32x16(lhs_mat_23_0_sp1, rhs_mat_014589CD_0_sp1));
+ __m512i iacc_mat_11_sp1 =
+ _mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int32x16(lhs_mat_23_3_sp1, rhs_mat_2367ABEF_3_sp1), mul_sum_i8_pairs_int32x16(lhs_mat_23_2_sp1, rhs_mat_2367ABEF_2_sp1)), mul_sum_i8_pairs_int32x16(lhs_mat_23_1_sp1, rhs_mat_2367ABEF_1_sp1)), mul_sum_i8_pairs_int32x16(lhs_mat_23_0_sp1, rhs_mat_2367ABEF_0_sp1));
+ __m512i iacc_mat_00_sp2 =
+ _mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int32x16(lhs_mat_01_3_sp2, rhs_mat_014589CD_3_sp2), mul_sum_i8_pairs_int32x16(lhs_mat_01_2_sp2, rhs_mat_014589CD_2_sp2)), mul_sum_i8_pairs_int32x16(lhs_mat_01_1_sp2, rhs_mat_014589CD_1_sp2)), mul_sum_i8_pairs_int32x16(lhs_mat_01_0_sp2, rhs_mat_014589CD_0_sp2));
+ __m512i iacc_mat_01_sp2 =
+ _mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int32x16(lhs_mat_01_3_sp2, rhs_mat_2367ABEF_3_sp2), mul_sum_i8_pairs_int32x16(lhs_mat_01_2_sp2, rhs_mat_2367ABEF_2_sp2)), mul_sum_i8_pairs_int32x16(lhs_mat_01_1_sp2, rhs_mat_2367ABEF_1_sp2)), mul_sum_i8_pairs_int32x16(lhs_mat_01_0_sp2, rhs_mat_2367ABEF_0_sp2));
+ __m512i iacc_mat_10_sp2 =
+ _mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int32x16(lhs_mat_23_3_sp2, rhs_mat_014589CD_3_sp2), mul_sum_i8_pairs_int32x16(lhs_mat_23_2_sp2, rhs_mat_014589CD_2_sp2)), mul_sum_i8_pairs_int32x16(lhs_mat_23_1_sp2, rhs_mat_014589CD_1_sp2)), mul_sum_i8_pairs_int32x16(lhs_mat_23_0_sp2, rhs_mat_014589CD_0_sp2));
+ __m512i iacc_mat_11_sp2 =
+ _mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int32x16(lhs_mat_23_3_sp2, rhs_mat_2367ABEF_3_sp2), mul_sum_i8_pairs_int32x16(lhs_mat_23_2_sp2, rhs_mat_2367ABEF_2_sp2)), mul_sum_i8_pairs_int32x16(lhs_mat_23_1_sp2, rhs_mat_2367ABEF_1_sp2)), mul_sum_i8_pairs_int32x16(lhs_mat_23_0_sp2, rhs_mat_2367ABEF_0_sp2));
+
+ // Output of both shuffle patterns are added in order to sum dot product outputs of all 32 values in block
+ __m512i iacc_mat_00 = _mm512_add_epi32(iacc_mat_00_sp1, iacc_mat_00_sp2);
+ __m512i iacc_mat_01 = _mm512_add_epi32(iacc_mat_01_sp1, iacc_mat_01_sp2);
+ __m512i iacc_mat_10 = _mm512_add_epi32(iacc_mat_10_sp1, iacc_mat_10_sp2);
+ __m512i iacc_mat_11 = _mm512_add_epi32(iacc_mat_11_sp1, iacc_mat_11_sp2);
+
+
+ // Straighten out to make 4 row vectors
+ __m512i iacc_row_0 = _mm512_mask_blend_epi32(0xCCCC, iacc_mat_00, _mm512_shuffle_epi32(iacc_mat_01, 78));
+ __m512i iacc_row_1 = _mm512_mask_blend_epi32(0xCCCC, _mm512_shuffle_epi32(iacc_mat_00, 78), iacc_mat_01);
+ __m512i iacc_row_2 = _mm512_mask_blend_epi32(0xCCCC, iacc_mat_10, _mm512_shuffle_epi32(iacc_mat_11, 78));
+ __m512i iacc_row_3 = _mm512_mask_blend_epi32(0xCCCC, _mm512_shuffle_epi32(iacc_mat_10, 78), iacc_mat_11);
+
+ // Load the scale(d) values for all the 4 Q8_0 blocks and repeat it across lanes
+ const __m128i row_scale_f16 = _mm_shuffle_epi32(_mm_maskload_epi32((int const*)(a_ptrs[rp][b].d), loadMask), 68);
+ const __m512 row_scale_f32 = GGML_F32Cx16_REPEAT_LOAD(row_scale_f16);
+
+ // Multiply with appropiate scales and accumulate
+ acc_rows[rp * 4] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_0), _mm512_mul_ps(col_scale_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 0)), acc_rows[rp * 4]);
+ acc_rows[rp * 4 + 1] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_1), _mm512_mul_ps(col_scale_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 85)), acc_rows[rp * 4 + 1]);
+ acc_rows[rp * 4 + 2] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_2), _mm512_mul_ps(col_scale_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 170)), acc_rows[rp * 4 + 2]);
+ acc_rows[rp * 4 + 3] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_3), _mm512_mul_ps(col_scale_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 255)), acc_rows[rp * 4 + 3]);
+ }
+ }
+
+ // Store the accumulated values
+ for (int i = 0; i < 16; i++) {
+ _mm512_storeu_ps((float *)(s + ((y * 4 + i) * bs + x * 8)), acc_rows[i]);
+ }
+ }
+ }
+ // Take a block_q8_0x4 structures at each pass of the loop and perform dot product operation
+ for (; y < nr / 4; y ++) {
+
+ const block_q8_0x4 * a_ptr = a_ptr_start + (y * nb);
+
+ // Take group of two block_q4_0x8 structures at each pass of the loop and perform dot product operation
+ for (int64_t x = 0; x < anc / 8; x += 2) {
+
+ const block_q4_0x8 * b_ptr_0 = b_ptr_start + ((x) * b_nb);
+ const block_q4_0x8 * b_ptr_1 = b_ptr_start + ((x + 1) * b_nb);
+
+ // Master FP accumulators
+ __m512 acc_rows[4];
+ for (int i = 0; i < 4; i++) {
+ acc_rows[i] = _mm512_setzero_ps();
+ }
+
+ for (int64_t b = 0; b < nb; b++) {
+ // Load the sixteen block_q4_0 quantized values interleaved with each other in chunks of eight - B0,B1 ....BE,BF
+ const __m256i rhs_raw_mat_0123_0 = _mm256_loadu_si256((const __m256i *)(b_ptr_0[b].qs));
+ const __m256i rhs_raw_mat_4567_0 = _mm256_loadu_si256((const __m256i *)(b_ptr_0[b].qs + 32));
+ const __m256i rhs_raw_mat_0123_1 = _mm256_loadu_si256((const __m256i *)(b_ptr_0[b].qs + 64));
+ const __m256i rhs_raw_mat_4567_1 = _mm256_loadu_si256((const __m256i *)(b_ptr_0[b].qs + 96));
+
+ const __m256i rhs_raw_mat_89AB_0 = _mm256_loadu_si256((const __m256i *)(b_ptr_1[b].qs));
+ const __m256i rhs_raw_mat_CDEF_0 = _mm256_loadu_si256((const __m256i *)(b_ptr_1[b].qs + 32));
+ const __m256i rhs_raw_mat_89AB_1 = _mm256_loadu_si256((const __m256i *)(b_ptr_1[b].qs + 64));
+ const __m256i rhs_raw_mat_CDEF_1 = _mm256_loadu_si256((const __m256i *)(b_ptr_1[b].qs + 96));
+
+ // Save the values in the following vectors in the formats B0B1B4B5, B2B3B6B7 for further processing and storing of valuess
+ const __m256i rhs_raw_mat_0145_0 = _mm256_blend_epi32(rhs_raw_mat_0123_0, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_0, requiredOrder), 240);
+ const __m256i rhs_raw_mat_2367_0 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_0, requiredOrder), rhs_raw_mat_4567_0, 240);
+ const __m256i rhs_raw_mat_0145_1 = _mm256_blend_epi32(rhs_raw_mat_0123_1, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_1, requiredOrder), 240);
+ const __m256i rhs_raw_mat_2367_1 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_1, requiredOrder), rhs_raw_mat_4567_1, 240);
+
+ const __m256i rhs_raw_mat_89CD_0 = _mm256_blend_epi32(rhs_raw_mat_89AB_0, _mm256_permutevar8x32_epi32(rhs_raw_mat_CDEF_0, requiredOrder), 240);
+ const __m256i rhs_raw_mat_ABEF_0 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_89AB_0, requiredOrder), rhs_raw_mat_CDEF_0, 240);
+ const __m256i rhs_raw_mat_89CD_1 = _mm256_blend_epi32(rhs_raw_mat_89AB_1, _mm256_permutevar8x32_epi32(rhs_raw_mat_CDEF_1, requiredOrder), 240);
+ const __m256i rhs_raw_mat_ABEF_1 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_89AB_1, requiredOrder), rhs_raw_mat_CDEF_1, 240);
+
+ const __m512i rhs_raw_mat_014589CD_0 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_0145_0), rhs_raw_mat_89CD_0, 1);
+ const __m512i rhs_raw_mat_2367ABEF_0 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_2367_0), rhs_raw_mat_ABEF_0, 1);
+ const __m512i rhs_raw_mat_014589CD_1 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_0145_1), rhs_raw_mat_89CD_1, 1);
+ const __m512i rhs_raw_mat_2367ABEF_1 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_2367_1), rhs_raw_mat_ABEF_1, 1);
+
+ // 4-bit -> 8-bit - Sign is maintained
+ const __m512i rhs_mat_014589CD_0 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(rhs_raw_mat_014589CD_0, m4bexpanded)); //B0(0-7) B1(0-7) B4(0-7) B5(0-7) B8(0-7) B9(0-7) BC(0-7) BD(0-7)
+ const __m512i rhs_mat_2367ABEF_0 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(rhs_raw_mat_2367ABEF_0, m4bexpanded)); //B2(0-7) B3(0-7) B6(0-7) B7(0-7) BA(0-7) BB(0-7) BE(0-7) BF(0-7)
+
+ const __m512i rhs_mat_014589CD_1 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(rhs_raw_mat_014589CD_1, m4bexpanded)); //B0(8-15) B1(8-15) B4(8-15) B5(8-15) B8(8-15) B9(8-15) BC(8-15) BD(8-15)
+ const __m512i rhs_mat_2367ABEF_1 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(rhs_raw_mat_2367ABEF_1, m4bexpanded)); //B2(8-15) B3(8-15) B6(8-15) B7(8-15) BA(8-15) BB(8-15) BE(8-15) BF(8-15)
+
+ const __m512i rhs_mat_014589CD_2 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_014589CD_0, 4), m4bexpanded)); //B0(16-23) B1(16-23) B4(16-23) B5(16-23) B8(16-23) B9(16-23) BC(16-23) BD(16-23)
+ const __m512i rhs_mat_2367ABEF_2 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_2367ABEF_0, 4), m4bexpanded)); //B2(16-23) B3(16-23) B6(16-23) B7(16-23) BA(16-23) BB(16-23) BE(16-23) BF(16-23)
+
+ const __m512i rhs_mat_014589CD_3 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_014589CD_1, 4), m4bexpanded)); //B0(24-31) B1(24-31) B4(24-31) B5(24-31) B8(24-31) B9(24-31) BC(24-31) BD(24-31)
+ const __m512i rhs_mat_2367ABEF_3 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_2367ABEF_1, 4), m4bexpanded)); //B2(24-31) B3(24-31) B6(24-31) B7(24-31) BA(24-31) BB(24-31) BE(24-31) BF(24-31)
+
+ // Shuffle pattern one - right side input
+ const __m512i rhs_mat_014589CD_0_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_0, 136); //B0(0-3) B1(0-3) B0(0-3) B1(0-3) B4(0-3) B5(0-3) B4(0-3) B5(0-3) B8(0-3) B9(0-3) B8(0-3) B9(0-3) BC(0-3) BD(0-3) BC(0-3) BD(0-3)
+ const __m512i rhs_mat_2367ABEF_0_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_0, 136); //B2(0-3) B3(0-3) B2(0-3) B3(0-3) B6(0-3) B7(0-3) B6(0-3) B7(0-3) BA(0-3) BB(0-3) BA(0-3) BB(0-3) BE(0-3) BF(0-3) BE(0-3) BF(0-3)
+
+ const __m512i rhs_mat_014589CD_1_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_1, 136); //B0(8-11) B1(8-11) B0(8-11) B1(8-11) B4(8-11) B5(8-11) B4(8-11) B5(8-11) B8(8-11) B9(8-11) B8(8-11) B9(8-11) BC(8-11) BD(8-11) BC(8-11) BD(8-11)
+ const __m512i rhs_mat_2367ABEF_1_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_1, 136); //B2(8-11) B3(8-11) B2(8-11) B3(8-11) B6(8-11) B7(8-11) B6(8-11) B7(8-11) BA(8-11) BB(8-11) BA(8-11) BB(8-11) BE(8-11) BF(8-11) BE(8-11) BF(8-11)
+
+ const __m512i rhs_mat_014589CD_2_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_2, 136); //B0(16-19) B1(16-19) B0(16-19) B1(16-19) B4(16-19) B5(16-19) B4(16-19) B5(16-19) B8(16-19) B9(16-19) B8(16-19) B9(16-19) BC(16-19) BD(16-19) BC(16-19) BD(16-19)
+ const __m512i rhs_mat_2367ABEF_2_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_2, 136); //B2(16-19) B3(16-19) B2(16-19) B3(16-19) B6(16-19) B7(16-19) B6(16-19) B7(16-19) BA(16-19) BB(16-19) BA(16-19) BB(16-19) BE(16-19) BF(16-19) BE(16-19) BF(16-19)
+
+ const __m512i rhs_mat_014589CD_3_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_3, 136); //B0(24-27) B1(24-27) B0(24-27) B1(24-27) B4(24-27) B5(24-27) B4(24-27) B5(24-27) B8(24-27) B9(24-27) B8(24-27) B9(24-27) BC(24-27) BD(24-27) BC(24-27) BD(24-27)
+ const __m512i rhs_mat_2367ABEF_3_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_3, 136); //B2(24-27) B3(24-27) B2(24-27) B3(24-27) B6(24-27) B7(24-27) B6(24-27) B7(24-27) BA(24-27) BB(24-27) BA(24-27) BB(24-27) BE(24-27) BF(24-27) BE(24-27) BF(24-27)
+
+ // Shuffle pattern two - right side input
+
+ const __m512i rhs_mat_014589CD_0_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_0, 221); //B0(4-7) B1(4-7) B0(4-7) B1(4-7) B4(4-7) B5(4-7) B4(4-7) B5(4-7) B8(4-7) B9(4-7) B8(4-7) B9(4-7) BC(4-7) BD(4-7) BC(4-7) BD(4-7)
+ const __m512i rhs_mat_2367ABEF_0_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_0, 221); //B2(4-7) B3(4-7) B2(4-7) B3(4-7) B6(4-7) B7(4-7) B6(4-7) B7(4-7) BA(4-7) BB(4-7) BA(4-7) BB(4-7) BE(4-7) BF(4-7) BE(4-7) BF(4-7)
+
+ const __m512i rhs_mat_014589CD_1_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_1, 221); //B0(12-15) B1(12-15) B0(12-15) B1(12-15) B4(12-15) B5(12-15) B4(12-15) B5(12-15) B8(12-15) B9(12-15) B8(12-15) B9(12-15) BC(12-15) BD(12-15) BC(12-15) BD(12-15)
+ const __m512i rhs_mat_2367ABEF_1_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_1, 221); //B2(12-15) B3(12-15) B2(12-15) B3(12-15) B6(12-15) B7(12-15) B6(12-15) B7(12-15) BA(12-15) BB(12-15) BA(12-15) BB(12-15) BE(12-15) BF(12-15) BE(12-15) BF(12-15)
+
+ const __m512i rhs_mat_014589CD_2_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_2, 221); //B0(20-23) B1(20-23) B0(20-23) B1(20-23) B4(20-23) B5(20-23) B4(20-23) B5(20-23) B8(20-23) B9(20-23) B8(20-23) B9(20-23) BC(20-23) BD(20-23) BC(20-23) BD(20-23)
+ const __m512i rhs_mat_2367ABEF_2_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_2, 221); //B2(20-23) B3(20-23) B2(20-23) B3(20-23) B6(20-23) B7(20-23) B6(20-23) B7(20-23) BA(20-23) BB(20-23) BA(20-23) BB(20-23) BE(20-23) BF(20-23) BE(20-23) BF(20-23)
+
+ const __m512i rhs_mat_014589CD_3_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_3, 221); //B0(28-31) B1(28-31) B0(28-31) B1(28-31) B4(28-31) B5(28-31) B4(28-31) B5(28-31) B8(28-31) B9(28-31) B8(28-31) B9(28-31) BC(28-31) BD(28-31) BC(28-31) BD(28-31)
+ const __m512i rhs_mat_2367ABEF_3_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_3, 221); //B2(28-31) B3(28-31) B2(28-31) B3(28-31) B6(28-31) B7(28-31) B6(28-31) B7(28-31) BA(28-31) BB(28-31) BA(28-31) BB(28-31) BE(28-31) BF(28-31) BE(28-31) BF(28-31)
+
+
+ // Scale values - Load the weight scale values of two block_q4_0x8
+ const __m512 col_scale_f32 = GGML_F32Cx8x2_LOAD(b_ptr_0[b].d, b_ptr_1[b].d);
+
+ // Load the four block_q4_0 quantized values interleaved with each other in chunks of eight - A0,A1,A2,A3
+ // Loaded as set of 128 bit vectors and repeated and stored into a 256 bit vector before again repeating into 512 bit vector
+ __m256i lhs_mat_ymm_0123_0 = _mm256_loadu_si256((const __m256i *)((a_ptr[b].qs)));
+ __m256i lhs_mat_ymm_01_0 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_0, lhs_mat_ymm_0123_0, 0);
+ __m256i lhs_mat_ymm_23_0 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_0, lhs_mat_ymm_0123_0, 17);
+ __m256i lhs_mat_ymm_0123_1 = _mm256_loadu_si256((const __m256i *)((a_ptr[b].qs + 32)));
+ __m256i lhs_mat_ymm_01_1 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_1, lhs_mat_ymm_0123_1, 0);
+ __m256i lhs_mat_ymm_23_1 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_1, lhs_mat_ymm_0123_1, 17);
+ __m256i lhs_mat_ymm_0123_2 = _mm256_loadu_si256((const __m256i *)((a_ptr[b].qs + 64)));
+ __m256i lhs_mat_ymm_01_2 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_2, lhs_mat_ymm_0123_2, 0);
+ __m256i lhs_mat_ymm_23_2 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_2, lhs_mat_ymm_0123_2, 17);
+ __m256i lhs_mat_ymm_0123_3 = _mm256_loadu_si256((const __m256i *)((a_ptr[b].qs + 96)));
+ __m256i lhs_mat_ymm_01_3 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_3, lhs_mat_ymm_0123_3, 0);
+ __m256i lhs_mat_ymm_23_3 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_3, lhs_mat_ymm_0123_3, 17);
+
+ __m512i lhs_mat_01_0 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_0), lhs_mat_ymm_01_0, 1);
+ __m512i lhs_mat_23_0 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_0), lhs_mat_ymm_23_0, 1);
+ __m512i lhs_mat_01_1 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_1), lhs_mat_ymm_01_1, 1);
+ __m512i lhs_mat_23_1 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_1), lhs_mat_ymm_23_1, 1);
+ __m512i lhs_mat_01_2 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_2), lhs_mat_ymm_01_2, 1);
+ __m512i lhs_mat_23_2 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_2), lhs_mat_ymm_23_2, 1);
+ __m512i lhs_mat_01_3 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_3), lhs_mat_ymm_01_3, 1);
+ __m512i lhs_mat_23_3 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_3), lhs_mat_ymm_23_3, 1);
+
+ // Shuffle pattern one - left side input
+
+ const __m512i lhs_mat_01_0_sp1 = _mm512_shuffle_epi32(lhs_mat_01_0, 160); //A0(0-3) A0(0-3) A1(0-3) A1(0-3) A0(0-3) A0(0-3) A1(0-3) A1(0-3) A0(0-3) A0(0-3) A1(0-3) A1(0-3) A0(0-3) A0(0-3) A1(0-3) A1(0-3)
+ const __m512i lhs_mat_23_0_sp1 = _mm512_shuffle_epi32(lhs_mat_23_0, 160); //A2(0-3) A2(0-3) A3(0-3) A3(0-3) A2(0-3) A2(0-3) A3(0-3) A3(0-3) A2(0-3) A2(0-3) A3(0-3) A3(0-3) A2(0-3) A2(0-3) A3(0-3) A3(0-3)
+
+ const __m512i lhs_mat_01_1_sp1 = _mm512_shuffle_epi32(lhs_mat_01_1, 160); //A0(8-11) A0(8-11) A1(8-11) A1(8-11) A0(8-11) A0(8-11) A1(8-11) A1(8-11) A0(8-11) A0(8-11) A1(8-11) A1(8-11) A0(8-11) A0(8-11) A1(8-11) A1(8-11)
+ const __m512i lhs_mat_23_1_sp1 = _mm512_shuffle_epi32(lhs_mat_23_1, 160); //A2(8-11) A2(8-11) A3(8-11) A3(8-11) A2(8-11) A2(8-11) A3(8-11) A3(8-11) A2(8-11) A2(8-11) A3(8-11) A3(8-11) A2(8-11) A2(8-11) A3(8-11) A3(8-11)
+
+ const __m512i lhs_mat_01_2_sp1 = _mm512_shuffle_epi32(lhs_mat_01_2, 160); //A0(16-19) A0(16-19) A1(16-19) A1(16-19) A0(16-19) A0(16-19) A1(16-19) A1(16-19) A0(16-19) A0(16-19) A1(16-19) A1(16-19) A0(16-19) A0(16-19) A1(16-19) A1(16-19)
+ const __m512i lhs_mat_23_2_sp1 = _mm512_shuffle_epi32(lhs_mat_23_2, 160); //A2(16-19) A2(16-19) A3(16-19) A3(16-19) A2(16-19) A2(16-19) A3(16-19) A3(16-19) A2(16-19) A2(16-19) A3(16-19) A3(16-19) A2(16-19) A2(16-19) A3(16-19) A3(16-19)
+
+ const __m512i lhs_mat_01_3_sp1 = _mm512_shuffle_epi32(lhs_mat_01_3, 160); //A0(24-27) A0(24-27) A1(24-27) A1(24-27) A0(24-27) A0(24-27) A1(24-27) A1(24-27) A0(24-27) A0(24-27) A1(24-27) A1(24-27) A0(24-27) A0(24-27) A1(24-27) A1(24-27)
+ const __m512i lhs_mat_23_3_sp1 = _mm512_shuffle_epi32(lhs_mat_23_3, 160); //A2(24-27) A2(24-27) A3(24-27) A3(24-27) A2(24-27) A2(24-27) A3(24-27) A3(24-27) A2(24-27) A2(24-27) A3(24-27) A3(24-27) A2(24-27) A2(24-27) A3(24-27) A3(24-27)
+
+ // Shuffle pattern two - left side input
+
+ const __m512i lhs_mat_01_0_sp2 = _mm512_shuffle_epi32(lhs_mat_01_0, 245); //A0(4-7) A0(4-7) A1(4-7) A1(4-7) A0(4-7) A0(4-7) A1(4-7) A1(4-7) A0(4-7) A0(4-7) A1(4-7) A1(4-7) A0(4-7) A0(4-7) A1(4-7) A1(4-7)
+ const __m512i lhs_mat_23_0_sp2 = _mm512_shuffle_epi32(lhs_mat_23_0, 245); //A2(4-7) A2(4-7) A3(4-7) A3(4-7) A2(4-7) A2(4-7) A3(4-7) A3(4-7) A2(4-7) A2(4-7) A3(4-7) A3(4-7) A2(4-7) A2(4-7) A3(4-7) A3(4-7)
+
+ const __m512i lhs_mat_01_1_sp2 = _mm512_shuffle_epi32(lhs_mat_01_1, 245); //A0(12-15) A0(12-15) A1(12-15) A1(12-15) A0(12-15) A0(12-15) A1(12-15) A1(12-15) A0(12-15) A0(12-15) A1(12-15) A1(12-15) A0(12-15) A0(12-15) A1(12-15) A1(12-15)
+ const __m512i lhs_mat_23_1_sp2 = _mm512_shuffle_epi32(lhs_mat_23_1, 245); //A2(12-15) A2(12-15) A3(12-15) A3(12-15) A2(12-15) A2(12-15) A3(12-15) A3(12-15) A2(12-15) A2(12-15) A3(12-15) A3(12-15) A2(12-15) A2(12-15) A3(12-15) A3(12-15)
+
+ const __m512i lhs_mat_01_2_sp2 = _mm512_shuffle_epi32(lhs_mat_01_2, 245); //A0(20-23) A0(20-23) A1(20-23) A1(20-23) A0(20-23) A0(20-23) A1(20-23) A1(20-23) A0(20-23) A0(20-23) A1(20-23) A1(20-23) A0(20-23) A0(20-23) A1(20-23) A1(20-23)
+ const __m512i lhs_mat_23_2_sp2 = _mm512_shuffle_epi32(lhs_mat_23_2, 245); //A2(20-23) A2(20-23) A3(20-23) A3(20-23) A2(20-23) A2(20-23) A3(20-23) A3(20-23) A2(20-23) A2(20-23) A3(20-23) A3(20-23) A2(20-23) A2(20-23) A3(20-23) A3(20-23)
+
+ const __m512i lhs_mat_01_3_sp2 = _mm512_shuffle_epi32(lhs_mat_01_3, 245); //A0(28-31) A0(28-31) A1(28-31) A1(28-31) A0(28-31) A0(28-31) A1(28-31) A1(28-31) A0(28-31) A0(28-31) A1(28-31) A1(28-31) A0(28-31) A0(28-31) A1(28-31) A1(28-31)
+ const __m512i lhs_mat_23_3_sp2 = _mm512_shuffle_epi32(lhs_mat_23_3, 245); //A2(28-31) A2(28-31) A3(28-31) A3(28-31) A2(28-31) A2(28-31) A3(28-31) A3(28-31) A2(28-31) A2(28-31) A3(28-31) A3(28-31) A2(28-31) A2(28-31) A3(28-31) A3(28-31)
+
+ // The values arranged in shuffle patterns are operated with dot product operation within 32 bit lane i.e corresponding bytes and multiplied and added into 32 bit integers within 32 bit lane
+ // Resembles MMLAs into 2x2 matrices in ARM Version
+ __m512i iacc_mat_00_sp1 =
+ _mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int32x16(lhs_mat_01_3_sp1, rhs_mat_014589CD_3_sp1), mul_sum_i8_pairs_int32x16(lhs_mat_01_2_sp1, rhs_mat_014589CD_2_sp1)), mul_sum_i8_pairs_int32x16(lhs_mat_01_1_sp1, rhs_mat_014589CD_1_sp1)), mul_sum_i8_pairs_int32x16(lhs_mat_01_0_sp1, rhs_mat_014589CD_0_sp1));
+ __m512i iacc_mat_01_sp1 =
+ _mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int32x16(lhs_mat_01_3_sp1, rhs_mat_2367ABEF_3_sp1), mul_sum_i8_pairs_int32x16(lhs_mat_01_2_sp1, rhs_mat_2367ABEF_2_sp1)), mul_sum_i8_pairs_int32x16(lhs_mat_01_1_sp1, rhs_mat_2367ABEF_1_sp1)), mul_sum_i8_pairs_int32x16(lhs_mat_01_0_sp1, rhs_mat_2367ABEF_0_sp1));
+ __m512i iacc_mat_10_sp1 =
+ _mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int32x16(lhs_mat_23_3_sp1, rhs_mat_014589CD_3_sp1), mul_sum_i8_pairs_int32x16(lhs_mat_23_2_sp1, rhs_mat_014589CD_2_sp1)), mul_sum_i8_pairs_int32x16(lhs_mat_23_1_sp1, rhs_mat_014589CD_1_sp1)), mul_sum_i8_pairs_int32x16(lhs_mat_23_0_sp1, rhs_mat_014589CD_0_sp1));
+ __m512i iacc_mat_11_sp1 =
+ _mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int32x16(lhs_mat_23_3_sp1, rhs_mat_2367ABEF_3_sp1), mul_sum_i8_pairs_int32x16(lhs_mat_23_2_sp1, rhs_mat_2367ABEF_2_sp1)), mul_sum_i8_pairs_int32x16(lhs_mat_23_1_sp1, rhs_mat_2367ABEF_1_sp1)), mul_sum_i8_pairs_int32x16(lhs_mat_23_0_sp1, rhs_mat_2367ABEF_0_sp1));
+ __m512i iacc_mat_00_sp2 =
+ _mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int32x16(lhs_mat_01_3_sp2, rhs_mat_014589CD_3_sp2), mul_sum_i8_pairs_int32x16(lhs_mat_01_2_sp2, rhs_mat_014589CD_2_sp2)), mul_sum_i8_pairs_int32x16(lhs_mat_01_1_sp2, rhs_mat_014589CD_1_sp2)), mul_sum_i8_pairs_int32x16(lhs_mat_01_0_sp2, rhs_mat_014589CD_0_sp2));
+ __m512i iacc_mat_01_sp2 =
+ _mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int32x16(lhs_mat_01_3_sp2, rhs_mat_2367ABEF_3_sp2), mul_sum_i8_pairs_int32x16(lhs_mat_01_2_sp2, rhs_mat_2367ABEF_2_sp2)), mul_sum_i8_pairs_int32x16(lhs_mat_01_1_sp2, rhs_mat_2367ABEF_1_sp2)), mul_sum_i8_pairs_int32x16(lhs_mat_01_0_sp2, rhs_mat_2367ABEF_0_sp2));
+ __m512i iacc_mat_10_sp2 =
+ _mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int32x16(lhs_mat_23_3_sp2, rhs_mat_014589CD_3_sp2), mul_sum_i8_pairs_int32x16(lhs_mat_23_2_sp2, rhs_mat_014589CD_2_sp2)), mul_sum_i8_pairs_int32x16(lhs_mat_23_1_sp2, rhs_mat_014589CD_1_sp2)), mul_sum_i8_pairs_int32x16(lhs_mat_23_0_sp2, rhs_mat_014589CD_0_sp2));
+ __m512i iacc_mat_11_sp2 =
+ _mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int32x16(lhs_mat_23_3_sp2, rhs_mat_2367ABEF_3_sp2), mul_sum_i8_pairs_int32x16(lhs_mat_23_2_sp2, rhs_mat_2367ABEF_2_sp2)), mul_sum_i8_pairs_int32x16(lhs_mat_23_1_sp2, rhs_mat_2367ABEF_1_sp2)), mul_sum_i8_pairs_int32x16(lhs_mat_23_0_sp2, rhs_mat_2367ABEF_0_sp2));
+
+ // Output of both shuffle patterns are added in order to sum dot product outputs of all 32 values in block
+ __m512i iacc_mat_00 = _mm512_add_epi32(iacc_mat_00_sp1, iacc_mat_00_sp2);
+ __m512i iacc_mat_01 = _mm512_add_epi32(iacc_mat_01_sp1, iacc_mat_01_sp2);
+ __m512i iacc_mat_10 = _mm512_add_epi32(iacc_mat_10_sp1, iacc_mat_10_sp2);
+ __m512i iacc_mat_11 = _mm512_add_epi32(iacc_mat_11_sp1, iacc_mat_11_sp2);
+
+
+ // Straighten out to make 4 row vectors
+ __m512i iacc_row_0 = _mm512_mask_blend_epi32(0xCCCC, iacc_mat_00, _mm512_shuffle_epi32(iacc_mat_01, 78));
+ __m512i iacc_row_1 = _mm512_mask_blend_epi32(0xCCCC, _mm512_shuffle_epi32(iacc_mat_00, 78), iacc_mat_01);
+ __m512i iacc_row_2 = _mm512_mask_blend_epi32(0xCCCC, iacc_mat_10, _mm512_shuffle_epi32(iacc_mat_11, 78));
+ __m512i iacc_row_3 = _mm512_mask_blend_epi32(0xCCCC, _mm512_shuffle_epi32(iacc_mat_10, 78), iacc_mat_11);
+
+ // Load the scale(d) values for all the 4 Q8_0 blocks and repeat it across lanes
+ const __m128i row_scale_f16 = _mm_shuffle_epi32(_mm_maskload_epi32((int const*)(a_ptr[b].d), loadMask), 68);
+ const __m512 row_scale_f32 = GGML_F32Cx16_REPEAT_LOAD(row_scale_f16);
+
+ // Multiply with appropiate scales and accumulate
+ acc_rows[0] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_0), _mm512_mul_ps(col_scale_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 0)), acc_rows[0]);
+ acc_rows[1] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_1), _mm512_mul_ps(col_scale_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 85)), acc_rows[1]);
+ acc_rows[2] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_2), _mm512_mul_ps(col_scale_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 170)), acc_rows[2]);
+ acc_rows[3] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_3), _mm512_mul_ps(col_scale_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 255)), acc_rows[3]);
+ }
+
+ // Store the accumulated values
+ for (int i = 0; i < 4; i++) {
+ _mm512_storeu_ps((float *)(s + ((y * 4 + i) * bs + x * 8)), acc_rows[i]);
+ }
+ }
+ }
+ if (anc != nc) {
+ xstart = anc/8;
+ y = 0;
+ }
+ #endif // __AVX512F__
+
+ // Take group of four block_q8_0x4 structures at each pass of the loop and perform dot product operation
+
+ for (; y < anr / 4; y += 4) {
+ const block_q8_0x4 * a_ptrs[4];
+
+ a_ptrs[0] = a_ptr_start + (y * nb);
+ for (int i = 0; i < 3; ++i) {
+ a_ptrs[i + 1] = a_ptrs[i] + nb;
+ }
+
+ // Take group of eight block_q4_0x8 structures at each pass of the loop and perform dot product operation
+ for (int64_t x = xstart; x < nc / 8; x++) {
+
+ const block_q4_0x8 * b_ptr = b_ptr_start + (x * b_nb);
+
+ // Master FP accumulators
+ __m256 acc_rows[16];
+ for (int i = 0; i < 16; i++) {
+ acc_rows[i] = _mm256_setzero_ps();
+ }
+
+ for (int64_t b = 0; b < nb; b++) {
+ // Load the eight block_q4_0 quantized values interleaved with each other in chunks of eight - B0,B1 ....B6,B7
+ const __m256i rhs_raw_mat_0123_0 = _mm256_loadu_si256((const __m256i *)(b_ptr[b].qs));
+ const __m256i rhs_raw_mat_4567_0 = _mm256_loadu_si256((const __m256i *)(b_ptr[b].qs + 32));
+ const __m256i rhs_raw_mat_0123_1 = _mm256_loadu_si256((const __m256i *)(b_ptr[b].qs + 64));
+ const __m256i rhs_raw_mat_4567_1 = _mm256_loadu_si256((const __m256i *)(b_ptr[b].qs + 96));
+
+ // Save the values in the following vectors in the formats B0B1B4B5, B2B3B6B7 for further processing and storing of values
+ const __m256i rhs_raw_mat_0145_0 = _mm256_blend_epi32(rhs_raw_mat_0123_0, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_0, requiredOrder), 240);
+ const __m256i rhs_raw_mat_2367_0 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_0, requiredOrder), rhs_raw_mat_4567_0, 240);
+ const __m256i rhs_raw_mat_0145_1 = _mm256_blend_epi32(rhs_raw_mat_0123_1, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_1, requiredOrder), 240);
+ const __m256i rhs_raw_mat_2367_1 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_1, requiredOrder), rhs_raw_mat_4567_1, 240);
+
+ // 4-bit -> 8-bit - Sign is maintained
+ const __m256i rhs_mat_0145_0 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(rhs_raw_mat_0145_0, m4b)); //B0(0-7) B1(0-7) B4(0-7) B5(0-7)
+ const __m256i rhs_mat_2367_0 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(rhs_raw_mat_2367_0, m4b)); //B2(0-7) B3(0-7) B6(0-7) B7(0-7)
+
+ const __m256i rhs_mat_0145_1 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(rhs_raw_mat_0145_1, m4b)); //B0(8-15) B1(8-15) B4(8-15) B5(8-15)
+ const __m256i rhs_mat_2367_1 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(rhs_raw_mat_2367_1, m4b)); //B2(8-15) B3(8-15) B6(8-15) B7(8-15)
+
+ const __m256i rhs_mat_0145_2 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_0145_0, 4), m4b)); //B0(16-23) B1(16-23) B4(16-23) B5(16-23)
+ const __m256i rhs_mat_2367_2 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_2367_0, 4), m4b)); //B2(16-23) B3(16-23) B6(16-23) B7(16-23)
+
+ const __m256i rhs_mat_0145_3 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_0145_1, 4), m4b)); //B0(24-31) B1(24-31) B4(24-31) B5(24-31)
+ const __m256i rhs_mat_2367_3 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_2367_1, 4), m4b)); //B2(24-31) B3(24-31) B6(24-31) B7(24-31)
+
+ // Shuffle pattern one - right side input
+ const __m256i rhs_mat_0145_0_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_0, 136); //B0(0-3) B1(0-3) B0(0-3) B1(0-3) B4(0-3) B5(0-3) B4(0-3) B5(0-3)
+ const __m256i rhs_mat_2367_0_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_0, 136); //B2(0-3) B3(0-3) B2(0-3) B3(0-3) B6(0-3) B7(0-3) B6(0-3) B7(0-3)
+
+ const __m256i rhs_mat_0145_1_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_1, 136); //B0(8-11) B1(8-11) B0(8-11) B1(8-11) B4(8-11) B5(8-11) B4(8-11) B5(8-11)
+ const __m256i rhs_mat_2367_1_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_1, 136); //B2(8-11) B3(8-11) B2(8-11) B3(8-11) B6(8-11) B7(8-11) B6(8-11) B7(8-11)
+
+ const __m256i rhs_mat_0145_2_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_2, 136); //B0(16-19) B1(16-19) B0(16-19) B1(16-19) B4(16-19) B5(16-19) B4(16-19) B5(16-19)
+ const __m256i rhs_mat_2367_2_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_2, 136); //B2(16-19) B3(16-19) B2(16-19) B3(16-19) B6(16-19) B7(16-19) B6(16-19) B7(16-19)
+
+ const __m256i rhs_mat_0145_3_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_3, 136); //B0(24-27) B1(24-27) B0(24-27) B1(24-27) B4(24-27) B5(24-27) B4(24-27) B5(24-27)
+ const __m256i rhs_mat_2367_3_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_3, 136); //B2(24-27) B3(24-27) B2(24-27) B3(24-27) B6(24-27) B7(24-27) B6(24-27) B7(24-27)
+
+ // Shuffle pattern two - right side input
+
+ const __m256i rhs_mat_0145_0_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_0, 221); //B0(4-7) B1(4-7) B0(4-7) B1(4-7) B4(4-7) B5(4-7) B4(4-7) B5(4-7)
+ const __m256i rhs_mat_2367_0_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_0, 221); //B2(4-7) B3(4-7) B2(4-7) B3(4-7) B6(4-7) B7(4-7) B6(4-7) B7(4-7)
+
+ const __m256i rhs_mat_0145_1_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_1, 221); //B0(12-15) B1(12-15) B0(12-15) B1(12-15) B4(12-15) B5(12-15) B4(12-15) B5(12-15)
+ const __m256i rhs_mat_2367_1_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_1, 221); //B2(12-15) B3(12-15) B2(12-15) B3(12-15) B6(12-15) B7(12-15) B6(12-15) B7(12-15)
+
+ const __m256i rhs_mat_0145_2_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_2, 221); //B0(20-23) B1(20-23) B0(20-23) B1(20-23) B4(20-23) B5(20-23) B4(20-23) B5(20-23)
+ const __m256i rhs_mat_2367_2_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_2, 221); //B2(20-23) B3(20-23) B2(20-23) B3(20-23) B6(20-23) B7(20-23) B6(20-23) B7(20-23)
+
+ const __m256i rhs_mat_0145_3_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_3, 221); //B0(28-31) B1(28-31) B0(28-31) B1(28-31) B4(28-31) B5(28-31) B4(28-31) B5(28-31)
+ const __m256i rhs_mat_2367_3_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_3, 221); //B2(28-31) B3(28-31) B2(28-31) B3(28-31) B6(28-31) B7(28-31) B6(28-31) B7(28-31)
+
+ // Scale values - Load the wight scale values of block_q4_0x8
+ const __m256 col_scale_f32 = GGML_F32Cx8_LOAD(b_ptr[b].d);
+
+ // Process LHS in groups of four
+ for (int rp = 0; rp < 4; rp++) {
+ // Load the four block_q4_0 quantized values interleaved with each other in chunks of eight - A0,A1,A2,A3
+ // Loaded as set of 128 bit vectors and repeated into a 256 bit vector
+ __m256i lhs_mat_0123_0 = _mm256_loadu_si256((const __m256i *)((a_ptrs[rp][b].qs)));
+ __m256i lhs_mat_01_0 = _mm256_permute2f128_si256(lhs_mat_0123_0, lhs_mat_0123_0, 0);
+ __m256i lhs_mat_23_0 = _mm256_permute2f128_si256(lhs_mat_0123_0, lhs_mat_0123_0, 17);
+ __m256i lhs_mat_0123_1 = _mm256_loadu_si256((const __m256i *)((a_ptrs[rp][b].qs + 32)));
+ __m256i lhs_mat_01_1 = _mm256_permute2f128_si256(lhs_mat_0123_1, lhs_mat_0123_1, 0);
+ __m256i lhs_mat_23_1 = _mm256_permute2f128_si256(lhs_mat_0123_1, lhs_mat_0123_1, 17);
+ __m256i lhs_mat_0123_2 = _mm256_loadu_si256((const __m256i *)((a_ptrs[rp][b].qs + 64)));
+ __m256i lhs_mat_01_2 = _mm256_permute2f128_si256(lhs_mat_0123_2, lhs_mat_0123_2, 0);
+ __m256i lhs_mat_23_2 = _mm256_permute2f128_si256(lhs_mat_0123_2, lhs_mat_0123_2, 17);
+ __m256i lhs_mat_0123_3 = _mm256_loadu_si256((const __m256i *)((a_ptrs[rp][b].qs + 96)));
+ __m256i lhs_mat_01_3 = _mm256_permute2f128_si256(lhs_mat_0123_3, lhs_mat_0123_3, 0);
+ __m256i lhs_mat_23_3 = _mm256_permute2f128_si256(lhs_mat_0123_3, lhs_mat_0123_3, 17);
+
+ // Shuffle pattern one - left side input
+ const __m256i lhs_mat_01_0_sp1 = _mm256_shuffle_epi32(lhs_mat_01_0, 160); //A0(0-3) A0(0-3) A1(0-3) A1(0-3) A0(0-3) A0(0-3) A1(0-3) A1(0-3)
+ const __m256i lhs_mat_23_0_sp1 = _mm256_shuffle_epi32(lhs_mat_23_0, 160); //A2(0-3) A2(0-3) A3(0-3) A3(0-3) A2(0-3) A2(0-3) A3(0-3) A3(0-3)
+
+ const __m256i lhs_mat_01_1_sp1 = _mm256_shuffle_epi32(lhs_mat_01_1, 160); //A0(8-11) A0(8-11) A1(8-11) A1(8-11) A0(8-11) A0(8-11) A1(8-11) A1(8-11)
+ const __m256i lhs_mat_23_1_sp1 = _mm256_shuffle_epi32(lhs_mat_23_1, 160); //A2(8-11) A2(8-11) A3(8-11) A3(8-11) A2(8-11) A2(8-11) A3(8-11) A3(8-11)
+
+ const __m256i lhs_mat_01_2_sp1 = _mm256_shuffle_epi32(lhs_mat_01_2, 160); //A0(16-19) A0(16-19) A1(16-19) A1(16-19) A0(16-19) A0(16-19) A1(16-19) A1(16-19)
+ const __m256i lhs_mat_23_2_sp1 = _mm256_shuffle_epi32(lhs_mat_23_2, 160); //A2(16-19) A2(16-19) A3(16-19) A3(16-19) A2(16-19) A2(16-19) A3(16-19) A3(16-19)
+
+ const __m256i lhs_mat_01_3_sp1 = _mm256_shuffle_epi32(lhs_mat_01_3, 160); //A0(24-27) A0(24-27) A1(24-27) A1(24-27) A0(24-27) A0(24-27) A1(24-27) A1(24-27)
+ const __m256i lhs_mat_23_3_sp1 = _mm256_shuffle_epi32(lhs_mat_23_3, 160); //A2(24-27) A2(24-27) A3(24-27) A3(24-27) A2(24-27) A2(24-27) A3(24-27) A3(24-27)
+
+ // Shuffle pattern two - left side input
+ const __m256i lhs_mat_01_0_sp2 = _mm256_shuffle_epi32(lhs_mat_01_0, 245); //A0(4-7) A0(4-7) A1(4-7) A1(4-7) A0(4-7) A0(4-7) A1(4-7) A1(4-7)
+ const __m256i lhs_mat_23_0_sp2 = _mm256_shuffle_epi32(lhs_mat_23_0, 245); //A2(4-7) A2(4-7) A3(4-7) A3(4-7) A2(4-7) A2(4-7) A3(4-7) A3(4-7)
+
+ const __m256i lhs_mat_01_1_sp2 = _mm256_shuffle_epi32(lhs_mat_01_1, 245); //A0(12-15) A0(12-15) A1(12-15) A1(12-15) A0(12-15) A0(12-15) A1(12-15) A1(12-15)
+ const __m256i lhs_mat_23_1_sp2 = _mm256_shuffle_epi32(lhs_mat_23_1, 245); //A2(12-15) A2(12-15) A3(12-15) A3(12-15) A2(12-15) A2(12-15) A3(12-15) A3(12-15)
+
+ const __m256i lhs_mat_01_2_sp2 = _mm256_shuffle_epi32(lhs_mat_01_2, 245); //A0(20-23) A0(20-23) A1(20-23) A1(20-23) A0(20-23) A0(20-23) A1(20-23) A1(20-23)
+ const __m256i lhs_mat_23_2_sp2 = _mm256_shuffle_epi32(lhs_mat_23_2, 245); //A2(20-23) A2(20-23) A3(20-23) A3(20-23) A2(20-23) A2(20-23) A3(20-23) A3(20-23)
+
+ const __m256i lhs_mat_01_3_sp2 = _mm256_shuffle_epi32(lhs_mat_01_3, 245); //A0(28-31) A0(28-31) A1(28-31) A1(28-31) A0(28-31) A0(28-31) A1(28-31) A1(28-31)
+ const __m256i lhs_mat_23_3_sp2 = _mm256_shuffle_epi32(lhs_mat_23_3, 245); //A2(28-31) A2(28-31) A3(28-31) A3(28-31) A2(28-31) A2(28-31) A3(28-31) A3(28-31)
+
+ // The values arranged in shuffle patterns are operated with dot product operation within 32 bit lane i.e corresponding bytes and multiplied and added into 32 bit integers within 32 bit lane
+ // Resembles MMLAs into 2x2 matrices in ARM Version
+ __m256i iacc_mat_00_sp1 =
+ _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int32x8(lhs_mat_01_3_sp1, rhs_mat_0145_3_sp1), mul_sum_i8_pairs_int32x8(lhs_mat_01_2_sp1, rhs_mat_0145_2_sp1)), mul_sum_i8_pairs_int32x8(lhs_mat_01_1_sp1, rhs_mat_0145_1_sp1)), mul_sum_i8_pairs_int32x8(lhs_mat_01_0_sp1, rhs_mat_0145_0_sp1));
+ __m256i iacc_mat_01_sp1 =
+ _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int32x8(lhs_mat_01_3_sp1, rhs_mat_2367_3_sp1), mul_sum_i8_pairs_int32x8(lhs_mat_01_2_sp1, rhs_mat_2367_2_sp1)), mul_sum_i8_pairs_int32x8(lhs_mat_01_1_sp1, rhs_mat_2367_1_sp1)), mul_sum_i8_pairs_int32x8(lhs_mat_01_0_sp1, rhs_mat_2367_0_sp1));
+ __m256i iacc_mat_10_sp1 =
+ _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int32x8(lhs_mat_23_3_sp1, rhs_mat_0145_3_sp1), mul_sum_i8_pairs_int32x8(lhs_mat_23_2_sp1, rhs_mat_0145_2_sp1)), mul_sum_i8_pairs_int32x8(lhs_mat_23_1_sp1, rhs_mat_0145_1_sp1)), mul_sum_i8_pairs_int32x8(lhs_mat_23_0_sp1, rhs_mat_0145_0_sp1));
+ __m256i iacc_mat_11_sp1 =
+ _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int32x8(lhs_mat_23_3_sp1, rhs_mat_2367_3_sp1), mul_sum_i8_pairs_int32x8(lhs_mat_23_2_sp1, rhs_mat_2367_2_sp1)), mul_sum_i8_pairs_int32x8(lhs_mat_23_1_sp1, rhs_mat_2367_1_sp1)), mul_sum_i8_pairs_int32x8(lhs_mat_23_0_sp1, rhs_mat_2367_0_sp1));
+ __m256i iacc_mat_00_sp2 =
+ _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int32x8(lhs_mat_01_3_sp2, rhs_mat_0145_3_sp2), mul_sum_i8_pairs_int32x8(lhs_mat_01_2_sp2, rhs_mat_0145_2_sp2)), mul_sum_i8_pairs_int32x8(lhs_mat_01_1_sp2, rhs_mat_0145_1_sp2)), mul_sum_i8_pairs_int32x8(lhs_mat_01_0_sp2, rhs_mat_0145_0_sp2));
+ __m256i iacc_mat_01_sp2 =
+ _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int32x8(lhs_mat_01_3_sp2, rhs_mat_2367_3_sp2), mul_sum_i8_pairs_int32x8(lhs_mat_01_2_sp2, rhs_mat_2367_2_sp2)), mul_sum_i8_pairs_int32x8(lhs_mat_01_1_sp2, rhs_mat_2367_1_sp2)), mul_sum_i8_pairs_int32x8(lhs_mat_01_0_sp2, rhs_mat_2367_0_sp2));
+ __m256i iacc_mat_10_sp2 =
+ _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int32x8(lhs_mat_23_3_sp2, rhs_mat_0145_3_sp2), mul_sum_i8_pairs_int32x8(lhs_mat_23_2_sp2, rhs_mat_0145_2_sp2)), mul_sum_i8_pairs_int32x8(lhs_mat_23_1_sp2, rhs_mat_0145_1_sp2)), mul_sum_i8_pairs_int32x8(lhs_mat_23_0_sp2, rhs_mat_0145_0_sp2));
+ __m256i iacc_mat_11_sp2 =
+ _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int32x8(lhs_mat_23_3_sp2, rhs_mat_2367_3_sp2), mul_sum_i8_pairs_int32x8(lhs_mat_23_2_sp2, rhs_mat_2367_2_sp2)), mul_sum_i8_pairs_int32x8(lhs_mat_23_1_sp2, rhs_mat_2367_1_sp2)), mul_sum_i8_pairs_int32x8(lhs_mat_23_0_sp2, rhs_mat_2367_0_sp2));
+
+ // Output of both shuffle patterns are added in order to sum dot product outputs of all 32 values in block
+ __m256i iacc_mat_00 = _mm256_add_epi32(iacc_mat_00_sp1, iacc_mat_00_sp2);
+ __m256i iacc_mat_01 = _mm256_add_epi32(iacc_mat_01_sp1, iacc_mat_01_sp2);
+ __m256i iacc_mat_10 = _mm256_add_epi32(iacc_mat_10_sp1, iacc_mat_10_sp2);
+ __m256i iacc_mat_11 = _mm256_add_epi32(iacc_mat_11_sp1, iacc_mat_11_sp2);
+
+ // Straighten out to make 4 row vectors
+ __m256i iacc_row_0 = _mm256_blend_epi32(iacc_mat_00, _mm256_shuffle_epi32(iacc_mat_01, 78), 204);
+ __m256i iacc_row_1 = _mm256_blend_epi32(_mm256_shuffle_epi32(iacc_mat_00, 78), iacc_mat_01, 204);
+ __m256i iacc_row_2 = _mm256_blend_epi32(iacc_mat_10, _mm256_shuffle_epi32(iacc_mat_11, 78), 204);
+ __m256i iacc_row_3 = _mm256_blend_epi32(_mm256_shuffle_epi32(iacc_mat_10, 78), iacc_mat_11, 204);
+
+ // Load the scale(d) values for all the 4 Q8_0 blocks and repeat it across lanes
+ const __m256 row_scale_f32 = GGML_F32Cx8_REPEAT_LOAD(a_ptrs[rp][b].d, loadMask);
+
+ // Multiply with appropiate scales and accumulate
+ acc_rows[rp * 4] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_0), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 0)), acc_rows[rp * 4]);
+ acc_rows[rp * 4 + 1] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_1), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 85)), acc_rows[rp * 4 + 1]);
+ acc_rows[rp * 4 + 2] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_2), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 170)), acc_rows[rp * 4 + 2]);
+ acc_rows[rp * 4 + 3] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_3), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 255)), acc_rows[rp * 4 + 3]);
+ }
+ }
+
+ // Store the accumulated values
+ for (int i = 0; i < 16; i++) {
+ _mm256_storeu_ps((float *)(s + ((y * 4 + i) * bs + x * 8)), acc_rows[i]);
+ }
+ }
+ }
+
+ // Take a block_q8_0x4 structures at each pass of the loop and perform dot product operation
+ for (; y < nr / 4; y ++) {
+
+ const block_q8_0x4 * a_ptr = a_ptr_start + (y * nb);
+
+ // Load the eight block_q4_0 quantized values interleaved with each other in chunks of eight - B0,B1 ....B6,B7
+ for (int64_t x = xstart; x < nc / 8; x++) {
+
+ const block_q4_0x8 * b_ptr = b_ptr_start + (x * b_nb);
+
+ // Master FP accumulators
+ __m256 acc_rows[4];
+ for (int i = 0; i < 4; i++) {
+ acc_rows[i] = _mm256_setzero_ps();
+ }
+
+ for (int64_t b = 0; b < nb; b++) {
+ // Load the eight block_q8_0 quantized values interleaved with each other in chunks of eight - B0,B1 ....B6,B7
+ const __m256i rhs_raw_mat_0123_0 = _mm256_loadu_si256((const __m256i *)(b_ptr[b].qs));
+ const __m256i rhs_raw_mat_4567_0 = _mm256_loadu_si256((const __m256i *)(b_ptr[b].qs + 32));
+ const __m256i rhs_raw_mat_0123_1 = _mm256_loadu_si256((const __m256i *)(b_ptr[b].qs + 64));
+ const __m256i rhs_raw_mat_4567_1 = _mm256_loadu_si256((const __m256i *)(b_ptr[b].qs + 96));
+
+ // Save the values in the following vectors in the formats B0B1B4B5, B2B3B6B7 for further processing and storing of valuess
+ const __m256i rhs_raw_mat_0145_0 = _mm256_blend_epi32(rhs_raw_mat_0123_0, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_0, requiredOrder), 240);
+ const __m256i rhs_raw_mat_2367_0 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_0, requiredOrder), rhs_raw_mat_4567_0, 240);
+ const __m256i rhs_raw_mat_0145_1 = _mm256_blend_epi32(rhs_raw_mat_0123_1, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_1, requiredOrder), 240);
+ const __m256i rhs_raw_mat_2367_1 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_1, requiredOrder), rhs_raw_mat_4567_1, 240);
+
+ // 4-bit -> 8-bit - Sign is maintained
+ const __m256i rhs_mat_0145_0 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(rhs_raw_mat_0145_0, m4b)); //B0(0-7) B1(0-7) B4(0-7) B5(0-7)
+ const __m256i rhs_mat_2367_0 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(rhs_raw_mat_2367_0, m4b)); //B2(0-7) B3(0-7) B6(0-7) B7(0-7)
+
+ const __m256i rhs_mat_0145_1 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(rhs_raw_mat_0145_1, m4b)); //B0(8-15) B1(8-15) B4(8-15) B5(8-15)
+ const __m256i rhs_mat_2367_1 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(rhs_raw_mat_2367_1, m4b)); //B2(8-15) B3(8-15) B6(8-15) B7(8-15)
+
+ const __m256i rhs_mat_0145_2 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_0145_0, 4), m4b)); //B0(16-23) B1(16-23) B4(16-23) B5(16-23)
+ const __m256i rhs_mat_2367_2 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_2367_0, 4), m4b)); //B2(16-23) B3(16-23) B6(16-23) B7(16-23)
+
+ const __m256i rhs_mat_0145_3 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_0145_1, 4), m4b)); //B0(24-31) B1(24-31) B4(24-31) B5(24-31)
+ const __m256i rhs_mat_2367_3 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_2367_1, 4), m4b)); //B2(24-31) B3(24-31) B6(24-31) B7(24-31)
+
+ // Shuffle pattern one - right side input
+ const __m256i rhs_mat_0145_0_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_0, 136); //B0(0-3) B1(0-3) B0(0-3) B1(0-3) B4(0-3) B5(0-3) B4(0-3) B5(0-3)
+ const __m256i rhs_mat_2367_0_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_0, 136); //B2(0-3) B3(0-3) B2(0-3) B3(0-3) B6(0-3) B7(0-3) B6(0-3) B7(0-3)
+
+ const __m256i rhs_mat_0145_1_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_1, 136); //B0(8-11) B1(8-11) B0(8-11) B1(8-11) B4(8-11) B5(8-11) B4(8-11) B5(8-11)
+ const __m256i rhs_mat_2367_1_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_1, 136); //B2(8-11) B3(8-11) B2(8-11) B3(8-11) B6(8-11) B7(8-11) B6(8-11) B7(8-11)
+
+ const __m256i rhs_mat_0145_2_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_2, 136); //B0(16-19) B1(16-19) B0(16-19) B1(16-19) B4(16-19) B5(16-19) B4(16-19) B5(16-19)
+ const __m256i rhs_mat_2367_2_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_2, 136); //B2(16-19) B3(16-19) B2(16-19) B3(16-19) B6(16-19) B7(16-19) B6(16-19) B7(16-19)
+
+ const __m256i rhs_mat_0145_3_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_3, 136); //B0(24-27) B1(24-27) B0(24-27) B1(24-27) B4(24-27) B5(24-27) B4(24-27) B5(24-27)
+ const __m256i rhs_mat_2367_3_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_3, 136); //B2(24-27) B3(24-27) B2(24-27) B3(24-27) B6(24-27) B7(24-27) B6(24-27) B7(24-27)
+
+ // Shuffle pattern two - right side input
+
+ const __m256i rhs_mat_0145_0_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_0, 221); //B0(4-7) B1(4-7) B0(4-7) B1(4-7) B4(4-7) B5(4-7) B4(4-7) B5(4-7)
+ const __m256i rhs_mat_2367_0_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_0, 221); //B2(4-7) B3(4-7) B2(4-7) B3(4-7) B6(4-7) B7(4-7) B6(4-7) B7(4-7)
+
+ const __m256i rhs_mat_0145_1_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_1, 221); //B0(12-15) B1(12-15) B0(12-15) B1(12-15) B4(12-15) B5(12-15) B4(12-15) B5(12-15)
+ const __m256i rhs_mat_2367_1_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_1, 221); //B2(12-15) B3(12-15) B2(12-15) B3(12-15) B6(12-15) B7(12-15) B6(12-15) B7(12-15)
+
+ const __m256i rhs_mat_0145_2_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_2, 221); //B0(20-23) B1(20-23) B0(20-23) B1(20-23) B4(20-23) B5(20-23) B4(20-23) B5(20-23)
+ const __m256i rhs_mat_2367_2_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_2, 221); //B2(20-23) B3(20-23) B2(20-23) B3(20-23) B6(20-23) B7(20-23) B6(20-23) B7(20-23)
+
+ const __m256i rhs_mat_0145_3_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_3, 221); //B0(28-31) B1(28-31) B0(28-31) B1(28-31) B4(28-31) B5(28-31) B4(28-31) B5(28-31)
+ const __m256i rhs_mat_2367_3_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_3, 221); //B2(28-31) B3(28-31) B2(28-31) B3(28-31) B6(28-31) B7(28-31) B6(28-31) B7(28-31)
+
+ // Scale values - Load the wight scale values of block_q4_0x8
+ const __m256 col_scale_f32 = GGML_F32Cx8_LOAD(b_ptr[b].d);
- // Process LHS in groups of four
- for (int rp = 0; rp < 4; rp++) {
// Load the four block_q4_0 quantized values interleaved with each other in chunks of eight - A0,A1,A2,A3
// Loaded as set of 128 bit vectors and repeated into a 256 bit vector
- __m256i lhs_mat_0123_0 = _mm256_loadu_si256((const __m256i *)((a_ptrs[rp][b].qs)));
+ __m256i lhs_mat_0123_0 = _mm256_loadu_si256((const __m256i *)((a_ptr[b].qs)));
__m256i lhs_mat_01_0 = _mm256_permute2f128_si256(lhs_mat_0123_0, lhs_mat_0123_0, 0);
__m256i lhs_mat_23_0 = _mm256_permute2f128_si256(lhs_mat_0123_0, lhs_mat_0123_0, 17);
- __m256i lhs_mat_0123_1 = _mm256_loadu_si256((const __m256i *)((a_ptrs[rp][b].qs + 32)));
+ __m256i lhs_mat_0123_1 = _mm256_loadu_si256((const __m256i *)((a_ptr[b].qs + 32)));
__m256i lhs_mat_01_1 = _mm256_permute2f128_si256(lhs_mat_0123_1, lhs_mat_0123_1, 0);
__m256i lhs_mat_23_1 = _mm256_permute2f128_si256(lhs_mat_0123_1, lhs_mat_0123_1, 17);
- __m256i lhs_mat_0123_2 = _mm256_loadu_si256((const __m256i *)((a_ptrs[rp][b].qs + 64)));
+ __m256i lhs_mat_0123_2 = _mm256_loadu_si256((const __m256i *)((a_ptr[b].qs + 64)));
__m256i lhs_mat_01_2 = _mm256_permute2f128_si256(lhs_mat_0123_2, lhs_mat_0123_2, 0);
__m256i lhs_mat_23_2 = _mm256_permute2f128_si256(lhs_mat_0123_2, lhs_mat_0123_2, 17);
- __m256i lhs_mat_0123_3 = _mm256_loadu_si256((const __m256i *)((a_ptrs[rp][b].qs + 96)));
+ __m256i lhs_mat_0123_3 = _mm256_loadu_si256((const __m256i *)((a_ptr[b].qs + 96)));
__m256i lhs_mat_01_3 = _mm256_permute2f128_si256(lhs_mat_0123_3, lhs_mat_0123_3, 0);
__m256i lhs_mat_23_3 = _mm256_permute2f128_si256(lhs_mat_0123_3, lhs_mat_0123_3, 17);
// Shuffle pattern one - left side input
+
const __m256i lhs_mat_01_0_sp1 = _mm256_shuffle_epi32(lhs_mat_01_0, 160); //A0(0-3) A0(0-3) A1(0-3) A1(0-3) A0(0-3) A0(0-3) A1(0-3) A1(0-3)
const __m256i lhs_mat_23_0_sp1 = _mm256_shuffle_epi32(lhs_mat_23_0, 160); //A2(0-3) A2(0-3) A3(0-3) A3(0-3) A2(0-3) A2(0-3) A3(0-3) A3(0-3)
@@ -2532,6 +3119,7 @@ void ggml_gemm_q4_0_8x8_q8_0(int n, float * restrict s, size_t bs, const void *
const __m256i lhs_mat_23_3_sp1 = _mm256_shuffle_epi32(lhs_mat_23_3, 160); //A2(24-27) A2(24-27) A3(24-27) A3(24-27) A2(24-27) A2(24-27) A3(24-27) A3(24-27)
// Shuffle pattern two - left side input
+
const __m256i lhs_mat_01_0_sp2 = _mm256_shuffle_epi32(lhs_mat_01_0, 245); //A0(4-7) A0(4-7) A1(4-7) A1(4-7) A0(4-7) A0(4-7) A1(4-7) A1(4-7)
const __m256i lhs_mat_23_0_sp2 = _mm256_shuffle_epi32(lhs_mat_23_0, 245); //A2(4-7) A2(4-7) A3(4-7) A3(4-7) A2(4-7) A2(4-7) A3(4-7) A3(4-7)
@@ -2547,21 +3135,21 @@ void ggml_gemm_q4_0_8x8_q8_0(int n, float * restrict s, size_t bs, const void *
// The values arranged in shuffle patterns are operated with dot product operation within 32 bit lane i.e corresponding bytes and multiplied and added into 32 bit integers within 32 bit lane
// Resembles MMLAs into 2x2 matrices in ARM Version
__m256i iacc_mat_00_sp1 =
- _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int(lhs_mat_01_3_sp1, rhs_mat_0145_3_sp1), mul_sum_i8_pairs_int(lhs_mat_01_2_sp1, rhs_mat_0145_2_sp1)), mul_sum_i8_pairs_int(lhs_mat_01_1_sp1, rhs_mat_0145_1_sp1)), mul_sum_i8_pairs_int(lhs_mat_01_0_sp1, rhs_mat_0145_0_sp1));
+ _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int32x8(lhs_mat_01_3_sp1, rhs_mat_0145_3_sp1), mul_sum_i8_pairs_int32x8(lhs_mat_01_2_sp1, rhs_mat_0145_2_sp1)), mul_sum_i8_pairs_int32x8(lhs_mat_01_1_sp1, rhs_mat_0145_1_sp1)), mul_sum_i8_pairs_int32x8(lhs_mat_01_0_sp1, rhs_mat_0145_0_sp1));
__m256i iacc_mat_01_sp1 =
- _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int(lhs_mat_01_3_sp1, rhs_mat_2367_3_sp1), mul_sum_i8_pairs_int(lhs_mat_01_2_sp1, rhs_mat_2367_2_sp1)), mul_sum_i8_pairs_int(lhs_mat_01_1_sp1, rhs_mat_2367_1_sp1)), mul_sum_i8_pairs_int(lhs_mat_01_0_sp1, rhs_mat_2367_0_sp1));
+ _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int32x8(lhs_mat_01_3_sp1, rhs_mat_2367_3_sp1), mul_sum_i8_pairs_int32x8(lhs_mat_01_2_sp1, rhs_mat_2367_2_sp1)), mul_sum_i8_pairs_int32x8(lhs_mat_01_1_sp1, rhs_mat_2367_1_sp1)), mul_sum_i8_pairs_int32x8(lhs_mat_01_0_sp1, rhs_mat_2367_0_sp1));
__m256i iacc_mat_10_sp1 =
- _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int(lhs_mat_23_3_sp1, rhs_mat_0145_3_sp1), mul_sum_i8_pairs_int(lhs_mat_23_2_sp1, rhs_mat_0145_2_sp1)), mul_sum_i8_pairs_int(lhs_mat_23_1_sp1, rhs_mat_0145_1_sp1)), mul_sum_i8_pairs_int(lhs_mat_23_0_sp1, rhs_mat_0145_0_sp1));
+ _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int32x8(lhs_mat_23_3_sp1, rhs_mat_0145_3_sp1), mul_sum_i8_pairs_int32x8(lhs_mat_23_2_sp1, rhs_mat_0145_2_sp1)), mul_sum_i8_pairs_int32x8(lhs_mat_23_1_sp1, rhs_mat_0145_1_sp1)), mul_sum_i8_pairs_int32x8(lhs_mat_23_0_sp1, rhs_mat_0145_0_sp1));
__m256i iacc_mat_11_sp1 =
- _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int(lhs_mat_23_3_sp1, rhs_mat_2367_3_sp1), mul_sum_i8_pairs_int(lhs_mat_23_2_sp1, rhs_mat_2367_2_sp1)), mul_sum_i8_pairs_int(lhs_mat_23_1_sp1, rhs_mat_2367_1_sp1)), mul_sum_i8_pairs_int(lhs_mat_23_0_sp1, rhs_mat_2367_0_sp1));
+ _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int32x8(lhs_mat_23_3_sp1, rhs_mat_2367_3_sp1), mul_sum_i8_pairs_int32x8(lhs_mat_23_2_sp1, rhs_mat_2367_2_sp1)), mul_sum_i8_pairs_int32x8(lhs_mat_23_1_sp1, rhs_mat_2367_1_sp1)), mul_sum_i8_pairs_int32x8(lhs_mat_23_0_sp1, rhs_mat_2367_0_sp1));
__m256i iacc_mat_00_sp2 =
- _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int(lhs_mat_01_3_sp2, rhs_mat_0145_3_sp2), mul_sum_i8_pairs_int(lhs_mat_01_2_sp2, rhs_mat_0145_2_sp2)), mul_sum_i8_pairs_int(lhs_mat_01_1_sp2, rhs_mat_0145_1_sp2)), mul_sum_i8_pairs_int(lhs_mat_01_0_sp2, rhs_mat_0145_0_sp2));
+ _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int32x8(lhs_mat_01_3_sp2, rhs_mat_0145_3_sp2), mul_sum_i8_pairs_int32x8(lhs_mat_01_2_sp2, rhs_mat_0145_2_sp2)), mul_sum_i8_pairs_int32x8(lhs_mat_01_1_sp2, rhs_mat_0145_1_sp2)), mul_sum_i8_pairs_int32x8(lhs_mat_01_0_sp2, rhs_mat_0145_0_sp2));
__m256i iacc_mat_01_sp2 =
- _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int(lhs_mat_01_3_sp2, rhs_mat_2367_3_sp2), mul_sum_i8_pairs_int(lhs_mat_01_2_sp2, rhs_mat_2367_2_sp2)), mul_sum_i8_pairs_int(lhs_mat_01_1_sp2, rhs_mat_2367_1_sp2)), mul_sum_i8_pairs_int(lhs_mat_01_0_sp2, rhs_mat_2367_0_sp2));
+ _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int32x8(lhs_mat_01_3_sp2, rhs_mat_2367_3_sp2), mul_sum_i8_pairs_int32x8(lhs_mat_01_2_sp2, rhs_mat_2367_2_sp2)), mul_sum_i8_pairs_int32x8(lhs_mat_01_1_sp2, rhs_mat_2367_1_sp2)), mul_sum_i8_pairs_int32x8(lhs_mat_01_0_sp2, rhs_mat_2367_0_sp2));
__m256i iacc_mat_10_sp2 =
- _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int(lhs_mat_23_3_sp2, rhs_mat_0145_3_sp2), mul_sum_i8_pairs_int(lhs_mat_23_2_sp2, rhs_mat_0145_2_sp2)), mul_sum_i8_pairs_int(lhs_mat_23_1_sp2, rhs_mat_0145_1_sp2)), mul_sum_i8_pairs_int(lhs_mat_23_0_sp2, rhs_mat_0145_0_sp2));
+ _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int32x8(lhs_mat_23_3_sp2, rhs_mat_0145_3_sp2), mul_sum_i8_pairs_int32x8(lhs_mat_23_2_sp2, rhs_mat_0145_2_sp2)), mul_sum_i8_pairs_int32x8(lhs_mat_23_1_sp2, rhs_mat_0145_1_sp2)), mul_sum_i8_pairs_int32x8(lhs_mat_23_0_sp2, rhs_mat_0145_0_sp2));
__m256i iacc_mat_11_sp2 =
- _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int(lhs_mat_23_3_sp2, rhs_mat_2367_3_sp2), mul_sum_i8_pairs_int(lhs_mat_23_2_sp2, rhs_mat_2367_2_sp2)), mul_sum_i8_pairs_int(lhs_mat_23_1_sp2, rhs_mat_2367_1_sp2)), mul_sum_i8_pairs_int(lhs_mat_23_0_sp2, rhs_mat_2367_0_sp2));
+ _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int32x8(lhs_mat_23_3_sp2, rhs_mat_2367_3_sp2), mul_sum_i8_pairs_int32x8(lhs_mat_23_2_sp2, rhs_mat_2367_2_sp2)), mul_sum_i8_pairs_int32x8(lhs_mat_23_1_sp2, rhs_mat_2367_1_sp2)), mul_sum_i8_pairs_int32x8(lhs_mat_23_0_sp2, rhs_mat_2367_0_sp2));
// Output of both shuffle patterns are added in order to sum dot product outputs of all 32 values in block
__m256i iacc_mat_00 = _mm256_add_epi32(iacc_mat_00_sp1, iacc_mat_00_sp2);
@@ -2569,6 +3157,7 @@ void ggml_gemm_q4_0_8x8_q8_0(int n, float * restrict s, size_t bs, const void *
__m256i iacc_mat_10 = _mm256_add_epi32(iacc_mat_10_sp1, iacc_mat_10_sp2);
__m256i iacc_mat_11 = _mm256_add_epi32(iacc_mat_11_sp1, iacc_mat_11_sp2);
+
// Straighten out to make 4 row vectors
__m256i iacc_row_0 = _mm256_blend_epi32(iacc_mat_00, _mm256_shuffle_epi32(iacc_mat_01, 78), 204);
__m256i iacc_row_1 = _mm256_blend_epi32(_mm256_shuffle_epi32(iacc_mat_00, 78), iacc_mat_01, 204);
@@ -2576,187 +3165,24 @@ void ggml_gemm_q4_0_8x8_q8_0(int n, float * restrict s, size_t bs, const void *
__m256i iacc_row_3 = _mm256_blend_epi32(_mm256_shuffle_epi32(iacc_mat_10, 78), iacc_mat_11, 204);
// Load the scale(d) values for all the 4 Q8_0 blocks and repeat it across lanes
- const __m256 row_scale_f32 = GGML_F32Cx8_REPEAT_LOAD(a_ptrs[rp][b].d, loadMask);
+ const __m256 row_scale_f32 = GGML_F32Cx8_REPEAT_LOAD(a_ptr[b].d, loadMask);
// Multiply with appropiate scales and accumulate
- acc_rows[rp * 4] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_0), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 0)), acc_rows[rp * 4]);
- acc_rows[rp * 4 + 1] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_1), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 85)), acc_rows[rp * 4 + 1]);
- acc_rows[rp * 4 + 2] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_2), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 170)), acc_rows[rp * 4 + 2]);
- acc_rows[rp * 4 + 3] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_3), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 255)), acc_rows[rp * 4 + 3]);
+ acc_rows[0] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_0), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 0)), acc_rows[0]);
+ acc_rows[1] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_1), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 85)), acc_rows[1]);
+ acc_rows[2] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_2), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 170)), acc_rows[2]);
+ acc_rows[3] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_3), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 255)), acc_rows[3]);
+ }
+
+ // Store the accumulated values
+ for (int i = 0; i < 4; i++) {
+ _mm256_storeu_ps((float *)(s + ((y * 4 + i) * bs + x * 8)), acc_rows[i]);
}
}
-
- // Store the accumulated values
- for (int i = 0; i < 16; i++) {
- _mm256_storeu_ps((float *)(s + ((y * 4 + i) * bs + x * 8)), acc_rows[i]);
- }
}
+ return;
}
-
- // Take a block_q8_0x4 structures at each pass of the loop and perform dot product operation
- for (; y < nr / 4; y ++) {
-
- const block_q8_0x4 * a_ptr = a_ptr_start + (y * nb);
-
- // Load the eight block_q4_0 quantized values interleaved with each other in chunks of eight - B0,B1 ....B6,B7
- for (int64_t x = 0; x < nc / 8; x++) {
-
- const block_q4_0x8 * b_ptr = b_ptr_start + (x * b_nb);
-
- // Master FP accumulators
- __m256 acc_rows[4];
- for (int i = 0; i < 4; i++) {
- acc_rows[i] = _mm256_setzero_ps();
- }
-
- for (int64_t b = 0; b < nb; b++) {
- // Load the eight block_q8_0 quantized values interleaved with each other in chunks of eight - B0,B1 ....B6,B7
- const __m256i rhs_raw_mat_0123_0 = _mm256_loadu_si256((const __m256i *)(b_ptr[b].qs));
- const __m256i rhs_raw_mat_4567_0 = _mm256_loadu_si256((const __m256i *)(b_ptr[b].qs + 32));
- const __m256i rhs_raw_mat_0123_1 = _mm256_loadu_si256((const __m256i *)(b_ptr[b].qs + 64));
- const __m256i rhs_raw_mat_4567_1 = _mm256_loadu_si256((const __m256i *)(b_ptr[b].qs + 96));
-
- // Save the values in the following vectors in the formats B0B1B4B5, B2B3B6B7 for further processing and storing of valuess
- const __m256i rhs_raw_mat_0145_0 = _mm256_blend_epi32(rhs_raw_mat_0123_0, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_0, requiredOrder), 240);
- const __m256i rhs_raw_mat_2367_0 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_0, requiredOrder), rhs_raw_mat_4567_0, 240);
- const __m256i rhs_raw_mat_0145_1 = _mm256_blend_epi32(rhs_raw_mat_0123_1, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_1, requiredOrder), 240);
- const __m256i rhs_raw_mat_2367_1 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_1, requiredOrder), rhs_raw_mat_4567_1, 240);
-
- // 4-bit -> 8-bit - Sign is maintained
- const __m256i rhs_mat_0145_0 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(rhs_raw_mat_0145_0, m4b)); //B0(0-7) B1(0-7) B4(0-7) B5(0-7)
- const __m256i rhs_mat_2367_0 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(rhs_raw_mat_2367_0, m4b)); //B2(0-7) B3(0-7) B6(0-7) B7(0-7)
-
- const __m256i rhs_mat_0145_1 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(rhs_raw_mat_0145_1, m4b)); //B0(8-15) B1(8-15) B4(8-15) B5(8-15)
- const __m256i rhs_mat_2367_1 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(rhs_raw_mat_2367_1, m4b)); //B2(8-15) B3(8-15) B6(8-15) B7(8-15)
-
- const __m256i rhs_mat_0145_2 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_0145_0, 4), m4b)); //B0(16-23) B1(16-23) B4(16-23) B5(16-23)
- const __m256i rhs_mat_2367_2 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_2367_0, 4), m4b)); //B2(16-23) B3(16-23) B6(16-23) B7(16-23)
-
- const __m256i rhs_mat_0145_3 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_0145_1, 4), m4b)); //B0(24-31) B1(24-31) B4(24-31) B5(24-31)
- const __m256i rhs_mat_2367_3 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_2367_1, 4), m4b)); //B2(24-31) B3(24-31) B6(24-31) B7(24-31)
-
- // Shuffle pattern one - right side input
- const __m256i rhs_mat_0145_0_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_0, 136); //B0(0-3) B1(0-3) B0(0-3) B1(0-3) B4(0-3) B5(0-3) B4(0-3) B5(0-3)
- const __m256i rhs_mat_2367_0_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_0, 136); //B2(0-3) B3(0-3) B2(0-3) B3(0-3) B6(0-3) B7(0-3) B6(0-3) B7(0-3)
-
- const __m256i rhs_mat_0145_1_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_1, 136); //B0(8-11) B1(8-11) B0(8-11) B1(8-11) B4(8-11) B5(8-11) B4(8-11) B5(8-11)
- const __m256i rhs_mat_2367_1_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_1, 136); //B2(8-11) B3(8-11) B2(8-11) B3(8-11) B6(8-11) B7(8-11) B6(8-11) B7(8-11)
-
- const __m256i rhs_mat_0145_2_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_2, 136); //B0(16-19) B1(16-19) B0(16-19) B1(16-19) B4(16-19) B5(16-19) B4(16-19) B5(16-19)
- const __m256i rhs_mat_2367_2_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_2, 136); //B2(16-19) B3(16-19) B2(16-19) B3(16-19) B6(16-19) B7(16-19) B6(16-19) B7(16-19)
-
- const __m256i rhs_mat_0145_3_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_3, 136); //B0(24-27) B1(24-27) B0(24-27) B1(24-27) B4(24-27) B5(24-27) B4(24-27) B5(24-27)
- const __m256i rhs_mat_2367_3_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_3, 136); //B2(24-27) B3(24-27) B2(24-27) B3(24-27) B6(24-27) B7(24-27) B6(24-27) B7(24-27)
-
- // Shuffle pattern two - right side input
-
- const __m256i rhs_mat_0145_0_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_0, 221); //B0(4-7) B1(4-7) B0(4-7) B1(4-7) B4(4-7) B5(4-7) B4(4-7) B5(4-7)
- const __m256i rhs_mat_2367_0_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_0, 221); //B2(4-7) B3(4-7) B2(4-7) B3(4-7) B6(4-7) B7(4-7) B6(4-7) B7(4-7)
-
- const __m256i rhs_mat_0145_1_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_1, 221); //B0(12-15) B1(12-15) B0(12-15) B1(12-15) B4(12-15) B5(12-15) B4(12-15) B5(12-15)
- const __m256i rhs_mat_2367_1_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_1, 221); //B2(12-15) B3(12-15) B2(12-15) B3(12-15) B6(12-15) B7(12-15) B6(12-15) B7(12-15)
-
- const __m256i rhs_mat_0145_2_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_2, 221); //B0(20-23) B1(20-23) B0(20-23) B1(20-23) B4(20-23) B5(20-23) B4(20-23) B5(20-23)
- const __m256i rhs_mat_2367_2_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_2, 221); //B2(20-23) B3(20-23) B2(20-23) B3(20-23) B6(20-23) B7(20-23) B6(20-23) B7(20-23)
-
- const __m256i rhs_mat_0145_3_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_3, 221); //B0(28-31) B1(28-31) B0(28-31) B1(28-31) B4(28-31) B5(28-31) B4(28-31) B5(28-31)
- const __m256i rhs_mat_2367_3_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_3, 221); //B2(28-31) B3(28-31) B2(28-31) B3(28-31) B6(28-31) B7(28-31) B6(28-31) B7(28-31)
-
- // Scale values - Load the wight scale values of block_q4_0x8
- const __m256 col_scale_f32 = GGML_F32Cx8_LOAD(b_ptr[b].d);
-
- // Load the four block_q4_0 quantized values interleaved with each other in chunks of eight - A0,A1,A2,A3
- // Loaded as set of 128 bit vectors and repeated into a 256 bit vector
- __m256i lhs_mat_0123_0 = _mm256_loadu_si256((const __m256i *)((a_ptr[b].qs)));
- __m256i lhs_mat_01_0 = _mm256_permute2f128_si256(lhs_mat_0123_0, lhs_mat_0123_0, 0);
- __m256i lhs_mat_23_0 = _mm256_permute2f128_si256(lhs_mat_0123_0, lhs_mat_0123_0, 17);
- __m256i lhs_mat_0123_1 = _mm256_loadu_si256((const __m256i *)((a_ptr[b].qs + 32)));
- __m256i lhs_mat_01_1 = _mm256_permute2f128_si256(lhs_mat_0123_1, lhs_mat_0123_1, 0);
- __m256i lhs_mat_23_1 = _mm256_permute2f128_si256(lhs_mat_0123_1, lhs_mat_0123_1, 17);
- __m256i lhs_mat_0123_2 = _mm256_loadu_si256((const __m256i *)((a_ptr[b].qs + 64)));
- __m256i lhs_mat_01_2 = _mm256_permute2f128_si256(lhs_mat_0123_2, lhs_mat_0123_2, 0);
- __m256i lhs_mat_23_2 = _mm256_permute2f128_si256(lhs_mat_0123_2, lhs_mat_0123_2, 17);
- __m256i lhs_mat_0123_3 = _mm256_loadu_si256((const __m256i *)((a_ptr[b].qs + 96)));
- __m256i lhs_mat_01_3 = _mm256_permute2f128_si256(lhs_mat_0123_3, lhs_mat_0123_3, 0);
- __m256i lhs_mat_23_3 = _mm256_permute2f128_si256(lhs_mat_0123_3, lhs_mat_0123_3, 17);
-
- // Shuffle pattern one - left side input
-
- const __m256i lhs_mat_01_0_sp1 = _mm256_shuffle_epi32(lhs_mat_01_0, 160); //A0(0-3) A0(0-3) A1(0-3) A1(0-3) A0(0-3) A0(0-3) A1(0-3) A1(0-3)
- const __m256i lhs_mat_23_0_sp1 = _mm256_shuffle_epi32(lhs_mat_23_0, 160); //A2(0-3) A2(0-3) A3(0-3) A3(0-3) A2(0-3) A2(0-3) A3(0-3) A3(0-3)
-
- const __m256i lhs_mat_01_1_sp1 = _mm256_shuffle_epi32(lhs_mat_01_1, 160); //A0(8-11) A0(8-11) A1(8-11) A1(8-11) A0(8-11) A0(8-11) A1(8-11) A1(8-11)
- const __m256i lhs_mat_23_1_sp1 = _mm256_shuffle_epi32(lhs_mat_23_1, 160); //A2(8-11) A2(8-11) A3(8-11) A3(8-11) A2(8-11) A2(8-11) A3(8-11) A3(8-11)
-
- const __m256i lhs_mat_01_2_sp1 = _mm256_shuffle_epi32(lhs_mat_01_2, 160); //A0(16-19) A0(16-19) A1(16-19) A1(16-19) A0(16-19) A0(16-19) A1(16-19) A1(16-19)
- const __m256i lhs_mat_23_2_sp1 = _mm256_shuffle_epi32(lhs_mat_23_2, 160); //A2(16-19) A2(16-19) A3(16-19) A3(16-19) A2(16-19) A2(16-19) A3(16-19) A3(16-19)
-
- const __m256i lhs_mat_01_3_sp1 = _mm256_shuffle_epi32(lhs_mat_01_3, 160); //A0(24-27) A0(24-27) A1(24-27) A1(24-27) A0(24-27) A0(24-27) A1(24-27) A1(24-27)
- const __m256i lhs_mat_23_3_sp1 = _mm256_shuffle_epi32(lhs_mat_23_3, 160); //A2(24-27) A2(24-27) A3(24-27) A3(24-27) A2(24-27) A2(24-27) A3(24-27) A3(24-27)
-
- // Shuffle pattern two - left side input
-
- const __m256i lhs_mat_01_0_sp2 = _mm256_shuffle_epi32(lhs_mat_01_0, 245); //A0(4-7) A0(4-7) A1(4-7) A1(4-7) A0(4-7) A0(4-7) A1(4-7) A1(4-7)
- const __m256i lhs_mat_23_0_sp2 = _mm256_shuffle_epi32(lhs_mat_23_0, 245); //A2(4-7) A2(4-7) A3(4-7) A3(4-7) A2(4-7) A2(4-7) A3(4-7) A3(4-7)
-
- const __m256i lhs_mat_01_1_sp2 = _mm256_shuffle_epi32(lhs_mat_01_1, 245); //A0(12-15) A0(12-15) A1(12-15) A1(12-15) A0(12-15) A0(12-15) A1(12-15) A1(12-15)
- const __m256i lhs_mat_23_1_sp2 = _mm256_shuffle_epi32(lhs_mat_23_1, 245); //A2(12-15) A2(12-15) A3(12-15) A3(12-15) A2(12-15) A2(12-15) A3(12-15) A3(12-15)
-
- const __m256i lhs_mat_01_2_sp2 = _mm256_shuffle_epi32(lhs_mat_01_2, 245); //A0(20-23) A0(20-23) A1(20-23) A1(20-23) A0(20-23) A0(20-23) A1(20-23) A1(20-23)
- const __m256i lhs_mat_23_2_sp2 = _mm256_shuffle_epi32(lhs_mat_23_2, 245); //A2(20-23) A2(20-23) A3(20-23) A3(20-23) A2(20-23) A2(20-23) A3(20-23) A3(20-23)
-
- const __m256i lhs_mat_01_3_sp2 = _mm256_shuffle_epi32(lhs_mat_01_3, 245); //A0(28-31) A0(28-31) A1(28-31) A1(28-31) A0(28-31) A0(28-31) A1(28-31) A1(28-31)
- const __m256i lhs_mat_23_3_sp2 = _mm256_shuffle_epi32(lhs_mat_23_3, 245); //A2(28-31) A2(28-31) A3(28-31) A3(28-31) A2(28-31) A2(28-31) A3(28-31) A3(28-31)
-
- // The values arranged in shuffle patterns are operated with dot product operation within 32 bit lane i.e corresponding bytes and multiplied and added into 32 bit integers within 32 bit lane
- // Resembles MMLAs into 2x2 matrices in ARM Version
- __m256i iacc_mat_00_sp1 =
- _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int(lhs_mat_01_3_sp1, rhs_mat_0145_3_sp1), mul_sum_i8_pairs_int(lhs_mat_01_2_sp1, rhs_mat_0145_2_sp1)), mul_sum_i8_pairs_int(lhs_mat_01_1_sp1, rhs_mat_0145_1_sp1)), mul_sum_i8_pairs_int(lhs_mat_01_0_sp1, rhs_mat_0145_0_sp1));
- __m256i iacc_mat_01_sp1 =
- _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int(lhs_mat_01_3_sp1, rhs_mat_2367_3_sp1), mul_sum_i8_pairs_int(lhs_mat_01_2_sp1, rhs_mat_2367_2_sp1)), mul_sum_i8_pairs_int(lhs_mat_01_1_sp1, rhs_mat_2367_1_sp1)), mul_sum_i8_pairs_int(lhs_mat_01_0_sp1, rhs_mat_2367_0_sp1));
- __m256i iacc_mat_10_sp1 =
- _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int(lhs_mat_23_3_sp1, rhs_mat_0145_3_sp1), mul_sum_i8_pairs_int(lhs_mat_23_2_sp1, rhs_mat_0145_2_sp1)), mul_sum_i8_pairs_int(lhs_mat_23_1_sp1, rhs_mat_0145_1_sp1)), mul_sum_i8_pairs_int(lhs_mat_23_0_sp1, rhs_mat_0145_0_sp1));
- __m256i iacc_mat_11_sp1 =
- _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int(lhs_mat_23_3_sp1, rhs_mat_2367_3_sp1), mul_sum_i8_pairs_int(lhs_mat_23_2_sp1, rhs_mat_2367_2_sp1)), mul_sum_i8_pairs_int(lhs_mat_23_1_sp1, rhs_mat_2367_1_sp1)), mul_sum_i8_pairs_int(lhs_mat_23_0_sp1, rhs_mat_2367_0_sp1));
- __m256i iacc_mat_00_sp2 =
- _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int(lhs_mat_01_3_sp2, rhs_mat_0145_3_sp2), mul_sum_i8_pairs_int(lhs_mat_01_2_sp2, rhs_mat_0145_2_sp2)), mul_sum_i8_pairs_int(lhs_mat_01_1_sp2, rhs_mat_0145_1_sp2)), mul_sum_i8_pairs_int(lhs_mat_01_0_sp2, rhs_mat_0145_0_sp2));
- __m256i iacc_mat_01_sp2 =
- _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int(lhs_mat_01_3_sp2, rhs_mat_2367_3_sp2), mul_sum_i8_pairs_int(lhs_mat_01_2_sp2, rhs_mat_2367_2_sp2)), mul_sum_i8_pairs_int(lhs_mat_01_1_sp2, rhs_mat_2367_1_sp2)), mul_sum_i8_pairs_int(lhs_mat_01_0_sp2, rhs_mat_2367_0_sp2));
- __m256i iacc_mat_10_sp2 =
- _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int(lhs_mat_23_3_sp2, rhs_mat_0145_3_sp2), mul_sum_i8_pairs_int(lhs_mat_23_2_sp2, rhs_mat_0145_2_sp2)), mul_sum_i8_pairs_int(lhs_mat_23_1_sp2, rhs_mat_0145_1_sp2)), mul_sum_i8_pairs_int(lhs_mat_23_0_sp2, rhs_mat_0145_0_sp2));
- __m256i iacc_mat_11_sp2 =
- _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int(lhs_mat_23_3_sp2, rhs_mat_2367_3_sp2), mul_sum_i8_pairs_int(lhs_mat_23_2_sp2, rhs_mat_2367_2_sp2)), mul_sum_i8_pairs_int(lhs_mat_23_1_sp2, rhs_mat_2367_1_sp2)), mul_sum_i8_pairs_int(lhs_mat_23_0_sp2, rhs_mat_2367_0_sp2));
-
- // Output of both shuffle patterns are added in order to sum dot product outputs of all 32 values in block
- __m256i iacc_mat_00 = _mm256_add_epi32(iacc_mat_00_sp1, iacc_mat_00_sp2);
- __m256i iacc_mat_01 = _mm256_add_epi32(iacc_mat_01_sp1, iacc_mat_01_sp2);
- __m256i iacc_mat_10 = _mm256_add_epi32(iacc_mat_10_sp1, iacc_mat_10_sp2);
- __m256i iacc_mat_11 = _mm256_add_epi32(iacc_mat_11_sp1, iacc_mat_11_sp2);
-
-
- // Straighten out to make 4 row vectors
- __m256i iacc_row_0 = _mm256_blend_epi32(iacc_mat_00, _mm256_shuffle_epi32(iacc_mat_01, 78), 204);
- __m256i iacc_row_1 = _mm256_blend_epi32(_mm256_shuffle_epi32(iacc_mat_00, 78), iacc_mat_01, 204);
- __m256i iacc_row_2 = _mm256_blend_epi32(iacc_mat_10, _mm256_shuffle_epi32(iacc_mat_11, 78), 204);
- __m256i iacc_row_3 = _mm256_blend_epi32(_mm256_shuffle_epi32(iacc_mat_10, 78), iacc_mat_11, 204);
-
- // Load the scale(d) values for all the 4 Q8_0 blocks and repeat it across lanes
- const __m256 row_scale_f32 = GGML_F32Cx8_REPEAT_LOAD(a_ptr[b].d, loadMask);
-
- // Multiply with appropiate scales and accumulate
- acc_rows[0] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_0), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 0)), acc_rows[0]);
- acc_rows[1] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_1), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 85)), acc_rows[1]);
- acc_rows[2] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_2), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 170)), acc_rows[2]);
- acc_rows[3] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_3), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 255)), acc_rows[3]);
- }
-
- // Store the accumulated values
- for (int i = 0; i < 4; i++) {
- _mm256_storeu_ps((float *)(s + ((y * 4 + i) * bs + x * 8)), acc_rows[i]);
- }
- }
- }
-#else
+#endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__)
float sumf[4][8];
int sumi;
@@ -2789,5 +3215,4 @@ void ggml_gemm_q4_0_8x8_q8_0(int n, float * restrict s, size_t bs, const void *
}
}
}
-#endif
}
diff --git a/ggml/src/ggml-alloc.c b/ggml/src/ggml-alloc.c
index e485326ab..70187b9b6 100644
--- a/ggml/src/ggml-alloc.c
+++ b/ggml/src/ggml-alloc.c
@@ -294,6 +294,12 @@ static void ggml_dyn_tallocr_reset(struct ggml_dyn_tallocr * alloc) {
alloc->free_blocks[0].offset = 0;
alloc->free_blocks[0].size = SIZE_MAX/2; // restrict maximum size of a measure allocator to half size_t max to avoid overflows
alloc->max_size = 0;
+
+#ifdef GGML_ALLOCATOR_DEBUG
+ for (int i = 0; i < 1024; i++) {
+ alloc->allocated_tensors[i].tensor = NULL;
+ }
+#endif
}
static struct ggml_dyn_tallocr * ggml_dyn_tallocr_new(size_t alignment) {
diff --git a/ggml/src/ggml-cann/common.h b/ggml/src/ggml-cann/common.h
index e6a570107..edfa49614 100644
--- a/ggml/src/ggml-cann/common.h
+++ b/ggml/src/ggml-cann/common.h
@@ -227,6 +227,7 @@ struct ggml_backend_cann_context {
* @brief Destructor for cleaning up resources.
*/
~ggml_backend_cann_context() {
+ ggml_cann_set_device(device);
if (copy_event != nullptr) {
ACL_CHECK(aclrtDestroyEvent(copy_event));
}
diff --git a/ggml/src/ggml-cuda.cu b/ggml/src/ggml-cuda.cu
index da0094eed..127eb458b 100644
--- a/ggml/src/ggml-cuda.cu
+++ b/ggml/src/ggml-cuda.cu
@@ -36,6 +36,7 @@ bool g_mul_mat_q = false;
#include "ggml-cuda/tsembd.cuh"
#include "ggml-cuda/unary.cuh"
#include "ggml-cuda/upscale.cuh"
+#include "ggml-cuda/rwkv-wkv.cuh"
#include
#include
@@ -137,7 +138,7 @@ static cudaError_t ggml_cuda_device_malloc(void ** ptr, size_t size, int device)
return res;
#else
-#if !defined(GGML_USE_HIPBLAS) && !defined(GGML_USE_MUSA)
+#if !defined(GGML_USE_HIPBLAS)
cudaError_t err;
if (getenv("GGML_CUDA_ENABLE_UNIFIED_MEMORY") != nullptr)
{
@@ -150,7 +151,7 @@ static cudaError_t ggml_cuda_device_malloc(void ** ptr, size_t size, int device)
return err;
#else
return cudaMalloc(ptr, size);
-#endif // !defined(GGML_USE_HIPBLAS) && !defined(GGML_USE_MUSA)
+#endif // !defined(GGML_USE_HIPBLAS)
#endif
}
@@ -188,7 +189,7 @@ static ggml_cuda_device_info ggml_cuda_init() {
for (int id = 0; id < info.device_count; ++id) {
int device_vmm = 0;
-#if !defined(GGML_USE_HIPBLAS) && !defined(GGML_CUDA_NO_VMM) && !defined(GGML_USE_MUSA)
+#if !defined(GGML_USE_HIPBLAS) && !defined(GGML_CUDA_NO_VMM)
CUdevice device;
CU_CHECK(cuDeviceGet(&device, id));
CU_CHECK(cuDeviceGetAttribute(&device_vmm, CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED, device));
@@ -200,7 +201,7 @@ static ggml_cuda_device_info ggml_cuda_init() {
alloc_prop.location.id = id;
CU_CHECK(cuMemGetAllocationGranularity(&info.devices[id].vmm_granularity, &alloc_prop, CU_MEM_ALLOC_GRANULARITY_RECOMMENDED));
}
-#endif // !defined(GGML_USE_HIPBLAS) && !defined(GGML_CUDA_NO_VMM) && !defined(GGML_USE_MUSA)
+#endif // !defined(GGML_USE_HIPBLAS) && !defined(GGML_CUDA_NO_VMM)
info.devices[id].vmm = !!device_vmm;
cudaDeviceProp prop;
@@ -334,7 +335,7 @@ struct ggml_cuda_pool_leg : public ggml_cuda_pool {
};
// pool with virtual memory
-#if !defined(GGML_USE_HIPBLAS) && !defined(GGML_CUDA_NO_VMM) && !defined(GGML_USE_MUSA)
+#if !defined(GGML_USE_HIPBLAS) && !defined(GGML_CUDA_NO_VMM)
struct ggml_cuda_pool_vmm : public ggml_cuda_pool {
static const size_t CUDA_POOL_VMM_MAX_SIZE = 1ull << 35; // 32 GB
@@ -428,14 +429,14 @@ struct ggml_cuda_pool_vmm : public ggml_cuda_pool {
GGML_ASSERT(ptr == (void *) (pool_addr + pool_used));
}
};
-#endif // !defined(GGML_USE_HIPBLAS) && !defined(GGML_CUDA_NO_VMM) && !defined(GGML_USE_MUSA)
+#endif // !defined(GGML_USE_HIPBLAS) && !defined(GGML_CUDA_NO_VMM)
std::unique_ptr ggml_backend_cuda_context::new_pool_for_device(int device) {
-#if !defined(GGML_USE_HIPBLAS) && !defined(GGML_CUDA_NO_VMM) && !defined(GGML_USE_MUSA)
+#if !defined(GGML_USE_HIPBLAS) && !defined(GGML_CUDA_NO_VMM)
if (ggml_cuda_info().devices[device].vmm) {
return std::unique_ptr(new ggml_cuda_pool_vmm(device));
}
-#endif // !defined(GGML_USE_HIPBLAS) && !defined(GGML_CUDA_NO_VMM) && !defined(GGML_USE_MUSA)
+#endif // !defined(GGML_USE_HIPBLAS) && !defined(GGML_CUDA_NO_VMM)
return std::unique_ptr(new ggml_cuda_pool_leg(device));
}
@@ -2247,6 +2248,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
case GGML_UNARY_OP_HARDSWISH:
ggml_cuda_op_hardswish(ctx, dst);
break;
+ case GGML_UNARY_OP_EXP:
+ ggml_cuda_op_exp(ctx, dst);
+ break;
default:
return false;
}
@@ -2349,6 +2353,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
case GGML_OP_CROSS_ENTROPY_LOSS:
ggml_cuda_cross_entropy_loss(ctx, dst);
break;
+ case GGML_OP_RWKV_WKV:
+ ggml_cuda_op_rwkv_wkv(ctx, dst);
+ break;
case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
ggml_cuda_cross_entropy_loss_back(ctx, dst);
break;
@@ -2810,6 +2817,7 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
case GGML_UNARY_OP_HARDSWISH:
case GGML_UNARY_OP_GELU_QUICK:
case GGML_UNARY_OP_TANH:
+ case GGML_UNARY_OP_EXP:
return ggml_is_contiguous(op->src[0]);
default:
return false;
@@ -2826,6 +2834,12 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
if (op->op == GGML_OP_MUL_MAT && a->ne[3] != b->ne[3]) {
return false;
}
+#ifdef GGML_USE_MUSA
+ if (b->type == GGML_TYPE_F16 && b->ne[2]*b->ne[3] > 1 &&
+ !ggml_is_transposed(a) && !ggml_is_transposed(b)) {
+ return false;
+ }
+#endif // GGML_USE_MUSA
switch (a->type) {
case GGML_TYPE_F32:
case GGML_TYPE_F16:
@@ -2849,6 +2863,11 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
case GGML_TYPE_IQ3_XXS:
case GGML_TYPE_IQ4_NL:
case GGML_TYPE_IQ4_XS:
+#ifdef GGML_USE_MUSA
+ if (a->type == GGML_TYPE_Q3_K) {
+ return false;
+ }
+#endif // GGML_USE_MUSA
return true;
default:
return false;
@@ -2884,6 +2903,9 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_Q8_0) {
return true;
}
+ if (src0_type == GGML_TYPE_Q8_0 && src1_type == GGML_TYPE_F32) {
+ return true;
+ }
if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_Q4_0) {
return true;
}
@@ -2971,20 +2993,24 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
case GGML_OP_ARANGE:
case GGML_OP_TIMESTEP_EMBEDDING:
case GGML_OP_LEAKY_RELU:
+ case GGML_OP_RWKV_WKV:
return true;
- case GGML_OP_FLASH_ATTN_EXT:
-#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
- return (op->src[0]->ne[0] == 64 && op->src[1]->type == GGML_TYPE_F16) || op->src[0]->ne[0] == 128;
-#else
- if (op->src[0]->ne[0] == 128) {
- return true;
- }
+ case GGML_OP_FLASH_ATTN_EXT: {
+#ifndef FLASH_ATTN_AVAILABLE
+ return false;
+#endif
if (op->src[0]->ne[0] == 64 && op->src[1]->type == GGML_TYPE_F16) {
return true;
}
- return ggml_cuda_info().devices[cuda_ctx->device].cc >= CC_VOLTA &&
- op->src[1]->type == GGML_TYPE_F16 && op->src[2]->type == GGML_TYPE_F16;
-#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
+ if (op->src[0]->ne[0] == 128) {
+ return true;
+ }
+ if (op->src[0]->ne[0] == 256 && op->src[1]->type == GGML_TYPE_F16 && op->src[2]->type == GGML_TYPE_F16) {
+ return true;
+ }
+ const int cc = ggml_cuda_info().devices[cuda_ctx->device].cc;
+ return cc >= CC_VOLTA && cc < CC_OFFSET_AMD && op->src[1]->type == GGML_TYPE_F16 && op->src[2]->type == GGML_TYPE_F16;
+ }
case GGML_OP_CROSS_ENTROPY_LOSS:
case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
case GGML_OP_OPT_STEP_ADAMW:
diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh
index 85eb200f0..6a4bcdba0 100644
--- a/ggml/src/ggml-cuda/common.cuh
+++ b/ggml/src/ggml-cuda/common.cuh
@@ -50,6 +50,8 @@
#define CC_RDNA1 (CC_OFFSET_AMD + 1010)
#define CC_RDNA2 (CC_OFFSET_AMD + 1030)
#define CC_RDNA3 (CC_OFFSET_AMD + 1100)
+#define CC_QY1 210
+#define CC_QY2 220
#define MATRIX_ROW_PADDING 512 // last row of quant. matrices is a multiple of this to avoid out-of-bounds memory accesses
@@ -134,6 +136,10 @@ typedef float2 dfloat2;
#define INT8_MMA_AVAILABLE
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_TURING
+#if !(defined(GGML_USE_MUSA) && __MUSA_ARCH__ <= CC_QY1)
+#define FLASH_ATTN_AVAILABLE
+#endif // !(defined(GGML_USE_MUSA) && __MUSA_ARCH__ <= CC_QY1)
+
static constexpr bool fast_fp16_available(const int cc) {
return cc >= CC_PASCAL && cc != 610;
}
diff --git a/ggml/src/ggml-cuda/cpy.cu b/ggml/src/ggml-cuda/cpy.cu
index 51deb75fd..54c0f66d2 100644
--- a/ggml/src/ggml-cuda/cpy.cu
+++ b/ggml/src/ggml-cuda/cpy.cu
@@ -81,6 +81,17 @@ static __device__ void cpy_blck_f32_q8_0(const char * cxi, char * cdsti) {
}
}
+static __device__ void cpy_blck_q8_0_f32(const char * cxi, char * cdsti) {
+ const block_q8_0 * xi = (const block_q8_0 *) cxi;
+ float * dsti = (float *) cdsti;
+
+ const float d = (float)xi->d;
+
+ for (int j = 0; j < QK8_0; j++) {
+ dsti[j] = xi->qs[j] * d;
+ }
+}
+
static __device__ void cpy_blck_f32_q4_0(const char * cxi, char * cdsti) {
const float * xi = (const float *) cxi;
block_q4_0 * dsti = (block_q4_0 *) cdsti;
@@ -288,6 +299,32 @@ static __global__ void cpy_f32_q(const char * cx, char * cdst, const int ne,
cpy_blck(cx + x_offset, cdst + dst_offset);
}
+template
+static __global__ void cpy_q_f32(const char * cx, char * cdst, const int ne,
+ const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
+ const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
+ const int nb12, const int nb13) {
+ const int i = (blockDim.x*blockIdx.x + threadIdx.x)*qk;
+
+ if (i >= ne) {
+ return;
+ }
+
+ const int i03 = i/(ne00 * ne01 * ne02);
+ const int i02 = (i - i03*ne00*ne01*ne02 )/ (ne00*ne01);
+ const int i01 = (i - i03*ne00*ne01*ne02 - i02*ne01*ne00) / ne00;
+ const int i00 = i - i03*ne00*ne01*ne02 - i02*ne01*ne00 - i01*ne00;
+ const int x_offset = (i00/qk)*nb00 + i01*nb01 + i02*nb02 + i03 * nb03;
+
+ const int i13 = i/(ne10 * ne11 * ne12);
+ const int i12 = (i - i13*ne10*ne11*ne12) / (ne10*ne11);
+ const int i11 = (i - i13*ne10*ne11*ne12 - i12*ne10*ne11) / ne10;
+ const int i10 = i - i13*ne10*ne11*ne12 - i12*ne10*ne11 - i11*ne10;
+ const int dst_offset = i10*nb10 + i11*nb11 + i12*nb12 + i13*nb13;
+
+ cpy_blck(cx + x_offset, cdst + dst_offset);
+}
+
static void ggml_cpy_f16_f32_cuda(
const char * cx, char * cdst, const int ne,
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
@@ -329,6 +366,16 @@ static void ggml_cpy_f32_q8_0_cuda(
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
}
+static void ggml_cpy_q8_0_f32_cuda(
+ const char * cx, char * cdst, const int ne,
+ const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
+ const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
+
+ const int num_blocks = ne;
+ cpy_q_f32<<>>
+ (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
+}
+
static void ggml_cpy_f32_q4_0_cuda(
const char * cx, char * cdst, const int ne,
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
@@ -437,6 +484,8 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
ggml_cpy_f32_f16_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q8_0) {
ggml_cpy_f32_q8_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
+ } else if (src0->type == GGML_TYPE_Q8_0 && src1->type == GGML_TYPE_F32) {
+ ggml_cpy_q8_0_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_0) {
ggml_cpy_f32_q4_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_1) {
@@ -471,6 +520,8 @@ void* ggml_cuda_cpy_fn(const ggml_tensor * src0, ggml_tensor * src1) {
return (void*) cpy_f32_f16;
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q8_0) {
return (void*) cpy_f32_q;
+ } else if (src0->type == GGML_TYPE_Q8_0 && src1->type == GGML_TYPE_F32) {
+ return (void*) cpy_q_f32;
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_0) {
return (void*) cpy_f32_q;
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_1) {
diff --git a/ggml/src/ggml-cuda/fattn-tile-f32.cu b/ggml/src/ggml-cuda/fattn-tile-f32.cu
index 827437ca0..f402195ce 100644
--- a/ggml/src/ggml-cuda/fattn-tile-f32.cu
+++ b/ggml/src/ggml-cuda/fattn-tile-f32.cu
@@ -44,13 +44,17 @@ static __global__ void flash_attn_tile_ext_f32(
const int ne1,
const int ne2,
const int ne3) {
+#ifndef FLASH_ATTN_AVAILABLE
+ NO_DEVICE_CODE;
+ return;
+#endif // FLASH_ATTN_AVAILABLE
// Skip unused kernel variants for faster compilation:
if (use_logit_softcap && !(D == 128 || D == 256)) {
NO_DEVICE_CODE;
return;
}
- //In this kernel Q, K, V are matrices while i, j, k are matrix indices.
+ // In this kernel Q, K, V are matrices while i, j, k are matrix indices.
const int ic0 = (blockIdx.x / parallel_blocks) * ncols; // Index of the Q/QKV column to work on.
const int ip = blockIdx.x % parallel_blocks; // Index in group of blocks running for the same column in parallel.
diff --git a/ggml/src/ggml-cuda/fattn.cu b/ggml/src/ggml-cuda/fattn.cu
index f28a19d40..83e5589a1 100644
--- a/ggml/src/ggml-cuda/fattn.cu
+++ b/ggml/src/ggml-cuda/fattn.cu
@@ -314,7 +314,7 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
}
if (!fast_fp16_available(cc)) {
- if (Q->ne[1] <= 8) {
+ if (Q->ne[1] <= 8 || Q->ne[0] == 256) {
ggml_cuda_flash_attn_ext_vec_f32(ctx, dst);
} else {
ggml_cuda_flash_attn_ext_tile_f32(ctx, dst);
diff --git a/ggml/src/ggml-cuda/rwkv-wkv.cu b/ggml/src/ggml-cuda/rwkv-wkv.cu
new file mode 100644
index 000000000..098e92d35
--- /dev/null
+++ b/ggml/src/ggml-cuda/rwkv-wkv.cu
@@ -0,0 +1,89 @@
+#include "common.cuh"
+#include "rwkv-wkv.cuh"
+
+static __global__ void rwkv_wkv_f32(const int B, const int T, const int C, const int H, const float * k, const float * v, const float * r, const float * tf, const float * td, const float * s, float * dst) {
+ const int tid = threadIdx.x;
+ const int bid = blockIdx.x;
+
+ const int head_size = CUDA_WKV_BLOCK_SIZE;
+ const int batch_i = bid / H;
+ const int head_i = bid % H;
+ const int state_size = C * head_size;
+ const int n_seq_tokens = T / B;
+
+ float state[head_size];
+ __shared__ float _k[head_size], _r[head_size], _tf[head_size], _td[head_size];
+
+ #pragma unroll
+ for (int i = 0; i < head_size; i++) {
+ state[i] = s[batch_i * state_size + head_i * head_size * head_size + i * head_size + tid];
+ }
+
+ __syncthreads();
+ _tf[tid] = tf[head_i * head_size + tid];
+ __syncthreads();
+
+ for (int t = batch_i * n_seq_tokens * C + head_i * head_size + tid; t < (batch_i + 1) * n_seq_tokens * C + head_i * head_size + tid; t += C) {
+ __syncthreads();
+ _k[tid] = k[t];
+ _r[tid] = r[t];
+ _td[tid] = td[t];
+ __syncthreads();
+
+ const float _v = v[t];
+ float y = 0;
+ for (int j = 0; j < head_size; j += 4) {
+ const float4& k = (float4&)(_k[j]);
+ const float4& r = (float4&)(_r[j]);
+ const float4& tf = (float4&)(_tf[j]);
+ const float4& td = (float4&)(_td[j]);
+ float4& s = (float4&)(state[j]);
+ float4 kv;
+
+ kv.x = k.x * _v;
+ kv.y = k.y * _v;
+ kv.z = k.z * _v;
+ kv.w = k.w * _v;
+
+ y += r.x * (tf.x * kv.x + s.x);
+ y += r.y * (tf.y * kv.y + s.y);
+ y += r.z * (tf.z * kv.z + s.z);
+ y += r.w * (tf.w * kv.w + s.w);
+
+ s.x = s.x * td.x + kv.x;
+ s.y = s.y * td.y + kv.y;
+ s.z = s.z * td.z + kv.z;
+ s.w = s.w * td.w + kv.w;
+ }
+ dst[t] = y;
+ }
+
+ #pragma unroll
+ for (int i = 0; i < head_size; i++) {
+ dst[T * C + batch_i * state_size + head_i * head_size * head_size + i * head_size + tid] = state[i];
+ }
+}
+
+void ggml_cuda_op_rwkv_wkv(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+ const float * k_d = (const float *)dst->src[0]->data;
+ const float * v_d = (const float *)dst->src[1]->data;
+ const float * r_d = (const float *)dst->src[2]->data;
+ const float * tf_d = (const float *)dst->src[3]->data;
+ const float * td_d = (const float *)dst->src[4]->data;
+ const float * s_d = (const float *)dst->src[5]->data;
+
+ const int64_t B = dst->src[5]->ne[1];
+ const int64_t T = dst->src[0]->ne[3];
+ const int64_t C = dst->ne[0];
+ const int64_t H = dst->src[0]->ne[2];
+
+ float * dst_d = (float *)dst->data;
+
+ cudaStream_t stream = ctx.stream();
+
+ GGML_ASSERT(dst->src[5]->type == GGML_TYPE_F32);
+ GGML_ASSERT(C % H == 0);
+ GGML_ASSERT(C / H == CUDA_WKV_BLOCK_SIZE);
+
+ rwkv_wkv_f32<<>>(B, T, C, H, k_d, v_d, r_d, tf_d, td_d, s_d, dst_d);
+}
diff --git a/ggml/src/ggml-cuda/rwkv-wkv.cuh b/ggml/src/ggml-cuda/rwkv-wkv.cuh
new file mode 100644
index 000000000..13795247f
--- /dev/null
+++ b/ggml/src/ggml-cuda/rwkv-wkv.cuh
@@ -0,0 +1,5 @@
+#include "common.cuh"
+
+#define CUDA_WKV_BLOCK_SIZE 64
+
+void ggml_cuda_op_rwkv_wkv(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
diff --git a/ggml/src/ggml-cuda/unary.cu b/ggml/src/ggml-cuda/unary.cu
index 163b5a8ff..81fc92202 100644
--- a/ggml/src/ggml-cuda/unary.cu
+++ b/ggml/src/ggml-cuda/unary.cu
@@ -95,6 +95,15 @@ static __global__ void hardswish_f32(const float * x, float * dst, const int k)
dst[i] = x[i] * fminf(1.0f, fmaxf(0.0f, (x[i] + 3.0f) / 6.0f));
}
+static __global__ void exp_f32(const float * x, float * dst, const int k) {
+ const int i = blockDim.x*blockIdx.x + threadIdx.x;
+
+ if (i >= k) {
+ return;
+ }
+ dst[i] = expf(x[i]);
+}
+
static __global__ void leaky_relu_f32(const float * x, float * dst, const int k, const float negative_slope) {
const int i = blockDim.x*blockIdx.x + threadIdx.x;
if (i >= k) {
@@ -189,6 +198,11 @@ static void hardswish_f32_cuda(const float * x, float * dst, const int k, cudaSt
hardswish_f32<<>>(x, dst, k);
}
+static void exp_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
+ const int num_blocks = (k + CUDA_EXP_BLOCK_SIZE - 1) / CUDA_EXP_BLOCK_SIZE;
+ exp_f32<<>>(x, dst, k);
+}
+
static void leaky_relu_f32_cuda(const float * x, float * dst, const int k, const float negative_slope, cudaStream_t stream) {
const int num_blocks = (k + CUDA_RELU_BLOCK_SIZE - 1) / CUDA_RELU_BLOCK_SIZE;
leaky_relu_f32<<>>(x, dst, k, negative_slope);
@@ -354,6 +368,20 @@ void ggml_cuda_op_hardswish(ggml_backend_cuda_context & ctx, ggml_tensor * dst)
hardswish_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream);
}
+void ggml_cuda_op_exp(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+ const ggml_tensor * src0 = dst->src[0];
+ const float * src0_d = (const float *)src0->data;
+ float * dst_d = (float *)dst->data;
+ cudaStream_t stream = ctx.stream();
+
+ GGML_ASSERT(ggml_is_contiguous(src0));
+
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
+ GGML_ASSERT( dst->type == GGML_TYPE_F32);
+
+ exp_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream);
+}
+
void ggml_cuda_op_leaky_relu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const float * src0_d = (const float *)src0->data;
diff --git a/ggml/src/ggml-cuda/unary.cuh b/ggml/src/ggml-cuda/unary.cuh
index fe519f6a2..c91936728 100644
--- a/ggml/src/ggml-cuda/unary.cuh
+++ b/ggml/src/ggml-cuda/unary.cuh
@@ -8,6 +8,7 @@
#define CUDA_RELU_BLOCK_SIZE 256
#define CUDA_SIGMOID_BLOCK_SIZE 256
#define CUDA_HARDSIGMOID_BLOCK_SIZE 256
+#define CUDA_EXP_BLOCK_SIZE 256
#define CUDA_HARDSWISH_BLOCK_SIZE 256
#define CUDA_SQR_BLOCK_SIZE 256
#define CUDA_SQRT_BLOCK_SIZE 256
@@ -32,6 +33,8 @@ void ggml_cuda_op_sigmoid(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_hardsigmoid(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
+void ggml_cuda_op_exp(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
+
void ggml_cuda_op_hardswish(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_leaky_relu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
diff --git a/ggml/src/ggml-cuda/vendors/musa.h b/ggml/src/ggml-cuda/vendors/musa.h
index 8df571149..1604b8229 100644
--- a/ggml/src/ggml-cuda/vendors/musa.h
+++ b/ggml/src/ggml-cuda/vendors/musa.h
@@ -26,6 +26,7 @@
#define cublasSetStream mublasSetStream
#define cublasSgemm mublasSgemm
#define cublasStatus_t mublasStatus_t
+#define cublasOperation_t mublasOperation_t
#define cublasGetStatusString mublasStatus_to_string
#define cudaDataType_t musaDataType_t
#define cudaDeviceCanAccessPeer musaDeviceCanAccessPeer
@@ -56,6 +57,7 @@
#define cudaLaunchHostFunc musaLaunchHostFunc
#define cudaMalloc musaMalloc
#define cudaMallocHost musaMallocHost
+#define cudaMallocManaged musaMallocManaged
#define cudaMemcpy musaMemcpy
#define cudaMemcpyAsync musaMemcpyAsync
#define cudaMemcpyPeerAsync musaMemcpyPeerAsync
diff --git a/ggml/src/ggml-metal.metal b/ggml/src/ggml-metal.metal
index f323ab5f4..2b2000323 100644
--- a/ggml/src/ggml-metal.metal
+++ b/ggml/src/ggml-metal.metal
@@ -2631,11 +2631,11 @@ kernel void kernel_flash_attn_ext_vec_f16(
const short iv3 = iq3 / rv3;
// load the queries from shared memory into local memory
- half4 mq[D4];
+ float4 mq[D4];
for (short ii = 0; ii < D4; ii += NW) {
short i = ii + tiisg;
- mq[i] = sq4[i];
+ mq[i] = (float4) sq4[i];
}
// pointer to the mask
@@ -2661,11 +2661,11 @@ kernel void kernel_flash_attn_ext_vec_f16(
for (short ii = 0; ii < D4; ii += NW) {
const short i = ii + tiisg;
- half4x4 mk;
- mk[0] = pk4[i + 0*(nb11/8)];
- mk[1] = pk4[i + 1*(nb11/8)];
- mk[2] = pk4[i + 2*(nb11/8)];
- mk[3] = pk4[i + 3*(nb11/8)];
+ float4x4 mk;
+ mk[0] = (float4) pk4[i + 0*(nb11/8)];
+ mk[1] = (float4) pk4[i + 1*(nb11/8)];
+ mk[2] = (float4) pk4[i + 2*(nb11/8)];
+ mk[3] = (float4) pk4[i + 3*(nb11/8)];
mqk += (float4) (mq[i] * mk);
}
diff --git a/ggml/src/ggml-sycl.cpp b/ggml/src/ggml-sycl.cpp
index 16e6be4a0..6978a3192 100644
--- a/ggml/src/ggml-sycl.cpp
+++ b/ggml/src/ggml-sycl.cpp
@@ -3496,8 +3496,7 @@ static void ggml_sycl_mul_mat(ggml_backend_sycl_context & ctx, const ggml_tensor
bool use_mul_mat_vec_q = ggml_is_quantized(src0->type)
&& src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32
- && src1->ne[1] <= MMVQ_MAX_BATCH_SIZE
- && (ctx.stream()->get_backend() == sycl::backend::ext_oneapi_cuda || src1->ne[1] > MMVQ_MIN_BATCH_SIZE);
+ && src1->ne[1] <= MMVQ_MAX_BATCH_SIZE;
bool use_mul_mat_q = ggml_sycl_supports_mmq(src0->type)
&& src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32;
diff --git a/ggml/src/ggml-sycl/common.hpp b/ggml/src/ggml-sycl/common.hpp
index 05947ccb7..bc0faa867 100644
--- a/ggml/src/ggml-sycl/common.hpp
+++ b/ggml/src/ggml-sycl/common.hpp
@@ -134,7 +134,6 @@ typedef sycl::float2 dfloat2;
#endif // GGML_SYCL_F16
#define MMVQ_MAX_BATCH_SIZE 8
-#define MMVQ_MIN_BATCH_SIZE 4
static const int8_t kvalues_iq4nl[16]={-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113};
diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c
index 69f663529..9a9ea12e0 100644
--- a/ggml/src/ggml.c
+++ b/ggml/src/ggml.c
@@ -63,6 +63,25 @@ int ggml_sve_cnt_b = 0;
#pragma warning(disable: 4702)
#endif
+// Note: once we move threading into a separate C++ file
+// will use std::hardware_destructive_interference_size instead of hardcoding it here
+// and we'll use C++ attribute syntax.
+#define GGML_CACHE_LINE 64
+
+#if defined(__clang__) || defined(__GNUC__)
+#define GGML_CACHE_ALIGN __attribute__((aligned(GGML_CACHE_LINE)))
+#endif
+
+#if defined(__has_feature)
+#if __has_feature(thread_sanitizer)
+#define GGML_TSAN_ENABLED 1
+#endif
+#else // __has_feature
+#if defined(__SANITIZE_THREAD__)
+#define GGML_TSAN_ENABLED 1
+#endif
+#endif // __has_feature
+
#if defined(_WIN32)
#define WIN32_LEAN_AND_MEAN
@@ -72,6 +91,8 @@ int ggml_sve_cnt_b = 0;
#include
#if !defined(__clang__)
+#define GGML_CACHE_ALIGN __declspec(align(GGML_CACHE_LINE))
+
typedef volatile LONG atomic_int;
typedef atomic_int atomic_bool;
typedef atomic_int atomic_flag;
@@ -114,6 +135,9 @@ static atomic_bool atomic_flag_test_and_set(atomic_flag * ptr) {
static void atomic_flag_clear(atomic_flag * ptr) {
InterlockedExchange(ptr, 0);
}
+static void atomic_thread_fence(memory_order mo) {
+ MemoryBarrier();
+}
#else // clang
#include
#endif
@@ -289,7 +313,6 @@ void ggml_abort(const char * file, int line, const char * fmt, ...) {
#define GGML_DEBUG 0
#define GGML_GELU_FP16
#define GGML_GELU_QUICK_FP16
-#define GGML_N_TASKS_MAX (-1)
#define GGML_SOFT_MAX_UNROLL 4
#define GGML_VEC_DOT_UNROLL 2
@@ -2015,8 +2038,8 @@ struct ggml_threadpool {
// synchronization primitives
atomic_int n_graph; // incremented when there is work to be done (i.e each graph)
- atomic_int n_barrier;
- atomic_int n_barrier_passed;
+ atomic_int GGML_CACHE_ALIGN n_barrier;
+ atomic_int GGML_CACHE_ALIGN n_barrier_passed;
atomic_int current_chunk; // currently processing chunk during Mat_Mul, shared between all the threads.
// these are atomic as an annotation for thread-sanitizer
@@ -3213,20 +3236,27 @@ static void ggml_barrier(struct ggml_threadpool * tp) {
// enter barrier (full seq-cst fence)
int n_barrier = atomic_fetch_add_explicit(&tp->n_barrier, 1, memory_order_seq_cst);
- int last = 0;
if (n_barrier == (n_threads - 1)) {
// last thread
atomic_store_explicit(&tp->n_barrier, 0, memory_order_relaxed);
- last = 1;
- } else {
- // wait for other threads
- while (atomic_load_explicit(&tp->n_barrier_passed, memory_order_relaxed) == n_passed) {
- ggml_thread_cpu_relax();
- }
+
+ // exit barrier (fill seq-cst fence)
+ atomic_fetch_add_explicit(&tp->n_barrier_passed, 1, memory_order_seq_cst);
+ return;
+ }
+
+ // wait for other threads
+ while (atomic_load_explicit(&tp->n_barrier_passed, memory_order_relaxed) == n_passed) {
+ ggml_thread_cpu_relax();
}
// exit barrier (full seq-cst fence)
- atomic_fetch_add_explicit(&tp->n_barrier_passed, last, memory_order_seq_cst);
+ // TSAN doesn't support standalone fence yet, we use a dummy read-modify-write instead
+ #ifdef GGML_TSAN_ENABLED
+ atomic_fetch_add_explicit(&tp->n_barrier_passed, 0, memory_order_seq_cst);
+ #else
+ atomic_thread_fence(memory_order_seq_cst);
+ #endif
#endif
}
@@ -20299,10 +20329,13 @@ static inline bool ggml_graph_compute_thread_ready(struct ggml_compute_state * s
// sync thread state after polling
static inline void ggml_graph_compute_thread_sync(struct ggml_compute_state * state) {
- struct ggml_threadpool * threadpool = state->threadpool;
- // this should just be atomic_thread_fence(seq_cst) but it confuses thread-sanitizer
- // so instead we just use a dummy read-modify-write
- atomic_fetch_add_explicit(&threadpool->n_graph, 0, memory_order_seq_cst);
+ // TSAN doesn't support standalone fence yet, we use a dummy read-modify-write instead
+ #ifdef GGML_TSAN_ENABLED
+ atomic_fetch_add_explicit(&state->threadpool->n_graph, 0, memory_order_seq_cst);
+ #else
+ atomic_thread_fence(memory_order_seq_cst);
+ #endif
+ UNUSED(state);
}
static inline bool ggml_graph_compute_poll_for_work(struct ggml_compute_state * state) {
diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py
index b36a60d49..560eee916 100644
--- a/gguf-py/gguf/constants.py
+++ b/gguf-py/gguf/constants.py
@@ -235,6 +235,7 @@ class MODEL_ARCH(IntEnum):
NEMOTRON = auto()
EXAONE = auto()
GRANITE = auto()
+ GRANITE_MOE = auto()
class MODEL_TENSOR(IntEnum):
@@ -392,6 +393,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
MODEL_ARCH.NEMOTRON: "nemotron",
MODEL_ARCH.EXAONE: "exaone",
MODEL_ARCH.GRANITE: "granite",
+ MODEL_ARCH.GRANITE_MOE: "granitemoe",
}
TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
@@ -1232,6 +1234,7 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
MODEL_ARCH.GRANITE: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT_NORM,
+ MODEL_TENSOR.OUTPUT,
MODEL_TENSOR.ATTN_NORM,
MODEL_TENSOR.ATTN_Q,
MODEL_TENSOR.ATTN_K,
@@ -1242,6 +1245,21 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
MODEL_TENSOR.FFN_DOWN,
MODEL_TENSOR.FFN_UP,
],
+ MODEL_ARCH.GRANITE_MOE: [
+ MODEL_TENSOR.TOKEN_EMBD,
+ MODEL_TENSOR.OUTPUT_NORM,
+ MODEL_TENSOR.OUTPUT,
+ MODEL_TENSOR.ATTN_NORM,
+ MODEL_TENSOR.ATTN_Q,
+ MODEL_TENSOR.ATTN_K,
+ MODEL_TENSOR.ATTN_V,
+ MODEL_TENSOR.ATTN_OUT,
+ MODEL_TENSOR.FFN_NORM,
+ MODEL_TENSOR.FFN_GATE_INP,
+ MODEL_TENSOR.FFN_GATE_EXP,
+ MODEL_TENSOR.FFN_DOWN_EXP,
+ MODEL_TENSOR.FFN_UP_EXP,
+ ],
# TODO
}
diff --git a/gguf-py/gguf/tensor_mapping.py b/gguf-py/gguf/tensor_mapping.py
index 2ebfa2b43..4e850726e 100644
--- a/gguf-py/gguf/tensor_mapping.py
+++ b/gguf-py/gguf/tensor_mapping.py
@@ -251,11 +251,12 @@ class TensorNameMap:
),
MODEL_TENSOR.FFN_GATE_INP: (
- "layers.{bid}.feed_forward.gate", # mixtral
- "model.layers.{bid}.block_sparse_moe.gate", # mixtral
- "model.layers.{bid}.mlp.gate", # qwen2moe olmoe
- "transformer.decoder_layer.{bid}.router", # Grok
- "transformer.blocks.{bid}.ffn.router.layer", # dbrx
+ "layers.{bid}.feed_forward.gate", # mixtral
+ "model.layers.{bid}.block_sparse_moe.gate", # mixtral
+ "model.layers.{bid}.mlp.gate", # qwen2moe olmoe
+ "transformer.decoder_layer.{bid}.router", # Grok
+ "transformer.blocks.{bid}.ffn.router.layer", # dbrx
+ "model.layers.{bid}.block_sparse_moe.router.layer", # granitemoe
),
MODEL_TENSOR.FFN_GATE_INP_SHEXP: (
@@ -364,10 +365,11 @@ class TensorNameMap:
),
MODEL_TENSOR.FFN_DOWN_EXP: (
- "layers.{bid}.feed_forward.experts.w2", # mixtral (merged)
- "transformer.decoder_layer.{bid}.moe.linear_1", # Grok (merged)
- "transformer.blocks.{bid}.ffn.experts.mlp.w2", # dbrx
- "model.layers.{bid}.mlp.experts.down_proj", # qwen2moe olmoe (merged)
+ "layers.{bid}.feed_forward.experts.w2", # mixtral (merged)
+ "transformer.decoder_layer.{bid}.moe.linear_1", # Grok (merged)
+ "transformer.blocks.{bid}.ffn.experts.mlp.w2", # dbrx
+ "model.layers.{bid}.mlp.experts.down_proj", # qwen2moe olmoe (merged)
+ "model.layers.{bid}.block_sparse_moe.output_linear", # granitemoe
),
MODEL_TENSOR.FFN_DOWN_SHEXP: (
diff --git a/include/llama.h b/include/llama.h
index 8d326447a..645602d0d 100644
--- a/include/llama.h
+++ b/include/llama.h
@@ -1068,6 +1068,7 @@ extern "C" {
LLAMA_API struct llama_sampler * llama_sampler_init_dist (uint32_t seed);
/// @details Sorts candidate tokens by their logits in descending order and calculate probabilities based on logits.
+ /// NOTE: Avoid using on the full vocabulary as the sorting can become slow. For example, apply top-k or top-p sampling first.
LLAMA_API struct llama_sampler * llama_sampler_init_softmax (void);
/// @details Top-K sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751
diff --git a/klite.embd b/klite.embd
index c7af6a494..7d6a4e849 100644
--- a/klite.embd
+++ b/klite.embd
@@ -4185,6 +4185,7 @@ Current version indicated by LITEVER below.
const default_oai_image_endpoint = "/images/generations";
const default_oai_tts_endpoint = "/audio/speech";
+ const default_dalle_model_name = "dall-e-3";
const claude_submit_endpoint = "/complete";
const claude_submit_endpoint_v3 = "/messages";
@@ -4325,6 +4326,7 @@ Current version indicated by LITEVER below.
saved_oai_addr: default_oai_base, //do not ever share this in save files!
saved_dalle_key: "",
saved_dalle_url: (default_oai_base + "/v1" + default_oai_image_endpoint),
+ saved_dalle_model: default_dalle_model_name,
saved_oai_tts_key: "",
saved_oai_tts_url: (default_oai_base + "/v1" + default_oai_tts_endpoint),
saved_openrouter_key: "",
@@ -4557,16 +4559,23 @@ Current version indicated by LITEVER below.
},
{
"id":12,
- "name":"Mistral Gen 1",
- "user":"\\n[INST] ",
- "assistant":" [/INST]\\n",
+ "name":"Mistral V1",
+ "user":" [INST] ",
+ "assistant":" [/INST]",
"system":"",
},
{
"id":13,
- "name":"Mistral Gen 2",
- "user":"\\n[INST]",
- "assistant":"[/INST]\\n",
+ "name":"Mistral V2 & V3",
+ "user":"[INST] ",
+ "assistant":"[/INST]",
+ "system":"",
+ },
+ {
+ "id":14,
+ "name":"Mistral V3-Tekken",
+ "user":"[INST]",
+ "assistant":"[/INST]",
"system":"",
}
];
@@ -5114,6 +5123,7 @@ Current version indicated by LITEVER below.
const foundChub = urlParams.get('chub');
const foundPyg = urlParams.get('pyg');
const foundAicc = urlParams.get('aicc');
+ const foundQuery = urlParams.get('query');
if (foundStory && foundStory != "") {
if (localsettings.persist_session && !safe_to_overwrite()) {
@@ -5150,6 +5160,25 @@ Current version indicated by LITEVER below.
//purge url params
window.history.replaceState(null, null, window.location.pathname);
}
+ else if (foundQuery && foundQuery != "")
+ {
+ window.history.replaceState(null, null, window.location.pathname);
+ if (localsettings.persist_session && !safe_to_overwrite()) {
+ msgboxYesNo("You already have an existing persistent story. Do you want to overwrite it?","Overwrite Story Warning",()=>{
+ localsettings.opmode = 4;
+ restart_new_game(false);
+ document.getElementById("input_text").value = foundQuery;
+ submit_generation();
+ },null,false);
+ }
+ else
+ {
+ localsettings.opmode = 4;
+ restart_new_game(false);
+ document.getElementById("input_text").value = foundQuery;
+ submit_generation();
+ }
+ }
}
var image_models_fetched = false;
@@ -5363,6 +5392,18 @@ Current version indicated by LITEVER below.
}
},false);
}
+ function set_dalle_model()
+ {
+ inputBox("Enter DALL-E API Model Identifier.","DALL-E API Model Identifier",localsettings.saved_dalle_model,"Input DALL-E Model Identifier", ()=>{
+ let userinput = getInputBoxValue();
+ userinput = userinput.trim();
+ if (userinput != null && userinput!="") {
+ localsettings.saved_dalle_model = userinput.trim();
+ }else{
+ localsettings.saved_dalle_model = default_dalle_model_name;
+ }
+ },false);
+ }
function set_oai_tts_key()
{
@@ -5394,7 +5435,7 @@ Current version indicated by LITEVER below.
let prompt = splits[0].trim();
let dalle_payload = {
- "model": "dall-e-3",
+ "model": localsettings.saved_dalle_model,
"prompt": prompt,
"n": 1,
"size": "1024x1024",
@@ -12596,7 +12637,7 @@ Current version indicated by LITEVER below.
//mistral api does not support presence pen
oai_payload.presence_penalty = scaled_rep_pen;
}
- if(targetep.toLowerCase().includes("featherless.ai"))
+ if(document.getElementById("useoainonstandard").checked || targetep.toLowerCase().includes("featherless.ai"))
{
//featherless api supports additional fields, include them
oai_payload.top_k = (submit_payload.params.top_k<1?300:submit_payload.params.top_k);
@@ -12605,6 +12646,7 @@ Current version indicated by LITEVER below.
{
oai_payload.seed = submit_payload.params.sampler_seed;
}
+ oai_payload.top_a = localsettings.top_a;
}
if(submit_payload.params.logit_bias && JSON.stringify(submit_payload.params.logit_bias) != '{}')
{
@@ -17982,11 +18024,13 @@ Current version indicated by LITEVER below.
+ Chat-Completions API
+
@@ -18694,8 +18738,9 @@ Current version indicated by LITEVER below.
- |
- |
+ |
+ |
+ |
diff --git a/koboldcpp.py b/koboldcpp.py
index 5f4314d9a..5d3423a08 100644
--- a/koboldcpp.py
+++ b/koboldcpp.py
@@ -41,7 +41,7 @@ maxhordelen = 400
modelbusy = threading.Lock()
requestsinqueue = 0
defaultport = 5001
-KcppVersion = "1.75.2"
+KcppVersion = "1.76"
showdebug = True
guimode = False
showsamplerwarning = True
diff --git a/src/llama-impl.h b/src/llama-impl.h
index 2bde75ec1..70f16b61c 100644
--- a/src/llama-impl.h
+++ b/src/llama-impl.h
@@ -28,6 +28,8 @@ void llama_log_callback_default(ggml_log_level level, const char * text, void *
#define LLAMA_LOG_INFO(...) llama_log_internal(GGML_LOG_LEVEL_INFO , __VA_ARGS__)
#define LLAMA_LOG_WARN(...) llama_log_internal(GGML_LOG_LEVEL_WARN , __VA_ARGS__)
#define LLAMA_LOG_ERROR(...) llama_log_internal(GGML_LOG_LEVEL_ERROR, __VA_ARGS__)
+#define LLAMA_LOG_DEBUG(...) llama_log_internal(GGML_LOG_LEVEL_DEBUG, __VA_ARGS__)
+#define LLAMA_LOG_CONT(...) llama_log_internal(GGML_LOG_LEVEL_CONT , __VA_ARGS__)
//
// helpers
diff --git a/src/llama-sampling.cpp b/src/llama-sampling.cpp
index 5299f5116..e255a8fc4 100644
--- a/src/llama-sampling.cpp
+++ b/src/llama-sampling.cpp
@@ -3,13 +3,14 @@
#include "llama-vocab.h"
#include "llama-grammar.h"
-#include
#include
-#include
-#include
+#include
#include
#include
#include
+#include
+#include
+#include
#include
#include
#include
diff --git a/src/llama-vocab.cpp b/src/llama-vocab.cpp
index ec1d7d736..7737992f5 100644
--- a/src/llama-vocab.cpp
+++ b/src/llama-vocab.cpp
@@ -1826,11 +1826,7 @@ llama_token_attr llama_token_get_attr_impl(const struct llama_vocab & vocab, lla
}
bool llama_token_is_eog_impl(const struct llama_vocab & vocab, llama_token token) {
- return token != -1 && (
- token == llama_token_eos_impl(vocab) ||
- token == llama_token_eot_impl(vocab) ||
- token == llama_token_eom_impl(vocab)
- );
+ return token != -1 && vocab.special_eog_ids.count(token) > 0;
}
bool llama_token_is_control_impl(const struct llama_vocab & vocab, llama_token token) {
diff --git a/src/llama-vocab.h b/src/llama-vocab.h
index dc4b5f12f..cc46f642b 100644
--- a/src/llama-vocab.h
+++ b/src/llama-vocab.h
@@ -6,6 +6,7 @@
#include
#include
#include