mirror of
https://github.com/LostRuins/koboldcpp.git
synced 2025-09-10 17:14:36 +00:00
Merge branch 'upstream' into concedo_experimental
# Conflicts: # .dockerignore # .github/workflows/build.yml # .github/workflows/docker.yml # Makefile # README.md # examples/infill/infill.cpp # examples/perplexity/perplexity.cpp # examples/server/README.md # examples/speculative/speculative.cpp # flake.lock # ggml/src/CMakeLists.txt # scripts/sync-ggml.last # tests/test-backend-ops.cpp # tests/test-sampling.cpp
This commit is contained in:
commit
ea55f69dc1
39 changed files with 2587 additions and 1564 deletions
|
@ -692,7 +692,7 @@ gpt_params_context gpt_params_parser_init(gpt_params & params, llama_example ex,
|
||||||
[](gpt_params & params) {
|
[](gpt_params & params) {
|
||||||
params.ctx_shift = false;
|
params.ctx_shift = false;
|
||||||
}
|
}
|
||||||
).set_examples({LLAMA_EXAMPLE_MAIN}));
|
).set_examples({LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_NO_CONTEXT_SHIFT"));
|
||||||
add_opt(llama_arg(
|
add_opt(llama_arg(
|
||||||
{"--chunks"}, "N",
|
{"--chunks"}, "N",
|
||||||
format("max number of chunks to process (default: %d, -1 = all)", params.n_chunks),
|
format("max number of chunks to process (default: %d, -1 = all)", params.n_chunks),
|
||||||
|
@ -1103,7 +1103,7 @@ gpt_params_context gpt_params_parser_init(gpt_params & params, llama_example ex,
|
||||||
else if (value == "last") { params.pooling_type = LLAMA_POOLING_TYPE_LAST; }
|
else if (value == "last") { params.pooling_type = LLAMA_POOLING_TYPE_LAST; }
|
||||||
else { throw std::invalid_argument("invalid value"); }
|
else { throw std::invalid_argument("invalid value"); }
|
||||||
}
|
}
|
||||||
).set_examples({LLAMA_EXAMPLE_EMBEDDING}));
|
).set_examples({LLAMA_EXAMPLE_EMBEDDING, LLAMA_EXAMPLE_RETRIEVAL, LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_POOLING"));
|
||||||
add_opt(llama_arg(
|
add_opt(llama_arg(
|
||||||
{"--attention"}, "{causal,non,causal}",
|
{"--attention"}, "{causal,non,causal}",
|
||||||
"attention type for embeddings, use model default if unspecified",
|
"attention type for embeddings, use model default if unspecified",
|
||||||
|
@ -1122,77 +1122,77 @@ gpt_params_context gpt_params_parser_init(gpt_params & params, llama_example ex,
|
||||||
else if (value == "yarn") { params.rope_scaling_type = LLAMA_ROPE_SCALING_TYPE_YARN; }
|
else if (value == "yarn") { params.rope_scaling_type = LLAMA_ROPE_SCALING_TYPE_YARN; }
|
||||||
else { throw std::invalid_argument("invalid value"); }
|
else { throw std::invalid_argument("invalid value"); }
|
||||||
}
|
}
|
||||||
));
|
).set_env("LLAMA_ARG_ROPE_SCALING_TYPE"));
|
||||||
add_opt(llama_arg(
|
add_opt(llama_arg(
|
||||||
{"--rope-scale"}, "N",
|
{"--rope-scale"}, "N",
|
||||||
"RoPE context scaling factor, expands context by a factor of N",
|
"RoPE context scaling factor, expands context by a factor of N",
|
||||||
[](gpt_params & params, const std::string & value) {
|
[](gpt_params & params, const std::string & value) {
|
||||||
params.rope_freq_scale = 1.0f / std::stof(value);
|
params.rope_freq_scale = 1.0f / std::stof(value);
|
||||||
}
|
}
|
||||||
));
|
).set_env("LLAMA_ARG_ROPE_SCALE"));
|
||||||
add_opt(llama_arg(
|
add_opt(llama_arg(
|
||||||
{"--rope-freq-base"}, "N",
|
{"--rope-freq-base"}, "N",
|
||||||
"RoPE base frequency, used by NTK-aware scaling (default: loaded from model)",
|
"RoPE base frequency, used by NTK-aware scaling (default: loaded from model)",
|
||||||
[](gpt_params & params, const std::string & value) {
|
[](gpt_params & params, const std::string & value) {
|
||||||
params.rope_freq_base = std::stof(value);
|
params.rope_freq_base = std::stof(value);
|
||||||
}
|
}
|
||||||
));
|
).set_env("LLAMA_ARG_ROPE_FREQ_BASE"));
|
||||||
add_opt(llama_arg(
|
add_opt(llama_arg(
|
||||||
{"--rope-freq-scale"}, "N",
|
{"--rope-freq-scale"}, "N",
|
||||||
"RoPE frequency scaling factor, expands context by a factor of 1/N",
|
"RoPE frequency scaling factor, expands context by a factor of 1/N",
|
||||||
[](gpt_params & params, const std::string & value) {
|
[](gpt_params & params, const std::string & value) {
|
||||||
params.rope_freq_scale = std::stof(value);
|
params.rope_freq_scale = std::stof(value);
|
||||||
}
|
}
|
||||||
));
|
).set_env("LLAMA_ARG_ROPE_FREQ_SCALE"));
|
||||||
add_opt(llama_arg(
|
add_opt(llama_arg(
|
||||||
{"--yarn-orig-ctx"}, "N",
|
{"--yarn-orig-ctx"}, "N",
|
||||||
format("YaRN: original context size of model (default: %d = model training context size)", params.yarn_orig_ctx),
|
format("YaRN: original context size of model (default: %d = model training context size)", params.yarn_orig_ctx),
|
||||||
[](gpt_params & params, int value) {
|
[](gpt_params & params, int value) {
|
||||||
params.yarn_orig_ctx = value;
|
params.yarn_orig_ctx = value;
|
||||||
}
|
}
|
||||||
));
|
).set_env("LLAMA_ARG_YARN_ORIG_CTX"));
|
||||||
add_opt(llama_arg(
|
add_opt(llama_arg(
|
||||||
{"--yarn-ext-factor"}, "N",
|
{"--yarn-ext-factor"}, "N",
|
||||||
format("YaRN: extrapolation mix factor (default: %.1f, 0.0 = full interpolation)", (double)params.yarn_ext_factor),
|
format("YaRN: extrapolation mix factor (default: %.1f, 0.0 = full interpolation)", (double)params.yarn_ext_factor),
|
||||||
[](gpt_params & params, const std::string & value) {
|
[](gpt_params & params, const std::string & value) {
|
||||||
params.yarn_ext_factor = std::stof(value);
|
params.yarn_ext_factor = std::stof(value);
|
||||||
}
|
}
|
||||||
));
|
).set_env("LLAMA_ARG_YARN_EXT_FACTOR"));
|
||||||
add_opt(llama_arg(
|
add_opt(llama_arg(
|
||||||
{"--yarn-attn-factor"}, "N",
|
{"--yarn-attn-factor"}, "N",
|
||||||
format("YaRN: scale sqrt(t) or attention magnitude (default: %.1f)", (double)params.yarn_attn_factor),
|
format("YaRN: scale sqrt(t) or attention magnitude (default: %.1f)", (double)params.yarn_attn_factor),
|
||||||
[](gpt_params & params, const std::string & value) {
|
[](gpt_params & params, const std::string & value) {
|
||||||
params.yarn_attn_factor = std::stof(value);
|
params.yarn_attn_factor = std::stof(value);
|
||||||
}
|
}
|
||||||
));
|
).set_env("LLAMA_ARG_YARN_ATTN_FACTOR"));
|
||||||
add_opt(llama_arg(
|
add_opt(llama_arg(
|
||||||
{"--yarn-beta-slow"}, "N",
|
{"--yarn-beta-slow"}, "N",
|
||||||
format("YaRN: high correction dim or alpha (default: %.1f)", (double)params.yarn_beta_slow),
|
format("YaRN: high correction dim or alpha (default: %.1f)", (double)params.yarn_beta_slow),
|
||||||
[](gpt_params & params, const std::string & value) {
|
[](gpt_params & params, const std::string & value) {
|
||||||
params.yarn_beta_slow = std::stof(value);
|
params.yarn_beta_slow = std::stof(value);
|
||||||
}
|
}
|
||||||
));
|
).set_env("LLAMA_ARG_YARN_BETA_SLOW"));
|
||||||
add_opt(llama_arg(
|
add_opt(llama_arg(
|
||||||
{"--yarn-beta-fast"}, "N",
|
{"--yarn-beta-fast"}, "N",
|
||||||
format("YaRN: low correction dim or beta (default: %.1f)", (double)params.yarn_beta_fast),
|
format("YaRN: low correction dim or beta (default: %.1f)", (double)params.yarn_beta_fast),
|
||||||
[](gpt_params & params, const std::string & value) {
|
[](gpt_params & params, const std::string & value) {
|
||||||
params.yarn_beta_fast = std::stof(value);
|
params.yarn_beta_fast = std::stof(value);
|
||||||
}
|
}
|
||||||
));
|
).set_env("LLAMA_ARG_YARN_BETA_FAST"));
|
||||||
add_opt(llama_arg(
|
add_opt(llama_arg(
|
||||||
{"-gan", "--grp-attn-n"}, "N",
|
{"-gan", "--grp-attn-n"}, "N",
|
||||||
format("group-attention factor (default: %d)", params.grp_attn_n),
|
format("group-attention factor (default: %d)", params.grp_attn_n),
|
||||||
[](gpt_params & params, int value) {
|
[](gpt_params & params, int value) {
|
||||||
params.grp_attn_n = value;
|
params.grp_attn_n = value;
|
||||||
}
|
}
|
||||||
));
|
).set_env("LLAMA_ARG_GRP_ATTN_N"));
|
||||||
add_opt(llama_arg(
|
add_opt(llama_arg(
|
||||||
{"-gaw", "--grp-attn-w"}, "N",
|
{"-gaw", "--grp-attn-w"}, "N",
|
||||||
format("group-attention width (default: %.1f)", (double)params.grp_attn_w),
|
format("group-attention width (default: %.1f)", (double)params.grp_attn_w),
|
||||||
[](gpt_params & params, int value) {
|
[](gpt_params & params, int value) {
|
||||||
params.grp_attn_w = value;
|
params.grp_attn_w = value;
|
||||||
}
|
}
|
||||||
));
|
).set_env("LLAMA_ARG_GRP_ATTN_W"));
|
||||||
add_opt(llama_arg(
|
add_opt(llama_arg(
|
||||||
{"-dkvc", "--dump-kv-cache"},
|
{"-dkvc", "--dump-kv-cache"},
|
||||||
"verbose print of the KV cache",
|
"verbose print of the KV cache",
|
||||||
|
@ -1206,7 +1206,7 @@ gpt_params_context gpt_params_parser_init(gpt_params & params, llama_example ex,
|
||||||
[](gpt_params & params) {
|
[](gpt_params & params) {
|
||||||
params.no_kv_offload = true;
|
params.no_kv_offload = true;
|
||||||
}
|
}
|
||||||
));
|
).set_env("LLAMA_ARG_NO_KV_OFFLOAD"));
|
||||||
add_opt(llama_arg(
|
add_opt(llama_arg(
|
||||||
{"-ctk", "--cache-type-k"}, "TYPE",
|
{"-ctk", "--cache-type-k"}, "TYPE",
|
||||||
format("KV cache data type for K (default: %s)", params.cache_type_k.c_str()),
|
format("KV cache data type for K (default: %s)", params.cache_type_k.c_str()),
|
||||||
|
@ -1214,7 +1214,7 @@ gpt_params_context gpt_params_parser_init(gpt_params & params, llama_example ex,
|
||||||
// TODO: get the type right here
|
// TODO: get the type right here
|
||||||
params.cache_type_k = value;
|
params.cache_type_k = value;
|
||||||
}
|
}
|
||||||
));
|
).set_env("LLAMA_ARG_CACHE_TYPE_K"));
|
||||||
add_opt(llama_arg(
|
add_opt(llama_arg(
|
||||||
{"-ctv", "--cache-type-v"}, "TYPE",
|
{"-ctv", "--cache-type-v"}, "TYPE",
|
||||||
format("KV cache data type for V (default: %s)", params.cache_type_v.c_str()),
|
format("KV cache data type for V (default: %s)", params.cache_type_v.c_str()),
|
||||||
|
@ -1222,7 +1222,7 @@ gpt_params_context gpt_params_parser_init(gpt_params & params, llama_example ex,
|
||||||
// TODO: get the type right here
|
// TODO: get the type right here
|
||||||
params.cache_type_v = value;
|
params.cache_type_v = value;
|
||||||
}
|
}
|
||||||
));
|
).set_env("LLAMA_ARG_CACHE_TYPE_V"));
|
||||||
add_opt(llama_arg(
|
add_opt(llama_arg(
|
||||||
{"--perplexity", "--all-logits"},
|
{"--perplexity", "--all-logits"},
|
||||||
format("return logits for all tokens in the batch (default: %s)", params.logits_all ? "true" : "false"),
|
format("return logits for all tokens in the batch (default: %s)", params.logits_all ? "true" : "false"),
|
||||||
|
@ -1356,7 +1356,7 @@ gpt_params_context gpt_params_parser_init(gpt_params & params, llama_example ex,
|
||||||
[](gpt_params & params, const std::string & value) {
|
[](gpt_params & params, const std::string & value) {
|
||||||
params.rpc_servers = value;
|
params.rpc_servers = value;
|
||||||
}
|
}
|
||||||
));
|
).set_env("LLAMA_ARG_RPC"));
|
||||||
#endif
|
#endif
|
||||||
add_opt(llama_arg(
|
add_opt(llama_arg(
|
||||||
{"--mlock"},
|
{"--mlock"},
|
||||||
|
@ -1364,14 +1364,14 @@ gpt_params_context gpt_params_parser_init(gpt_params & params, llama_example ex,
|
||||||
[](gpt_params & params) {
|
[](gpt_params & params) {
|
||||||
params.use_mlock = true;
|
params.use_mlock = true;
|
||||||
}
|
}
|
||||||
));
|
).set_env("LLAMA_ARG_MLOCK"));
|
||||||
add_opt(llama_arg(
|
add_opt(llama_arg(
|
||||||
{"--no-mmap"},
|
{"--no-mmap"},
|
||||||
"do not memory-map model (slower load but may reduce pageouts if not using mlock)",
|
"do not memory-map model (slower load but may reduce pageouts if not using mlock)",
|
||||||
[](gpt_params & params) {
|
[](gpt_params & params) {
|
||||||
params.use_mmap = false;
|
params.use_mmap = false;
|
||||||
}
|
}
|
||||||
));
|
).set_env("LLAMA_ARG_NO_MMAP"));
|
||||||
add_opt(llama_arg(
|
add_opt(llama_arg(
|
||||||
{"--numa"}, "TYPE",
|
{"--numa"}, "TYPE",
|
||||||
"attempt optimizations that help on some NUMA systems\n"
|
"attempt optimizations that help on some NUMA systems\n"
|
||||||
|
@ -1386,7 +1386,7 @@ gpt_params_context gpt_params_parser_init(gpt_params & params, llama_example ex,
|
||||||
else if (value == "numactl") { params.numa = GGML_NUMA_STRATEGY_NUMACTL; }
|
else if (value == "numactl") { params.numa = GGML_NUMA_STRATEGY_NUMACTL; }
|
||||||
else { throw std::invalid_argument("invalid value"); }
|
else { throw std::invalid_argument("invalid value"); }
|
||||||
}
|
}
|
||||||
));
|
).set_env("LLAMA_ARG_NUMA"));
|
||||||
add_opt(llama_arg(
|
add_opt(llama_arg(
|
||||||
{"-ngl", "--gpu-layers", "--n-gpu-layers"}, "N",
|
{"-ngl", "--gpu-layers", "--n-gpu-layers"}, "N",
|
||||||
"number of layers to store in VRAM",
|
"number of layers to store in VRAM",
|
||||||
|
@ -1434,7 +1434,7 @@ gpt_params_context gpt_params_parser_init(gpt_params & params, llama_example ex,
|
||||||
fprintf(stderr, "warning: llama.cpp was compiled without support for GPU offload. Setting the split mode has no effect.\n");
|
fprintf(stderr, "warning: llama.cpp was compiled without support for GPU offload. Setting the split mode has no effect.\n");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
));
|
).set_env("LLAMA_ARG_SPLIT_MODE"));
|
||||||
add_opt(llama_arg(
|
add_opt(llama_arg(
|
||||||
{"-ts", "--tensor-split"}, "N0,N1,N2,...",
|
{"-ts", "--tensor-split"}, "N0,N1,N2,...",
|
||||||
"fraction of the model to offload to each GPU, comma-separated list of proportions, e.g. 3,1",
|
"fraction of the model to offload to each GPU, comma-separated list of proportions, e.g. 3,1",
|
||||||
|
@ -1461,7 +1461,7 @@ gpt_params_context gpt_params_parser_init(gpt_params & params, llama_example ex,
|
||||||
fprintf(stderr, "warning: llama.cpp was compiled without support for GPU offload. Setting a tensor split has no effect.\n");
|
fprintf(stderr, "warning: llama.cpp was compiled without support for GPU offload. Setting a tensor split has no effect.\n");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
));
|
).set_env("LLAMA_ARG_TENSOR_SPLIT"));
|
||||||
add_opt(llama_arg(
|
add_opt(llama_arg(
|
||||||
{"-mg", "--main-gpu"}, "INDEX",
|
{"-mg", "--main-gpu"}, "INDEX",
|
||||||
format("the GPU to use for the model (with split-mode = none), or for intermediate results and KV (with split-mode = row) (default: %d)", params.main_gpu),
|
format("the GPU to use for the model (with split-mode = none), or for intermediate results and KV (with split-mode = row) (default: %d)", params.main_gpu),
|
||||||
|
@ -1471,7 +1471,7 @@ gpt_params_context gpt_params_parser_init(gpt_params & params, llama_example ex,
|
||||||
fprintf(stderr, "warning: llama.cpp was compiled without support for GPU offload. Setting the main GPU has no effect.\n");
|
fprintf(stderr, "warning: llama.cpp was compiled without support for GPU offload. Setting the main GPU has no effect.\n");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
));
|
).set_env("LLAMA_ARG_MAIN_GPU"));
|
||||||
add_opt(llama_arg(
|
add_opt(llama_arg(
|
||||||
{"--check-tensors"},
|
{"--check-tensors"},
|
||||||
format("check model tensor data for invalid values (default: %s)", params.check_tensors ? "true" : "false"),
|
format("check model tensor data for invalid values (default: %s)", params.check_tensors ? "true" : "false"),
|
||||||
|
@ -1534,7 +1534,7 @@ gpt_params_context gpt_params_parser_init(gpt_params & params, llama_example ex,
|
||||||
[](gpt_params & params, const std::string & value) {
|
[](gpt_params & params, const std::string & value) {
|
||||||
params.model_alias = value;
|
params.model_alias = value;
|
||||||
}
|
}
|
||||||
).set_examples({LLAMA_EXAMPLE_SERVER}));
|
).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_ALIAS"));
|
||||||
add_opt(llama_arg(
|
add_opt(llama_arg(
|
||||||
{"-m", "--model"}, "FNAME",
|
{"-m", "--model"}, "FNAME",
|
||||||
ex == LLAMA_EXAMPLE_EXPORT_LORA
|
ex == LLAMA_EXAMPLE_EXPORT_LORA
|
||||||
|
@ -1742,7 +1742,7 @@ gpt_params_context gpt_params_parser_init(gpt_params & params, llama_example ex,
|
||||||
[](gpt_params & params, const std::string & value) {
|
[](gpt_params & params, const std::string & value) {
|
||||||
params.public_path = value;
|
params.public_path = value;
|
||||||
}
|
}
|
||||||
).set_examples({LLAMA_EXAMPLE_SERVER}));
|
).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_STATIC_PATH"));
|
||||||
add_opt(llama_arg(
|
add_opt(llama_arg(
|
||||||
{"--embedding", "--embeddings"},
|
{"--embedding", "--embeddings"},
|
||||||
format("restrict to only support embedding use case; use only with dedicated embedding models (default: %s)", params.embedding ? "enabled" : "disabled"),
|
format("restrict to only support embedding use case; use only with dedicated embedding models (default: %s)", params.embedding ? "enabled" : "disabled"),
|
||||||
|
@ -1780,14 +1780,14 @@ gpt_params_context gpt_params_parser_init(gpt_params & params, llama_example ex,
|
||||||
[](gpt_params & params, const std::string & value) {
|
[](gpt_params & params, const std::string & value) {
|
||||||
params.ssl_file_key = value;
|
params.ssl_file_key = value;
|
||||||
}
|
}
|
||||||
).set_examples({LLAMA_EXAMPLE_SERVER}));
|
).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_SSL_KEY_FILE"));
|
||||||
add_opt(llama_arg(
|
add_opt(llama_arg(
|
||||||
{"--ssl-cert-file"}, "FNAME",
|
{"--ssl-cert-file"}, "FNAME",
|
||||||
"path to file a PEM-encoded SSL certificate",
|
"path to file a PEM-encoded SSL certificate",
|
||||||
[](gpt_params & params, const std::string & value) {
|
[](gpt_params & params, const std::string & value) {
|
||||||
params.ssl_file_cert = value;
|
params.ssl_file_cert = value;
|
||||||
}
|
}
|
||||||
).set_examples({LLAMA_EXAMPLE_SERVER}));
|
).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_SSL_CERT_FILE"));
|
||||||
add_opt(llama_arg(
|
add_opt(llama_arg(
|
||||||
{"-to", "--timeout"}, "N",
|
{"-to", "--timeout"}, "N",
|
||||||
format("server read/write timeout in seconds (default: %d)", params.timeout_read),
|
format("server read/write timeout in seconds (default: %d)", params.timeout_read),
|
||||||
|
@ -1795,7 +1795,7 @@ gpt_params_context gpt_params_parser_init(gpt_params & params, llama_example ex,
|
||||||
params.timeout_read = value;
|
params.timeout_read = value;
|
||||||
params.timeout_write = value;
|
params.timeout_write = value;
|
||||||
}
|
}
|
||||||
).set_examples({LLAMA_EXAMPLE_SERVER}));
|
).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_TIMEOUT"));
|
||||||
add_opt(llama_arg(
|
add_opt(llama_arg(
|
||||||
{"--threads-http"}, "N",
|
{"--threads-http"}, "N",
|
||||||
format("number of threads used to process HTTP requests (default: %d)", params.n_threads_http),
|
format("number of threads used to process HTTP requests (default: %d)", params.n_threads_http),
|
||||||
|
|
|
@ -82,7 +82,7 @@ struct gpt_log_entry {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (level != GGML_LOG_LEVEL_NONE && prefix) {
|
if (level != GGML_LOG_LEVEL_NONE && level != GGML_LOG_LEVEL_CONT && prefix) {
|
||||||
if (timestamp) {
|
if (timestamp) {
|
||||||
// [M.s.ms.us]
|
// [M.s.ms.us]
|
||||||
fprintf(fcur, "%s%d.%02d.%03d.%03d%s ",
|
fprintf(fcur, "%s%d.%02d.%03d.%03d%s ",
|
||||||
|
|
|
@ -83,8 +83,10 @@ void gpt_log_set_timestamps(struct gpt_log * log, bool timestamps); // w
|
||||||
#define LOG_WRN(...) LOG_TMPL(GGML_LOG_LEVEL_WARN, 0, __VA_ARGS__)
|
#define LOG_WRN(...) LOG_TMPL(GGML_LOG_LEVEL_WARN, 0, __VA_ARGS__)
|
||||||
#define LOG_ERR(...) LOG_TMPL(GGML_LOG_LEVEL_ERROR, 0, __VA_ARGS__)
|
#define LOG_ERR(...) LOG_TMPL(GGML_LOG_LEVEL_ERROR, 0, __VA_ARGS__)
|
||||||
#define LOG_DBG(...) LOG_TMPL(GGML_LOG_LEVEL_DEBUG, LOG_DEFAULT_DEBUG, __VA_ARGS__)
|
#define LOG_DBG(...) LOG_TMPL(GGML_LOG_LEVEL_DEBUG, LOG_DEFAULT_DEBUG, __VA_ARGS__)
|
||||||
|
#define LOG_CNT(...) LOG_TMPL(GGML_LOG_LEVEL_CONT, 0, __VA_ARGS__)
|
||||||
|
|
||||||
#define LOG_INFV(verbosity, ...) LOG_TMPL(GGML_LOG_LEVEL_INFO, verbosity, __VA_ARGS__)
|
#define LOG_INFV(verbosity, ...) LOG_TMPL(GGML_LOG_LEVEL_INFO, verbosity, __VA_ARGS__)
|
||||||
#define LOG_WRNV(verbosity, ...) LOG_TMPL(GGML_LOG_LEVEL_WARN, verbosity, __VA_ARGS__)
|
#define LOG_WRNV(verbosity, ...) LOG_TMPL(GGML_LOG_LEVEL_WARN, verbosity, __VA_ARGS__)
|
||||||
#define LOG_ERRV(verbosity, ...) LOG_TMPL(GGML_LOG_LEVEL_ERROR, verbosity, __VA_ARGS__)
|
#define LOG_ERRV(verbosity, ...) LOG_TMPL(GGML_LOG_LEVEL_ERROR, verbosity, __VA_ARGS__)
|
||||||
#define LOG_DBGV(verbosity, ...) LOG_TMPL(GGML_LOG_LEVEL_DEBUG, verbosity, __VA_ARGS__)
|
#define LOG_DBGV(verbosity, ...) LOG_TMPL(GGML_LOG_LEVEL_DEBUG, verbosity, __VA_ARGS__)
|
||||||
|
#define LOG_CNTV(verbosity, ...) LOG_TMPL(GGML_LOG_LEVEL_CONT, verbosity, __VA_ARGS__)
|
||||||
|
|
|
@ -209,7 +209,15 @@ struct gpt_sampler * gpt_sampler_init(const struct llama_model * model, const st
|
||||||
GGML_ASSERT(false && "unknown mirostat version");
|
GGML_ASSERT(false && "unknown mirostat version");
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
|
if (params.n_probs > 0) {
|
||||||
|
// some use cases require to sample greedily, but still obtain the probabilities of the top tokens
|
||||||
|
// ref: https://github.com/ggerganov/llama.cpp/pull/9605
|
||||||
|
//
|
||||||
|
// the following will not produce exactly the same probs as applyging softmax to the full vocabulary, but
|
||||||
|
// it is much faster, since we avoid sorting all tokens and should give a good approximation
|
||||||
|
llama_sampler_chain_add(result->chain, llama_sampler_init_top_k(params.n_probs));
|
||||||
llama_sampler_chain_add(result->chain, llama_sampler_init_softmax());
|
llama_sampler_chain_add(result->chain, llama_sampler_init_softmax());
|
||||||
|
}
|
||||||
llama_sampler_chain_add(result->chain, llama_sampler_init_greedy());
|
llama_sampler_chain_add(result->chain, llama_sampler_init_greedy());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -4102,16 +4102,45 @@ class GraniteModel(LlamaModel):
|
||||||
# consistency
|
# consistency
|
||||||
if attention_scale := self.hparams.get("attention_multiplier"):
|
if attention_scale := self.hparams.get("attention_multiplier"):
|
||||||
self.gguf_writer.add_attention_scale(attention_scale)
|
self.gguf_writer.add_attention_scale(attention_scale)
|
||||||
|
logger.info("gguf: (granite) attention_scale = %s", attention_scale)
|
||||||
if embedding_scale := self.hparams.get("embedding_multiplier"):
|
if embedding_scale := self.hparams.get("embedding_multiplier"):
|
||||||
self.gguf_writer.add_embedding_scale(embedding_scale)
|
self.gguf_writer.add_embedding_scale(embedding_scale)
|
||||||
|
logger.info("gguf: (granite) embedding_scale = %s", embedding_scale)
|
||||||
if residual_scale := self.hparams.get("residual_multiplier"):
|
if residual_scale := self.hparams.get("residual_multiplier"):
|
||||||
self.gguf_writer.add_residual_scale(residual_scale)
|
self.gguf_writer.add_residual_scale(residual_scale)
|
||||||
if logits_scaling := self.hparams.get("logits_scaling"):
|
logger.info("gguf: (granite) residual_scale = %s", residual_scale)
|
||||||
self.gguf_writer.add_logit_scale(logits_scaling)
|
if logits_scale := self.hparams.get("logits_scaling"):
|
||||||
|
self.gguf_writer.add_logit_scale(logits_scale)
|
||||||
|
logger.info("gguf: (granite) logits_scale = %s", logits_scale)
|
||||||
|
|
||||||
|
|
||||||
|
@Model.register("GraniteMoeForCausalLM")
|
||||||
|
class GraniteMoeModel(GraniteModel):
|
||||||
|
"""Conversion for IBM's GraniteMoeForCausalLM"""
|
||||||
|
model_arch = gguf.MODEL_ARCH.GRANITE_MOE
|
||||||
|
|
||||||
|
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
||||||
|
"""In modeling_granitemoe, the JetMoe implementation of parallel experts
|
||||||
|
is used. This essentially merges w1 and w3 into a single tensor with 2x
|
||||||
|
the hidden size that is then split during forward. To keep compatibility
|
||||||
|
with existing mixtral support, we pull them apart here.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if name.endswith("block_sparse_moe.input_linear.weight"):
|
||||||
|
ffn_dim = self.hparams["intermediate_size"]
|
||||||
|
assert data_torch.shape[-2] == 2 * ffn_dim, "Merged FFN tensor size must be 2 * intermediate_size"
|
||||||
|
gate, up = data_torch[..., :ffn_dim, :], data_torch[..., ffn_dim:, :]
|
||||||
|
return [
|
||||||
|
(self.format_tensor_name(gguf.MODEL_TENSOR.FFN_GATE_EXP, bid), gate),
|
||||||
|
(self.format_tensor_name(gguf.MODEL_TENSOR.FFN_UP_EXP, bid), up),
|
||||||
|
]
|
||||||
|
|
||||||
|
return super().modify_tensors(data_torch, name, bid)
|
||||||
|
|
||||||
|
|
||||||
###### CONVERSION LOGIC ######
|
###### CONVERSION LOGIC ######
|
||||||
|
|
||||||
|
|
||||||
# tree of lazy tensors
|
# tree of lazy tensors
|
||||||
class LazyTorchTensor(gguf.LazyBase):
|
class LazyTorchTensor(gguf.LazyBase):
|
||||||
_tensor_type = torch.Tensor
|
_tensor_type = torch.Tensor
|
||||||
|
|
|
@ -6,15 +6,12 @@
|
||||||
|
|
||||||
// Export usage message (-h) to markdown format
|
// Export usage message (-h) to markdown format
|
||||||
|
|
||||||
static void export_md(std::string fname, llama_example ex) {
|
static void write_table_header(std::ofstream & file) {
|
||||||
std::ofstream file(fname, std::ofstream::out | std::ofstream::trunc);
|
|
||||||
|
|
||||||
gpt_params params;
|
|
||||||
auto ctx_arg = gpt_params_parser_init(params, ex);
|
|
||||||
|
|
||||||
file << "| Argument | Explanation |\n";
|
file << "| Argument | Explanation |\n";
|
||||||
file << "| -------- | ----------- |\n";
|
file << "| -------- | ----------- |\n";
|
||||||
for (auto & opt : ctx_arg.options) {
|
}
|
||||||
|
|
||||||
|
static void write_table_entry(std::ofstream & file, const llama_arg & opt) {
|
||||||
file << "| `";
|
file << "| `";
|
||||||
// args
|
// args
|
||||||
for (const auto & arg : opt.args) {
|
for (const auto & arg : opt.args) {
|
||||||
|
@ -41,9 +38,43 @@ static void export_md(std::string fname, llama_example ex) {
|
||||||
string_replace_all(md_help, "\n", "<br/>");
|
string_replace_all(md_help, "\n", "<br/>");
|
||||||
string_replace_all(md_help, "|", "\\|");
|
string_replace_all(md_help, "|", "\\|");
|
||||||
file << "` | " << md_help << " |\n";
|
file << "` | " << md_help << " |\n";
|
||||||
|
}
|
||||||
|
|
||||||
|
static void write_table(std::ofstream & file, std::vector<llama_arg *> & opts) {
|
||||||
|
write_table_header(file);
|
||||||
|
for (const auto & opt : opts) {
|
||||||
|
write_table_entry(file, *opt);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static void export_md(std::string fname, llama_example ex) {
|
||||||
|
std::ofstream file(fname, std::ofstream::out | std::ofstream::trunc);
|
||||||
|
|
||||||
|
gpt_params params;
|
||||||
|
auto ctx_arg = gpt_params_parser_init(params, ex);
|
||||||
|
|
||||||
|
std::vector<llama_arg *> common_options;
|
||||||
|
std::vector<llama_arg *> sparam_options;
|
||||||
|
std::vector<llama_arg *> specific_options;
|
||||||
|
for (auto & opt : ctx_arg.options) {
|
||||||
|
// in case multiple LLAMA_EXAMPLE_* are set, we prioritize the LLAMA_EXAMPLE_* matching current example
|
||||||
|
if (opt.is_sparam) {
|
||||||
|
sparam_options.push_back(&opt);
|
||||||
|
} else if (opt.in_example(ctx_arg.ex)) {
|
||||||
|
specific_options.push_back(&opt);
|
||||||
|
} else {
|
||||||
|
common_options.push_back(&opt);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
file << "**Common params**\n\n";
|
||||||
|
write_table(file, common_options);
|
||||||
|
file << "\n\n**Sampling params**\n\n";
|
||||||
|
write_table(file, sparam_options);
|
||||||
|
file << "\n\n**Example-specific params**\n\n";
|
||||||
|
write_table(file, specific_options);
|
||||||
|
}
|
||||||
|
|
||||||
int main(int, char **) {
|
int main(int, char **) {
|
||||||
export_md("autogen-main.md", LLAMA_EXAMPLE_MAIN);
|
export_md("autogen-main.md", LLAMA_EXAMPLE_MAIN);
|
||||||
export_md("autogen-server.md", LLAMA_EXAMPLE_SERVER);
|
export_md("autogen-server.md", LLAMA_EXAMPLE_SERVER);
|
||||||
|
|
|
@ -386,9 +386,9 @@ int main(int argc, char ** argv) {
|
||||||
if (params.n_keep > add_bos) {
|
if (params.n_keep > add_bos) {
|
||||||
LOG_INF("%s: static prompt based on n_keep: '", __func__);
|
LOG_INF("%s: static prompt based on n_keep: '", __func__);
|
||||||
for (int i = 0; i < params.n_keep; i++) {
|
for (int i = 0; i < params.n_keep; i++) {
|
||||||
LOG("%s", llama_token_to_piece(ctx, embd_inp[i]).c_str());
|
LOG_CNT("%s", llama_token_to_piece(ctx, embd_inp[i]).c_str());
|
||||||
}
|
}
|
||||||
LOG("'\n");
|
LOG_CNT("'\n");
|
||||||
}
|
}
|
||||||
LOG_INF("\n");
|
LOG_INF("\n");
|
||||||
}
|
}
|
||||||
|
@ -410,40 +410,40 @@ int main(int argc, char ** argv) {
|
||||||
}
|
}
|
||||||
|
|
||||||
if (params.interactive) {
|
if (params.interactive) {
|
||||||
LOG("%s: interactive mode on.\n", __func__);
|
LOG_INF("%s: interactive mode on.\n", __func__);
|
||||||
|
|
||||||
if (!params.antiprompt.empty()) {
|
if (!params.antiprompt.empty()) {
|
||||||
for (const auto & antiprompt : params.antiprompt) {
|
for (const auto & antiprompt : params.antiprompt) {
|
||||||
LOG("Reverse prompt: '%s'\n", antiprompt.c_str());
|
LOG_INF("Reverse prompt: '%s'\n", antiprompt.c_str());
|
||||||
if (params.verbose_prompt) {
|
if (params.verbose_prompt) {
|
||||||
auto tmp = ::llama_tokenize(ctx, antiprompt, false, true);
|
auto tmp = ::llama_tokenize(ctx, antiprompt, false, true);
|
||||||
for (int i = 0; i < (int) tmp.size(); i++) {
|
for (int i = 0; i < (int) tmp.size(); i++) {
|
||||||
LOG("%6d -> '%s'\n", tmp[i], llama_token_to_piece(ctx, tmp[i]).c_str());
|
LOG_INF("%6d -> '%s'\n", tmp[i], llama_token_to_piece(ctx, tmp[i]).c_str());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (params.input_prefix_bos) {
|
if (params.input_prefix_bos) {
|
||||||
LOG("Input prefix with BOS\n");
|
LOG_INF("Input prefix with BOS\n");
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!params.input_prefix.empty()) {
|
if (!params.input_prefix.empty()) {
|
||||||
LOG("Input prefix: '%s'\n", params.input_prefix.c_str());
|
LOG_INF("Input prefix: '%s'\n", params.input_prefix.c_str());
|
||||||
if (params.verbose_prompt) {
|
if (params.verbose_prompt) {
|
||||||
auto tmp = ::llama_tokenize(ctx, params.input_prefix, true, true);
|
auto tmp = ::llama_tokenize(ctx, params.input_prefix, true, true);
|
||||||
for (int i = 0; i < (int) tmp.size(); i++) {
|
for (int i = 0; i < (int) tmp.size(); i++) {
|
||||||
LOG("%6d -> '%s'\n", tmp[i], llama_token_to_piece(ctx, tmp[i]).c_str());
|
LOG_INF("%6d -> '%s'\n", tmp[i], llama_token_to_piece(ctx, tmp[i]).c_str());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!params.input_suffix.empty()) {
|
if (!params.input_suffix.empty()) {
|
||||||
LOG("Input suffix: '%s'\n", params.input_suffix.c_str());
|
LOG_INF("Input suffix: '%s'\n", params.input_suffix.c_str());
|
||||||
if (params.verbose_prompt) {
|
if (params.verbose_prompt) {
|
||||||
auto tmp = ::llama_tokenize(ctx, params.input_suffix, false, true);
|
auto tmp = ::llama_tokenize(ctx, params.input_suffix, false, true);
|
||||||
for (int i = 0; i < (int) tmp.size(); i++) {
|
for (int i = 0; i < (int) tmp.size(); i++) {
|
||||||
LOG("%6d -> '%s'\n", tmp[i], llama_token_to_piece(ctx, tmp[i]).c_str());
|
LOG_INF("%6d -> '%s'\n", tmp[i], llama_token_to_piece(ctx, tmp[i]).c_str());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -475,7 +475,7 @@ int main(int argc, char ** argv) {
|
||||||
//GGML_ASSERT(n_ctx >= n_ctx_train * ga_n && "n_ctx must be at least n_ctx_train * grp_attn_n"); // NOLINT
|
//GGML_ASSERT(n_ctx >= n_ctx_train * ga_n && "n_ctx must be at least n_ctx_train * grp_attn_n"); // NOLINT
|
||||||
LOG_INF("self-extend: n_ctx_train = %d, grp_attn_n = %d, grp_attn_w = %d\n", n_ctx_train, ga_n, ga_w);
|
LOG_INF("self-extend: n_ctx_train = %d, grp_attn_n = %d, grp_attn_w = %d\n", n_ctx_train, ga_n, ga_w);
|
||||||
}
|
}
|
||||||
LOG("\n");
|
LOG_INF("\n");
|
||||||
|
|
||||||
if (params.interactive) {
|
if (params.interactive) {
|
||||||
const char * control_message;
|
const char * control_message;
|
||||||
|
@ -487,11 +487,11 @@ int main(int argc, char ** argv) {
|
||||||
" - To return control without starting a new line, end your input with '/'.\n"
|
" - To return control without starting a new line, end your input with '/'.\n"
|
||||||
" - If you want to submit another line, end your input with '\\'.\n";
|
" - If you want to submit another line, end your input with '\\'.\n";
|
||||||
}
|
}
|
||||||
LOG("== Running in interactive mode. ==\n");
|
LOG_INF("== Running in interactive mode. ==\n");
|
||||||
#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) || defined (_WIN32)
|
#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) || defined (_WIN32)
|
||||||
LOG( " - Press Ctrl+C to interject at any time.\n");
|
LOG_INF( " - Press Ctrl+C to interject at any time.\n");
|
||||||
#endif
|
#endif
|
||||||
LOG( "%s\n", control_message);
|
LOG_INF( "%s\n", control_message);
|
||||||
|
|
||||||
is_interacting = params.interactive_first;
|
is_interacting = params.interactive_first;
|
||||||
}
|
}
|
||||||
|
|
|
@ -1181,6 +1181,15 @@ struct server_context {
|
||||||
SLT_DBG(slot, "stopped by limit, n_decoded = %d, n_predict = %d\n", slot.n_decoded, slot.params.n_predict);
|
SLT_DBG(slot, "stopped by limit, n_decoded = %d, n_predict = %d\n", slot.n_decoded, slot.params.n_predict);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// if context shift is disabled, we stop when it reaches the context limit
|
||||||
|
if (slot.n_decoded >= slot.n_ctx) {
|
||||||
|
slot.truncated = true;
|
||||||
|
slot.stopped_limit = true;
|
||||||
|
slot.has_next_token = false;
|
||||||
|
|
||||||
|
SLT_DBG(slot, "stopped due to running out of context capacity, n_decoded = %d, n_ctx = %d\n", slot.n_decoded, slot.n_ctx);
|
||||||
|
}
|
||||||
|
|
||||||
if (llama_token_is_eog(model, result.tok)) {
|
if (llama_token_is_eog(model, result.tok)) {
|
||||||
slot.stopped_eos = true;
|
slot.stopped_eos = true;
|
||||||
slot.has_next_token = false;
|
slot.has_next_token = false;
|
||||||
|
@ -1481,7 +1490,7 @@ struct server_context {
|
||||||
if (result.error) {
|
if (result.error) {
|
||||||
error_handler(result.data);
|
error_handler(result.data);
|
||||||
cancel_tasks(id_tasks);
|
cancel_tasks(id_tasks);
|
||||||
break;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t idx = result.data["index"];
|
size_t idx = result.data["index"];
|
||||||
|
@ -1828,6 +1837,14 @@ struct server_context {
|
||||||
for (server_slot & slot : slots) {
|
for (server_slot & slot : slots) {
|
||||||
if (slot.ga_n == 1) {
|
if (slot.ga_n == 1) {
|
||||||
if (slot.is_processing() && (int) system_tokens.size() + slot.n_past >= slot.n_ctx - 1) {
|
if (slot.is_processing() && (int) system_tokens.size() + slot.n_past >= slot.n_ctx - 1) {
|
||||||
|
if (!params.ctx_shift) {
|
||||||
|
// this check is redundant (for good)
|
||||||
|
// we should never get here, because generation should already stopped in process_token()
|
||||||
|
slot.release();
|
||||||
|
send_error(slot, "context shift is disabled", ERROR_TYPE_SERVER);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
// Shift context
|
// Shift context
|
||||||
const int n_keep = slot.params.n_keep + add_bos_token;
|
const int n_keep = slot.params.n_keep + add_bos_token;
|
||||||
const int n_left = (int) system_tokens.size() + slot.n_past - n_keep;
|
const int n_left = (int) system_tokens.size() + slot.n_past - n_keep;
|
||||||
|
@ -1962,6 +1979,14 @@ struct server_context {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
|
if (!params.ctx_shift) {
|
||||||
|
// if context shift is disabled, we make sure prompt size is smaller than KV size
|
||||||
|
if ((int) system_tokens.size() + slot.n_prompt_tokens >= slot.n_ctx) {
|
||||||
|
slot.release();
|
||||||
|
send_error(slot, "the request exceeds the available context size. try increasing the context size or enable context shift", ERROR_TYPE_INVALID_REQUEST);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
}
|
||||||
if (slot.params.n_keep < 0) {
|
if (slot.params.n_keep < 0) {
|
||||||
slot.params.n_keep = slot.n_prompt_tokens;
|
slot.params.n_keep = slot.n_prompt_tokens;
|
||||||
}
|
}
|
||||||
|
@ -2332,6 +2357,10 @@ int main(int argc, char ** argv) {
|
||||||
svr.reset(new httplib::Server());
|
svr.reset(new httplib::Server());
|
||||||
}
|
}
|
||||||
#else
|
#else
|
||||||
|
if (params.ssl_file_key != "" && params.ssl_file_cert != "") {
|
||||||
|
LOG_ERR("Server is built without SSL support\n");
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
svr.reset(new httplib::Server());
|
svr.reset(new httplib::Server());
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
@ -3155,7 +3184,7 @@ int main(int argc, char ** argv) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// print sample chat example to make it clear which template is used
|
// print sample chat example to make it clear which template is used
|
||||||
LOG_INF("%s: chat template, built_in: %d, chat_example: '%s\n'", __func__, params.chat_template.empty(), llama_chat_format_example(ctx_server.model, params.chat_template).c_str());
|
LOG_INF("%s: chat template, built_in: %d, chat_example: '%s'\n", __func__, params.chat_template.empty(), llama_chat_format_example(ctx_server.model, params.chat_template).c_str());
|
||||||
|
|
||||||
ctx_server.queue_tasks.on_new_task(std::bind(
|
ctx_server.queue_tasks.on_new_task(std::bind(
|
||||||
&server_context::process_single_task, &ctx_server, std::placeholders::_1));
|
&server_context::process_single_task, &ctx_server, std::placeholders::_1));
|
||||||
|
|
62
examples/server/tests/features/ctx_shift.feature
Normal file
62
examples/server/tests/features/ctx_shift.feature
Normal file
|
@ -0,0 +1,62 @@
|
||||||
|
@llama.cpp
|
||||||
|
@ctx_shift
|
||||||
|
Feature: llama.cpp server
|
||||||
|
|
||||||
|
Background: Server startup
|
||||||
|
Given a server listening on localhost:8080
|
||||||
|
And a model file tinyllamas/stories260K.gguf from HF repo ggml-org/models
|
||||||
|
And a model file test-model.gguf
|
||||||
|
And a model alias tinyllama-2
|
||||||
|
And BOS token is 1
|
||||||
|
And 42 as server seed
|
||||||
|
And 256 KV cache size
|
||||||
|
And 32 as batch size
|
||||||
|
And 2 slots
|
||||||
|
|
||||||
|
Scenario: Inference with context shift
|
||||||
|
And 64 server max tokens to predict
|
||||||
|
Then the server is starting
|
||||||
|
Then the server is healthy
|
||||||
|
Given a prompt:
|
||||||
|
"""
|
||||||
|
Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua.
|
||||||
|
Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat.
|
||||||
|
Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur.
|
||||||
|
Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum.
|
||||||
|
"""
|
||||||
|
And a completion request with no api error
|
||||||
|
Then 64 tokens are predicted matching fun|Annaks|popcorns|pictry|bowl
|
||||||
|
And the completion is truncated
|
||||||
|
And 109 prompt tokens are processed
|
||||||
|
|
||||||
|
Scenario Outline: Inference without context shift
|
||||||
|
And <n_predict> server max tokens to predict
|
||||||
|
And disable context shifting
|
||||||
|
Then the server is starting
|
||||||
|
Then the server is healthy
|
||||||
|
Given a prompt:
|
||||||
|
"""
|
||||||
|
Hi how are you
|
||||||
|
"""
|
||||||
|
And a completion request with no api error
|
||||||
|
Then <n_token_output> tokens are predicted matching twind|Anna
|
||||||
|
And the completion is <truncated> truncated
|
||||||
|
And 8 prompt tokens are processed
|
||||||
|
Examples:
|
||||||
|
| n_predict | n_token_output | truncated |
|
||||||
|
| 64 | 64 | not |
|
||||||
|
| -1 | 120 | |
|
||||||
|
|
||||||
|
Scenario: Inference without context shift (expected error: prompt too long)
|
||||||
|
And disable context shifting
|
||||||
|
Then the server is starting
|
||||||
|
Then the server is healthy
|
||||||
|
Given a prompt:
|
||||||
|
"""
|
||||||
|
Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua.
|
||||||
|
Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat.
|
||||||
|
Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur.
|
||||||
|
Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum.
|
||||||
|
"""
|
||||||
|
And a completion request with 400 api error
|
||||||
|
|
|
@ -10,11 +10,11 @@ Feature: llama.cpp server
|
||||||
And 42 as server seed
|
And 42 as server seed
|
||||||
And 2 slots
|
And 2 slots
|
||||||
# the bert-bge-small model has context size of 512
|
# the bert-bge-small model has context size of 512
|
||||||
# since the generated prompts are as big as the batch size, we need to set the batch size to 512
|
# since the generated prompts are as big as the batch size, we need to set the batch size to <= 512
|
||||||
# ref: https://huggingface.co/BAAI/bge-small-en-v1.5/blob/5c38ec7c405ec4b44b94cc5a9bb96e735b38267a/config.json#L20
|
# ref: https://huggingface.co/BAAI/bge-small-en-v1.5/blob/5c38ec7c405ec4b44b94cc5a9bb96e735b38267a/config.json#L20
|
||||||
And 512 as batch size
|
And 128 as batch size
|
||||||
And 512 as ubatch size
|
And 128 as ubatch size
|
||||||
And 2048 KV cache size
|
And 512 KV cache size
|
||||||
And embeddings extraction
|
And embeddings extraction
|
||||||
Then the server is starting
|
Then the server is starting
|
||||||
Then the server is healthy
|
Then the server is healthy
|
||||||
|
@ -26,6 +26,20 @@ Feature: llama.cpp server
|
||||||
"""
|
"""
|
||||||
Then embeddings are generated
|
Then embeddings are generated
|
||||||
|
|
||||||
|
Scenario: Embedding (error: prompt too long)
|
||||||
|
When embeddings are computed for:
|
||||||
|
"""
|
||||||
|
Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua.
|
||||||
|
Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat.
|
||||||
|
Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur.
|
||||||
|
Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum.
|
||||||
|
Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua.
|
||||||
|
Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat.
|
||||||
|
Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur.
|
||||||
|
Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum.
|
||||||
|
"""
|
||||||
|
And embeddings request with 500 api error
|
||||||
|
|
||||||
Scenario: OAI Embeddings compatibility
|
Scenario: OAI Embeddings compatibility
|
||||||
Given a model bert-bge-small
|
Given a model bert-bge-small
|
||||||
When an OAI compatible embeddings computation request for:
|
When an OAI compatible embeddings computation request for:
|
||||||
|
|
|
@ -77,6 +77,7 @@ def step_server_config(context, server_fqdn: str, server_port: str):
|
||||||
context.response_format = None
|
context.response_format = None
|
||||||
context.temperature = None
|
context.temperature = None
|
||||||
context.lora_file = None
|
context.lora_file = None
|
||||||
|
context.disable_ctx_shift = False
|
||||||
|
|
||||||
context.tasks_result = []
|
context.tasks_result = []
|
||||||
context.concurrent_tasks = []
|
context.concurrent_tasks = []
|
||||||
|
@ -148,7 +149,7 @@ def step_n_slots(context, n_slots: int):
|
||||||
|
|
||||||
@step('{n_predict:d} server max tokens to predict')
|
@step('{n_predict:d} server max tokens to predict')
|
||||||
def step_server_n_predict(context, n_predict: int):
|
def step_server_n_predict(context, n_predict: int):
|
||||||
context.n_server_predict = n_predict
|
context.n_server_predict = n_predict if n_predict > 0 else None
|
||||||
|
|
||||||
|
|
||||||
@step('{slot_save_path} as slot save path')
|
@step('{slot_save_path} as slot save path')
|
||||||
|
@ -180,6 +181,9 @@ def step_server_embeddings(context):
|
||||||
def step_server_metrics(context):
|
def step_server_metrics(context):
|
||||||
context.server_metrics = True
|
context.server_metrics = True
|
||||||
|
|
||||||
|
@step('disable context shifting')
|
||||||
|
def step_server_disable_ctx_shift(context):
|
||||||
|
context.disable_ctx_shift = True
|
||||||
|
|
||||||
@step("the server is starting")
|
@step("the server is starting")
|
||||||
def step_start_server(context):
|
def step_start_server(context):
|
||||||
|
@ -257,7 +261,7 @@ async def step_all_slots_status(context, expected_slot_status_string: Literal['i
|
||||||
@step('a completion request with {api_error} api error')
|
@step('a completion request with {api_error} api error')
|
||||||
@async_run_until_complete
|
@async_run_until_complete
|
||||||
async def step_request_completion(context, api_error: Literal['raised'] | str):
|
async def step_request_completion(context, api_error: Literal['raised'] | str):
|
||||||
expect_api_error = api_error == 'raised'
|
expect_api_error = api_error == 'raised' or api_error != 'no'
|
||||||
seeds = await completions_seed(context, num_seeds=1)
|
seeds = await completions_seed(context, num_seeds=1)
|
||||||
completion = await request_completion(context.prompts.pop(),
|
completion = await request_completion(context.prompts.pop(),
|
||||||
seeds[0] if seeds is not None else seeds,
|
seeds[0] if seeds is not None else seeds,
|
||||||
|
@ -272,8 +276,11 @@ async def step_request_completion(context, api_error: Literal['raised'] | str):
|
||||||
context.tasks_result.append(completion)
|
context.tasks_result.append(completion)
|
||||||
if context.debug:
|
if context.debug:
|
||||||
print(f"Completion response: {completion}")
|
print(f"Completion response: {completion}")
|
||||||
if expect_api_error:
|
if api_error == 'raised':
|
||||||
assert completion == 401, f"completion must be an 401 status code: {completion}"
|
assert completion == 401, f"completion must be an 401 status code: {completion}"
|
||||||
|
elif api_error.isdigit():
|
||||||
|
api_error_code = int(api_error)
|
||||||
|
assert completion == api_error_code, f"completion must be an {api_error_code} status code: {completion}"
|
||||||
|
|
||||||
|
|
||||||
@step('{predicted_n:d} tokens are predicted matching {re_content}')
|
@step('{predicted_n:d} tokens are predicted matching {re_content}')
|
||||||
|
@ -645,6 +652,9 @@ def step_assert_embeddings(context):
|
||||||
for embedding in context.embeddings:
|
for embedding in context.embeddings:
|
||||||
assert_embeddings(embedding)
|
assert_embeddings(embedding)
|
||||||
|
|
||||||
|
@step('embeddings request with {api_error_code:d} api error')
|
||||||
|
def step_assert_embeddings(context, api_error_code: int):
|
||||||
|
assert context.embeddings == api_error_code, f"embeddings request must return code {api_error_code}, but got {context.embeddings}"
|
||||||
|
|
||||||
@step('an OAI compatible embeddings computation request for')
|
@step('an OAI compatible embeddings computation request for')
|
||||||
@async_run_until_complete
|
@async_run_until_complete
|
||||||
|
@ -1089,15 +1099,17 @@ async def oai_chat_completions(user_prompt,
|
||||||
return completion_response
|
return completion_response
|
||||||
|
|
||||||
|
|
||||||
async def request_embedding(content, seed, base_url=None) -> list[list[float]]:
|
async def request_embedding(content, seed, base_url=None) -> list[list[float]] | int:
|
||||||
async with aiohttp.ClientSession(timeout=DEFAULT_TIMEOUT_SECONDS) as session:
|
async with aiohttp.ClientSession(timeout=DEFAULT_TIMEOUT_SECONDS) as session:
|
||||||
async with session.post(f'{base_url}/embedding',
|
async with session.post(f'{base_url}/embedding',
|
||||||
json={
|
json={
|
||||||
"content": content,
|
"content": content,
|
||||||
}) as response:
|
}) as response:
|
||||||
assert response.status == 200
|
if response.status == 200:
|
||||||
response_json = await response.json()
|
response_json = await response.json()
|
||||||
return [response_json['embedding']]
|
return [response_json['embedding']]
|
||||||
|
else:
|
||||||
|
return response.status
|
||||||
|
|
||||||
|
|
||||||
async def request_oai_embeddings(input, seed,
|
async def request_oai_embeddings(input, seed,
|
||||||
|
@ -1372,6 +1384,8 @@ def start_server_background(context):
|
||||||
server_args.append('--verbose')
|
server_args.append('--verbose')
|
||||||
if context.lora_file:
|
if context.lora_file:
|
||||||
server_args.extend(['--lora', context.lora_file])
|
server_args.extend(['--lora', context.lora_file])
|
||||||
|
if context.disable_ctx_shift:
|
||||||
|
server_args.extend(['--no-context-shift'])
|
||||||
|
|
||||||
args = [str(arg) for arg in [context.server_path, *server_args]]
|
args = [str(arg) for arg in [context.server_path, *server_args]]
|
||||||
print(f"bench: starting server with: {' '.join(args)}")
|
print(f"bench: starting server with: {' '.join(args)}")
|
||||||
|
|
|
@ -576,6 +576,7 @@ extern "C" {
|
||||||
GGML_LOG_LEVEL_WARN = 2,
|
GGML_LOG_LEVEL_WARN = 2,
|
||||||
GGML_LOG_LEVEL_ERROR = 3,
|
GGML_LOG_LEVEL_ERROR = 3,
|
||||||
GGML_LOG_LEVEL_DEBUG = 4,
|
GGML_LOG_LEVEL_DEBUG = 4,
|
||||||
|
GGML_LOG_LEVEL_CONT = 5, // continue previous log
|
||||||
};
|
};
|
||||||
|
|
||||||
// this tensor...
|
// this tensor...
|
||||||
|
@ -1985,6 +1986,9 @@ extern "C" {
|
||||||
typedef void (*ggml_custom2_op_t)(struct ggml_tensor * dst , const struct ggml_tensor * a, const struct ggml_tensor * b, int ith, int nth, void * userdata);
|
typedef void (*ggml_custom2_op_t)(struct ggml_tensor * dst , const struct ggml_tensor * a, const struct ggml_tensor * b, int ith, int nth, void * userdata);
|
||||||
typedef void (*ggml_custom3_op_t)(struct ggml_tensor * dst , const struct ggml_tensor * a, const struct ggml_tensor * b, const struct ggml_tensor * c, int ith, int nth, void * userdata);
|
typedef void (*ggml_custom3_op_t)(struct ggml_tensor * dst , const struct ggml_tensor * a, const struct ggml_tensor * b, const struct ggml_tensor * c, int ith, int nth, void * userdata);
|
||||||
|
|
||||||
|
#define GGML_N_TASKS_MAX (-1)
|
||||||
|
// n_tasks == GGML_N_TASKS_MAX means to use max number of tasks
|
||||||
|
|
||||||
GGML_API struct ggml_tensor * ggml_map_custom1(
|
GGML_API struct ggml_tensor * ggml_map_custom1(
|
||||||
struct ggml_context * ctx,
|
struct ggml_context * ctx,
|
||||||
struct ggml_tensor * a,
|
struct ggml_tensor * a,
|
||||||
|
|
|
@ -1,4 +1,7 @@
|
||||||
// SPDX-FileCopyrightText: Copyright 2024 Arm Ltd.
|
// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates <open-source-office@arm.com>
|
||||||
|
// SPDX-License-Identifier: MIT
|
||||||
|
//
|
||||||
|
|
||||||
#define GGML_COMMON_IMPL_C
|
#define GGML_COMMON_IMPL_C
|
||||||
#include "ggml-common.h"
|
#include "ggml-common.h"
|
||||||
|
|
||||||
|
@ -39,11 +42,44 @@
|
||||||
//
|
//
|
||||||
#if defined(__AVX__)
|
#if defined(__AVX__)
|
||||||
#if defined(__F16C__)
|
#if defined(__F16C__)
|
||||||
|
#if defined(__AVX512F__)
|
||||||
|
#define GGML_F32Cx8x2_LOAD(x, y) _mm512_cvtph_ps(_mm256_set_m128i(_mm_loadu_si128((const __m128i *)(y)), _mm_loadu_si128((const __m128i *)(x))))
|
||||||
|
#define GGML_F32Cx16_REPEAT_LOAD(x) _mm512_cvtph_ps(_mm256_set_m128i(x, x))
|
||||||
|
#endif
|
||||||
// the _mm256_cvt intrinsics require F16C
|
// the _mm256_cvt intrinsics require F16C
|
||||||
#define GGML_F32Cx8_LOAD(x) _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)(x)))
|
#define GGML_F32Cx8_LOAD(x) _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)(x)))
|
||||||
#define GGML_F32Cx8_REPEAT_LOAD(x, loadMask) _mm256_cvtph_ps(_mm_shuffle_epi32(_mm_maskload_epi32((int const*)(x), loadMask), 68))
|
#define GGML_F32Cx8_REPEAT_LOAD(x, loadMask) _mm256_cvtph_ps(_mm_shuffle_epi32(_mm_maskload_epi32((int const*)(x), loadMask), 68))
|
||||||
#define GGML_F32Cx8_REARRANGE_LOAD(x, arrangeMask) _mm256_cvtph_ps(_mm_shuffle_epi8(_mm_loadu_si128((const __m128i *) x), arrangeMask))
|
#define GGML_F32Cx8_REARRANGE_LOAD(x, arrangeMask) _mm256_cvtph_ps(_mm_shuffle_epi8(_mm_loadu_si128((const __m128i *) x), arrangeMask))
|
||||||
#else
|
#else
|
||||||
|
#if defined(__AVX512F__)
|
||||||
|
static inline __m512 __avx512_f32cx8x2_load(ggml_fp16_t *x, ggml_fp16_t *y) {
|
||||||
|
float tmp[16];
|
||||||
|
|
||||||
|
for (int i = 0; i < 8; i++) {
|
||||||
|
tmp[i] = GGML_FP16_TO_FP32(x[i]);
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int i = 0; i < 8; i++) {
|
||||||
|
tmp[i + 8] = GGML_FP16_TO_FP32(y[i]);
|
||||||
|
}
|
||||||
|
|
||||||
|
return _mm512_loadu_ps(tmp);
|
||||||
|
}
|
||||||
|
static inline __m512 __avx512_repeat_f32cx16_load(__m128i x) {
|
||||||
|
float tmp[16];
|
||||||
|
uint16_t tmphalf[8];
|
||||||
|
_mm_storeu_si128((__m128i*)tmphalf, x);
|
||||||
|
|
||||||
|
for (int i = 0; i < 4; i++) {
|
||||||
|
tmp[i] = GGML_FP16_TO_FP32(tmphalf[i]);
|
||||||
|
tmp[i + 4] = GGML_FP16_TO_FP32(tmphalf[i]);
|
||||||
|
tmp[i + 8] = GGML_FP16_TO_FP32(tmphalf[i]);
|
||||||
|
tmp[i + 12] = GGML_FP16_TO_FP32(tmphalf[i]);
|
||||||
|
}
|
||||||
|
|
||||||
|
return _mm512_loadu_ps(tmp);
|
||||||
|
}
|
||||||
|
#endif
|
||||||
static inline __m256 __avx_f32cx8_load(ggml_fp16_t *x) {
|
static inline __m256 __avx_f32cx8_load(ggml_fp16_t *x) {
|
||||||
float tmp[8];
|
float tmp[8];
|
||||||
|
|
||||||
|
@ -78,30 +114,65 @@ static inline __m256 __avx_rearranged_f32cx8_load(ggml_fp16_t *x, __m128i arrang
|
||||||
#define GGML_F32Cx8_LOAD(x) __avx_f32cx8_load(x)
|
#define GGML_F32Cx8_LOAD(x) __avx_f32cx8_load(x)
|
||||||
#define GGML_F32Cx8_REPEAT_LOAD(x, loadMask) __avx_repeat_f32cx8_load(x)
|
#define GGML_F32Cx8_REPEAT_LOAD(x, loadMask) __avx_repeat_f32cx8_load(x)
|
||||||
#define GGML_F32Cx8_REARRANGE_LOAD(x, arrangeMask) __avx_rearranged_f32cx8_load(x, arrangeMask)
|
#define GGML_F32Cx8_REARRANGE_LOAD(x, arrangeMask) __avx_rearranged_f32cx8_load(x, arrangeMask)
|
||||||
|
#if defined(__AVX512F__)
|
||||||
|
#define GGML_F32Cx8x2_LOAD(x, y) __avx512_f32cx8x2_load(x, y)
|
||||||
|
#define GGML_F32Cx16_REPEAT_LOAD(x) __avx512_repeat_f32cx16_load(x)
|
||||||
|
#endif
|
||||||
#endif
|
#endif
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
|
||||||
#if defined(__AVX2__) || defined(__AVX512F__)
|
#if defined(__AVX2__) || defined(__AVX512F__)
|
||||||
static inline __m256i sum_i16_pairs_int(const __m256i x) {
|
#if defined(__AVX512F__)
|
||||||
|
// add int16_t pairwise and return as 512 bit int vector
|
||||||
|
static inline __m512i sum_i16_pairs_int_32x16(const __m512i x) {
|
||||||
|
const __m512i ones = _mm512_set1_epi16(1);
|
||||||
|
return _mm512_madd_epi16(ones, x);
|
||||||
|
}
|
||||||
|
|
||||||
|
static inline __m512i mul_sum_us8_pairs_int32x16(const __m512i ax, const __m512i sy) {
|
||||||
|
#if defined(__AVXVNNI__) || (defined(__AVX512VNNI__) && defined(__AVX512VL__))
|
||||||
|
const __m512i zero = _mm512_setzero_si512();
|
||||||
|
return _mm512_dpbusd_epi32(zero, ax, sy);
|
||||||
|
#else
|
||||||
|
// Perform multiplication and create 16-bit values
|
||||||
|
const __m512i dot = _mm512_maddubs_epi16(ax, sy);
|
||||||
|
return sum_i16_pairs_int_32x16(dot);
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
// multiply int8_t, add results pairwise twice and return as 512 bit int vector
|
||||||
|
static inline __m512i mul_sum_i8_pairs_int32x16(const __m512i x, const __m512i y) {
|
||||||
|
const __m512i zero = _mm512_setzero_si512();
|
||||||
|
// Get absolute values of x vectors
|
||||||
|
const __m512i ax = _mm512_abs_epi8(x);
|
||||||
|
// Sign the values of the y vectors
|
||||||
|
__mmask64 blt0 = _mm512_movepi8_mask(x);
|
||||||
|
const __m512i sy = _mm512_mask_sub_epi8(y, blt0, zero, y);
|
||||||
|
return mul_sum_us8_pairs_int32x16(ax, sy);
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
// add int16_t pairwise and return as 256 bit int vector
|
||||||
|
static inline __m256i sum_i16_pairs_int32x8(const __m256i x) {
|
||||||
const __m256i ones = _mm256_set1_epi16(1);
|
const __m256i ones = _mm256_set1_epi16(1);
|
||||||
return _mm256_madd_epi16(ones, x);
|
return _mm256_madd_epi16(ones, x);
|
||||||
}
|
}
|
||||||
|
|
||||||
static inline __m256i mul_sum_us8_pairs_int(const __m256i ax, const __m256i sy) {
|
static inline __m256i mul_sum_us8_pairs_int32x8(const __m256i ax, const __m256i sy) {
|
||||||
#if defined(__AVXVNNI__) || (defined(__AVX512VNNI__) && defined(__AVX512VL__))
|
#if defined(__AVXVNNI__) || (defined(__AVX512VNNI__) && defined(__AVX512VL__))
|
||||||
const __m256i zero = _mm256_setzero_si256();
|
const __m256i zero = _mm256_setzero_si256();
|
||||||
return _mm256_dpbusd_epi32(zero, ax, sy);
|
return _mm256_dpbusd_epi32(zero, ax, sy);
|
||||||
#else
|
#else
|
||||||
// Perform multiplication and create 16-bit values
|
// Perform multiplication and create 16-bit values
|
||||||
const __m256i dot = _mm256_maddubs_epi16(ax, sy);
|
const __m256i dot = _mm256_maddubs_epi16(ax, sy);
|
||||||
return sum_i16_pairs_int(dot);
|
return sum_i16_pairs_int32x8(dot);
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
// Integer variant of the function defined in ggml-quants.c
|
// Integer variant of the function defined in ggml-quants.c
|
||||||
// multiply int8_t, add results pairwise twice and return as float vector
|
// multiply int8_t, add results pairwise twice and return as 256 bit int vector
|
||||||
static inline __m256i mul_sum_i8_pairs_int(const __m256i x, const __m256i y) {
|
static inline __m256i mul_sum_i8_pairs_int32x8(const __m256i x, const __m256i y) {
|
||||||
#if __AVXVNNIINT8__
|
#if __AVXVNNIINT8__
|
||||||
const __m256i zero = _mm256_setzero_si256();
|
const __m256i zero = _mm256_setzero_si256();
|
||||||
return _mm256_dpbssd_epi32(zero, x, y);
|
return _mm256_dpbssd_epi32(zero, x, y);
|
||||||
|
@ -110,7 +181,7 @@ static inline __m256i mul_sum_i8_pairs_int(const __m256i x, const __m256i y) {
|
||||||
const __m256i ax = _mm256_sign_epi8(x, x);
|
const __m256i ax = _mm256_sign_epi8(x, x);
|
||||||
// Sign the values of the y vectors
|
// Sign the values of the y vectors
|
||||||
const __m256i sy = _mm256_sign_epi8(y, x);
|
const __m256i sy = _mm256_sign_epi8(y, x);
|
||||||
return mul_sum_us8_pairs_int(ax, sy);
|
return mul_sum_us8_pairs_int32x8(ax, sy);
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
@ -527,6 +598,15 @@ size_t quantize_q4_0_8x8(const float * restrict src, void * restrict dst, int64_
|
||||||
return quantize_q4_0_nr_bl(src, dst, nrow, n_per_row, 8, 8);
|
return quantize_q4_0_nr_bl(src, dst, nrow, n_per_row, 8, 8);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Return the number of byte lanes in the SVE vector if SVE is supported; otherwise, returns 0 if SVE is not supported.
|
||||||
|
static int sve_lane_count(void) {
|
||||||
|
#if defined(__ARM_FEATURE_SVE)
|
||||||
|
return ggml_sve_cnt_b;
|
||||||
|
#else
|
||||||
|
return 0;
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
void ggml_gemv_q4_0_4x4_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, const void * restrict vy, int nr, int nc) {
|
void ggml_gemv_q4_0_4x4_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, const void * restrict vy, int nr, int nc) {
|
||||||
const int qk = QK8_0;
|
const int qk = QK8_0;
|
||||||
const int nb = n / qk;
|
const int nb = n / qk;
|
||||||
|
@ -546,16 +626,8 @@ void ggml_gemv_q4_0_4x4_q8_0(int n, float * restrict s, size_t bs, const void *
|
||||||
UNUSED(ncols_interleaved);
|
UNUSED(ncols_interleaved);
|
||||||
UNUSED(blocklen);
|
UNUSED(blocklen);
|
||||||
|
|
||||||
#if defined(__ARM_FEATURE_SVE)
|
#if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON)
|
||||||
if (ggml_sve_cnt_b == QK8_0) {
|
if (ggml_cpu_has_neon()) {
|
||||||
GGML_ASSERT(!(ggml_cpu_has_sve() && (ggml_sve_cnt_b == QK8_0)) &&
|
|
||||||
"__ARM_FEATURE_SVE defined, use the Q4_0_8_8 quantization format for optimal performance");
|
|
||||||
}
|
|
||||||
#endif
|
|
||||||
#if defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8)
|
|
||||||
GGML_ASSERT(!(ggml_cpu_has_neon() && ggml_cpu_has_matmul_int8()) &&
|
|
||||||
"__ARM_NEON and __ARM_FEATURE_MATMUL_INT8 defined, use the Q4_0_4_8 quantization format for optimal performance");
|
|
||||||
#elif defined(__ARM_NEON) && defined(__aarch64__) && ! ((defined(_MSC_VER)) && ! defined(__clang__))
|
|
||||||
const void * b_ptr = vx;
|
const void * b_ptr = vx;
|
||||||
const void * a_ptr = vy;
|
const void * a_ptr = vy;
|
||||||
float * res_ptr = s;
|
float * res_ptr = s;
|
||||||
|
@ -612,7 +684,9 @@ void ggml_gemv_q4_0_4x4_q8_0(int n, float * restrict s, size_t bs, const void *
|
||||||
: [a_ptr] "r" (a_ptr), [nb] "r" (nb)
|
: [a_ptr] "r" (a_ptr), [nb] "r" (nb)
|
||||||
: "memory", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "x20", "x21", "x22"
|
: "memory", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "x20", "x21", "x22"
|
||||||
);
|
);
|
||||||
#else
|
return;
|
||||||
|
}
|
||||||
|
#endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON)
|
||||||
float sumf[4];
|
float sumf[4];
|
||||||
int sumi;
|
int sumi;
|
||||||
|
|
||||||
|
@ -636,7 +710,6 @@ void ggml_gemv_q4_0_4x4_q8_0(int n, float * restrict s, size_t bs, const void *
|
||||||
}
|
}
|
||||||
for (int j = 0; j < ncols_interleaved; j++) s[x * ncols_interleaved + j] = sumf[j];
|
for (int j = 0; j < ncols_interleaved; j++) s[x * ncols_interleaved + j] = sumf[j];
|
||||||
}
|
}
|
||||||
#endif
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void ggml_gemv_q4_0_4x8_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, const void * restrict vy, int nr, int nc) {
|
void ggml_gemv_q4_0_4x8_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, const void * restrict vy, int nr, int nc) {
|
||||||
|
@ -658,13 +731,8 @@ void ggml_gemv_q4_0_4x8_q8_0(int n, float * restrict s, size_t bs, const void *
|
||||||
UNUSED(ncols_interleaved);
|
UNUSED(ncols_interleaved);
|
||||||
UNUSED(blocklen);
|
UNUSED(blocklen);
|
||||||
|
|
||||||
#if defined(__ARM_FEATURE_SVE)
|
#if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8)
|
||||||
if (ggml_sve_cnt_b == QK8_0) {
|
if (ggml_cpu_has_neon() && ggml_cpu_has_matmul_int8()) {
|
||||||
GGML_ASSERT(!(ggml_cpu_has_sve() && (ggml_sve_cnt_b == QK8_0)) &&
|
|
||||||
"__ARM_FEATURE_SVE defined, use the Q4_0_8_8 quantization format for optimal performance");
|
|
||||||
}
|
|
||||||
#endif
|
|
||||||
#if defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8) && ! ((defined(_MSC_VER)) && ! defined(__clang__))
|
|
||||||
const void * b_ptr = vx;
|
const void * b_ptr = vx;
|
||||||
const void * a_ptr = vy;
|
const void * a_ptr = vy;
|
||||||
float * res_ptr = s;
|
float * res_ptr = s;
|
||||||
|
@ -726,11 +794,9 @@ void ggml_gemv_q4_0_4x8_q8_0(int n, float * restrict s, size_t bs, const void *
|
||||||
: [a_ptr] "r" (a_ptr), [nb] "r" (nb)
|
: [a_ptr] "r" (a_ptr), [nb] "r" (nb)
|
||||||
: "memory", "v0", "v1", "v2", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "x20", "x21", "x22", "x23"
|
: "memory", "v0", "v1", "v2", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "x20", "x21", "x22", "x23"
|
||||||
);
|
);
|
||||||
#elif defined(__ARM_NEON) && defined(__aarch64__)
|
return;
|
||||||
GGML_ASSERT((ggml_cpu_has_sve() || ggml_cpu_has_matmul_int8()) &&
|
}
|
||||||
"__ARM_FEATURE_SVE and __ARM_FEATURE_MATMUL_INT8 not defined, use the Q4_0_4_4 quantization format for optimal "
|
#endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8)
|
||||||
"performance");
|
|
||||||
#else
|
|
||||||
float sumf[4];
|
float sumf[4];
|
||||||
int sumi;
|
int sumi;
|
||||||
|
|
||||||
|
@ -754,7 +820,6 @@ void ggml_gemv_q4_0_4x8_q8_0(int n, float * restrict s, size_t bs, const void *
|
||||||
}
|
}
|
||||||
for (int j = 0; j < ncols_interleaved; j++) s[x * ncols_interleaved + j] = sumf[j];
|
for (int j = 0; j < ncols_interleaved; j++) s[x * ncols_interleaved + j] = sumf[j];
|
||||||
}
|
}
|
||||||
#endif
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void ggml_gemv_q4_0_8x8_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, const void * restrict vy, int nr, int nc) {
|
void ggml_gemv_q4_0_8x8_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, const void * restrict vy, int nr, int nc) {
|
||||||
|
@ -776,8 +841,9 @@ void ggml_gemv_q4_0_8x8_q8_0(int n, float * restrict s, size_t bs, const void *
|
||||||
UNUSED(ncols_interleaved);
|
UNUSED(ncols_interleaved);
|
||||||
UNUSED(blocklen);
|
UNUSED(blocklen);
|
||||||
|
|
||||||
#if defined(__ARM_FEATURE_SVE) && ! ((defined(_MSC_VER)) && ! defined(__clang__))
|
#if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__)
|
||||||
if (ggml_sve_cnt_b == QK8_0) {
|
#if defined(__ARM_FEATURE_SVE)
|
||||||
|
if (ggml_cpu_has_sve() && sve_lane_count() == QK8_0) {
|
||||||
const void * b_ptr = vx;
|
const void * b_ptr = vx;
|
||||||
const void * a_ptr = vy;
|
const void * a_ptr = vy;
|
||||||
float * res_ptr = s;
|
float * res_ptr = s;
|
||||||
|
@ -842,24 +908,7 @@ void ggml_gemv_q4_0_8x8_q8_0(int n, float * restrict s, size_t bs, const void *
|
||||||
);
|
);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
else if (ggml_cpu_has_neon() && ggml_cpu_has_matmul_int8()) {
|
#endif // #if defined(__ARM_FEATURE_SVE)
|
||||||
GGML_ASSERT((ggml_cpu_has_sve() && (ggml_sve_cnt_b == QK8_0)) &&
|
|
||||||
"__ARM_FEATURE_SVE for vector size of 256-bits not defined, use the Q4_0_4_8 quantization format for optimal "
|
|
||||||
"performance");
|
|
||||||
}
|
|
||||||
else if (ggml_cpu_has_neon()) {
|
|
||||||
GGML_ASSERT(((ggml_cpu_has_sve() && (ggml_sve_cnt_b == QK8_0)) || ggml_cpu_has_matmul_int8()) &&
|
|
||||||
"__ARM_FEATURE_SVE for vector size of 256-bits and __ARM_FEATURE_MATMUL_INT8 not defined, use the Q4_0_4_4 "
|
|
||||||
"quantization format for optimal performance");
|
|
||||||
}
|
|
||||||
#endif
|
|
||||||
#if defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8)
|
|
||||||
GGML_ASSERT(ggml_cpu_has_sve() &&
|
|
||||||
"__ARM_FEATURE_SVE not defined, use the Q4_0_4_8 quantization format for optimal performance");
|
|
||||||
#elif defined(__ARM_NEON) && defined(__aarch64__)
|
|
||||||
GGML_ASSERT((ggml_cpu_has_sve() || ggml_cpu_has_matmul_int8()) &&
|
|
||||||
"__ARM_FEATURE_SVE and __ARM_FEATURE_MATMUL_INT8 not defined, use the Q4_0_4_4 quantization format for optimal "
|
|
||||||
"performance");
|
|
||||||
#elif defined(__AVX2__)
|
#elif defined(__AVX2__)
|
||||||
// Lookup table to convert signed nibbles to signed bytes
|
// Lookup table to convert signed nibbles to signed bytes
|
||||||
__m256i signextendlut = _mm256_castsi128_si256(_mm_set_epi8(-1, -2, -3, -4, -5, -6, -7, -8, 7, 6, 5, 4, 3, 2, 1, 0));
|
__m256i signextendlut = _mm256_castsi128_si256(_mm_set_epi8(-1, -2, -3, -4, -5, -6, -7, -8, 7, 6, 5, 4, 3, 2, 1, 0));
|
||||||
|
@ -929,17 +978,17 @@ void ggml_gemv_q4_0_8x8_q8_0(int n, float * restrict s, size_t bs, const void *
|
||||||
// ...........................................................................
|
// ...........................................................................
|
||||||
// B0(28-31) B4(28-31) B1(28-31) B5(28-31) B2(28-31) B6(28-31) B3(28-31) B7(28-31) with A0(28-31)
|
// B0(28-31) B4(28-31) B1(28-31) B5(28-31) B2(28-31) B6(28-31) B3(28-31) B7(28-31) with A0(28-31)
|
||||||
|
|
||||||
iacc = _mm256_add_epi32(iacc, mul_sum_i8_pairs_int(_mm256_blend_epi32(rhs_vec_0123_0 ,_mm256_shuffle_epi32(rhs_vec_4567_0, 177), 170), _mm256_shuffle_epi32(lhs_vec_0, 0)));
|
iacc = _mm256_add_epi32(iacc, mul_sum_i8_pairs_int32x8(_mm256_blend_epi32(rhs_vec_0123_0 ,_mm256_shuffle_epi32(rhs_vec_4567_0, 177), 170), _mm256_shuffle_epi32(lhs_vec_0, 0)));
|
||||||
iacc = _mm256_add_epi32(iacc, mul_sum_i8_pairs_int(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_0, 177) ,rhs_vec_4567_0, 170), _mm256_shuffle_epi32(lhs_vec_0, 85)));
|
iacc = _mm256_add_epi32(iacc, mul_sum_i8_pairs_int32x8(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_0, 177) ,rhs_vec_4567_0, 170), _mm256_shuffle_epi32(lhs_vec_0, 85)));
|
||||||
|
|
||||||
iacc = _mm256_add_epi32(iacc, mul_sum_i8_pairs_int(_mm256_blend_epi32(rhs_vec_0123_1 ,_mm256_shuffle_epi32(rhs_vec_4567_1, 177), 170), _mm256_shuffle_epi32(lhs_vec_0, 170)));
|
iacc = _mm256_add_epi32(iacc, mul_sum_i8_pairs_int32x8(_mm256_blend_epi32(rhs_vec_0123_1 ,_mm256_shuffle_epi32(rhs_vec_4567_1, 177), 170), _mm256_shuffle_epi32(lhs_vec_0, 170)));
|
||||||
iacc = _mm256_add_epi32(iacc, mul_sum_i8_pairs_int(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_1, 177) ,rhs_vec_4567_1, 170), _mm256_shuffle_epi32(lhs_vec_0, 255)));
|
iacc = _mm256_add_epi32(iacc, mul_sum_i8_pairs_int32x8(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_1, 177) ,rhs_vec_4567_1, 170), _mm256_shuffle_epi32(lhs_vec_0, 255)));
|
||||||
|
|
||||||
iacc = _mm256_add_epi32(iacc, mul_sum_i8_pairs_int(_mm256_blend_epi32(rhs_vec_0123_2 ,_mm256_shuffle_epi32(rhs_vec_4567_2, 177), 170), _mm256_shuffle_epi32(lhs_vec_1, 0)));
|
iacc = _mm256_add_epi32(iacc, mul_sum_i8_pairs_int32x8(_mm256_blend_epi32(rhs_vec_0123_2 ,_mm256_shuffle_epi32(rhs_vec_4567_2, 177), 170), _mm256_shuffle_epi32(lhs_vec_1, 0)));
|
||||||
iacc = _mm256_add_epi32(iacc, mul_sum_i8_pairs_int(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_2, 177) ,rhs_vec_4567_2, 170), _mm256_shuffle_epi32(lhs_vec_1, 85)));
|
iacc = _mm256_add_epi32(iacc, mul_sum_i8_pairs_int32x8(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_2, 177) ,rhs_vec_4567_2, 170), _mm256_shuffle_epi32(lhs_vec_1, 85)));
|
||||||
|
|
||||||
iacc = _mm256_add_epi32(iacc, mul_sum_i8_pairs_int(_mm256_blend_epi32(rhs_vec_0123_3 ,_mm256_shuffle_epi32(rhs_vec_4567_3, 177), 170), _mm256_shuffle_epi32(lhs_vec_1, 170)));
|
iacc = _mm256_add_epi32(iacc, mul_sum_i8_pairs_int32x8(_mm256_blend_epi32(rhs_vec_0123_3 ,_mm256_shuffle_epi32(rhs_vec_4567_3, 177), 170), _mm256_shuffle_epi32(lhs_vec_1, 170)));
|
||||||
iacc = _mm256_add_epi32(iacc, mul_sum_i8_pairs_int(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_3, 177) ,rhs_vec_4567_3, 170), _mm256_shuffle_epi32(lhs_vec_1, 255)));
|
iacc = _mm256_add_epi32(iacc, mul_sum_i8_pairs_int32x8(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_3, 177) ,rhs_vec_4567_3, 170), _mm256_shuffle_epi32(lhs_vec_1, 255)));
|
||||||
|
|
||||||
// Accumulated values multipled with appropriate scales
|
// Accumulated values multipled with appropriate scales
|
||||||
acc_row = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc), _mm256_mul_ps(col_scale_f32, row_scale_f32), acc_row);
|
acc_row = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc), _mm256_mul_ps(col_scale_f32, row_scale_f32), acc_row);
|
||||||
|
@ -950,7 +999,9 @@ void ggml_gemv_q4_0_8x8_q8_0(int n, float * restrict s, size_t bs, const void *
|
||||||
_mm256_storeu_ps(s + (y * nr + x * 8), acc_row);
|
_mm256_storeu_ps(s + (y * nr + x * 8), acc_row);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
#else
|
return;
|
||||||
|
#endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__)
|
||||||
|
{
|
||||||
float sumf[8];
|
float sumf[8];
|
||||||
int sumi;
|
int sumi;
|
||||||
|
|
||||||
|
@ -974,7 +1025,7 @@ void ggml_gemv_q4_0_8x8_q8_0(int n, float * restrict s, size_t bs, const void *
|
||||||
}
|
}
|
||||||
for (int j = 0; j < ncols_interleaved; j++) s[x * ncols_interleaved + j] = sumf[j];
|
for (int j = 0; j < ncols_interleaved; j++) s[x * ncols_interleaved + j] = sumf[j];
|
||||||
}
|
}
|
||||||
#endif
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void ggml_gemm_q4_0_4x4_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, const void * restrict vy, int nr, int nc) {
|
void ggml_gemm_q4_0_4x4_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, const void * restrict vy, int nr, int nc) {
|
||||||
|
@ -997,16 +1048,8 @@ void ggml_gemm_q4_0_4x4_q8_0(int n, float * restrict s, size_t bs, const void *
|
||||||
UNUSED(ncols_interleaved);
|
UNUSED(ncols_interleaved);
|
||||||
UNUSED(blocklen);
|
UNUSED(blocklen);
|
||||||
|
|
||||||
#if defined(__ARM_FEATURE_SVE) && defined(__ARM_FEATURE_MATMUL_INT8)
|
#if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON)
|
||||||
if (ggml_sve_cnt_b == QK8_0) {
|
if (ggml_cpu_has_neon()) {
|
||||||
GGML_ASSERT(!(ggml_cpu_has_sve() && (ggml_sve_cnt_b == QK8_0)) &&
|
|
||||||
"__ARM_FEATURE_SVE defined, use the Q4_0_8_8 quantization format for optimal performance");
|
|
||||||
}
|
|
||||||
#endif
|
|
||||||
#if defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8)
|
|
||||||
GGML_ASSERT(!(ggml_cpu_has_neon() && ggml_cpu_has_matmul_int8()) &&
|
|
||||||
"__ARM_NEON and __ARM_FEATURE_MATMUL_INT8 defined, use the Q4_0_4_8 quantization format for optimal performance");
|
|
||||||
#elif defined(__ARM_NEON) && defined(__aarch64__) && ! ((defined(_MSC_VER)) && ! defined(__clang__))
|
|
||||||
const void * b_ptr = vx;
|
const void * b_ptr = vx;
|
||||||
const void * a_ptr = vy;
|
const void * a_ptr = vy;
|
||||||
float * res_ptr = s;
|
float * res_ptr = s;
|
||||||
|
@ -1462,7 +1505,10 @@ void ggml_gemm_q4_0_4x4_q8_0(int n, float * restrict s, size_t bs, const void *
|
||||||
: [b_ptr] "r" (b_ptr), [nr] "r" (nr), [nb] "r" (nb), [res_stride] "r" (res_stride), [nc] "r" (nc)
|
: [b_ptr] "r" (b_ptr), [nr] "r" (nr), [nb] "r" (nb), [res_stride] "r" (res_stride), [nc] "r" (nc)
|
||||||
: "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "x9", "x10", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28"
|
: "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "x9", "x10", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28"
|
||||||
);
|
);
|
||||||
#else
|
return;
|
||||||
|
}
|
||||||
|
#endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON)
|
||||||
|
{
|
||||||
float sumf[4][4];
|
float sumf[4][4];
|
||||||
int sumi;
|
int sumi;
|
||||||
|
|
||||||
|
@ -1495,7 +1541,7 @@ void ggml_gemm_q4_0_4x4_q8_0(int n, float * restrict s, size_t bs, const void *
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
#endif
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void ggml_gemm_q4_0_4x8_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, const void * restrict vy, int nr, int nc) {
|
void ggml_gemm_q4_0_4x8_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, const void * restrict vy, int nr, int nc) {
|
||||||
|
@ -1518,13 +1564,8 @@ void ggml_gemm_q4_0_4x8_q8_0(int n, float * restrict s, size_t bs, const void *
|
||||||
UNUSED(ncols_interleaved);
|
UNUSED(ncols_interleaved);
|
||||||
UNUSED(blocklen);
|
UNUSED(blocklen);
|
||||||
|
|
||||||
#if defined(__ARM_FEATURE_SVE) && defined(__ARM_FEATURE_MATMUL_INT8)
|
#if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8)
|
||||||
if (ggml_sve_cnt_b == QK8_0) {
|
if (ggml_cpu_has_neon() && ggml_cpu_has_matmul_int8()) {
|
||||||
GGML_ASSERT(!(ggml_cpu_has_sve() && (ggml_sve_cnt_b == QK8_0)) &&
|
|
||||||
"__ARM_FEATURE_SVE defined, use the Q4_0_8_8 quantization format for optimal performance");
|
|
||||||
}
|
|
||||||
#endif
|
|
||||||
#if defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8) && ! ((defined(_MSC_VER)) && ! defined(__clang__))
|
|
||||||
const void * b_ptr = vx;
|
const void * b_ptr = vx;
|
||||||
const void * a_ptr = vy;
|
const void * a_ptr = vy;
|
||||||
float * res_ptr = s;
|
float * res_ptr = s;
|
||||||
|
@ -1920,11 +1961,9 @@ void ggml_gemm_q4_0_4x8_q8_0(int n, float * restrict s, size_t bs, const void *
|
||||||
: [b_ptr] "r" (b_ptr), [nr] "r" (nr), [nb] "r" (nb), [res_stride] "r" (res_stride), [nc] "r" (nc)
|
: [b_ptr] "r" (b_ptr), [nr] "r" (nr), [nb] "r" (nb), [res_stride] "r" (res_stride), [nc] "r" (nc)
|
||||||
: "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "x9", "x10", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28"
|
: "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "x9", "x10", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28"
|
||||||
);
|
);
|
||||||
#elif defined(__ARM_NEON) && defined(__aarch64__)
|
return;
|
||||||
GGML_ASSERT((ggml_cpu_has_sve() || ggml_cpu_has_matmul_int8()) &&
|
}
|
||||||
"__ARM_FEATURE_SVE and __ARM_FEATURE_MATMUL_INT8 not defined, use the Q4_0_4_4 quantization format for optimal "
|
#endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8)
|
||||||
"performance");
|
|
||||||
#else
|
|
||||||
float sumf[4][4];
|
float sumf[4][4];
|
||||||
int sumi;
|
int sumi;
|
||||||
|
|
||||||
|
@ -1957,7 +1996,6 @@ void ggml_gemm_q4_0_4x8_q8_0(int n, float * restrict s, size_t bs, const void *
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
#endif
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void ggml_gemm_q4_0_8x8_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, const void * restrict vy, int nr, int nc) {
|
void ggml_gemm_q4_0_8x8_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, const void * restrict vy, int nr, int nc) {
|
||||||
|
@ -1980,8 +2018,9 @@ void ggml_gemm_q4_0_8x8_q8_0(int n, float * restrict s, size_t bs, const void *
|
||||||
UNUSED(ncols_interleaved);
|
UNUSED(ncols_interleaved);
|
||||||
UNUSED(blocklen);
|
UNUSED(blocklen);
|
||||||
|
|
||||||
#if defined(__ARM_FEATURE_SVE) && defined(__ARM_FEATURE_MATMUL_INT8) && ! ((defined(_MSC_VER)) && ! defined(__clang__))
|
#if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__)
|
||||||
if (ggml_sve_cnt_b == QK8_0) {
|
#if defined(__ARM_FEATURE_SVE) && defined(__ARM_FEATURE_MATMUL_INT8)
|
||||||
|
if (ggml_cpu_has_sve() && ggml_cpu_has_matmul_int8() && sve_lane_count() == QK8_0) {
|
||||||
const void * b_ptr = vx;
|
const void * b_ptr = vx;
|
||||||
const void * a_ptr = vy;
|
const void * a_ptr = vy;
|
||||||
float * res_ptr = s;
|
float * res_ptr = s;
|
||||||
|
@ -2391,25 +2430,9 @@ void ggml_gemm_q4_0_8x8_q8_0(int n, float * restrict s, size_t bs, const void *
|
||||||
);
|
);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
else if (ggml_cpu_has_neon() && ggml_cpu_has_matmul_int8()) {
|
#endif // #if defined(__ARM_FEATURE_SVE) && defined(__ARM_FEATURE_MATMUL_INT8)
|
||||||
GGML_ASSERT((ggml_cpu_has_sve() && (ggml_sve_cnt_b == QK8_0)) &&
|
|
||||||
"__ARM_FEATURE_SVE for vector size of 256-bits not defined, use the Q4_0_4_8 quantization format for optimal "
|
|
||||||
"performance");
|
|
||||||
}
|
|
||||||
else if (ggml_cpu_has_neon()) {
|
|
||||||
GGML_ASSERT(((ggml_cpu_has_sve() && (ggml_sve_cnt_b == QK8_0)) || ggml_cpu_has_matmul_int8()) &&
|
|
||||||
"__ARM_FEATURE_SVE for vector size of 256-bits and __ARM_FEATURE_MATMUL_INT8 not defined, use the Q4_0_4_4 "
|
|
||||||
"quantization format for optimal performance");
|
|
||||||
}
|
|
||||||
#endif
|
|
||||||
#if defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8)
|
|
||||||
GGML_ASSERT(ggml_cpu_has_sve() &&
|
|
||||||
"__ARM_FEATURE_SVE not defined, use the Q4_0_4_8 quantization format for optimal performance");
|
|
||||||
#elif defined(__ARM_NEON) && defined(__aarch64__)
|
|
||||||
GGML_ASSERT((ggml_cpu_has_sve() || ggml_cpu_has_matmul_int8()) &&
|
|
||||||
"__ARM_FEATURE_SVE and __ARM_FEATURE_MATMUL_INT8 not defined, use the Q4_0_4_4 quantization format for optimal "
|
|
||||||
"performance");
|
|
||||||
#elif defined(__AVX2__) || defined(__AVX512F__)
|
#elif defined(__AVX2__) || defined(__AVX512F__)
|
||||||
|
{
|
||||||
const block_q4_0x8 * b_ptr_start = (const block_q4_0x8 *)vx;
|
const block_q4_0x8 * b_ptr_start = (const block_q4_0x8 *)vx;
|
||||||
const block_q8_0x4 * a_ptr_start = (const block_q8_0x4 *)vy;
|
const block_q8_0x4 * a_ptr_start = (const block_q8_0x4 *)vy;
|
||||||
int64_t b_nb = n / QK4_0;
|
int64_t b_nb = n / QK4_0;
|
||||||
|
@ -2421,10 +2444,411 @@ void ggml_gemm_q4_0_8x8_q8_0(int n, float * restrict s, size_t bs, const void *
|
||||||
__m256i signextendlut = _mm256_castsi128_si256(_mm_set_epi8(-1, -2, -3, -4, -5, -6, -7, -8, 7, 6, 5, 4, 3, 2, 1, 0));
|
__m256i signextendlut = _mm256_castsi128_si256(_mm_set_epi8(-1, -2, -3, -4, -5, -6, -7, -8, 7, 6, 5, 4, 3, 2, 1, 0));
|
||||||
signextendlut = _mm256_permute2f128_si256(signextendlut, signextendlut, 0);
|
signextendlut = _mm256_permute2f128_si256(signextendlut, signextendlut, 0);
|
||||||
// Permute mask used for easier vector processing at later stages
|
// Permute mask used for easier vector processing at later stages
|
||||||
__m256i requiredOrder = _mm256_set_epi32(3 ,2 ,1 ,0, 7 ,6, 5, 4);
|
__m256i requiredOrder = _mm256_set_epi32(3, 2, 1, 0, 7, 6, 5, 4);
|
||||||
|
int64_t xstart = 0;
|
||||||
|
int anr = nr - nr%16; // Used to align nr with boundary of 16
|
||||||
|
#ifdef __AVX512F__
|
||||||
|
int anc = nc - nc%16; // Used to align nc with boundary of 16
|
||||||
|
// Mask to mask out nibbles from packed bytes expanded to 512 bit length
|
||||||
|
const __m512i m4bexpanded = _mm512_set1_epi8(0x0F);
|
||||||
|
// Lookup table to convert signed nibbles to signed bytes expanded to 512 bit length
|
||||||
|
__m512i signextendlutexpanded = _mm512_inserti32x8(_mm512_castsi256_si512(signextendlut), signextendlut, 1);
|
||||||
|
|
||||||
|
// Take group of four block_q8_0x4 structures at each pass of the loop and perform dot product operation
|
||||||
|
for (; y < anr / 4; y += 4) {
|
||||||
|
|
||||||
|
const block_q8_0x4 * a_ptrs[4];
|
||||||
|
|
||||||
|
a_ptrs[0] = a_ptr_start + (y * nb);
|
||||||
|
for (int i = 0; i < 3; ++i) {
|
||||||
|
a_ptrs[i + 1] = a_ptrs[i] + nb;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Take group of two block_q4_0x8 structures at each pass of the loop and perform dot product operation
|
||||||
|
for (int64_t x = 0; x < anc / 8; x += 2) {
|
||||||
|
|
||||||
|
const block_q4_0x8 * b_ptr_0 = b_ptr_start + ((x) * b_nb);
|
||||||
|
const block_q4_0x8 * b_ptr_1 = b_ptr_start + ((x + 1) * b_nb);
|
||||||
|
|
||||||
|
// Master FP accumulators
|
||||||
|
__m512 acc_rows[16];
|
||||||
|
for (int i = 0; i < 16; i++) {
|
||||||
|
acc_rows[i] = _mm512_setzero_ps();
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int64_t b = 0; b < nb; b++) {
|
||||||
|
// Load the sixteen block_q4_0 quantized values interleaved with each other in chunks of eight - B0,B1 ....BE,BF
|
||||||
|
const __m256i rhs_raw_mat_0123_0 = _mm256_loadu_si256((const __m256i *)(b_ptr_0[b].qs));
|
||||||
|
const __m256i rhs_raw_mat_4567_0 = _mm256_loadu_si256((const __m256i *)(b_ptr_0[b].qs + 32));
|
||||||
|
const __m256i rhs_raw_mat_0123_1 = _mm256_loadu_si256((const __m256i *)(b_ptr_0[b].qs + 64));
|
||||||
|
const __m256i rhs_raw_mat_4567_1 = _mm256_loadu_si256((const __m256i *)(b_ptr_0[b].qs + 96));
|
||||||
|
|
||||||
|
const __m256i rhs_raw_mat_89AB_0 = _mm256_loadu_si256((const __m256i *)(b_ptr_1[b].qs));
|
||||||
|
const __m256i rhs_raw_mat_CDEF_0 = _mm256_loadu_si256((const __m256i *)(b_ptr_1[b].qs + 32));
|
||||||
|
const __m256i rhs_raw_mat_89AB_1 = _mm256_loadu_si256((const __m256i *)(b_ptr_1[b].qs + 64));
|
||||||
|
const __m256i rhs_raw_mat_CDEF_1 = _mm256_loadu_si256((const __m256i *)(b_ptr_1[b].qs + 96));
|
||||||
|
|
||||||
|
// Save the values in the following vectors in the formats B0B1B4B5B8B9BCBD, B2B3B6B7BABBBEBF for further processing and storing of values
|
||||||
|
const __m256i rhs_raw_mat_0145_0 = _mm256_blend_epi32(rhs_raw_mat_0123_0, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_0, requiredOrder), 240);
|
||||||
|
const __m256i rhs_raw_mat_2367_0 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_0, requiredOrder), rhs_raw_mat_4567_0, 240);
|
||||||
|
const __m256i rhs_raw_mat_0145_1 = _mm256_blend_epi32(rhs_raw_mat_0123_1, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_1, requiredOrder), 240);
|
||||||
|
const __m256i rhs_raw_mat_2367_1 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_1, requiredOrder), rhs_raw_mat_4567_1, 240);
|
||||||
|
|
||||||
|
const __m256i rhs_raw_mat_89CD_0 = _mm256_blend_epi32(rhs_raw_mat_89AB_0, _mm256_permutevar8x32_epi32(rhs_raw_mat_CDEF_0, requiredOrder), 240);
|
||||||
|
const __m256i rhs_raw_mat_ABEF_0 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_89AB_0, requiredOrder), rhs_raw_mat_CDEF_0, 240);
|
||||||
|
const __m256i rhs_raw_mat_89CD_1 = _mm256_blend_epi32(rhs_raw_mat_89AB_1, _mm256_permutevar8x32_epi32(rhs_raw_mat_CDEF_1, requiredOrder), 240);
|
||||||
|
const __m256i rhs_raw_mat_ABEF_1 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_89AB_1, requiredOrder), rhs_raw_mat_CDEF_1, 240);
|
||||||
|
|
||||||
|
const __m512i rhs_raw_mat_014589CD_0 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_0145_0), rhs_raw_mat_89CD_0, 1);
|
||||||
|
const __m512i rhs_raw_mat_2367ABEF_0 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_2367_0), rhs_raw_mat_ABEF_0, 1);
|
||||||
|
const __m512i rhs_raw_mat_014589CD_1 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_0145_1), rhs_raw_mat_89CD_1, 1);
|
||||||
|
const __m512i rhs_raw_mat_2367ABEF_1 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_2367_1), rhs_raw_mat_ABEF_1, 1);
|
||||||
|
|
||||||
|
// 4-bit -> 8-bit - Sign is maintained
|
||||||
|
const __m512i rhs_mat_014589CD_0 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(rhs_raw_mat_014589CD_0, m4bexpanded)); //B0(0-7) B1(0-7) B4(0-7) B5(0-7) B8(0-7) B9(0-7) BC(0-7) BD(0-7)
|
||||||
|
const __m512i rhs_mat_2367ABEF_0 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(rhs_raw_mat_2367ABEF_0, m4bexpanded)); //B2(0-7) B3(0-7) B6(0-7) B7(0-7) BA(0-7) BB(0-7) BE(0-7) BF(0-7)
|
||||||
|
|
||||||
|
const __m512i rhs_mat_014589CD_1 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(rhs_raw_mat_014589CD_1, m4bexpanded)); //B0(8-15) B1(8-15) B4(8-15) B5(8-15) B8(8-15) B9(8-15) BC(8-15) BD(8-15)
|
||||||
|
const __m512i rhs_mat_2367ABEF_1 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(rhs_raw_mat_2367ABEF_1, m4bexpanded)); //B2(8-15) B3(8-15) B6(8-15) B7(8-15) BA(8-15) BB(8-15) BE(8-15) BF(8-15)
|
||||||
|
|
||||||
|
const __m512i rhs_mat_014589CD_2 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_014589CD_0, 4), m4bexpanded)); //B0(16-23) B1(16-23) B4(16-23) B5(16-23) B8(16-23) B9(16-23) BC(16-23) BD(16-23)
|
||||||
|
const __m512i rhs_mat_2367ABEF_2 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_2367ABEF_0, 4), m4bexpanded)); //B2(16-23) B3(16-23) B6(16-23) B7(16-23) BA(16-23) BB(16-23) BE(16-23) BF(16-23)
|
||||||
|
|
||||||
|
const __m512i rhs_mat_014589CD_3 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_014589CD_1, 4), m4bexpanded)); //B0(24-31) B1(24-31) B4(24-31) B5(24-31) B8(24-31) B9(24-31) BC(24-31) BD(24-31)
|
||||||
|
const __m512i rhs_mat_2367ABEF_3 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_2367ABEF_1, 4), m4bexpanded)); //B2(24-31) B3(24-31) B6(24-31) B7(24-31) BA(24-31) BB(24-31) BE(24-31) BF(24-31)
|
||||||
|
|
||||||
|
// Shuffle pattern one - right side input
|
||||||
|
const __m512i rhs_mat_014589CD_0_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_0, 136); //B0(0-3) B1(0-3) B0(0-3) B1(0-3) B4(0-3) B5(0-3) B4(0-3) B5(0-3) B8(0-3) B9(0-3) B8(0-3) B9(0-3) BC(0-3) BD(0-3) BC(0-3) BD(0-3)
|
||||||
|
const __m512i rhs_mat_2367ABEF_0_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_0, 136); //B2(0-3) B3(0-3) B2(0-3) B3(0-3) B6(0-3) B7(0-3) B6(0-3) B7(0-3) BA(0-3) BB(0-3) BA(0-3) BB(0-3) BE(0-3) BF(0-3) BE(0-3) BF(0-3)
|
||||||
|
|
||||||
|
const __m512i rhs_mat_014589CD_1_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_1, 136); //B0(8-11) B1(8-11) B0(8-11) B1(8-11) B4(8-11) B5(8-11) B4(8-11) B5(8-11) B8(8-11) B9(8-11) B8(8-11) B9(8-11) BC(8-11) BD(8-11) BC(8-11) BD(8-11)
|
||||||
|
const __m512i rhs_mat_2367ABEF_1_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_1, 136); //B2(8-11) B3(8-11) B2(8-11) B3(8-11) B6(8-11) B7(8-11) B6(8-11) B7(8-11) BA(8-11) BB(8-11) BA(8-11) BB(8-11) BE(8-11) BF(8-11) BE(8-11) BF(8-11)
|
||||||
|
|
||||||
|
const __m512i rhs_mat_014589CD_2_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_2, 136); //B0(16-19) B1(16-19) B0(16-19) B1(16-19) B4(16-19) B5(16-19) B4(16-19) B5(16-19) B8(16-19) B9(16-19) B8(16-19) B9(16-19) BC(16-19) BD(16-19) BC(16-19) BD(16-19)
|
||||||
|
const __m512i rhs_mat_2367ABEF_2_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_2, 136); //B2(16-19) B3(16-19) B2(16-19) B3(16-19) B6(16-19) B7(16-19) B6(16-19) B7(16-19) BA(16-19) BB(16-19) BA(16-19) BB(16-19) BE(16-19) BF(16-19) BE(16-19) BF(16-19)
|
||||||
|
|
||||||
|
const __m512i rhs_mat_014589CD_3_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_3, 136); //B0(24-27) B1(24-27) B0(24-27) B1(24-27) B4(24-27) B5(24-27) B4(24-27) B5(24-27) B8(24-27) B9(24-27) B8(24-27) B9(24-27) BC(24-27) BD(24-27) BC(24-27) BD(24-27)
|
||||||
|
const __m512i rhs_mat_2367ABEF_3_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_3, 136); //B2(24-27) B3(24-27) B2(24-27) B3(24-27) B6(24-27) B7(24-27) B6(24-27) B7(24-27) BA(24-27) BB(24-27) BA(24-27) BB(24-27) BE(24-27) BF(24-27) BE(24-27) BF(24-27)
|
||||||
|
|
||||||
|
// Shuffle pattern two - right side input
|
||||||
|
|
||||||
|
const __m512i rhs_mat_014589CD_0_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_0, 221); //B0(4-7) B1(4-7) B0(4-7) B1(4-7) B4(4-7) B5(4-7) B4(4-7) B5(4-7) B8(4-7) B9(4-7) B8(4-7) B9(4-7) BC(4-7) BD(4-7) BC(4-7) BD(4-7)
|
||||||
|
const __m512i rhs_mat_2367ABEF_0_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_0, 221); //B2(4-7) B3(4-7) B2(4-7) B3(4-7) B6(4-7) B7(4-7) B6(4-7) B7(4-7) BA(4-7) BB(4-7) BA(4-7) BB(4-7) BE(4-7) BF(4-7) BE(4-7) BF(4-7)
|
||||||
|
|
||||||
|
const __m512i rhs_mat_014589CD_1_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_1, 221); //B0(12-15) B1(12-15) B0(12-15) B1(12-15) B4(12-15) B5(12-15) B4(12-15) B5(12-15) B8(12-15) B9(12-15) B8(12-15) B9(12-15) BC(12-15) BD(12-15) BC(12-15) BD(12-15)
|
||||||
|
const __m512i rhs_mat_2367ABEF_1_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_1, 221); //B2(12-15) B3(12-15) B2(12-15) B3(12-15) B6(12-15) B7(12-15) B6(12-15) B7(12-15) BA(12-15) BB(12-15) BA(12-15) BB(12-15) BE(12-15) BF(12-15) BE(12-15) BF(12-15)
|
||||||
|
|
||||||
|
const __m512i rhs_mat_014589CD_2_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_2, 221); //B0(20-23) B1(20-23) B0(20-23) B1(20-23) B4(20-23) B5(20-23) B4(20-23) B5(20-23) B8(20-23) B9(20-23) B8(20-23) B9(20-23) BC(20-23) BD(20-23) BC(20-23) BD(20-23)
|
||||||
|
const __m512i rhs_mat_2367ABEF_2_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_2, 221); //B2(20-23) B3(20-23) B2(20-23) B3(20-23) B6(20-23) B7(20-23) B6(20-23) B7(20-23) BA(20-23) BB(20-23) BA(20-23) BB(20-23) BE(20-23) BF(20-23) BE(20-23) BF(20-23)
|
||||||
|
|
||||||
|
const __m512i rhs_mat_014589CD_3_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_3, 221); //B0(28-31) B1(28-31) B0(28-31) B1(28-31) B4(28-31) B5(28-31) B4(28-31) B5(28-31) B8(28-31) B9(28-31) B8(28-31) B9(28-31) BC(28-31) BD(28-31) BC(28-31) BD(28-31)
|
||||||
|
const __m512i rhs_mat_2367ABEF_3_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_3, 221); //B2(28-31) B3(28-31) B2(28-31) B3(28-31) B6(28-31) B7(28-31) B6(28-31) B7(28-31) BA(28-31) BB(28-31) BA(28-31) BB(28-31) BE(28-31) BF(28-31) BE(28-31) BF(28-31)
|
||||||
|
|
||||||
|
// Scale values - Load the weight scale values of two block_q4_0x8
|
||||||
|
const __m512 col_scale_f32 = GGML_F32Cx8x2_LOAD(b_ptr_0[b].d, b_ptr_1[b].d);
|
||||||
|
|
||||||
|
// Process LHS in pairs of rows
|
||||||
|
for (int rp = 0; rp < 4; rp++) {
|
||||||
|
|
||||||
|
// Load the four block_q4_0 quantized values interleaved with each other in chunks of eight - A0,A1,A2,A3
|
||||||
|
// Loaded as set of 128 bit vectors and repeated and stored into a 256 bit vector before again repeating into 512 bit vector
|
||||||
|
__m256i lhs_mat_ymm_0123_0 = _mm256_loadu_si256((const __m256i *)((a_ptrs[rp][b].qs)));
|
||||||
|
__m256i lhs_mat_ymm_01_0 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_0, lhs_mat_ymm_0123_0, 0);
|
||||||
|
__m256i lhs_mat_ymm_23_0 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_0, lhs_mat_ymm_0123_0, 17);
|
||||||
|
__m256i lhs_mat_ymm_0123_1 = _mm256_loadu_si256((const __m256i *)((a_ptrs[rp][b].qs + 32)));
|
||||||
|
__m256i lhs_mat_ymm_01_1 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_1, lhs_mat_ymm_0123_1, 0);
|
||||||
|
__m256i lhs_mat_ymm_23_1 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_1, lhs_mat_ymm_0123_1, 17);
|
||||||
|
__m256i lhs_mat_ymm_0123_2 = _mm256_loadu_si256((const __m256i *)((a_ptrs[rp][b].qs + 64)));
|
||||||
|
__m256i lhs_mat_ymm_01_2 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_2, lhs_mat_ymm_0123_2, 0);
|
||||||
|
__m256i lhs_mat_ymm_23_2 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_2, lhs_mat_ymm_0123_2, 17);
|
||||||
|
__m256i lhs_mat_ymm_0123_3 = _mm256_loadu_si256((const __m256i *)((a_ptrs[rp][b].qs + 96)));
|
||||||
|
__m256i lhs_mat_ymm_01_3 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_3, lhs_mat_ymm_0123_3, 0);
|
||||||
|
__m256i lhs_mat_ymm_23_3 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_3, lhs_mat_ymm_0123_3, 17);
|
||||||
|
|
||||||
|
__m512i lhs_mat_01_0 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_0), lhs_mat_ymm_01_0, 1);
|
||||||
|
__m512i lhs_mat_23_0 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_0), lhs_mat_ymm_23_0, 1);
|
||||||
|
__m512i lhs_mat_01_1 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_1), lhs_mat_ymm_01_1, 1);
|
||||||
|
__m512i lhs_mat_23_1 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_1), lhs_mat_ymm_23_1, 1);
|
||||||
|
__m512i lhs_mat_01_2 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_2), lhs_mat_ymm_01_2, 1);
|
||||||
|
__m512i lhs_mat_23_2 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_2), lhs_mat_ymm_23_2, 1);
|
||||||
|
__m512i lhs_mat_01_3 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_3), lhs_mat_ymm_01_3, 1);
|
||||||
|
__m512i lhs_mat_23_3 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_3), lhs_mat_ymm_23_3, 1);
|
||||||
|
|
||||||
|
// Shuffle pattern one - left side input
|
||||||
|
|
||||||
|
const __m512i lhs_mat_01_0_sp1 = _mm512_shuffle_epi32(lhs_mat_01_0, 160); //A0(0-3) A0(0-3) A1(0-3) A1(0-3) A0(0-3) A0(0-3) A1(0-3) A1(0-3) A0(0-3) A0(0-3) A1(0-3) A1(0-3) A0(0-3) A0(0-3) A1(0-3) A1(0-3)
|
||||||
|
const __m512i lhs_mat_23_0_sp1 = _mm512_shuffle_epi32(lhs_mat_23_0, 160); //A2(0-3) A2(0-3) A3(0-3) A3(0-3) A2(0-3) A2(0-3) A3(0-3) A3(0-3) A2(0-3) A2(0-3) A3(0-3) A3(0-3) A2(0-3) A2(0-3) A3(0-3) A3(0-3)
|
||||||
|
|
||||||
|
const __m512i lhs_mat_01_1_sp1 = _mm512_shuffle_epi32(lhs_mat_01_1, 160); //A0(8-11) A0(8-11) A1(8-11) A1(8-11) A0(8-11) A0(8-11) A1(8-11) A1(8-11) A0(8-11) A0(8-11) A1(8-11) A1(8-11) A0(8-11) A0(8-11) A1(8-11) A1(8-11)
|
||||||
|
const __m512i lhs_mat_23_1_sp1 = _mm512_shuffle_epi32(lhs_mat_23_1, 160); //A2(8-11) A2(8-11) A3(8-11) A3(8-11) A2(8-11) A2(8-11) A3(8-11) A3(8-11) A2(8-11) A2(8-11) A3(8-11) A3(8-11) A2(8-11) A2(8-11) A3(8-11) A3(8-11)
|
||||||
|
|
||||||
|
const __m512i lhs_mat_01_2_sp1 = _mm512_shuffle_epi32(lhs_mat_01_2, 160); //A0(16-19) A0(16-19) A1(16-19) A1(16-19) A0(16-19) A0(16-19) A1(16-19) A1(16-19) A0(16-19) A0(16-19) A1(16-19) A1(16-19) A0(16-19) A0(16-19) A1(16-19) A1(16-19)
|
||||||
|
const __m512i lhs_mat_23_2_sp1 = _mm512_shuffle_epi32(lhs_mat_23_2, 160); //A2(16-19) A2(16-19) A3(16-19) A3(16-19) A2(16-19) A2(16-19) A3(16-19) A3(16-19) A2(16-19) A2(16-19) A3(16-19) A3(16-19) A2(16-19) A2(16-19) A3(16-19) A3(16-19)
|
||||||
|
|
||||||
|
const __m512i lhs_mat_01_3_sp1 = _mm512_shuffle_epi32(lhs_mat_01_3, 160); //A0(24-27) A0(24-27) A1(24-27) A1(24-27) A0(24-27) A0(24-27) A1(24-27) A1(24-27) A0(24-27) A0(24-27) A1(24-27) A1(24-27) A0(24-27) A0(24-27) A1(24-27) A1(24-27)
|
||||||
|
const __m512i lhs_mat_23_3_sp1 = _mm512_shuffle_epi32(lhs_mat_23_3, 160); //A2(24-27) A2(24-27) A3(24-27) A3(24-27) A2(24-27) A2(24-27) A3(24-27) A3(24-27) A2(24-27) A2(24-27) A3(24-27) A3(24-27) A2(24-27) A2(24-27) A3(24-27) A3(24-27)
|
||||||
|
|
||||||
|
// Shuffle pattern two - left side input
|
||||||
|
|
||||||
|
const __m512i lhs_mat_01_0_sp2 = _mm512_shuffle_epi32(lhs_mat_01_0, 245); //A0(4-7) A0(4-7) A1(4-7) A1(4-7) A0(4-7) A0(4-7) A1(4-7) A1(4-7) A0(4-7) A0(4-7) A1(4-7) A1(4-7) A0(4-7) A0(4-7) A1(4-7) A1(4-7)
|
||||||
|
const __m512i lhs_mat_23_0_sp2 = _mm512_shuffle_epi32(lhs_mat_23_0, 245); //A2(4-7) A2(4-7) A3(4-7) A3(4-7) A2(4-7) A2(4-7) A3(4-7) A3(4-7) A2(4-7) A2(4-7) A3(4-7) A3(4-7) A2(4-7) A2(4-7) A3(4-7) A3(4-7)
|
||||||
|
|
||||||
|
const __m512i lhs_mat_01_1_sp2 = _mm512_shuffle_epi32(lhs_mat_01_1, 245); //A0(12-15) A0(12-15) A1(12-15) A1(12-15) A0(12-15) A0(12-15) A1(12-15) A1(12-15) A0(12-15) A0(12-15) A1(12-15) A1(12-15) A0(12-15) A0(12-15) A1(12-15) A1(12-15)
|
||||||
|
const __m512i lhs_mat_23_1_sp2 = _mm512_shuffle_epi32(lhs_mat_23_1, 245); //A2(12-15) A2(12-15) A3(12-15) A3(12-15) A2(12-15) A2(12-15) A3(12-15) A3(12-15) A2(12-15) A2(12-15) A3(12-15) A3(12-15) A2(12-15) A2(12-15) A3(12-15) A3(12-15)
|
||||||
|
|
||||||
|
const __m512i lhs_mat_01_2_sp2 = _mm512_shuffle_epi32(lhs_mat_01_2, 245); //A0(20-23) A0(20-23) A1(20-23) A1(20-23) A0(20-23) A0(20-23) A1(20-23) A1(20-23) A0(20-23) A0(20-23) A1(20-23) A1(20-23) A0(20-23) A0(20-23) A1(20-23) A1(20-23)
|
||||||
|
const __m512i lhs_mat_23_2_sp2 = _mm512_shuffle_epi32(lhs_mat_23_2, 245); //A2(20-23) A2(20-23) A3(20-23) A3(20-23) A2(20-23) A2(20-23) A3(20-23) A3(20-23) A2(20-23) A2(20-23) A3(20-23) A3(20-23) A2(20-23) A2(20-23) A3(20-23) A3(20-23)
|
||||||
|
|
||||||
|
const __m512i lhs_mat_01_3_sp2 = _mm512_shuffle_epi32(lhs_mat_01_3, 245); //A0(28-31) A0(28-31) A1(28-31) A1(28-31) A0(28-31) A0(28-31) A1(28-31) A1(28-31) A0(28-31) A0(28-31) A1(28-31) A1(28-31) A0(28-31) A0(28-31) A1(28-31) A1(28-31)
|
||||||
|
const __m512i lhs_mat_23_3_sp2 = _mm512_shuffle_epi32(lhs_mat_23_3, 245); //A2(28-31) A2(28-31) A3(28-31) A3(28-31) A2(28-31) A2(28-31) A3(28-31) A3(28-31) A2(28-31) A2(28-31) A3(28-31) A3(28-31) A2(28-31) A2(28-31) A3(28-31) A3(28-31)
|
||||||
|
|
||||||
|
// The values arranged in shuffle patterns are operated with dot product operation within 32 bit lane i.e corresponding bytes and multiplied and added into 32 bit integers within 32 bit lane
|
||||||
|
// Resembles MMLAs into 2x2 matrices in ARM Version
|
||||||
|
__m512i iacc_mat_00_sp1 =
|
||||||
|
_mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int32x16(lhs_mat_01_3_sp1, rhs_mat_014589CD_3_sp1), mul_sum_i8_pairs_int32x16(lhs_mat_01_2_sp1, rhs_mat_014589CD_2_sp1)), mul_sum_i8_pairs_int32x16(lhs_mat_01_1_sp1, rhs_mat_014589CD_1_sp1)), mul_sum_i8_pairs_int32x16(lhs_mat_01_0_sp1, rhs_mat_014589CD_0_sp1));
|
||||||
|
__m512i iacc_mat_01_sp1 =
|
||||||
|
_mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int32x16(lhs_mat_01_3_sp1, rhs_mat_2367ABEF_3_sp1), mul_sum_i8_pairs_int32x16(lhs_mat_01_2_sp1, rhs_mat_2367ABEF_2_sp1)), mul_sum_i8_pairs_int32x16(lhs_mat_01_1_sp1, rhs_mat_2367ABEF_1_sp1)), mul_sum_i8_pairs_int32x16(lhs_mat_01_0_sp1, rhs_mat_2367ABEF_0_sp1));
|
||||||
|
__m512i iacc_mat_10_sp1 =
|
||||||
|
_mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int32x16(lhs_mat_23_3_sp1, rhs_mat_014589CD_3_sp1), mul_sum_i8_pairs_int32x16(lhs_mat_23_2_sp1, rhs_mat_014589CD_2_sp1)), mul_sum_i8_pairs_int32x16(lhs_mat_23_1_sp1, rhs_mat_014589CD_1_sp1)), mul_sum_i8_pairs_int32x16(lhs_mat_23_0_sp1, rhs_mat_014589CD_0_sp1));
|
||||||
|
__m512i iacc_mat_11_sp1 =
|
||||||
|
_mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int32x16(lhs_mat_23_3_sp1, rhs_mat_2367ABEF_3_sp1), mul_sum_i8_pairs_int32x16(lhs_mat_23_2_sp1, rhs_mat_2367ABEF_2_sp1)), mul_sum_i8_pairs_int32x16(lhs_mat_23_1_sp1, rhs_mat_2367ABEF_1_sp1)), mul_sum_i8_pairs_int32x16(lhs_mat_23_0_sp1, rhs_mat_2367ABEF_0_sp1));
|
||||||
|
__m512i iacc_mat_00_sp2 =
|
||||||
|
_mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int32x16(lhs_mat_01_3_sp2, rhs_mat_014589CD_3_sp2), mul_sum_i8_pairs_int32x16(lhs_mat_01_2_sp2, rhs_mat_014589CD_2_sp2)), mul_sum_i8_pairs_int32x16(lhs_mat_01_1_sp2, rhs_mat_014589CD_1_sp2)), mul_sum_i8_pairs_int32x16(lhs_mat_01_0_sp2, rhs_mat_014589CD_0_sp2));
|
||||||
|
__m512i iacc_mat_01_sp2 =
|
||||||
|
_mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int32x16(lhs_mat_01_3_sp2, rhs_mat_2367ABEF_3_sp2), mul_sum_i8_pairs_int32x16(lhs_mat_01_2_sp2, rhs_mat_2367ABEF_2_sp2)), mul_sum_i8_pairs_int32x16(lhs_mat_01_1_sp2, rhs_mat_2367ABEF_1_sp2)), mul_sum_i8_pairs_int32x16(lhs_mat_01_0_sp2, rhs_mat_2367ABEF_0_sp2));
|
||||||
|
__m512i iacc_mat_10_sp2 =
|
||||||
|
_mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int32x16(lhs_mat_23_3_sp2, rhs_mat_014589CD_3_sp2), mul_sum_i8_pairs_int32x16(lhs_mat_23_2_sp2, rhs_mat_014589CD_2_sp2)), mul_sum_i8_pairs_int32x16(lhs_mat_23_1_sp2, rhs_mat_014589CD_1_sp2)), mul_sum_i8_pairs_int32x16(lhs_mat_23_0_sp2, rhs_mat_014589CD_0_sp2));
|
||||||
|
__m512i iacc_mat_11_sp2 =
|
||||||
|
_mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int32x16(lhs_mat_23_3_sp2, rhs_mat_2367ABEF_3_sp2), mul_sum_i8_pairs_int32x16(lhs_mat_23_2_sp2, rhs_mat_2367ABEF_2_sp2)), mul_sum_i8_pairs_int32x16(lhs_mat_23_1_sp2, rhs_mat_2367ABEF_1_sp2)), mul_sum_i8_pairs_int32x16(lhs_mat_23_0_sp2, rhs_mat_2367ABEF_0_sp2));
|
||||||
|
|
||||||
|
// Output of both shuffle patterns are added in order to sum dot product outputs of all 32 values in block
|
||||||
|
__m512i iacc_mat_00 = _mm512_add_epi32(iacc_mat_00_sp1, iacc_mat_00_sp2);
|
||||||
|
__m512i iacc_mat_01 = _mm512_add_epi32(iacc_mat_01_sp1, iacc_mat_01_sp2);
|
||||||
|
__m512i iacc_mat_10 = _mm512_add_epi32(iacc_mat_10_sp1, iacc_mat_10_sp2);
|
||||||
|
__m512i iacc_mat_11 = _mm512_add_epi32(iacc_mat_11_sp1, iacc_mat_11_sp2);
|
||||||
|
|
||||||
|
|
||||||
|
// Straighten out to make 4 row vectors
|
||||||
|
__m512i iacc_row_0 = _mm512_mask_blend_epi32(0xCCCC, iacc_mat_00, _mm512_shuffle_epi32(iacc_mat_01, 78));
|
||||||
|
__m512i iacc_row_1 = _mm512_mask_blend_epi32(0xCCCC, _mm512_shuffle_epi32(iacc_mat_00, 78), iacc_mat_01);
|
||||||
|
__m512i iacc_row_2 = _mm512_mask_blend_epi32(0xCCCC, iacc_mat_10, _mm512_shuffle_epi32(iacc_mat_11, 78));
|
||||||
|
__m512i iacc_row_3 = _mm512_mask_blend_epi32(0xCCCC, _mm512_shuffle_epi32(iacc_mat_10, 78), iacc_mat_11);
|
||||||
|
|
||||||
|
// Load the scale(d) values for all the 4 Q8_0 blocks and repeat it across lanes
|
||||||
|
const __m128i row_scale_f16 = _mm_shuffle_epi32(_mm_maskload_epi32((int const*)(a_ptrs[rp][b].d), loadMask), 68);
|
||||||
|
const __m512 row_scale_f32 = GGML_F32Cx16_REPEAT_LOAD(row_scale_f16);
|
||||||
|
|
||||||
|
// Multiply with appropiate scales and accumulate
|
||||||
|
acc_rows[rp * 4] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_0), _mm512_mul_ps(col_scale_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 0)), acc_rows[rp * 4]);
|
||||||
|
acc_rows[rp * 4 + 1] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_1), _mm512_mul_ps(col_scale_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 85)), acc_rows[rp * 4 + 1]);
|
||||||
|
acc_rows[rp * 4 + 2] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_2), _mm512_mul_ps(col_scale_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 170)), acc_rows[rp * 4 + 2]);
|
||||||
|
acc_rows[rp * 4 + 3] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_3), _mm512_mul_ps(col_scale_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 255)), acc_rows[rp * 4 + 3]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Store the accumulated values
|
||||||
|
for (int i = 0; i < 16; i++) {
|
||||||
|
_mm512_storeu_ps((float *)(s + ((y * 4 + i) * bs + x * 8)), acc_rows[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Take a block_q8_0x4 structures at each pass of the loop and perform dot product operation
|
||||||
|
for (; y < nr / 4; y ++) {
|
||||||
|
|
||||||
|
const block_q8_0x4 * a_ptr = a_ptr_start + (y * nb);
|
||||||
|
|
||||||
|
// Take group of two block_q4_0x8 structures at each pass of the loop and perform dot product operation
|
||||||
|
for (int64_t x = 0; x < anc / 8; x += 2) {
|
||||||
|
|
||||||
|
const block_q4_0x8 * b_ptr_0 = b_ptr_start + ((x) * b_nb);
|
||||||
|
const block_q4_0x8 * b_ptr_1 = b_ptr_start + ((x + 1) * b_nb);
|
||||||
|
|
||||||
|
// Master FP accumulators
|
||||||
|
__m512 acc_rows[4];
|
||||||
|
for (int i = 0; i < 4; i++) {
|
||||||
|
acc_rows[i] = _mm512_setzero_ps();
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int64_t b = 0; b < nb; b++) {
|
||||||
|
// Load the sixteen block_q4_0 quantized values interleaved with each other in chunks of eight - B0,B1 ....BE,BF
|
||||||
|
const __m256i rhs_raw_mat_0123_0 = _mm256_loadu_si256((const __m256i *)(b_ptr_0[b].qs));
|
||||||
|
const __m256i rhs_raw_mat_4567_0 = _mm256_loadu_si256((const __m256i *)(b_ptr_0[b].qs + 32));
|
||||||
|
const __m256i rhs_raw_mat_0123_1 = _mm256_loadu_si256((const __m256i *)(b_ptr_0[b].qs + 64));
|
||||||
|
const __m256i rhs_raw_mat_4567_1 = _mm256_loadu_si256((const __m256i *)(b_ptr_0[b].qs + 96));
|
||||||
|
|
||||||
|
const __m256i rhs_raw_mat_89AB_0 = _mm256_loadu_si256((const __m256i *)(b_ptr_1[b].qs));
|
||||||
|
const __m256i rhs_raw_mat_CDEF_0 = _mm256_loadu_si256((const __m256i *)(b_ptr_1[b].qs + 32));
|
||||||
|
const __m256i rhs_raw_mat_89AB_1 = _mm256_loadu_si256((const __m256i *)(b_ptr_1[b].qs + 64));
|
||||||
|
const __m256i rhs_raw_mat_CDEF_1 = _mm256_loadu_si256((const __m256i *)(b_ptr_1[b].qs + 96));
|
||||||
|
|
||||||
|
// Save the values in the following vectors in the formats B0B1B4B5, B2B3B6B7 for further processing and storing of valuess
|
||||||
|
const __m256i rhs_raw_mat_0145_0 = _mm256_blend_epi32(rhs_raw_mat_0123_0, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_0, requiredOrder), 240);
|
||||||
|
const __m256i rhs_raw_mat_2367_0 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_0, requiredOrder), rhs_raw_mat_4567_0, 240);
|
||||||
|
const __m256i rhs_raw_mat_0145_1 = _mm256_blend_epi32(rhs_raw_mat_0123_1, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_1, requiredOrder), 240);
|
||||||
|
const __m256i rhs_raw_mat_2367_1 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_1, requiredOrder), rhs_raw_mat_4567_1, 240);
|
||||||
|
|
||||||
|
const __m256i rhs_raw_mat_89CD_0 = _mm256_blend_epi32(rhs_raw_mat_89AB_0, _mm256_permutevar8x32_epi32(rhs_raw_mat_CDEF_0, requiredOrder), 240);
|
||||||
|
const __m256i rhs_raw_mat_ABEF_0 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_89AB_0, requiredOrder), rhs_raw_mat_CDEF_0, 240);
|
||||||
|
const __m256i rhs_raw_mat_89CD_1 = _mm256_blend_epi32(rhs_raw_mat_89AB_1, _mm256_permutevar8x32_epi32(rhs_raw_mat_CDEF_1, requiredOrder), 240);
|
||||||
|
const __m256i rhs_raw_mat_ABEF_1 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_89AB_1, requiredOrder), rhs_raw_mat_CDEF_1, 240);
|
||||||
|
|
||||||
|
const __m512i rhs_raw_mat_014589CD_0 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_0145_0), rhs_raw_mat_89CD_0, 1);
|
||||||
|
const __m512i rhs_raw_mat_2367ABEF_0 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_2367_0), rhs_raw_mat_ABEF_0, 1);
|
||||||
|
const __m512i rhs_raw_mat_014589CD_1 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_0145_1), rhs_raw_mat_89CD_1, 1);
|
||||||
|
const __m512i rhs_raw_mat_2367ABEF_1 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_2367_1), rhs_raw_mat_ABEF_1, 1);
|
||||||
|
|
||||||
|
// 4-bit -> 8-bit - Sign is maintained
|
||||||
|
const __m512i rhs_mat_014589CD_0 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(rhs_raw_mat_014589CD_0, m4bexpanded)); //B0(0-7) B1(0-7) B4(0-7) B5(0-7) B8(0-7) B9(0-7) BC(0-7) BD(0-7)
|
||||||
|
const __m512i rhs_mat_2367ABEF_0 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(rhs_raw_mat_2367ABEF_0, m4bexpanded)); //B2(0-7) B3(0-7) B6(0-7) B7(0-7) BA(0-7) BB(0-7) BE(0-7) BF(0-7)
|
||||||
|
|
||||||
|
const __m512i rhs_mat_014589CD_1 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(rhs_raw_mat_014589CD_1, m4bexpanded)); //B0(8-15) B1(8-15) B4(8-15) B5(8-15) B8(8-15) B9(8-15) BC(8-15) BD(8-15)
|
||||||
|
const __m512i rhs_mat_2367ABEF_1 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(rhs_raw_mat_2367ABEF_1, m4bexpanded)); //B2(8-15) B3(8-15) B6(8-15) B7(8-15) BA(8-15) BB(8-15) BE(8-15) BF(8-15)
|
||||||
|
|
||||||
|
const __m512i rhs_mat_014589CD_2 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_014589CD_0, 4), m4bexpanded)); //B0(16-23) B1(16-23) B4(16-23) B5(16-23) B8(16-23) B9(16-23) BC(16-23) BD(16-23)
|
||||||
|
const __m512i rhs_mat_2367ABEF_2 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_2367ABEF_0, 4), m4bexpanded)); //B2(16-23) B3(16-23) B6(16-23) B7(16-23) BA(16-23) BB(16-23) BE(16-23) BF(16-23)
|
||||||
|
|
||||||
|
const __m512i rhs_mat_014589CD_3 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_014589CD_1, 4), m4bexpanded)); //B0(24-31) B1(24-31) B4(24-31) B5(24-31) B8(24-31) B9(24-31) BC(24-31) BD(24-31)
|
||||||
|
const __m512i rhs_mat_2367ABEF_3 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_2367ABEF_1, 4), m4bexpanded)); //B2(24-31) B3(24-31) B6(24-31) B7(24-31) BA(24-31) BB(24-31) BE(24-31) BF(24-31)
|
||||||
|
|
||||||
|
// Shuffle pattern one - right side input
|
||||||
|
const __m512i rhs_mat_014589CD_0_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_0, 136); //B0(0-3) B1(0-3) B0(0-3) B1(0-3) B4(0-3) B5(0-3) B4(0-3) B5(0-3) B8(0-3) B9(0-3) B8(0-3) B9(0-3) BC(0-3) BD(0-3) BC(0-3) BD(0-3)
|
||||||
|
const __m512i rhs_mat_2367ABEF_0_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_0, 136); //B2(0-3) B3(0-3) B2(0-3) B3(0-3) B6(0-3) B7(0-3) B6(0-3) B7(0-3) BA(0-3) BB(0-3) BA(0-3) BB(0-3) BE(0-3) BF(0-3) BE(0-3) BF(0-3)
|
||||||
|
|
||||||
|
const __m512i rhs_mat_014589CD_1_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_1, 136); //B0(8-11) B1(8-11) B0(8-11) B1(8-11) B4(8-11) B5(8-11) B4(8-11) B5(8-11) B8(8-11) B9(8-11) B8(8-11) B9(8-11) BC(8-11) BD(8-11) BC(8-11) BD(8-11)
|
||||||
|
const __m512i rhs_mat_2367ABEF_1_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_1, 136); //B2(8-11) B3(8-11) B2(8-11) B3(8-11) B6(8-11) B7(8-11) B6(8-11) B7(8-11) BA(8-11) BB(8-11) BA(8-11) BB(8-11) BE(8-11) BF(8-11) BE(8-11) BF(8-11)
|
||||||
|
|
||||||
|
const __m512i rhs_mat_014589CD_2_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_2, 136); //B0(16-19) B1(16-19) B0(16-19) B1(16-19) B4(16-19) B5(16-19) B4(16-19) B5(16-19) B8(16-19) B9(16-19) B8(16-19) B9(16-19) BC(16-19) BD(16-19) BC(16-19) BD(16-19)
|
||||||
|
const __m512i rhs_mat_2367ABEF_2_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_2, 136); //B2(16-19) B3(16-19) B2(16-19) B3(16-19) B6(16-19) B7(16-19) B6(16-19) B7(16-19) BA(16-19) BB(16-19) BA(16-19) BB(16-19) BE(16-19) BF(16-19) BE(16-19) BF(16-19)
|
||||||
|
|
||||||
|
const __m512i rhs_mat_014589CD_3_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_3, 136); //B0(24-27) B1(24-27) B0(24-27) B1(24-27) B4(24-27) B5(24-27) B4(24-27) B5(24-27) B8(24-27) B9(24-27) B8(24-27) B9(24-27) BC(24-27) BD(24-27) BC(24-27) BD(24-27)
|
||||||
|
const __m512i rhs_mat_2367ABEF_3_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_3, 136); //B2(24-27) B3(24-27) B2(24-27) B3(24-27) B6(24-27) B7(24-27) B6(24-27) B7(24-27) BA(24-27) BB(24-27) BA(24-27) BB(24-27) BE(24-27) BF(24-27) BE(24-27) BF(24-27)
|
||||||
|
|
||||||
|
// Shuffle pattern two - right side input
|
||||||
|
|
||||||
|
const __m512i rhs_mat_014589CD_0_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_0, 221); //B0(4-7) B1(4-7) B0(4-7) B1(4-7) B4(4-7) B5(4-7) B4(4-7) B5(4-7) B8(4-7) B9(4-7) B8(4-7) B9(4-7) BC(4-7) BD(4-7) BC(4-7) BD(4-7)
|
||||||
|
const __m512i rhs_mat_2367ABEF_0_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_0, 221); //B2(4-7) B3(4-7) B2(4-7) B3(4-7) B6(4-7) B7(4-7) B6(4-7) B7(4-7) BA(4-7) BB(4-7) BA(4-7) BB(4-7) BE(4-7) BF(4-7) BE(4-7) BF(4-7)
|
||||||
|
|
||||||
|
const __m512i rhs_mat_014589CD_1_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_1, 221); //B0(12-15) B1(12-15) B0(12-15) B1(12-15) B4(12-15) B5(12-15) B4(12-15) B5(12-15) B8(12-15) B9(12-15) B8(12-15) B9(12-15) BC(12-15) BD(12-15) BC(12-15) BD(12-15)
|
||||||
|
const __m512i rhs_mat_2367ABEF_1_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_1, 221); //B2(12-15) B3(12-15) B2(12-15) B3(12-15) B6(12-15) B7(12-15) B6(12-15) B7(12-15) BA(12-15) BB(12-15) BA(12-15) BB(12-15) BE(12-15) BF(12-15) BE(12-15) BF(12-15)
|
||||||
|
|
||||||
|
const __m512i rhs_mat_014589CD_2_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_2, 221); //B0(20-23) B1(20-23) B0(20-23) B1(20-23) B4(20-23) B5(20-23) B4(20-23) B5(20-23) B8(20-23) B9(20-23) B8(20-23) B9(20-23) BC(20-23) BD(20-23) BC(20-23) BD(20-23)
|
||||||
|
const __m512i rhs_mat_2367ABEF_2_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_2, 221); //B2(20-23) B3(20-23) B2(20-23) B3(20-23) B6(20-23) B7(20-23) B6(20-23) B7(20-23) BA(20-23) BB(20-23) BA(20-23) BB(20-23) BE(20-23) BF(20-23) BE(20-23) BF(20-23)
|
||||||
|
|
||||||
|
const __m512i rhs_mat_014589CD_3_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_3, 221); //B0(28-31) B1(28-31) B0(28-31) B1(28-31) B4(28-31) B5(28-31) B4(28-31) B5(28-31) B8(28-31) B9(28-31) B8(28-31) B9(28-31) BC(28-31) BD(28-31) BC(28-31) BD(28-31)
|
||||||
|
const __m512i rhs_mat_2367ABEF_3_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_3, 221); //B2(28-31) B3(28-31) B2(28-31) B3(28-31) B6(28-31) B7(28-31) B6(28-31) B7(28-31) BA(28-31) BB(28-31) BA(28-31) BB(28-31) BE(28-31) BF(28-31) BE(28-31) BF(28-31)
|
||||||
|
|
||||||
|
|
||||||
|
// Scale values - Load the weight scale values of two block_q4_0x8
|
||||||
|
const __m512 col_scale_f32 = GGML_F32Cx8x2_LOAD(b_ptr_0[b].d, b_ptr_1[b].d);
|
||||||
|
|
||||||
|
// Load the four block_q4_0 quantized values interleaved with each other in chunks of eight - A0,A1,A2,A3
|
||||||
|
// Loaded as set of 128 bit vectors and repeated and stored into a 256 bit vector before again repeating into 512 bit vector
|
||||||
|
__m256i lhs_mat_ymm_0123_0 = _mm256_loadu_si256((const __m256i *)((a_ptr[b].qs)));
|
||||||
|
__m256i lhs_mat_ymm_01_0 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_0, lhs_mat_ymm_0123_0, 0);
|
||||||
|
__m256i lhs_mat_ymm_23_0 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_0, lhs_mat_ymm_0123_0, 17);
|
||||||
|
__m256i lhs_mat_ymm_0123_1 = _mm256_loadu_si256((const __m256i *)((a_ptr[b].qs + 32)));
|
||||||
|
__m256i lhs_mat_ymm_01_1 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_1, lhs_mat_ymm_0123_1, 0);
|
||||||
|
__m256i lhs_mat_ymm_23_1 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_1, lhs_mat_ymm_0123_1, 17);
|
||||||
|
__m256i lhs_mat_ymm_0123_2 = _mm256_loadu_si256((const __m256i *)((a_ptr[b].qs + 64)));
|
||||||
|
__m256i lhs_mat_ymm_01_2 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_2, lhs_mat_ymm_0123_2, 0);
|
||||||
|
__m256i lhs_mat_ymm_23_2 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_2, lhs_mat_ymm_0123_2, 17);
|
||||||
|
__m256i lhs_mat_ymm_0123_3 = _mm256_loadu_si256((const __m256i *)((a_ptr[b].qs + 96)));
|
||||||
|
__m256i lhs_mat_ymm_01_3 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_3, lhs_mat_ymm_0123_3, 0);
|
||||||
|
__m256i lhs_mat_ymm_23_3 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_3, lhs_mat_ymm_0123_3, 17);
|
||||||
|
|
||||||
|
__m512i lhs_mat_01_0 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_0), lhs_mat_ymm_01_0, 1);
|
||||||
|
__m512i lhs_mat_23_0 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_0), lhs_mat_ymm_23_0, 1);
|
||||||
|
__m512i lhs_mat_01_1 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_1), lhs_mat_ymm_01_1, 1);
|
||||||
|
__m512i lhs_mat_23_1 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_1), lhs_mat_ymm_23_1, 1);
|
||||||
|
__m512i lhs_mat_01_2 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_2), lhs_mat_ymm_01_2, 1);
|
||||||
|
__m512i lhs_mat_23_2 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_2), lhs_mat_ymm_23_2, 1);
|
||||||
|
__m512i lhs_mat_01_3 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_3), lhs_mat_ymm_01_3, 1);
|
||||||
|
__m512i lhs_mat_23_3 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_3), lhs_mat_ymm_23_3, 1);
|
||||||
|
|
||||||
|
// Shuffle pattern one - left side input
|
||||||
|
|
||||||
|
const __m512i lhs_mat_01_0_sp1 = _mm512_shuffle_epi32(lhs_mat_01_0, 160); //A0(0-3) A0(0-3) A1(0-3) A1(0-3) A0(0-3) A0(0-3) A1(0-3) A1(0-3) A0(0-3) A0(0-3) A1(0-3) A1(0-3) A0(0-3) A0(0-3) A1(0-3) A1(0-3)
|
||||||
|
const __m512i lhs_mat_23_0_sp1 = _mm512_shuffle_epi32(lhs_mat_23_0, 160); //A2(0-3) A2(0-3) A3(0-3) A3(0-3) A2(0-3) A2(0-3) A3(0-3) A3(0-3) A2(0-3) A2(0-3) A3(0-3) A3(0-3) A2(0-3) A2(0-3) A3(0-3) A3(0-3)
|
||||||
|
|
||||||
|
const __m512i lhs_mat_01_1_sp1 = _mm512_shuffle_epi32(lhs_mat_01_1, 160); //A0(8-11) A0(8-11) A1(8-11) A1(8-11) A0(8-11) A0(8-11) A1(8-11) A1(8-11) A0(8-11) A0(8-11) A1(8-11) A1(8-11) A0(8-11) A0(8-11) A1(8-11) A1(8-11)
|
||||||
|
const __m512i lhs_mat_23_1_sp1 = _mm512_shuffle_epi32(lhs_mat_23_1, 160); //A2(8-11) A2(8-11) A3(8-11) A3(8-11) A2(8-11) A2(8-11) A3(8-11) A3(8-11) A2(8-11) A2(8-11) A3(8-11) A3(8-11) A2(8-11) A2(8-11) A3(8-11) A3(8-11)
|
||||||
|
|
||||||
|
const __m512i lhs_mat_01_2_sp1 = _mm512_shuffle_epi32(lhs_mat_01_2, 160); //A0(16-19) A0(16-19) A1(16-19) A1(16-19) A0(16-19) A0(16-19) A1(16-19) A1(16-19) A0(16-19) A0(16-19) A1(16-19) A1(16-19) A0(16-19) A0(16-19) A1(16-19) A1(16-19)
|
||||||
|
const __m512i lhs_mat_23_2_sp1 = _mm512_shuffle_epi32(lhs_mat_23_2, 160); //A2(16-19) A2(16-19) A3(16-19) A3(16-19) A2(16-19) A2(16-19) A3(16-19) A3(16-19) A2(16-19) A2(16-19) A3(16-19) A3(16-19) A2(16-19) A2(16-19) A3(16-19) A3(16-19)
|
||||||
|
|
||||||
|
const __m512i lhs_mat_01_3_sp1 = _mm512_shuffle_epi32(lhs_mat_01_3, 160); //A0(24-27) A0(24-27) A1(24-27) A1(24-27) A0(24-27) A0(24-27) A1(24-27) A1(24-27) A0(24-27) A0(24-27) A1(24-27) A1(24-27) A0(24-27) A0(24-27) A1(24-27) A1(24-27)
|
||||||
|
const __m512i lhs_mat_23_3_sp1 = _mm512_shuffle_epi32(lhs_mat_23_3, 160); //A2(24-27) A2(24-27) A3(24-27) A3(24-27) A2(24-27) A2(24-27) A3(24-27) A3(24-27) A2(24-27) A2(24-27) A3(24-27) A3(24-27) A2(24-27) A2(24-27) A3(24-27) A3(24-27)
|
||||||
|
|
||||||
|
// Shuffle pattern two - left side input
|
||||||
|
|
||||||
|
const __m512i lhs_mat_01_0_sp2 = _mm512_shuffle_epi32(lhs_mat_01_0, 245); //A0(4-7) A0(4-7) A1(4-7) A1(4-7) A0(4-7) A0(4-7) A1(4-7) A1(4-7) A0(4-7) A0(4-7) A1(4-7) A1(4-7) A0(4-7) A0(4-7) A1(4-7) A1(4-7)
|
||||||
|
const __m512i lhs_mat_23_0_sp2 = _mm512_shuffle_epi32(lhs_mat_23_0, 245); //A2(4-7) A2(4-7) A3(4-7) A3(4-7) A2(4-7) A2(4-7) A3(4-7) A3(4-7) A2(4-7) A2(4-7) A3(4-7) A3(4-7) A2(4-7) A2(4-7) A3(4-7) A3(4-7)
|
||||||
|
|
||||||
|
const __m512i lhs_mat_01_1_sp2 = _mm512_shuffle_epi32(lhs_mat_01_1, 245); //A0(12-15) A0(12-15) A1(12-15) A1(12-15) A0(12-15) A0(12-15) A1(12-15) A1(12-15) A0(12-15) A0(12-15) A1(12-15) A1(12-15) A0(12-15) A0(12-15) A1(12-15) A1(12-15)
|
||||||
|
const __m512i lhs_mat_23_1_sp2 = _mm512_shuffle_epi32(lhs_mat_23_1, 245); //A2(12-15) A2(12-15) A3(12-15) A3(12-15) A2(12-15) A2(12-15) A3(12-15) A3(12-15) A2(12-15) A2(12-15) A3(12-15) A3(12-15) A2(12-15) A2(12-15) A3(12-15) A3(12-15)
|
||||||
|
|
||||||
|
const __m512i lhs_mat_01_2_sp2 = _mm512_shuffle_epi32(lhs_mat_01_2, 245); //A0(20-23) A0(20-23) A1(20-23) A1(20-23) A0(20-23) A0(20-23) A1(20-23) A1(20-23) A0(20-23) A0(20-23) A1(20-23) A1(20-23) A0(20-23) A0(20-23) A1(20-23) A1(20-23)
|
||||||
|
const __m512i lhs_mat_23_2_sp2 = _mm512_shuffle_epi32(lhs_mat_23_2, 245); //A2(20-23) A2(20-23) A3(20-23) A3(20-23) A2(20-23) A2(20-23) A3(20-23) A3(20-23) A2(20-23) A2(20-23) A3(20-23) A3(20-23) A2(20-23) A2(20-23) A3(20-23) A3(20-23)
|
||||||
|
|
||||||
|
const __m512i lhs_mat_01_3_sp2 = _mm512_shuffle_epi32(lhs_mat_01_3, 245); //A0(28-31) A0(28-31) A1(28-31) A1(28-31) A0(28-31) A0(28-31) A1(28-31) A1(28-31) A0(28-31) A0(28-31) A1(28-31) A1(28-31) A0(28-31) A0(28-31) A1(28-31) A1(28-31)
|
||||||
|
const __m512i lhs_mat_23_3_sp2 = _mm512_shuffle_epi32(lhs_mat_23_3, 245); //A2(28-31) A2(28-31) A3(28-31) A3(28-31) A2(28-31) A2(28-31) A3(28-31) A3(28-31) A2(28-31) A2(28-31) A3(28-31) A3(28-31) A2(28-31) A2(28-31) A3(28-31) A3(28-31)
|
||||||
|
|
||||||
|
// The values arranged in shuffle patterns are operated with dot product operation within 32 bit lane i.e corresponding bytes and multiplied and added into 32 bit integers within 32 bit lane
|
||||||
|
// Resembles MMLAs into 2x2 matrices in ARM Version
|
||||||
|
__m512i iacc_mat_00_sp1 =
|
||||||
|
_mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int32x16(lhs_mat_01_3_sp1, rhs_mat_014589CD_3_sp1), mul_sum_i8_pairs_int32x16(lhs_mat_01_2_sp1, rhs_mat_014589CD_2_sp1)), mul_sum_i8_pairs_int32x16(lhs_mat_01_1_sp1, rhs_mat_014589CD_1_sp1)), mul_sum_i8_pairs_int32x16(lhs_mat_01_0_sp1, rhs_mat_014589CD_0_sp1));
|
||||||
|
__m512i iacc_mat_01_sp1 =
|
||||||
|
_mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int32x16(lhs_mat_01_3_sp1, rhs_mat_2367ABEF_3_sp1), mul_sum_i8_pairs_int32x16(lhs_mat_01_2_sp1, rhs_mat_2367ABEF_2_sp1)), mul_sum_i8_pairs_int32x16(lhs_mat_01_1_sp1, rhs_mat_2367ABEF_1_sp1)), mul_sum_i8_pairs_int32x16(lhs_mat_01_0_sp1, rhs_mat_2367ABEF_0_sp1));
|
||||||
|
__m512i iacc_mat_10_sp1 =
|
||||||
|
_mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int32x16(lhs_mat_23_3_sp1, rhs_mat_014589CD_3_sp1), mul_sum_i8_pairs_int32x16(lhs_mat_23_2_sp1, rhs_mat_014589CD_2_sp1)), mul_sum_i8_pairs_int32x16(lhs_mat_23_1_sp1, rhs_mat_014589CD_1_sp1)), mul_sum_i8_pairs_int32x16(lhs_mat_23_0_sp1, rhs_mat_014589CD_0_sp1));
|
||||||
|
__m512i iacc_mat_11_sp1 =
|
||||||
|
_mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int32x16(lhs_mat_23_3_sp1, rhs_mat_2367ABEF_3_sp1), mul_sum_i8_pairs_int32x16(lhs_mat_23_2_sp1, rhs_mat_2367ABEF_2_sp1)), mul_sum_i8_pairs_int32x16(lhs_mat_23_1_sp1, rhs_mat_2367ABEF_1_sp1)), mul_sum_i8_pairs_int32x16(lhs_mat_23_0_sp1, rhs_mat_2367ABEF_0_sp1));
|
||||||
|
__m512i iacc_mat_00_sp2 =
|
||||||
|
_mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int32x16(lhs_mat_01_3_sp2, rhs_mat_014589CD_3_sp2), mul_sum_i8_pairs_int32x16(lhs_mat_01_2_sp2, rhs_mat_014589CD_2_sp2)), mul_sum_i8_pairs_int32x16(lhs_mat_01_1_sp2, rhs_mat_014589CD_1_sp2)), mul_sum_i8_pairs_int32x16(lhs_mat_01_0_sp2, rhs_mat_014589CD_0_sp2));
|
||||||
|
__m512i iacc_mat_01_sp2 =
|
||||||
|
_mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int32x16(lhs_mat_01_3_sp2, rhs_mat_2367ABEF_3_sp2), mul_sum_i8_pairs_int32x16(lhs_mat_01_2_sp2, rhs_mat_2367ABEF_2_sp2)), mul_sum_i8_pairs_int32x16(lhs_mat_01_1_sp2, rhs_mat_2367ABEF_1_sp2)), mul_sum_i8_pairs_int32x16(lhs_mat_01_0_sp2, rhs_mat_2367ABEF_0_sp2));
|
||||||
|
__m512i iacc_mat_10_sp2 =
|
||||||
|
_mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int32x16(lhs_mat_23_3_sp2, rhs_mat_014589CD_3_sp2), mul_sum_i8_pairs_int32x16(lhs_mat_23_2_sp2, rhs_mat_014589CD_2_sp2)), mul_sum_i8_pairs_int32x16(lhs_mat_23_1_sp2, rhs_mat_014589CD_1_sp2)), mul_sum_i8_pairs_int32x16(lhs_mat_23_0_sp2, rhs_mat_014589CD_0_sp2));
|
||||||
|
__m512i iacc_mat_11_sp2 =
|
||||||
|
_mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int32x16(lhs_mat_23_3_sp2, rhs_mat_2367ABEF_3_sp2), mul_sum_i8_pairs_int32x16(lhs_mat_23_2_sp2, rhs_mat_2367ABEF_2_sp2)), mul_sum_i8_pairs_int32x16(lhs_mat_23_1_sp2, rhs_mat_2367ABEF_1_sp2)), mul_sum_i8_pairs_int32x16(lhs_mat_23_0_sp2, rhs_mat_2367ABEF_0_sp2));
|
||||||
|
|
||||||
|
// Output of both shuffle patterns are added in order to sum dot product outputs of all 32 values in block
|
||||||
|
__m512i iacc_mat_00 = _mm512_add_epi32(iacc_mat_00_sp1, iacc_mat_00_sp2);
|
||||||
|
__m512i iacc_mat_01 = _mm512_add_epi32(iacc_mat_01_sp1, iacc_mat_01_sp2);
|
||||||
|
__m512i iacc_mat_10 = _mm512_add_epi32(iacc_mat_10_sp1, iacc_mat_10_sp2);
|
||||||
|
__m512i iacc_mat_11 = _mm512_add_epi32(iacc_mat_11_sp1, iacc_mat_11_sp2);
|
||||||
|
|
||||||
|
|
||||||
|
// Straighten out to make 4 row vectors
|
||||||
|
__m512i iacc_row_0 = _mm512_mask_blend_epi32(0xCCCC, iacc_mat_00, _mm512_shuffle_epi32(iacc_mat_01, 78));
|
||||||
|
__m512i iacc_row_1 = _mm512_mask_blend_epi32(0xCCCC, _mm512_shuffle_epi32(iacc_mat_00, 78), iacc_mat_01);
|
||||||
|
__m512i iacc_row_2 = _mm512_mask_blend_epi32(0xCCCC, iacc_mat_10, _mm512_shuffle_epi32(iacc_mat_11, 78));
|
||||||
|
__m512i iacc_row_3 = _mm512_mask_blend_epi32(0xCCCC, _mm512_shuffle_epi32(iacc_mat_10, 78), iacc_mat_11);
|
||||||
|
|
||||||
|
// Load the scale(d) values for all the 4 Q8_0 blocks and repeat it across lanes
|
||||||
|
const __m128i row_scale_f16 = _mm_shuffle_epi32(_mm_maskload_epi32((int const*)(a_ptr[b].d), loadMask), 68);
|
||||||
|
const __m512 row_scale_f32 = GGML_F32Cx16_REPEAT_LOAD(row_scale_f16);
|
||||||
|
|
||||||
|
// Multiply with appropiate scales and accumulate
|
||||||
|
acc_rows[0] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_0), _mm512_mul_ps(col_scale_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 0)), acc_rows[0]);
|
||||||
|
acc_rows[1] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_1), _mm512_mul_ps(col_scale_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 85)), acc_rows[1]);
|
||||||
|
acc_rows[2] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_2), _mm512_mul_ps(col_scale_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 170)), acc_rows[2]);
|
||||||
|
acc_rows[3] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_3), _mm512_mul_ps(col_scale_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 255)), acc_rows[3]);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Store the accumulated values
|
||||||
|
for (int i = 0; i < 4; i++) {
|
||||||
|
_mm512_storeu_ps((float *)(s + ((y * 4 + i) * bs + x * 8)), acc_rows[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (anc != nc) {
|
||||||
|
xstart = anc/8;
|
||||||
|
y = 0;
|
||||||
|
}
|
||||||
|
#endif // __AVX512F__
|
||||||
|
|
||||||
// Take group of four block_q8_0x4 structures at each pass of the loop and perform dot product operation
|
// Take group of four block_q8_0x4 structures at each pass of the loop and perform dot product operation
|
||||||
int anr = nr - nr %16; // Used to align nr with boundary of 16
|
|
||||||
|
|
||||||
for (; y < anr / 4; y += 4) {
|
for (; y < anr / 4; y += 4) {
|
||||||
const block_q8_0x4 * a_ptrs[4];
|
const block_q8_0x4 * a_ptrs[4];
|
||||||
|
@ -2435,7 +2859,7 @@ void ggml_gemm_q4_0_8x8_q8_0(int n, float * restrict s, size_t bs, const void *
|
||||||
}
|
}
|
||||||
|
|
||||||
// Take group of eight block_q4_0x8 structures at each pass of the loop and perform dot product operation
|
// Take group of eight block_q4_0x8 structures at each pass of the loop and perform dot product operation
|
||||||
for (int64_t x = 0; x < nc / 8; x++) {
|
for (int64_t x = xstart; x < nc / 8; x++) {
|
||||||
|
|
||||||
const block_q4_0x8 * b_ptr = b_ptr_start + (x * b_nb);
|
const block_q4_0x8 * b_ptr = b_ptr_start + (x * b_nb);
|
||||||
|
|
||||||
|
@ -2547,21 +2971,21 @@ void ggml_gemm_q4_0_8x8_q8_0(int n, float * restrict s, size_t bs, const void *
|
||||||
// The values arranged in shuffle patterns are operated with dot product operation within 32 bit lane i.e corresponding bytes and multiplied and added into 32 bit integers within 32 bit lane
|
// The values arranged in shuffle patterns are operated with dot product operation within 32 bit lane i.e corresponding bytes and multiplied and added into 32 bit integers within 32 bit lane
|
||||||
// Resembles MMLAs into 2x2 matrices in ARM Version
|
// Resembles MMLAs into 2x2 matrices in ARM Version
|
||||||
__m256i iacc_mat_00_sp1 =
|
__m256i iacc_mat_00_sp1 =
|
||||||
_mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int(lhs_mat_01_3_sp1, rhs_mat_0145_3_sp1), mul_sum_i8_pairs_int(lhs_mat_01_2_sp1, rhs_mat_0145_2_sp1)), mul_sum_i8_pairs_int(lhs_mat_01_1_sp1, rhs_mat_0145_1_sp1)), mul_sum_i8_pairs_int(lhs_mat_01_0_sp1, rhs_mat_0145_0_sp1));
|
_mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int32x8(lhs_mat_01_3_sp1, rhs_mat_0145_3_sp1), mul_sum_i8_pairs_int32x8(lhs_mat_01_2_sp1, rhs_mat_0145_2_sp1)), mul_sum_i8_pairs_int32x8(lhs_mat_01_1_sp1, rhs_mat_0145_1_sp1)), mul_sum_i8_pairs_int32x8(lhs_mat_01_0_sp1, rhs_mat_0145_0_sp1));
|
||||||
__m256i iacc_mat_01_sp1 =
|
__m256i iacc_mat_01_sp1 =
|
||||||
_mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int(lhs_mat_01_3_sp1, rhs_mat_2367_3_sp1), mul_sum_i8_pairs_int(lhs_mat_01_2_sp1, rhs_mat_2367_2_sp1)), mul_sum_i8_pairs_int(lhs_mat_01_1_sp1, rhs_mat_2367_1_sp1)), mul_sum_i8_pairs_int(lhs_mat_01_0_sp1, rhs_mat_2367_0_sp1));
|
_mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int32x8(lhs_mat_01_3_sp1, rhs_mat_2367_3_sp1), mul_sum_i8_pairs_int32x8(lhs_mat_01_2_sp1, rhs_mat_2367_2_sp1)), mul_sum_i8_pairs_int32x8(lhs_mat_01_1_sp1, rhs_mat_2367_1_sp1)), mul_sum_i8_pairs_int32x8(lhs_mat_01_0_sp1, rhs_mat_2367_0_sp1));
|
||||||
__m256i iacc_mat_10_sp1 =
|
__m256i iacc_mat_10_sp1 =
|
||||||
_mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int(lhs_mat_23_3_sp1, rhs_mat_0145_3_sp1), mul_sum_i8_pairs_int(lhs_mat_23_2_sp1, rhs_mat_0145_2_sp1)), mul_sum_i8_pairs_int(lhs_mat_23_1_sp1, rhs_mat_0145_1_sp1)), mul_sum_i8_pairs_int(lhs_mat_23_0_sp1, rhs_mat_0145_0_sp1));
|
_mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int32x8(lhs_mat_23_3_sp1, rhs_mat_0145_3_sp1), mul_sum_i8_pairs_int32x8(lhs_mat_23_2_sp1, rhs_mat_0145_2_sp1)), mul_sum_i8_pairs_int32x8(lhs_mat_23_1_sp1, rhs_mat_0145_1_sp1)), mul_sum_i8_pairs_int32x8(lhs_mat_23_0_sp1, rhs_mat_0145_0_sp1));
|
||||||
__m256i iacc_mat_11_sp1 =
|
__m256i iacc_mat_11_sp1 =
|
||||||
_mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int(lhs_mat_23_3_sp1, rhs_mat_2367_3_sp1), mul_sum_i8_pairs_int(lhs_mat_23_2_sp1, rhs_mat_2367_2_sp1)), mul_sum_i8_pairs_int(lhs_mat_23_1_sp1, rhs_mat_2367_1_sp1)), mul_sum_i8_pairs_int(lhs_mat_23_0_sp1, rhs_mat_2367_0_sp1));
|
_mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int32x8(lhs_mat_23_3_sp1, rhs_mat_2367_3_sp1), mul_sum_i8_pairs_int32x8(lhs_mat_23_2_sp1, rhs_mat_2367_2_sp1)), mul_sum_i8_pairs_int32x8(lhs_mat_23_1_sp1, rhs_mat_2367_1_sp1)), mul_sum_i8_pairs_int32x8(lhs_mat_23_0_sp1, rhs_mat_2367_0_sp1));
|
||||||
__m256i iacc_mat_00_sp2 =
|
__m256i iacc_mat_00_sp2 =
|
||||||
_mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int(lhs_mat_01_3_sp2, rhs_mat_0145_3_sp2), mul_sum_i8_pairs_int(lhs_mat_01_2_sp2, rhs_mat_0145_2_sp2)), mul_sum_i8_pairs_int(lhs_mat_01_1_sp2, rhs_mat_0145_1_sp2)), mul_sum_i8_pairs_int(lhs_mat_01_0_sp2, rhs_mat_0145_0_sp2));
|
_mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int32x8(lhs_mat_01_3_sp2, rhs_mat_0145_3_sp2), mul_sum_i8_pairs_int32x8(lhs_mat_01_2_sp2, rhs_mat_0145_2_sp2)), mul_sum_i8_pairs_int32x8(lhs_mat_01_1_sp2, rhs_mat_0145_1_sp2)), mul_sum_i8_pairs_int32x8(lhs_mat_01_0_sp2, rhs_mat_0145_0_sp2));
|
||||||
__m256i iacc_mat_01_sp2 =
|
__m256i iacc_mat_01_sp2 =
|
||||||
_mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int(lhs_mat_01_3_sp2, rhs_mat_2367_3_sp2), mul_sum_i8_pairs_int(lhs_mat_01_2_sp2, rhs_mat_2367_2_sp2)), mul_sum_i8_pairs_int(lhs_mat_01_1_sp2, rhs_mat_2367_1_sp2)), mul_sum_i8_pairs_int(lhs_mat_01_0_sp2, rhs_mat_2367_0_sp2));
|
_mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int32x8(lhs_mat_01_3_sp2, rhs_mat_2367_3_sp2), mul_sum_i8_pairs_int32x8(lhs_mat_01_2_sp2, rhs_mat_2367_2_sp2)), mul_sum_i8_pairs_int32x8(lhs_mat_01_1_sp2, rhs_mat_2367_1_sp2)), mul_sum_i8_pairs_int32x8(lhs_mat_01_0_sp2, rhs_mat_2367_0_sp2));
|
||||||
__m256i iacc_mat_10_sp2 =
|
__m256i iacc_mat_10_sp2 =
|
||||||
_mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int(lhs_mat_23_3_sp2, rhs_mat_0145_3_sp2), mul_sum_i8_pairs_int(lhs_mat_23_2_sp2, rhs_mat_0145_2_sp2)), mul_sum_i8_pairs_int(lhs_mat_23_1_sp2, rhs_mat_0145_1_sp2)), mul_sum_i8_pairs_int(lhs_mat_23_0_sp2, rhs_mat_0145_0_sp2));
|
_mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int32x8(lhs_mat_23_3_sp2, rhs_mat_0145_3_sp2), mul_sum_i8_pairs_int32x8(lhs_mat_23_2_sp2, rhs_mat_0145_2_sp2)), mul_sum_i8_pairs_int32x8(lhs_mat_23_1_sp2, rhs_mat_0145_1_sp2)), mul_sum_i8_pairs_int32x8(lhs_mat_23_0_sp2, rhs_mat_0145_0_sp2));
|
||||||
__m256i iacc_mat_11_sp2 =
|
__m256i iacc_mat_11_sp2 =
|
||||||
_mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int(lhs_mat_23_3_sp2, rhs_mat_2367_3_sp2), mul_sum_i8_pairs_int(lhs_mat_23_2_sp2, rhs_mat_2367_2_sp2)), mul_sum_i8_pairs_int(lhs_mat_23_1_sp2, rhs_mat_2367_1_sp2)), mul_sum_i8_pairs_int(lhs_mat_23_0_sp2, rhs_mat_2367_0_sp2));
|
_mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int32x8(lhs_mat_23_3_sp2, rhs_mat_2367_3_sp2), mul_sum_i8_pairs_int32x8(lhs_mat_23_2_sp2, rhs_mat_2367_2_sp2)), mul_sum_i8_pairs_int32x8(lhs_mat_23_1_sp2, rhs_mat_2367_1_sp2)), mul_sum_i8_pairs_int32x8(lhs_mat_23_0_sp2, rhs_mat_2367_0_sp2));
|
||||||
|
|
||||||
// Output of both shuffle patterns are added in order to sum dot product outputs of all 32 values in block
|
// Output of both shuffle patterns are added in order to sum dot product outputs of all 32 values in block
|
||||||
__m256i iacc_mat_00 = _mm256_add_epi32(iacc_mat_00_sp1, iacc_mat_00_sp2);
|
__m256i iacc_mat_00 = _mm256_add_epi32(iacc_mat_00_sp1, iacc_mat_00_sp2);
|
||||||
|
@ -2599,7 +3023,7 @@ void ggml_gemm_q4_0_8x8_q8_0(int n, float * restrict s, size_t bs, const void *
|
||||||
const block_q8_0x4 * a_ptr = a_ptr_start + (y * nb);
|
const block_q8_0x4 * a_ptr = a_ptr_start + (y * nb);
|
||||||
|
|
||||||
// Load the eight block_q4_0 quantized values interleaved with each other in chunks of eight - B0,B1 ....B6,B7
|
// Load the eight block_q4_0 quantized values interleaved with each other in chunks of eight - B0,B1 ....B6,B7
|
||||||
for (int64_t x = 0; x < nc / 8; x++) {
|
for (int64_t x = xstart; x < nc / 8; x++) {
|
||||||
|
|
||||||
const block_q4_0x8 * b_ptr = b_ptr_start + (x * b_nb);
|
const block_q4_0x8 * b_ptr = b_ptr_start + (x * b_nb);
|
||||||
|
|
||||||
|
@ -2711,21 +3135,21 @@ void ggml_gemm_q4_0_8x8_q8_0(int n, float * restrict s, size_t bs, const void *
|
||||||
// The values arranged in shuffle patterns are operated with dot product operation within 32 bit lane i.e corresponding bytes and multiplied and added into 32 bit integers within 32 bit lane
|
// The values arranged in shuffle patterns are operated with dot product operation within 32 bit lane i.e corresponding bytes and multiplied and added into 32 bit integers within 32 bit lane
|
||||||
// Resembles MMLAs into 2x2 matrices in ARM Version
|
// Resembles MMLAs into 2x2 matrices in ARM Version
|
||||||
__m256i iacc_mat_00_sp1 =
|
__m256i iacc_mat_00_sp1 =
|
||||||
_mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int(lhs_mat_01_3_sp1, rhs_mat_0145_3_sp1), mul_sum_i8_pairs_int(lhs_mat_01_2_sp1, rhs_mat_0145_2_sp1)), mul_sum_i8_pairs_int(lhs_mat_01_1_sp1, rhs_mat_0145_1_sp1)), mul_sum_i8_pairs_int(lhs_mat_01_0_sp1, rhs_mat_0145_0_sp1));
|
_mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int32x8(lhs_mat_01_3_sp1, rhs_mat_0145_3_sp1), mul_sum_i8_pairs_int32x8(lhs_mat_01_2_sp1, rhs_mat_0145_2_sp1)), mul_sum_i8_pairs_int32x8(lhs_mat_01_1_sp1, rhs_mat_0145_1_sp1)), mul_sum_i8_pairs_int32x8(lhs_mat_01_0_sp1, rhs_mat_0145_0_sp1));
|
||||||
__m256i iacc_mat_01_sp1 =
|
__m256i iacc_mat_01_sp1 =
|
||||||
_mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int(lhs_mat_01_3_sp1, rhs_mat_2367_3_sp1), mul_sum_i8_pairs_int(lhs_mat_01_2_sp1, rhs_mat_2367_2_sp1)), mul_sum_i8_pairs_int(lhs_mat_01_1_sp1, rhs_mat_2367_1_sp1)), mul_sum_i8_pairs_int(lhs_mat_01_0_sp1, rhs_mat_2367_0_sp1));
|
_mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int32x8(lhs_mat_01_3_sp1, rhs_mat_2367_3_sp1), mul_sum_i8_pairs_int32x8(lhs_mat_01_2_sp1, rhs_mat_2367_2_sp1)), mul_sum_i8_pairs_int32x8(lhs_mat_01_1_sp1, rhs_mat_2367_1_sp1)), mul_sum_i8_pairs_int32x8(lhs_mat_01_0_sp1, rhs_mat_2367_0_sp1));
|
||||||
__m256i iacc_mat_10_sp1 =
|
__m256i iacc_mat_10_sp1 =
|
||||||
_mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int(lhs_mat_23_3_sp1, rhs_mat_0145_3_sp1), mul_sum_i8_pairs_int(lhs_mat_23_2_sp1, rhs_mat_0145_2_sp1)), mul_sum_i8_pairs_int(lhs_mat_23_1_sp1, rhs_mat_0145_1_sp1)), mul_sum_i8_pairs_int(lhs_mat_23_0_sp1, rhs_mat_0145_0_sp1));
|
_mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int32x8(lhs_mat_23_3_sp1, rhs_mat_0145_3_sp1), mul_sum_i8_pairs_int32x8(lhs_mat_23_2_sp1, rhs_mat_0145_2_sp1)), mul_sum_i8_pairs_int32x8(lhs_mat_23_1_sp1, rhs_mat_0145_1_sp1)), mul_sum_i8_pairs_int32x8(lhs_mat_23_0_sp1, rhs_mat_0145_0_sp1));
|
||||||
__m256i iacc_mat_11_sp1 =
|
__m256i iacc_mat_11_sp1 =
|
||||||
_mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int(lhs_mat_23_3_sp1, rhs_mat_2367_3_sp1), mul_sum_i8_pairs_int(lhs_mat_23_2_sp1, rhs_mat_2367_2_sp1)), mul_sum_i8_pairs_int(lhs_mat_23_1_sp1, rhs_mat_2367_1_sp1)), mul_sum_i8_pairs_int(lhs_mat_23_0_sp1, rhs_mat_2367_0_sp1));
|
_mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int32x8(lhs_mat_23_3_sp1, rhs_mat_2367_3_sp1), mul_sum_i8_pairs_int32x8(lhs_mat_23_2_sp1, rhs_mat_2367_2_sp1)), mul_sum_i8_pairs_int32x8(lhs_mat_23_1_sp1, rhs_mat_2367_1_sp1)), mul_sum_i8_pairs_int32x8(lhs_mat_23_0_sp1, rhs_mat_2367_0_sp1));
|
||||||
__m256i iacc_mat_00_sp2 =
|
__m256i iacc_mat_00_sp2 =
|
||||||
_mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int(lhs_mat_01_3_sp2, rhs_mat_0145_3_sp2), mul_sum_i8_pairs_int(lhs_mat_01_2_sp2, rhs_mat_0145_2_sp2)), mul_sum_i8_pairs_int(lhs_mat_01_1_sp2, rhs_mat_0145_1_sp2)), mul_sum_i8_pairs_int(lhs_mat_01_0_sp2, rhs_mat_0145_0_sp2));
|
_mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int32x8(lhs_mat_01_3_sp2, rhs_mat_0145_3_sp2), mul_sum_i8_pairs_int32x8(lhs_mat_01_2_sp2, rhs_mat_0145_2_sp2)), mul_sum_i8_pairs_int32x8(lhs_mat_01_1_sp2, rhs_mat_0145_1_sp2)), mul_sum_i8_pairs_int32x8(lhs_mat_01_0_sp2, rhs_mat_0145_0_sp2));
|
||||||
__m256i iacc_mat_01_sp2 =
|
__m256i iacc_mat_01_sp2 =
|
||||||
_mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int(lhs_mat_01_3_sp2, rhs_mat_2367_3_sp2), mul_sum_i8_pairs_int(lhs_mat_01_2_sp2, rhs_mat_2367_2_sp2)), mul_sum_i8_pairs_int(lhs_mat_01_1_sp2, rhs_mat_2367_1_sp2)), mul_sum_i8_pairs_int(lhs_mat_01_0_sp2, rhs_mat_2367_0_sp2));
|
_mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int32x8(lhs_mat_01_3_sp2, rhs_mat_2367_3_sp2), mul_sum_i8_pairs_int32x8(lhs_mat_01_2_sp2, rhs_mat_2367_2_sp2)), mul_sum_i8_pairs_int32x8(lhs_mat_01_1_sp2, rhs_mat_2367_1_sp2)), mul_sum_i8_pairs_int32x8(lhs_mat_01_0_sp2, rhs_mat_2367_0_sp2));
|
||||||
__m256i iacc_mat_10_sp2 =
|
__m256i iacc_mat_10_sp2 =
|
||||||
_mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int(lhs_mat_23_3_sp2, rhs_mat_0145_3_sp2), mul_sum_i8_pairs_int(lhs_mat_23_2_sp2, rhs_mat_0145_2_sp2)), mul_sum_i8_pairs_int(lhs_mat_23_1_sp2, rhs_mat_0145_1_sp2)), mul_sum_i8_pairs_int(lhs_mat_23_0_sp2, rhs_mat_0145_0_sp2));
|
_mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int32x8(lhs_mat_23_3_sp2, rhs_mat_0145_3_sp2), mul_sum_i8_pairs_int32x8(lhs_mat_23_2_sp2, rhs_mat_0145_2_sp2)), mul_sum_i8_pairs_int32x8(lhs_mat_23_1_sp2, rhs_mat_0145_1_sp2)), mul_sum_i8_pairs_int32x8(lhs_mat_23_0_sp2, rhs_mat_0145_0_sp2));
|
||||||
__m256i iacc_mat_11_sp2 =
|
__m256i iacc_mat_11_sp2 =
|
||||||
_mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int(lhs_mat_23_3_sp2, rhs_mat_2367_3_sp2), mul_sum_i8_pairs_int(lhs_mat_23_2_sp2, rhs_mat_2367_2_sp2)), mul_sum_i8_pairs_int(lhs_mat_23_1_sp2, rhs_mat_2367_1_sp2)), mul_sum_i8_pairs_int(lhs_mat_23_0_sp2, rhs_mat_2367_0_sp2));
|
_mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int32x8(lhs_mat_23_3_sp2, rhs_mat_2367_3_sp2), mul_sum_i8_pairs_int32x8(lhs_mat_23_2_sp2, rhs_mat_2367_2_sp2)), mul_sum_i8_pairs_int32x8(lhs_mat_23_1_sp2, rhs_mat_2367_1_sp2)), mul_sum_i8_pairs_int32x8(lhs_mat_23_0_sp2, rhs_mat_2367_0_sp2));
|
||||||
|
|
||||||
// Output of both shuffle patterns are added in order to sum dot product outputs of all 32 values in block
|
// Output of both shuffle patterns are added in order to sum dot product outputs of all 32 values in block
|
||||||
__m256i iacc_mat_00 = _mm256_add_epi32(iacc_mat_00_sp1, iacc_mat_00_sp2);
|
__m256i iacc_mat_00 = _mm256_add_epi32(iacc_mat_00_sp1, iacc_mat_00_sp2);
|
||||||
|
@ -2756,7 +3180,9 @@ void ggml_gemm_q4_0_8x8_q8_0(int n, float * restrict s, size_t bs, const void *
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
#else
|
return;
|
||||||
|
}
|
||||||
|
#endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__)
|
||||||
float sumf[4][8];
|
float sumf[4][8];
|
||||||
int sumi;
|
int sumi;
|
||||||
|
|
||||||
|
@ -2789,5 +3215,4 @@ void ggml_gemm_q4_0_8x8_q8_0(int n, float * restrict s, size_t bs, const void *
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
#endif
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -294,6 +294,12 @@ static void ggml_dyn_tallocr_reset(struct ggml_dyn_tallocr * alloc) {
|
||||||
alloc->free_blocks[0].offset = 0;
|
alloc->free_blocks[0].offset = 0;
|
||||||
alloc->free_blocks[0].size = SIZE_MAX/2; // restrict maximum size of a measure allocator to half size_t max to avoid overflows
|
alloc->free_blocks[0].size = SIZE_MAX/2; // restrict maximum size of a measure allocator to half size_t max to avoid overflows
|
||||||
alloc->max_size = 0;
|
alloc->max_size = 0;
|
||||||
|
|
||||||
|
#ifdef GGML_ALLOCATOR_DEBUG
|
||||||
|
for (int i = 0; i < 1024; i++) {
|
||||||
|
alloc->allocated_tensors[i].tensor = NULL;
|
||||||
|
}
|
||||||
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
static struct ggml_dyn_tallocr * ggml_dyn_tallocr_new(size_t alignment) {
|
static struct ggml_dyn_tallocr * ggml_dyn_tallocr_new(size_t alignment) {
|
||||||
|
|
|
@ -227,6 +227,7 @@ struct ggml_backend_cann_context {
|
||||||
* @brief Destructor for cleaning up resources.
|
* @brief Destructor for cleaning up resources.
|
||||||
*/
|
*/
|
||||||
~ggml_backend_cann_context() {
|
~ggml_backend_cann_context() {
|
||||||
|
ggml_cann_set_device(device);
|
||||||
if (copy_event != nullptr) {
|
if (copy_event != nullptr) {
|
||||||
ACL_CHECK(aclrtDestroyEvent(copy_event));
|
ACL_CHECK(aclrtDestroyEvent(copy_event));
|
||||||
}
|
}
|
||||||
|
|
|
@ -36,6 +36,7 @@ bool g_mul_mat_q = false;
|
||||||
#include "ggml-cuda/tsembd.cuh"
|
#include "ggml-cuda/tsembd.cuh"
|
||||||
#include "ggml-cuda/unary.cuh"
|
#include "ggml-cuda/unary.cuh"
|
||||||
#include "ggml-cuda/upscale.cuh"
|
#include "ggml-cuda/upscale.cuh"
|
||||||
|
#include "ggml-cuda/rwkv-wkv.cuh"
|
||||||
|
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include <array>
|
#include <array>
|
||||||
|
@ -137,7 +138,7 @@ static cudaError_t ggml_cuda_device_malloc(void ** ptr, size_t size, int device)
|
||||||
return res;
|
return res;
|
||||||
#else
|
#else
|
||||||
|
|
||||||
#if !defined(GGML_USE_HIPBLAS) && !defined(GGML_USE_MUSA)
|
#if !defined(GGML_USE_HIPBLAS)
|
||||||
cudaError_t err;
|
cudaError_t err;
|
||||||
if (getenv("GGML_CUDA_ENABLE_UNIFIED_MEMORY") != nullptr)
|
if (getenv("GGML_CUDA_ENABLE_UNIFIED_MEMORY") != nullptr)
|
||||||
{
|
{
|
||||||
|
@ -150,7 +151,7 @@ static cudaError_t ggml_cuda_device_malloc(void ** ptr, size_t size, int device)
|
||||||
return err;
|
return err;
|
||||||
#else
|
#else
|
||||||
return cudaMalloc(ptr, size);
|
return cudaMalloc(ptr, size);
|
||||||
#endif // !defined(GGML_USE_HIPBLAS) && !defined(GGML_USE_MUSA)
|
#endif // !defined(GGML_USE_HIPBLAS)
|
||||||
|
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
@ -188,7 +189,7 @@ static ggml_cuda_device_info ggml_cuda_init() {
|
||||||
for (int id = 0; id < info.device_count; ++id) {
|
for (int id = 0; id < info.device_count; ++id) {
|
||||||
int device_vmm = 0;
|
int device_vmm = 0;
|
||||||
|
|
||||||
#if !defined(GGML_USE_HIPBLAS) && !defined(GGML_CUDA_NO_VMM) && !defined(GGML_USE_MUSA)
|
#if !defined(GGML_USE_HIPBLAS) && !defined(GGML_CUDA_NO_VMM)
|
||||||
CUdevice device;
|
CUdevice device;
|
||||||
CU_CHECK(cuDeviceGet(&device, id));
|
CU_CHECK(cuDeviceGet(&device, id));
|
||||||
CU_CHECK(cuDeviceGetAttribute(&device_vmm, CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED, device));
|
CU_CHECK(cuDeviceGetAttribute(&device_vmm, CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED, device));
|
||||||
|
@ -200,7 +201,7 @@ static ggml_cuda_device_info ggml_cuda_init() {
|
||||||
alloc_prop.location.id = id;
|
alloc_prop.location.id = id;
|
||||||
CU_CHECK(cuMemGetAllocationGranularity(&info.devices[id].vmm_granularity, &alloc_prop, CU_MEM_ALLOC_GRANULARITY_RECOMMENDED));
|
CU_CHECK(cuMemGetAllocationGranularity(&info.devices[id].vmm_granularity, &alloc_prop, CU_MEM_ALLOC_GRANULARITY_RECOMMENDED));
|
||||||
}
|
}
|
||||||
#endif // !defined(GGML_USE_HIPBLAS) && !defined(GGML_CUDA_NO_VMM) && !defined(GGML_USE_MUSA)
|
#endif // !defined(GGML_USE_HIPBLAS) && !defined(GGML_CUDA_NO_VMM)
|
||||||
info.devices[id].vmm = !!device_vmm;
|
info.devices[id].vmm = !!device_vmm;
|
||||||
|
|
||||||
cudaDeviceProp prop;
|
cudaDeviceProp prop;
|
||||||
|
@ -334,7 +335,7 @@ struct ggml_cuda_pool_leg : public ggml_cuda_pool {
|
||||||
};
|
};
|
||||||
|
|
||||||
// pool with virtual memory
|
// pool with virtual memory
|
||||||
#if !defined(GGML_USE_HIPBLAS) && !defined(GGML_CUDA_NO_VMM) && !defined(GGML_USE_MUSA)
|
#if !defined(GGML_USE_HIPBLAS) && !defined(GGML_CUDA_NO_VMM)
|
||||||
struct ggml_cuda_pool_vmm : public ggml_cuda_pool {
|
struct ggml_cuda_pool_vmm : public ggml_cuda_pool {
|
||||||
static const size_t CUDA_POOL_VMM_MAX_SIZE = 1ull << 35; // 32 GB
|
static const size_t CUDA_POOL_VMM_MAX_SIZE = 1ull << 35; // 32 GB
|
||||||
|
|
||||||
|
@ -428,14 +429,14 @@ struct ggml_cuda_pool_vmm : public ggml_cuda_pool {
|
||||||
GGML_ASSERT(ptr == (void *) (pool_addr + pool_used));
|
GGML_ASSERT(ptr == (void *) (pool_addr + pool_used));
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
#endif // !defined(GGML_USE_HIPBLAS) && !defined(GGML_CUDA_NO_VMM) && !defined(GGML_USE_MUSA)
|
#endif // !defined(GGML_USE_HIPBLAS) && !defined(GGML_CUDA_NO_VMM)
|
||||||
|
|
||||||
std::unique_ptr<ggml_cuda_pool> ggml_backend_cuda_context::new_pool_for_device(int device) {
|
std::unique_ptr<ggml_cuda_pool> ggml_backend_cuda_context::new_pool_for_device(int device) {
|
||||||
#if !defined(GGML_USE_HIPBLAS) && !defined(GGML_CUDA_NO_VMM) && !defined(GGML_USE_MUSA)
|
#if !defined(GGML_USE_HIPBLAS) && !defined(GGML_CUDA_NO_VMM)
|
||||||
if (ggml_cuda_info().devices[device].vmm) {
|
if (ggml_cuda_info().devices[device].vmm) {
|
||||||
return std::unique_ptr<ggml_cuda_pool>(new ggml_cuda_pool_vmm(device));
|
return std::unique_ptr<ggml_cuda_pool>(new ggml_cuda_pool_vmm(device));
|
||||||
}
|
}
|
||||||
#endif // !defined(GGML_USE_HIPBLAS) && !defined(GGML_CUDA_NO_VMM) && !defined(GGML_USE_MUSA)
|
#endif // !defined(GGML_USE_HIPBLAS) && !defined(GGML_CUDA_NO_VMM)
|
||||||
return std::unique_ptr<ggml_cuda_pool>(new ggml_cuda_pool_leg(device));
|
return std::unique_ptr<ggml_cuda_pool>(new ggml_cuda_pool_leg(device));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -2247,6 +2248,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
|
||||||
case GGML_UNARY_OP_HARDSWISH:
|
case GGML_UNARY_OP_HARDSWISH:
|
||||||
ggml_cuda_op_hardswish(ctx, dst);
|
ggml_cuda_op_hardswish(ctx, dst);
|
||||||
break;
|
break;
|
||||||
|
case GGML_UNARY_OP_EXP:
|
||||||
|
ggml_cuda_op_exp(ctx, dst);
|
||||||
|
break;
|
||||||
default:
|
default:
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
@ -2349,6 +2353,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
|
||||||
case GGML_OP_CROSS_ENTROPY_LOSS:
|
case GGML_OP_CROSS_ENTROPY_LOSS:
|
||||||
ggml_cuda_cross_entropy_loss(ctx, dst);
|
ggml_cuda_cross_entropy_loss(ctx, dst);
|
||||||
break;
|
break;
|
||||||
|
case GGML_OP_RWKV_WKV:
|
||||||
|
ggml_cuda_op_rwkv_wkv(ctx, dst);
|
||||||
|
break;
|
||||||
case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
|
case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
|
||||||
ggml_cuda_cross_entropy_loss_back(ctx, dst);
|
ggml_cuda_cross_entropy_loss_back(ctx, dst);
|
||||||
break;
|
break;
|
||||||
|
@ -2810,6 +2817,7 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
|
||||||
case GGML_UNARY_OP_HARDSWISH:
|
case GGML_UNARY_OP_HARDSWISH:
|
||||||
case GGML_UNARY_OP_GELU_QUICK:
|
case GGML_UNARY_OP_GELU_QUICK:
|
||||||
case GGML_UNARY_OP_TANH:
|
case GGML_UNARY_OP_TANH:
|
||||||
|
case GGML_UNARY_OP_EXP:
|
||||||
return ggml_is_contiguous(op->src[0]);
|
return ggml_is_contiguous(op->src[0]);
|
||||||
default:
|
default:
|
||||||
return false;
|
return false;
|
||||||
|
@ -2826,6 +2834,12 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
|
||||||
if (op->op == GGML_OP_MUL_MAT && a->ne[3] != b->ne[3]) {
|
if (op->op == GGML_OP_MUL_MAT && a->ne[3] != b->ne[3]) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
#ifdef GGML_USE_MUSA
|
||||||
|
if (b->type == GGML_TYPE_F16 && b->ne[2]*b->ne[3] > 1 &&
|
||||||
|
!ggml_is_transposed(a) && !ggml_is_transposed(b)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
#endif // GGML_USE_MUSA
|
||||||
switch (a->type) {
|
switch (a->type) {
|
||||||
case GGML_TYPE_F32:
|
case GGML_TYPE_F32:
|
||||||
case GGML_TYPE_F16:
|
case GGML_TYPE_F16:
|
||||||
|
@ -2849,6 +2863,11 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
|
||||||
case GGML_TYPE_IQ3_XXS:
|
case GGML_TYPE_IQ3_XXS:
|
||||||
case GGML_TYPE_IQ4_NL:
|
case GGML_TYPE_IQ4_NL:
|
||||||
case GGML_TYPE_IQ4_XS:
|
case GGML_TYPE_IQ4_XS:
|
||||||
|
#ifdef GGML_USE_MUSA
|
||||||
|
if (a->type == GGML_TYPE_Q3_K) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
#endif // GGML_USE_MUSA
|
||||||
return true;
|
return true;
|
||||||
default:
|
default:
|
||||||
return false;
|
return false;
|
||||||
|
@ -2884,6 +2903,9 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
|
||||||
if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_Q8_0) {
|
if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_Q8_0) {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
if (src0_type == GGML_TYPE_Q8_0 && src1_type == GGML_TYPE_F32) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_Q4_0) {
|
if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_Q4_0) {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
@ -2971,20 +2993,24 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
|
||||||
case GGML_OP_ARANGE:
|
case GGML_OP_ARANGE:
|
||||||
case GGML_OP_TIMESTEP_EMBEDDING:
|
case GGML_OP_TIMESTEP_EMBEDDING:
|
||||||
case GGML_OP_LEAKY_RELU:
|
case GGML_OP_LEAKY_RELU:
|
||||||
|
case GGML_OP_RWKV_WKV:
|
||||||
return true;
|
return true;
|
||||||
case GGML_OP_FLASH_ATTN_EXT:
|
case GGML_OP_FLASH_ATTN_EXT: {
|
||||||
#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
|
#ifndef FLASH_ATTN_AVAILABLE
|
||||||
return (op->src[0]->ne[0] == 64 && op->src[1]->type == GGML_TYPE_F16) || op->src[0]->ne[0] == 128;
|
return false;
|
||||||
#else
|
#endif
|
||||||
if (op->src[0]->ne[0] == 128) {
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
if (op->src[0]->ne[0] == 64 && op->src[1]->type == GGML_TYPE_F16) {
|
if (op->src[0]->ne[0] == 64 && op->src[1]->type == GGML_TYPE_F16) {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
return ggml_cuda_info().devices[cuda_ctx->device].cc >= CC_VOLTA &&
|
if (op->src[0]->ne[0] == 128) {
|
||||||
op->src[1]->type == GGML_TYPE_F16 && op->src[2]->type == GGML_TYPE_F16;
|
return true;
|
||||||
#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
|
}
|
||||||
|
if (op->src[0]->ne[0] == 256 && op->src[1]->type == GGML_TYPE_F16 && op->src[2]->type == GGML_TYPE_F16) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
const int cc = ggml_cuda_info().devices[cuda_ctx->device].cc;
|
||||||
|
return cc >= CC_VOLTA && cc < CC_OFFSET_AMD && op->src[1]->type == GGML_TYPE_F16 && op->src[2]->type == GGML_TYPE_F16;
|
||||||
|
}
|
||||||
case GGML_OP_CROSS_ENTROPY_LOSS:
|
case GGML_OP_CROSS_ENTROPY_LOSS:
|
||||||
case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
|
case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
|
||||||
case GGML_OP_OPT_STEP_ADAMW:
|
case GGML_OP_OPT_STEP_ADAMW:
|
||||||
|
|
|
@ -50,6 +50,8 @@
|
||||||
#define CC_RDNA1 (CC_OFFSET_AMD + 1010)
|
#define CC_RDNA1 (CC_OFFSET_AMD + 1010)
|
||||||
#define CC_RDNA2 (CC_OFFSET_AMD + 1030)
|
#define CC_RDNA2 (CC_OFFSET_AMD + 1030)
|
||||||
#define CC_RDNA3 (CC_OFFSET_AMD + 1100)
|
#define CC_RDNA3 (CC_OFFSET_AMD + 1100)
|
||||||
|
#define CC_QY1 210
|
||||||
|
#define CC_QY2 220
|
||||||
|
|
||||||
#define MATRIX_ROW_PADDING 512 // last row of quant. matrices is a multiple of this to avoid out-of-bounds memory accesses
|
#define MATRIX_ROW_PADDING 512 // last row of quant. matrices is a multiple of this to avoid out-of-bounds memory accesses
|
||||||
|
|
||||||
|
@ -134,6 +136,10 @@ typedef float2 dfloat2;
|
||||||
#define INT8_MMA_AVAILABLE
|
#define INT8_MMA_AVAILABLE
|
||||||
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_TURING
|
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_TURING
|
||||||
|
|
||||||
|
#if !(defined(GGML_USE_MUSA) && __MUSA_ARCH__ <= CC_QY1)
|
||||||
|
#define FLASH_ATTN_AVAILABLE
|
||||||
|
#endif // !(defined(GGML_USE_MUSA) && __MUSA_ARCH__ <= CC_QY1)
|
||||||
|
|
||||||
static constexpr bool fast_fp16_available(const int cc) {
|
static constexpr bool fast_fp16_available(const int cc) {
|
||||||
return cc >= CC_PASCAL && cc != 610;
|
return cc >= CC_PASCAL && cc != 610;
|
||||||
}
|
}
|
||||||
|
|
|
@ -81,6 +81,17 @@ static __device__ void cpy_blck_f32_q8_0(const char * cxi, char * cdsti) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static __device__ void cpy_blck_q8_0_f32(const char * cxi, char * cdsti) {
|
||||||
|
const block_q8_0 * xi = (const block_q8_0 *) cxi;
|
||||||
|
float * dsti = (float *) cdsti;
|
||||||
|
|
||||||
|
const float d = (float)xi->d;
|
||||||
|
|
||||||
|
for (int j = 0; j < QK8_0; j++) {
|
||||||
|
dsti[j] = xi->qs[j] * d;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
static __device__ void cpy_blck_f32_q4_0(const char * cxi, char * cdsti) {
|
static __device__ void cpy_blck_f32_q4_0(const char * cxi, char * cdsti) {
|
||||||
const float * xi = (const float *) cxi;
|
const float * xi = (const float *) cxi;
|
||||||
block_q4_0 * dsti = (block_q4_0 *) cdsti;
|
block_q4_0 * dsti = (block_q4_0 *) cdsti;
|
||||||
|
@ -288,6 +299,32 @@ static __global__ void cpy_f32_q(const char * cx, char * cdst, const int ne,
|
||||||
cpy_blck(cx + x_offset, cdst + dst_offset);
|
cpy_blck(cx + x_offset, cdst + dst_offset);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <cpy_kernel_t cpy_blck, int qk>
|
||||||
|
static __global__ void cpy_q_f32(const char * cx, char * cdst, const int ne,
|
||||||
|
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
|
||||||
|
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
|
||||||
|
const int nb12, const int nb13) {
|
||||||
|
const int i = (blockDim.x*blockIdx.x + threadIdx.x)*qk;
|
||||||
|
|
||||||
|
if (i >= ne) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const int i03 = i/(ne00 * ne01 * ne02);
|
||||||
|
const int i02 = (i - i03*ne00*ne01*ne02 )/ (ne00*ne01);
|
||||||
|
const int i01 = (i - i03*ne00*ne01*ne02 - i02*ne01*ne00) / ne00;
|
||||||
|
const int i00 = i - i03*ne00*ne01*ne02 - i02*ne01*ne00 - i01*ne00;
|
||||||
|
const int x_offset = (i00/qk)*nb00 + i01*nb01 + i02*nb02 + i03 * nb03;
|
||||||
|
|
||||||
|
const int i13 = i/(ne10 * ne11 * ne12);
|
||||||
|
const int i12 = (i - i13*ne10*ne11*ne12) / (ne10*ne11);
|
||||||
|
const int i11 = (i - i13*ne10*ne11*ne12 - i12*ne10*ne11) / ne10;
|
||||||
|
const int i10 = i - i13*ne10*ne11*ne12 - i12*ne10*ne11 - i11*ne10;
|
||||||
|
const int dst_offset = i10*nb10 + i11*nb11 + i12*nb12 + i13*nb13;
|
||||||
|
|
||||||
|
cpy_blck(cx + x_offset, cdst + dst_offset);
|
||||||
|
}
|
||||||
|
|
||||||
static void ggml_cpy_f16_f32_cuda(
|
static void ggml_cpy_f16_f32_cuda(
|
||||||
const char * cx, char * cdst, const int ne,
|
const char * cx, char * cdst, const int ne,
|
||||||
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
|
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
|
||||||
|
@ -329,6 +366,16 @@ static void ggml_cpy_f32_q8_0_cuda(
|
||||||
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
|
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static void ggml_cpy_q8_0_f32_cuda(
|
||||||
|
const char * cx, char * cdst, const int ne,
|
||||||
|
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
|
||||||
|
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
|
||||||
|
|
||||||
|
const int num_blocks = ne;
|
||||||
|
cpy_q_f32<cpy_blck_q8_0_f32, QK8_0><<<num_blocks, 1, 0, stream>>>
|
||||||
|
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
|
||||||
|
}
|
||||||
|
|
||||||
static void ggml_cpy_f32_q4_0_cuda(
|
static void ggml_cpy_f32_q4_0_cuda(
|
||||||
const char * cx, char * cdst, const int ne,
|
const char * cx, char * cdst, const int ne,
|
||||||
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
|
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
|
||||||
|
@ -437,6 +484,8 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
|
||||||
ggml_cpy_f32_f16_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
ggml_cpy_f32_f16_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||||
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q8_0) {
|
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q8_0) {
|
||||||
ggml_cpy_f32_q8_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
ggml_cpy_f32_q8_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||||
|
} else if (src0->type == GGML_TYPE_Q8_0 && src1->type == GGML_TYPE_F32) {
|
||||||
|
ggml_cpy_q8_0_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||||
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_0) {
|
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_0) {
|
||||||
ggml_cpy_f32_q4_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
ggml_cpy_f32_q4_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||||
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_1) {
|
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_1) {
|
||||||
|
@ -471,6 +520,8 @@ void* ggml_cuda_cpy_fn(const ggml_tensor * src0, ggml_tensor * src1) {
|
||||||
return (void*) cpy_f32_f16<cpy_1_f32_f16>;
|
return (void*) cpy_f32_f16<cpy_1_f32_f16>;
|
||||||
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q8_0) {
|
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q8_0) {
|
||||||
return (void*) cpy_f32_q<cpy_blck_f32_q8_0, QK8_0>;
|
return (void*) cpy_f32_q<cpy_blck_f32_q8_0, QK8_0>;
|
||||||
|
} else if (src0->type == GGML_TYPE_Q8_0 && src1->type == GGML_TYPE_F32) {
|
||||||
|
return (void*) cpy_q_f32<cpy_blck_q8_0_f32, QK8_0>;
|
||||||
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_0) {
|
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_0) {
|
||||||
return (void*) cpy_f32_q<cpy_blck_f32_q4_0, QK4_0>;
|
return (void*) cpy_f32_q<cpy_blck_f32_q4_0, QK4_0>;
|
||||||
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_1) {
|
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_1) {
|
||||||
|
|
|
@ -44,13 +44,17 @@ static __global__ void flash_attn_tile_ext_f32(
|
||||||
const int ne1,
|
const int ne1,
|
||||||
const int ne2,
|
const int ne2,
|
||||||
const int ne3) {
|
const int ne3) {
|
||||||
|
#ifndef FLASH_ATTN_AVAILABLE
|
||||||
|
NO_DEVICE_CODE;
|
||||||
|
return;
|
||||||
|
#endif // FLASH_ATTN_AVAILABLE
|
||||||
// Skip unused kernel variants for faster compilation:
|
// Skip unused kernel variants for faster compilation:
|
||||||
if (use_logit_softcap && !(D == 128 || D == 256)) {
|
if (use_logit_softcap && !(D == 128 || D == 256)) {
|
||||||
NO_DEVICE_CODE;
|
NO_DEVICE_CODE;
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
//In this kernel Q, K, V are matrices while i, j, k are matrix indices.
|
// In this kernel Q, K, V are matrices while i, j, k are matrix indices.
|
||||||
|
|
||||||
const int ic0 = (blockIdx.x / parallel_blocks) * ncols; // Index of the Q/QKV column to work on.
|
const int ic0 = (blockIdx.x / parallel_blocks) * ncols; // Index of the Q/QKV column to work on.
|
||||||
const int ip = blockIdx.x % parallel_blocks; // Index in group of blocks running for the same column in parallel.
|
const int ip = blockIdx.x % parallel_blocks; // Index in group of blocks running for the same column in parallel.
|
||||||
|
|
|
@ -314,7 +314,7 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!fast_fp16_available(cc)) {
|
if (!fast_fp16_available(cc)) {
|
||||||
if (Q->ne[1] <= 8) {
|
if (Q->ne[1] <= 8 || Q->ne[0] == 256) {
|
||||||
ggml_cuda_flash_attn_ext_vec_f32(ctx, dst);
|
ggml_cuda_flash_attn_ext_vec_f32(ctx, dst);
|
||||||
} else {
|
} else {
|
||||||
ggml_cuda_flash_attn_ext_tile_f32(ctx, dst);
|
ggml_cuda_flash_attn_ext_tile_f32(ctx, dst);
|
||||||
|
|
89
ggml/src/ggml-cuda/rwkv-wkv.cu
Normal file
89
ggml/src/ggml-cuda/rwkv-wkv.cu
Normal file
|
@ -0,0 +1,89 @@
|
||||||
|
#include "common.cuh"
|
||||||
|
#include "rwkv-wkv.cuh"
|
||||||
|
|
||||||
|
static __global__ void rwkv_wkv_f32(const int B, const int T, const int C, const int H, const float * k, const float * v, const float * r, const float * tf, const float * td, const float * s, float * dst) {
|
||||||
|
const int tid = threadIdx.x;
|
||||||
|
const int bid = blockIdx.x;
|
||||||
|
|
||||||
|
const int head_size = CUDA_WKV_BLOCK_SIZE;
|
||||||
|
const int batch_i = bid / H;
|
||||||
|
const int head_i = bid % H;
|
||||||
|
const int state_size = C * head_size;
|
||||||
|
const int n_seq_tokens = T / B;
|
||||||
|
|
||||||
|
float state[head_size];
|
||||||
|
__shared__ float _k[head_size], _r[head_size], _tf[head_size], _td[head_size];
|
||||||
|
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < head_size; i++) {
|
||||||
|
state[i] = s[batch_i * state_size + head_i * head_size * head_size + i * head_size + tid];
|
||||||
|
}
|
||||||
|
|
||||||
|
__syncthreads();
|
||||||
|
_tf[tid] = tf[head_i * head_size + tid];
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
for (int t = batch_i * n_seq_tokens * C + head_i * head_size + tid; t < (batch_i + 1) * n_seq_tokens * C + head_i * head_size + tid; t += C) {
|
||||||
|
__syncthreads();
|
||||||
|
_k[tid] = k[t];
|
||||||
|
_r[tid] = r[t];
|
||||||
|
_td[tid] = td[t];
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
const float _v = v[t];
|
||||||
|
float y = 0;
|
||||||
|
for (int j = 0; j < head_size; j += 4) {
|
||||||
|
const float4& k = (float4&)(_k[j]);
|
||||||
|
const float4& r = (float4&)(_r[j]);
|
||||||
|
const float4& tf = (float4&)(_tf[j]);
|
||||||
|
const float4& td = (float4&)(_td[j]);
|
||||||
|
float4& s = (float4&)(state[j]);
|
||||||
|
float4 kv;
|
||||||
|
|
||||||
|
kv.x = k.x * _v;
|
||||||
|
kv.y = k.y * _v;
|
||||||
|
kv.z = k.z * _v;
|
||||||
|
kv.w = k.w * _v;
|
||||||
|
|
||||||
|
y += r.x * (tf.x * kv.x + s.x);
|
||||||
|
y += r.y * (tf.y * kv.y + s.y);
|
||||||
|
y += r.z * (tf.z * kv.z + s.z);
|
||||||
|
y += r.w * (tf.w * kv.w + s.w);
|
||||||
|
|
||||||
|
s.x = s.x * td.x + kv.x;
|
||||||
|
s.y = s.y * td.y + kv.y;
|
||||||
|
s.z = s.z * td.z + kv.z;
|
||||||
|
s.w = s.w * td.w + kv.w;
|
||||||
|
}
|
||||||
|
dst[t] = y;
|
||||||
|
}
|
||||||
|
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < head_size; i++) {
|
||||||
|
dst[T * C + batch_i * state_size + head_i * head_size * head_size + i * head_size + tid] = state[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void ggml_cuda_op_rwkv_wkv(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||||
|
const float * k_d = (const float *)dst->src[0]->data;
|
||||||
|
const float * v_d = (const float *)dst->src[1]->data;
|
||||||
|
const float * r_d = (const float *)dst->src[2]->data;
|
||||||
|
const float * tf_d = (const float *)dst->src[3]->data;
|
||||||
|
const float * td_d = (const float *)dst->src[4]->data;
|
||||||
|
const float * s_d = (const float *)dst->src[5]->data;
|
||||||
|
|
||||||
|
const int64_t B = dst->src[5]->ne[1];
|
||||||
|
const int64_t T = dst->src[0]->ne[3];
|
||||||
|
const int64_t C = dst->ne[0];
|
||||||
|
const int64_t H = dst->src[0]->ne[2];
|
||||||
|
|
||||||
|
float * dst_d = (float *)dst->data;
|
||||||
|
|
||||||
|
cudaStream_t stream = ctx.stream();
|
||||||
|
|
||||||
|
GGML_ASSERT(dst->src[5]->type == GGML_TYPE_F32);
|
||||||
|
GGML_ASSERT(C % H == 0);
|
||||||
|
GGML_ASSERT(C / H == CUDA_WKV_BLOCK_SIZE);
|
||||||
|
|
||||||
|
rwkv_wkv_f32<<<B * H, C / H, 0, stream>>>(B, T, C, H, k_d, v_d, r_d, tf_d, td_d, s_d, dst_d);
|
||||||
|
}
|
5
ggml/src/ggml-cuda/rwkv-wkv.cuh
Normal file
5
ggml/src/ggml-cuda/rwkv-wkv.cuh
Normal file
|
@ -0,0 +1,5 @@
|
||||||
|
#include "common.cuh"
|
||||||
|
|
||||||
|
#define CUDA_WKV_BLOCK_SIZE 64
|
||||||
|
|
||||||
|
void ggml_cuda_op_rwkv_wkv(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
|
@ -95,6 +95,15 @@ static __global__ void hardswish_f32(const float * x, float * dst, const int k)
|
||||||
dst[i] = x[i] * fminf(1.0f, fmaxf(0.0f, (x[i] + 3.0f) / 6.0f));
|
dst[i] = x[i] * fminf(1.0f, fmaxf(0.0f, (x[i] + 3.0f) / 6.0f));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static __global__ void exp_f32(const float * x, float * dst, const int k) {
|
||||||
|
const int i = blockDim.x*blockIdx.x + threadIdx.x;
|
||||||
|
|
||||||
|
if (i >= k) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
dst[i] = expf(x[i]);
|
||||||
|
}
|
||||||
|
|
||||||
static __global__ void leaky_relu_f32(const float * x, float * dst, const int k, const float negative_slope) {
|
static __global__ void leaky_relu_f32(const float * x, float * dst, const int k, const float negative_slope) {
|
||||||
const int i = blockDim.x*blockIdx.x + threadIdx.x;
|
const int i = blockDim.x*blockIdx.x + threadIdx.x;
|
||||||
if (i >= k) {
|
if (i >= k) {
|
||||||
|
@ -189,6 +198,11 @@ static void hardswish_f32_cuda(const float * x, float * dst, const int k, cudaSt
|
||||||
hardswish_f32<<<num_blocks, CUDA_HARDSWISH_BLOCK_SIZE, 0, stream>>>(x, dst, k);
|
hardswish_f32<<<num_blocks, CUDA_HARDSWISH_BLOCK_SIZE, 0, stream>>>(x, dst, k);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static void exp_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
|
||||||
|
const int num_blocks = (k + CUDA_EXP_BLOCK_SIZE - 1) / CUDA_EXP_BLOCK_SIZE;
|
||||||
|
exp_f32<<<num_blocks, CUDA_EXP_BLOCK_SIZE, 0, stream>>>(x, dst, k);
|
||||||
|
}
|
||||||
|
|
||||||
static void leaky_relu_f32_cuda(const float * x, float * dst, const int k, const float negative_slope, cudaStream_t stream) {
|
static void leaky_relu_f32_cuda(const float * x, float * dst, const int k, const float negative_slope, cudaStream_t stream) {
|
||||||
const int num_blocks = (k + CUDA_RELU_BLOCK_SIZE - 1) / CUDA_RELU_BLOCK_SIZE;
|
const int num_blocks = (k + CUDA_RELU_BLOCK_SIZE - 1) / CUDA_RELU_BLOCK_SIZE;
|
||||||
leaky_relu_f32<<<num_blocks, CUDA_RELU_BLOCK_SIZE, 0, stream>>>(x, dst, k, negative_slope);
|
leaky_relu_f32<<<num_blocks, CUDA_RELU_BLOCK_SIZE, 0, stream>>>(x, dst, k, negative_slope);
|
||||||
|
@ -354,6 +368,20 @@ void ggml_cuda_op_hardswish(ggml_backend_cuda_context & ctx, ggml_tensor * dst)
|
||||||
hardswish_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream);
|
hardswish_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void ggml_cuda_op_exp(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||||
|
const ggml_tensor * src0 = dst->src[0];
|
||||||
|
const float * src0_d = (const float *)src0->data;
|
||||||
|
float * dst_d = (float *)dst->data;
|
||||||
|
cudaStream_t stream = ctx.stream();
|
||||||
|
|
||||||
|
GGML_ASSERT(ggml_is_contiguous(src0));
|
||||||
|
|
||||||
|
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||||
|
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
||||||
|
|
||||||
|
exp_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream);
|
||||||
|
}
|
||||||
|
|
||||||
void ggml_cuda_op_leaky_relu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
void ggml_cuda_op_leaky_relu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||||
const ggml_tensor * src0 = dst->src[0];
|
const ggml_tensor * src0 = dst->src[0];
|
||||||
const float * src0_d = (const float *)src0->data;
|
const float * src0_d = (const float *)src0->data;
|
||||||
|
|
|
@ -8,6 +8,7 @@
|
||||||
#define CUDA_RELU_BLOCK_SIZE 256
|
#define CUDA_RELU_BLOCK_SIZE 256
|
||||||
#define CUDA_SIGMOID_BLOCK_SIZE 256
|
#define CUDA_SIGMOID_BLOCK_SIZE 256
|
||||||
#define CUDA_HARDSIGMOID_BLOCK_SIZE 256
|
#define CUDA_HARDSIGMOID_BLOCK_SIZE 256
|
||||||
|
#define CUDA_EXP_BLOCK_SIZE 256
|
||||||
#define CUDA_HARDSWISH_BLOCK_SIZE 256
|
#define CUDA_HARDSWISH_BLOCK_SIZE 256
|
||||||
#define CUDA_SQR_BLOCK_SIZE 256
|
#define CUDA_SQR_BLOCK_SIZE 256
|
||||||
#define CUDA_SQRT_BLOCK_SIZE 256
|
#define CUDA_SQRT_BLOCK_SIZE 256
|
||||||
|
@ -32,6 +33,8 @@ void ggml_cuda_op_sigmoid(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||||
|
|
||||||
void ggml_cuda_op_hardsigmoid(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
void ggml_cuda_op_hardsigmoid(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||||
|
|
||||||
|
void ggml_cuda_op_exp(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||||
|
|
||||||
void ggml_cuda_op_hardswish(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
void ggml_cuda_op_hardswish(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||||
|
|
||||||
void ggml_cuda_op_leaky_relu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
void ggml_cuda_op_leaky_relu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||||
|
|
2
ggml/src/ggml-cuda/vendors/musa.h
vendored
2
ggml/src/ggml-cuda/vendors/musa.h
vendored
|
@ -26,6 +26,7 @@
|
||||||
#define cublasSetStream mublasSetStream
|
#define cublasSetStream mublasSetStream
|
||||||
#define cublasSgemm mublasSgemm
|
#define cublasSgemm mublasSgemm
|
||||||
#define cublasStatus_t mublasStatus_t
|
#define cublasStatus_t mublasStatus_t
|
||||||
|
#define cublasOperation_t mublasOperation_t
|
||||||
#define cublasGetStatusString mublasStatus_to_string
|
#define cublasGetStatusString mublasStatus_to_string
|
||||||
#define cudaDataType_t musaDataType_t
|
#define cudaDataType_t musaDataType_t
|
||||||
#define cudaDeviceCanAccessPeer musaDeviceCanAccessPeer
|
#define cudaDeviceCanAccessPeer musaDeviceCanAccessPeer
|
||||||
|
@ -56,6 +57,7 @@
|
||||||
#define cudaLaunchHostFunc musaLaunchHostFunc
|
#define cudaLaunchHostFunc musaLaunchHostFunc
|
||||||
#define cudaMalloc musaMalloc
|
#define cudaMalloc musaMalloc
|
||||||
#define cudaMallocHost musaMallocHost
|
#define cudaMallocHost musaMallocHost
|
||||||
|
#define cudaMallocManaged musaMallocManaged
|
||||||
#define cudaMemcpy musaMemcpy
|
#define cudaMemcpy musaMemcpy
|
||||||
#define cudaMemcpyAsync musaMemcpyAsync
|
#define cudaMemcpyAsync musaMemcpyAsync
|
||||||
#define cudaMemcpyPeerAsync musaMemcpyPeerAsync
|
#define cudaMemcpyPeerAsync musaMemcpyPeerAsync
|
||||||
|
|
|
@ -2631,11 +2631,11 @@ kernel void kernel_flash_attn_ext_vec_f16(
|
||||||
const short iv3 = iq3 / rv3;
|
const short iv3 = iq3 / rv3;
|
||||||
|
|
||||||
// load the queries from shared memory into local memory
|
// load the queries from shared memory into local memory
|
||||||
half4 mq[D4];
|
float4 mq[D4];
|
||||||
|
|
||||||
for (short ii = 0; ii < D4; ii += NW) {
|
for (short ii = 0; ii < D4; ii += NW) {
|
||||||
short i = ii + tiisg;
|
short i = ii + tiisg;
|
||||||
mq[i] = sq4[i];
|
mq[i] = (float4) sq4[i];
|
||||||
}
|
}
|
||||||
|
|
||||||
// pointer to the mask
|
// pointer to the mask
|
||||||
|
@ -2661,11 +2661,11 @@ kernel void kernel_flash_attn_ext_vec_f16(
|
||||||
for (short ii = 0; ii < D4; ii += NW) {
|
for (short ii = 0; ii < D4; ii += NW) {
|
||||||
const short i = ii + tiisg;
|
const short i = ii + tiisg;
|
||||||
|
|
||||||
half4x4 mk;
|
float4x4 mk;
|
||||||
mk[0] = pk4[i + 0*(nb11/8)];
|
mk[0] = (float4) pk4[i + 0*(nb11/8)];
|
||||||
mk[1] = pk4[i + 1*(nb11/8)];
|
mk[1] = (float4) pk4[i + 1*(nb11/8)];
|
||||||
mk[2] = pk4[i + 2*(nb11/8)];
|
mk[2] = (float4) pk4[i + 2*(nb11/8)];
|
||||||
mk[3] = pk4[i + 3*(nb11/8)];
|
mk[3] = (float4) pk4[i + 3*(nb11/8)];
|
||||||
|
|
||||||
mqk += (float4) (mq[i] * mk);
|
mqk += (float4) (mq[i] * mk);
|
||||||
}
|
}
|
||||||
|
|
|
@ -3496,8 +3496,7 @@ static void ggml_sycl_mul_mat(ggml_backend_sycl_context & ctx, const ggml_tensor
|
||||||
|
|
||||||
bool use_mul_mat_vec_q = ggml_is_quantized(src0->type)
|
bool use_mul_mat_vec_q = ggml_is_quantized(src0->type)
|
||||||
&& src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32
|
&& src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32
|
||||||
&& src1->ne[1] <= MMVQ_MAX_BATCH_SIZE
|
&& src1->ne[1] <= MMVQ_MAX_BATCH_SIZE;
|
||||||
&& (ctx.stream()->get_backend() == sycl::backend::ext_oneapi_cuda || src1->ne[1] > MMVQ_MIN_BATCH_SIZE);
|
|
||||||
|
|
||||||
bool use_mul_mat_q = ggml_sycl_supports_mmq(src0->type)
|
bool use_mul_mat_q = ggml_sycl_supports_mmq(src0->type)
|
||||||
&& src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32;
|
&& src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32;
|
||||||
|
|
|
@ -134,7 +134,6 @@ typedef sycl::float2 dfloat2;
|
||||||
#endif // GGML_SYCL_F16
|
#endif // GGML_SYCL_F16
|
||||||
|
|
||||||
#define MMVQ_MAX_BATCH_SIZE 8
|
#define MMVQ_MAX_BATCH_SIZE 8
|
||||||
#define MMVQ_MIN_BATCH_SIZE 4
|
|
||||||
|
|
||||||
static const int8_t kvalues_iq4nl[16]={-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113};
|
static const int8_t kvalues_iq4nl[16]={-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113};
|
||||||
|
|
||||||
|
|
|
@ -63,6 +63,25 @@ int ggml_sve_cnt_b = 0;
|
||||||
#pragma warning(disable: 4702)
|
#pragma warning(disable: 4702)
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
// Note: once we move threading into a separate C++ file
|
||||||
|
// will use std::hardware_destructive_interference_size instead of hardcoding it here
|
||||||
|
// and we'll use C++ attribute syntax.
|
||||||
|
#define GGML_CACHE_LINE 64
|
||||||
|
|
||||||
|
#if defined(__clang__) || defined(__GNUC__)
|
||||||
|
#define GGML_CACHE_ALIGN __attribute__((aligned(GGML_CACHE_LINE)))
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#if defined(__has_feature)
|
||||||
|
#if __has_feature(thread_sanitizer)
|
||||||
|
#define GGML_TSAN_ENABLED 1
|
||||||
|
#endif
|
||||||
|
#else // __has_feature
|
||||||
|
#if defined(__SANITIZE_THREAD__)
|
||||||
|
#define GGML_TSAN_ENABLED 1
|
||||||
|
#endif
|
||||||
|
#endif // __has_feature
|
||||||
|
|
||||||
#if defined(_WIN32)
|
#if defined(_WIN32)
|
||||||
|
|
||||||
#define WIN32_LEAN_AND_MEAN
|
#define WIN32_LEAN_AND_MEAN
|
||||||
|
@ -72,6 +91,8 @@ int ggml_sve_cnt_b = 0;
|
||||||
#include <windows.h>
|
#include <windows.h>
|
||||||
|
|
||||||
#if !defined(__clang__)
|
#if !defined(__clang__)
|
||||||
|
#define GGML_CACHE_ALIGN __declspec(align(GGML_CACHE_LINE))
|
||||||
|
|
||||||
typedef volatile LONG atomic_int;
|
typedef volatile LONG atomic_int;
|
||||||
typedef atomic_int atomic_bool;
|
typedef atomic_int atomic_bool;
|
||||||
typedef atomic_int atomic_flag;
|
typedef atomic_int atomic_flag;
|
||||||
|
@ -114,6 +135,9 @@ static atomic_bool atomic_flag_test_and_set(atomic_flag * ptr) {
|
||||||
static void atomic_flag_clear(atomic_flag * ptr) {
|
static void atomic_flag_clear(atomic_flag * ptr) {
|
||||||
InterlockedExchange(ptr, 0);
|
InterlockedExchange(ptr, 0);
|
||||||
}
|
}
|
||||||
|
static void atomic_thread_fence(memory_order mo) {
|
||||||
|
MemoryBarrier();
|
||||||
|
}
|
||||||
#else // clang
|
#else // clang
|
||||||
#include <stdatomic.h>
|
#include <stdatomic.h>
|
||||||
#endif
|
#endif
|
||||||
|
@ -289,7 +313,6 @@ void ggml_abort(const char * file, int line, const char * fmt, ...) {
|
||||||
#define GGML_DEBUG 0
|
#define GGML_DEBUG 0
|
||||||
#define GGML_GELU_FP16
|
#define GGML_GELU_FP16
|
||||||
#define GGML_GELU_QUICK_FP16
|
#define GGML_GELU_QUICK_FP16
|
||||||
#define GGML_N_TASKS_MAX (-1)
|
|
||||||
|
|
||||||
#define GGML_SOFT_MAX_UNROLL 4
|
#define GGML_SOFT_MAX_UNROLL 4
|
||||||
#define GGML_VEC_DOT_UNROLL 2
|
#define GGML_VEC_DOT_UNROLL 2
|
||||||
|
@ -2015,8 +2038,8 @@ struct ggml_threadpool {
|
||||||
|
|
||||||
// synchronization primitives
|
// synchronization primitives
|
||||||
atomic_int n_graph; // incremented when there is work to be done (i.e each graph)
|
atomic_int n_graph; // incremented when there is work to be done (i.e each graph)
|
||||||
atomic_int n_barrier;
|
atomic_int GGML_CACHE_ALIGN n_barrier;
|
||||||
atomic_int n_barrier_passed;
|
atomic_int GGML_CACHE_ALIGN n_barrier_passed;
|
||||||
atomic_int current_chunk; // currently processing chunk during Mat_Mul, shared between all the threads.
|
atomic_int current_chunk; // currently processing chunk during Mat_Mul, shared between all the threads.
|
||||||
|
|
||||||
// these are atomic as an annotation for thread-sanitizer
|
// these are atomic as an annotation for thread-sanitizer
|
||||||
|
@ -3213,20 +3236,27 @@ static void ggml_barrier(struct ggml_threadpool * tp) {
|
||||||
// enter barrier (full seq-cst fence)
|
// enter barrier (full seq-cst fence)
|
||||||
int n_barrier = atomic_fetch_add_explicit(&tp->n_barrier, 1, memory_order_seq_cst);
|
int n_barrier = atomic_fetch_add_explicit(&tp->n_barrier, 1, memory_order_seq_cst);
|
||||||
|
|
||||||
int last = 0;
|
|
||||||
if (n_barrier == (n_threads - 1)) {
|
if (n_barrier == (n_threads - 1)) {
|
||||||
// last thread
|
// last thread
|
||||||
atomic_store_explicit(&tp->n_barrier, 0, memory_order_relaxed);
|
atomic_store_explicit(&tp->n_barrier, 0, memory_order_relaxed);
|
||||||
last = 1;
|
|
||||||
} else {
|
// exit barrier (fill seq-cst fence)
|
||||||
|
atomic_fetch_add_explicit(&tp->n_barrier_passed, 1, memory_order_seq_cst);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
// wait for other threads
|
// wait for other threads
|
||||||
while (atomic_load_explicit(&tp->n_barrier_passed, memory_order_relaxed) == n_passed) {
|
while (atomic_load_explicit(&tp->n_barrier_passed, memory_order_relaxed) == n_passed) {
|
||||||
ggml_thread_cpu_relax();
|
ggml_thread_cpu_relax();
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
// exit barrier (full seq-cst fence)
|
// exit barrier (full seq-cst fence)
|
||||||
atomic_fetch_add_explicit(&tp->n_barrier_passed, last, memory_order_seq_cst);
|
// TSAN doesn't support standalone fence yet, we use a dummy read-modify-write instead
|
||||||
|
#ifdef GGML_TSAN_ENABLED
|
||||||
|
atomic_fetch_add_explicit(&tp->n_barrier_passed, 0, memory_order_seq_cst);
|
||||||
|
#else
|
||||||
|
atomic_thread_fence(memory_order_seq_cst);
|
||||||
|
#endif
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -20299,10 +20329,13 @@ static inline bool ggml_graph_compute_thread_ready(struct ggml_compute_state * s
|
||||||
|
|
||||||
// sync thread state after polling
|
// sync thread state after polling
|
||||||
static inline void ggml_graph_compute_thread_sync(struct ggml_compute_state * state) {
|
static inline void ggml_graph_compute_thread_sync(struct ggml_compute_state * state) {
|
||||||
struct ggml_threadpool * threadpool = state->threadpool;
|
// TSAN doesn't support standalone fence yet, we use a dummy read-modify-write instead
|
||||||
// this should just be atomic_thread_fence(seq_cst) but it confuses thread-sanitizer
|
#ifdef GGML_TSAN_ENABLED
|
||||||
// so instead we just use a dummy read-modify-write
|
atomic_fetch_add_explicit(&state->threadpool->n_graph, 0, memory_order_seq_cst);
|
||||||
atomic_fetch_add_explicit(&threadpool->n_graph, 0, memory_order_seq_cst);
|
#else
|
||||||
|
atomic_thread_fence(memory_order_seq_cst);
|
||||||
|
#endif
|
||||||
|
UNUSED(state);
|
||||||
}
|
}
|
||||||
|
|
||||||
static inline bool ggml_graph_compute_poll_for_work(struct ggml_compute_state * state) {
|
static inline bool ggml_graph_compute_poll_for_work(struct ggml_compute_state * state) {
|
||||||
|
|
|
@ -235,6 +235,7 @@ class MODEL_ARCH(IntEnum):
|
||||||
NEMOTRON = auto()
|
NEMOTRON = auto()
|
||||||
EXAONE = auto()
|
EXAONE = auto()
|
||||||
GRANITE = auto()
|
GRANITE = auto()
|
||||||
|
GRANITE_MOE = auto()
|
||||||
|
|
||||||
|
|
||||||
class MODEL_TENSOR(IntEnum):
|
class MODEL_TENSOR(IntEnum):
|
||||||
|
@ -392,6 +393,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
|
||||||
MODEL_ARCH.NEMOTRON: "nemotron",
|
MODEL_ARCH.NEMOTRON: "nemotron",
|
||||||
MODEL_ARCH.EXAONE: "exaone",
|
MODEL_ARCH.EXAONE: "exaone",
|
||||||
MODEL_ARCH.GRANITE: "granite",
|
MODEL_ARCH.GRANITE: "granite",
|
||||||
|
MODEL_ARCH.GRANITE_MOE: "granitemoe",
|
||||||
}
|
}
|
||||||
|
|
||||||
TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
|
TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
|
||||||
|
@ -1232,6 +1234,7 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
|
||||||
MODEL_ARCH.GRANITE: [
|
MODEL_ARCH.GRANITE: [
|
||||||
MODEL_TENSOR.TOKEN_EMBD,
|
MODEL_TENSOR.TOKEN_EMBD,
|
||||||
MODEL_TENSOR.OUTPUT_NORM,
|
MODEL_TENSOR.OUTPUT_NORM,
|
||||||
|
MODEL_TENSOR.OUTPUT,
|
||||||
MODEL_TENSOR.ATTN_NORM,
|
MODEL_TENSOR.ATTN_NORM,
|
||||||
MODEL_TENSOR.ATTN_Q,
|
MODEL_TENSOR.ATTN_Q,
|
||||||
MODEL_TENSOR.ATTN_K,
|
MODEL_TENSOR.ATTN_K,
|
||||||
|
@ -1242,6 +1245,21 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
|
||||||
MODEL_TENSOR.FFN_DOWN,
|
MODEL_TENSOR.FFN_DOWN,
|
||||||
MODEL_TENSOR.FFN_UP,
|
MODEL_TENSOR.FFN_UP,
|
||||||
],
|
],
|
||||||
|
MODEL_ARCH.GRANITE_MOE: [
|
||||||
|
MODEL_TENSOR.TOKEN_EMBD,
|
||||||
|
MODEL_TENSOR.OUTPUT_NORM,
|
||||||
|
MODEL_TENSOR.OUTPUT,
|
||||||
|
MODEL_TENSOR.ATTN_NORM,
|
||||||
|
MODEL_TENSOR.ATTN_Q,
|
||||||
|
MODEL_TENSOR.ATTN_K,
|
||||||
|
MODEL_TENSOR.ATTN_V,
|
||||||
|
MODEL_TENSOR.ATTN_OUT,
|
||||||
|
MODEL_TENSOR.FFN_NORM,
|
||||||
|
MODEL_TENSOR.FFN_GATE_INP,
|
||||||
|
MODEL_TENSOR.FFN_GATE_EXP,
|
||||||
|
MODEL_TENSOR.FFN_DOWN_EXP,
|
||||||
|
MODEL_TENSOR.FFN_UP_EXP,
|
||||||
|
],
|
||||||
# TODO
|
# TODO
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -256,6 +256,7 @@ class TensorNameMap:
|
||||||
"model.layers.{bid}.mlp.gate", # qwen2moe olmoe
|
"model.layers.{bid}.mlp.gate", # qwen2moe olmoe
|
||||||
"transformer.decoder_layer.{bid}.router", # Grok
|
"transformer.decoder_layer.{bid}.router", # Grok
|
||||||
"transformer.blocks.{bid}.ffn.router.layer", # dbrx
|
"transformer.blocks.{bid}.ffn.router.layer", # dbrx
|
||||||
|
"model.layers.{bid}.block_sparse_moe.router.layer", # granitemoe
|
||||||
),
|
),
|
||||||
|
|
||||||
MODEL_TENSOR.FFN_GATE_INP_SHEXP: (
|
MODEL_TENSOR.FFN_GATE_INP_SHEXP: (
|
||||||
|
@ -368,6 +369,7 @@ class TensorNameMap:
|
||||||
"transformer.decoder_layer.{bid}.moe.linear_1", # Grok (merged)
|
"transformer.decoder_layer.{bid}.moe.linear_1", # Grok (merged)
|
||||||
"transformer.blocks.{bid}.ffn.experts.mlp.w2", # dbrx
|
"transformer.blocks.{bid}.ffn.experts.mlp.w2", # dbrx
|
||||||
"model.layers.{bid}.mlp.experts.down_proj", # qwen2moe olmoe (merged)
|
"model.layers.{bid}.mlp.experts.down_proj", # qwen2moe olmoe (merged)
|
||||||
|
"model.layers.{bid}.block_sparse_moe.output_linear", # granitemoe
|
||||||
),
|
),
|
||||||
|
|
||||||
MODEL_TENSOR.FFN_DOWN_SHEXP: (
|
MODEL_TENSOR.FFN_DOWN_SHEXP: (
|
||||||
|
|
|
@ -1068,6 +1068,7 @@ extern "C" {
|
||||||
LLAMA_API struct llama_sampler * llama_sampler_init_dist (uint32_t seed);
|
LLAMA_API struct llama_sampler * llama_sampler_init_dist (uint32_t seed);
|
||||||
|
|
||||||
/// @details Sorts candidate tokens by their logits in descending order and calculate probabilities based on logits.
|
/// @details Sorts candidate tokens by their logits in descending order and calculate probabilities based on logits.
|
||||||
|
/// NOTE: Avoid using on the full vocabulary as the sorting can become slow. For example, apply top-k or top-p sampling first.
|
||||||
LLAMA_API struct llama_sampler * llama_sampler_init_softmax (void);
|
LLAMA_API struct llama_sampler * llama_sampler_init_softmax (void);
|
||||||
|
|
||||||
/// @details Top-K sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751
|
/// @details Top-K sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751
|
||||||
|
|
69
klite.embd
69
klite.embd
|
@ -4185,6 +4185,7 @@ Current version indicated by LITEVER below.
|
||||||
|
|
||||||
const default_oai_image_endpoint = "/images/generations";
|
const default_oai_image_endpoint = "/images/generations";
|
||||||
const default_oai_tts_endpoint = "/audio/speech";
|
const default_oai_tts_endpoint = "/audio/speech";
|
||||||
|
const default_dalle_model_name = "dall-e-3";
|
||||||
|
|
||||||
const claude_submit_endpoint = "/complete";
|
const claude_submit_endpoint = "/complete";
|
||||||
const claude_submit_endpoint_v3 = "/messages";
|
const claude_submit_endpoint_v3 = "/messages";
|
||||||
|
@ -4325,6 +4326,7 @@ Current version indicated by LITEVER below.
|
||||||
saved_oai_addr: default_oai_base, //do not ever share this in save files!
|
saved_oai_addr: default_oai_base, //do not ever share this in save files!
|
||||||
saved_dalle_key: "",
|
saved_dalle_key: "",
|
||||||
saved_dalle_url: (default_oai_base + "/v1" + default_oai_image_endpoint),
|
saved_dalle_url: (default_oai_base + "/v1" + default_oai_image_endpoint),
|
||||||
|
saved_dalle_model: default_dalle_model_name,
|
||||||
saved_oai_tts_key: "",
|
saved_oai_tts_key: "",
|
||||||
saved_oai_tts_url: (default_oai_base + "/v1" + default_oai_tts_endpoint),
|
saved_oai_tts_url: (default_oai_base + "/v1" + default_oai_tts_endpoint),
|
||||||
saved_openrouter_key: "",
|
saved_openrouter_key: "",
|
||||||
|
@ -4557,16 +4559,23 @@ Current version indicated by LITEVER below.
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id":12,
|
"id":12,
|
||||||
"name":"Mistral Gen 1",
|
"name":"Mistral V1",
|
||||||
"user":"\\n[INST] ",
|
"user":"</s> [INST] ",
|
||||||
"assistant":" [/INST]\\n",
|
"assistant":" [/INST]",
|
||||||
"system":"",
|
"system":"",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id":13,
|
"id":13,
|
||||||
"name":"Mistral Gen 2",
|
"name":"Mistral V2 & V3",
|
||||||
"user":"</s>\\n[INST]",
|
"user":"</s>[INST] ",
|
||||||
"assistant":"[/INST]\\n",
|
"assistant":"[/INST]",
|
||||||
|
"system":"",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id":14,
|
||||||
|
"name":"Mistral V3-Tekken",
|
||||||
|
"user":"</s>[INST]",
|
||||||
|
"assistant":"[/INST]",
|
||||||
"system":"",
|
"system":"",
|
||||||
}
|
}
|
||||||
];
|
];
|
||||||
|
@ -5114,6 +5123,7 @@ Current version indicated by LITEVER below.
|
||||||
const foundChub = urlParams.get('chub');
|
const foundChub = urlParams.get('chub');
|
||||||
const foundPyg = urlParams.get('pyg');
|
const foundPyg = urlParams.get('pyg');
|
||||||
const foundAicc = urlParams.get('aicc');
|
const foundAicc = urlParams.get('aicc');
|
||||||
|
const foundQuery = urlParams.get('query');
|
||||||
|
|
||||||
if (foundStory && foundStory != "") {
|
if (foundStory && foundStory != "") {
|
||||||
if (localsettings.persist_session && !safe_to_overwrite()) {
|
if (localsettings.persist_session && !safe_to_overwrite()) {
|
||||||
|
@ -5150,6 +5160,25 @@ Current version indicated by LITEVER below.
|
||||||
//purge url params
|
//purge url params
|
||||||
window.history.replaceState(null, null, window.location.pathname);
|
window.history.replaceState(null, null, window.location.pathname);
|
||||||
}
|
}
|
||||||
|
else if (foundQuery && foundQuery != "")
|
||||||
|
{
|
||||||
|
window.history.replaceState(null, null, window.location.pathname);
|
||||||
|
if (localsettings.persist_session && !safe_to_overwrite()) {
|
||||||
|
msgboxYesNo("You already have an existing persistent story. Do you want to overwrite it?","Overwrite Story Warning",()=>{
|
||||||
|
localsettings.opmode = 4;
|
||||||
|
restart_new_game(false);
|
||||||
|
document.getElementById("input_text").value = foundQuery;
|
||||||
|
submit_generation();
|
||||||
|
},null,false);
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
localsettings.opmode = 4;
|
||||||
|
restart_new_game(false);
|
||||||
|
document.getElementById("input_text").value = foundQuery;
|
||||||
|
submit_generation();
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
var image_models_fetched = false;
|
var image_models_fetched = false;
|
||||||
|
@ -5363,6 +5392,18 @@ Current version indicated by LITEVER below.
|
||||||
}
|
}
|
||||||
},false);
|
},false);
|
||||||
}
|
}
|
||||||
|
function set_dalle_model()
|
||||||
|
{
|
||||||
|
inputBox("Enter DALL-E API Model Identifier.","DALL-E API Model Identifier",localsettings.saved_dalle_model,"Input DALL-E Model Identifier", ()=>{
|
||||||
|
let userinput = getInputBoxValue();
|
||||||
|
userinput = userinput.trim();
|
||||||
|
if (userinput != null && userinput!="") {
|
||||||
|
localsettings.saved_dalle_model = userinput.trim();
|
||||||
|
}else{
|
||||||
|
localsettings.saved_dalle_model = default_dalle_model_name;
|
||||||
|
}
|
||||||
|
},false);
|
||||||
|
}
|
||||||
|
|
||||||
function set_oai_tts_key()
|
function set_oai_tts_key()
|
||||||
{
|
{
|
||||||
|
@ -5394,7 +5435,7 @@ Current version indicated by LITEVER below.
|
||||||
let prompt = splits[0].trim();
|
let prompt = splits[0].trim();
|
||||||
|
|
||||||
let dalle_payload = {
|
let dalle_payload = {
|
||||||
"model": "dall-e-3",
|
"model": localsettings.saved_dalle_model,
|
||||||
"prompt": prompt,
|
"prompt": prompt,
|
||||||
"n": 1,
|
"n": 1,
|
||||||
"size": "1024x1024",
|
"size": "1024x1024",
|
||||||
|
@ -12596,7 +12637,7 @@ Current version indicated by LITEVER below.
|
||||||
//mistral api does not support presence pen
|
//mistral api does not support presence pen
|
||||||
oai_payload.presence_penalty = scaled_rep_pen;
|
oai_payload.presence_penalty = scaled_rep_pen;
|
||||||
}
|
}
|
||||||
if(targetep.toLowerCase().includes("featherless.ai"))
|
if(document.getElementById("useoainonstandard").checked || targetep.toLowerCase().includes("featherless.ai"))
|
||||||
{
|
{
|
||||||
//featherless api supports additional fields, include them
|
//featherless api supports additional fields, include them
|
||||||
oai_payload.top_k = (submit_payload.params.top_k<1?300:submit_payload.params.top_k);
|
oai_payload.top_k = (submit_payload.params.top_k<1?300:submit_payload.params.top_k);
|
||||||
|
@ -12605,6 +12646,7 @@ Current version indicated by LITEVER below.
|
||||||
{
|
{
|
||||||
oai_payload.seed = submit_payload.params.sampler_seed;
|
oai_payload.seed = submit_payload.params.sampler_seed;
|
||||||
}
|
}
|
||||||
|
oai_payload.top_a = localsettings.top_a;
|
||||||
}
|
}
|
||||||
if(submit_payload.params.logit_bias && JSON.stringify(submit_payload.params.logit_bias) != '{}')
|
if(submit_payload.params.logit_bias && JSON.stringify(submit_payload.params.logit_bias) != '{}')
|
||||||
{
|
{
|
||||||
|
@ -17982,11 +18024,13 @@ Current version indicated by LITEVER below.
|
||||||
<button type="button" class="btn btn-primary" style="display:inline;width:105px;" id="oaiusecustom" onclick="select_custom_oai_model()">Use Custom</button>
|
<button type="button" class="btn btn-primary" style="display:inline;width:105px;" id="oaiusecustom" onclick="select_custom_oai_model()">Use Custom</button>
|
||||||
<div style="display:inline-flex">
|
<div style="display:inline-flex">
|
||||||
<div><input type="checkbox" id="oaiaddversion" title="Add Endpoint Version Number" onchange="" checked>
|
<div><input type="checkbox" id="oaiaddversion" title="Add Endpoint Version Number" onchange="" checked>
|
||||||
<div class="box-label">Add Version Num</div></div>
|
<div class="box-label">Add Ver. Num</div></div>
|
||||||
<div><input type="checkbox" id="oaistreaming" title="Enable SSE Streaming" onchange="">
|
<div><input type="checkbox" id="oaistreaming" title="Enable SSE Streaming" onchange="">
|
||||||
<div class="box-label">Streaming</div></div>
|
<div class="box-label">Streaming</div></div>
|
||||||
<div><input type="checkbox" id="useoaichatcompl" title="Use ChatCompletions API" onchange="toggleoaichatcompl()">
|
<div><input type="checkbox" id="useoaichatcompl" title="Use ChatCompletions API" onchange="toggleoaichatcompl()">
|
||||||
<div class="box-label" id="useoaichatcompllabel">ChatCompletions API</div></div>
|
<div class="box-label">Chat-Completions API</div></div>
|
||||||
|
<div><input type="checkbox" id="useoainonstandard" title="Send Non-Standard Fields">
|
||||||
|
<div class="box-label">Non-Standard Fields</div></div>
|
||||||
</div>
|
</div>
|
||||||
<span id="useoaichatcomplbox" class="hidden" onload="toggleoaichatcompl();">
|
<span id="useoaichatcomplbox" class="hidden" onload="toggleoaichatcompl();">
|
||||||
<br>
|
<br>
|
||||||
|
@ -18694,8 +18738,9 @@ Current version indicated by LITEVER below.
|
||||||
</div>
|
</div>
|
||||||
<div id="generate_images_dalle_container" class="settinglabel hidden">
|
<div id="generate_images_dalle_container" class="settinglabel hidden">
|
||||||
<table width="100%"><tr>
|
<table width="100%"><tr>
|
||||||
<td><button id="generate_images_dalle_setkey" type="button" class="btn btn-primary" style="width:100%; padding:2px 3px;margin-top:2px;font-size:11px;" onclick="set_dalle_url()">Set URL</button></td>
|
<td><button type="button" class="btn btn-primary" style="width:100%; padding:2px 3px;margin-top:2px;font-size:11px;" onclick="set_dalle_url()">Set URL</button></td>
|
||||||
<td><button id="generate_images_dalle_seturl" type="button" class="btn btn-primary" style="width:100%; padding:2px 3px;margin-top:2px;font-size:11px;" onclick="set_dalle_key()">Set Key</button></td>
|
<td><button type="button" class="btn btn-primary" style="width:100%; padding:2px 3px;margin-top:2px;font-size:11px;" onclick="set_dalle_key()">Set Key</button></td>
|
||||||
|
<td><button type="button" class="btn btn-primary" style="width:100%; padding:2px 3px;margin-top:2px;font-size:11px;" onclick="set_dalle_model()">Model</button></td>
|
||||||
</tr></table>
|
</tr></table>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
|
|
|
@ -41,7 +41,7 @@ maxhordelen = 400
|
||||||
modelbusy = threading.Lock()
|
modelbusy = threading.Lock()
|
||||||
requestsinqueue = 0
|
requestsinqueue = 0
|
||||||
defaultport = 5001
|
defaultport = 5001
|
||||||
KcppVersion = "1.75.2"
|
KcppVersion = "1.76"
|
||||||
showdebug = True
|
showdebug = True
|
||||||
guimode = False
|
guimode = False
|
||||||
showsamplerwarning = True
|
showsamplerwarning = True
|
||||||
|
|
|
@ -28,6 +28,8 @@ void llama_log_callback_default(ggml_log_level level, const char * text, void *
|
||||||
#define LLAMA_LOG_INFO(...) llama_log_internal(GGML_LOG_LEVEL_INFO , __VA_ARGS__)
|
#define LLAMA_LOG_INFO(...) llama_log_internal(GGML_LOG_LEVEL_INFO , __VA_ARGS__)
|
||||||
#define LLAMA_LOG_WARN(...) llama_log_internal(GGML_LOG_LEVEL_WARN , __VA_ARGS__)
|
#define LLAMA_LOG_WARN(...) llama_log_internal(GGML_LOG_LEVEL_WARN , __VA_ARGS__)
|
||||||
#define LLAMA_LOG_ERROR(...) llama_log_internal(GGML_LOG_LEVEL_ERROR, __VA_ARGS__)
|
#define LLAMA_LOG_ERROR(...) llama_log_internal(GGML_LOG_LEVEL_ERROR, __VA_ARGS__)
|
||||||
|
#define LLAMA_LOG_DEBUG(...) llama_log_internal(GGML_LOG_LEVEL_DEBUG, __VA_ARGS__)
|
||||||
|
#define LLAMA_LOG_CONT(...) llama_log_internal(GGML_LOG_LEVEL_CONT , __VA_ARGS__)
|
||||||
|
|
||||||
//
|
//
|
||||||
// helpers
|
// helpers
|
||||||
|
|
|
@ -3,13 +3,14 @@
|
||||||
#include "llama-vocab.h"
|
#include "llama-vocab.h"
|
||||||
#include "llama-grammar.h"
|
#include "llama-grammar.h"
|
||||||
|
|
||||||
#include <cassert>
|
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include <cstring>
|
#include <cassert>
|
||||||
#include <ctime>
|
|
||||||
#include <cfloat>
|
#include <cfloat>
|
||||||
#include <chrono>
|
#include <chrono>
|
||||||
#include <cmath>
|
#include <cmath>
|
||||||
|
#include <cstdlib>
|
||||||
|
#include <cstring>
|
||||||
|
#include <ctime>
|
||||||
#include <numeric>
|
#include <numeric>
|
||||||
#include <random>
|
#include <random>
|
||||||
#include <unordered_map>
|
#include <unordered_map>
|
||||||
|
|
|
@ -1826,11 +1826,7 @@ llama_token_attr llama_token_get_attr_impl(const struct llama_vocab & vocab, lla
|
||||||
}
|
}
|
||||||
|
|
||||||
bool llama_token_is_eog_impl(const struct llama_vocab & vocab, llama_token token) {
|
bool llama_token_is_eog_impl(const struct llama_vocab & vocab, llama_token token) {
|
||||||
return token != -1 && (
|
return token != -1 && vocab.special_eog_ids.count(token) > 0;
|
||||||
token == llama_token_eos_impl(vocab) ||
|
|
||||||
token == llama_token_eot_impl(vocab) ||
|
|
||||||
token == llama_token_eom_impl(vocab)
|
|
||||||
);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
bool llama_token_is_control_impl(const struct llama_vocab & vocab, llama_token token) {
|
bool llama_token_is_control_impl(const struct llama_vocab & vocab, llama_token token) {
|
||||||
|
|
|
@ -6,6 +6,7 @@
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <unordered_map>
|
#include <unordered_map>
|
||||||
#include <map>
|
#include <map>
|
||||||
|
#include <set>
|
||||||
|
|
||||||
struct llama_vocab {
|
struct llama_vocab {
|
||||||
using id = llama_token;
|
using id = llama_token;
|
||||||
|
@ -49,6 +50,9 @@ struct llama_vocab {
|
||||||
id special_eot_id = -1; // TODO: move above after "eos_id", and here add "file separator" token
|
id special_eot_id = -1; // TODO: move above after "eos_id", and here add "file separator" token
|
||||||
id special_eom_id = -1;
|
id special_eom_id = -1;
|
||||||
|
|
||||||
|
// set of all tokens that cause "end of generation"
|
||||||
|
std::set<id> special_eog_ids;
|
||||||
|
|
||||||
// tokenizer flags
|
// tokenizer flags
|
||||||
bool tokenizer_add_space_prefix = false;
|
bool tokenizer_add_space_prefix = false;
|
||||||
bool tokenizer_add_bos = false;
|
bool tokenizer_add_bos = false;
|
||||||
|
|
124
src/llama.cpp
124
src/llama.cpp
|
@ -225,6 +225,7 @@ enum llm_arch {
|
||||||
LLM_ARCH_EXAONE,
|
LLM_ARCH_EXAONE,
|
||||||
LLM_ARCH_RWKV6,
|
LLM_ARCH_RWKV6,
|
||||||
LLM_ARCH_GRANITE,
|
LLM_ARCH_GRANITE,
|
||||||
|
LLM_ARCH_GRANITE_MOE,
|
||||||
LLM_ARCH_UNKNOWN,
|
LLM_ARCH_UNKNOWN,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -276,6 +277,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
|
||||||
{ LLM_ARCH_EXAONE, "exaone" },
|
{ LLM_ARCH_EXAONE, "exaone" },
|
||||||
{ LLM_ARCH_RWKV6, "rwkv6" },
|
{ LLM_ARCH_RWKV6, "rwkv6" },
|
||||||
{ LLM_ARCH_GRANITE, "granite" },
|
{ LLM_ARCH_GRANITE, "granite" },
|
||||||
|
{ LLM_ARCH_GRANITE_MOE, "granitemoe" },
|
||||||
{ LLM_ARCH_UNKNOWN, "(unknown)" },
|
{ LLM_ARCH_UNKNOWN, "(unknown)" },
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -1477,6 +1479,7 @@ static const std::map<llm_arch, std::map<llm_tensor, std::string>> LLM_TENSOR_NA
|
||||||
{
|
{
|
||||||
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
|
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
|
||||||
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
|
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
|
||||||
|
{ LLM_TENSOR_OUTPUT, "output" },
|
||||||
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
|
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
|
||||||
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
|
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
|
||||||
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
|
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
|
||||||
|
@ -1488,6 +1491,24 @@ static const std::map<llm_arch, std::map<llm_tensor, std::string>> LLM_TENSOR_NA
|
||||||
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
LLM_ARCH_GRANITE_MOE,
|
||||||
|
{
|
||||||
|
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
|
||||||
|
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
|
||||||
|
{ LLM_TENSOR_OUTPUT, "output" },
|
||||||
|
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
|
||||||
|
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
|
||||||
|
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
|
||||||
|
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
|
||||||
|
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
|
||||||
|
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
|
||||||
|
{ LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
|
||||||
|
{ LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" },
|
||||||
|
{ LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
|
||||||
|
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
|
||||||
|
},
|
||||||
|
},
|
||||||
{
|
{
|
||||||
LLM_ARCH_UNKNOWN,
|
LLM_ARCH_UNKNOWN,
|
||||||
{
|
{
|
||||||
|
@ -2410,7 +2431,7 @@ struct llama_hparams {
|
||||||
float f_max_alibi_bias = 0.0f;
|
float f_max_alibi_bias = 0.0f;
|
||||||
float f_logit_scale = 0.0f;
|
float f_logit_scale = 0.0f;
|
||||||
|
|
||||||
// Additional scale factors (Granite)
|
// Additional scale factors (Granite/Granite MoE)
|
||||||
float f_residual_scale = 0.0f;
|
float f_residual_scale = 0.0f;
|
||||||
float f_embedding_scale = 0.0f;
|
float f_embedding_scale = 0.0f;
|
||||||
float f_attention_scale = 0.0f;
|
float f_attention_scale = 0.0f;
|
||||||
|
@ -3070,18 +3091,14 @@ struct llama_sbatch {
|
||||||
} else {
|
} else {
|
||||||
// simple split
|
// simple split
|
||||||
if (batch->n_seq_id) {
|
if (batch->n_seq_id) {
|
||||||
for (size_t i = 0; i < length; ++i) {
|
|
||||||
ubatch.n_seq_id = batch->n_seq_id + seq.offset;
|
ubatch.n_seq_id = batch->n_seq_id + seq.offset;
|
||||||
}
|
|
||||||
} else {
|
} else {
|
||||||
for (size_t i = 0; i < length; ++i) {
|
for (size_t i = 0; i < length; ++i) {
|
||||||
ubatch.n_seq_id[ubatch.n_seqs + i] = 1;
|
ubatch.n_seq_id[ubatch.n_seqs + i] = 1;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (batch->seq_id) {
|
if (batch->seq_id) {
|
||||||
for (size_t i = 0; i < length; ++i) {
|
|
||||||
ubatch.seq_id = batch->seq_id + seq.offset;
|
ubatch.seq_id = batch->seq_id + seq.offset;
|
||||||
}
|
|
||||||
} else {
|
} else {
|
||||||
for (size_t i = 0; i < length; ++i) {
|
for (size_t i = 0; i < length; ++i) {
|
||||||
ubatch.seq_id[ubatch.n_seqs + i] = &seq.all_seq_id;
|
ubatch.seq_id[ubatch.n_seqs + i] = &seq.all_seq_id;
|
||||||
|
@ -6084,6 +6101,7 @@ static void llm_load_hparams(
|
||||||
}
|
}
|
||||||
} break;
|
} break;
|
||||||
case LLM_ARCH_GRANITE:
|
case LLM_ARCH_GRANITE:
|
||||||
|
case LLM_ARCH_GRANITE_MOE:
|
||||||
{
|
{
|
||||||
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
|
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
|
||||||
ml.get_key(LLM_KV_LOGIT_SCALE, hparams.f_logit_scale);
|
ml.get_key(LLM_KV_LOGIT_SCALE, hparams.f_logit_scale);
|
||||||
|
@ -6092,6 +6110,7 @@ static void llm_load_hparams(
|
||||||
ml.get_key(LLM_KV_ATTENTION_SCALE, hparams.f_attention_scale);
|
ml.get_key(LLM_KV_ATTENTION_SCALE, hparams.f_attention_scale);
|
||||||
|
|
||||||
switch (hparams.n_layer) {
|
switch (hparams.n_layer) {
|
||||||
|
case 32: model.type = e_model::MODEL_3B; break;
|
||||||
case 40: model.type = e_model::MODEL_3B; break;
|
case 40: model.type = e_model::MODEL_3B; break;
|
||||||
// Add additional layer/vocab/etc checks here for other model sizes
|
// Add additional layer/vocab/etc checks here for other model sizes
|
||||||
default: model.type = e_model::MODEL_UNKNOWN;
|
default: model.type = e_model::MODEL_UNKNOWN;
|
||||||
|
@ -6563,16 +6582,16 @@ static void llm_load_vocab(
|
||||||
// for now, we apply this workaround to find the EOT token based on its text
|
// for now, we apply this workaround to find the EOT token based on its text
|
||||||
if (vocab.special_eot_id == -1) {
|
if (vocab.special_eot_id == -1) {
|
||||||
for (const auto & t : vocab.token_to_id) {
|
for (const auto & t : vocab.token_to_id) {
|
||||||
if (
|
if (false
|
||||||
// TODO: gemma "<end_of_turn>" is exported as a normal token, so the following check does not work
|
// TODO: gemma "<end_of_turn>" is exported as a normal token, so the following check does not work
|
||||||
// need to fix convert script
|
// need to fix convert script
|
||||||
//vocab.id_to_token[t.second].type == LLAMA_TOKEN_TYPE_CONTROL &&
|
//vocab.id_to_token[t.second].type == LLAMA_TOKEN_TYPE_CONTROL &&
|
||||||
(t.first == "<|eot_id|>" ||
|
|| t.first == "<|eot_id|>"
|
||||||
t.first == "<|im_end|>" ||
|
|| t.first == "<|im_end|>"
|
||||||
t.first == "<|end|>" ||
|
|| t.first == "<|end|>"
|
||||||
t.first == "<end_of_turn>" ||
|
|| t.first == "<end_of_turn>"
|
||||||
t.first == "<|endoftext|>"
|
|| t.first == "<|endoftext|>"
|
||||||
)
|
|| t.first == "<EOT>"
|
||||||
) {
|
) {
|
||||||
vocab.special_eot_id = t.second;
|
vocab.special_eot_id = t.second;
|
||||||
if ((vocab.id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
|
if ((vocab.id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
|
||||||
|
@ -6600,6 +6619,44 @@ static void llm_load_vocab(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// maintain a list of tokens that cause end-of-generation
|
||||||
|
// this is currently determined based on the token text, which is obviously not ideal
|
||||||
|
// ref: https://github.com/ggerganov/llama.cpp/issues/9606
|
||||||
|
vocab.special_eog_ids.clear();
|
||||||
|
for (const auto & t : vocab.token_to_id) {
|
||||||
|
if (false
|
||||||
|
|| t.first == "<|eot_id|>"
|
||||||
|
|| t.first == "<|im_end|>"
|
||||||
|
|| t.first == "<|end|>"
|
||||||
|
|| t.first == "<end_of_turn>"
|
||||||
|
|| t.first == "<|endoftext|>"
|
||||||
|
|| t.first == "<|eom_id|>"
|
||||||
|
|| t.first == "<EOT>"
|
||||||
|
) {
|
||||||
|
vocab.special_eog_ids.insert(t.second);
|
||||||
|
if ((vocab.id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
|
||||||
|
LLAMA_LOG_WARN("%s: control-looking token: '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n",
|
||||||
|
__func__, t.first.c_str());
|
||||||
|
vocab.id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_CONTROL;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (vocab.special_eos_id != -1 && vocab.special_eog_ids.count(vocab.special_eos_id) == 0) {
|
||||||
|
vocab.special_eog_ids.insert(vocab.special_eos_id);
|
||||||
|
LLAMA_LOG_WARN("%s: special_eos_id is not in special_eog_ids - the tokenizer config may be incorrect\n", __func__);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (vocab.special_eot_id != -1 && vocab.special_eog_ids.count(vocab.special_eot_id) == 0) {
|
||||||
|
vocab.special_eog_ids.insert(vocab.special_eot_id);
|
||||||
|
LLAMA_LOG_WARN("%s: special_eot_id is not in special_eog_ids - the tokenizer config may be incorrect\n", __func__);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (vocab.special_eom_id != -1 && vocab.special_eog_ids.count(vocab.special_eom_id) == 0) {
|
||||||
|
vocab.special_eog_ids.insert(vocab.special_eom_id);
|
||||||
|
LLAMA_LOG_WARN("%s: special_eom_id is not in special_eog_ids - the tokenizer config may be incorrect\n", __func__);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// build special tokens cache
|
// build special tokens cache
|
||||||
|
@ -6803,6 +6860,11 @@ static void llm_load_print_meta(llama_model_loader & ml, llama_model & model) {
|
||||||
if (vocab.special_suffix_id != -1) { LLAMA_LOG_INFO( "%s: SUF token = %d '%s'\n", __func__, vocab.special_suffix_id, vocab.id_to_token[vocab.special_suffix_id].text.c_str() ); }
|
if (vocab.special_suffix_id != -1) { LLAMA_LOG_INFO( "%s: SUF token = %d '%s'\n", __func__, vocab.special_suffix_id, vocab.id_to_token[vocab.special_suffix_id].text.c_str() ); }
|
||||||
if (vocab.special_middle_id != -1) { LLAMA_LOG_INFO( "%s: MID token = %d '%s'\n", __func__, vocab.special_middle_id, vocab.id_to_token[vocab.special_middle_id].text.c_str() ); }
|
if (vocab.special_middle_id != -1) { LLAMA_LOG_INFO( "%s: MID token = %d '%s'\n", __func__, vocab.special_middle_id, vocab.id_to_token[vocab.special_middle_id].text.c_str() ); }
|
||||||
if (vocab.special_eot_id != -1) { LLAMA_LOG_INFO( "%s: EOT token = %d '%s'\n", __func__, vocab.special_eot_id, vocab.id_to_token[vocab.special_eot_id].text.c_str() ); }
|
if (vocab.special_eot_id != -1) { LLAMA_LOG_INFO( "%s: EOT token = %d '%s'\n", __func__, vocab.special_eot_id, vocab.id_to_token[vocab.special_eot_id].text.c_str() ); }
|
||||||
|
if (vocab.special_eom_id != -1) { LLAMA_LOG_INFO( "%s: EOM token = %d '%s'\n", __func__, vocab.special_eom_id, vocab.id_to_token[vocab.special_eom_id].text.c_str() ); }
|
||||||
|
|
||||||
|
for (const auto & id : vocab.special_eog_ids) {
|
||||||
|
LLAMA_LOG_INFO( "%s: EOG token = %d '%s'\n", __func__, id, vocab.id_to_token[id].text.c_str() );
|
||||||
|
}
|
||||||
|
|
||||||
LLAMA_LOG_INFO("%s: max token length = %d\n", __func__, vocab.max_token_len);
|
LLAMA_LOG_INFO("%s: max token length = %d\n", __func__, vocab.max_token_len);
|
||||||
|
|
||||||
|
@ -6821,7 +6883,7 @@ static void llm_load_print_meta(llama_model_loader & ml, llama_model & model) {
|
||||||
LLAMA_LOG_INFO("%s: n_ff_shexp = %d\n", __func__, hparams.n_ff_shexp);
|
LLAMA_LOG_INFO("%s: n_ff_shexp = %d\n", __func__, hparams.n_ff_shexp);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (model.arch == LLM_ARCH_GRANITE) {
|
if (model.arch == LLM_ARCH_GRANITE || model.arch == LLM_ARCH_GRANITE_MOE) {
|
||||||
LLAMA_LOG_INFO("%s: f_embedding_scale = %f\n", __func__, hparams.f_embedding_scale);
|
LLAMA_LOG_INFO("%s: f_embedding_scale = %f\n", __func__, hparams.f_embedding_scale);
|
||||||
LLAMA_LOG_INFO("%s: f_residual_scale = %f\n", __func__, hparams.f_residual_scale);
|
LLAMA_LOG_INFO("%s: f_residual_scale = %f\n", __func__, hparams.f_residual_scale);
|
||||||
LLAMA_LOG_INFO("%s: f_attention_scale = %f\n", __func__, hparams.f_attention_scale);
|
LLAMA_LOG_INFO("%s: f_attention_scale = %f\n", __func__, hparams.f_attention_scale);
|
||||||
|
@ -7004,6 +7066,7 @@ static bool llm_load_tensors(
|
||||||
case LLM_ARCH_REFACT:
|
case LLM_ARCH_REFACT:
|
||||||
case LLM_ARCH_MINICPM:
|
case LLM_ARCH_MINICPM:
|
||||||
case LLM_ARCH_GRANITE:
|
case LLM_ARCH_GRANITE:
|
||||||
|
case LLM_ARCH_GRANITE_MOE:
|
||||||
{
|
{
|
||||||
model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
|
model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
|
||||||
|
|
||||||
|
@ -9993,17 +10056,36 @@ struct llm_build_context {
|
||||||
const int64_t n_head_kv = hparams.n_head_kv(il);
|
const int64_t n_head_kv = hparams.n_head_kv(il);
|
||||||
const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
|
const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
|
||||||
struct ggml_tensor * rope_factors = build_rope_factors(il);
|
struct ggml_tensor * rope_factors = build_rope_factors(il);
|
||||||
struct ggml_tensor * tmp =
|
struct ggml_tensor * k =
|
||||||
// we rotate only the first n_rot dimensions
|
|
||||||
ggml_rope_ext_inplace(ctx0,
|
|
||||||
ggml_view_3d(ctx0, kv_self.k_l[il],
|
ggml_view_3d(ctx0, kv_self.k_l[il],
|
||||||
n_embd_head_k, n_head_kv, n_ctx,
|
n_embd_head_k, n_head_kv, n_ctx,
|
||||||
ggml_row_size(kv_self.k_l[il]->type, n_embd_head_k),
|
ggml_row_size(kv_self.k_l[il]->type, n_embd_head_k),
|
||||||
ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa),
|
ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa),
|
||||||
0),
|
0);
|
||||||
|
|
||||||
|
struct ggml_tensor * tmp;
|
||||||
|
if (ggml_is_quantized(k->type)) {
|
||||||
|
// dequantize to f32 -> RoPE -> quantize back
|
||||||
|
tmp = ggml_cast(ctx0, k, GGML_TYPE_F32);
|
||||||
|
cb(tmp, "K_f32", il);
|
||||||
|
for (auto * backend : lctx.backends) {
|
||||||
|
// Figure out which backend KV cache belongs to
|
||||||
|
if (ggml_backend_supports_buft(backend, lctx.model.buft_layer[il].buft)) {
|
||||||
|
ggml_backend_sched_set_tensor_backend(lctx.sched, tmp, backend);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
tmp = ggml_rope_ext_inplace(ctx0, tmp,
|
||||||
lctx.inp_K_shift, rope_factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
lctx.inp_K_shift, rope_factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||||
ext_factor, attn_factor, beta_fast, beta_slow);
|
ext_factor, attn_factor, beta_fast, beta_slow);
|
||||||
|
cb(tmp, "K_shifted_f32", il);
|
||||||
|
tmp = ggml_cpy(ctx0, tmp, k);
|
||||||
|
} else {
|
||||||
|
// we rotate only the first n_rot dimensions
|
||||||
|
tmp = ggml_rope_ext_inplace(ctx0, k,
|
||||||
|
lctx.inp_K_shift, rope_factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||||
|
ext_factor, attn_factor, beta_fast, beta_slow);
|
||||||
|
}
|
||||||
cb(tmp, "K_shifted", il);
|
cb(tmp, "K_shifted", il);
|
||||||
ggml_build_forward_expand(gf, tmp);
|
ggml_build_forward_expand(gf, tmp);
|
||||||
}
|
}
|
||||||
|
@ -15949,6 +16031,7 @@ static struct ggml_cgraph * llama_build_graph(
|
||||||
switch (model.arch) {
|
switch (model.arch) {
|
||||||
case LLM_ARCH_LLAMA:
|
case LLM_ARCH_LLAMA:
|
||||||
case LLM_ARCH_GRANITE:
|
case LLM_ARCH_GRANITE:
|
||||||
|
case LLM_ARCH_GRANITE_MOE:
|
||||||
{
|
{
|
||||||
result = llm.build_llama();
|
result = llm.build_llama();
|
||||||
} break;
|
} break;
|
||||||
|
@ -18719,9 +18802,9 @@ struct llama_model * llama_load_model_from_file(
|
||||||
unsigned percentage = (unsigned) (100 * progress);
|
unsigned percentage = (unsigned) (100 * progress);
|
||||||
while (percentage > *cur_percentage_p) {
|
while (percentage > *cur_percentage_p) {
|
||||||
*cur_percentage_p = percentage;
|
*cur_percentage_p = percentage;
|
||||||
LLAMA_LOG(".");
|
LLAMA_LOG_CONT(".");
|
||||||
if (percentage >= 100) {
|
if (percentage >= 100) {
|
||||||
LLAMA_LOG("\n");
|
LLAMA_LOG_CONT("\n");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return true;
|
return true;
|
||||||
|
@ -19236,6 +19319,7 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) {
|
||||||
case LLM_ARCH_DEEPSEEK2:
|
case LLM_ARCH_DEEPSEEK2:
|
||||||
case LLM_ARCH_CHATGLM:
|
case LLM_ARCH_CHATGLM:
|
||||||
case LLM_ARCH_GRANITE:
|
case LLM_ARCH_GRANITE:
|
||||||
|
case LLM_ARCH_GRANITE_MOE:
|
||||||
return LLAMA_ROPE_TYPE_NORM;
|
return LLAMA_ROPE_TYPE_NORM;
|
||||||
|
|
||||||
// the pairs of head values are offset by n_rot/2
|
// the pairs of head values are offset by n_rot/2
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue