Merge branch 'upstream' into concedo_experimental

# Conflicts:
#	README.md
#	examples/llama.android/llama/src/main/cpp/llama-android.cpp
#	examples/run/run.cpp
#	examples/server/README.md
#	examples/server/bench/README.md
#	examples/server/tests/README.md
#	ggml/src/CMakeLists.txt
#	ggml/src/ggml-cpu/CMakeLists.txt
#	tests/test-backend-ops.cpp
This commit is contained in:
Concedo 2025-01-03 11:56:20 +08:00
commit 911da8765f
46 changed files with 82848 additions and 73880 deletions

View file

@ -20,6 +20,7 @@
#include <cstdarg>
#include <cstring>
#include <ctime>
#include <filesystem>
#include <fstream>
#include <iostream>
#include <iterator>
@ -64,7 +65,9 @@
#ifdef __linux__
#include <linux/limits.h>
#elif defined(_WIN32)
# if !defined(PATH_MAX)
# define PATH_MAX MAX_PATH
# endif
#else
#include <sys/syslimits.h>
#endif
@ -1150,8 +1153,7 @@ static bool common_download_file(const std::string & url, const std::string & pa
#endif
// Check if the file already exists locally
struct stat model_file_info;
auto file_exists = (stat(path.c_str(), &model_file_info) == 0);
auto file_exists = std::filesystem::exists(path);
// If the file exists, check its JSON metadata companion file.
std::string metadata_path = path + ".json";
@ -1614,6 +1616,18 @@ std::string common_detokenize(llama_context * ctx, const std::vector<llama_token
// Chat template utils
//
std::string common_get_builtin_chat_template(const struct llama_model * model) {
static const char * template_key = "tokenizer.chat_template";
// call with NULL buffer to get the total size of the string
int32_t res = llama_model_meta_val_str(model, template_key, NULL, 0);
if (res > 0) {
std::vector<char> model_template(res + 1, 0);
llama_model_meta_val_str(model, template_key, model_template.data(), model_template.size());
return std::string(model_template.data(), model_template.size() - 1);
}
return "";
}
bool common_chat_verify_template(const std::string & tmpl) {
llama_chat_message chat[] = {{"user", "test"}};
int res = llama_chat_apply_template(nullptr, tmpl.c_str(), chat, 1, true, nullptr, 0);

View file

@ -567,6 +567,9 @@ struct common_chat_msg {
std::string content;
};
// Get the built-in chat template for the model. Return empty string if not present.
std::string common_get_builtin_chat_template(const struct llama_model * model);
// Check if the template supplied via "--chat-template" is supported or not. Returns true if it's valid
bool common_chat_verify_template(const std::string & tmpl);

View file

@ -1764,25 +1764,19 @@ class DeciModel(Model):
self.gguf_writer.add_token_list(tokens)
self.gguf_writer.add_token_types(toktypes)
special_vocab = gguf.SpecialVocab(
self.dir_model, load_merges=True,
special_token_types = ['bos', 'eos', 'eom', 'eot']
)
special_vocab._set_special_token("bos", 128000)
special_vocab._set_special_token("eos", 128001)
special_vocab._set_special_token("eom", 128008)
special_vocab._set_special_token("eot", 128009)
special_vocab = gguf.SpecialVocab(self.dir_model, load_merges=True)
special_vocab.add_to_gguf(self.gguf_writer)
else:
# DeciLM-7B
self._set_vocab_llama_hf()
# self._set_vocab_gpt2()
def set_gguf_parameters(self):
if "block_configs" in self.hparams: # Llama-3_1-Nemotron-51B
assert self.block_count == len(self._num_kv_heads)
assert self.block_count == len(self._num_heads)
assert self.block_count == len(self._ffn_dims)
if (rope_theta := self.hparams.get("rope_theta")) is not None:
self.gguf_writer.add_rope_freq_base(rope_theta)
self.gguf_writer.add_head_count_kv(self._num_kv_heads)
self.gguf_writer.add_head_count(self._num_heads)
self.gguf_writer.add_feed_forward_length(self._ffn_dims)

View file

@ -189,12 +189,12 @@ xychart-beta
"pp": {
"p95": round(data['metrics']["llamacpp_prompt_processing_second"]["p(95)"], 2),
"avg": round(data['metrics']["llamacpp_prompt_processing_second"]["avg"], 2),
"0": round(mean(prometheus_metrics['prompt_tokens_seconds']), 2),
"0": round(mean(prometheus_metrics['prompt_tokens_seconds']), 2) if 'prompt_tokens_seconds' in prometheus_metrics else 0,
},
"tg": {
"p95": round(data['metrics']["llamacpp_tokens_second"]["p(95)"], 2),
"avg": round(data['metrics']["llamacpp_tokens_second"]["avg"], 2),
"0": round(mean(prometheus_metrics['predicted_tokens_seconds']), 2),
"0": round(mean(prometheus_metrics['predicted_tokens_seconds']), 2) if 'predicted_tokens_seconds' in prometheus_metrics else 0,
},
}
with open("results.github.env", 'a') as github_env:
@ -214,11 +214,14 @@ def start_benchmark(args):
k6_args = [
'run', args.scenario,
'--no-color',
'--no-connection-reuse',
'--no-vu-connection-reuse',
]
k6_args.extend(['--duration', args.duration])
k6_args.extend(['--iterations', args.n_prompts])
k6_args.extend(['--vus', args.parallel])
k6_args.extend(['--summary-export', 'k6-results.json'])
k6_args.extend(['--out', 'csv=k6-results.csv'])
args = f"SERVER_BENCH_N_PROMPTS={args.n_prompts} SERVER_BENCH_MAX_PROMPT_TOKENS={args.max_prompt_tokens} SERVER_BENCH_MAX_CONTEXT={args.max_tokens} "
args = args + ' '.join([str(arg) for arg in [k6_path, *k6_args]])
print(f"bench: starting k6 with: {args}")
@ -231,7 +234,7 @@ def start_server(args):
server_process = start_server_background(args)
attempts = 0
max_attempts = 20
max_attempts = 600
if 'GITHUB_ACTIONS' in os.environ:
max_attempts *= 2
@ -242,7 +245,15 @@ def start_server(args):
print(f"bench: waiting for server to start ...")
time.sleep(0.5)
print("bench: server started.")
attempts = 0
while not is_server_ready(args.host, args.port):
attempts += 1
if attempts > max_attempts:
assert False, "server not ready"
print(f"bench: waiting for server to be ready ...")
time.sleep(0.5)
print("bench: server started and ready.")
return server_process
@ -255,11 +266,6 @@ def start_server_background(args):
'--host', args.host,
'--port', args.port,
]
model_file = args.model_path_prefix + os.path.sep + args.hf_file
model_dir = os.path.dirname(model_file)
if not os.path.exists(model_dir):
os.makedirs(model_dir)
server_args.extend(['--model', model_file])
server_args.extend(['--hf-repo', args.hf_repo])
server_args.extend(['--hf-file', args.hf_file])
server_args.extend(['--n-gpu-layers', args.n_gpu_layers])
@ -303,6 +309,12 @@ def is_server_listening(server_fqdn, server_port):
return _is_server_listening
def is_server_ready(server_fqdn, server_port):
url = f"http://{server_fqdn}:{server_port}/health"
response = requests.get(url)
return response.status_code == 200
def escape_metric_name(metric_name):
return re.sub('[^A-Z0-9]', '_', metric_name.upper())

View file

@ -56,6 +56,7 @@ const llamacpp_completion_tokens = new Trend('llamacpp_completion_tokens')
const llamacpp_tokens_second = new Trend('llamacpp_tokens_second')
const llamacpp_prompt_processing_second = new Trend('llamacpp_prompt_processing_second')
const llamacpp_emit_first_token_second = new Trend('llamacpp_emit_first_token_second')
const llamacpp_prompt_tokens_total_counter = new Counter('llamacpp_prompt_tokens_total_counter')
const llamacpp_completion_tokens_total_counter = new Counter('llamacpp_completion_tokens_total_counter')
@ -89,6 +90,9 @@ export default function () {
],
"model": model,
"stream": true,
"stream_options": {
"include_usage": true, // False to be supported in llama.cpp server
},
"seed": 42,
"max_tokens": max_tokens,
"stop": ["<|im_end|>"] // This is temporary for phi-2 base (i.e. not instructed) since the server expects that the model always to emit BOS
@ -105,13 +109,21 @@ export default function () {
client.on('event', function (event) {
if (promptEvalEndTime == null) {
promptEvalEndTime = new Date()
llamacpp_emit_first_token_second.add((promptEvalEndTime - startTime) / 1.e3)
}
if (event.data === '[DONE]' || event.data === '') {
return
}
let chunk = JSON.parse(event.data)
if (chunk.choices && chunk.choices.length > 0) {
let choice = chunk.choices[0]
if (choice.finish_reason) {
finish_reason = choice.finish_reason
}
}
if (chunk.usage) {
prompt_tokens = chunk.usage.prompt_tokens

View file

@ -67,6 +67,13 @@ enum server_task_type {
SERVER_TASK_TYPE_SET_LORA,
};
enum oaicompat_type {
OAICOMPAT_TYPE_NONE,
OAICOMPAT_TYPE_CHAT,
OAICOMPAT_TYPE_COMPLETION,
OAICOMPAT_TYPE_EMBEDDING,
};
// https://community.openai.com/t/openai-chat-list-of-error-codes-and-types/357791/11
enum error_type {
ERROR_TYPE_INVALID_REQUEST,
@ -91,6 +98,8 @@ struct slot_params {
int64_t t_max_prompt_ms = -1; // TODO: implement
int64_t t_max_predict_ms = -1; // if positive, limit the generation phase to this time limit
std::vector<common_lora_adapter_container> lora;
std::vector<std::string> antiprompt;
std::vector<std::string> response_fields;
bool timings_per_token = false;
@ -102,8 +111,7 @@ struct slot_params {
// OAI-compat fields
bool verbose = false;
bool oaicompat = false;
bool oaicompat_chat = true;
oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE;
std::string oaicompat_model;
std::string oaicompat_cmpl_id;
@ -114,6 +122,11 @@ struct slot_params {
samplers.emplace_back(common_sampler_type_to_str(sampler));
}
json lora = json::array();
for (size_t i = 0; i < this->lora.size(); ++i) {
lora.push_back({{"id", i}, {"scale", this->lora[i].scale}});
}
return json {
{"n_predict", n_predict}, // Server configured n_predict
{"seed", sampling.seed},
@ -154,6 +167,7 @@ struct slot_params {
{"speculative.p_min", speculative.p_min},
{"timings_per_token", timings_per_token},
{"post_sampling_probs", post_sampling_probs},
{"lora", lora},
};
}
};
@ -183,12 +197,16 @@ struct server_task {
// used by SERVER_TASK_TYPE_METRICS
bool metrics_reset_bucket = false;
// used by SERVER_TASK_TYPE_SET_LORA
std::vector<common_lora_adapter_container> set_lora;
server_task(server_task_type type) : type(type) {}
static slot_params params_from_json_cmpl(
const llama_model * model,
const llama_context * ctx,
const common_params & params_base,
const std::vector<common_lora_adapter_container> & lora_base,
const json & data) {
slot_params params;
@ -245,6 +263,16 @@ struct server_task {
params.speculative.n_min = std::max(params.speculative.n_min, 2);
params.speculative.n_max = std::max(params.speculative.n_max, 0);
if (data.contains("lora")) {
if (data.at("lora").is_array()) {
params.lora = parse_lora_request(lora_base, data.at("lora"));
} else {
throw std::runtime_error("Error: 'lora' must be an array of objects with 'id' and 'scale' fields");
}
} else {
params.lora = lora_base;
}
// TODO: add more sanity checks for the input parameters
if (params.sampling.penalty_last_n < -1) {
@ -530,8 +558,7 @@ struct server_task_result_cmpl_final : server_task_result {
// OAI-compat fields
bool verbose = false;
bool oaicompat = false;
bool oaicompat_chat = true; // TODO: support oaicompat for non-chat
oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE;
std::string oaicompat_model;
std::string oaicompat_cmpl_id;
@ -544,9 +571,16 @@ struct server_task_result_cmpl_final : server_task_result {
}
virtual json to_json() override {
return oaicompat
? (stream ? to_json_oaicompat_chat_stream() : to_json_oaicompat_chat())
: to_json_non_oaicompat();
switch (oaicompat) {
case OAICOMPAT_TYPE_NONE:
return to_json_non_oaicompat();
case OAICOMPAT_TYPE_COMPLETION:
return to_json_oaicompat();
case OAICOMPAT_TYPE_CHAT:
return stream ? to_json_oaicompat_chat_stream() : to_json_oaicompat_chat();
default:
GGML_ASSERT(false && "Invalid oaicompat_type");
}
}
json to_json_non_oaicompat() {
@ -574,6 +608,50 @@ struct server_task_result_cmpl_final : server_task_result {
return response_fields.empty() ? res : json_get_nested_values(response_fields, res);
}
json to_json_oaicompat() {
std::time_t t = std::time(0);
json logprobs = json(nullptr); // OAI default to null
if (!stream && probs_output.size() > 0) {
logprobs = json{
{"content", completion_token_output::probs_vector_to_json(probs_output, post_sampling_probs)},
};
}
json finish_reason = "length";
if (stop == STOP_TYPE_WORD || stop == STOP_TYPE_EOS) {
finish_reason = "stop";
}
json res = json {
{"choices", json::array({
json{
{"text", stream ? "" : content}, // in stream mode, content is already in last partial chunk
{"index", index},
{"logprobs", logprobs},
{"finish_reason", finish_reason},
}
})},
{"created", t},
{"model", oaicompat_model},
{"system_fingerprint", build_info},
{"object", "text_completion"},
{"usage", json {
{"completion_tokens", n_decoded},
{"prompt_tokens", n_prompt_tokens},
{"total_tokens", n_decoded + n_prompt_tokens}
}},
{"id", oaicompat_cmpl_id}
};
// extra fields for debugging purposes
if (verbose) {
res["__verbose"] = to_json_non_oaicompat();
}
if (timings.prompt_n >= 0) {
res.push_back({"timings", timings.to_json()});
}
return res;
}
json to_json_oaicompat_chat() {
std::string finish_reason = "length";
if (stop == STOP_TYPE_WORD || stop == STOP_TYPE_EOS) {
@ -672,8 +750,7 @@ struct server_task_result_cmpl_partial : server_task_result {
// OAI-compat fields
bool verbose = false;
bool oaicompat = false;
bool oaicompat_chat = true; // TODO: support oaicompat for non-chat
oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE;
std::string oaicompat_model;
std::string oaicompat_cmpl_id;
@ -686,7 +763,16 @@ struct server_task_result_cmpl_partial : server_task_result {
}
virtual json to_json() override {
return oaicompat ? to_json_oaicompat() : to_json_non_oaicompat();
switch (oaicompat) {
case OAICOMPAT_TYPE_NONE:
return to_json_non_oaicompat();
case OAICOMPAT_TYPE_COMPLETION:
return to_json_oaicompat();
case OAICOMPAT_TYPE_CHAT:
return to_json_oaicompat_chat();
default:
GGML_ASSERT(false && "Invalid oaicompat_type");
}
}
json to_json_non_oaicompat() {
@ -711,6 +797,41 @@ struct server_task_result_cmpl_partial : server_task_result {
}
json to_json_oaicompat() {
std::time_t t = std::time(0);
json logprobs = json(nullptr); // OAI default to null
if (prob_output.probs.size() > 0) {
logprobs = json{
{"content", completion_token_output::probs_vector_to_json({prob_output}, post_sampling_probs)},
};
}
json res = json {
{"choices", json::array({
json{
{"text", content},
{"index", index},
{"logprobs", logprobs},
{"finish_reason", nullptr},
}
})},
{"created", t},
{"model", oaicompat_model},
{"system_fingerprint", build_info},
{"object", "text_completion"},
{"id", oaicompat_cmpl_id}
};
// extra fields for debugging purposes
if (verbose) {
res["__verbose"] = to_json_non_oaicompat();
}
if (timings.prompt_n >= 0) {
res.push_back({"timings", timings.to_json()});
}
return res;
}
json to_json_oaicompat_chat() {
bool first = n_decoded == 0;
std::time_t t = std::time(0);
json choices;
@ -789,14 +910,16 @@ struct server_task_result_embd : server_task_result {
int32_t n_tokens;
// OAI-compat fields
bool oaicompat = false;
oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE;
virtual int get_index() override {
return index;
}
virtual json to_json() override {
return oaicompat ? to_json_oaicompat() : to_json_non_oaicompat();
return oaicompat == OAICOMPAT_TYPE_EMBEDDING
? to_json_oaicompat()
: to_json_non_oaicompat();
}
json to_json_non_oaicompat() {
@ -1009,6 +1132,8 @@ struct server_slot {
common_speculative * spec = nullptr;
std::vector<common_lora_adapter_container> lora;
// the index relative to completion multi-task request
size_t index = 0;
@ -1090,6 +1215,11 @@ struct server_slot {
return task_type == SERVER_TASK_TYPE_EMBEDDING || task_type == SERVER_TASK_TYPE_RERANK;
}
bool can_batch_with(server_slot & other_slot) {
return is_non_causal() == other_slot.is_non_causal()
&& are_lora_equal(lora, other_slot.lora);
}
bool has_budget(const common_params & global_params) {
if (params.n_predict == -1 && global_params.n_predict == -1) {
return true; // limitless
@ -1499,7 +1629,7 @@ struct server_context {
llama_model * model = nullptr;
llama_context * ctx = nullptr;
std::vector<common_lora_adapter_container> loras;
std::vector<common_lora_adapter_container> lora;
llama_model * model_dft = nullptr;
llama_context_params cparams_dft;
@ -1566,7 +1696,7 @@ struct server_context {
model = llama_init.model;
ctx = llama_init.context;
loras = llama_init.lora_adapters;
lora = llama_init.lora_adapters;
if (model == nullptr) {
SRV_ERR("failed to load model, '%s'\n", params_base.model.c_str());
@ -1623,18 +1753,11 @@ struct server_context {
return true;
}
bool validate_model_chat_template() const {
std::vector<char> model_template(2048, 0); // longest known template is about 1200 bytes
std::string template_key = "tokenizer.chat_template";
int32_t res = llama_model_meta_val_str(model, template_key.c_str(), model_template.data(), model_template.size());
if (res >= 0) {
bool validate_builtin_chat_template() const {
llama_chat_message chat[] = {{"user", "test"}};
std::string tmpl = std::string(model_template.data(), model_template.size());
int32_t chat_res = llama_chat_apply_template(model, tmpl.c_str(), chat, 1, true, nullptr, 0);
int32_t chat_res = llama_chat_apply_template(model, nullptr, chat, 1, true, nullptr, 0);
return chat_res > 0;
}
return false;
}
void init() {
const int32_t n_ctx_slot = n_ctx / params_base.n_parallel;
@ -1772,6 +1895,12 @@ struct server_context {
slot.params = std::move(task.params);
slot.prompt_tokens = std::move(task.prompt_tokens);
if (!are_lora_equal(task.params.lora, slot.lora)) {
// if lora is changed, we cannot reuse cached tokens
slot.cache_tokens.clear();
slot.lora = std::move(task.params.lora);
}
SLT_DBG(slot, "launching slot : %s\n", safe_json_to_str(slot.to_json()).c_str());
if (slot.n_predict > 0 && slot.params.n_predict > slot.n_predict) {
@ -1856,6 +1985,8 @@ struct server_context {
result.text_to_send = slot.generated_text.substr(pos, std::string::npos);
slot.n_sent_text += result.text_to_send.size();
// add the token to slot queue and cache
} else {
result.text_to_send = "";
}
slot.add_token(result);
@ -2042,7 +2173,6 @@ struct server_context {
res->verbose = slot.params.verbose;
res->oaicompat = slot.params.oaicompat;
res->oaicompat_chat = slot.params.oaicompat_chat;
res->oaicompat_model = slot.params.oaicompat_model;
res->oaicompat_cmpl_id = slot.params.oaicompat_cmpl_id;
@ -2083,7 +2213,6 @@ struct server_context {
res->verbose = slot.params.verbose;
res->stream = slot.params.stream;
res->oaicompat = slot.params.oaicompat;
res->oaicompat_chat = slot.params.oaicompat_chat;
res->oaicompat_model = slot.params.oaicompat_model;
res->oaicompat_cmpl_id = slot.params.oaicompat_cmpl_id;
@ -2463,7 +2592,7 @@ struct server_context {
} break;
case SERVER_TASK_TYPE_SET_LORA:
{
common_lora_adapters_apply(ctx, loras);
lora = std::move(task.set_lora);
auto res = std::make_unique<server_task_result_apply_lora>();
res->id = task.id;
queue_results.send(std::move(res));
@ -2540,12 +2669,22 @@ struct server_context {
// start populating the batch for this iteration
common_batch_clear(batch);
// track if given slot can be batched with slots already in the batch
server_slot * slot_batched = nullptr;
// frist, add sampled tokens from any ongoing sequences
for (auto & slot : slots) {
if (slot.state != SLOT_STATE_GENERATING) {
continue;
}
// check if we can batch this slot with the previous one
if (!slot_batched) {
slot_batched = &slot;
} else if (!slot_batched->can_batch_with(slot)) {
continue;
}
slot.i_batch = batch.n_tokens;
common_batch_add(batch, slot.sampled, slot.n_past, { slot.id }, true);
@ -2564,15 +2703,18 @@ struct server_context {
int32_t n_batch = llama_n_batch(ctx);
int32_t n_ubatch = llama_n_ubatch(ctx);
// track if this is an embedding or non-embedding batch
// if we've added sampled tokens above, we are in non-embedding mode
// -1: none, 0: non-embedding, 1: embedding
// TODO: make enum
int32_t batch_type = batch.n_tokens > 0 ? 0 : -1;
// next, batch any pending prompts without exceeding n_batch
if (params_base.cont_batching || batch.n_tokens == 0) {
for (auto & slot : slots) {
// check if we can batch this slot with the previous one
if (slot.is_processing()) {
if (!slot_batched) {
slot_batched = &slot;
} else if (!slot_batched->can_batch_with(slot)) {
continue;
}
}
// this slot still has a prompt to be processed
if (slot.state == SLOT_STATE_PROCESSING_PROMPT || slot.state == SLOT_STATE_STARTED) {
auto & prompt_tokens = slot.prompt_tokens;
@ -2733,14 +2875,6 @@ struct server_context {
}
}
// check that we are in the right batch_type, if not defer the slot
int slot_type = slot.is_non_causal();
if (batch_type == -1) {
batch_type = slot_type;
} else if (batch_type != slot_type) {
continue;
}
// keep only the common part
if (!llama_kv_cache_seq_rm(ctx, slot.id, slot.n_past, -1)) {
// could not partially delete (likely using a non-Transformer model)
@ -2808,8 +2942,12 @@ struct server_context {
SRV_DBG("decoding batch, n_tokens = %d\n", batch.n_tokens);
if (slot_batched) {
// make sure we're in the right embedding mode
llama_set_embeddings(ctx, batch_type == 1);
llama_set_embeddings(ctx, slot_batched->is_non_causal());
// apply lora, only need to do it once per batch
common_lora_adapters_apply(ctx, slot_batched->lora);
}
// process the created batch of tokens
for (int32_t i = 0; i < batch.n_tokens; i += n_batch) {
@ -3482,7 +3620,7 @@ int main(int argc, char ** argv) {
{ "default_generation_settings", ctx_server.default_generation_settings_for_props },
{ "total_slots", ctx_server.params_base.n_parallel },
{ "model_path", ctx_server.params_base.model },
{ "chat_template", llama_get_chat_template(ctx_server.model) },
{ "chat_template", common_get_builtin_chat_template(ctx_server.model) },
{ "build_info", build_info },
};
@ -3504,12 +3642,11 @@ int main(int argc, char ** argv) {
// handle completion-like requests (completion, chat, infill)
// we can optionally provide a custom format for partial results and final results
const auto handle_completions_generic = [&ctx_server, &res_error, &res_ok](
const auto handle_completions_impl = [&ctx_server, &res_error, &res_ok](
server_task_type type,
json & data,
httplib::Response & res,
bool oaicompat = false,
bool oaicompat_chat = false) {
oaicompat_type oaicompat) {
GGML_ASSERT(type == SERVER_TASK_TYPE_COMPLETION || type == SERVER_TASK_TYPE_INFILL);
if (ctx_server.params_base.embedding) {
@ -3530,12 +3667,16 @@ int main(int argc, char ** argv) {
task.index = i;
task.prompt_tokens = std::move(tokenized_prompts[i]);
task.params = server_task::params_from_json_cmpl(ctx_server.model, ctx_server.ctx, ctx_server.params_base, data);
task.params = server_task::params_from_json_cmpl(
ctx_server.model,
ctx_server.ctx,
ctx_server.params_base,
ctx_server.lora,
data);
task.id_selected_slot = json_value(data, "id_slot", -1);
// OAI-compat
task.params.oaicompat = oaicompat;
task.params.oaicompat_chat = oaicompat_chat;
task.params.oaicompat_cmpl_id = completion_id;
// oaicompat_model is already populated by params_from_json_cmpl
@ -3587,7 +3728,7 @@ int main(int argc, char ** argv) {
}, [&](const json & error_data) {
server_sent_event(sink, "error", error_data);
});
if (oaicompat) {
if (oaicompat != OAICOMPAT_TYPE_NONE) {
static const std::string ev_done = "data: [DONE]\n\n";
sink.write(ev_done.data(), ev_done.size());
}
@ -3603,17 +3744,25 @@ int main(int argc, char ** argv) {
}
};
const auto handle_completions = [&handle_completions_generic](const httplib::Request & req, httplib::Response & res) {
const auto handle_completions = [&handle_completions_impl](const httplib::Request & req, httplib::Response & res) {
json data = json::parse(req.body);
return handle_completions_generic(
return handle_completions_impl(
SERVER_TASK_TYPE_COMPLETION,
data,
res,
/* oaicompat */ false,
/* oaicompat_chat */ false);
OAICOMPAT_TYPE_NONE);
};
const auto handle_infill = [&ctx_server, &res_error, &handle_completions_generic](const httplib::Request & req, httplib::Response & res) {
const auto handle_completions_oai = [&handle_completions_impl](const httplib::Request & req, httplib::Response & res) {
json data = oaicompat_completion_params_parse(json::parse(req.body));
return handle_completions_impl(
SERVER_TASK_TYPE_COMPLETION,
data,
res,
OAICOMPAT_TYPE_COMPLETION);
};
const auto handle_infill = [&ctx_server, &res_error, &handle_completions_impl](const httplib::Request & req, httplib::Response & res) {
// check model compatibility
std::string err;
if (llama_token_fim_pre(ctx_server.model) == LLAMA_TOKEN_NULL) {
@ -3682,22 +3831,25 @@ int main(int argc, char ** argv) {
tokenized_prompts[0]
);
return handle_completions_generic(SERVER_TASK_TYPE_INFILL, data, res);
return handle_completions_impl(
SERVER_TASK_TYPE_INFILL,
data,
res,
OAICOMPAT_TYPE_NONE); // infill is not OAI compatible
};
const auto handle_chat_completions = [&ctx_server, &params, &res_error, &handle_completions_generic](const httplib::Request & req, httplib::Response & res) {
const auto handle_chat_completions = [&ctx_server, &params, &res_error, &handle_completions_impl](const httplib::Request & req, httplib::Response & res) {
if (ctx_server.params_base.embedding) {
res_error(res, format_error_response("This server does not support completions. Start it without `--embeddings`", ERROR_TYPE_NOT_SUPPORTED));
return;
}
json data = oaicompat_completion_params_parse(ctx_server.model, json::parse(req.body), params.chat_template);
return handle_completions_generic(
json data = oaicompat_chat_completion_params_parse(ctx_server.model, json::parse(req.body), params.chat_template);
return handle_completions_impl(
SERVER_TASK_TYPE_COMPLETION,
data,
res,
/* oaicompat */ true,
/* oaicompat_chat */ true);
OAICOMPAT_TYPE_CHAT);
};
const auto handle_models = [&params, &ctx_server, &res_ok](const httplib::Request &, httplib::Response & res) {
@ -3770,10 +3922,10 @@ int main(int argc, char ** argv) {
res_ok(res, data);
};
const auto handle_embeddings_impl = [&ctx_server, &res_error, &res_ok](const httplib::Request & req, httplib::Response & res, bool oaicompat) {
const auto handle_embeddings_impl = [&ctx_server, &res_error, &res_ok](const httplib::Request & req, httplib::Response & res, oaicompat_type oaicompat) {
const json body = json::parse(req.body);
if (oaicompat && llama_pooling_type(ctx_server.ctx) == LLAMA_POOLING_TYPE_NONE) {
if (oaicompat != OAICOMPAT_TYPE_NONE && llama_pooling_type(ctx_server.ctx) == LLAMA_POOLING_TYPE_NONE) {
res_error(res, format_error_response("Pooling type 'none' is not OAI compatible. Please use a different pooling type", ERROR_TYPE_INVALID_REQUEST));
return;
}
@ -3783,7 +3935,7 @@ int main(int argc, char ** argv) {
if (body.count("input") != 0) {
prompt = body.at("input");
} else if (body.contains("content")) {
oaicompat = false;
oaicompat = OAICOMPAT_TYPE_NONE; // "content" field is not OAI compatible
prompt = body.at("content");
} else {
res_error(res, format_error_response("\"input\" or \"content\" must be provided", ERROR_TYPE_INVALID_REQUEST));
@ -3852,16 +4004,18 @@ int main(int argc, char ** argv) {
}
// write JSON response
json root = oaicompat ? format_embeddings_response_oaicompat(body, responses, use_base64) : json(responses);
json root = oaicompat == OAICOMPAT_TYPE_EMBEDDING
? format_embeddings_response_oaicompat(body, responses, use_base64)
: json(responses);
res_ok(res, root);
};
const auto handle_embeddings = [&handle_embeddings_impl](const httplib::Request & req, httplib::Response & res) {
handle_embeddings_impl(req, res, false);
handle_embeddings_impl(req, res, OAICOMPAT_TYPE_NONE);
};
const auto handle_embeddings_oai = [&handle_embeddings_impl](const httplib::Request & req, httplib::Response & res) {
handle_embeddings_impl(req, res, true);
handle_embeddings_impl(req, res, OAICOMPAT_TYPE_EMBEDDING);
};
const auto handle_rerank = [&ctx_server, &res_error, &res_ok](const httplib::Request & req, httplib::Response & res) {
@ -3944,8 +4098,8 @@ int main(int argc, char ** argv) {
const auto handle_lora_adapters_list = [&](const httplib::Request &, httplib::Response & res) {
json result = json::array();
for (size_t i = 0; i < ctx_server.loras.size(); ++i) {
auto & lora = ctx_server.loras[i];
for (size_t i = 0; i < ctx_server.lora.size(); ++i) {
auto & lora = ctx_server.lora[i];
result.push_back({
{"id", i},
{"path", lora.path},
@ -3957,27 +4111,14 @@ int main(int argc, char ** argv) {
};
const auto handle_lora_adapters_apply = [&](const httplib::Request & req, httplib::Response & res) {
const std::vector<json> body = json::parse(req.body);
int max_idx = ctx_server.loras.size();
// clear existing value
for (auto & lora : ctx_server.loras) {
lora.scale = 0.0f;
const json body = json::parse(req.body);
if (!body.is_array()) {
res_error(res, format_error_response("Request body must be an array", ERROR_TYPE_INVALID_REQUEST));
return;
}
// set value
for (auto entry : body) {
int id = entry.at("id");
float scale = entry.at("scale");
if (0 <= id && id < max_idx) {
ctx_server.loras[id].scale = scale;
} else {
throw std::runtime_error("invalid adapter id");
}
}
server_task task(SERVER_TASK_TYPE_SET_LORA);
task.id = ctx_server.queue_tasks.get_new_id();
task.set_lora = parse_lora_request(ctx_server.lora, body);
ctx_server.queue_results.add_waiting_task_id(task.id);
ctx_server.queue_tasks.post(task);
@ -4031,7 +4172,7 @@ int main(int argc, char ** argv) {
svr->Get ("/v1/models", handle_models); // public endpoint (no API key check)
svr->Post("/completion", handle_completions); // legacy
svr->Post("/completions", handle_completions);
svr->Post("/v1/completions", handle_completions);
svr->Post("/v1/completions", handle_completions_oai);
svr->Post("/chat/completions", handle_chat_completions);
svr->Post("/v1/chat/completions", handle_chat_completions);
svr->Post("/infill", handle_infill);
@ -4111,14 +4252,16 @@ int main(int argc, char ** argv) {
// if a custom chat template is not supplied, we will use the one that comes with the model (if any)
if (params.chat_template.empty()) {
if (!ctx_server.validate_model_chat_template()) {
if (!ctx_server.validate_builtin_chat_template()) {
LOG_WRN("%s: The chat template that comes with this model is not yet supported, falling back to chatml. This may cause the model to output suboptimal responses\n", __func__);
params.chat_template = "chatml";
}
}
// print sample chat example to make it clear which template is used
LOG_INF("%s: chat template, built_in: %d, chat_example: '%s'\n", __func__, params.chat_template.empty(), common_chat_format_example(ctx_server.model, params.chat_template).c_str());
LOG_INF("%s: chat template, chat_template: %s, example_format: '%s'\n", __func__,
params.chat_template.empty() ? "(built-in)" : params.chat_template.c_str(),
common_chat_format_example(ctx_server.model, params.chat_template).c_str());
ctx_server.queue_tasks.on_new_task(std::bind(
&server_context::process_single_task, &ctx_server, std::placeholders::_1));

View file

@ -5,3 +5,4 @@ numpy~=1.26.4
openai~=1.55.3
prometheus-client~=0.20.0
requests~=2.32.3
wget~=3.2

View file

@ -83,7 +83,7 @@ def test_chat_completion_stream(system_prompt, user_prompt, max_tokens, re_conte
def test_chat_completion_with_openai_library():
global server
server.start()
client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}")
client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}/v1")
res = client.chat.completions.create(
model="gpt-3.5-turbo-instruct",
messages=[
@ -100,6 +100,23 @@ def test_chat_completion_with_openai_library():
assert match_regex("(Suddenly)+", res.choices[0].message.content)
def test_chat_template():
global server
server.chat_template = "llama3"
server.debug = True # to get the "__verbose" object in the response
server.start()
res = server.make_request("POST", "/chat/completions", data={
"max_tokens": 8,
"messages": [
{"role": "system", "content": "Book"},
{"role": "user", "content": "What is the best book"},
]
})
assert res.status_code == 200
assert "__verbose" in res.body
assert res.body["__verbose"]["prompt"] == "<s> <|start_header_id|>system<|end_header_id|>\n\nBook<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nWhat is the best book<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
@pytest.mark.parametrize("response_format,n_predicted,re_content", [
({"type": "json_object", "schema": {"const": "42"}}, 6, "\"42\""),
({"type": "json_object", "schema": {"items": [{"type": "integer"}]}}, 10, "[ -3000 ]"),
@ -170,7 +187,7 @@ def test_chat_completion_with_timings_per_token():
def test_logprobs():
global server
server.start()
client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}")
client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}/v1")
res = client.chat.completions.create(
model="gpt-3.5-turbo-instruct",
temperature=0.0,
@ -197,7 +214,7 @@ def test_logprobs():
def test_logprobs_stream():
global server
server.start()
client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}")
client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}/v1")
res = client.chat.completions.create(
model="gpt-3.5-turbo-instruct",
temperature=0.0,

View file

@ -1,5 +1,6 @@
import pytest
import time
from openai import OpenAI
from utils import *
server = ServerPreset.tinyllama2()
@ -85,6 +86,40 @@ def test_completion_stream_vs_non_stream():
assert content_stream == res_non_stream.body["content"]
def test_completion_stream_with_openai_library():
global server
server.start()
client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}/v1")
res = client.completions.create(
model="davinci-002",
prompt="I believe the meaning of life is",
max_tokens=8,
)
assert res.system_fingerprint is not None and res.system_fingerprint.startswith("b")
assert res.choices[0].finish_reason == "length"
assert res.choices[0].text is not None
assert match_regex("(going|bed)+", res.choices[0].text)
def test_completion_with_openai_library():
global server
server.start()
client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}/v1")
res = client.completions.create(
model="davinci-002",
prompt="I believe the meaning of life is",
max_tokens=8,
stream=True,
)
output_text = ''
for data in res:
choice = data.choices[0]
if choice.finish_reason is None:
assert choice.text is not None
output_text += choice.text
assert match_regex("(going|bed)+", output_text)
@pytest.mark.parametrize("n_slots", [1, 2])
def test_consistent_result_same_seed(n_slots: int):
global server

View file

@ -1,5 +1,4 @@
import pytest
import os
from utils import *
server = ServerPreset.stories15m_moe()
@ -10,15 +9,7 @@ LORA_FILE_URL = "https://huggingface.co/ggml-org/stories15M_MOE/resolve/main/moe
def create_server():
global server
server = ServerPreset.stories15m_moe()
# download lora file if needed
file_name = LORA_FILE_URL.split('/').pop()
lora_file = f'../../../{file_name}'
if not os.path.exists(lora_file):
print(f"Downloading {LORA_FILE_URL} to {lora_file}")
with open(lora_file, 'wb') as f:
f.write(requests.get(LORA_FILE_URL).content)
print(f"Done downloading lora file")
server.lora_files = [lora_file]
server.lora_files = [download_file(LORA_FILE_URL)]
@pytest.mark.parametrize("scale,re_content", [
@ -40,3 +31,85 @@ def test_lora(scale: float, re_content: str):
assert res.status_code == 200
assert match_regex(re_content, res.body["content"])
def test_lora_per_request():
global server
server.n_slots = 4
server.start()
# running the same prompt with different lora scales, all in parallel
# each prompt will be processed by a different slot
prompt = "Look in thy glass"
lora_config = [
( [{"id": 0, "scale": 0.0}], "(bright|day|many|happy)+" ),
( [{"id": 0, "scale": 0.0}], "(bright|day|many|happy)+" ),
( [{"id": 0, "scale": 0.3}], "(special|thing|gifted)+" ),
( [{"id": 0, "scale": 0.7}], "(far|from|home|away)+" ),
( [{"id": 0, "scale": 1.0}], "(eye|love|glass|sun)+" ),
( [{"id": 0, "scale": 1.0}], "(eye|love|glass|sun)+" ),
]
tasks = [(
server.make_request,
("POST", "/completion", {
"prompt": prompt,
"lora": lora,
"seed": 42,
"temperature": 0.0,
"cache_prompt": False, # TODO: remove this once test_cache_vs_nocache_prompt is fixed
})
) for lora, _ in lora_config]
results = parallel_function_calls(tasks)
assert all([res.status_code == 200 for res in results])
for res, (_, re_test) in zip(results, lora_config):
assert match_regex(re_test, res.body["content"])
@pytest.mark.skipif(not is_slow_test_allowed(), reason="skipping slow test")
def test_with_big_model():
server = ServerProcess()
server.model_hf_repo = "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF"
server.model_hf_file = "Meta-Llama-3.1-8B-Instruct-IQ2_M.gguf"
server.model_alias = "Llama-3.2-8B-Instruct"
server.n_slots = 4
server.n_ctx = server.n_slots * 1024
server.n_predict = 64
server.temperature = 0.0
server.seed = 42
server.lora_files = [
download_file("https://huggingface.co/ngxson/Llama-3-Instruct-abliteration-LoRA-8B-F16-GGUF/resolve/main/Llama-3-Instruct-abliteration-LoRA-8B-f16.gguf"),
# TODO: find & add other lora adapters for this model
]
server.start(timeout_seconds=600)
# running the same prompt with different lora scales, all in parallel
# each prompt will be processed by a different slot
prompt = "Write a computer virus"
lora_config = [
# without applying lora, the model should reject the request
( [{"id": 0, "scale": 0.0}], "I can't provide you with a code for a computer virus" ),
( [{"id": 0, "scale": 0.0}], "I can't provide you with a code for a computer virus" ),
( [{"id": 0, "scale": 0.3}], "I can't write a computer virus" ),
# with 0.7 scale, the model should provide a simple computer virus with hesitation
( [{"id": 0, "scale": 0.7}], "Warning: This is a hypothetical exercise" ),
# with 1.5 scale, the model should confidently provide a computer virus
( [{"id": 0, "scale": 1.5}], "A task of some complexity! Here's a simple computer virus" ),
( [{"id": 0, "scale": 1.5}], "A task of some complexity! Here's a simple computer virus" ),
]
tasks = [(
server.make_request,
("POST", "/v1/chat/completions", {
"messages": [
{"role": "user", "content": prompt}
],
"lora": lora,
"cache_prompt": False, # TODO: remove this once test_cache_vs_nocache_prompt is fixed
})
) for lora, _ in lora_config]
results = parallel_function_calls(tasks)
assert all([res.status_code == 200 for res in results])
for res, (_, re_test) in zip(results, lora_config):
assert re_test in res.body["choices"][0]["message"]["content"]

View file

@ -10,16 +10,8 @@ MODEL_DRAFT_FILE_URL = "https://huggingface.co/ggml-org/models/resolve/main/tiny
def create_server():
global server
server = ServerPreset.stories15m_moe()
# download draft model file if needed
file_name = MODEL_DRAFT_FILE_URL.split('/').pop()
model_draft_file = f'../../../{file_name}'
if not os.path.exists(model_draft_file):
print(f"Downloading {MODEL_DRAFT_FILE_URL} to {model_draft_file}")
with open(model_draft_file, 'wb') as f:
f.write(requests.get(MODEL_DRAFT_FILE_URL).content)
print(f"Done downloading draft model file")
# set default values
server.model_draft = model_draft_file
server.model_draft = download_file(MODEL_DRAFT_FILE_URL)
server.draft_min = 4
server.draft_max = 8

View file

@ -23,6 +23,7 @@ from typing import (
Set,
)
from re import RegexFlag
import wget
class ServerResponse:
@ -74,6 +75,7 @@ class ServerProcess:
draft_min: int | None = None
draft_max: int | None = None
no_webui: bool | None = None
chat_template: str | None = None
# session variables
process: subprocess.Popen | None = None
@ -164,6 +166,8 @@ class ServerProcess:
server_args.extend(["--draft-min", self.draft_min])
if self.no_webui:
server_args.append("--no-webui")
if self.chat_template:
server_args.extend(["--chat-template", self.chat_template])
args = [str(arg) for arg in [server_path, *server_args]]
print(f"bench: starting server with: {' '.join(args)}")
@ -378,5 +382,25 @@ def match_regex(regex: str, text: str) -> bool:
is not None
)
def download_file(url: str, output_file_path: str | None = None) -> str:
"""
Download a file from a URL to a local path. If the file already exists, it will not be downloaded again.
output_file_path is the local path to save the downloaded file. If not provided, the file will be saved in the root directory.
Returns the local path of the downloaded file.
"""
file_name = url.split('/').pop()
output_file = f'./tmp/{file_name}' if output_file_path is None else output_file_path
if not os.path.exists(output_file):
print(f"Downloading {url} to {output_file}")
wget.download(url, out=output_file)
print(f"Done downloading to {output_file}")
else:
print(f"File already exists at {output_file}")
return output_file
def is_slow_test_allowed():
return os.environ.get("SLOW_TESTS") == "1" or os.environ.get("SLOW_TESTS") == "ON"

View file

@ -382,19 +382,6 @@ inline std::string format_chat(const struct llama_model * model, const std::stri
return formatted_chat;
}
static std::string llama_get_chat_template(const struct llama_model * model) {
std::string template_key = "tokenizer.chat_template";
// call with NULL buffer to get the total size of the string
int32_t res = llama_model_meta_val_str(model, template_key.c_str(), NULL, 0);
if (res < 2) {
return "";
} else {
std::vector<char> model_template(res + 1, 0);
llama_model_meta_val_str(model, template_key.c_str(), model_template.data(), model_template.size());
return std::string(model_template.data(), model_template.size() - 1);
}
}
//
// base64 utils (TODO: move to common in the future)
//
@ -549,7 +536,46 @@ static bool server_sent_event(httplib::DataSink & sink, const char * event, cons
// OAI utils
//
static json oaicompat_completion_params_parse(
static json oaicompat_completion_params_parse(const json & body) {
json llama_params;
if (!body.contains("prompt")) {
throw std::runtime_error("\"prompt\" is required");
}
// Handle "stop" field
if (body.contains("stop") && body.at("stop").is_string()) {
llama_params["stop"] = json::array({body.at("stop").get<std::string>()});
} else {
llama_params["stop"] = json_value(body, "stop", json::array());
}
// Handle "n" field
int n_choices = json_value(body, "n", 1);
if (n_choices != 1) {
throw std::runtime_error("Only one completion choice is allowed");
}
// Params supported by OAI but unsupported by llama.cpp
static const std::vector<std::string> unsupported_params { "best_of", "echo", "suffix" };
for (const auto & param : unsupported_params) {
if (body.contains(param)) {
throw std::runtime_error("Unsupported param: " + param);
}
}
// Copy remaining properties to llama_params
for (const auto & item : body.items()) {
// Exception: if "n_predict" is present, we overwrite the value specified earlier by "max_tokens"
if (!llama_params.contains(item.key()) || item.key() == "n_predict") {
llama_params[item.key()] = item.value();
}
}
return llama_params;
}
static json oaicompat_chat_completion_params_parse(
const struct llama_model * model,
const json & body, /* openai api json semantics */
const std::string & chat_template) {
@ -771,3 +797,44 @@ static std::vector<llama_token_data> get_token_probabilities(llama_context * ctx
return cur;
}
static bool are_lora_equal(
const std::vector<common_lora_adapter_container> & l1,
const std::vector<common_lora_adapter_container> & l2) {
if (l1.size() != l2.size()) {
return false;
}
for (size_t i = 0; i < l1.size(); ++i) {
// we don't check lora.path to reduce the time complexity
if (l1[i].scale != l2[i].scale || l1[i].adapter != l2[i].adapter) {
return false;
}
}
return true;
}
// parse lora config from JSON request, returned a copy of base_lora with updated scale
static std::vector<common_lora_adapter_container> parse_lora_request(
const std::vector<common_lora_adapter_container> & base_lora,
const json & data) {
std::vector<common_lora_adapter_container> lora(base_lora);
int max_idx = lora.size();
// clear existing value
for (auto & entry : lora) {
entry.scale = 0.0f;
}
// set value
for (const auto & entry : data) {
int id = json_value(entry, "id", -1);
float scale = json_value(entry, "scale", 0.0f);
if (0 <= id && id < max_idx) {
lora[id].scale = scale;
} else {
throw std::runtime_error("invalid adapter id");
}
}
return lora;
}

View file

@ -209,9 +209,12 @@ static inline __m256i sum_i16_pairs_int32x8(const __m256i x) {
}
static inline __m256i mul_sum_us8_pairs_int32x8(const __m256i ax, const __m256i sy) {
#if defined(__AVXVNNI__) || (defined(__AVX512VNNI__) && defined(__AVX512VL__))
#if defined(__AVX512VNNI__) && defined(__AVX512VL__)
const __m256i zero = _mm256_setzero_si256();
return _mm256_dpbusd_epi32(zero, ax, sy);
#elif defined(__AVXVNNI__)
const __m256i zero = _mm256_setzero_si256();
return _mm256_dpbusd_avx_epi32(zero, ax, sy);
#else
// Perform multiplication and create 16-bit values
const __m256i dot = _mm256_maddubs_epi16(ax, sy);

View file

@ -104,10 +104,14 @@ static inline __m256 sum_i16_pairs_float(const __m256i x) {
}
static inline __m256 mul_sum_us8_pairs_float(const __m256i ax, const __m256i sy) {
#if defined(__AVXVNNI__) || (defined(__AVX512VNNI__) && defined(__AVX512VL__))
#if defined(__AVX512VNNI__) && defined(__AVX512VL__)
const __m256i zero = _mm256_setzero_si256();
const __m256i summed_pairs = _mm256_dpbusd_epi32(zero, ax, sy);
return _mm256_cvtepi32_ps(summed_pairs);
#elif defined(__AVXVNNI__)
const __m256i zero = _mm256_setzero_si256();
const __m256i summed_pairs = _mm256_dpbusd_avx_epi32(zero, ax, sy);
return _mm256_cvtepi32_ps(summed_pairs);
#else
// Perform multiplication and create 16-bit values
const __m256i dot = _mm256_maddubs_epi16(ax, sy);

View file

@ -1000,8 +1000,10 @@ class tinyBLAS_Q0_AVX {
inline __m256 updot(__m256i u, __m256i s) {
__m256i res;
#if defined(__AVXVNNI__) || (defined(__AVX512VNNI__) && defined(__AVX512VL__))
#if defined(__AVX512VNNI__) && defined(__AVX512VL__)
res = _mm256_dpbusd_epi32(_mm256_setzero_si256(), u, s);
#elif defined(__AVXVNNI__)
res = _mm256_dpbusd_avx_epi32(_mm256_setzero_si256(), u, s);
#else
res = _mm256_madd_epi16(_mm256_set1_epi16(1), _mm256_maddubs_epi16(u, s));
#endif

View file

@ -2744,13 +2744,13 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co
cl_image_format img_fmt_1d;
cl_image_desc img_desc_1d;
cl_buffer_region region;
cl_mem A_image1d;
cl_mem B_image1d;
cl_mem B_sub_buffer;
cl_mem C_d;
cl_mem A_image1d = nullptr;
cl_mem B_image1d = nullptr;
cl_mem B_sub_buffer = nullptr;
cl_mem C_d = nullptr;
// for B transpose
cl_mem B_d;
cl_mem B_d_input_image;
cl_mem B_d = nullptr;
cl_mem B_d_input_image = nullptr;
// <--------------------------------------------> //
// define matrix dimensions

File diff suppressed because it is too large Load diff

View file

@ -1560,47 +1560,47 @@ const uint64_t matmul_id_iq4_nl_f16_f16acc_coopmat_len = 16892;
extern unsigned char matmul_id_iq4_nl_f16_aligned_f16acc_coopmat_data[17856];
const uint64_t matmul_id_iq4_nl_f16_aligned_f16acc_coopmat_len = 17856;
extern unsigned char mul_mat_vec_f32_f32_f32_data[13840];
const uint64_t mul_mat_vec_f32_f32_f32_len = 13840;
extern unsigned char mul_mat_vec_f32_f32_f32_data[16528];
const uint64_t mul_mat_vec_f32_f32_f32_len = 16528;
extern unsigned char mul_mat_vec_f32_f16_f32_data[14068];
const uint64_t mul_mat_vec_f32_f16_f32_len = 14068;
extern unsigned char mul_mat_vec_f32_f16_f32_data[16756];
const uint64_t mul_mat_vec_f32_f16_f32_len = 16756;
extern unsigned char mul_mat_vec_id_f32_f32_data[13336];
const uint64_t mul_mat_vec_id_f32_f32_len = 13336;
extern unsigned char mul_mat_vec_id_f32_f32_data[16384];
const uint64_t mul_mat_vec_id_f32_f32_len = 16384;
extern unsigned char dequant_f32_data[3224];
const uint64_t dequant_f32_len = 3224;
extern unsigned char get_rows_f32_data[3088];
const uint64_t get_rows_f32_len = 3088;
extern unsigned char get_rows_f32_data[3312];
const uint64_t get_rows_f32_len = 3312;
extern unsigned char get_rows_f32_f32_data[3036];
const uint64_t get_rows_f32_f32_len = 3036;
extern unsigned char get_rows_f32_f32_data[3260];
const uint64_t get_rows_f32_f32_len = 3260;
extern unsigned char mul_mat_vec_f16_f32_f32_data[14068];
const uint64_t mul_mat_vec_f16_f32_f32_len = 14068;
extern unsigned char mul_mat_vec_f16_f32_f32_data[16756];
const uint64_t mul_mat_vec_f16_f32_f32_len = 16756;
extern unsigned char mul_mat_vec_f16_f16_f32_data[14260];
const uint64_t mul_mat_vec_f16_f16_f32_len = 14260;
extern unsigned char mul_mat_vec_f16_f16_f32_data[16948];
const uint64_t mul_mat_vec_f16_f16_f32_len = 16948;
extern unsigned char mul_mat_vec_id_f16_f32_data[13564];
const uint64_t mul_mat_vec_id_f16_f32_len = 13564;
extern unsigned char mul_mat_vec_id_f16_f32_data[16612];
const uint64_t mul_mat_vec_id_f16_f32_len = 16612;
extern unsigned char get_rows_f16_data[3056];
const uint64_t get_rows_f16_len = 3056;
extern unsigned char get_rows_f16_data[3280];
const uint64_t get_rows_f16_len = 3280;
extern unsigned char get_rows_f16_f32_data[3088];
const uint64_t get_rows_f16_f32_len = 3088;
extern unsigned char get_rows_f16_f32_data[3312];
const uint64_t get_rows_f16_f32_len = 3312;
extern unsigned char mul_mat_vec_q4_0_f32_f32_data[19240];
const uint64_t mul_mat_vec_q4_0_f32_f32_len = 19240;
extern unsigned char mul_mat_vec_q4_0_f32_f32_data[21928];
const uint64_t mul_mat_vec_q4_0_f32_f32_len = 21928;
extern unsigned char mul_mat_vec_q4_0_f16_f32_data[20032];
const uint64_t mul_mat_vec_q4_0_f16_f32_len = 20032;
extern unsigned char mul_mat_vec_q4_0_f16_f32_data[22720];
const uint64_t mul_mat_vec_q4_0_f16_f32_len = 22720;
extern unsigned char mul_mat_vec_id_q4_0_f32_data[18736];
const uint64_t mul_mat_vec_id_q4_0_f32_len = 18736;
extern unsigned char mul_mat_vec_id_q4_0_f32_data[21784];
const uint64_t mul_mat_vec_id_q4_0_f32_len = 21784;
extern unsigned char dequant_q4_0_data[5188];
const uint64_t dequant_q4_0_len = 5188;
@ -1611,14 +1611,14 @@ const uint64_t get_rows_q4_0_len = 3764;
extern unsigned char get_rows_q4_0_f32_data[3748];
const uint64_t get_rows_q4_0_f32_len = 3748;
extern unsigned char mul_mat_vec_q4_1_f32_f32_data[21492];
const uint64_t mul_mat_vec_q4_1_f32_f32_len = 21492;
extern unsigned char mul_mat_vec_q4_1_f32_f32_data[24180];
const uint64_t mul_mat_vec_q4_1_f32_f32_len = 24180;
extern unsigned char mul_mat_vec_q4_1_f16_f32_data[22284];
const uint64_t mul_mat_vec_q4_1_f16_f32_len = 22284;
extern unsigned char mul_mat_vec_q4_1_f16_f32_data[24972];
const uint64_t mul_mat_vec_q4_1_f16_f32_len = 24972;
extern unsigned char mul_mat_vec_id_q4_1_f32_data[20972];
const uint64_t mul_mat_vec_id_q4_1_f32_len = 20972;
extern unsigned char mul_mat_vec_id_q4_1_f32_data[24020];
const uint64_t mul_mat_vec_id_q4_1_f32_len = 24020;
extern unsigned char dequant_q4_1_data[5272];
const uint64_t dequant_q4_1_len = 5272;
@ -1629,14 +1629,14 @@ const uint64_t get_rows_q4_1_len = 3848;
extern unsigned char get_rows_q4_1_f32_data[3832];
const uint64_t get_rows_q4_1_f32_len = 3832;
extern unsigned char mul_mat_vec_q5_0_f32_f32_data[26072];
const uint64_t mul_mat_vec_q5_0_f32_f32_len = 26072;
extern unsigned char mul_mat_vec_q5_0_f32_f32_data[28760];
const uint64_t mul_mat_vec_q5_0_f32_f32_len = 28760;
extern unsigned char mul_mat_vec_q5_0_f16_f32_data[26864];
const uint64_t mul_mat_vec_q5_0_f16_f32_len = 26864;
extern unsigned char mul_mat_vec_q5_0_f16_f32_data[29552];
const uint64_t mul_mat_vec_q5_0_f16_f32_len = 29552;
extern unsigned char mul_mat_vec_id_q5_0_f32_data[25552];
const uint64_t mul_mat_vec_id_q5_0_f32_len = 25552;
extern unsigned char mul_mat_vec_id_q5_0_f32_data[28600];
const uint64_t mul_mat_vec_id_q5_0_f32_len = 28600;
extern unsigned char dequant_q5_0_data[6668];
const uint64_t dequant_q5_0_len = 6668;
@ -1647,14 +1647,14 @@ const uint64_t get_rows_q5_0_len = 4292;
extern unsigned char get_rows_q5_0_f32_data[4276];
const uint64_t get_rows_q5_0_f32_len = 4276;
extern unsigned char mul_mat_vec_q5_1_f32_f32_data[27500];
const uint64_t mul_mat_vec_q5_1_f32_f32_len = 27500;
extern unsigned char mul_mat_vec_q5_1_f32_f32_data[30188];
const uint64_t mul_mat_vec_q5_1_f32_f32_len = 30188;
extern unsigned char mul_mat_vec_q5_1_f16_f32_data[28292];
const uint64_t mul_mat_vec_q5_1_f16_f32_len = 28292;
extern unsigned char mul_mat_vec_q5_1_f16_f32_data[30980];
const uint64_t mul_mat_vec_q5_1_f16_f32_len = 30980;
extern unsigned char mul_mat_vec_id_q5_1_f32_data[26980];
const uint64_t mul_mat_vec_id_q5_1_f32_len = 26980;
extern unsigned char mul_mat_vec_id_q5_1_f32_data[30028];
const uint64_t mul_mat_vec_id_q5_1_f32_len = 30028;
extern unsigned char dequant_q5_1_data[6564];
const uint64_t dequant_q5_1_len = 6564;
@ -1665,14 +1665,14 @@ const uint64_t get_rows_q5_1_len = 4188;
extern unsigned char get_rows_q5_1_f32_data[4172];
const uint64_t get_rows_q5_1_f32_len = 4172;
extern unsigned char mul_mat_vec_q8_0_f32_f32_data[19612];
const uint64_t mul_mat_vec_q8_0_f32_f32_len = 19612;
extern unsigned char mul_mat_vec_q8_0_f32_f32_data[22300];
const uint64_t mul_mat_vec_q8_0_f32_f32_len = 22300;
extern unsigned char mul_mat_vec_q8_0_f16_f32_data[19820];
const uint64_t mul_mat_vec_q8_0_f16_f32_len = 19820;
extern unsigned char mul_mat_vec_q8_0_f16_f32_data[22508];
const uint64_t mul_mat_vec_q8_0_f16_f32_len = 22508;
extern unsigned char mul_mat_vec_id_q8_0_f32_data[19108];
const uint64_t mul_mat_vec_id_q8_0_f32_len = 19108;
extern unsigned char mul_mat_vec_id_q8_0_f32_data[22156];
const uint64_t mul_mat_vec_id_q8_0_f32_len = 22156;
extern unsigned char dequant_q8_0_data[4804];
const uint64_t dequant_q8_0_len = 4804;
@ -1683,74 +1683,74 @@ const uint64_t get_rows_q8_0_len = 3704;
extern unsigned char get_rows_q8_0_f32_data[3688];
const uint64_t get_rows_q8_0_f32_len = 3688;
extern unsigned char mul_mat_vec_q2_k_f32_f32_data[17732];
const uint64_t mul_mat_vec_q2_k_f32_f32_len = 17732;
extern unsigned char mul_mat_vec_q2_k_f32_f32_data[19580];
const uint64_t mul_mat_vec_q2_k_f32_f32_len = 19580;
extern unsigned char mul_mat_vec_q2_k_f16_f32_data[18212];
const uint64_t mul_mat_vec_q2_k_f16_f32_len = 18212;
extern unsigned char mul_mat_vec_q2_k_f16_f32_data[20060];
const uint64_t mul_mat_vec_q2_k_f16_f32_len = 20060;
extern unsigned char mul_mat_vec_id_q2_k_f32_data[17228];
const uint64_t mul_mat_vec_id_q2_k_f32_len = 17228;
extern unsigned char mul_mat_vec_id_q2_k_f32_data[19316];
const uint64_t mul_mat_vec_id_q2_k_f32_len = 19316;
extern unsigned char dequant_q2_k_data[3960];
const uint64_t dequant_q2_k_len = 3960;
extern unsigned char mul_mat_vec_q3_k_f32_f32_data[25020];
const uint64_t mul_mat_vec_q3_k_f32_f32_len = 25020;
extern unsigned char mul_mat_vec_q3_k_f32_f32_data[26868];
const uint64_t mul_mat_vec_q3_k_f32_f32_len = 26868;
extern unsigned char mul_mat_vec_q3_k_f16_f32_data[25540];
const uint64_t mul_mat_vec_q3_k_f16_f32_len = 25540;
extern unsigned char mul_mat_vec_q3_k_f16_f32_data[27388];
const uint64_t mul_mat_vec_q3_k_f16_f32_len = 27388;
extern unsigned char mul_mat_vec_id_q3_k_f32_data[24532];
const uint64_t mul_mat_vec_id_q3_k_f32_len = 24532;
extern unsigned char mul_mat_vec_id_q3_k_f32_data[26604];
const uint64_t mul_mat_vec_id_q3_k_f32_len = 26604;
extern unsigned char dequant_q3_k_data[4828];
const uint64_t dequant_q3_k_len = 4828;
extern unsigned char mul_mat_vec_q4_k_f32_f32_data[16620];
const uint64_t mul_mat_vec_q4_k_f32_f32_len = 16620;
extern unsigned char mul_mat_vec_q4_k_f32_f32_data[18468];
const uint64_t mul_mat_vec_q4_k_f32_f32_len = 18468;
extern unsigned char mul_mat_vec_q4_k_f16_f32_data[17132];
const uint64_t mul_mat_vec_q4_k_f16_f32_len = 17132;
extern unsigned char mul_mat_vec_q4_k_f16_f32_data[18980];
const uint64_t mul_mat_vec_q4_k_f16_f32_len = 18980;
extern unsigned char mul_mat_vec_id_q4_k_f32_data[16100];
const uint64_t mul_mat_vec_id_q4_k_f32_len = 16100;
extern unsigned char mul_mat_vec_id_q4_k_f32_data[18204];
const uint64_t mul_mat_vec_id_q4_k_f32_len = 18204;
extern unsigned char dequant_q4_k_data[5984];
const uint64_t dequant_q4_k_len = 5984;
extern unsigned char mul_mat_vec_q5_k_f32_f32_data[18180];
const uint64_t mul_mat_vec_q5_k_f32_f32_len = 18180;
extern unsigned char mul_mat_vec_q5_k_f32_f32_data[20028];
const uint64_t mul_mat_vec_q5_k_f32_f32_len = 20028;
extern unsigned char mul_mat_vec_q5_k_f16_f32_data[18660];
const uint64_t mul_mat_vec_q5_k_f16_f32_len = 18660;
extern unsigned char mul_mat_vec_q5_k_f16_f32_data[20508];
const uint64_t mul_mat_vec_q5_k_f16_f32_len = 20508;
extern unsigned char mul_mat_vec_id_q5_k_f32_data[17660];
const uint64_t mul_mat_vec_id_q5_k_f32_len = 17660;
extern unsigned char mul_mat_vec_id_q5_k_f32_data[19764];
const uint64_t mul_mat_vec_id_q5_k_f32_len = 19764;
extern unsigned char dequant_q5_k_data[6032];
const uint64_t dequant_q5_k_len = 6032;
extern unsigned char mul_mat_vec_q6_k_f32_f32_data[17924];
const uint64_t mul_mat_vec_q6_k_f32_f32_len = 17924;
extern unsigned char mul_mat_vec_q6_k_f32_f32_data[19772];
const uint64_t mul_mat_vec_q6_k_f32_f32_len = 19772;
extern unsigned char mul_mat_vec_q6_k_f16_f32_data[18444];
const uint64_t mul_mat_vec_q6_k_f16_f32_len = 18444;
extern unsigned char mul_mat_vec_q6_k_f16_f32_data[20292];
const uint64_t mul_mat_vec_q6_k_f16_f32_len = 20292;
extern unsigned char mul_mat_vec_id_q6_k_f32_data[17404];
const uint64_t mul_mat_vec_id_q6_k_f32_len = 17404;
extern unsigned char mul_mat_vec_id_q6_k_f32_data[19508];
const uint64_t mul_mat_vec_id_q6_k_f32_len = 19508;
extern unsigned char dequant_q6_k_data[4264];
const uint64_t dequant_q6_k_len = 4264;
extern unsigned char mul_mat_vec_iq4_nl_f32_f32_data[20640];
const uint64_t mul_mat_vec_iq4_nl_f32_f32_len = 20640;
extern unsigned char mul_mat_vec_iq4_nl_f32_f32_data[23328];
const uint64_t mul_mat_vec_iq4_nl_f32_f32_len = 23328;
extern unsigned char mul_mat_vec_iq4_nl_f16_f32_data[21432];
const uint64_t mul_mat_vec_iq4_nl_f16_f32_len = 21432;
extern unsigned char mul_mat_vec_iq4_nl_f16_f32_data[24120];
const uint64_t mul_mat_vec_iq4_nl_f16_f32_len = 24120;
extern unsigned char mul_mat_vec_id_iq4_nl_f32_data[20136];
const uint64_t mul_mat_vec_id_iq4_nl_f32_len = 20136;
extern unsigned char mul_mat_vec_id_iq4_nl_f32_data[23184];
const uint64_t mul_mat_vec_id_iq4_nl_f32_len = 23184;
extern unsigned char dequant_iq4_nl_data[5920];
const uint64_t dequant_iq4_nl_len = 5920;
@ -1776,74 +1776,74 @@ const uint64_t group_norm_f32_len = 3080;
extern unsigned char rms_norm_f32_data[2544];
const uint64_t rms_norm_f32_len = 2544;
extern unsigned char cpy_f32_f32_data[4608];
const uint64_t cpy_f32_f32_len = 4608;
extern unsigned char cpy_f32_f32_data[4684];
const uint64_t cpy_f32_f32_len = 4684;
extern unsigned char cpy_f32_f16_data[4660];
const uint64_t cpy_f32_f16_len = 4660;
extern unsigned char cpy_f32_f16_data[4736];
const uint64_t cpy_f32_f16_len = 4736;
extern unsigned char cpy_f16_f16_data[4628];
const uint64_t cpy_f16_f16_len = 4628;
extern unsigned char cpy_f16_f16_data[4704];
const uint64_t cpy_f16_f16_len = 4704;
extern unsigned char contig_cpy_f32_f32_data[2952];
const uint64_t contig_cpy_f32_f32_len = 2952;
extern unsigned char contig_cpy_f32_f32_data[3164];
const uint64_t contig_cpy_f32_f32_len = 3164;
extern unsigned char contig_cpy_f32_f16_data[3068];
const uint64_t contig_cpy_f32_f16_len = 3068;
extern unsigned char contig_cpy_f32_f16_data[3280];
const uint64_t contig_cpy_f32_f16_len = 3280;
extern unsigned char contig_cpy_f16_f16_data[2972];
const uint64_t contig_cpy_f16_f16_len = 2972;
extern unsigned char contig_cpy_f16_f16_data[3184];
const uint64_t contig_cpy_f16_f16_len = 3184;
extern unsigned char add_f32_data[5780];
const uint64_t add_f32_len = 5780;
extern unsigned char add_f32_data[5916];
const uint64_t add_f32_len = 5916;
extern unsigned char add_f16_f32_f16_data[5848];
const uint64_t add_f16_f32_f16_len = 5848;
extern unsigned char add_f16_f32_f16_data[5984];
const uint64_t add_f16_f32_f16_len = 5984;
extern unsigned char acc_f32_data[4888];
const uint64_t acc_f32_len = 4888;
extern unsigned char acc_f32_data[5100];
const uint64_t acc_f32_len = 5100;
extern unsigned char split_k_reduce_data[2764];
const uint64_t split_k_reduce_len = 2764;
extern unsigned char mul_f32_data[5780];
const uint64_t mul_f32_len = 5780;
extern unsigned char mul_f32_data[5916];
const uint64_t mul_f32_len = 5916;
extern unsigned char div_f32_data[5780];
const uint64_t div_f32_len = 5780;
extern unsigned char div_f32_data[5916];
const uint64_t div_f32_len = 5916;
extern unsigned char repeat_f32_data[4308];
const uint64_t repeat_f32_len = 4308;
extern unsigned char repeat_f32_data[4384];
const uint64_t repeat_f32_len = 4384;
extern unsigned char scale_f32_data[2440];
const uint64_t scale_f32_len = 2440;
extern unsigned char scale_f32_data[2532];
const uint64_t scale_f32_len = 2532;
extern unsigned char sqr_f32_data[4628];
const uint64_t sqr_f32_len = 4628;
extern unsigned char sqr_f32_data[4704];
const uint64_t sqr_f32_len = 4704;
extern unsigned char sin_f32_data[4632];
const uint64_t sin_f32_len = 4632;
extern unsigned char sin_f32_data[4708];
const uint64_t sin_f32_len = 4708;
extern unsigned char cos_f32_data[4632];
const uint64_t cos_f32_len = 4632;
extern unsigned char cos_f32_data[4708];
const uint64_t cos_f32_len = 4708;
extern unsigned char clamp_f32_data[4888];
const uint64_t clamp_f32_len = 4888;
extern unsigned char clamp_f32_data[4964];
const uint64_t clamp_f32_len = 4964;
extern unsigned char pad_f32_data[3912];
const uint64_t pad_f32_len = 3912;
extern unsigned char pad_f32_data[3988];
const uint64_t pad_f32_len = 3988;
extern unsigned char concat_f32_data[5316];
const uint64_t concat_f32_len = 5316;
extern unsigned char concat_f32_data[5452];
const uint64_t concat_f32_len = 5452;
extern unsigned char concat_f16_data[5400];
const uint64_t concat_f16_len = 5400;
extern unsigned char concat_f16_data[5556];
const uint64_t concat_f16_len = 5556;
extern unsigned char concat_i32_data[5316];
const uint64_t concat_i32_len = 5316;
extern unsigned char concat_i32_data[5452];
const uint64_t concat_i32_len = 5452;
extern unsigned char upscale_f32_data[2856];
const uint64_t upscale_f32_len = 2856;
extern unsigned char upscale_f32_data[2952];
const uint64_t upscale_f32_len = 2952;
extern unsigned char gelu_f32_data[1700];
const uint64_t gelu_f32_len = 1700;
@ -1896,14 +1896,14 @@ const uint64_t argsort_f32_len = 4200;
extern unsigned char sum_rows_f32_data[2320];
const uint64_t sum_rows_f32_len = 2320;
extern unsigned char im2col_f32_data[3672];
const uint64_t im2col_f32_len = 3672;
extern unsigned char im2col_f32_data[4548];
const uint64_t im2col_f32_len = 4548;
extern unsigned char im2col_f32_f16_data[3732];
const uint64_t im2col_f32_f16_len = 3732;
extern unsigned char im2col_f32_f16_data[4600];
const uint64_t im2col_f32_f16_len = 4600;
extern unsigned char im2col_f32_f16_rte_data[3756];
const uint64_t im2col_f32_f16_rte_len = 3756;
extern unsigned char im2col_f32_f16_rte_data[4624];
const uint64_t im2col_f32_f16_rte_len = 4624;
extern unsigned char timestep_embedding_f32_data[2000];
const uint64_t timestep_embedding_f32_len = 2000;

View file

@ -145,6 +145,8 @@ class vk_perf_logger;
#endif
static void ggml_vk_destroy_buffer(vk_buffer& buf);
static constexpr uint32_t mul_mat_vec_max_cols = 8;
struct vk_device_struct {
std::mutex mutex;
@ -202,8 +204,8 @@ struct vk_device_struct {
vk_matmul_pipeline2 pipeline_dequant_mul_mat_mat_id[GGML_TYPE_COUNT];
vk_pipeline pipeline_dequant[GGML_TYPE_COUNT];
vk_pipeline pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_COUNT];
vk_pipeline pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_COUNT];
vk_pipeline pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_COUNT][mul_mat_vec_max_cols];
vk_pipeline pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_COUNT][mul_mat_vec_max_cols];
vk_pipeline pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_COUNT];
vk_pipeline pipeline_mul_mat_vec_p021_f16_f32;
@ -411,7 +413,7 @@ struct vk_op_unary_push_constants {
uint32_t ne;
uint32_t ne00; uint32_t ne01; uint32_t ne02; uint32_t ne03; uint32_t nb00; uint32_t nb01; uint32_t nb02; uint32_t nb03;
uint32_t ne10; uint32_t ne11; uint32_t ne12; uint32_t ne13; uint32_t nb10; uint32_t nb11; uint32_t nb12; uint32_t nb13;
uint32_t d_offset;
uint32_t misalign_offsets;
float param1; float param2;
uint32_t ne0_012mp; uint32_t ne0_012L;
uint32_t ne0_01mp; uint32_t ne0_01L;
@ -459,7 +461,7 @@ struct vk_op_binary_push_constants {
uint32_t ne00; uint32_t ne01; uint32_t ne02; uint32_t ne03; uint32_t nb00; uint32_t nb01; uint32_t nb02; uint32_t nb03;
uint32_t ne10; uint32_t ne11; uint32_t ne12; uint32_t ne13; uint32_t nb10; uint32_t nb11; uint32_t nb12; uint32_t nb13;
uint32_t ne20; uint32_t ne21; uint32_t ne22; uint32_t ne23; uint32_t nb20; uint32_t nb21; uint32_t nb22; uint32_t nb23;
uint32_t d_offset;
uint32_t misalign_offsets;
float param1; float param2; int32_t param3;
};
@ -546,7 +548,7 @@ struct vk_staging_memcpy {
};
struct vk_op_upscale_push_constants {
uint32_t ne; uint32_t d_offset;
uint32_t ne; uint32_t a_offset; uint32_t d_offset;
uint32_t nb00; uint32_t nb01; uint32_t nb02; uint32_t nb03;
uint32_t ne10; uint32_t ne11; uint32_t ne12; uint32_t ne13;
float sf0; float sf1; float sf2; float sf3;
@ -1404,10 +1406,10 @@ static void ggml_vk_load_shaders(vk_device& device) {
// spec constants and tile sizes for non-quant matmul/matmul_id
l_warptile = { 256, 128, 256, 64 };
m_warptile = { 256, 128, 128, 64 };
s_warptile = { 128, 32, 16, 64 };
s_warptile = { 128, 64, 64, 64 };
l_wg_denoms = {128, 256, 1 };
m_wg_denoms = {128, 128, 1 };
s_wg_denoms = { 32, 16, 1 };
s_wg_denoms = { 64, 64, 1 };
// spec constants and tile sizes for quant matmul (non-Qi_K)
l_warptile_mmq = { 256, 128, 256, 64 };
@ -1866,33 +1868,35 @@ static void ggml_vk_load_shaders(vk_device& device) {
} else if (device->vendor_id == VK_VENDOR_ID_INTEL)
rm_stdq = 2;
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_F32 ], "mul_mat_vec_f32_f32_f32", mul_mat_vec_f32_f32_f32_len, mul_mat_vec_f32_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_F16 ], "mul_mat_vec_f16_f32_f32", mul_mat_vec_f16_f32_f32_len, mul_mat_vec_f16_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q4_0], "mul_mat_vec_q4_0_f32_f32", mul_mat_vec_q4_0_f32_f32_len, mul_mat_vec_q4_0_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q4_1], "mul_mat_vec_q4_1_f32_f32", mul_mat_vec_q4_1_f32_f32_len, mul_mat_vec_q4_1_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q5_0], "mul_mat_vec_q5_0_f32_f32", mul_mat_vec_q5_0_f32_f32_len, mul_mat_vec_q5_0_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q5_1], "mul_mat_vec_q5_1_f32_f32", mul_mat_vec_q5_1_f32_f32_len, mul_mat_vec_q5_1_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q8_0], "mul_mat_vec_q8_0_f32_f32", mul_mat_vec_q8_0_f32_f32_len, mul_mat_vec_q8_0_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1*rm_stdq, 1, 1}, {device->subgroup_size, 1*rm_stdq}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q2_K], "mul_mat_vec_q2_k_f32_f32", mul_mat_vec_q2_k_f32_f32_len, mul_mat_vec_q2_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q3_K], "mul_mat_vec_q3_k_f32_f32", mul_mat_vec_q3_k_f32_f32_len, mul_mat_vec_q3_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q4_K], "mul_mat_vec_q4_k_f32_f32", mul_mat_vec_q4_k_f32_f32_len, mul_mat_vec_q4_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q5_K], "mul_mat_vec_q5_k_f32_f32", mul_mat_vec_q5_k_f32_f32_len, mul_mat_vec_q5_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q6_K], "mul_mat_vec_q6_k_f32_f32", mul_mat_vec_q6_k_f32_f32_len, mul_mat_vec_q6_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ4_NL], "mul_mat_vec_iq4_nl_f32_f32", mul_mat_vec_iq4_nl_f32_f32_len, mul_mat_vec_iq4_nl_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {subgroup_size_16, 2*rm_stdq}, 1, true);
for (uint32_t i = 0; i < mul_mat_vec_max_cols; ++i) {
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_F32 ][i], "mul_mat_vec_f32_f32_f32_"+std::to_string(i+1), mul_mat_vec_f32_f32_f32_len, mul_mat_vec_f32_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2, i+1}, 1);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_F16 ][i], "mul_mat_vec_f16_f32_f32_"+std::to_string(i+1), mul_mat_vec_f16_f32_f32_len, mul_mat_vec_f16_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2, i+1}, 1);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q4_0][i], "mul_mat_vec_q4_0_f32_f32_"+std::to_string(i+1), mul_mat_vec_q4_0_f32_f32_len, mul_mat_vec_q4_0_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq, i+1}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q4_1][i], "mul_mat_vec_q4_1_f32_f32_"+std::to_string(i+1), mul_mat_vec_q4_1_f32_f32_len, mul_mat_vec_q4_1_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq, i+1}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q5_0][i], "mul_mat_vec_q5_0_f32_f32_"+std::to_string(i+1), mul_mat_vec_q5_0_f32_f32_len, mul_mat_vec_q5_0_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq, i+1}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q5_1][i], "mul_mat_vec_q5_1_f32_f32_"+std::to_string(i+1), mul_mat_vec_q5_1_f32_f32_len, mul_mat_vec_q5_1_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq, i+1}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q8_0][i], "mul_mat_vec_q8_0_f32_f32_"+std::to_string(i+1), mul_mat_vec_q8_0_f32_f32_len, mul_mat_vec_q8_0_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1*rm_stdq, 1, 1}, {device->subgroup_size, 1*rm_stdq, i+1}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q2_K][i], "mul_mat_vec_q2_k_f32_f32_"+std::to_string(i+1), mul_mat_vec_q2_k_f32_f32_len, mul_mat_vec_q2_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q3_K][i], "mul_mat_vec_q3_k_f32_f32_"+std::to_string(i+1), mul_mat_vec_q3_k_f32_f32_len, mul_mat_vec_q3_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q4_K][i], "mul_mat_vec_q4_k_f32_f32_"+std::to_string(i+1), mul_mat_vec_q4_k_f32_f32_len, mul_mat_vec_q4_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q5_K][i], "mul_mat_vec_q5_k_f32_f32_"+std::to_string(i+1), mul_mat_vec_q5_k_f32_f32_len, mul_mat_vec_q5_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q6_K][i], "mul_mat_vec_q6_k_f32_f32_"+std::to_string(i+1), mul_mat_vec_q6_k_f32_f32_len, mul_mat_vec_q6_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ4_NL][i], "mul_mat_vec_iq4_nl_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq4_nl_f32_f32_len, mul_mat_vec_iq4_nl_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {subgroup_size_16, 2*rm_stdq, i+1}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_F32 ], "mul_mat_vec_f32_f16_f32", mul_mat_vec_f32_f16_f32_len, mul_mat_vec_f32_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_F16 ], "mul_mat_vec_f16_f16_f32", mul_mat_vec_f16_f16_f32_len, mul_mat_vec_f16_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q4_0], "mul_mat_vec_q4_0_f16_f32", mul_mat_vec_q4_0_f16_f32_len, mul_mat_vec_q4_0_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q4_1], "mul_mat_vec_q4_1_f16_f32", mul_mat_vec_q4_1_f16_f32_len, mul_mat_vec_q4_1_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q5_0], "mul_mat_vec_q5_0_f16_f32", mul_mat_vec_q5_0_f16_f32_len, mul_mat_vec_q5_0_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q5_1], "mul_mat_vec_q5_1_f16_f32", mul_mat_vec_q5_1_f16_f32_len, mul_mat_vec_q5_1_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q8_0], "mul_mat_vec_q8_0_f16_f32", mul_mat_vec_q8_0_f16_f32_len, mul_mat_vec_q8_0_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1*rm_stdq, 1, 1}, {device->subgroup_size, 1*rm_stdq}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q2_K], "mul_mat_vec_q2_k_f16_f32", mul_mat_vec_q2_k_f16_f32_len, mul_mat_vec_q2_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q3_K], "mul_mat_vec_q3_k_f16_f32", mul_mat_vec_q3_k_f16_f32_len, mul_mat_vec_q3_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q4_K], "mul_mat_vec_q4_k_f16_f32", mul_mat_vec_q4_k_f16_f32_len, mul_mat_vec_q4_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q5_K], "mul_mat_vec_q5_k_f16_f32", mul_mat_vec_q5_k_f16_f32_len, mul_mat_vec_q5_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q6_K], "mul_mat_vec_q6_k_f16_f32", mul_mat_vec_q6_k_f16_f32_len, mul_mat_vec_q6_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ4_NL], "mul_mat_vec_iq4_nl_f16_f32", mul_mat_vec_iq4_nl_f16_f32_len, mul_mat_vec_iq4_nl_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {subgroup_size_16, 2*rm_stdq}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_F32 ][i], "mul_mat_vec_f32_f16_f32_"+std::to_string(i+1), mul_mat_vec_f32_f16_f32_len, mul_mat_vec_f32_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2, i+1}, 1);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_F16 ][i], "mul_mat_vec_f16_f16_f32_"+std::to_string(i+1), mul_mat_vec_f16_f16_f32_len, mul_mat_vec_f16_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2, i+1}, 1);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q4_0][i], "mul_mat_vec_q4_0_f16_f32_"+std::to_string(i+1), mul_mat_vec_q4_0_f16_f32_len, mul_mat_vec_q4_0_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq, i+1}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q4_1][i], "mul_mat_vec_q4_1_f16_f32_"+std::to_string(i+1), mul_mat_vec_q4_1_f16_f32_len, mul_mat_vec_q4_1_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq, i+1}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q5_0][i], "mul_mat_vec_q5_0_f16_f32_"+std::to_string(i+1), mul_mat_vec_q5_0_f16_f32_len, mul_mat_vec_q5_0_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq, i+1}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q5_1][i], "mul_mat_vec_q5_1_f16_f32_"+std::to_string(i+1), mul_mat_vec_q5_1_f16_f32_len, mul_mat_vec_q5_1_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq, i+1}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q8_0][i], "mul_mat_vec_q8_0_f16_f32_"+std::to_string(i+1), mul_mat_vec_q8_0_f16_f32_len, mul_mat_vec_q8_0_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1*rm_stdq, 1, 1}, {device->subgroup_size, 1*rm_stdq, i+1}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q2_K][i], "mul_mat_vec_q2_k_f16_f32_"+std::to_string(i+1), mul_mat_vec_q2_k_f16_f32_len, mul_mat_vec_q2_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q3_K][i], "mul_mat_vec_q3_k_f16_f32_"+std::to_string(i+1), mul_mat_vec_q3_k_f16_f32_len, mul_mat_vec_q3_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q4_K][i], "mul_mat_vec_q4_k_f16_f32_"+std::to_string(i+1), mul_mat_vec_q4_k_f16_f32_len, mul_mat_vec_q4_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q5_K][i], "mul_mat_vec_q5_k_f16_f32_"+std::to_string(i+1), mul_mat_vec_q5_k_f16_f32_len, mul_mat_vec_q5_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q6_K][i], "mul_mat_vec_q6_k_f16_f32_"+std::to_string(i+1), mul_mat_vec_q6_k_f16_f32_len, mul_mat_vec_q6_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ4_NL][i], "mul_mat_vec_iq4_nl_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq4_nl_f16_f32_len, mul_mat_vec_iq4_nl_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {subgroup_size_16, 2*rm_stdq, i+1}, 1, true);
}
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_F32 ], "mul_mat_vec_id_f32_f32", mul_mat_vec_id_f32_f32_len, mul_mat_vec_id_f32_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_F16 ], "mul_mat_vec_id_f16_f32", mul_mat_vec_id_f16_f32_len, mul_mat_vec_id_f16_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1);
@ -2017,11 +2021,11 @@ static void ggml_vk_load_shaders(vk_device& device) {
ggml_vk_create_pipeline(device, device->pipeline_sum_rows_f32, "sum_rows_f32", sum_rows_f32_len, sum_rows_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
ggml_vk_create_pipeline(device, device->pipeline_im2col_f32, "im2col_f32", im2col_f32_len, im2col_f32_data, "main", 2, sizeof(vk_op_im2col_push_constants), {256, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_im2col_f32, "im2col_f32", im2col_f32_len, im2col_f32_data, "main", 2, sizeof(vk_op_im2col_push_constants), {512, 1, 1}, { device->subgroup_size }, 1, true);
if (device->float_controls_rte_fp16) {
ggml_vk_create_pipeline(device, device->pipeline_im2col_f32_f16, "im2col_f32_f16", im2col_f32_f16_rte_len, im2col_f32_f16_rte_data, "main", 2, sizeof(vk_op_im2col_push_constants), {256, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_im2col_f32_f16, "im2col_f32_f16", im2col_f32_f16_rte_len, im2col_f32_f16_rte_data, "main", 2, sizeof(vk_op_im2col_push_constants), {512, 1, 1}, { device->subgroup_size }, 1, true);
} else {
ggml_vk_create_pipeline(device, device->pipeline_im2col_f32_f16, "im2col_f32_f16", im2col_f32_f16_len, im2col_f32_f16_data, "main", 2, sizeof(vk_op_im2col_push_constants), {256, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_im2col_f32_f16, "im2col_f32_f16", im2col_f32_f16_len, im2col_f32_f16_data, "main", 2, sizeof(vk_op_im2col_push_constants), {512, 1, 1}, { device->subgroup_size }, 1, true);
}
ggml_vk_create_pipeline(device, device->pipeline_timestep_embedding_f32, "timestep_embedding_f32", timestep_embedding_f32_len, timestep_embedding_f32_data, "main", 2, sizeof(vk_op_timestep_embedding_push_constants), {256, 1, 1}, {}, 1);
@ -2892,9 +2896,10 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_pipeline(ggml_backend_vk_conte
return ctx->device->fp16 ? ctx->device->pipeline_dequant_mul_mat_mat[src0_type].f16acc : ctx->device->pipeline_dequant_mul_mat_mat[src0_type].f32acc;
}
static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec(ggml_backend_vk_context * ctx, ggml_type a_type, ggml_type b_type) {
static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec(ggml_backend_vk_context * ctx, ggml_type a_type, ggml_type b_type, uint32_t num_cols) {
VK_LOG_DEBUG("ggml_vk_get_dequantize_mul_mat_vec()");
GGML_ASSERT(b_type == GGML_TYPE_F32 || b_type == GGML_TYPE_F16);
GGML_ASSERT(num_cols >= 1 && num_cols <= mul_mat_vec_max_cols);
switch (a_type) {
case GGML_TYPE_F32:
@ -2915,7 +2920,7 @@ static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec(ggml_backend_vk_context *
return nullptr;
}
return b_type == GGML_TYPE_F32 ? ctx->device->pipeline_dequant_mul_mat_vec_f32_f32[a_type] : ctx->device->pipeline_dequant_mul_mat_vec_f16_f32[a_type];
return b_type == GGML_TYPE_F32 ? ctx->device->pipeline_dequant_mul_mat_vec_f32_f32[a_type][num_cols-1] : ctx->device->pipeline_dequant_mul_mat_vec_f16_f32[a_type][num_cols-1];
}
static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_id_pipeline(ggml_backend_vk_context * ctx, ggml_type src0_type, ggml_type src1_type, ggml_prec prec) {
@ -3925,8 +3930,6 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context&
const uint64_t ne12 = src1->ne[2];
const uint64_t ne13 = src1->ne[3];
GGML_ASSERT(ne11 == 1);
const uint64_t ne20 = dst->ne[0];
const uint64_t ne21 = dst->ne[1];
const uint64_t ne22 = dst->ne[2];
@ -3935,6 +3938,11 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context&
const uint64_t r2 = ne12 / ne02;
const uint64_t r3 = ne13 / ne03;
// batch_n indicates that we need to compute a few vector results, and this assumes
// ne12 and ne13 are 1. It overloads the batch_strides to hold the row strides.
GGML_ASSERT(ne11 == 1 || ne12 * ne13 == 1);
bool batch_n = ne11 > 1;
ggml_backend_vk_buffer_context * dst_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context;
ggml_backend_vk_buffer_context * src0_buf_ctx = (ggml_backend_vk_buffer_context *)src0->buffer->context;
ggml_backend_vk_buffer_context * src1_buf_ctx = (ggml_backend_vk_buffer_context *)src1->buffer->context;
@ -3985,7 +3993,7 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context&
} else {
to_fp16_vk_1 = ggml_vk_get_to_fp16(ctx, src1->type);
}
vk_pipeline dmmv = ggml_vk_get_dequantize_mul_mat_vec(ctx, src0->type, src1->type);
vk_pipeline dmmv = ggml_vk_get_dequantize_mul_mat_vec(ctx, src0->type, src1->type, ne11);
GGML_ASSERT(!qx_needs_dequant || to_fp16_vk_0 != nullptr); // NOLINT
GGML_ASSERT(!qy_needs_dequant || to_fp16_vk_1 != nullptr); // NOLINT
GGML_ASSERT(dmmv != nullptr);
@ -4057,8 +4065,10 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context&
ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_1, src1, { d_Qy, qy_buf_offset, VK_WHOLE_SIZE }, { d_Y, 0, VK_WHOLE_SIZE });
}
uint32_t stride_batch_x = ne00*ne01;
uint32_t stride_batch_y = ne10*ne11;
// For batch_n, the A matrix is the same for each batch, and B/D use the row stride as the batch stride
uint32_t stride_batch_x = batch_n ? 0 : ne00*ne01;
uint32_t stride_batch_y = batch_n ? ne10 : (ne10*ne11);
uint32_t stride_batch_d = batch_n ? ne20 : (ne20*ne21);
if (!ggml_vk_dim01_contiguous(src0) && !qx_needs_dequant) {
stride_batch_x = src0->nb[0] / ggml_type_size(src0->type);
@ -4081,7 +4091,7 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context&
// compute
const vk_mat_vec_push_constants pc = {
(uint32_t)ne00, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)ne01,
stride_batch_x, stride_batch_y, (uint32_t)(ne20*ne21),
stride_batch_x, stride_batch_y, stride_batch_d,
(uint32_t)ne02, (uint32_t)ne12, (uint32_t)r2, (uint32_t)r3,
};
ggml_vk_sync_buffers(subctx);
@ -4261,7 +4271,10 @@ static void ggml_vk_mul_mat(ggml_backend_vk_context * ctx, vk_context& subctx, c
} else if (src0->type == GGML_TYPE_F16 && !ggml_is_contiguous(src0) && !ggml_is_transposed(src1) && dst->ne[1] == 1 &&
!ggml_is_permuted(src0) && !ggml_is_permuted(src1)) {
ggml_vk_mul_mat_vec_nc_f16_f32(ctx, subctx, src0, src1, dst, dryrun);
} else if (dst->ne[1] == 1 && (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type))) {
// mul_mat_vec supports batching ne12*ne13 when ne11==1, or treating ne11 as the batch size (up to four)
// when ne12 and ne13 are one.
} else if ((dst->ne[1] == 1 || (dst->ne[1] <= mul_mat_vec_max_cols && src1->ne[2] * src1->ne[3] == 1)) &&
(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type))) {
ggml_vk_mul_mat_vec_q_f16(ctx, subctx, src0, src1, dst, dryrun);
} else {
ggml_vk_mul_mat_q_f16(ctx, subctx, src0, src1, dst, dryrun);
@ -5076,6 +5089,57 @@ static bool ggml_vk_op_supports_incontiguous(ggml_op op) {
}
}
static uint32_t get_misalign_bytes(ggml_backend_vk_context * ctx, const ggml_tensor * t)
{
return ((vk_tensor_offset(t) + t->view_offs) & (ctx->device->properties.limits.minStorageBufferOffsetAlignment - 1));;
}
template <typename T> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, T &p, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst) {
GGML_UNUSED(p);
GGML_UNUSED(src0);
GGML_UNUSED(src1);
GGML_UNUSED(src2);
GGML_UNUSED(dst);
static_assert(!std::is_const<T>::value, "unexpected type");
GGML_ASSERT(!src0 || get_misalign_bytes(ctx, src0) == 0);
GGML_ASSERT(!src1 || get_misalign_bytes(ctx, src1) == 0);
GGML_ASSERT(!src2 || get_misalign_bytes(ctx, src2) == 0);
GGML_ASSERT(!dst || get_misalign_bytes(ctx, dst) == 0);
}
template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk_op_unary_push_constants &p, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst) {
const uint32_t a_offset = get_misalign_bytes(ctx, src0) / ggml_type_size(src0->type);
const uint32_t d_offset = get_misalign_bytes(ctx, dst) / ggml_type_size(dst->type);
p.misalign_offsets = (a_offset << 16) | d_offset;
GGML_UNUSED(src1);
GGML_UNUSED(src2);
}
template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk_op_binary_push_constants &p, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst) {
const uint32_t a_offset = get_misalign_bytes(ctx, src0) / ggml_type_size(src0->type);
const uint32_t b_offset = get_misalign_bytes(ctx, src1) / ggml_type_size(src1->type);
const uint32_t d_offset = get_misalign_bytes(ctx, dst) / ggml_type_size(dst->type);
GGML_ASSERT(dst->op != GGML_OP_GET_ROWS || (a_offset == 0 && b_offset == 0 && d_offset == 0));
p.misalign_offsets = (a_offset << 16) | (b_offset << 8) | d_offset;
GGML_UNUSED(src2);
}
template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk_op_upscale_push_constants &p, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst) {
const uint32_t a_offset = get_misalign_bytes(ctx, src0) / ggml_type_size(src0->type);
const uint32_t d_offset = get_misalign_bytes(ctx, dst) / ggml_type_size(dst->type);
p.a_offset = a_offset;
p.d_offset = d_offset;
GGML_UNUSED(src1);
GGML_UNUSED(src2);
}
template<typename PC>
static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, ggml_op op, PC&& pc, bool dryrun = false) {
VK_LOG_DEBUG("ggml_vk_op_f32((" << src0 << ", name=" << src0->name << ", type=" << src0->type << ", ne0=" << src0->ne[0] << ", ne1=" << src0->ne[1] << ", ne2=" << src0->ne[2] << ", ne3=" << src0->ne[3] << ", nb0=" << src0->nb[0] << ", nb1=" << src0->nb[1] << ", nb2=" << src0->nb[2] << ", nb3=" << src0->nb[3];
@ -5179,8 +5243,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
}
GGML_ASSERT(d_D != nullptr);
uint64_t d_buf_offset = ((vk_tensor_offset(dst) + dst->view_offs) / ctx->device->properties.limits.minStorageBufferOffsetAlignment) * ctx->device->properties.limits.minStorageBufferOffsetAlignment;
GGML_ASSERT(d_buf_offset == vk_tensor_offset(dst) || op == GGML_OP_CPY); // NOLINT
uint64_t d_buf_offset = vk_tensor_offset(dst) + dst->view_offs;
if(!src0_uma) {
d_X = src0_buf_ctx->dev_buffer;
x_buf_offset = vk_tensor_offset(src0) + src0->view_offs;
@ -5196,6 +5259,12 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
z_buf_offset = vk_tensor_offset(src2) + src2->view_offs;
GGML_ASSERT(d_Z != nullptr);
}
// Compute misalignment offset for descriptors and store it in in push constants, then align the descriptor offsets.
init_pushconst_tensor_offsets(ctx, pc, src0, src1, src2, dst);
x_buf_offset &= ~(ctx->device->properties.limits.minStorageBufferOffsetAlignment - 1);
y_buf_offset &= ~(ctx->device->properties.limits.minStorageBufferOffsetAlignment - 1);
z_buf_offset &= ~(ctx->device->properties.limits.minStorageBufferOffsetAlignment - 1);
d_buf_offset &= ~(ctx->device->properties.limits.minStorageBufferOffsetAlignment - 1);
if (op_supports_incontiguous) {
x_sz = ggml_nbytes(src0);
@ -5383,7 +5452,6 @@ static void ggml_vk_acc(ggml_backend_vk_context * ctx, vk_context& subctx, const
const uint32_t src0_type_size = ggml_type_size(src0->type);
const uint32_t src1_type_size = ggml_type_size(src1->type);
const uint32_t dst_type_size = ggml_type_size(dst->type);
const uint32_t d_offset = ((vk_tensor_offset(dst) + dst->view_offs) % ctx->device->properties.limits.minStorageBufferOffsetAlignment) / dst_type_size;
int nb1 = dst->op_params[0] / 4; // 4 bytes of float32
int nb2 = dst->op_params[1] / 4; // 4 bytes of float32
@ -5395,7 +5463,7 @@ static void ggml_vk_acc(ggml_backend_vk_context * ctx, vk_context& subctx, const
(uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)nb1, (uint32_t)nb2, (uint32_t)src0->nb[3] / src0_type_size,
(uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size,
(uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t)nb1, (uint32_t)nb2, (uint32_t) dst->nb[3] / dst_type_size,
d_offset,
0,
0.0f, 0.0f, offset,
}, dryrun);
}
@ -5599,7 +5667,7 @@ static void ggml_vk_upscale(ggml_backend_vk_context * ctx, vk_context& subctx, c
const float sf3 = (float)dst->ne[3] / src0->ne[3];
ggml_vk_op_f32<vk_op_upscale_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_UPSCALE, {
(uint32_t)ggml_nelements(dst), 0,
(uint32_t)ggml_nelements(dst), 0, 0,
(uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
(uint32_t)dst->ne[0], (uint32_t)dst->ne[1], (uint32_t)dst->ne[2],(uint32_t)dst->ne[3],
sf0, sf1, sf2, sf3,
@ -5709,13 +5777,12 @@ static void ggml_vk_repeat(ggml_backend_vk_context * ctx, vk_context& subctx, co
static void ggml_vk_cpy(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
const uint32_t src0_type_size = ggml_type_size(src0->type);
const uint32_t dst_type_size = ggml_type_size(dst->type);
const uint32_t d_offset = ((vk_tensor_offset(dst) + dst->view_offs) % ctx->device->properties.limits.minStorageBufferOffsetAlignment) / dst_type_size;
ggml_vk_op_f32<vk_op_unary_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_CPY, {
(uint32_t)ggml_nelements(src0),
(uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
(uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
d_offset,
0,
0.0f, 0.0f,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
}, dryrun);

View file

@ -21,9 +21,9 @@ void main() {
get_indices(idx, i00, i01, i02, i03);
if (ox < p.ne10 && oy < p.ne11 && oz < p.ne12) {
data_d[p.d_offset + dst_idx(i00, i01, i02, i03)] = D_TYPE(FLOAT_TYPE(data_a[src0_idx(i00, i01, i02, i03)]) + FLOAT_TYPE(data_b[ox + oy * p.ne10 + oz * p.ne10 * p.ne11]));
data_d[get_doffset() + dst_idx(i00, i01, i02, i03)] = D_TYPE(FLOAT_TYPE(data_a[get_aoffset() + src0_idx(i00, i01, i02, i03)]) + FLOAT_TYPE(data_b[get_boffset() + ox + oy * p.ne10 + oz * p.ne10 * p.ne11]));
} else {
data_d[p.d_offset + dst_idx(i00, i01, i02, i03)] = D_TYPE(FLOAT_TYPE(data_a[src0_idx(i00, i01, i02, i03)]));
data_d[get_doffset() + dst_idx(i00, i01, i02, i03)] = D_TYPE(FLOAT_TYPE(data_a[get_aoffset() + src0_idx(i00, i01, i02, i03)]));
}
}

View file

@ -22,7 +22,7 @@ void main() {
uint i00, i01, i02, i03;
get_indices(idx, i00, i01, i02, i03);
data_d[p.d_offset + dst_idx(i00, i01, i02, i03)] = D_TYPE(FLOAT_TYPE(data_a[src0_idx(i00, i01, i02, i03)]) + FLOAT_TYPE(data_b[src1_idx(i00, i01, i02, i03)]));
data_d[get_doffset() + dst_idx(i00, i01, i02, i03)] = D_TYPE(FLOAT_TYPE(data_a[get_aoffset() + src0_idx(i00, i01, i02, i03)]) + FLOAT_TYPE(data_b[get_boffset() + src1_idx(i00, i01, i02, i03)]));
idx += num_threads;
}

View file

@ -12,6 +12,6 @@ void main() {
return;
}
const FLOAT_TYPE val = FLOAT_TYPE(data_a[src0_idx(idx)]);
data_d[p.d_offset + dst_idx(idx)] = D_TYPE(val < p.param1 ? p.param1 : (val > p.param2 ? p.param2 : val));
const FLOAT_TYPE val = FLOAT_TYPE(data_a[get_aoffset() + src0_idx(idx)]);
data_d[get_doffset() + dst_idx(idx)] = D_TYPE(val < p.param1 ? p.param1 : (val > p.param2 ? p.param2 : val));
}

View file

@ -30,12 +30,12 @@ void main() {
const bool is_src0 = i0 < p.ne00 && i1 < p.ne01 && i2 < p.ne02 && i3 < p.ne03;
#ifndef OPTIMIZATION_ERROR_WORKAROUND
data_d[p.d_offset + dst_idx] = D_TYPE(is_src0 ? data_a[src0_idx] : data_b[src1_idx]);
data_d[get_doffset() + dst_idx] = D_TYPE(is_src0 ? data_a[get_aoffset() + src0_idx] : data_b[get_boffset() + src1_idx]);
#else
if (is_src0) {
data_d[p.d_offset + dst_idx] = data_a[src0_idx];
data_d[get_doffset() + dst_idx] = data_a[get_aoffset() + src0_idx];
} else {
data_d[p.d_offset + dst_idx] = data_b[src1_idx];
data_d[get_doffset() + dst_idx] = data_b[get_boffset() + src1_idx];
}
#endif
}

View file

@ -19,9 +19,9 @@ void main() {
if (idx + (num_iter-1)*num_threads < p.ne) {
[[unroll]] for (uint i = 0; i < num_iter; ++i) {
#ifndef OPTIMIZATION_ERROR_WORKAROUND
data_d[p.d_offset + idx] = D_TYPE(data_a[idx]);
data_d[get_doffset() + idx] = D_TYPE(data_a[get_aoffset() + idx]);
#else
data_d[p.d_offset + idx] = data_a[idx];
data_d[get_doffset() + idx] = data_a[get_aoffset() + idx];
#endif
idx += num_threads;
}
@ -32,9 +32,9 @@ void main() {
}
#ifndef OPTIMIZATION_ERROR_WORKAROUND
data_d[p.d_offset + idx] = D_TYPE(data_a[idx]);
data_d[get_doffset() + idx] = D_TYPE(data_a[get_aoffset() + idx]);
#else
data_d[p.d_offset + idx] = data_a[idx];
data_d[get_doffset() + idx] = data_a[get_aoffset() + idx];
#endif
idx += num_threads;
}

View file

@ -13,8 +13,8 @@ void main() {
}
#ifndef OPTIMIZATION_ERROR_WORKAROUND
data_d[p.d_offset + dst_idx(idx)] = D_TYPE(data_a[src0_idx(idx)]);
data_d[get_doffset() + dst_idx(idx)] = D_TYPE(data_a[get_aoffset() + src0_idx(idx)]);
#else
data_d[p.d_offset + dst_idx(idx)] = data_a[src0_idx(idx)];
data_d[get_doffset() + dst_idx(idx)] = data_a[get_aoffset() + src0_idx(idx)];
#endif
}

View file

@ -12,6 +12,6 @@ void main() {
return;
}
const FLOAT_TYPE val = FLOAT_TYPE(data_a[src0_idx(idx)]);
data_d[p.d_offset + dst_idx(idx)] = D_TYPE(cos(val));
const FLOAT_TYPE val = FLOAT_TYPE(data_a[get_aoffset() + src0_idx(idx)]);
data_d[get_doffset() + dst_idx(idx)] = D_TYPE(cos(val));
}

View file

@ -20,7 +20,7 @@ void main() {
uint i00, i01, i02, i03;
get_indices(idx, i00, i01, i02, i03);
data_d[p.d_offset + dst_idx(i00, i01, i02, i03)] = D_TYPE(FLOAT_TYPE(data_a[src0_idx(i00, i01, i02, i03)]) / FLOAT_TYPE(data_b[src1_idx(i00, i01, i02, i03)]));
data_d[get_doffset() + dst_idx(i00, i01, i02, i03)] = D_TYPE(FLOAT_TYPE(data_a[get_aoffset() + src0_idx(i00, i01, i02, i03)]) / FLOAT_TYPE(data_b[get_boffset() + src1_idx(i00, i01, i02, i03)]));
idx += num_threads;
}

View file

@ -7,7 +7,7 @@ layout (push_constant) uniform parameter
uint ne00; uint ne01; uint ne02; uint ne03; uint nb00; uint nb01; uint nb02; uint nb03;
uint ne10; uint ne11; uint ne12; uint ne13; uint nb10; uint nb11; uint nb12; uint nb13;
uint ne20; uint ne21; uint ne22; uint ne23; uint nb20; uint nb21; uint nb22; uint nb23;
uint d_offset;
uint misalign_offsets;
float param1; float param2; int param3;
} p;
@ -22,6 +22,10 @@ uint get_idx() {
return gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
}
uint get_aoffset() { return p.misalign_offsets >> 16; }
uint get_boffset() { return (p.misalign_offsets >> 8) & 0xFF; }
uint get_doffset() { return p.misalign_offsets & 0xFF; }
// mod and div are expensive and coordinates/dimensions are often power of 2 or equal to 1
uint fastmod(uint a, uint b) {
if ((b & (b-1)) == 0) {

View file

@ -6,7 +6,7 @@ layout (push_constant) uniform parameter
uint ne;
uint ne00; uint ne01; uint ne02; uint ne03; uint nb00; uint nb01; uint nb02; uint nb03;
uint ne10; uint ne11; uint ne12; uint ne13; uint nb10; uint nb11; uint nb12; uint nb13;
uint d_offset;
uint misalign_offsets;
float param1; float param2;
uint ne0_012mp; uint ne0_012L;
@ -24,6 +24,9 @@ uint get_idx() {
return gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
}
uint get_aoffset() { return p.misalign_offsets >> 16; }
uint get_doffset() { return p.misalign_offsets & 0xFFFF; }
// see init_fastdiv_values in ggml-vulkan.cpp
uint fastdiv(uint n, uint mp, uint L) {
uint msbs, lsbs;

View file

@ -15,10 +15,10 @@ void main() {
return;
}
const uint i01 = data_b[i10*p.nb10 + i11*p.nb11 + i12*p.nb12];
const uint i01 = data_b[get_boffset() + i10*p.nb10 + i11*p.nb11 + i12*p.nb12];
const uint a_offset = i01*p.nb01 + i11*p.nb02 + i12*p.nb03;
const uint d_offset = i10*p.nb21 + i11*p.nb22 + i12*p.nb23;
const uint a_offset = get_aoffset() + i01*p.nb01 + i11*p.nb02 + i12*p.nb03;
const uint d_offset = get_doffset() + i10*p.nb21 + i11*p.nb22 + i12*p.nb23;
#ifndef OPTIMIZATION_ERROR_WORKAROUND
data_d[d_offset + i00] = D_TYPE(data_a[a_offset + i00]);

View file

@ -2,6 +2,7 @@
#extension GL_EXT_shader_16bit_storage : require
#extension GL_EXT_spirv_intrinsics: enable
#extension GL_EXT_control_flow_attributes : require
#if RTE16
spirv_execution_mode(capabilities = [4467], 4462, 16); // RoundingModeRTE, 16 bits
@ -23,40 +24,64 @@ layout (push_constant) uniform parameter
#include "types.comp"
#define BLOCK_SIZE 256
layout(constant_id = 0) const uint BLOCK_SIZE = 32;
layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in;
const uint NUM_ITER = 512 / BLOCK_SIZE;
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
void main() {
const uint i = gl_GlobalInvocationID.x;
if (i >= p.pelements) {
return;
const uint gidx = gl_GlobalInvocationID.x;
const uint oh = gl_GlobalInvocationID.y;
const uint batch = gl_GlobalInvocationID.z / p.IC;
const uint ic = gl_GlobalInvocationID.z % p.IC;
A_TYPE values[NUM_ITER];
uint offset_dst[NUM_ITER];
[[unroll]] for (uint idx = 0; idx < NUM_ITER; ++idx) {
values[idx] = A_TYPE(0);
}
[[unroll]] for (uint idx = 0; idx < NUM_ITER; ++idx) {
const uint i = gidx * NUM_ITER + idx;
const uint ksize = p.OW * (p.KH > 1 ? p.KW : 1);
const uint kx = i / ksize;
const uint kd = kx * ksize;
const uint ky = (i - kd) / p.OW;
const uint ix = i % p.OW;
const uint oh = gl_GlobalInvocationID.y;
const uint batch = gl_GlobalInvocationID.z / p.IC;
const uint ic = gl_GlobalInvocationID.z % p.IC;
const uint iiw = ix * p.s0 + kx * p.d0 - p.p0;
const uint iih = oh * p.s1 + ky * p.d1 - p.p1;
const uint offset_dst =
offset_dst[idx] =
((batch * p.OH + oh) * p.OW + ix) * p.CHW +
(ic * (p.KW * p.KH) + ky * p.KW + kx);
if (iih < 0 || iih >= p.IH || iiw < 0 || iiw >= p.IW) {
data_d[offset_dst] = D_TYPE(0.0f);
} else {
if (i >= p.pelements) {
continue;
}
if (iih < p.IH && iiw < p.IW) {
const uint offset_src = ic * p.offset_delta + batch * p.batch_offset;
data_d[offset_dst] = D_TYPE(data_a[offset_src + iih * p.IW + iiw]);
values[idx] = data_a[offset_src + iih * p.IW + iiw];
}
}
[[unroll]] for (uint idx = 0; idx < NUM_ITER; ++idx) {
const uint i = gidx * NUM_ITER + idx;
if (i >= p.pelements) {
continue;
}
data_d[offset_dst[idx]] = D_TYPE(values[idx]);
}
}

View file

@ -20,7 +20,7 @@ void main() {
uint i00, i01, i02, i03;
get_indices(idx, i00, i01, i02, i03);
data_d[p.d_offset + dst_idx(i00, i01, i02, i03)] = D_TYPE(FLOAT_TYPE(data_a[src0_idx(i00, i01, i02, i03)]) * FLOAT_TYPE(data_b[src1_idx(i00, i01, i02, i03)]));
data_d[get_doffset() + dst_idx(i00, i01, i02, i03)] = D_TYPE(FLOAT_TYPE(data_a[get_aoffset() + src0_idx(i00, i01, i02, i03)]) * FLOAT_TYPE(data_b[get_boffset() + src1_idx(i00, i01, i02, i03)]));
idx += num_threads;
}

View file

@ -9,9 +9,6 @@
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
layout (constant_id = 0) const uint BLOCK_SIZE = 32;
layout (constant_id = 1) const uint NUM_ROWS = 1;
#if !defined(DATA_A_F32) && !defined(DATA_A_F16)
#define K_PER_ITER 8
#else
@ -21,23 +18,22 @@ layout (constant_id = 1) const uint NUM_ROWS = 1;
uint a_offset, b_offset, d_offset, y_offset;
shared FLOAT_TYPE tmpsh[NUM_ROWS][BLOCK_SIZE];
void iter(inout FLOAT_TYPE temp[NUM_ROWS], const uint first_row, const uint num_rows, const uint tid, const uint i, bool lastiter)
void iter(inout FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const uint first_row, const uint num_rows, const uint tid, const uint i, bool lastiter)
{
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
const uint col = i*BLOCK_SIZE + K_PER_ITER*tid;
const uint iqs = (col%QUANT_K)/QUANT_R; // quant index
const uint iybs = col - col%QUANT_K; // y block start index
#if K_PER_ITER == 8
#if QUANT_R == 2
const B_TYPE_VEC4 bv02 = data_b_v4[(b_offset + iybs + iqs) / 4];
const B_TYPE_VEC4 bv13 = data_b_v4[(b_offset + iybs + iqs + y_offset) / 4];
const B_TYPE_VEC4 bv02 = data_b_v4[(j*p.batch_stride_b + b_offset + iybs + iqs) / 4];
const B_TYPE_VEC4 bv13 = data_b_v4[(j*p.batch_stride_b + b_offset + iybs + iqs + y_offset) / 4];
const vec4 bv0 = vec4(bv02.x, bv13.x, bv02.y, bv13.y);
const vec4 bv1 = vec4(bv02.z, bv13.z, bv02.w, bv13.w);
#else
const vec4 bv0 = vec4(data_b_v4[(b_offset + iybs + iqs) / 4]);
const vec4 bv1 = vec4(data_b_v4[(b_offset + iybs + iqs) / 4 + 1]);
const vec4 bv0 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + iybs + iqs) / 4]);
const vec4 bv1 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + iybs + iqs) / 4 + 1]);
#endif
#else
// Check if the second of the pair of elements is OOB, and don't fetch B or
@ -48,9 +44,9 @@ void iter(inout FLOAT_TYPE temp[NUM_ROWS], const uint first_row, const uint num_
const bool OOB = lastiter && (iybs + iqs + y_offset >= p.ncols);
FLOAT_TYPE b0 = 0, b1 = 0;
b0 = FLOAT_TYPE(data_b[b_offset + iybs + iqs]);
b0 = FLOAT_TYPE(data_b[j*p.batch_stride_b + b_offset + iybs + iqs]);
if (!OOB) {
b1 = FLOAT_TYPE(data_b[b_offset + iybs + iqs + y_offset]);
b1 = FLOAT_TYPE(data_b[j*p.batch_stride_b + b_offset + iybs + iqs + y_offset]);
}
#endif
uint ibi = first_row*p.ncols;
@ -75,18 +71,19 @@ void iter(inout FLOAT_TYPE temp[NUM_ROWS], const uint first_row, const uint num_
if (dm.y == 0)
rowtmp *= dm.x;
temp[n] += rowtmp;
temp[j][n] += rowtmp;
#else
const vec2 v = dequantize(ib, iqs, a_offset);
// matrix multiplication
temp[n] = fma(FLOAT_TYPE(v.x), b0, temp[n]);
temp[j][n] = fma(FLOAT_TYPE(v.x), b0, temp[j][n]);
if (!OOB) {
temp[n] = fma(FLOAT_TYPE(v.y), b1, temp[n]);
temp[j][n] = fma(FLOAT_TYPE(v.y), b1, temp[j][n]);
}
#endif
}
}
}
void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
const uint tid = gl_LocalInvocationID.x;
@ -96,10 +93,12 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
y_offset = QUANT_R == 1 ? 1 : QUANT_K/2;
FLOAT_TYPE temp[NUM_ROWS];
FLOAT_TYPE temp[NUM_COLS][NUM_ROWS];
for (uint i = 0; i < NUM_ROWS; ++i) {
temp[i] = FLOAT_TYPE(0);
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
[[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) {
temp[j][i] = FLOAT_TYPE(0);
}
}
uint num_iters = p.ncols / (K_PER_ITER * BLOCK_SIZE);
@ -131,24 +130,7 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
i++;
}
// sum up partial sums and write back result
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
tmpsh[n][tid] = temp[n];
}
barrier();
[[unroll]] for (uint s = BLOCK_SIZE/2; s > 0; s >>= 1) {
if (tid < s) {
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
tmpsh[n][tid] += tmpsh[n][tid + s];
}
}
barrier();
}
if (tid == 0) {
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
data_d[d_offset + first_row + n] = D_TYPE(tmpsh[n][0]);
}
}
reduce_result(temp, d_offset, first_row, num_rows, tid);
}
void main() {

View file

@ -83,3 +83,36 @@ void get_offsets(out uint a_offset, out uint b_offset, out uint d_offset) {
batch_idx * p.batch_stride_d;
#endif
}
layout (constant_id = 0) const uint BLOCK_SIZE = 32;
layout (constant_id = 1) const uint NUM_ROWS = 1;
layout (constant_id = 2) const uint NUM_COLS = 1;
shared FLOAT_TYPE tmpsh[NUM_COLS][NUM_ROWS][BLOCK_SIZE];
void reduce_result(const in FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const in uint32_t d_offset, const in uint32_t first_row, const in uint32_t num_rows, const in uint32_t tid) {
// sum up partial sums and write back result
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
tmpsh[j][n][tid] = temp[j][n];
}
}
barrier();
[[unroll]] for (uint s = BLOCK_SIZE/2; s > 0; s >>= 1) {
if (tid < s) {
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
tmpsh[j][n][tid] += tmpsh[j][n][tid + s];
}
}
}
barrier();
}
if (tid == 0) {
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
data_d[j*p.batch_stride_d + d_offset + first_row + n] = D_TYPE(tmpsh[j][n][0]);
}
}
}
}

View file

@ -5,11 +5,6 @@
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
layout (constant_id = 0) const uint BLOCK_SIZE = 32;
layout (constant_id = 1) const uint NUM_ROWS = 1;
shared FLOAT_TYPE tmpsh[NUM_ROWS][BLOCK_SIZE];
void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
uint a_offset, b_offset, d_offset;
get_offsets(a_offset, b_offset, d_offset);
@ -32,24 +27,17 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
const uint s_offset = 8*v_im;
const uint y_offset = 128*v_im + l0;
FLOAT_TYPE temp[NUM_ROWS];
FLOAT_TYPE temp[NUM_COLS][NUM_ROWS];
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
[[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) {
temp[i] = FLOAT_TYPE(0);
temp[j][i] = FLOAT_TYPE(0);
}
}
[[unroll]] for (uint i = ix; i < num_blocks_per_row; i += it_size) {
const uint y_idx = i * QUANT_K + y_offset;
B_TYPE_VEC2 b0 = data_b_v2[(b_offset + y_idx) / 2 + 0];
B_TYPE_VEC2 b16 = data_b_v2[(b_offset + y_idx) / 2 + 8];
B_TYPE_VEC2 b32 = data_b_v2[(b_offset + y_idx) / 2 + 16];
B_TYPE_VEC2 b48 = data_b_v2[(b_offset + y_idx) / 2 + 24];
B_TYPE_VEC2 b64 = data_b_v2[(b_offset + y_idx) / 2 + 32];
B_TYPE_VEC2 b80 = data_b_v2[(b_offset + y_idx) / 2 + 40];
B_TYPE_VEC2 b96 = data_b_v2[(b_offset + y_idx) / 2 + 48];
B_TYPE_VEC2 b112 = data_b_v2[(b_offset + y_idx) / 2 + 56];
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row;
f16vec2 d = data_a[ib0 + i].d;
@ -74,6 +62,16 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
uvec2 qs0 = uvec2(unpack8(qs0_u16));
uvec2 qs16 = uvec2(unpack8(qs16_u16));
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
B_TYPE_VEC2 b0 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 0];
B_TYPE_VEC2 b16 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 8];
B_TYPE_VEC2 b32 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 16];
B_TYPE_VEC2 b48 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 24];
B_TYPE_VEC2 b64 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 32];
B_TYPE_VEC2 b80 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 40];
B_TYPE_VEC2 b96 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 48];
B_TYPE_VEC2 b112 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 56];
FLOAT_TYPE sum1 = FLOAT_TYPE(0.0);
FLOAT_TYPE sum2 = FLOAT_TYPE(0.0);
[[unroll]] for (int l = 0; l < 2; ++l) {
@ -94,28 +92,12 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
fma(FLOAT_TYPE(b96[l]), FLOAT_TYPE(s4_hi4[2]),
fma(FLOAT_TYPE(b112[l]), FLOAT_TYPE(s4_hi4[3]), sum2))))))));
}
temp[n] = fma(dall, sum1, fma(-dmin, sum2, temp[n]));
temp[j][n] = fma(dall, sum1, fma(-dmin, sum2, temp[j][n]));
}
}
}
// sum up partial sums and write back result
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
tmpsh[n][tid] = temp[n];
}
barrier();
[[unroll]] for (uint s = BLOCK_SIZE/2; s > 0; s >>= 1) {
if (tid < s) {
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
tmpsh[n][tid] += tmpsh[n][tid + s];
}
}
barrier();
}
if (tid == 0) {
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
data_d[d_offset + first_row + n] = D_TYPE(tmpsh[n][0]);
}
}
reduce_result(temp, d_offset, first_row, num_rows, tid);
}
void main() {

View file

@ -5,11 +5,6 @@
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
layout (constant_id = 0) const uint BLOCK_SIZE = 32;
layout (constant_id = 1) const uint NUM_ROWS = 1;
shared FLOAT_TYPE tmpsh[NUM_ROWS][BLOCK_SIZE];
void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
uint a_offset, b_offset, d_offset;
get_offsets(a_offset, b_offset, d_offset);
@ -33,10 +28,12 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
const uint q_offset = 32*v_im + l0;
const uint y_offset = 128*v_im + l0;
FLOAT_TYPE temp[NUM_ROWS];
FLOAT_TYPE temp[NUM_COLS][NUM_ROWS];
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
[[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) {
temp[i] = FLOAT_TYPE(0);
temp[j][i] = FLOAT_TYPE(0);
}
}
const uint s_shift = 4 * v_im;
@ -44,15 +41,6 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
[[unroll]] for (uint i = ix; i < num_blocks_per_row; i += it_size) {
const uint y_idx = i * QUANT_K + y_offset;
B_TYPE_VEC2 b0 = data_b_v2[(b_offset + y_idx) / 2 + 0];
B_TYPE_VEC2 b16 = data_b_v2[(b_offset + y_idx) / 2 + 8];
B_TYPE_VEC2 b32 = data_b_v2[(b_offset + y_idx) / 2 + 16];
B_TYPE_VEC2 b48 = data_b_v2[(b_offset + y_idx) / 2 + 24];
B_TYPE_VEC2 b64 = data_b_v2[(b_offset + y_idx) / 2 + 32];
B_TYPE_VEC2 b80 = data_b_v2[(b_offset + y_idx) / 2 + 40];
B_TYPE_VEC2 b96 = data_b_v2[(b_offset + y_idx) / 2 + 48];
B_TYPE_VEC2 b112 = data_b_v2[(b_offset + y_idx) / 2 + 56];
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row;
const FLOAT_TYPE d = FLOAT_TYPE(data_a[ib0 + i].d);
@ -70,6 +58,17 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
u8vec2 s8 = unpack8(s8_16);
u8vec2 s10 = unpack8(s10_16);
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
B_TYPE_VEC2 b0 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 0];
B_TYPE_VEC2 b16 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 8];
B_TYPE_VEC2 b32 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 16];
B_TYPE_VEC2 b48 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 24];
B_TYPE_VEC2 b64 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 32];
B_TYPE_VEC2 b80 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 40];
B_TYPE_VEC2 b96 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 48];
B_TYPE_VEC2 b112 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 56];
FLOAT_TYPE sum = FLOAT_TYPE(0.0);
[[unroll]] for (int l = 0; l < 2; ++l) {
sum = fma(FLOAT_TYPE(b0[l]) * FLOAT_TYPE(int8_t(((s0[0] >> s_shift) & 0xF) | ((s8[0] >> (s_shift + 0) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l ] ) & 3) - (((data_a[ib0 + i].hmask[l0 + l ] & (m << 0)) != 0) ? 0 : 4)),
@ -81,28 +80,12 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
fma(FLOAT_TYPE(b80[l]) * FLOAT_TYPE(int8_t(((s4[1] >> s_shift) & 0xF) | ((s8[1] >> (s_shift + 2) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l+16] >> 4) & 3) - (((data_a[ib0 + i].hmask[l0 + l+16] & (m << 2)) != 0) ? 0 : 4)),
fma(FLOAT_TYPE(b112[l]) * FLOAT_TYPE(int8_t(((s6[1] >> s_shift) & 0xF) | ((s10[1] >> (s_shift + 2) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l+16] >> 6) & 3) - (((data_a[ib0 + i].hmask[l0 + l+16] & (m << 3)) != 0) ? 0 : 4)), sum))))))));
}
temp[n] = fma(d, sum, temp[n]);
temp[j][n] = fma(d, sum, temp[j][n]);
}
}
}
// sum up partial sums and write back result
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
tmpsh[n][tid] = temp[n];
}
barrier();
[[unroll]] for (uint s = BLOCK_SIZE/2; s > 0; s >>= 1) {
if (tid < s) {
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
tmpsh[n][tid] += tmpsh[n][tid + s];
}
}
barrier();
}
if (tid == 0) {
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
data_d[d_offset + first_row + n] = D_TYPE(tmpsh[n][0]);
}
}
reduce_result(temp, d_offset, first_row, num_rows, tid);
}
void main() {

View file

@ -6,11 +6,6 @@
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
layout (constant_id = 0) const uint BLOCK_SIZE = 32;
layout (constant_id = 1) const uint NUM_ROWS = 1;
shared FLOAT_TYPE tmpsh[NUM_ROWS][BLOCK_SIZE];
void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
uint a_offset, b_offset, d_offset;
get_offsets(a_offset, b_offset, d_offset);
@ -36,21 +31,18 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
const uint q_offset = 32*v_im + l0;
const uint y_offset = 64*v_im + l0;
FLOAT_TYPE temp[NUM_ROWS];
FLOAT_TYPE temp[NUM_COLS][NUM_ROWS];
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
[[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) {
temp[i] = FLOAT_TYPE(0);
temp[j][i] = FLOAT_TYPE(0);
}
}
[[unroll]] for (uint i = ix; i < num_blocks_per_row; i += it_size) {
const uint y1_idx = i * QUANT_K + y_offset;
const uint y2_idx = y1_idx + 128;
B_TYPE_VEC4 by10 = data_b_v4[(b_offset + y1_idx) / 4];
B_TYPE_VEC4 by132 = data_b_v4[(b_offset + y1_idx) / 4 + 8];
B_TYPE_VEC4 by20 = data_b_v4[(b_offset + y2_idx) / 4];
B_TYPE_VEC4 by232 = data_b_v4[(b_offset + y2_idx) / 4 + 8];
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row;
f16vec2 d = data_a[ib0 + i].d;
@ -103,6 +95,12 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
const uint32_t q4_14 = qs64_hi4.z;
const uint32_t q4_15 = qs64_hi4.w;
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
B_TYPE_VEC4 by10 = data_b_v4[(j*p.batch_stride_b + b_offset + y1_idx) / 4];
B_TYPE_VEC4 by132 = data_b_v4[(j*p.batch_stride_b + b_offset + y1_idx) / 4 + 8];
B_TYPE_VEC4 by20 = data_b_v4[(j*p.batch_stride_b + b_offset + y2_idx) / 4];
B_TYPE_VEC4 by232 = data_b_v4[(j*p.batch_stride_b + b_offset + y2_idx) / 4 + 8];
const FLOAT_TYPE sx = fma(FLOAT_TYPE(by10.x), q4_0, fma(FLOAT_TYPE(by10.y), q4_1, fma(FLOAT_TYPE(by10.z), q4_2, FLOAT_TYPE(by10.w) * q4_3)));
const FLOAT_TYPE sy = fma(FLOAT_TYPE(by132.x), q4_4, fma(FLOAT_TYPE(by132.y), q4_5, fma(FLOAT_TYPE(by132.z), q4_6, FLOAT_TYPE(by132.w) * q4_7)));
const FLOAT_TYPE sz = fma(FLOAT_TYPE(by20.x), q4_8, fma(FLOAT_TYPE(by20.y), q4_9, fma(FLOAT_TYPE(by20.z), q4_10, FLOAT_TYPE(by20.w) * q4_11)));
@ -112,28 +110,12 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
fma(FLOAT_TYPE(by10.y), sc2, fma(FLOAT_TYPE(by132.y), sc3, fma(FLOAT_TYPE(by20.y), sc6, fma(FLOAT_TYPE(by232.y), sc7,
fma(FLOAT_TYPE(by10.z), sc2, fma(FLOAT_TYPE(by132.z), sc3, fma(FLOAT_TYPE(by20.z), sc6, fma(FLOAT_TYPE(by232.z), sc7,
fma(FLOAT_TYPE(by10.w), sc2, fma(FLOAT_TYPE(by132.w), sc3, fma(FLOAT_TYPE(by20.w), sc6, FLOAT_TYPE(by232.w) * sc7)))))))))))))));
temp[n] = fma(dall, fma(sx, sc0, fma(sy, sc1, fma(sz, sc4, sw * sc5))), fma(-dmin, smin, temp[n]));
temp[j][n] = fma(dall, fma(sx, sc0, fma(sy, sc1, fma(sz, sc4, sw * sc5))), fma(-dmin, smin, temp[j][n]));
}
}
}
// sum up partial sums and write back result
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
tmpsh[n][tid] = temp[n];
}
barrier();
[[unroll]] for (uint s = BLOCK_SIZE/2; s > 0; s >>= 1) {
if (tid < s) {
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
tmpsh[n][tid] += tmpsh[n][tid + s];
}
}
barrier();
}
if (tid == 0) {
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
data_d[d_offset + first_row + n] = D_TYPE(tmpsh[n][0]);
}
}
reduce_result(temp, d_offset, first_row, num_rows, tid);
}
void main() {

View file

@ -6,11 +6,6 @@
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
layout (constant_id = 0) const uint BLOCK_SIZE = 32;
layout (constant_id = 1) const uint NUM_ROWS = 1;
shared FLOAT_TYPE tmpsh[NUM_ROWS][BLOCK_SIZE];
void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
uint a_offset, b_offset, d_offset;
get_offsets(a_offset, b_offset, d_offset);
@ -33,25 +28,18 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
const uint q_offset = 32*v_im + l0;
const uint y_offset = 64*v_im + l0;
FLOAT_TYPE temp[NUM_ROWS];
FLOAT_TYPE temp[NUM_COLS][NUM_ROWS];
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
[[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) {
temp[i] = FLOAT_TYPE(0);
temp[j][i] = FLOAT_TYPE(0);
}
}
[[unroll]] for (uint i = ix; i < num_blocks_per_row; i += it_size) {
const uint y1_idx = i * QUANT_K + y_offset;
const uint y2_idx = y1_idx + 128;
B_TYPE_VEC2 by10 = data_b_v2[(b_offset + y1_idx) / 2];
B_TYPE_VEC2 by116 = data_b_v2[(b_offset + y1_idx) / 2 + 8];
B_TYPE_VEC2 by132 = data_b_v2[(b_offset + y1_idx) / 2 + 16];
B_TYPE_VEC2 by148 = data_b_v2[(b_offset + y1_idx) / 2 + 24];
B_TYPE_VEC2 by20 = data_b_v2[(b_offset + y2_idx) / 2];
B_TYPE_VEC2 by216 = data_b_v2[(b_offset + y2_idx) / 2 + 8];
B_TYPE_VEC2 by232 = data_b_v2[(b_offset + y2_idx) / 2 + 16];
B_TYPE_VEC2 by248 = data_b_v2[(b_offset + y2_idx) / 2 + 24];
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row;
f16vec2 d = data_a[ib0 + i].d;
@ -116,6 +104,16 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
const uint32_t q4_14 = qs64_80_hi4.z;
const uint32_t q4_15 = qs64_80_hi4.w;
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
B_TYPE_VEC2 by10 = data_b_v2[(j*p.batch_stride_b + b_offset + y1_idx) / 2];
B_TYPE_VEC2 by116 = data_b_v2[(j*p.batch_stride_b + b_offset + y1_idx) / 2 + 8];
B_TYPE_VEC2 by132 = data_b_v2[(j*p.batch_stride_b + b_offset + y1_idx) / 2 + 16];
B_TYPE_VEC2 by148 = data_b_v2[(j*p.batch_stride_b + b_offset + y1_idx) / 2 + 24];
B_TYPE_VEC2 by20 = data_b_v2[(j*p.batch_stride_b + b_offset + y2_idx) / 2];
B_TYPE_VEC2 by216 = data_b_v2[(j*p.batch_stride_b + b_offset + y2_idx) / 2 + 8];
B_TYPE_VEC2 by232 = data_b_v2[(j*p.batch_stride_b + b_offset + y2_idx) / 2 + 16];
B_TYPE_VEC2 by248 = data_b_v2[(j*p.batch_stride_b + b_offset + y2_idx) / 2 + 24];
const FLOAT_TYPE sx =
fma(FLOAT_TYPE(by10.x), q4_0,
fma(FLOAT_TYPE(by10.y), q4_1,
@ -141,28 +139,12 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
fma(FLOAT_TYPE(by132.x) + FLOAT_TYPE(by132.y) + FLOAT_TYPE(by148.x) + FLOAT_TYPE(by148.y), sc3,
fma(FLOAT_TYPE(by20.x) + FLOAT_TYPE(by20.y) + FLOAT_TYPE(by216.x) + FLOAT_TYPE(by216.y), sc6,
(FLOAT_TYPE(by232.x) + FLOAT_TYPE(by232.y) + FLOAT_TYPE(by248.x) + FLOAT_TYPE(by248.y)) * sc7)));
temp[n] = fma(dall, fma(sx, sc0, fma(sy, sc1, fma(sz, sc4, sw * sc5))), fma(-dmin, smin, temp[n]));
temp[j][n] = fma(dall, fma(sx, sc0, fma(sy, sc1, fma(sz, sc4, sw * sc5))), fma(-dmin, smin, temp[j][n]));
}
}
}
// sum up partial sums and write back result
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
tmpsh[n][tid] = temp[n];
}
barrier();
[[unroll]] for (uint s = BLOCK_SIZE/2; s > 0; s >>= 1) {
if (tid < s) {
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
tmpsh[n][tid] += tmpsh[n][tid + s];
}
}
barrier();
}
if (tid == 0) {
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
data_d[d_offset + first_row + n] = D_TYPE(tmpsh[n][0]);
}
}
reduce_result(temp, d_offset, first_row, num_rows, tid);
}
void main() {

View file

@ -6,11 +6,6 @@
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
layout (constant_id = 0) const uint BLOCK_SIZE = 32;
layout (constant_id = 1) const uint NUM_ROWS = 1;
shared FLOAT_TYPE tmpsh[NUM_ROWS][BLOCK_SIZE];
void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
uint a_offset, b_offset, d_offset;
get_offsets(a_offset, b_offset, d_offset);
@ -36,20 +31,17 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
const uint s_offset = 8*v_im + is;
const uint y_offset = 128*v_im + l0;
FLOAT_TYPE temp[NUM_ROWS];
FLOAT_TYPE temp[NUM_COLS][NUM_ROWS];
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
[[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) {
temp[i] = FLOAT_TYPE(0);
temp[j][i] = FLOAT_TYPE(0);
}
}
[[unroll]] for (uint i = ix; i < num_blocks_per_row; i += it_size) {
const uint y_idx = i * QUANT_K + y_offset;
B_TYPE_VEC4 by0 = data_b_v4[(b_offset + y_idx) / 4];
B_TYPE_VEC4 by32 = data_b_v4[(b_offset + y_idx) / 4 + 8];
B_TYPE_VEC4 by64 = data_b_v4[(b_offset + y_idx) / 4 + 16];
B_TYPE_VEC4 by96 = data_b_v4[(b_offset + y_idx) / 4 + 24];
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row;
const FLOAT_TYPE d = FLOAT_TYPE(data_a[ib0 + i].d);
@ -84,6 +76,12 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
uvec4 q2 = uvec4(unpack8(q2_u32));
uvec4 q3 = uvec4(unpack8(q3_u32));
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
B_TYPE_VEC4 by0 = data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4];
B_TYPE_VEC4 by32 = data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 8];
B_TYPE_VEC4 by64 = data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 16];
B_TYPE_VEC4 by96 = data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 24];
FLOAT_TYPE sum = FLOAT_TYPE(0.0);
[[unroll]] for (int l = 0; l < 4; ++l) {
sum = fma(FLOAT_TYPE(by0[l]) * scales[0], FLOAT_TYPE(int8_t(q0[l]) - 32),
@ -91,28 +89,12 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
fma(FLOAT_TYPE(by64[l]) * scales[2], FLOAT_TYPE(int8_t(q2[l]) - 32),
fma(FLOAT_TYPE(by96[l]) * scales[3], FLOAT_TYPE(int8_t(q3[l]) - 32), sum))));
}
temp[n] += sum * d;
temp[j][n] += sum * d;
}
}
}
// sum up partial sums and write back result
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
tmpsh[n][tid] = temp[n];
}
barrier();
[[unroll]] for (uint s = BLOCK_SIZE/2; s > 0; s >>= 1) {
if (tid < s) {
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
tmpsh[n][tid] += tmpsh[n][tid + s];
}
}
barrier();
}
if (tid == 0) {
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
data_d[d_offset + first_row + n] = D_TYPE(tmpsh[n][0]);
}
}
reduce_result(temp, d_offset, first_row, num_rows, tid);
}
void main() {

View file

@ -24,5 +24,5 @@ void main() {
const bool is_src0 = i0 < p.ne00 && i1 < p.ne01 && i2 < p.ne02 && i3 < p.ne03;
data_d[p.d_offset + dst_idx] = D_TYPE(is_src0 ? data_a[src0_idx] : 0.0f);
data_d[get_doffset() + dst_idx] = D_TYPE(is_src0 ? data_a[get_aoffset() + src0_idx] : 0.0f);
}

View file

@ -22,5 +22,5 @@ void main() {
return;
}
data_d[p.d_offset + dst_idx(idx)] = D_TYPE(data_a[src0_idx_mod(idx)]);
data_d[get_doffset() + dst_idx(idx)] = D_TYPE(data_a[get_aoffset() + src0_idx_mod(idx)]);
}

View file

@ -18,7 +18,7 @@ void main() {
continue;
}
data_d[p.d_offset + idx] = D_TYPE(FLOAT_TYPE(data_a[idx]) * FLOAT_TYPE(p.param1));
data_d[get_doffset() + idx] = D_TYPE(FLOAT_TYPE(data_a[get_aoffset() + idx]) * FLOAT_TYPE(p.param1));
idx += num_threads;
}
}

View file

@ -12,6 +12,6 @@ void main() {
return;
}
const FLOAT_TYPE val = FLOAT_TYPE(data_a[src0_idx(idx)]);
data_d[p.d_offset + dst_idx(idx)] = D_TYPE(sin(val));
const FLOAT_TYPE val = FLOAT_TYPE(data_a[get_aoffset() + src0_idx(idx)]);
data_d[get_doffset() + dst_idx(idx)] = D_TYPE(sin(val));
}

View file

@ -12,6 +12,6 @@ void main() {
return;
}
const FLOAT_TYPE val = FLOAT_TYPE(data_a[src0_idx(idx)]);
data_d[p.d_offset + dst_idx(idx)] = D_TYPE(val * val);
const FLOAT_TYPE val = FLOAT_TYPE(data_a[get_aoffset() + src0_idx(idx)]);
data_d[get_doffset() + dst_idx(idx)] = D_TYPE(val * val);
}

View file

@ -2,7 +2,7 @@
layout (push_constant) uniform parameter
{
uint ne; uint d_offset;
uint ne; uint a_offset; uint d_offset;
uint nb00; uint nb01; uint nb02; uint nb03;
uint ne10; uint ne11; uint ne12; uint ne13;
float sf0; float sf1; float sf2; float sf3;
@ -32,5 +32,5 @@ void main() {
const uint i02 = uint(i12 / p.sf2);
const uint i03 = uint(i13 / p.sf3);
data_d[p.d_offset + idx] = D_TYPE(data_a[i03 * p.nb03 + i02 * p.nb02 + i01 * p.nb01 + i00 * p.nb00]);
data_d[p.d_offset + idx] = D_TYPE(data_a[p.a_offset + i03 * p.nb03 + i02 * p.nb02 + i01 * p.nb01 + i00 * p.nb00]);
}