mirror of
https://github.com/LostRuins/koboldcpp.git
synced 2025-09-10 17:14:36 +00:00
Merge branch 'upstream' into concedo_experimental
# Conflicts: # README.md # examples/llama.android/llama/src/main/cpp/llama-android.cpp # examples/run/run.cpp # examples/server/README.md # examples/server/bench/README.md # examples/server/tests/README.md # ggml/src/CMakeLists.txt # ggml/src/ggml-cpu/CMakeLists.txt # tests/test-backend-ops.cpp
This commit is contained in:
commit
911da8765f
46 changed files with 82848 additions and 73880 deletions
|
@ -20,6 +20,7 @@
|
|||
#include <cstdarg>
|
||||
#include <cstring>
|
||||
#include <ctime>
|
||||
#include <filesystem>
|
||||
#include <fstream>
|
||||
#include <iostream>
|
||||
#include <iterator>
|
||||
|
@ -64,7 +65,9 @@
|
|||
#ifdef __linux__
|
||||
#include <linux/limits.h>
|
||||
#elif defined(_WIN32)
|
||||
# if !defined(PATH_MAX)
|
||||
# define PATH_MAX MAX_PATH
|
||||
# endif
|
||||
#else
|
||||
#include <sys/syslimits.h>
|
||||
#endif
|
||||
|
@ -1150,8 +1153,7 @@ static bool common_download_file(const std::string & url, const std::string & pa
|
|||
#endif
|
||||
|
||||
// Check if the file already exists locally
|
||||
struct stat model_file_info;
|
||||
auto file_exists = (stat(path.c_str(), &model_file_info) == 0);
|
||||
auto file_exists = std::filesystem::exists(path);
|
||||
|
||||
// If the file exists, check its JSON metadata companion file.
|
||||
std::string metadata_path = path + ".json";
|
||||
|
@ -1614,6 +1616,18 @@ std::string common_detokenize(llama_context * ctx, const std::vector<llama_token
|
|||
// Chat template utils
|
||||
//
|
||||
|
||||
std::string common_get_builtin_chat_template(const struct llama_model * model) {
|
||||
static const char * template_key = "tokenizer.chat_template";
|
||||
// call with NULL buffer to get the total size of the string
|
||||
int32_t res = llama_model_meta_val_str(model, template_key, NULL, 0);
|
||||
if (res > 0) {
|
||||
std::vector<char> model_template(res + 1, 0);
|
||||
llama_model_meta_val_str(model, template_key, model_template.data(), model_template.size());
|
||||
return std::string(model_template.data(), model_template.size() - 1);
|
||||
}
|
||||
return "";
|
||||
}
|
||||
|
||||
bool common_chat_verify_template(const std::string & tmpl) {
|
||||
llama_chat_message chat[] = {{"user", "test"}};
|
||||
int res = llama_chat_apply_template(nullptr, tmpl.c_str(), chat, 1, true, nullptr, 0);
|
||||
|
|
|
@ -567,6 +567,9 @@ struct common_chat_msg {
|
|||
std::string content;
|
||||
};
|
||||
|
||||
// Get the built-in chat template for the model. Return empty string if not present.
|
||||
std::string common_get_builtin_chat_template(const struct llama_model * model);
|
||||
|
||||
// Check if the template supplied via "--chat-template" is supported or not. Returns true if it's valid
|
||||
bool common_chat_verify_template(const std::string & tmpl);
|
||||
|
||||
|
|
|
@ -1764,25 +1764,19 @@ class DeciModel(Model):
|
|||
self.gguf_writer.add_token_list(tokens)
|
||||
self.gguf_writer.add_token_types(toktypes)
|
||||
|
||||
special_vocab = gguf.SpecialVocab(
|
||||
self.dir_model, load_merges=True,
|
||||
special_token_types = ['bos', 'eos', 'eom', 'eot']
|
||||
)
|
||||
special_vocab._set_special_token("bos", 128000)
|
||||
special_vocab._set_special_token("eos", 128001)
|
||||
special_vocab._set_special_token("eom", 128008)
|
||||
special_vocab._set_special_token("eot", 128009)
|
||||
special_vocab = gguf.SpecialVocab(self.dir_model, load_merges=True)
|
||||
special_vocab.add_to_gguf(self.gguf_writer)
|
||||
else:
|
||||
# DeciLM-7B
|
||||
self._set_vocab_llama_hf()
|
||||
# self._set_vocab_gpt2()
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
if "block_configs" in self.hparams: # Llama-3_1-Nemotron-51B
|
||||
assert self.block_count == len(self._num_kv_heads)
|
||||
assert self.block_count == len(self._num_heads)
|
||||
assert self.block_count == len(self._ffn_dims)
|
||||
if (rope_theta := self.hparams.get("rope_theta")) is not None:
|
||||
self.gguf_writer.add_rope_freq_base(rope_theta)
|
||||
self.gguf_writer.add_head_count_kv(self._num_kv_heads)
|
||||
self.gguf_writer.add_head_count(self._num_heads)
|
||||
self.gguf_writer.add_feed_forward_length(self._ffn_dims)
|
||||
|
|
|
@ -189,12 +189,12 @@ xychart-beta
|
|||
"pp": {
|
||||
"p95": round(data['metrics']["llamacpp_prompt_processing_second"]["p(95)"], 2),
|
||||
"avg": round(data['metrics']["llamacpp_prompt_processing_second"]["avg"], 2),
|
||||
"0": round(mean(prometheus_metrics['prompt_tokens_seconds']), 2),
|
||||
"0": round(mean(prometheus_metrics['prompt_tokens_seconds']), 2) if 'prompt_tokens_seconds' in prometheus_metrics else 0,
|
||||
},
|
||||
"tg": {
|
||||
"p95": round(data['metrics']["llamacpp_tokens_second"]["p(95)"], 2),
|
||||
"avg": round(data['metrics']["llamacpp_tokens_second"]["avg"], 2),
|
||||
"0": round(mean(prometheus_metrics['predicted_tokens_seconds']), 2),
|
||||
"0": round(mean(prometheus_metrics['predicted_tokens_seconds']), 2) if 'predicted_tokens_seconds' in prometheus_metrics else 0,
|
||||
},
|
||||
}
|
||||
with open("results.github.env", 'a') as github_env:
|
||||
|
@ -214,11 +214,14 @@ def start_benchmark(args):
|
|||
k6_args = [
|
||||
'run', args.scenario,
|
||||
'--no-color',
|
||||
'--no-connection-reuse',
|
||||
'--no-vu-connection-reuse',
|
||||
]
|
||||
k6_args.extend(['--duration', args.duration])
|
||||
k6_args.extend(['--iterations', args.n_prompts])
|
||||
k6_args.extend(['--vus', args.parallel])
|
||||
k6_args.extend(['--summary-export', 'k6-results.json'])
|
||||
k6_args.extend(['--out', 'csv=k6-results.csv'])
|
||||
args = f"SERVER_BENCH_N_PROMPTS={args.n_prompts} SERVER_BENCH_MAX_PROMPT_TOKENS={args.max_prompt_tokens} SERVER_BENCH_MAX_CONTEXT={args.max_tokens} "
|
||||
args = args + ' '.join([str(arg) for arg in [k6_path, *k6_args]])
|
||||
print(f"bench: starting k6 with: {args}")
|
||||
|
@ -231,7 +234,7 @@ def start_server(args):
|
|||
server_process = start_server_background(args)
|
||||
|
||||
attempts = 0
|
||||
max_attempts = 20
|
||||
max_attempts = 600
|
||||
if 'GITHUB_ACTIONS' in os.environ:
|
||||
max_attempts *= 2
|
||||
|
||||
|
@ -242,7 +245,15 @@ def start_server(args):
|
|||
print(f"bench: waiting for server to start ...")
|
||||
time.sleep(0.5)
|
||||
|
||||
print("bench: server started.")
|
||||
attempts = 0
|
||||
while not is_server_ready(args.host, args.port):
|
||||
attempts += 1
|
||||
if attempts > max_attempts:
|
||||
assert False, "server not ready"
|
||||
print(f"bench: waiting for server to be ready ...")
|
||||
time.sleep(0.5)
|
||||
|
||||
print("bench: server started and ready.")
|
||||
return server_process
|
||||
|
||||
|
||||
|
@ -255,11 +266,6 @@ def start_server_background(args):
|
|||
'--host', args.host,
|
||||
'--port', args.port,
|
||||
]
|
||||
model_file = args.model_path_prefix + os.path.sep + args.hf_file
|
||||
model_dir = os.path.dirname(model_file)
|
||||
if not os.path.exists(model_dir):
|
||||
os.makedirs(model_dir)
|
||||
server_args.extend(['--model', model_file])
|
||||
server_args.extend(['--hf-repo', args.hf_repo])
|
||||
server_args.extend(['--hf-file', args.hf_file])
|
||||
server_args.extend(['--n-gpu-layers', args.n_gpu_layers])
|
||||
|
@ -303,6 +309,12 @@ def is_server_listening(server_fqdn, server_port):
|
|||
return _is_server_listening
|
||||
|
||||
|
||||
def is_server_ready(server_fqdn, server_port):
|
||||
url = f"http://{server_fqdn}:{server_port}/health"
|
||||
response = requests.get(url)
|
||||
return response.status_code == 200
|
||||
|
||||
|
||||
def escape_metric_name(metric_name):
|
||||
return re.sub('[^A-Z0-9]', '_', metric_name.upper())
|
||||
|
||||
|
|
|
@ -56,6 +56,7 @@ const llamacpp_completion_tokens = new Trend('llamacpp_completion_tokens')
|
|||
|
||||
const llamacpp_tokens_second = new Trend('llamacpp_tokens_second')
|
||||
const llamacpp_prompt_processing_second = new Trend('llamacpp_prompt_processing_second')
|
||||
const llamacpp_emit_first_token_second = new Trend('llamacpp_emit_first_token_second')
|
||||
|
||||
const llamacpp_prompt_tokens_total_counter = new Counter('llamacpp_prompt_tokens_total_counter')
|
||||
const llamacpp_completion_tokens_total_counter = new Counter('llamacpp_completion_tokens_total_counter')
|
||||
|
@ -89,6 +90,9 @@ export default function () {
|
|||
],
|
||||
"model": model,
|
||||
"stream": true,
|
||||
"stream_options": {
|
||||
"include_usage": true, // False to be supported in llama.cpp server
|
||||
},
|
||||
"seed": 42,
|
||||
"max_tokens": max_tokens,
|
||||
"stop": ["<|im_end|>"] // This is temporary for phi-2 base (i.e. not instructed) since the server expects that the model always to emit BOS
|
||||
|
@ -105,13 +109,21 @@ export default function () {
|
|||
client.on('event', function (event) {
|
||||
if (promptEvalEndTime == null) {
|
||||
promptEvalEndTime = new Date()
|
||||
llamacpp_emit_first_token_second.add((promptEvalEndTime - startTime) / 1.e3)
|
||||
}
|
||||
|
||||
if (event.data === '[DONE]' || event.data === '') {
|
||||
return
|
||||
}
|
||||
|
||||
let chunk = JSON.parse(event.data)
|
||||
|
||||
if (chunk.choices && chunk.choices.length > 0) {
|
||||
let choice = chunk.choices[0]
|
||||
if (choice.finish_reason) {
|
||||
finish_reason = choice.finish_reason
|
||||
}
|
||||
}
|
||||
|
||||
if (chunk.usage) {
|
||||
prompt_tokens = chunk.usage.prompt_tokens
|
||||
|
|
|
@ -67,6 +67,13 @@ enum server_task_type {
|
|||
SERVER_TASK_TYPE_SET_LORA,
|
||||
};
|
||||
|
||||
enum oaicompat_type {
|
||||
OAICOMPAT_TYPE_NONE,
|
||||
OAICOMPAT_TYPE_CHAT,
|
||||
OAICOMPAT_TYPE_COMPLETION,
|
||||
OAICOMPAT_TYPE_EMBEDDING,
|
||||
};
|
||||
|
||||
// https://community.openai.com/t/openai-chat-list-of-error-codes-and-types/357791/11
|
||||
enum error_type {
|
||||
ERROR_TYPE_INVALID_REQUEST,
|
||||
|
@ -91,6 +98,8 @@ struct slot_params {
|
|||
int64_t t_max_prompt_ms = -1; // TODO: implement
|
||||
int64_t t_max_predict_ms = -1; // if positive, limit the generation phase to this time limit
|
||||
|
||||
std::vector<common_lora_adapter_container> lora;
|
||||
|
||||
std::vector<std::string> antiprompt;
|
||||
std::vector<std::string> response_fields;
|
||||
bool timings_per_token = false;
|
||||
|
@ -102,8 +111,7 @@ struct slot_params {
|
|||
|
||||
// OAI-compat fields
|
||||
bool verbose = false;
|
||||
bool oaicompat = false;
|
||||
bool oaicompat_chat = true;
|
||||
oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE;
|
||||
std::string oaicompat_model;
|
||||
std::string oaicompat_cmpl_id;
|
||||
|
||||
|
@ -114,6 +122,11 @@ struct slot_params {
|
|||
samplers.emplace_back(common_sampler_type_to_str(sampler));
|
||||
}
|
||||
|
||||
json lora = json::array();
|
||||
for (size_t i = 0; i < this->lora.size(); ++i) {
|
||||
lora.push_back({{"id", i}, {"scale", this->lora[i].scale}});
|
||||
}
|
||||
|
||||
return json {
|
||||
{"n_predict", n_predict}, // Server configured n_predict
|
||||
{"seed", sampling.seed},
|
||||
|
@ -154,6 +167,7 @@ struct slot_params {
|
|||
{"speculative.p_min", speculative.p_min},
|
||||
{"timings_per_token", timings_per_token},
|
||||
{"post_sampling_probs", post_sampling_probs},
|
||||
{"lora", lora},
|
||||
};
|
||||
}
|
||||
};
|
||||
|
@ -183,12 +197,16 @@ struct server_task {
|
|||
// used by SERVER_TASK_TYPE_METRICS
|
||||
bool metrics_reset_bucket = false;
|
||||
|
||||
// used by SERVER_TASK_TYPE_SET_LORA
|
||||
std::vector<common_lora_adapter_container> set_lora;
|
||||
|
||||
server_task(server_task_type type) : type(type) {}
|
||||
|
||||
static slot_params params_from_json_cmpl(
|
||||
const llama_model * model,
|
||||
const llama_context * ctx,
|
||||
const common_params & params_base,
|
||||
const std::vector<common_lora_adapter_container> & lora_base,
|
||||
const json & data) {
|
||||
slot_params params;
|
||||
|
||||
|
@ -245,6 +263,16 @@ struct server_task {
|
|||
params.speculative.n_min = std::max(params.speculative.n_min, 2);
|
||||
params.speculative.n_max = std::max(params.speculative.n_max, 0);
|
||||
|
||||
if (data.contains("lora")) {
|
||||
if (data.at("lora").is_array()) {
|
||||
params.lora = parse_lora_request(lora_base, data.at("lora"));
|
||||
} else {
|
||||
throw std::runtime_error("Error: 'lora' must be an array of objects with 'id' and 'scale' fields");
|
||||
}
|
||||
} else {
|
||||
params.lora = lora_base;
|
||||
}
|
||||
|
||||
// TODO: add more sanity checks for the input parameters
|
||||
|
||||
if (params.sampling.penalty_last_n < -1) {
|
||||
|
@ -530,8 +558,7 @@ struct server_task_result_cmpl_final : server_task_result {
|
|||
|
||||
// OAI-compat fields
|
||||
bool verbose = false;
|
||||
bool oaicompat = false;
|
||||
bool oaicompat_chat = true; // TODO: support oaicompat for non-chat
|
||||
oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE;
|
||||
std::string oaicompat_model;
|
||||
std::string oaicompat_cmpl_id;
|
||||
|
||||
|
@ -544,9 +571,16 @@ struct server_task_result_cmpl_final : server_task_result {
|
|||
}
|
||||
|
||||
virtual json to_json() override {
|
||||
return oaicompat
|
||||
? (stream ? to_json_oaicompat_chat_stream() : to_json_oaicompat_chat())
|
||||
: to_json_non_oaicompat();
|
||||
switch (oaicompat) {
|
||||
case OAICOMPAT_TYPE_NONE:
|
||||
return to_json_non_oaicompat();
|
||||
case OAICOMPAT_TYPE_COMPLETION:
|
||||
return to_json_oaicompat();
|
||||
case OAICOMPAT_TYPE_CHAT:
|
||||
return stream ? to_json_oaicompat_chat_stream() : to_json_oaicompat_chat();
|
||||
default:
|
||||
GGML_ASSERT(false && "Invalid oaicompat_type");
|
||||
}
|
||||
}
|
||||
|
||||
json to_json_non_oaicompat() {
|
||||
|
@ -574,6 +608,50 @@ struct server_task_result_cmpl_final : server_task_result {
|
|||
return response_fields.empty() ? res : json_get_nested_values(response_fields, res);
|
||||
}
|
||||
|
||||
json to_json_oaicompat() {
|
||||
std::time_t t = std::time(0);
|
||||
json logprobs = json(nullptr); // OAI default to null
|
||||
if (!stream && probs_output.size() > 0) {
|
||||
logprobs = json{
|
||||
{"content", completion_token_output::probs_vector_to_json(probs_output, post_sampling_probs)},
|
||||
};
|
||||
}
|
||||
json finish_reason = "length";
|
||||
if (stop == STOP_TYPE_WORD || stop == STOP_TYPE_EOS) {
|
||||
finish_reason = "stop";
|
||||
}
|
||||
json res = json {
|
||||
{"choices", json::array({
|
||||
json{
|
||||
{"text", stream ? "" : content}, // in stream mode, content is already in last partial chunk
|
||||
{"index", index},
|
||||
{"logprobs", logprobs},
|
||||
{"finish_reason", finish_reason},
|
||||
}
|
||||
})},
|
||||
{"created", t},
|
||||
{"model", oaicompat_model},
|
||||
{"system_fingerprint", build_info},
|
||||
{"object", "text_completion"},
|
||||
{"usage", json {
|
||||
{"completion_tokens", n_decoded},
|
||||
{"prompt_tokens", n_prompt_tokens},
|
||||
{"total_tokens", n_decoded + n_prompt_tokens}
|
||||
}},
|
||||
{"id", oaicompat_cmpl_id}
|
||||
};
|
||||
|
||||
// extra fields for debugging purposes
|
||||
if (verbose) {
|
||||
res["__verbose"] = to_json_non_oaicompat();
|
||||
}
|
||||
if (timings.prompt_n >= 0) {
|
||||
res.push_back({"timings", timings.to_json()});
|
||||
}
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
json to_json_oaicompat_chat() {
|
||||
std::string finish_reason = "length";
|
||||
if (stop == STOP_TYPE_WORD || stop == STOP_TYPE_EOS) {
|
||||
|
@ -672,8 +750,7 @@ struct server_task_result_cmpl_partial : server_task_result {
|
|||
|
||||
// OAI-compat fields
|
||||
bool verbose = false;
|
||||
bool oaicompat = false;
|
||||
bool oaicompat_chat = true; // TODO: support oaicompat for non-chat
|
||||
oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE;
|
||||
std::string oaicompat_model;
|
||||
std::string oaicompat_cmpl_id;
|
||||
|
||||
|
@ -686,7 +763,16 @@ struct server_task_result_cmpl_partial : server_task_result {
|
|||
}
|
||||
|
||||
virtual json to_json() override {
|
||||
return oaicompat ? to_json_oaicompat() : to_json_non_oaicompat();
|
||||
switch (oaicompat) {
|
||||
case OAICOMPAT_TYPE_NONE:
|
||||
return to_json_non_oaicompat();
|
||||
case OAICOMPAT_TYPE_COMPLETION:
|
||||
return to_json_oaicompat();
|
||||
case OAICOMPAT_TYPE_CHAT:
|
||||
return to_json_oaicompat_chat();
|
||||
default:
|
||||
GGML_ASSERT(false && "Invalid oaicompat_type");
|
||||
}
|
||||
}
|
||||
|
||||
json to_json_non_oaicompat() {
|
||||
|
@ -711,6 +797,41 @@ struct server_task_result_cmpl_partial : server_task_result {
|
|||
}
|
||||
|
||||
json to_json_oaicompat() {
|
||||
std::time_t t = std::time(0);
|
||||
json logprobs = json(nullptr); // OAI default to null
|
||||
if (prob_output.probs.size() > 0) {
|
||||
logprobs = json{
|
||||
{"content", completion_token_output::probs_vector_to_json({prob_output}, post_sampling_probs)},
|
||||
};
|
||||
}
|
||||
json res = json {
|
||||
{"choices", json::array({
|
||||
json{
|
||||
{"text", content},
|
||||
{"index", index},
|
||||
{"logprobs", logprobs},
|
||||
{"finish_reason", nullptr},
|
||||
}
|
||||
})},
|
||||
{"created", t},
|
||||
{"model", oaicompat_model},
|
||||
{"system_fingerprint", build_info},
|
||||
{"object", "text_completion"},
|
||||
{"id", oaicompat_cmpl_id}
|
||||
};
|
||||
|
||||
// extra fields for debugging purposes
|
||||
if (verbose) {
|
||||
res["__verbose"] = to_json_non_oaicompat();
|
||||
}
|
||||
if (timings.prompt_n >= 0) {
|
||||
res.push_back({"timings", timings.to_json()});
|
||||
}
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
json to_json_oaicompat_chat() {
|
||||
bool first = n_decoded == 0;
|
||||
std::time_t t = std::time(0);
|
||||
json choices;
|
||||
|
@ -789,14 +910,16 @@ struct server_task_result_embd : server_task_result {
|
|||
int32_t n_tokens;
|
||||
|
||||
// OAI-compat fields
|
||||
bool oaicompat = false;
|
||||
oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE;
|
||||
|
||||
virtual int get_index() override {
|
||||
return index;
|
||||
}
|
||||
|
||||
virtual json to_json() override {
|
||||
return oaicompat ? to_json_oaicompat() : to_json_non_oaicompat();
|
||||
return oaicompat == OAICOMPAT_TYPE_EMBEDDING
|
||||
? to_json_oaicompat()
|
||||
: to_json_non_oaicompat();
|
||||
}
|
||||
|
||||
json to_json_non_oaicompat() {
|
||||
|
@ -1009,6 +1132,8 @@ struct server_slot {
|
|||
|
||||
common_speculative * spec = nullptr;
|
||||
|
||||
std::vector<common_lora_adapter_container> lora;
|
||||
|
||||
// the index relative to completion multi-task request
|
||||
size_t index = 0;
|
||||
|
||||
|
@ -1090,6 +1215,11 @@ struct server_slot {
|
|||
return task_type == SERVER_TASK_TYPE_EMBEDDING || task_type == SERVER_TASK_TYPE_RERANK;
|
||||
}
|
||||
|
||||
bool can_batch_with(server_slot & other_slot) {
|
||||
return is_non_causal() == other_slot.is_non_causal()
|
||||
&& are_lora_equal(lora, other_slot.lora);
|
||||
}
|
||||
|
||||
bool has_budget(const common_params & global_params) {
|
||||
if (params.n_predict == -1 && global_params.n_predict == -1) {
|
||||
return true; // limitless
|
||||
|
@ -1499,7 +1629,7 @@ struct server_context {
|
|||
|
||||
llama_model * model = nullptr;
|
||||
llama_context * ctx = nullptr;
|
||||
std::vector<common_lora_adapter_container> loras;
|
||||
std::vector<common_lora_adapter_container> lora;
|
||||
|
||||
llama_model * model_dft = nullptr;
|
||||
llama_context_params cparams_dft;
|
||||
|
@ -1566,7 +1696,7 @@ struct server_context {
|
|||
|
||||
model = llama_init.model;
|
||||
ctx = llama_init.context;
|
||||
loras = llama_init.lora_adapters;
|
||||
lora = llama_init.lora_adapters;
|
||||
|
||||
if (model == nullptr) {
|
||||
SRV_ERR("failed to load model, '%s'\n", params_base.model.c_str());
|
||||
|
@ -1623,18 +1753,11 @@ struct server_context {
|
|||
return true;
|
||||
}
|
||||
|
||||
bool validate_model_chat_template() const {
|
||||
std::vector<char> model_template(2048, 0); // longest known template is about 1200 bytes
|
||||
std::string template_key = "tokenizer.chat_template";
|
||||
int32_t res = llama_model_meta_val_str(model, template_key.c_str(), model_template.data(), model_template.size());
|
||||
if (res >= 0) {
|
||||
bool validate_builtin_chat_template() const {
|
||||
llama_chat_message chat[] = {{"user", "test"}};
|
||||
std::string tmpl = std::string(model_template.data(), model_template.size());
|
||||
int32_t chat_res = llama_chat_apply_template(model, tmpl.c_str(), chat, 1, true, nullptr, 0);
|
||||
int32_t chat_res = llama_chat_apply_template(model, nullptr, chat, 1, true, nullptr, 0);
|
||||
return chat_res > 0;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
void init() {
|
||||
const int32_t n_ctx_slot = n_ctx / params_base.n_parallel;
|
||||
|
@ -1772,6 +1895,12 @@ struct server_context {
|
|||
slot.params = std::move(task.params);
|
||||
slot.prompt_tokens = std::move(task.prompt_tokens);
|
||||
|
||||
if (!are_lora_equal(task.params.lora, slot.lora)) {
|
||||
// if lora is changed, we cannot reuse cached tokens
|
||||
slot.cache_tokens.clear();
|
||||
slot.lora = std::move(task.params.lora);
|
||||
}
|
||||
|
||||
SLT_DBG(slot, "launching slot : %s\n", safe_json_to_str(slot.to_json()).c_str());
|
||||
|
||||
if (slot.n_predict > 0 && slot.params.n_predict > slot.n_predict) {
|
||||
|
@ -1856,6 +1985,8 @@ struct server_context {
|
|||
result.text_to_send = slot.generated_text.substr(pos, std::string::npos);
|
||||
slot.n_sent_text += result.text_to_send.size();
|
||||
// add the token to slot queue and cache
|
||||
} else {
|
||||
result.text_to_send = "";
|
||||
}
|
||||
|
||||
slot.add_token(result);
|
||||
|
@ -2042,7 +2173,6 @@ struct server_context {
|
|||
|
||||
res->verbose = slot.params.verbose;
|
||||
res->oaicompat = slot.params.oaicompat;
|
||||
res->oaicompat_chat = slot.params.oaicompat_chat;
|
||||
res->oaicompat_model = slot.params.oaicompat_model;
|
||||
res->oaicompat_cmpl_id = slot.params.oaicompat_cmpl_id;
|
||||
|
||||
|
@ -2083,7 +2213,6 @@ struct server_context {
|
|||
res->verbose = slot.params.verbose;
|
||||
res->stream = slot.params.stream;
|
||||
res->oaicompat = slot.params.oaicompat;
|
||||
res->oaicompat_chat = slot.params.oaicompat_chat;
|
||||
res->oaicompat_model = slot.params.oaicompat_model;
|
||||
res->oaicompat_cmpl_id = slot.params.oaicompat_cmpl_id;
|
||||
|
||||
|
@ -2463,7 +2592,7 @@ struct server_context {
|
|||
} break;
|
||||
case SERVER_TASK_TYPE_SET_LORA:
|
||||
{
|
||||
common_lora_adapters_apply(ctx, loras);
|
||||
lora = std::move(task.set_lora);
|
||||
auto res = std::make_unique<server_task_result_apply_lora>();
|
||||
res->id = task.id;
|
||||
queue_results.send(std::move(res));
|
||||
|
@ -2540,12 +2669,22 @@ struct server_context {
|
|||
// start populating the batch for this iteration
|
||||
common_batch_clear(batch);
|
||||
|
||||
// track if given slot can be batched with slots already in the batch
|
||||
server_slot * slot_batched = nullptr;
|
||||
|
||||
// frist, add sampled tokens from any ongoing sequences
|
||||
for (auto & slot : slots) {
|
||||
if (slot.state != SLOT_STATE_GENERATING) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// check if we can batch this slot with the previous one
|
||||
if (!slot_batched) {
|
||||
slot_batched = &slot;
|
||||
} else if (!slot_batched->can_batch_with(slot)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
slot.i_batch = batch.n_tokens;
|
||||
|
||||
common_batch_add(batch, slot.sampled, slot.n_past, { slot.id }, true);
|
||||
|
@ -2564,15 +2703,18 @@ struct server_context {
|
|||
int32_t n_batch = llama_n_batch(ctx);
|
||||
int32_t n_ubatch = llama_n_ubatch(ctx);
|
||||
|
||||
// track if this is an embedding or non-embedding batch
|
||||
// if we've added sampled tokens above, we are in non-embedding mode
|
||||
// -1: none, 0: non-embedding, 1: embedding
|
||||
// TODO: make enum
|
||||
int32_t batch_type = batch.n_tokens > 0 ? 0 : -1;
|
||||
|
||||
// next, batch any pending prompts without exceeding n_batch
|
||||
if (params_base.cont_batching || batch.n_tokens == 0) {
|
||||
for (auto & slot : slots) {
|
||||
// check if we can batch this slot with the previous one
|
||||
if (slot.is_processing()) {
|
||||
if (!slot_batched) {
|
||||
slot_batched = &slot;
|
||||
} else if (!slot_batched->can_batch_with(slot)) {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
// this slot still has a prompt to be processed
|
||||
if (slot.state == SLOT_STATE_PROCESSING_PROMPT || slot.state == SLOT_STATE_STARTED) {
|
||||
auto & prompt_tokens = slot.prompt_tokens;
|
||||
|
@ -2733,14 +2875,6 @@ struct server_context {
|
|||
}
|
||||
}
|
||||
|
||||
// check that we are in the right batch_type, if not defer the slot
|
||||
int slot_type = slot.is_non_causal();
|
||||
if (batch_type == -1) {
|
||||
batch_type = slot_type;
|
||||
} else if (batch_type != slot_type) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// keep only the common part
|
||||
if (!llama_kv_cache_seq_rm(ctx, slot.id, slot.n_past, -1)) {
|
||||
// could not partially delete (likely using a non-Transformer model)
|
||||
|
@ -2808,8 +2942,12 @@ struct server_context {
|
|||
|
||||
SRV_DBG("decoding batch, n_tokens = %d\n", batch.n_tokens);
|
||||
|
||||
if (slot_batched) {
|
||||
// make sure we're in the right embedding mode
|
||||
llama_set_embeddings(ctx, batch_type == 1);
|
||||
llama_set_embeddings(ctx, slot_batched->is_non_causal());
|
||||
// apply lora, only need to do it once per batch
|
||||
common_lora_adapters_apply(ctx, slot_batched->lora);
|
||||
}
|
||||
|
||||
// process the created batch of tokens
|
||||
for (int32_t i = 0; i < batch.n_tokens; i += n_batch) {
|
||||
|
@ -3482,7 +3620,7 @@ int main(int argc, char ** argv) {
|
|||
{ "default_generation_settings", ctx_server.default_generation_settings_for_props },
|
||||
{ "total_slots", ctx_server.params_base.n_parallel },
|
||||
{ "model_path", ctx_server.params_base.model },
|
||||
{ "chat_template", llama_get_chat_template(ctx_server.model) },
|
||||
{ "chat_template", common_get_builtin_chat_template(ctx_server.model) },
|
||||
{ "build_info", build_info },
|
||||
};
|
||||
|
||||
|
@ -3504,12 +3642,11 @@ int main(int argc, char ** argv) {
|
|||
|
||||
// handle completion-like requests (completion, chat, infill)
|
||||
// we can optionally provide a custom format for partial results and final results
|
||||
const auto handle_completions_generic = [&ctx_server, &res_error, &res_ok](
|
||||
const auto handle_completions_impl = [&ctx_server, &res_error, &res_ok](
|
||||
server_task_type type,
|
||||
json & data,
|
||||
httplib::Response & res,
|
||||
bool oaicompat = false,
|
||||
bool oaicompat_chat = false) {
|
||||
oaicompat_type oaicompat) {
|
||||
GGML_ASSERT(type == SERVER_TASK_TYPE_COMPLETION || type == SERVER_TASK_TYPE_INFILL);
|
||||
|
||||
if (ctx_server.params_base.embedding) {
|
||||
|
@ -3530,12 +3667,16 @@ int main(int argc, char ** argv) {
|
|||
task.index = i;
|
||||
|
||||
task.prompt_tokens = std::move(tokenized_prompts[i]);
|
||||
task.params = server_task::params_from_json_cmpl(ctx_server.model, ctx_server.ctx, ctx_server.params_base, data);
|
||||
task.params = server_task::params_from_json_cmpl(
|
||||
ctx_server.model,
|
||||
ctx_server.ctx,
|
||||
ctx_server.params_base,
|
||||
ctx_server.lora,
|
||||
data);
|
||||
task.id_selected_slot = json_value(data, "id_slot", -1);
|
||||
|
||||
// OAI-compat
|
||||
task.params.oaicompat = oaicompat;
|
||||
task.params.oaicompat_chat = oaicompat_chat;
|
||||
task.params.oaicompat_cmpl_id = completion_id;
|
||||
// oaicompat_model is already populated by params_from_json_cmpl
|
||||
|
||||
|
@ -3587,7 +3728,7 @@ int main(int argc, char ** argv) {
|
|||
}, [&](const json & error_data) {
|
||||
server_sent_event(sink, "error", error_data);
|
||||
});
|
||||
if (oaicompat) {
|
||||
if (oaicompat != OAICOMPAT_TYPE_NONE) {
|
||||
static const std::string ev_done = "data: [DONE]\n\n";
|
||||
sink.write(ev_done.data(), ev_done.size());
|
||||
}
|
||||
|
@ -3603,17 +3744,25 @@ int main(int argc, char ** argv) {
|
|||
}
|
||||
};
|
||||
|
||||
const auto handle_completions = [&handle_completions_generic](const httplib::Request & req, httplib::Response & res) {
|
||||
const auto handle_completions = [&handle_completions_impl](const httplib::Request & req, httplib::Response & res) {
|
||||
json data = json::parse(req.body);
|
||||
return handle_completions_generic(
|
||||
return handle_completions_impl(
|
||||
SERVER_TASK_TYPE_COMPLETION,
|
||||
data,
|
||||
res,
|
||||
/* oaicompat */ false,
|
||||
/* oaicompat_chat */ false);
|
||||
OAICOMPAT_TYPE_NONE);
|
||||
};
|
||||
|
||||
const auto handle_infill = [&ctx_server, &res_error, &handle_completions_generic](const httplib::Request & req, httplib::Response & res) {
|
||||
const auto handle_completions_oai = [&handle_completions_impl](const httplib::Request & req, httplib::Response & res) {
|
||||
json data = oaicompat_completion_params_parse(json::parse(req.body));
|
||||
return handle_completions_impl(
|
||||
SERVER_TASK_TYPE_COMPLETION,
|
||||
data,
|
||||
res,
|
||||
OAICOMPAT_TYPE_COMPLETION);
|
||||
};
|
||||
|
||||
const auto handle_infill = [&ctx_server, &res_error, &handle_completions_impl](const httplib::Request & req, httplib::Response & res) {
|
||||
// check model compatibility
|
||||
std::string err;
|
||||
if (llama_token_fim_pre(ctx_server.model) == LLAMA_TOKEN_NULL) {
|
||||
|
@ -3682,22 +3831,25 @@ int main(int argc, char ** argv) {
|
|||
tokenized_prompts[0]
|
||||
);
|
||||
|
||||
return handle_completions_generic(SERVER_TASK_TYPE_INFILL, data, res);
|
||||
return handle_completions_impl(
|
||||
SERVER_TASK_TYPE_INFILL,
|
||||
data,
|
||||
res,
|
||||
OAICOMPAT_TYPE_NONE); // infill is not OAI compatible
|
||||
};
|
||||
|
||||
const auto handle_chat_completions = [&ctx_server, ¶ms, &res_error, &handle_completions_generic](const httplib::Request & req, httplib::Response & res) {
|
||||
const auto handle_chat_completions = [&ctx_server, ¶ms, &res_error, &handle_completions_impl](const httplib::Request & req, httplib::Response & res) {
|
||||
if (ctx_server.params_base.embedding) {
|
||||
res_error(res, format_error_response("This server does not support completions. Start it without `--embeddings`", ERROR_TYPE_NOT_SUPPORTED));
|
||||
return;
|
||||
}
|
||||
|
||||
json data = oaicompat_completion_params_parse(ctx_server.model, json::parse(req.body), params.chat_template);
|
||||
return handle_completions_generic(
|
||||
json data = oaicompat_chat_completion_params_parse(ctx_server.model, json::parse(req.body), params.chat_template);
|
||||
return handle_completions_impl(
|
||||
SERVER_TASK_TYPE_COMPLETION,
|
||||
data,
|
||||
res,
|
||||
/* oaicompat */ true,
|
||||
/* oaicompat_chat */ true);
|
||||
OAICOMPAT_TYPE_CHAT);
|
||||
};
|
||||
|
||||
const auto handle_models = [¶ms, &ctx_server, &res_ok](const httplib::Request &, httplib::Response & res) {
|
||||
|
@ -3770,10 +3922,10 @@ int main(int argc, char ** argv) {
|
|||
res_ok(res, data);
|
||||
};
|
||||
|
||||
const auto handle_embeddings_impl = [&ctx_server, &res_error, &res_ok](const httplib::Request & req, httplib::Response & res, bool oaicompat) {
|
||||
const auto handle_embeddings_impl = [&ctx_server, &res_error, &res_ok](const httplib::Request & req, httplib::Response & res, oaicompat_type oaicompat) {
|
||||
const json body = json::parse(req.body);
|
||||
|
||||
if (oaicompat && llama_pooling_type(ctx_server.ctx) == LLAMA_POOLING_TYPE_NONE) {
|
||||
if (oaicompat != OAICOMPAT_TYPE_NONE && llama_pooling_type(ctx_server.ctx) == LLAMA_POOLING_TYPE_NONE) {
|
||||
res_error(res, format_error_response("Pooling type 'none' is not OAI compatible. Please use a different pooling type", ERROR_TYPE_INVALID_REQUEST));
|
||||
return;
|
||||
}
|
||||
|
@ -3783,7 +3935,7 @@ int main(int argc, char ** argv) {
|
|||
if (body.count("input") != 0) {
|
||||
prompt = body.at("input");
|
||||
} else if (body.contains("content")) {
|
||||
oaicompat = false;
|
||||
oaicompat = OAICOMPAT_TYPE_NONE; // "content" field is not OAI compatible
|
||||
prompt = body.at("content");
|
||||
} else {
|
||||
res_error(res, format_error_response("\"input\" or \"content\" must be provided", ERROR_TYPE_INVALID_REQUEST));
|
||||
|
@ -3852,16 +4004,18 @@ int main(int argc, char ** argv) {
|
|||
}
|
||||
|
||||
// write JSON response
|
||||
json root = oaicompat ? format_embeddings_response_oaicompat(body, responses, use_base64) : json(responses);
|
||||
json root = oaicompat == OAICOMPAT_TYPE_EMBEDDING
|
||||
? format_embeddings_response_oaicompat(body, responses, use_base64)
|
||||
: json(responses);
|
||||
res_ok(res, root);
|
||||
};
|
||||
|
||||
const auto handle_embeddings = [&handle_embeddings_impl](const httplib::Request & req, httplib::Response & res) {
|
||||
handle_embeddings_impl(req, res, false);
|
||||
handle_embeddings_impl(req, res, OAICOMPAT_TYPE_NONE);
|
||||
};
|
||||
|
||||
const auto handle_embeddings_oai = [&handle_embeddings_impl](const httplib::Request & req, httplib::Response & res) {
|
||||
handle_embeddings_impl(req, res, true);
|
||||
handle_embeddings_impl(req, res, OAICOMPAT_TYPE_EMBEDDING);
|
||||
};
|
||||
|
||||
const auto handle_rerank = [&ctx_server, &res_error, &res_ok](const httplib::Request & req, httplib::Response & res) {
|
||||
|
@ -3944,8 +4098,8 @@ int main(int argc, char ** argv) {
|
|||
|
||||
const auto handle_lora_adapters_list = [&](const httplib::Request &, httplib::Response & res) {
|
||||
json result = json::array();
|
||||
for (size_t i = 0; i < ctx_server.loras.size(); ++i) {
|
||||
auto & lora = ctx_server.loras[i];
|
||||
for (size_t i = 0; i < ctx_server.lora.size(); ++i) {
|
||||
auto & lora = ctx_server.lora[i];
|
||||
result.push_back({
|
||||
{"id", i},
|
||||
{"path", lora.path},
|
||||
|
@ -3957,27 +4111,14 @@ int main(int argc, char ** argv) {
|
|||
};
|
||||
|
||||
const auto handle_lora_adapters_apply = [&](const httplib::Request & req, httplib::Response & res) {
|
||||
const std::vector<json> body = json::parse(req.body);
|
||||
int max_idx = ctx_server.loras.size();
|
||||
|
||||
// clear existing value
|
||||
for (auto & lora : ctx_server.loras) {
|
||||
lora.scale = 0.0f;
|
||||
const json body = json::parse(req.body);
|
||||
if (!body.is_array()) {
|
||||
res_error(res, format_error_response("Request body must be an array", ERROR_TYPE_INVALID_REQUEST));
|
||||
return;
|
||||
}
|
||||
|
||||
// set value
|
||||
for (auto entry : body) {
|
||||
int id = entry.at("id");
|
||||
float scale = entry.at("scale");
|
||||
if (0 <= id && id < max_idx) {
|
||||
ctx_server.loras[id].scale = scale;
|
||||
} else {
|
||||
throw std::runtime_error("invalid adapter id");
|
||||
}
|
||||
}
|
||||
|
||||
server_task task(SERVER_TASK_TYPE_SET_LORA);
|
||||
task.id = ctx_server.queue_tasks.get_new_id();
|
||||
task.set_lora = parse_lora_request(ctx_server.lora, body);
|
||||
ctx_server.queue_results.add_waiting_task_id(task.id);
|
||||
ctx_server.queue_tasks.post(task);
|
||||
|
||||
|
@ -4031,7 +4172,7 @@ int main(int argc, char ** argv) {
|
|||
svr->Get ("/v1/models", handle_models); // public endpoint (no API key check)
|
||||
svr->Post("/completion", handle_completions); // legacy
|
||||
svr->Post("/completions", handle_completions);
|
||||
svr->Post("/v1/completions", handle_completions);
|
||||
svr->Post("/v1/completions", handle_completions_oai);
|
||||
svr->Post("/chat/completions", handle_chat_completions);
|
||||
svr->Post("/v1/chat/completions", handle_chat_completions);
|
||||
svr->Post("/infill", handle_infill);
|
||||
|
@ -4111,14 +4252,16 @@ int main(int argc, char ** argv) {
|
|||
|
||||
// if a custom chat template is not supplied, we will use the one that comes with the model (if any)
|
||||
if (params.chat_template.empty()) {
|
||||
if (!ctx_server.validate_model_chat_template()) {
|
||||
if (!ctx_server.validate_builtin_chat_template()) {
|
||||
LOG_WRN("%s: The chat template that comes with this model is not yet supported, falling back to chatml. This may cause the model to output suboptimal responses\n", __func__);
|
||||
params.chat_template = "chatml";
|
||||
}
|
||||
}
|
||||
|
||||
// print sample chat example to make it clear which template is used
|
||||
LOG_INF("%s: chat template, built_in: %d, chat_example: '%s'\n", __func__, params.chat_template.empty(), common_chat_format_example(ctx_server.model, params.chat_template).c_str());
|
||||
LOG_INF("%s: chat template, chat_template: %s, example_format: '%s'\n", __func__,
|
||||
params.chat_template.empty() ? "(built-in)" : params.chat_template.c_str(),
|
||||
common_chat_format_example(ctx_server.model, params.chat_template).c_str());
|
||||
|
||||
ctx_server.queue_tasks.on_new_task(std::bind(
|
||||
&server_context::process_single_task, &ctx_server, std::placeholders::_1));
|
||||
|
|
|
@ -5,3 +5,4 @@ numpy~=1.26.4
|
|||
openai~=1.55.3
|
||||
prometheus-client~=0.20.0
|
||||
requests~=2.32.3
|
||||
wget~=3.2
|
||||
|
|
|
@ -83,7 +83,7 @@ def test_chat_completion_stream(system_prompt, user_prompt, max_tokens, re_conte
|
|||
def test_chat_completion_with_openai_library():
|
||||
global server
|
||||
server.start()
|
||||
client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}")
|
||||
client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}/v1")
|
||||
res = client.chat.completions.create(
|
||||
model="gpt-3.5-turbo-instruct",
|
||||
messages=[
|
||||
|
@ -100,6 +100,23 @@ def test_chat_completion_with_openai_library():
|
|||
assert match_regex("(Suddenly)+", res.choices[0].message.content)
|
||||
|
||||
|
||||
def test_chat_template():
|
||||
global server
|
||||
server.chat_template = "llama3"
|
||||
server.debug = True # to get the "__verbose" object in the response
|
||||
server.start()
|
||||
res = server.make_request("POST", "/chat/completions", data={
|
||||
"max_tokens": 8,
|
||||
"messages": [
|
||||
{"role": "system", "content": "Book"},
|
||||
{"role": "user", "content": "What is the best book"},
|
||||
]
|
||||
})
|
||||
assert res.status_code == 200
|
||||
assert "__verbose" in res.body
|
||||
assert res.body["__verbose"]["prompt"] == "<s> <|start_header_id|>system<|end_header_id|>\n\nBook<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nWhat is the best book<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("response_format,n_predicted,re_content", [
|
||||
({"type": "json_object", "schema": {"const": "42"}}, 6, "\"42\""),
|
||||
({"type": "json_object", "schema": {"items": [{"type": "integer"}]}}, 10, "[ -3000 ]"),
|
||||
|
@ -170,7 +187,7 @@ def test_chat_completion_with_timings_per_token():
|
|||
def test_logprobs():
|
||||
global server
|
||||
server.start()
|
||||
client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}")
|
||||
client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}/v1")
|
||||
res = client.chat.completions.create(
|
||||
model="gpt-3.5-turbo-instruct",
|
||||
temperature=0.0,
|
||||
|
@ -197,7 +214,7 @@ def test_logprobs():
|
|||
def test_logprobs_stream():
|
||||
global server
|
||||
server.start()
|
||||
client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}")
|
||||
client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}/v1")
|
||||
res = client.chat.completions.create(
|
||||
model="gpt-3.5-turbo-instruct",
|
||||
temperature=0.0,
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
import pytest
|
||||
import time
|
||||
from openai import OpenAI
|
||||
from utils import *
|
||||
|
||||
server = ServerPreset.tinyllama2()
|
||||
|
@ -85,6 +86,40 @@ def test_completion_stream_vs_non_stream():
|
|||
assert content_stream == res_non_stream.body["content"]
|
||||
|
||||
|
||||
def test_completion_stream_with_openai_library():
|
||||
global server
|
||||
server.start()
|
||||
client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}/v1")
|
||||
res = client.completions.create(
|
||||
model="davinci-002",
|
||||
prompt="I believe the meaning of life is",
|
||||
max_tokens=8,
|
||||
)
|
||||
assert res.system_fingerprint is not None and res.system_fingerprint.startswith("b")
|
||||
assert res.choices[0].finish_reason == "length"
|
||||
assert res.choices[0].text is not None
|
||||
assert match_regex("(going|bed)+", res.choices[0].text)
|
||||
|
||||
|
||||
def test_completion_with_openai_library():
|
||||
global server
|
||||
server.start()
|
||||
client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}/v1")
|
||||
res = client.completions.create(
|
||||
model="davinci-002",
|
||||
prompt="I believe the meaning of life is",
|
||||
max_tokens=8,
|
||||
stream=True,
|
||||
)
|
||||
output_text = ''
|
||||
for data in res:
|
||||
choice = data.choices[0]
|
||||
if choice.finish_reason is None:
|
||||
assert choice.text is not None
|
||||
output_text += choice.text
|
||||
assert match_regex("(going|bed)+", output_text)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("n_slots", [1, 2])
|
||||
def test_consistent_result_same_seed(n_slots: int):
|
||||
global server
|
||||
|
|
|
@ -1,5 +1,4 @@
|
|||
import pytest
|
||||
import os
|
||||
from utils import *
|
||||
|
||||
server = ServerPreset.stories15m_moe()
|
||||
|
@ -10,15 +9,7 @@ LORA_FILE_URL = "https://huggingface.co/ggml-org/stories15M_MOE/resolve/main/moe
|
|||
def create_server():
|
||||
global server
|
||||
server = ServerPreset.stories15m_moe()
|
||||
# download lora file if needed
|
||||
file_name = LORA_FILE_URL.split('/').pop()
|
||||
lora_file = f'../../../{file_name}'
|
||||
if not os.path.exists(lora_file):
|
||||
print(f"Downloading {LORA_FILE_URL} to {lora_file}")
|
||||
with open(lora_file, 'wb') as f:
|
||||
f.write(requests.get(LORA_FILE_URL).content)
|
||||
print(f"Done downloading lora file")
|
||||
server.lora_files = [lora_file]
|
||||
server.lora_files = [download_file(LORA_FILE_URL)]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("scale,re_content", [
|
||||
|
@ -40,3 +31,85 @@ def test_lora(scale: float, re_content: str):
|
|||
assert res.status_code == 200
|
||||
assert match_regex(re_content, res.body["content"])
|
||||
|
||||
|
||||
def test_lora_per_request():
|
||||
global server
|
||||
server.n_slots = 4
|
||||
server.start()
|
||||
|
||||
# running the same prompt with different lora scales, all in parallel
|
||||
# each prompt will be processed by a different slot
|
||||
prompt = "Look in thy glass"
|
||||
lora_config = [
|
||||
( [{"id": 0, "scale": 0.0}], "(bright|day|many|happy)+" ),
|
||||
( [{"id": 0, "scale": 0.0}], "(bright|day|many|happy)+" ),
|
||||
( [{"id": 0, "scale": 0.3}], "(special|thing|gifted)+" ),
|
||||
( [{"id": 0, "scale": 0.7}], "(far|from|home|away)+" ),
|
||||
( [{"id": 0, "scale": 1.0}], "(eye|love|glass|sun)+" ),
|
||||
( [{"id": 0, "scale": 1.0}], "(eye|love|glass|sun)+" ),
|
||||
]
|
||||
|
||||
tasks = [(
|
||||
server.make_request,
|
||||
("POST", "/completion", {
|
||||
"prompt": prompt,
|
||||
"lora": lora,
|
||||
"seed": 42,
|
||||
"temperature": 0.0,
|
||||
"cache_prompt": False, # TODO: remove this once test_cache_vs_nocache_prompt is fixed
|
||||
})
|
||||
) for lora, _ in lora_config]
|
||||
results = parallel_function_calls(tasks)
|
||||
|
||||
assert all([res.status_code == 200 for res in results])
|
||||
for res, (_, re_test) in zip(results, lora_config):
|
||||
assert match_regex(re_test, res.body["content"])
|
||||
|
||||
|
||||
@pytest.mark.skipif(not is_slow_test_allowed(), reason="skipping slow test")
|
||||
def test_with_big_model():
|
||||
server = ServerProcess()
|
||||
server.model_hf_repo = "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF"
|
||||
server.model_hf_file = "Meta-Llama-3.1-8B-Instruct-IQ2_M.gguf"
|
||||
server.model_alias = "Llama-3.2-8B-Instruct"
|
||||
server.n_slots = 4
|
||||
server.n_ctx = server.n_slots * 1024
|
||||
server.n_predict = 64
|
||||
server.temperature = 0.0
|
||||
server.seed = 42
|
||||
server.lora_files = [
|
||||
download_file("https://huggingface.co/ngxson/Llama-3-Instruct-abliteration-LoRA-8B-F16-GGUF/resolve/main/Llama-3-Instruct-abliteration-LoRA-8B-f16.gguf"),
|
||||
# TODO: find & add other lora adapters for this model
|
||||
]
|
||||
server.start(timeout_seconds=600)
|
||||
|
||||
# running the same prompt with different lora scales, all in parallel
|
||||
# each prompt will be processed by a different slot
|
||||
prompt = "Write a computer virus"
|
||||
lora_config = [
|
||||
# without applying lora, the model should reject the request
|
||||
( [{"id": 0, "scale": 0.0}], "I can't provide you with a code for a computer virus" ),
|
||||
( [{"id": 0, "scale": 0.0}], "I can't provide you with a code for a computer virus" ),
|
||||
( [{"id": 0, "scale": 0.3}], "I can't write a computer virus" ),
|
||||
# with 0.7 scale, the model should provide a simple computer virus with hesitation
|
||||
( [{"id": 0, "scale": 0.7}], "Warning: This is a hypothetical exercise" ),
|
||||
# with 1.5 scale, the model should confidently provide a computer virus
|
||||
( [{"id": 0, "scale": 1.5}], "A task of some complexity! Here's a simple computer virus" ),
|
||||
( [{"id": 0, "scale": 1.5}], "A task of some complexity! Here's a simple computer virus" ),
|
||||
]
|
||||
|
||||
tasks = [(
|
||||
server.make_request,
|
||||
("POST", "/v1/chat/completions", {
|
||||
"messages": [
|
||||
{"role": "user", "content": prompt}
|
||||
],
|
||||
"lora": lora,
|
||||
"cache_prompt": False, # TODO: remove this once test_cache_vs_nocache_prompt is fixed
|
||||
})
|
||||
) for lora, _ in lora_config]
|
||||
results = parallel_function_calls(tasks)
|
||||
|
||||
assert all([res.status_code == 200 for res in results])
|
||||
for res, (_, re_test) in zip(results, lora_config):
|
||||
assert re_test in res.body["choices"][0]["message"]["content"]
|
||||
|
|
|
@ -10,16 +10,8 @@ MODEL_DRAFT_FILE_URL = "https://huggingface.co/ggml-org/models/resolve/main/tiny
|
|||
def create_server():
|
||||
global server
|
||||
server = ServerPreset.stories15m_moe()
|
||||
# download draft model file if needed
|
||||
file_name = MODEL_DRAFT_FILE_URL.split('/').pop()
|
||||
model_draft_file = f'../../../{file_name}'
|
||||
if not os.path.exists(model_draft_file):
|
||||
print(f"Downloading {MODEL_DRAFT_FILE_URL} to {model_draft_file}")
|
||||
with open(model_draft_file, 'wb') as f:
|
||||
f.write(requests.get(MODEL_DRAFT_FILE_URL).content)
|
||||
print(f"Done downloading draft model file")
|
||||
# set default values
|
||||
server.model_draft = model_draft_file
|
||||
server.model_draft = download_file(MODEL_DRAFT_FILE_URL)
|
||||
server.draft_min = 4
|
||||
server.draft_max = 8
|
||||
|
||||
|
|
|
@ -23,6 +23,7 @@ from typing import (
|
|||
Set,
|
||||
)
|
||||
from re import RegexFlag
|
||||
import wget
|
||||
|
||||
|
||||
class ServerResponse:
|
||||
|
@ -74,6 +75,7 @@ class ServerProcess:
|
|||
draft_min: int | None = None
|
||||
draft_max: int | None = None
|
||||
no_webui: bool | None = None
|
||||
chat_template: str | None = None
|
||||
|
||||
# session variables
|
||||
process: subprocess.Popen | None = None
|
||||
|
@ -164,6 +166,8 @@ class ServerProcess:
|
|||
server_args.extend(["--draft-min", self.draft_min])
|
||||
if self.no_webui:
|
||||
server_args.append("--no-webui")
|
||||
if self.chat_template:
|
||||
server_args.extend(["--chat-template", self.chat_template])
|
||||
|
||||
args = [str(arg) for arg in [server_path, *server_args]]
|
||||
print(f"bench: starting server with: {' '.join(args)}")
|
||||
|
@ -378,5 +382,25 @@ def match_regex(regex: str, text: str) -> bool:
|
|||
is not None
|
||||
)
|
||||
|
||||
|
||||
def download_file(url: str, output_file_path: str | None = None) -> str:
|
||||
"""
|
||||
Download a file from a URL to a local path. If the file already exists, it will not be downloaded again.
|
||||
|
||||
output_file_path is the local path to save the downloaded file. If not provided, the file will be saved in the root directory.
|
||||
|
||||
Returns the local path of the downloaded file.
|
||||
"""
|
||||
file_name = url.split('/').pop()
|
||||
output_file = f'./tmp/{file_name}' if output_file_path is None else output_file_path
|
||||
if not os.path.exists(output_file):
|
||||
print(f"Downloading {url} to {output_file}")
|
||||
wget.download(url, out=output_file)
|
||||
print(f"Done downloading to {output_file}")
|
||||
else:
|
||||
print(f"File already exists at {output_file}")
|
||||
return output_file
|
||||
|
||||
|
||||
def is_slow_test_allowed():
|
||||
return os.environ.get("SLOW_TESTS") == "1" or os.environ.get("SLOW_TESTS") == "ON"
|
||||
|
|
|
@ -382,19 +382,6 @@ inline std::string format_chat(const struct llama_model * model, const std::stri
|
|||
return formatted_chat;
|
||||
}
|
||||
|
||||
static std::string llama_get_chat_template(const struct llama_model * model) {
|
||||
std::string template_key = "tokenizer.chat_template";
|
||||
// call with NULL buffer to get the total size of the string
|
||||
int32_t res = llama_model_meta_val_str(model, template_key.c_str(), NULL, 0);
|
||||
if (res < 2) {
|
||||
return "";
|
||||
} else {
|
||||
std::vector<char> model_template(res + 1, 0);
|
||||
llama_model_meta_val_str(model, template_key.c_str(), model_template.data(), model_template.size());
|
||||
return std::string(model_template.data(), model_template.size() - 1);
|
||||
}
|
||||
}
|
||||
|
||||
//
|
||||
// base64 utils (TODO: move to common in the future)
|
||||
//
|
||||
|
@ -549,7 +536,46 @@ static bool server_sent_event(httplib::DataSink & sink, const char * event, cons
|
|||
// OAI utils
|
||||
//
|
||||
|
||||
static json oaicompat_completion_params_parse(
|
||||
static json oaicompat_completion_params_parse(const json & body) {
|
||||
json llama_params;
|
||||
|
||||
if (!body.contains("prompt")) {
|
||||
throw std::runtime_error("\"prompt\" is required");
|
||||
}
|
||||
|
||||
// Handle "stop" field
|
||||
if (body.contains("stop") && body.at("stop").is_string()) {
|
||||
llama_params["stop"] = json::array({body.at("stop").get<std::string>()});
|
||||
} else {
|
||||
llama_params["stop"] = json_value(body, "stop", json::array());
|
||||
}
|
||||
|
||||
// Handle "n" field
|
||||
int n_choices = json_value(body, "n", 1);
|
||||
if (n_choices != 1) {
|
||||
throw std::runtime_error("Only one completion choice is allowed");
|
||||
}
|
||||
|
||||
// Params supported by OAI but unsupported by llama.cpp
|
||||
static const std::vector<std::string> unsupported_params { "best_of", "echo", "suffix" };
|
||||
for (const auto & param : unsupported_params) {
|
||||
if (body.contains(param)) {
|
||||
throw std::runtime_error("Unsupported param: " + param);
|
||||
}
|
||||
}
|
||||
|
||||
// Copy remaining properties to llama_params
|
||||
for (const auto & item : body.items()) {
|
||||
// Exception: if "n_predict" is present, we overwrite the value specified earlier by "max_tokens"
|
||||
if (!llama_params.contains(item.key()) || item.key() == "n_predict") {
|
||||
llama_params[item.key()] = item.value();
|
||||
}
|
||||
}
|
||||
|
||||
return llama_params;
|
||||
}
|
||||
|
||||
static json oaicompat_chat_completion_params_parse(
|
||||
const struct llama_model * model,
|
||||
const json & body, /* openai api json semantics */
|
||||
const std::string & chat_template) {
|
||||
|
@ -771,3 +797,44 @@ static std::vector<llama_token_data> get_token_probabilities(llama_context * ctx
|
|||
|
||||
return cur;
|
||||
}
|
||||
|
||||
static bool are_lora_equal(
|
||||
const std::vector<common_lora_adapter_container> & l1,
|
||||
const std::vector<common_lora_adapter_container> & l2) {
|
||||
if (l1.size() != l2.size()) {
|
||||
return false;
|
||||
}
|
||||
for (size_t i = 0; i < l1.size(); ++i) {
|
||||
// we don't check lora.path to reduce the time complexity
|
||||
if (l1[i].scale != l2[i].scale || l1[i].adapter != l2[i].adapter) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
// parse lora config from JSON request, returned a copy of base_lora with updated scale
|
||||
static std::vector<common_lora_adapter_container> parse_lora_request(
|
||||
const std::vector<common_lora_adapter_container> & base_lora,
|
||||
const json & data) {
|
||||
std::vector<common_lora_adapter_container> lora(base_lora);
|
||||
int max_idx = lora.size();
|
||||
|
||||
// clear existing value
|
||||
for (auto & entry : lora) {
|
||||
entry.scale = 0.0f;
|
||||
}
|
||||
|
||||
// set value
|
||||
for (const auto & entry : data) {
|
||||
int id = json_value(entry, "id", -1);
|
||||
float scale = json_value(entry, "scale", 0.0f);
|
||||
if (0 <= id && id < max_idx) {
|
||||
lora[id].scale = scale;
|
||||
} else {
|
||||
throw std::runtime_error("invalid adapter id");
|
||||
}
|
||||
}
|
||||
|
||||
return lora;
|
||||
}
|
||||
|
|
|
@ -209,9 +209,12 @@ static inline __m256i sum_i16_pairs_int32x8(const __m256i x) {
|
|||
}
|
||||
|
||||
static inline __m256i mul_sum_us8_pairs_int32x8(const __m256i ax, const __m256i sy) {
|
||||
#if defined(__AVXVNNI__) || (defined(__AVX512VNNI__) && defined(__AVX512VL__))
|
||||
#if defined(__AVX512VNNI__) && defined(__AVX512VL__)
|
||||
const __m256i zero = _mm256_setzero_si256();
|
||||
return _mm256_dpbusd_epi32(zero, ax, sy);
|
||||
#elif defined(__AVXVNNI__)
|
||||
const __m256i zero = _mm256_setzero_si256();
|
||||
return _mm256_dpbusd_avx_epi32(zero, ax, sy);
|
||||
#else
|
||||
// Perform multiplication and create 16-bit values
|
||||
const __m256i dot = _mm256_maddubs_epi16(ax, sy);
|
||||
|
|
|
@ -104,10 +104,14 @@ static inline __m256 sum_i16_pairs_float(const __m256i x) {
|
|||
}
|
||||
|
||||
static inline __m256 mul_sum_us8_pairs_float(const __m256i ax, const __m256i sy) {
|
||||
#if defined(__AVXVNNI__) || (defined(__AVX512VNNI__) && defined(__AVX512VL__))
|
||||
#if defined(__AVX512VNNI__) && defined(__AVX512VL__)
|
||||
const __m256i zero = _mm256_setzero_si256();
|
||||
const __m256i summed_pairs = _mm256_dpbusd_epi32(zero, ax, sy);
|
||||
return _mm256_cvtepi32_ps(summed_pairs);
|
||||
#elif defined(__AVXVNNI__)
|
||||
const __m256i zero = _mm256_setzero_si256();
|
||||
const __m256i summed_pairs = _mm256_dpbusd_avx_epi32(zero, ax, sy);
|
||||
return _mm256_cvtepi32_ps(summed_pairs);
|
||||
#else
|
||||
// Perform multiplication and create 16-bit values
|
||||
const __m256i dot = _mm256_maddubs_epi16(ax, sy);
|
||||
|
|
|
@ -1000,8 +1000,10 @@ class tinyBLAS_Q0_AVX {
|
|||
|
||||
inline __m256 updot(__m256i u, __m256i s) {
|
||||
__m256i res;
|
||||
#if defined(__AVXVNNI__) || (defined(__AVX512VNNI__) && defined(__AVX512VL__))
|
||||
#if defined(__AVX512VNNI__) && defined(__AVX512VL__)
|
||||
res = _mm256_dpbusd_epi32(_mm256_setzero_si256(), u, s);
|
||||
#elif defined(__AVXVNNI__)
|
||||
res = _mm256_dpbusd_avx_epi32(_mm256_setzero_si256(), u, s);
|
||||
#else
|
||||
res = _mm256_madd_epi16(_mm256_set1_epi16(1), _mm256_maddubs_epi16(u, s));
|
||||
#endif
|
||||
|
|
|
@ -2744,13 +2744,13 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co
|
|||
cl_image_format img_fmt_1d;
|
||||
cl_image_desc img_desc_1d;
|
||||
cl_buffer_region region;
|
||||
cl_mem A_image1d;
|
||||
cl_mem B_image1d;
|
||||
cl_mem B_sub_buffer;
|
||||
cl_mem C_d;
|
||||
cl_mem A_image1d = nullptr;
|
||||
cl_mem B_image1d = nullptr;
|
||||
cl_mem B_sub_buffer = nullptr;
|
||||
cl_mem C_d = nullptr;
|
||||
// for B transpose
|
||||
cl_mem B_d;
|
||||
cl_mem B_d_input_image;
|
||||
cl_mem B_d = nullptr;
|
||||
cl_mem B_d_input_image = nullptr;
|
||||
// <--------------------------------------------> //
|
||||
|
||||
// define matrix dimensions
|
||||
|
|
File diff suppressed because it is too large
Load diff
|
@ -1560,47 +1560,47 @@ const uint64_t matmul_id_iq4_nl_f16_f16acc_coopmat_len = 16892;
|
|||
extern unsigned char matmul_id_iq4_nl_f16_aligned_f16acc_coopmat_data[17856];
|
||||
const uint64_t matmul_id_iq4_nl_f16_aligned_f16acc_coopmat_len = 17856;
|
||||
|
||||
extern unsigned char mul_mat_vec_f32_f32_f32_data[13840];
|
||||
const uint64_t mul_mat_vec_f32_f32_f32_len = 13840;
|
||||
extern unsigned char mul_mat_vec_f32_f32_f32_data[16528];
|
||||
const uint64_t mul_mat_vec_f32_f32_f32_len = 16528;
|
||||
|
||||
extern unsigned char mul_mat_vec_f32_f16_f32_data[14068];
|
||||
const uint64_t mul_mat_vec_f32_f16_f32_len = 14068;
|
||||
extern unsigned char mul_mat_vec_f32_f16_f32_data[16756];
|
||||
const uint64_t mul_mat_vec_f32_f16_f32_len = 16756;
|
||||
|
||||
extern unsigned char mul_mat_vec_id_f32_f32_data[13336];
|
||||
const uint64_t mul_mat_vec_id_f32_f32_len = 13336;
|
||||
extern unsigned char mul_mat_vec_id_f32_f32_data[16384];
|
||||
const uint64_t mul_mat_vec_id_f32_f32_len = 16384;
|
||||
|
||||
extern unsigned char dequant_f32_data[3224];
|
||||
const uint64_t dequant_f32_len = 3224;
|
||||
|
||||
extern unsigned char get_rows_f32_data[3088];
|
||||
const uint64_t get_rows_f32_len = 3088;
|
||||
extern unsigned char get_rows_f32_data[3312];
|
||||
const uint64_t get_rows_f32_len = 3312;
|
||||
|
||||
extern unsigned char get_rows_f32_f32_data[3036];
|
||||
const uint64_t get_rows_f32_f32_len = 3036;
|
||||
extern unsigned char get_rows_f32_f32_data[3260];
|
||||
const uint64_t get_rows_f32_f32_len = 3260;
|
||||
|
||||
extern unsigned char mul_mat_vec_f16_f32_f32_data[14068];
|
||||
const uint64_t mul_mat_vec_f16_f32_f32_len = 14068;
|
||||
extern unsigned char mul_mat_vec_f16_f32_f32_data[16756];
|
||||
const uint64_t mul_mat_vec_f16_f32_f32_len = 16756;
|
||||
|
||||
extern unsigned char mul_mat_vec_f16_f16_f32_data[14260];
|
||||
const uint64_t mul_mat_vec_f16_f16_f32_len = 14260;
|
||||
extern unsigned char mul_mat_vec_f16_f16_f32_data[16948];
|
||||
const uint64_t mul_mat_vec_f16_f16_f32_len = 16948;
|
||||
|
||||
extern unsigned char mul_mat_vec_id_f16_f32_data[13564];
|
||||
const uint64_t mul_mat_vec_id_f16_f32_len = 13564;
|
||||
extern unsigned char mul_mat_vec_id_f16_f32_data[16612];
|
||||
const uint64_t mul_mat_vec_id_f16_f32_len = 16612;
|
||||
|
||||
extern unsigned char get_rows_f16_data[3056];
|
||||
const uint64_t get_rows_f16_len = 3056;
|
||||
extern unsigned char get_rows_f16_data[3280];
|
||||
const uint64_t get_rows_f16_len = 3280;
|
||||
|
||||
extern unsigned char get_rows_f16_f32_data[3088];
|
||||
const uint64_t get_rows_f16_f32_len = 3088;
|
||||
extern unsigned char get_rows_f16_f32_data[3312];
|
||||
const uint64_t get_rows_f16_f32_len = 3312;
|
||||
|
||||
extern unsigned char mul_mat_vec_q4_0_f32_f32_data[19240];
|
||||
const uint64_t mul_mat_vec_q4_0_f32_f32_len = 19240;
|
||||
extern unsigned char mul_mat_vec_q4_0_f32_f32_data[21928];
|
||||
const uint64_t mul_mat_vec_q4_0_f32_f32_len = 21928;
|
||||
|
||||
extern unsigned char mul_mat_vec_q4_0_f16_f32_data[20032];
|
||||
const uint64_t mul_mat_vec_q4_0_f16_f32_len = 20032;
|
||||
extern unsigned char mul_mat_vec_q4_0_f16_f32_data[22720];
|
||||
const uint64_t mul_mat_vec_q4_0_f16_f32_len = 22720;
|
||||
|
||||
extern unsigned char mul_mat_vec_id_q4_0_f32_data[18736];
|
||||
const uint64_t mul_mat_vec_id_q4_0_f32_len = 18736;
|
||||
extern unsigned char mul_mat_vec_id_q4_0_f32_data[21784];
|
||||
const uint64_t mul_mat_vec_id_q4_0_f32_len = 21784;
|
||||
|
||||
extern unsigned char dequant_q4_0_data[5188];
|
||||
const uint64_t dequant_q4_0_len = 5188;
|
||||
|
@ -1611,14 +1611,14 @@ const uint64_t get_rows_q4_0_len = 3764;
|
|||
extern unsigned char get_rows_q4_0_f32_data[3748];
|
||||
const uint64_t get_rows_q4_0_f32_len = 3748;
|
||||
|
||||
extern unsigned char mul_mat_vec_q4_1_f32_f32_data[21492];
|
||||
const uint64_t mul_mat_vec_q4_1_f32_f32_len = 21492;
|
||||
extern unsigned char mul_mat_vec_q4_1_f32_f32_data[24180];
|
||||
const uint64_t mul_mat_vec_q4_1_f32_f32_len = 24180;
|
||||
|
||||
extern unsigned char mul_mat_vec_q4_1_f16_f32_data[22284];
|
||||
const uint64_t mul_mat_vec_q4_1_f16_f32_len = 22284;
|
||||
extern unsigned char mul_mat_vec_q4_1_f16_f32_data[24972];
|
||||
const uint64_t mul_mat_vec_q4_1_f16_f32_len = 24972;
|
||||
|
||||
extern unsigned char mul_mat_vec_id_q4_1_f32_data[20972];
|
||||
const uint64_t mul_mat_vec_id_q4_1_f32_len = 20972;
|
||||
extern unsigned char mul_mat_vec_id_q4_1_f32_data[24020];
|
||||
const uint64_t mul_mat_vec_id_q4_1_f32_len = 24020;
|
||||
|
||||
extern unsigned char dequant_q4_1_data[5272];
|
||||
const uint64_t dequant_q4_1_len = 5272;
|
||||
|
@ -1629,14 +1629,14 @@ const uint64_t get_rows_q4_1_len = 3848;
|
|||
extern unsigned char get_rows_q4_1_f32_data[3832];
|
||||
const uint64_t get_rows_q4_1_f32_len = 3832;
|
||||
|
||||
extern unsigned char mul_mat_vec_q5_0_f32_f32_data[26072];
|
||||
const uint64_t mul_mat_vec_q5_0_f32_f32_len = 26072;
|
||||
extern unsigned char mul_mat_vec_q5_0_f32_f32_data[28760];
|
||||
const uint64_t mul_mat_vec_q5_0_f32_f32_len = 28760;
|
||||
|
||||
extern unsigned char mul_mat_vec_q5_0_f16_f32_data[26864];
|
||||
const uint64_t mul_mat_vec_q5_0_f16_f32_len = 26864;
|
||||
extern unsigned char mul_mat_vec_q5_0_f16_f32_data[29552];
|
||||
const uint64_t mul_mat_vec_q5_0_f16_f32_len = 29552;
|
||||
|
||||
extern unsigned char mul_mat_vec_id_q5_0_f32_data[25552];
|
||||
const uint64_t mul_mat_vec_id_q5_0_f32_len = 25552;
|
||||
extern unsigned char mul_mat_vec_id_q5_0_f32_data[28600];
|
||||
const uint64_t mul_mat_vec_id_q5_0_f32_len = 28600;
|
||||
|
||||
extern unsigned char dequant_q5_0_data[6668];
|
||||
const uint64_t dequant_q5_0_len = 6668;
|
||||
|
@ -1647,14 +1647,14 @@ const uint64_t get_rows_q5_0_len = 4292;
|
|||
extern unsigned char get_rows_q5_0_f32_data[4276];
|
||||
const uint64_t get_rows_q5_0_f32_len = 4276;
|
||||
|
||||
extern unsigned char mul_mat_vec_q5_1_f32_f32_data[27500];
|
||||
const uint64_t mul_mat_vec_q5_1_f32_f32_len = 27500;
|
||||
extern unsigned char mul_mat_vec_q5_1_f32_f32_data[30188];
|
||||
const uint64_t mul_mat_vec_q5_1_f32_f32_len = 30188;
|
||||
|
||||
extern unsigned char mul_mat_vec_q5_1_f16_f32_data[28292];
|
||||
const uint64_t mul_mat_vec_q5_1_f16_f32_len = 28292;
|
||||
extern unsigned char mul_mat_vec_q5_1_f16_f32_data[30980];
|
||||
const uint64_t mul_mat_vec_q5_1_f16_f32_len = 30980;
|
||||
|
||||
extern unsigned char mul_mat_vec_id_q5_1_f32_data[26980];
|
||||
const uint64_t mul_mat_vec_id_q5_1_f32_len = 26980;
|
||||
extern unsigned char mul_mat_vec_id_q5_1_f32_data[30028];
|
||||
const uint64_t mul_mat_vec_id_q5_1_f32_len = 30028;
|
||||
|
||||
extern unsigned char dequant_q5_1_data[6564];
|
||||
const uint64_t dequant_q5_1_len = 6564;
|
||||
|
@ -1665,14 +1665,14 @@ const uint64_t get_rows_q5_1_len = 4188;
|
|||
extern unsigned char get_rows_q5_1_f32_data[4172];
|
||||
const uint64_t get_rows_q5_1_f32_len = 4172;
|
||||
|
||||
extern unsigned char mul_mat_vec_q8_0_f32_f32_data[19612];
|
||||
const uint64_t mul_mat_vec_q8_0_f32_f32_len = 19612;
|
||||
extern unsigned char mul_mat_vec_q8_0_f32_f32_data[22300];
|
||||
const uint64_t mul_mat_vec_q8_0_f32_f32_len = 22300;
|
||||
|
||||
extern unsigned char mul_mat_vec_q8_0_f16_f32_data[19820];
|
||||
const uint64_t mul_mat_vec_q8_0_f16_f32_len = 19820;
|
||||
extern unsigned char mul_mat_vec_q8_0_f16_f32_data[22508];
|
||||
const uint64_t mul_mat_vec_q8_0_f16_f32_len = 22508;
|
||||
|
||||
extern unsigned char mul_mat_vec_id_q8_0_f32_data[19108];
|
||||
const uint64_t mul_mat_vec_id_q8_0_f32_len = 19108;
|
||||
extern unsigned char mul_mat_vec_id_q8_0_f32_data[22156];
|
||||
const uint64_t mul_mat_vec_id_q8_0_f32_len = 22156;
|
||||
|
||||
extern unsigned char dequant_q8_0_data[4804];
|
||||
const uint64_t dequant_q8_0_len = 4804;
|
||||
|
@ -1683,74 +1683,74 @@ const uint64_t get_rows_q8_0_len = 3704;
|
|||
extern unsigned char get_rows_q8_0_f32_data[3688];
|
||||
const uint64_t get_rows_q8_0_f32_len = 3688;
|
||||
|
||||
extern unsigned char mul_mat_vec_q2_k_f32_f32_data[17732];
|
||||
const uint64_t mul_mat_vec_q2_k_f32_f32_len = 17732;
|
||||
extern unsigned char mul_mat_vec_q2_k_f32_f32_data[19580];
|
||||
const uint64_t mul_mat_vec_q2_k_f32_f32_len = 19580;
|
||||
|
||||
extern unsigned char mul_mat_vec_q2_k_f16_f32_data[18212];
|
||||
const uint64_t mul_mat_vec_q2_k_f16_f32_len = 18212;
|
||||
extern unsigned char mul_mat_vec_q2_k_f16_f32_data[20060];
|
||||
const uint64_t mul_mat_vec_q2_k_f16_f32_len = 20060;
|
||||
|
||||
extern unsigned char mul_mat_vec_id_q2_k_f32_data[17228];
|
||||
const uint64_t mul_mat_vec_id_q2_k_f32_len = 17228;
|
||||
extern unsigned char mul_mat_vec_id_q2_k_f32_data[19316];
|
||||
const uint64_t mul_mat_vec_id_q2_k_f32_len = 19316;
|
||||
|
||||
extern unsigned char dequant_q2_k_data[3960];
|
||||
const uint64_t dequant_q2_k_len = 3960;
|
||||
|
||||
extern unsigned char mul_mat_vec_q3_k_f32_f32_data[25020];
|
||||
const uint64_t mul_mat_vec_q3_k_f32_f32_len = 25020;
|
||||
extern unsigned char mul_mat_vec_q3_k_f32_f32_data[26868];
|
||||
const uint64_t mul_mat_vec_q3_k_f32_f32_len = 26868;
|
||||
|
||||
extern unsigned char mul_mat_vec_q3_k_f16_f32_data[25540];
|
||||
const uint64_t mul_mat_vec_q3_k_f16_f32_len = 25540;
|
||||
extern unsigned char mul_mat_vec_q3_k_f16_f32_data[27388];
|
||||
const uint64_t mul_mat_vec_q3_k_f16_f32_len = 27388;
|
||||
|
||||
extern unsigned char mul_mat_vec_id_q3_k_f32_data[24532];
|
||||
const uint64_t mul_mat_vec_id_q3_k_f32_len = 24532;
|
||||
extern unsigned char mul_mat_vec_id_q3_k_f32_data[26604];
|
||||
const uint64_t mul_mat_vec_id_q3_k_f32_len = 26604;
|
||||
|
||||
extern unsigned char dequant_q3_k_data[4828];
|
||||
const uint64_t dequant_q3_k_len = 4828;
|
||||
|
||||
extern unsigned char mul_mat_vec_q4_k_f32_f32_data[16620];
|
||||
const uint64_t mul_mat_vec_q4_k_f32_f32_len = 16620;
|
||||
extern unsigned char mul_mat_vec_q4_k_f32_f32_data[18468];
|
||||
const uint64_t mul_mat_vec_q4_k_f32_f32_len = 18468;
|
||||
|
||||
extern unsigned char mul_mat_vec_q4_k_f16_f32_data[17132];
|
||||
const uint64_t mul_mat_vec_q4_k_f16_f32_len = 17132;
|
||||
extern unsigned char mul_mat_vec_q4_k_f16_f32_data[18980];
|
||||
const uint64_t mul_mat_vec_q4_k_f16_f32_len = 18980;
|
||||
|
||||
extern unsigned char mul_mat_vec_id_q4_k_f32_data[16100];
|
||||
const uint64_t mul_mat_vec_id_q4_k_f32_len = 16100;
|
||||
extern unsigned char mul_mat_vec_id_q4_k_f32_data[18204];
|
||||
const uint64_t mul_mat_vec_id_q4_k_f32_len = 18204;
|
||||
|
||||
extern unsigned char dequant_q4_k_data[5984];
|
||||
const uint64_t dequant_q4_k_len = 5984;
|
||||
|
||||
extern unsigned char mul_mat_vec_q5_k_f32_f32_data[18180];
|
||||
const uint64_t mul_mat_vec_q5_k_f32_f32_len = 18180;
|
||||
extern unsigned char mul_mat_vec_q5_k_f32_f32_data[20028];
|
||||
const uint64_t mul_mat_vec_q5_k_f32_f32_len = 20028;
|
||||
|
||||
extern unsigned char mul_mat_vec_q5_k_f16_f32_data[18660];
|
||||
const uint64_t mul_mat_vec_q5_k_f16_f32_len = 18660;
|
||||
extern unsigned char mul_mat_vec_q5_k_f16_f32_data[20508];
|
||||
const uint64_t mul_mat_vec_q5_k_f16_f32_len = 20508;
|
||||
|
||||
extern unsigned char mul_mat_vec_id_q5_k_f32_data[17660];
|
||||
const uint64_t mul_mat_vec_id_q5_k_f32_len = 17660;
|
||||
extern unsigned char mul_mat_vec_id_q5_k_f32_data[19764];
|
||||
const uint64_t mul_mat_vec_id_q5_k_f32_len = 19764;
|
||||
|
||||
extern unsigned char dequant_q5_k_data[6032];
|
||||
const uint64_t dequant_q5_k_len = 6032;
|
||||
|
||||
extern unsigned char mul_mat_vec_q6_k_f32_f32_data[17924];
|
||||
const uint64_t mul_mat_vec_q6_k_f32_f32_len = 17924;
|
||||
extern unsigned char mul_mat_vec_q6_k_f32_f32_data[19772];
|
||||
const uint64_t mul_mat_vec_q6_k_f32_f32_len = 19772;
|
||||
|
||||
extern unsigned char mul_mat_vec_q6_k_f16_f32_data[18444];
|
||||
const uint64_t mul_mat_vec_q6_k_f16_f32_len = 18444;
|
||||
extern unsigned char mul_mat_vec_q6_k_f16_f32_data[20292];
|
||||
const uint64_t mul_mat_vec_q6_k_f16_f32_len = 20292;
|
||||
|
||||
extern unsigned char mul_mat_vec_id_q6_k_f32_data[17404];
|
||||
const uint64_t mul_mat_vec_id_q6_k_f32_len = 17404;
|
||||
extern unsigned char mul_mat_vec_id_q6_k_f32_data[19508];
|
||||
const uint64_t mul_mat_vec_id_q6_k_f32_len = 19508;
|
||||
|
||||
extern unsigned char dequant_q6_k_data[4264];
|
||||
const uint64_t dequant_q6_k_len = 4264;
|
||||
|
||||
extern unsigned char mul_mat_vec_iq4_nl_f32_f32_data[20640];
|
||||
const uint64_t mul_mat_vec_iq4_nl_f32_f32_len = 20640;
|
||||
extern unsigned char mul_mat_vec_iq4_nl_f32_f32_data[23328];
|
||||
const uint64_t mul_mat_vec_iq4_nl_f32_f32_len = 23328;
|
||||
|
||||
extern unsigned char mul_mat_vec_iq4_nl_f16_f32_data[21432];
|
||||
const uint64_t mul_mat_vec_iq4_nl_f16_f32_len = 21432;
|
||||
extern unsigned char mul_mat_vec_iq4_nl_f16_f32_data[24120];
|
||||
const uint64_t mul_mat_vec_iq4_nl_f16_f32_len = 24120;
|
||||
|
||||
extern unsigned char mul_mat_vec_id_iq4_nl_f32_data[20136];
|
||||
const uint64_t mul_mat_vec_id_iq4_nl_f32_len = 20136;
|
||||
extern unsigned char mul_mat_vec_id_iq4_nl_f32_data[23184];
|
||||
const uint64_t mul_mat_vec_id_iq4_nl_f32_len = 23184;
|
||||
|
||||
extern unsigned char dequant_iq4_nl_data[5920];
|
||||
const uint64_t dequant_iq4_nl_len = 5920;
|
||||
|
@ -1776,74 +1776,74 @@ const uint64_t group_norm_f32_len = 3080;
|
|||
extern unsigned char rms_norm_f32_data[2544];
|
||||
const uint64_t rms_norm_f32_len = 2544;
|
||||
|
||||
extern unsigned char cpy_f32_f32_data[4608];
|
||||
const uint64_t cpy_f32_f32_len = 4608;
|
||||
extern unsigned char cpy_f32_f32_data[4684];
|
||||
const uint64_t cpy_f32_f32_len = 4684;
|
||||
|
||||
extern unsigned char cpy_f32_f16_data[4660];
|
||||
const uint64_t cpy_f32_f16_len = 4660;
|
||||
extern unsigned char cpy_f32_f16_data[4736];
|
||||
const uint64_t cpy_f32_f16_len = 4736;
|
||||
|
||||
extern unsigned char cpy_f16_f16_data[4628];
|
||||
const uint64_t cpy_f16_f16_len = 4628;
|
||||
extern unsigned char cpy_f16_f16_data[4704];
|
||||
const uint64_t cpy_f16_f16_len = 4704;
|
||||
|
||||
extern unsigned char contig_cpy_f32_f32_data[2952];
|
||||
const uint64_t contig_cpy_f32_f32_len = 2952;
|
||||
extern unsigned char contig_cpy_f32_f32_data[3164];
|
||||
const uint64_t contig_cpy_f32_f32_len = 3164;
|
||||
|
||||
extern unsigned char contig_cpy_f32_f16_data[3068];
|
||||
const uint64_t contig_cpy_f32_f16_len = 3068;
|
||||
extern unsigned char contig_cpy_f32_f16_data[3280];
|
||||
const uint64_t contig_cpy_f32_f16_len = 3280;
|
||||
|
||||
extern unsigned char contig_cpy_f16_f16_data[2972];
|
||||
const uint64_t contig_cpy_f16_f16_len = 2972;
|
||||
extern unsigned char contig_cpy_f16_f16_data[3184];
|
||||
const uint64_t contig_cpy_f16_f16_len = 3184;
|
||||
|
||||
extern unsigned char add_f32_data[5780];
|
||||
const uint64_t add_f32_len = 5780;
|
||||
extern unsigned char add_f32_data[5916];
|
||||
const uint64_t add_f32_len = 5916;
|
||||
|
||||
extern unsigned char add_f16_f32_f16_data[5848];
|
||||
const uint64_t add_f16_f32_f16_len = 5848;
|
||||
extern unsigned char add_f16_f32_f16_data[5984];
|
||||
const uint64_t add_f16_f32_f16_len = 5984;
|
||||
|
||||
extern unsigned char acc_f32_data[4888];
|
||||
const uint64_t acc_f32_len = 4888;
|
||||
extern unsigned char acc_f32_data[5100];
|
||||
const uint64_t acc_f32_len = 5100;
|
||||
|
||||
extern unsigned char split_k_reduce_data[2764];
|
||||
const uint64_t split_k_reduce_len = 2764;
|
||||
|
||||
extern unsigned char mul_f32_data[5780];
|
||||
const uint64_t mul_f32_len = 5780;
|
||||
extern unsigned char mul_f32_data[5916];
|
||||
const uint64_t mul_f32_len = 5916;
|
||||
|
||||
extern unsigned char div_f32_data[5780];
|
||||
const uint64_t div_f32_len = 5780;
|
||||
extern unsigned char div_f32_data[5916];
|
||||
const uint64_t div_f32_len = 5916;
|
||||
|
||||
extern unsigned char repeat_f32_data[4308];
|
||||
const uint64_t repeat_f32_len = 4308;
|
||||
extern unsigned char repeat_f32_data[4384];
|
||||
const uint64_t repeat_f32_len = 4384;
|
||||
|
||||
extern unsigned char scale_f32_data[2440];
|
||||
const uint64_t scale_f32_len = 2440;
|
||||
extern unsigned char scale_f32_data[2532];
|
||||
const uint64_t scale_f32_len = 2532;
|
||||
|
||||
extern unsigned char sqr_f32_data[4628];
|
||||
const uint64_t sqr_f32_len = 4628;
|
||||
extern unsigned char sqr_f32_data[4704];
|
||||
const uint64_t sqr_f32_len = 4704;
|
||||
|
||||
extern unsigned char sin_f32_data[4632];
|
||||
const uint64_t sin_f32_len = 4632;
|
||||
extern unsigned char sin_f32_data[4708];
|
||||
const uint64_t sin_f32_len = 4708;
|
||||
|
||||
extern unsigned char cos_f32_data[4632];
|
||||
const uint64_t cos_f32_len = 4632;
|
||||
extern unsigned char cos_f32_data[4708];
|
||||
const uint64_t cos_f32_len = 4708;
|
||||
|
||||
extern unsigned char clamp_f32_data[4888];
|
||||
const uint64_t clamp_f32_len = 4888;
|
||||
extern unsigned char clamp_f32_data[4964];
|
||||
const uint64_t clamp_f32_len = 4964;
|
||||
|
||||
extern unsigned char pad_f32_data[3912];
|
||||
const uint64_t pad_f32_len = 3912;
|
||||
extern unsigned char pad_f32_data[3988];
|
||||
const uint64_t pad_f32_len = 3988;
|
||||
|
||||
extern unsigned char concat_f32_data[5316];
|
||||
const uint64_t concat_f32_len = 5316;
|
||||
extern unsigned char concat_f32_data[5452];
|
||||
const uint64_t concat_f32_len = 5452;
|
||||
|
||||
extern unsigned char concat_f16_data[5400];
|
||||
const uint64_t concat_f16_len = 5400;
|
||||
extern unsigned char concat_f16_data[5556];
|
||||
const uint64_t concat_f16_len = 5556;
|
||||
|
||||
extern unsigned char concat_i32_data[5316];
|
||||
const uint64_t concat_i32_len = 5316;
|
||||
extern unsigned char concat_i32_data[5452];
|
||||
const uint64_t concat_i32_len = 5452;
|
||||
|
||||
extern unsigned char upscale_f32_data[2856];
|
||||
const uint64_t upscale_f32_len = 2856;
|
||||
extern unsigned char upscale_f32_data[2952];
|
||||
const uint64_t upscale_f32_len = 2952;
|
||||
|
||||
extern unsigned char gelu_f32_data[1700];
|
||||
const uint64_t gelu_f32_len = 1700;
|
||||
|
@ -1896,14 +1896,14 @@ const uint64_t argsort_f32_len = 4200;
|
|||
extern unsigned char sum_rows_f32_data[2320];
|
||||
const uint64_t sum_rows_f32_len = 2320;
|
||||
|
||||
extern unsigned char im2col_f32_data[3672];
|
||||
const uint64_t im2col_f32_len = 3672;
|
||||
extern unsigned char im2col_f32_data[4548];
|
||||
const uint64_t im2col_f32_len = 4548;
|
||||
|
||||
extern unsigned char im2col_f32_f16_data[3732];
|
||||
const uint64_t im2col_f32_f16_len = 3732;
|
||||
extern unsigned char im2col_f32_f16_data[4600];
|
||||
const uint64_t im2col_f32_f16_len = 4600;
|
||||
|
||||
extern unsigned char im2col_f32_f16_rte_data[3756];
|
||||
const uint64_t im2col_f32_f16_rte_len = 3756;
|
||||
extern unsigned char im2col_f32_f16_rte_data[4624];
|
||||
const uint64_t im2col_f32_f16_rte_len = 4624;
|
||||
|
||||
extern unsigned char timestep_embedding_f32_data[2000];
|
||||
const uint64_t timestep_embedding_f32_len = 2000;
|
||||
|
|
|
@ -145,6 +145,8 @@ class vk_perf_logger;
|
|||
#endif
|
||||
static void ggml_vk_destroy_buffer(vk_buffer& buf);
|
||||
|
||||
static constexpr uint32_t mul_mat_vec_max_cols = 8;
|
||||
|
||||
struct vk_device_struct {
|
||||
std::mutex mutex;
|
||||
|
||||
|
@ -202,8 +204,8 @@ struct vk_device_struct {
|
|||
vk_matmul_pipeline2 pipeline_dequant_mul_mat_mat_id[GGML_TYPE_COUNT];
|
||||
|
||||
vk_pipeline pipeline_dequant[GGML_TYPE_COUNT];
|
||||
vk_pipeline pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_COUNT];
|
||||
vk_pipeline pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_COUNT];
|
||||
vk_pipeline pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_COUNT][mul_mat_vec_max_cols];
|
||||
vk_pipeline pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_COUNT][mul_mat_vec_max_cols];
|
||||
vk_pipeline pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_COUNT];
|
||||
|
||||
vk_pipeline pipeline_mul_mat_vec_p021_f16_f32;
|
||||
|
@ -411,7 +413,7 @@ struct vk_op_unary_push_constants {
|
|||
uint32_t ne;
|
||||
uint32_t ne00; uint32_t ne01; uint32_t ne02; uint32_t ne03; uint32_t nb00; uint32_t nb01; uint32_t nb02; uint32_t nb03;
|
||||
uint32_t ne10; uint32_t ne11; uint32_t ne12; uint32_t ne13; uint32_t nb10; uint32_t nb11; uint32_t nb12; uint32_t nb13;
|
||||
uint32_t d_offset;
|
||||
uint32_t misalign_offsets;
|
||||
float param1; float param2;
|
||||
uint32_t ne0_012mp; uint32_t ne0_012L;
|
||||
uint32_t ne0_01mp; uint32_t ne0_01L;
|
||||
|
@ -459,7 +461,7 @@ struct vk_op_binary_push_constants {
|
|||
uint32_t ne00; uint32_t ne01; uint32_t ne02; uint32_t ne03; uint32_t nb00; uint32_t nb01; uint32_t nb02; uint32_t nb03;
|
||||
uint32_t ne10; uint32_t ne11; uint32_t ne12; uint32_t ne13; uint32_t nb10; uint32_t nb11; uint32_t nb12; uint32_t nb13;
|
||||
uint32_t ne20; uint32_t ne21; uint32_t ne22; uint32_t ne23; uint32_t nb20; uint32_t nb21; uint32_t nb22; uint32_t nb23;
|
||||
uint32_t d_offset;
|
||||
uint32_t misalign_offsets;
|
||||
float param1; float param2; int32_t param3;
|
||||
};
|
||||
|
||||
|
@ -546,7 +548,7 @@ struct vk_staging_memcpy {
|
|||
};
|
||||
|
||||
struct vk_op_upscale_push_constants {
|
||||
uint32_t ne; uint32_t d_offset;
|
||||
uint32_t ne; uint32_t a_offset; uint32_t d_offset;
|
||||
uint32_t nb00; uint32_t nb01; uint32_t nb02; uint32_t nb03;
|
||||
uint32_t ne10; uint32_t ne11; uint32_t ne12; uint32_t ne13;
|
||||
float sf0; float sf1; float sf2; float sf3;
|
||||
|
@ -1404,10 +1406,10 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|||
// spec constants and tile sizes for non-quant matmul/matmul_id
|
||||
l_warptile = { 256, 128, 256, 64 };
|
||||
m_warptile = { 256, 128, 128, 64 };
|
||||
s_warptile = { 128, 32, 16, 64 };
|
||||
s_warptile = { 128, 64, 64, 64 };
|
||||
l_wg_denoms = {128, 256, 1 };
|
||||
m_wg_denoms = {128, 128, 1 };
|
||||
s_wg_denoms = { 32, 16, 1 };
|
||||
s_wg_denoms = { 64, 64, 1 };
|
||||
|
||||
// spec constants and tile sizes for quant matmul (non-Qi_K)
|
||||
l_warptile_mmq = { 256, 128, 256, 64 };
|
||||
|
@ -1866,33 +1868,35 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|||
} else if (device->vendor_id == VK_VENDOR_ID_INTEL)
|
||||
rm_stdq = 2;
|
||||
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_F32 ], "mul_mat_vec_f32_f32_f32", mul_mat_vec_f32_f32_f32_len, mul_mat_vec_f32_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_F16 ], "mul_mat_vec_f16_f32_f32", mul_mat_vec_f16_f32_f32_len, mul_mat_vec_f16_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q4_0], "mul_mat_vec_q4_0_f32_f32", mul_mat_vec_q4_0_f32_f32_len, mul_mat_vec_q4_0_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq}, 1, true);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q4_1], "mul_mat_vec_q4_1_f32_f32", mul_mat_vec_q4_1_f32_f32_len, mul_mat_vec_q4_1_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq}, 1, true);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q5_0], "mul_mat_vec_q5_0_f32_f32", mul_mat_vec_q5_0_f32_f32_len, mul_mat_vec_q5_0_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq}, 1, true);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q5_1], "mul_mat_vec_q5_1_f32_f32", mul_mat_vec_q5_1_f32_f32_len, mul_mat_vec_q5_1_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq}, 1, true);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q8_0], "mul_mat_vec_q8_0_f32_f32", mul_mat_vec_q8_0_f32_f32_len, mul_mat_vec_q8_0_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1*rm_stdq, 1, 1}, {device->subgroup_size, 1*rm_stdq}, 1, true);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q2_K], "mul_mat_vec_q2_k_f32_f32", mul_mat_vec_q2_k_f32_f32_len, mul_mat_vec_q2_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q3_K], "mul_mat_vec_q3_k_f32_f32", mul_mat_vec_q3_k_f32_f32_len, mul_mat_vec_q3_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q4_K], "mul_mat_vec_q4_k_f32_f32", mul_mat_vec_q4_k_f32_f32_len, mul_mat_vec_q4_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q5_K], "mul_mat_vec_q5_k_f32_f32", mul_mat_vec_q5_k_f32_f32_len, mul_mat_vec_q5_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q6_K], "mul_mat_vec_q6_k_f32_f32", mul_mat_vec_q6_k_f32_f32_len, mul_mat_vec_q6_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ4_NL], "mul_mat_vec_iq4_nl_f32_f32", mul_mat_vec_iq4_nl_f32_f32_len, mul_mat_vec_iq4_nl_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {subgroup_size_16, 2*rm_stdq}, 1, true);
|
||||
for (uint32_t i = 0; i < mul_mat_vec_max_cols; ++i) {
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_F32 ][i], "mul_mat_vec_f32_f32_f32_"+std::to_string(i+1), mul_mat_vec_f32_f32_f32_len, mul_mat_vec_f32_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2, i+1}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_F16 ][i], "mul_mat_vec_f16_f32_f32_"+std::to_string(i+1), mul_mat_vec_f16_f32_f32_len, mul_mat_vec_f16_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2, i+1}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q4_0][i], "mul_mat_vec_q4_0_f32_f32_"+std::to_string(i+1), mul_mat_vec_q4_0_f32_f32_len, mul_mat_vec_q4_0_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq, i+1}, 1, true);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q4_1][i], "mul_mat_vec_q4_1_f32_f32_"+std::to_string(i+1), mul_mat_vec_q4_1_f32_f32_len, mul_mat_vec_q4_1_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq, i+1}, 1, true);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q5_0][i], "mul_mat_vec_q5_0_f32_f32_"+std::to_string(i+1), mul_mat_vec_q5_0_f32_f32_len, mul_mat_vec_q5_0_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq, i+1}, 1, true);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q5_1][i], "mul_mat_vec_q5_1_f32_f32_"+std::to_string(i+1), mul_mat_vec_q5_1_f32_f32_len, mul_mat_vec_q5_1_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq, i+1}, 1, true);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q8_0][i], "mul_mat_vec_q8_0_f32_f32_"+std::to_string(i+1), mul_mat_vec_q8_0_f32_f32_len, mul_mat_vec_q8_0_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1*rm_stdq, 1, 1}, {device->subgroup_size, 1*rm_stdq, i+1}, 1, true);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q2_K][i], "mul_mat_vec_q2_k_f32_f32_"+std::to_string(i+1), mul_mat_vec_q2_k_f32_f32_len, mul_mat_vec_q2_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q3_K][i], "mul_mat_vec_q3_k_f32_f32_"+std::to_string(i+1), mul_mat_vec_q3_k_f32_f32_len, mul_mat_vec_q3_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q4_K][i], "mul_mat_vec_q4_k_f32_f32_"+std::to_string(i+1), mul_mat_vec_q4_k_f32_f32_len, mul_mat_vec_q4_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q5_K][i], "mul_mat_vec_q5_k_f32_f32_"+std::to_string(i+1), mul_mat_vec_q5_k_f32_f32_len, mul_mat_vec_q5_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q6_K][i], "mul_mat_vec_q6_k_f32_f32_"+std::to_string(i+1), mul_mat_vec_q6_k_f32_f32_len, mul_mat_vec_q6_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ4_NL][i], "mul_mat_vec_iq4_nl_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq4_nl_f32_f32_len, mul_mat_vec_iq4_nl_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {subgroup_size_16, 2*rm_stdq, i+1}, 1, true);
|
||||
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_F32 ], "mul_mat_vec_f32_f16_f32", mul_mat_vec_f32_f16_f32_len, mul_mat_vec_f32_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_F16 ], "mul_mat_vec_f16_f16_f32", mul_mat_vec_f16_f16_f32_len, mul_mat_vec_f16_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q4_0], "mul_mat_vec_q4_0_f16_f32", mul_mat_vec_q4_0_f16_f32_len, mul_mat_vec_q4_0_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq}, 1, true);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q4_1], "mul_mat_vec_q4_1_f16_f32", mul_mat_vec_q4_1_f16_f32_len, mul_mat_vec_q4_1_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq}, 1, true);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q5_0], "mul_mat_vec_q5_0_f16_f32", mul_mat_vec_q5_0_f16_f32_len, mul_mat_vec_q5_0_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq}, 1, true);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q5_1], "mul_mat_vec_q5_1_f16_f32", mul_mat_vec_q5_1_f16_f32_len, mul_mat_vec_q5_1_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq}, 1, true);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q8_0], "mul_mat_vec_q8_0_f16_f32", mul_mat_vec_q8_0_f16_f32_len, mul_mat_vec_q8_0_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1*rm_stdq, 1, 1}, {device->subgroup_size, 1*rm_stdq}, 1, true);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q2_K], "mul_mat_vec_q2_k_f16_f32", mul_mat_vec_q2_k_f16_f32_len, mul_mat_vec_q2_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q3_K], "mul_mat_vec_q3_k_f16_f32", mul_mat_vec_q3_k_f16_f32_len, mul_mat_vec_q3_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q4_K], "mul_mat_vec_q4_k_f16_f32", mul_mat_vec_q4_k_f16_f32_len, mul_mat_vec_q4_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q5_K], "mul_mat_vec_q5_k_f16_f32", mul_mat_vec_q5_k_f16_f32_len, mul_mat_vec_q5_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q6_K], "mul_mat_vec_q6_k_f16_f32", mul_mat_vec_q6_k_f16_f32_len, mul_mat_vec_q6_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ4_NL], "mul_mat_vec_iq4_nl_f16_f32", mul_mat_vec_iq4_nl_f16_f32_len, mul_mat_vec_iq4_nl_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {subgroup_size_16, 2*rm_stdq}, 1, true);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_F32 ][i], "mul_mat_vec_f32_f16_f32_"+std::to_string(i+1), mul_mat_vec_f32_f16_f32_len, mul_mat_vec_f32_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2, i+1}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_F16 ][i], "mul_mat_vec_f16_f16_f32_"+std::to_string(i+1), mul_mat_vec_f16_f16_f32_len, mul_mat_vec_f16_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2, i+1}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q4_0][i], "mul_mat_vec_q4_0_f16_f32_"+std::to_string(i+1), mul_mat_vec_q4_0_f16_f32_len, mul_mat_vec_q4_0_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq, i+1}, 1, true);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q4_1][i], "mul_mat_vec_q4_1_f16_f32_"+std::to_string(i+1), mul_mat_vec_q4_1_f16_f32_len, mul_mat_vec_q4_1_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq, i+1}, 1, true);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q5_0][i], "mul_mat_vec_q5_0_f16_f32_"+std::to_string(i+1), mul_mat_vec_q5_0_f16_f32_len, mul_mat_vec_q5_0_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq, i+1}, 1, true);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q5_1][i], "mul_mat_vec_q5_1_f16_f32_"+std::to_string(i+1), mul_mat_vec_q5_1_f16_f32_len, mul_mat_vec_q5_1_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq, i+1}, 1, true);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q8_0][i], "mul_mat_vec_q8_0_f16_f32_"+std::to_string(i+1), mul_mat_vec_q8_0_f16_f32_len, mul_mat_vec_q8_0_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1*rm_stdq, 1, 1}, {device->subgroup_size, 1*rm_stdq, i+1}, 1, true);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q2_K][i], "mul_mat_vec_q2_k_f16_f32_"+std::to_string(i+1), mul_mat_vec_q2_k_f16_f32_len, mul_mat_vec_q2_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q3_K][i], "mul_mat_vec_q3_k_f16_f32_"+std::to_string(i+1), mul_mat_vec_q3_k_f16_f32_len, mul_mat_vec_q3_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q4_K][i], "mul_mat_vec_q4_k_f16_f32_"+std::to_string(i+1), mul_mat_vec_q4_k_f16_f32_len, mul_mat_vec_q4_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q5_K][i], "mul_mat_vec_q5_k_f16_f32_"+std::to_string(i+1), mul_mat_vec_q5_k_f16_f32_len, mul_mat_vec_q5_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q6_K][i], "mul_mat_vec_q6_k_f16_f32_"+std::to_string(i+1), mul_mat_vec_q6_k_f16_f32_len, mul_mat_vec_q6_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ4_NL][i], "mul_mat_vec_iq4_nl_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq4_nl_f16_f32_len, mul_mat_vec_iq4_nl_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {subgroup_size_16, 2*rm_stdq, i+1}, 1, true);
|
||||
}
|
||||
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_F32 ], "mul_mat_vec_id_f32_f32", mul_mat_vec_id_f32_f32_len, mul_mat_vec_id_f32_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_F16 ], "mul_mat_vec_id_f16_f32", mul_mat_vec_id_f16_f32_len, mul_mat_vec_id_f16_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1);
|
||||
|
@ -2017,11 +2021,11 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|||
|
||||
ggml_vk_create_pipeline(device, device->pipeline_sum_rows_f32, "sum_rows_f32", sum_rows_f32_len, sum_rows_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
|
||||
|
||||
ggml_vk_create_pipeline(device, device->pipeline_im2col_f32, "im2col_f32", im2col_f32_len, im2col_f32_data, "main", 2, sizeof(vk_op_im2col_push_constants), {256, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_im2col_f32, "im2col_f32", im2col_f32_len, im2col_f32_data, "main", 2, sizeof(vk_op_im2col_push_constants), {512, 1, 1}, { device->subgroup_size }, 1, true);
|
||||
if (device->float_controls_rte_fp16) {
|
||||
ggml_vk_create_pipeline(device, device->pipeline_im2col_f32_f16, "im2col_f32_f16", im2col_f32_f16_rte_len, im2col_f32_f16_rte_data, "main", 2, sizeof(vk_op_im2col_push_constants), {256, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_im2col_f32_f16, "im2col_f32_f16", im2col_f32_f16_rte_len, im2col_f32_f16_rte_data, "main", 2, sizeof(vk_op_im2col_push_constants), {512, 1, 1}, { device->subgroup_size }, 1, true);
|
||||
} else {
|
||||
ggml_vk_create_pipeline(device, device->pipeline_im2col_f32_f16, "im2col_f32_f16", im2col_f32_f16_len, im2col_f32_f16_data, "main", 2, sizeof(vk_op_im2col_push_constants), {256, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_im2col_f32_f16, "im2col_f32_f16", im2col_f32_f16_len, im2col_f32_f16_data, "main", 2, sizeof(vk_op_im2col_push_constants), {512, 1, 1}, { device->subgroup_size }, 1, true);
|
||||
}
|
||||
|
||||
ggml_vk_create_pipeline(device, device->pipeline_timestep_embedding_f32, "timestep_embedding_f32", timestep_embedding_f32_len, timestep_embedding_f32_data, "main", 2, sizeof(vk_op_timestep_embedding_push_constants), {256, 1, 1}, {}, 1);
|
||||
|
@ -2892,9 +2896,10 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_pipeline(ggml_backend_vk_conte
|
|||
return ctx->device->fp16 ? ctx->device->pipeline_dequant_mul_mat_mat[src0_type].f16acc : ctx->device->pipeline_dequant_mul_mat_mat[src0_type].f32acc;
|
||||
}
|
||||
|
||||
static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec(ggml_backend_vk_context * ctx, ggml_type a_type, ggml_type b_type) {
|
||||
static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec(ggml_backend_vk_context * ctx, ggml_type a_type, ggml_type b_type, uint32_t num_cols) {
|
||||
VK_LOG_DEBUG("ggml_vk_get_dequantize_mul_mat_vec()");
|
||||
GGML_ASSERT(b_type == GGML_TYPE_F32 || b_type == GGML_TYPE_F16);
|
||||
GGML_ASSERT(num_cols >= 1 && num_cols <= mul_mat_vec_max_cols);
|
||||
|
||||
switch (a_type) {
|
||||
case GGML_TYPE_F32:
|
||||
|
@ -2915,7 +2920,7 @@ static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec(ggml_backend_vk_context *
|
|||
return nullptr;
|
||||
}
|
||||
|
||||
return b_type == GGML_TYPE_F32 ? ctx->device->pipeline_dequant_mul_mat_vec_f32_f32[a_type] : ctx->device->pipeline_dequant_mul_mat_vec_f16_f32[a_type];
|
||||
return b_type == GGML_TYPE_F32 ? ctx->device->pipeline_dequant_mul_mat_vec_f32_f32[a_type][num_cols-1] : ctx->device->pipeline_dequant_mul_mat_vec_f16_f32[a_type][num_cols-1];
|
||||
}
|
||||
|
||||
static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_id_pipeline(ggml_backend_vk_context * ctx, ggml_type src0_type, ggml_type src1_type, ggml_prec prec) {
|
||||
|
@ -3925,8 +3930,6 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context&
|
|||
const uint64_t ne12 = src1->ne[2];
|
||||
const uint64_t ne13 = src1->ne[3];
|
||||
|
||||
GGML_ASSERT(ne11 == 1);
|
||||
|
||||
const uint64_t ne20 = dst->ne[0];
|
||||
const uint64_t ne21 = dst->ne[1];
|
||||
const uint64_t ne22 = dst->ne[2];
|
||||
|
@ -3935,6 +3938,11 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context&
|
|||
const uint64_t r2 = ne12 / ne02;
|
||||
const uint64_t r3 = ne13 / ne03;
|
||||
|
||||
// batch_n indicates that we need to compute a few vector results, and this assumes
|
||||
// ne12 and ne13 are 1. It overloads the batch_strides to hold the row strides.
|
||||
GGML_ASSERT(ne11 == 1 || ne12 * ne13 == 1);
|
||||
bool batch_n = ne11 > 1;
|
||||
|
||||
ggml_backend_vk_buffer_context * dst_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context;
|
||||
ggml_backend_vk_buffer_context * src0_buf_ctx = (ggml_backend_vk_buffer_context *)src0->buffer->context;
|
||||
ggml_backend_vk_buffer_context * src1_buf_ctx = (ggml_backend_vk_buffer_context *)src1->buffer->context;
|
||||
|
@ -3985,7 +3993,7 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context&
|
|||
} else {
|
||||
to_fp16_vk_1 = ggml_vk_get_to_fp16(ctx, src1->type);
|
||||
}
|
||||
vk_pipeline dmmv = ggml_vk_get_dequantize_mul_mat_vec(ctx, src0->type, src1->type);
|
||||
vk_pipeline dmmv = ggml_vk_get_dequantize_mul_mat_vec(ctx, src0->type, src1->type, ne11);
|
||||
GGML_ASSERT(!qx_needs_dequant || to_fp16_vk_0 != nullptr); // NOLINT
|
||||
GGML_ASSERT(!qy_needs_dequant || to_fp16_vk_1 != nullptr); // NOLINT
|
||||
GGML_ASSERT(dmmv != nullptr);
|
||||
|
@ -4057,8 +4065,10 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context&
|
|||
ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_1, src1, { d_Qy, qy_buf_offset, VK_WHOLE_SIZE }, { d_Y, 0, VK_WHOLE_SIZE });
|
||||
}
|
||||
|
||||
uint32_t stride_batch_x = ne00*ne01;
|
||||
uint32_t stride_batch_y = ne10*ne11;
|
||||
// For batch_n, the A matrix is the same for each batch, and B/D use the row stride as the batch stride
|
||||
uint32_t stride_batch_x = batch_n ? 0 : ne00*ne01;
|
||||
uint32_t stride_batch_y = batch_n ? ne10 : (ne10*ne11);
|
||||
uint32_t stride_batch_d = batch_n ? ne20 : (ne20*ne21);
|
||||
|
||||
if (!ggml_vk_dim01_contiguous(src0) && !qx_needs_dequant) {
|
||||
stride_batch_x = src0->nb[0] / ggml_type_size(src0->type);
|
||||
|
@ -4081,7 +4091,7 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context&
|
|||
// compute
|
||||
const vk_mat_vec_push_constants pc = {
|
||||
(uint32_t)ne00, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)ne01,
|
||||
stride_batch_x, stride_batch_y, (uint32_t)(ne20*ne21),
|
||||
stride_batch_x, stride_batch_y, stride_batch_d,
|
||||
(uint32_t)ne02, (uint32_t)ne12, (uint32_t)r2, (uint32_t)r3,
|
||||
};
|
||||
ggml_vk_sync_buffers(subctx);
|
||||
|
@ -4261,7 +4271,10 @@ static void ggml_vk_mul_mat(ggml_backend_vk_context * ctx, vk_context& subctx, c
|
|||
} else if (src0->type == GGML_TYPE_F16 && !ggml_is_contiguous(src0) && !ggml_is_transposed(src1) && dst->ne[1] == 1 &&
|
||||
!ggml_is_permuted(src0) && !ggml_is_permuted(src1)) {
|
||||
ggml_vk_mul_mat_vec_nc_f16_f32(ctx, subctx, src0, src1, dst, dryrun);
|
||||
} else if (dst->ne[1] == 1 && (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type))) {
|
||||
// mul_mat_vec supports batching ne12*ne13 when ne11==1, or treating ne11 as the batch size (up to four)
|
||||
// when ne12 and ne13 are one.
|
||||
} else if ((dst->ne[1] == 1 || (dst->ne[1] <= mul_mat_vec_max_cols && src1->ne[2] * src1->ne[3] == 1)) &&
|
||||
(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type))) {
|
||||
ggml_vk_mul_mat_vec_q_f16(ctx, subctx, src0, src1, dst, dryrun);
|
||||
} else {
|
||||
ggml_vk_mul_mat_q_f16(ctx, subctx, src0, src1, dst, dryrun);
|
||||
|
@ -5076,6 +5089,57 @@ static bool ggml_vk_op_supports_incontiguous(ggml_op op) {
|
|||
}
|
||||
}
|
||||
|
||||
static uint32_t get_misalign_bytes(ggml_backend_vk_context * ctx, const ggml_tensor * t)
|
||||
{
|
||||
return ((vk_tensor_offset(t) + t->view_offs) & (ctx->device->properties.limits.minStorageBufferOffsetAlignment - 1));;
|
||||
}
|
||||
|
||||
template <typename T> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, T &p, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst) {
|
||||
GGML_UNUSED(p);
|
||||
GGML_UNUSED(src0);
|
||||
GGML_UNUSED(src1);
|
||||
GGML_UNUSED(src2);
|
||||
GGML_UNUSED(dst);
|
||||
static_assert(!std::is_const<T>::value, "unexpected type");
|
||||
GGML_ASSERT(!src0 || get_misalign_bytes(ctx, src0) == 0);
|
||||
GGML_ASSERT(!src1 || get_misalign_bytes(ctx, src1) == 0);
|
||||
GGML_ASSERT(!src2 || get_misalign_bytes(ctx, src2) == 0);
|
||||
GGML_ASSERT(!dst || get_misalign_bytes(ctx, dst) == 0);
|
||||
}
|
||||
|
||||
template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk_op_unary_push_constants &p, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst) {
|
||||
const uint32_t a_offset = get_misalign_bytes(ctx, src0) / ggml_type_size(src0->type);
|
||||
const uint32_t d_offset = get_misalign_bytes(ctx, dst) / ggml_type_size(dst->type);
|
||||
|
||||
p.misalign_offsets = (a_offset << 16) | d_offset;
|
||||
|
||||
GGML_UNUSED(src1);
|
||||
GGML_UNUSED(src2);
|
||||
}
|
||||
|
||||
template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk_op_binary_push_constants &p, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst) {
|
||||
const uint32_t a_offset = get_misalign_bytes(ctx, src0) / ggml_type_size(src0->type);
|
||||
const uint32_t b_offset = get_misalign_bytes(ctx, src1) / ggml_type_size(src1->type);
|
||||
const uint32_t d_offset = get_misalign_bytes(ctx, dst) / ggml_type_size(dst->type);
|
||||
|
||||
GGML_ASSERT(dst->op != GGML_OP_GET_ROWS || (a_offset == 0 && b_offset == 0 && d_offset == 0));
|
||||
|
||||
p.misalign_offsets = (a_offset << 16) | (b_offset << 8) | d_offset;
|
||||
|
||||
GGML_UNUSED(src2);
|
||||
}
|
||||
|
||||
template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk_op_upscale_push_constants &p, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst) {
|
||||
const uint32_t a_offset = get_misalign_bytes(ctx, src0) / ggml_type_size(src0->type);
|
||||
const uint32_t d_offset = get_misalign_bytes(ctx, dst) / ggml_type_size(dst->type);
|
||||
|
||||
p.a_offset = a_offset;
|
||||
p.d_offset = d_offset;
|
||||
|
||||
GGML_UNUSED(src1);
|
||||
GGML_UNUSED(src2);
|
||||
}
|
||||
|
||||
template<typename PC>
|
||||
static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, ggml_op op, PC&& pc, bool dryrun = false) {
|
||||
VK_LOG_DEBUG("ggml_vk_op_f32((" << src0 << ", name=" << src0->name << ", type=" << src0->type << ", ne0=" << src0->ne[0] << ", ne1=" << src0->ne[1] << ", ne2=" << src0->ne[2] << ", ne3=" << src0->ne[3] << ", nb0=" << src0->nb[0] << ", nb1=" << src0->nb[1] << ", nb2=" << src0->nb[2] << ", nb3=" << src0->nb[3];
|
||||
|
@ -5179,8 +5243,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
|
|||
}
|
||||
|
||||
GGML_ASSERT(d_D != nullptr);
|
||||
uint64_t d_buf_offset = ((vk_tensor_offset(dst) + dst->view_offs) / ctx->device->properties.limits.minStorageBufferOffsetAlignment) * ctx->device->properties.limits.minStorageBufferOffsetAlignment;
|
||||
GGML_ASSERT(d_buf_offset == vk_tensor_offset(dst) || op == GGML_OP_CPY); // NOLINT
|
||||
uint64_t d_buf_offset = vk_tensor_offset(dst) + dst->view_offs;
|
||||
if(!src0_uma) {
|
||||
d_X = src0_buf_ctx->dev_buffer;
|
||||
x_buf_offset = vk_tensor_offset(src0) + src0->view_offs;
|
||||
|
@ -5196,6 +5259,12 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
|
|||
z_buf_offset = vk_tensor_offset(src2) + src2->view_offs;
|
||||
GGML_ASSERT(d_Z != nullptr);
|
||||
}
|
||||
// Compute misalignment offset for descriptors and store it in in push constants, then align the descriptor offsets.
|
||||
init_pushconst_tensor_offsets(ctx, pc, src0, src1, src2, dst);
|
||||
x_buf_offset &= ~(ctx->device->properties.limits.minStorageBufferOffsetAlignment - 1);
|
||||
y_buf_offset &= ~(ctx->device->properties.limits.minStorageBufferOffsetAlignment - 1);
|
||||
z_buf_offset &= ~(ctx->device->properties.limits.minStorageBufferOffsetAlignment - 1);
|
||||
d_buf_offset &= ~(ctx->device->properties.limits.minStorageBufferOffsetAlignment - 1);
|
||||
|
||||
if (op_supports_incontiguous) {
|
||||
x_sz = ggml_nbytes(src0);
|
||||
|
@ -5383,7 +5452,6 @@ static void ggml_vk_acc(ggml_backend_vk_context * ctx, vk_context& subctx, const
|
|||
const uint32_t src0_type_size = ggml_type_size(src0->type);
|
||||
const uint32_t src1_type_size = ggml_type_size(src1->type);
|
||||
const uint32_t dst_type_size = ggml_type_size(dst->type);
|
||||
const uint32_t d_offset = ((vk_tensor_offset(dst) + dst->view_offs) % ctx->device->properties.limits.minStorageBufferOffsetAlignment) / dst_type_size;
|
||||
|
||||
int nb1 = dst->op_params[0] / 4; // 4 bytes of float32
|
||||
int nb2 = dst->op_params[1] / 4; // 4 bytes of float32
|
||||
|
@ -5395,7 +5463,7 @@ static void ggml_vk_acc(ggml_backend_vk_context * ctx, vk_context& subctx, const
|
|||
(uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)nb1, (uint32_t)nb2, (uint32_t)src0->nb[3] / src0_type_size,
|
||||
(uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size,
|
||||
(uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t)nb1, (uint32_t)nb2, (uint32_t) dst->nb[3] / dst_type_size,
|
||||
d_offset,
|
||||
0,
|
||||
0.0f, 0.0f, offset,
|
||||
}, dryrun);
|
||||
}
|
||||
|
@ -5599,7 +5667,7 @@ static void ggml_vk_upscale(ggml_backend_vk_context * ctx, vk_context& subctx, c
|
|||
const float sf3 = (float)dst->ne[3] / src0->ne[3];
|
||||
|
||||
ggml_vk_op_f32<vk_op_upscale_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_UPSCALE, {
|
||||
(uint32_t)ggml_nelements(dst), 0,
|
||||
(uint32_t)ggml_nelements(dst), 0, 0,
|
||||
(uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
|
||||
(uint32_t)dst->ne[0], (uint32_t)dst->ne[1], (uint32_t)dst->ne[2],(uint32_t)dst->ne[3],
|
||||
sf0, sf1, sf2, sf3,
|
||||
|
@ -5709,13 +5777,12 @@ static void ggml_vk_repeat(ggml_backend_vk_context * ctx, vk_context& subctx, co
|
|||
static void ggml_vk_cpy(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
|
||||
const uint32_t src0_type_size = ggml_type_size(src0->type);
|
||||
const uint32_t dst_type_size = ggml_type_size(dst->type);
|
||||
const uint32_t d_offset = ((vk_tensor_offset(dst) + dst->view_offs) % ctx->device->properties.limits.minStorageBufferOffsetAlignment) / dst_type_size;
|
||||
|
||||
ggml_vk_op_f32<vk_op_unary_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_CPY, {
|
||||
(uint32_t)ggml_nelements(src0),
|
||||
(uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
|
||||
(uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
|
||||
d_offset,
|
||||
0,
|
||||
0.0f, 0.0f,
|
||||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
}, dryrun);
|
||||
|
|
|
@ -21,9 +21,9 @@ void main() {
|
|||
get_indices(idx, i00, i01, i02, i03);
|
||||
|
||||
if (ox < p.ne10 && oy < p.ne11 && oz < p.ne12) {
|
||||
data_d[p.d_offset + dst_idx(i00, i01, i02, i03)] = D_TYPE(FLOAT_TYPE(data_a[src0_idx(i00, i01, i02, i03)]) + FLOAT_TYPE(data_b[ox + oy * p.ne10 + oz * p.ne10 * p.ne11]));
|
||||
data_d[get_doffset() + dst_idx(i00, i01, i02, i03)] = D_TYPE(FLOAT_TYPE(data_a[get_aoffset() + src0_idx(i00, i01, i02, i03)]) + FLOAT_TYPE(data_b[get_boffset() + ox + oy * p.ne10 + oz * p.ne10 * p.ne11]));
|
||||
} else {
|
||||
data_d[p.d_offset + dst_idx(i00, i01, i02, i03)] = D_TYPE(FLOAT_TYPE(data_a[src0_idx(i00, i01, i02, i03)]));
|
||||
data_d[get_doffset() + dst_idx(i00, i01, i02, i03)] = D_TYPE(FLOAT_TYPE(data_a[get_aoffset() + src0_idx(i00, i01, i02, i03)]));
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -22,7 +22,7 @@ void main() {
|
|||
uint i00, i01, i02, i03;
|
||||
get_indices(idx, i00, i01, i02, i03);
|
||||
|
||||
data_d[p.d_offset + dst_idx(i00, i01, i02, i03)] = D_TYPE(FLOAT_TYPE(data_a[src0_idx(i00, i01, i02, i03)]) + FLOAT_TYPE(data_b[src1_idx(i00, i01, i02, i03)]));
|
||||
data_d[get_doffset() + dst_idx(i00, i01, i02, i03)] = D_TYPE(FLOAT_TYPE(data_a[get_aoffset() + src0_idx(i00, i01, i02, i03)]) + FLOAT_TYPE(data_b[get_boffset() + src1_idx(i00, i01, i02, i03)]));
|
||||
|
||||
idx += num_threads;
|
||||
}
|
||||
|
|
|
@ -12,6 +12,6 @@ void main() {
|
|||
return;
|
||||
}
|
||||
|
||||
const FLOAT_TYPE val = FLOAT_TYPE(data_a[src0_idx(idx)]);
|
||||
data_d[p.d_offset + dst_idx(idx)] = D_TYPE(val < p.param1 ? p.param1 : (val > p.param2 ? p.param2 : val));
|
||||
const FLOAT_TYPE val = FLOAT_TYPE(data_a[get_aoffset() + src0_idx(idx)]);
|
||||
data_d[get_doffset() + dst_idx(idx)] = D_TYPE(val < p.param1 ? p.param1 : (val > p.param2 ? p.param2 : val));
|
||||
}
|
||||
|
|
|
@ -30,12 +30,12 @@ void main() {
|
|||
const bool is_src0 = i0 < p.ne00 && i1 < p.ne01 && i2 < p.ne02 && i3 < p.ne03;
|
||||
|
||||
#ifndef OPTIMIZATION_ERROR_WORKAROUND
|
||||
data_d[p.d_offset + dst_idx] = D_TYPE(is_src0 ? data_a[src0_idx] : data_b[src1_idx]);
|
||||
data_d[get_doffset() + dst_idx] = D_TYPE(is_src0 ? data_a[get_aoffset() + src0_idx] : data_b[get_boffset() + src1_idx]);
|
||||
#else
|
||||
if (is_src0) {
|
||||
data_d[p.d_offset + dst_idx] = data_a[src0_idx];
|
||||
data_d[get_doffset() + dst_idx] = data_a[get_aoffset() + src0_idx];
|
||||
} else {
|
||||
data_d[p.d_offset + dst_idx] = data_b[src1_idx];
|
||||
data_d[get_doffset() + dst_idx] = data_b[get_boffset() + src1_idx];
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
|
|
@ -19,9 +19,9 @@ void main() {
|
|||
if (idx + (num_iter-1)*num_threads < p.ne) {
|
||||
[[unroll]] for (uint i = 0; i < num_iter; ++i) {
|
||||
#ifndef OPTIMIZATION_ERROR_WORKAROUND
|
||||
data_d[p.d_offset + idx] = D_TYPE(data_a[idx]);
|
||||
data_d[get_doffset() + idx] = D_TYPE(data_a[get_aoffset() + idx]);
|
||||
#else
|
||||
data_d[p.d_offset + idx] = data_a[idx];
|
||||
data_d[get_doffset() + idx] = data_a[get_aoffset() + idx];
|
||||
#endif
|
||||
idx += num_threads;
|
||||
}
|
||||
|
@ -32,9 +32,9 @@ void main() {
|
|||
}
|
||||
|
||||
#ifndef OPTIMIZATION_ERROR_WORKAROUND
|
||||
data_d[p.d_offset + idx] = D_TYPE(data_a[idx]);
|
||||
data_d[get_doffset() + idx] = D_TYPE(data_a[get_aoffset() + idx]);
|
||||
#else
|
||||
data_d[p.d_offset + idx] = data_a[idx];
|
||||
data_d[get_doffset() + idx] = data_a[get_aoffset() + idx];
|
||||
#endif
|
||||
idx += num_threads;
|
||||
}
|
||||
|
|
|
@ -13,8 +13,8 @@ void main() {
|
|||
}
|
||||
|
||||
#ifndef OPTIMIZATION_ERROR_WORKAROUND
|
||||
data_d[p.d_offset + dst_idx(idx)] = D_TYPE(data_a[src0_idx(idx)]);
|
||||
data_d[get_doffset() + dst_idx(idx)] = D_TYPE(data_a[get_aoffset() + src0_idx(idx)]);
|
||||
#else
|
||||
data_d[p.d_offset + dst_idx(idx)] = data_a[src0_idx(idx)];
|
||||
data_d[get_doffset() + dst_idx(idx)] = data_a[get_aoffset() + src0_idx(idx)];
|
||||
#endif
|
||||
}
|
||||
|
|
|
@ -12,6 +12,6 @@ void main() {
|
|||
return;
|
||||
}
|
||||
|
||||
const FLOAT_TYPE val = FLOAT_TYPE(data_a[src0_idx(idx)]);
|
||||
data_d[p.d_offset + dst_idx(idx)] = D_TYPE(cos(val));
|
||||
const FLOAT_TYPE val = FLOAT_TYPE(data_a[get_aoffset() + src0_idx(idx)]);
|
||||
data_d[get_doffset() + dst_idx(idx)] = D_TYPE(cos(val));
|
||||
}
|
||||
|
|
|
@ -20,7 +20,7 @@ void main() {
|
|||
uint i00, i01, i02, i03;
|
||||
get_indices(idx, i00, i01, i02, i03);
|
||||
|
||||
data_d[p.d_offset + dst_idx(i00, i01, i02, i03)] = D_TYPE(FLOAT_TYPE(data_a[src0_idx(i00, i01, i02, i03)]) / FLOAT_TYPE(data_b[src1_idx(i00, i01, i02, i03)]));
|
||||
data_d[get_doffset() + dst_idx(i00, i01, i02, i03)] = D_TYPE(FLOAT_TYPE(data_a[get_aoffset() + src0_idx(i00, i01, i02, i03)]) / FLOAT_TYPE(data_b[get_boffset() + src1_idx(i00, i01, i02, i03)]));
|
||||
|
||||
idx += num_threads;
|
||||
}
|
||||
|
|
|
@ -7,7 +7,7 @@ layout (push_constant) uniform parameter
|
|||
uint ne00; uint ne01; uint ne02; uint ne03; uint nb00; uint nb01; uint nb02; uint nb03;
|
||||
uint ne10; uint ne11; uint ne12; uint ne13; uint nb10; uint nb11; uint nb12; uint nb13;
|
||||
uint ne20; uint ne21; uint ne22; uint ne23; uint nb20; uint nb21; uint nb22; uint nb23;
|
||||
uint d_offset;
|
||||
uint misalign_offsets;
|
||||
float param1; float param2; int param3;
|
||||
} p;
|
||||
|
||||
|
@ -22,6 +22,10 @@ uint get_idx() {
|
|||
return gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
|
||||
}
|
||||
|
||||
uint get_aoffset() { return p.misalign_offsets >> 16; }
|
||||
uint get_boffset() { return (p.misalign_offsets >> 8) & 0xFF; }
|
||||
uint get_doffset() { return p.misalign_offsets & 0xFF; }
|
||||
|
||||
// mod and div are expensive and coordinates/dimensions are often power of 2 or equal to 1
|
||||
uint fastmod(uint a, uint b) {
|
||||
if ((b & (b-1)) == 0) {
|
||||
|
|
|
@ -6,7 +6,7 @@ layout (push_constant) uniform parameter
|
|||
uint ne;
|
||||
uint ne00; uint ne01; uint ne02; uint ne03; uint nb00; uint nb01; uint nb02; uint nb03;
|
||||
uint ne10; uint ne11; uint ne12; uint ne13; uint nb10; uint nb11; uint nb12; uint nb13;
|
||||
uint d_offset;
|
||||
uint misalign_offsets;
|
||||
float param1; float param2;
|
||||
|
||||
uint ne0_012mp; uint ne0_012L;
|
||||
|
@ -24,6 +24,9 @@ uint get_idx() {
|
|||
return gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
|
||||
}
|
||||
|
||||
uint get_aoffset() { return p.misalign_offsets >> 16; }
|
||||
uint get_doffset() { return p.misalign_offsets & 0xFFFF; }
|
||||
|
||||
// see init_fastdiv_values in ggml-vulkan.cpp
|
||||
uint fastdiv(uint n, uint mp, uint L) {
|
||||
uint msbs, lsbs;
|
||||
|
|
|
@ -15,10 +15,10 @@ void main() {
|
|||
return;
|
||||
}
|
||||
|
||||
const uint i01 = data_b[i10*p.nb10 + i11*p.nb11 + i12*p.nb12];
|
||||
const uint i01 = data_b[get_boffset() + i10*p.nb10 + i11*p.nb11 + i12*p.nb12];
|
||||
|
||||
const uint a_offset = i01*p.nb01 + i11*p.nb02 + i12*p.nb03;
|
||||
const uint d_offset = i10*p.nb21 + i11*p.nb22 + i12*p.nb23;
|
||||
const uint a_offset = get_aoffset() + i01*p.nb01 + i11*p.nb02 + i12*p.nb03;
|
||||
const uint d_offset = get_doffset() + i10*p.nb21 + i11*p.nb22 + i12*p.nb23;
|
||||
|
||||
#ifndef OPTIMIZATION_ERROR_WORKAROUND
|
||||
data_d[d_offset + i00] = D_TYPE(data_a[a_offset + i00]);
|
||||
|
|
|
@ -2,6 +2,7 @@
|
|||
|
||||
#extension GL_EXT_shader_16bit_storage : require
|
||||
#extension GL_EXT_spirv_intrinsics: enable
|
||||
#extension GL_EXT_control_flow_attributes : require
|
||||
|
||||
#if RTE16
|
||||
spirv_execution_mode(capabilities = [4467], 4462, 16); // RoundingModeRTE, 16 bits
|
||||
|
@ -23,40 +24,64 @@ layout (push_constant) uniform parameter
|
|||
|
||||
#include "types.comp"
|
||||
|
||||
#define BLOCK_SIZE 256
|
||||
layout(constant_id = 0) const uint BLOCK_SIZE = 32;
|
||||
|
||||
layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in;
|
||||
const uint NUM_ITER = 512 / BLOCK_SIZE;
|
||||
|
||||
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
|
||||
|
||||
layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
|
||||
layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
|
||||
|
||||
void main() {
|
||||
const uint i = gl_GlobalInvocationID.x;
|
||||
if (i >= p.pelements) {
|
||||
return;
|
||||
const uint gidx = gl_GlobalInvocationID.x;
|
||||
|
||||
const uint oh = gl_GlobalInvocationID.y;
|
||||
const uint batch = gl_GlobalInvocationID.z / p.IC;
|
||||
const uint ic = gl_GlobalInvocationID.z % p.IC;
|
||||
|
||||
A_TYPE values[NUM_ITER];
|
||||
uint offset_dst[NUM_ITER];
|
||||
[[unroll]] for (uint idx = 0; idx < NUM_ITER; ++idx) {
|
||||
values[idx] = A_TYPE(0);
|
||||
}
|
||||
|
||||
[[unroll]] for (uint idx = 0; idx < NUM_ITER; ++idx) {
|
||||
|
||||
const uint i = gidx * NUM_ITER + idx;
|
||||
|
||||
const uint ksize = p.OW * (p.KH > 1 ? p.KW : 1);
|
||||
const uint kx = i / ksize;
|
||||
const uint kd = kx * ksize;
|
||||
const uint ky = (i - kd) / p.OW;
|
||||
const uint ix = i % p.OW;
|
||||
|
||||
const uint oh = gl_GlobalInvocationID.y;
|
||||
const uint batch = gl_GlobalInvocationID.z / p.IC;
|
||||
const uint ic = gl_GlobalInvocationID.z % p.IC;
|
||||
|
||||
const uint iiw = ix * p.s0 + kx * p.d0 - p.p0;
|
||||
const uint iih = oh * p.s1 + ky * p.d1 - p.p1;
|
||||
|
||||
const uint offset_dst =
|
||||
offset_dst[idx] =
|
||||
((batch * p.OH + oh) * p.OW + ix) * p.CHW +
|
||||
(ic * (p.KW * p.KH) + ky * p.KW + kx);
|
||||
|
||||
if (iih < 0 || iih >= p.IH || iiw < 0 || iiw >= p.IW) {
|
||||
data_d[offset_dst] = D_TYPE(0.0f);
|
||||
} else {
|
||||
if (i >= p.pelements) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (iih < p.IH && iiw < p.IW) {
|
||||
const uint offset_src = ic * p.offset_delta + batch * p.batch_offset;
|
||||
data_d[offset_dst] = D_TYPE(data_a[offset_src + iih * p.IW + iiw]);
|
||||
values[idx] = data_a[offset_src + iih * p.IW + iiw];
|
||||
}
|
||||
}
|
||||
|
||||
[[unroll]] for (uint idx = 0; idx < NUM_ITER; ++idx) {
|
||||
|
||||
const uint i = gidx * NUM_ITER + idx;
|
||||
|
||||
if (i >= p.pelements) {
|
||||
continue;
|
||||
}
|
||||
|
||||
data_d[offset_dst[idx]] = D_TYPE(values[idx]);
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -20,7 +20,7 @@ void main() {
|
|||
uint i00, i01, i02, i03;
|
||||
get_indices(idx, i00, i01, i02, i03);
|
||||
|
||||
data_d[p.d_offset + dst_idx(i00, i01, i02, i03)] = D_TYPE(FLOAT_TYPE(data_a[src0_idx(i00, i01, i02, i03)]) * FLOAT_TYPE(data_b[src1_idx(i00, i01, i02, i03)]));
|
||||
data_d[get_doffset() + dst_idx(i00, i01, i02, i03)] = D_TYPE(FLOAT_TYPE(data_a[get_aoffset() + src0_idx(i00, i01, i02, i03)]) * FLOAT_TYPE(data_b[get_boffset() + src1_idx(i00, i01, i02, i03)]));
|
||||
|
||||
idx += num_threads;
|
||||
}
|
||||
|
|
|
@ -9,9 +9,6 @@
|
|||
|
||||
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
|
||||
|
||||
layout (constant_id = 0) const uint BLOCK_SIZE = 32;
|
||||
layout (constant_id = 1) const uint NUM_ROWS = 1;
|
||||
|
||||
#if !defined(DATA_A_F32) && !defined(DATA_A_F16)
|
||||
#define K_PER_ITER 8
|
||||
#else
|
||||
|
@ -21,23 +18,22 @@ layout (constant_id = 1) const uint NUM_ROWS = 1;
|
|||
|
||||
uint a_offset, b_offset, d_offset, y_offset;
|
||||
|
||||
shared FLOAT_TYPE tmpsh[NUM_ROWS][BLOCK_SIZE];
|
||||
|
||||
void iter(inout FLOAT_TYPE temp[NUM_ROWS], const uint first_row, const uint num_rows, const uint tid, const uint i, bool lastiter)
|
||||
void iter(inout FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const uint first_row, const uint num_rows, const uint tid, const uint i, bool lastiter)
|
||||
{
|
||||
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
|
||||
const uint col = i*BLOCK_SIZE + K_PER_ITER*tid;
|
||||
const uint iqs = (col%QUANT_K)/QUANT_R; // quant index
|
||||
const uint iybs = col - col%QUANT_K; // y block start index
|
||||
|
||||
#if K_PER_ITER == 8
|
||||
#if QUANT_R == 2
|
||||
const B_TYPE_VEC4 bv02 = data_b_v4[(b_offset + iybs + iqs) / 4];
|
||||
const B_TYPE_VEC4 bv13 = data_b_v4[(b_offset + iybs + iqs + y_offset) / 4];
|
||||
const B_TYPE_VEC4 bv02 = data_b_v4[(j*p.batch_stride_b + b_offset + iybs + iqs) / 4];
|
||||
const B_TYPE_VEC4 bv13 = data_b_v4[(j*p.batch_stride_b + b_offset + iybs + iqs + y_offset) / 4];
|
||||
const vec4 bv0 = vec4(bv02.x, bv13.x, bv02.y, bv13.y);
|
||||
const vec4 bv1 = vec4(bv02.z, bv13.z, bv02.w, bv13.w);
|
||||
#else
|
||||
const vec4 bv0 = vec4(data_b_v4[(b_offset + iybs + iqs) / 4]);
|
||||
const vec4 bv1 = vec4(data_b_v4[(b_offset + iybs + iqs) / 4 + 1]);
|
||||
const vec4 bv0 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + iybs + iqs) / 4]);
|
||||
const vec4 bv1 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + iybs + iqs) / 4 + 1]);
|
||||
#endif
|
||||
#else
|
||||
// Check if the second of the pair of elements is OOB, and don't fetch B or
|
||||
|
@ -48,9 +44,9 @@ void iter(inout FLOAT_TYPE temp[NUM_ROWS], const uint first_row, const uint num_
|
|||
const bool OOB = lastiter && (iybs + iqs + y_offset >= p.ncols);
|
||||
|
||||
FLOAT_TYPE b0 = 0, b1 = 0;
|
||||
b0 = FLOAT_TYPE(data_b[b_offset + iybs + iqs]);
|
||||
b0 = FLOAT_TYPE(data_b[j*p.batch_stride_b + b_offset + iybs + iqs]);
|
||||
if (!OOB) {
|
||||
b1 = FLOAT_TYPE(data_b[b_offset + iybs + iqs + y_offset]);
|
||||
b1 = FLOAT_TYPE(data_b[j*p.batch_stride_b + b_offset + iybs + iqs + y_offset]);
|
||||
}
|
||||
#endif
|
||||
uint ibi = first_row*p.ncols;
|
||||
|
@ -75,18 +71,19 @@ void iter(inout FLOAT_TYPE temp[NUM_ROWS], const uint first_row, const uint num_
|
|||
if (dm.y == 0)
|
||||
rowtmp *= dm.x;
|
||||
|
||||
temp[n] += rowtmp;
|
||||
temp[j][n] += rowtmp;
|
||||
#else
|
||||
const vec2 v = dequantize(ib, iqs, a_offset);
|
||||
|
||||
// matrix multiplication
|
||||
temp[n] = fma(FLOAT_TYPE(v.x), b0, temp[n]);
|
||||
temp[j][n] = fma(FLOAT_TYPE(v.x), b0, temp[j][n]);
|
||||
if (!OOB) {
|
||||
temp[n] = fma(FLOAT_TYPE(v.y), b1, temp[n]);
|
||||
temp[j][n] = fma(FLOAT_TYPE(v.y), b1, temp[j][n]);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
|
||||
const uint tid = gl_LocalInvocationID.x;
|
||||
|
@ -96,10 +93,12 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
|
|||
|
||||
y_offset = QUANT_R == 1 ? 1 : QUANT_K/2;
|
||||
|
||||
FLOAT_TYPE temp[NUM_ROWS];
|
||||
FLOAT_TYPE temp[NUM_COLS][NUM_ROWS];
|
||||
|
||||
for (uint i = 0; i < NUM_ROWS; ++i) {
|
||||
temp[i] = FLOAT_TYPE(0);
|
||||
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
|
||||
[[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) {
|
||||
temp[j][i] = FLOAT_TYPE(0);
|
||||
}
|
||||
}
|
||||
|
||||
uint num_iters = p.ncols / (K_PER_ITER * BLOCK_SIZE);
|
||||
|
@ -131,24 +130,7 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
|
|||
i++;
|
||||
}
|
||||
|
||||
// sum up partial sums and write back result
|
||||
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
|
||||
tmpsh[n][tid] = temp[n];
|
||||
}
|
||||
barrier();
|
||||
[[unroll]] for (uint s = BLOCK_SIZE/2; s > 0; s >>= 1) {
|
||||
if (tid < s) {
|
||||
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
|
||||
tmpsh[n][tid] += tmpsh[n][tid + s];
|
||||
}
|
||||
}
|
||||
barrier();
|
||||
}
|
||||
if (tid == 0) {
|
||||
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
|
||||
data_d[d_offset + first_row + n] = D_TYPE(tmpsh[n][0]);
|
||||
}
|
||||
}
|
||||
reduce_result(temp, d_offset, first_row, num_rows, tid);
|
||||
}
|
||||
|
||||
void main() {
|
||||
|
|
|
@ -83,3 +83,36 @@ void get_offsets(out uint a_offset, out uint b_offset, out uint d_offset) {
|
|||
batch_idx * p.batch_stride_d;
|
||||
#endif
|
||||
}
|
||||
|
||||
layout (constant_id = 0) const uint BLOCK_SIZE = 32;
|
||||
layout (constant_id = 1) const uint NUM_ROWS = 1;
|
||||
layout (constant_id = 2) const uint NUM_COLS = 1;
|
||||
|
||||
shared FLOAT_TYPE tmpsh[NUM_COLS][NUM_ROWS][BLOCK_SIZE];
|
||||
|
||||
void reduce_result(const in FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const in uint32_t d_offset, const in uint32_t first_row, const in uint32_t num_rows, const in uint32_t tid) {
|
||||
// sum up partial sums and write back result
|
||||
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
|
||||
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
|
||||
tmpsh[j][n][tid] = temp[j][n];
|
||||
}
|
||||
}
|
||||
barrier();
|
||||
[[unroll]] for (uint s = BLOCK_SIZE/2; s > 0; s >>= 1) {
|
||||
if (tid < s) {
|
||||
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
|
||||
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
|
||||
tmpsh[j][n][tid] += tmpsh[j][n][tid + s];
|
||||
}
|
||||
}
|
||||
}
|
||||
barrier();
|
||||
}
|
||||
if (tid == 0) {
|
||||
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
|
||||
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
|
||||
data_d[j*p.batch_stride_d + d_offset + first_row + n] = D_TYPE(tmpsh[j][n][0]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -5,11 +5,6 @@
|
|||
|
||||
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
|
||||
|
||||
layout (constant_id = 0) const uint BLOCK_SIZE = 32;
|
||||
layout (constant_id = 1) const uint NUM_ROWS = 1;
|
||||
|
||||
shared FLOAT_TYPE tmpsh[NUM_ROWS][BLOCK_SIZE];
|
||||
|
||||
void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
|
||||
uint a_offset, b_offset, d_offset;
|
||||
get_offsets(a_offset, b_offset, d_offset);
|
||||
|
@ -32,24 +27,17 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
|
|||
const uint s_offset = 8*v_im;
|
||||
const uint y_offset = 128*v_im + l0;
|
||||
|
||||
FLOAT_TYPE temp[NUM_ROWS];
|
||||
FLOAT_TYPE temp[NUM_COLS][NUM_ROWS];
|
||||
|
||||
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
|
||||
[[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) {
|
||||
temp[i] = FLOAT_TYPE(0);
|
||||
temp[j][i] = FLOAT_TYPE(0);
|
||||
}
|
||||
}
|
||||
|
||||
[[unroll]] for (uint i = ix; i < num_blocks_per_row; i += it_size) {
|
||||
const uint y_idx = i * QUANT_K + y_offset;
|
||||
|
||||
B_TYPE_VEC2 b0 = data_b_v2[(b_offset + y_idx) / 2 + 0];
|
||||
B_TYPE_VEC2 b16 = data_b_v2[(b_offset + y_idx) / 2 + 8];
|
||||
B_TYPE_VEC2 b32 = data_b_v2[(b_offset + y_idx) / 2 + 16];
|
||||
B_TYPE_VEC2 b48 = data_b_v2[(b_offset + y_idx) / 2 + 24];
|
||||
B_TYPE_VEC2 b64 = data_b_v2[(b_offset + y_idx) / 2 + 32];
|
||||
B_TYPE_VEC2 b80 = data_b_v2[(b_offset + y_idx) / 2 + 40];
|
||||
B_TYPE_VEC2 b96 = data_b_v2[(b_offset + y_idx) / 2 + 48];
|
||||
B_TYPE_VEC2 b112 = data_b_v2[(b_offset + y_idx) / 2 + 56];
|
||||
|
||||
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
|
||||
const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row;
|
||||
f16vec2 d = data_a[ib0 + i].d;
|
||||
|
@ -74,6 +62,16 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
|
|||
uvec2 qs0 = uvec2(unpack8(qs0_u16));
|
||||
uvec2 qs16 = uvec2(unpack8(qs16_u16));
|
||||
|
||||
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
|
||||
B_TYPE_VEC2 b0 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 0];
|
||||
B_TYPE_VEC2 b16 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 8];
|
||||
B_TYPE_VEC2 b32 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 16];
|
||||
B_TYPE_VEC2 b48 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 24];
|
||||
B_TYPE_VEC2 b64 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 32];
|
||||
B_TYPE_VEC2 b80 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 40];
|
||||
B_TYPE_VEC2 b96 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 48];
|
||||
B_TYPE_VEC2 b112 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 56];
|
||||
|
||||
FLOAT_TYPE sum1 = FLOAT_TYPE(0.0);
|
||||
FLOAT_TYPE sum2 = FLOAT_TYPE(0.0);
|
||||
[[unroll]] for (int l = 0; l < 2; ++l) {
|
||||
|
@ -94,28 +92,12 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
|
|||
fma(FLOAT_TYPE(b96[l]), FLOAT_TYPE(s4_hi4[2]),
|
||||
fma(FLOAT_TYPE(b112[l]), FLOAT_TYPE(s4_hi4[3]), sum2))))))));
|
||||
}
|
||||
temp[n] = fma(dall, sum1, fma(-dmin, sum2, temp[n]));
|
||||
temp[j][n] = fma(dall, sum1, fma(-dmin, sum2, temp[j][n]));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// sum up partial sums and write back result
|
||||
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
|
||||
tmpsh[n][tid] = temp[n];
|
||||
}
|
||||
barrier();
|
||||
[[unroll]] for (uint s = BLOCK_SIZE/2; s > 0; s >>= 1) {
|
||||
if (tid < s) {
|
||||
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
|
||||
tmpsh[n][tid] += tmpsh[n][tid + s];
|
||||
}
|
||||
}
|
||||
barrier();
|
||||
}
|
||||
if (tid == 0) {
|
||||
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
|
||||
data_d[d_offset + first_row + n] = D_TYPE(tmpsh[n][0]);
|
||||
}
|
||||
}
|
||||
reduce_result(temp, d_offset, first_row, num_rows, tid);
|
||||
}
|
||||
|
||||
void main() {
|
||||
|
|
|
@ -5,11 +5,6 @@
|
|||
|
||||
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
|
||||
|
||||
layout (constant_id = 0) const uint BLOCK_SIZE = 32;
|
||||
layout (constant_id = 1) const uint NUM_ROWS = 1;
|
||||
|
||||
shared FLOAT_TYPE tmpsh[NUM_ROWS][BLOCK_SIZE];
|
||||
|
||||
void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
|
||||
uint a_offset, b_offset, d_offset;
|
||||
get_offsets(a_offset, b_offset, d_offset);
|
||||
|
@ -33,10 +28,12 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
|
|||
const uint q_offset = 32*v_im + l0;
|
||||
const uint y_offset = 128*v_im + l0;
|
||||
|
||||
FLOAT_TYPE temp[NUM_ROWS];
|
||||
FLOAT_TYPE temp[NUM_COLS][NUM_ROWS];
|
||||
|
||||
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
|
||||
[[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) {
|
||||
temp[i] = FLOAT_TYPE(0);
|
||||
temp[j][i] = FLOAT_TYPE(0);
|
||||
}
|
||||
}
|
||||
|
||||
const uint s_shift = 4 * v_im;
|
||||
|
@ -44,15 +41,6 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
|
|||
[[unroll]] for (uint i = ix; i < num_blocks_per_row; i += it_size) {
|
||||
const uint y_idx = i * QUANT_K + y_offset;
|
||||
|
||||
B_TYPE_VEC2 b0 = data_b_v2[(b_offset + y_idx) / 2 + 0];
|
||||
B_TYPE_VEC2 b16 = data_b_v2[(b_offset + y_idx) / 2 + 8];
|
||||
B_TYPE_VEC2 b32 = data_b_v2[(b_offset + y_idx) / 2 + 16];
|
||||
B_TYPE_VEC2 b48 = data_b_v2[(b_offset + y_idx) / 2 + 24];
|
||||
B_TYPE_VEC2 b64 = data_b_v2[(b_offset + y_idx) / 2 + 32];
|
||||
B_TYPE_VEC2 b80 = data_b_v2[(b_offset + y_idx) / 2 + 40];
|
||||
B_TYPE_VEC2 b96 = data_b_v2[(b_offset + y_idx) / 2 + 48];
|
||||
B_TYPE_VEC2 b112 = data_b_v2[(b_offset + y_idx) / 2 + 56];
|
||||
|
||||
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
|
||||
const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row;
|
||||
const FLOAT_TYPE d = FLOAT_TYPE(data_a[ib0 + i].d);
|
||||
|
@ -70,6 +58,17 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
|
|||
u8vec2 s8 = unpack8(s8_16);
|
||||
u8vec2 s10 = unpack8(s10_16);
|
||||
|
||||
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
|
||||
|
||||
B_TYPE_VEC2 b0 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 0];
|
||||
B_TYPE_VEC2 b16 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 8];
|
||||
B_TYPE_VEC2 b32 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 16];
|
||||
B_TYPE_VEC2 b48 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 24];
|
||||
B_TYPE_VEC2 b64 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 32];
|
||||
B_TYPE_VEC2 b80 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 40];
|
||||
B_TYPE_VEC2 b96 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 48];
|
||||
B_TYPE_VEC2 b112 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 56];
|
||||
|
||||
FLOAT_TYPE sum = FLOAT_TYPE(0.0);
|
||||
[[unroll]] for (int l = 0; l < 2; ++l) {
|
||||
sum = fma(FLOAT_TYPE(b0[l]) * FLOAT_TYPE(int8_t(((s0[0] >> s_shift) & 0xF) | ((s8[0] >> (s_shift + 0) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l ] ) & 3) - (((data_a[ib0 + i].hmask[l0 + l ] & (m << 0)) != 0) ? 0 : 4)),
|
||||
|
@ -81,28 +80,12 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
|
|||
fma(FLOAT_TYPE(b80[l]) * FLOAT_TYPE(int8_t(((s4[1] >> s_shift) & 0xF) | ((s8[1] >> (s_shift + 2) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l+16] >> 4) & 3) - (((data_a[ib0 + i].hmask[l0 + l+16] & (m << 2)) != 0) ? 0 : 4)),
|
||||
fma(FLOAT_TYPE(b112[l]) * FLOAT_TYPE(int8_t(((s6[1] >> s_shift) & 0xF) | ((s10[1] >> (s_shift + 2) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l+16] >> 6) & 3) - (((data_a[ib0 + i].hmask[l0 + l+16] & (m << 3)) != 0) ? 0 : 4)), sum))))))));
|
||||
}
|
||||
temp[n] = fma(d, sum, temp[n]);
|
||||
temp[j][n] = fma(d, sum, temp[j][n]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// sum up partial sums and write back result
|
||||
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
|
||||
tmpsh[n][tid] = temp[n];
|
||||
}
|
||||
barrier();
|
||||
[[unroll]] for (uint s = BLOCK_SIZE/2; s > 0; s >>= 1) {
|
||||
if (tid < s) {
|
||||
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
|
||||
tmpsh[n][tid] += tmpsh[n][tid + s];
|
||||
}
|
||||
}
|
||||
barrier();
|
||||
}
|
||||
if (tid == 0) {
|
||||
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
|
||||
data_d[d_offset + first_row + n] = D_TYPE(tmpsh[n][0]);
|
||||
}
|
||||
}
|
||||
reduce_result(temp, d_offset, first_row, num_rows, tid);
|
||||
}
|
||||
|
||||
void main() {
|
||||
|
|
|
@ -6,11 +6,6 @@
|
|||
|
||||
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
|
||||
|
||||
layout (constant_id = 0) const uint BLOCK_SIZE = 32;
|
||||
layout (constant_id = 1) const uint NUM_ROWS = 1;
|
||||
|
||||
shared FLOAT_TYPE tmpsh[NUM_ROWS][BLOCK_SIZE];
|
||||
|
||||
void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
|
||||
uint a_offset, b_offset, d_offset;
|
||||
get_offsets(a_offset, b_offset, d_offset);
|
||||
|
@ -36,21 +31,18 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
|
|||
const uint q_offset = 32*v_im + l0;
|
||||
const uint y_offset = 64*v_im + l0;
|
||||
|
||||
FLOAT_TYPE temp[NUM_ROWS];
|
||||
FLOAT_TYPE temp[NUM_COLS][NUM_ROWS];
|
||||
|
||||
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
|
||||
[[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) {
|
||||
temp[i] = FLOAT_TYPE(0);
|
||||
temp[j][i] = FLOAT_TYPE(0);
|
||||
}
|
||||
}
|
||||
|
||||
[[unroll]] for (uint i = ix; i < num_blocks_per_row; i += it_size) {
|
||||
const uint y1_idx = i * QUANT_K + y_offset;
|
||||
const uint y2_idx = y1_idx + 128;
|
||||
|
||||
B_TYPE_VEC4 by10 = data_b_v4[(b_offset + y1_idx) / 4];
|
||||
B_TYPE_VEC4 by132 = data_b_v4[(b_offset + y1_idx) / 4 + 8];
|
||||
B_TYPE_VEC4 by20 = data_b_v4[(b_offset + y2_idx) / 4];
|
||||
B_TYPE_VEC4 by232 = data_b_v4[(b_offset + y2_idx) / 4 + 8];
|
||||
|
||||
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
|
||||
const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row;
|
||||
f16vec2 d = data_a[ib0 + i].d;
|
||||
|
@ -103,6 +95,12 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
|
|||
const uint32_t q4_14 = qs64_hi4.z;
|
||||
const uint32_t q4_15 = qs64_hi4.w;
|
||||
|
||||
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
|
||||
B_TYPE_VEC4 by10 = data_b_v4[(j*p.batch_stride_b + b_offset + y1_idx) / 4];
|
||||
B_TYPE_VEC4 by132 = data_b_v4[(j*p.batch_stride_b + b_offset + y1_idx) / 4 + 8];
|
||||
B_TYPE_VEC4 by20 = data_b_v4[(j*p.batch_stride_b + b_offset + y2_idx) / 4];
|
||||
B_TYPE_VEC4 by232 = data_b_v4[(j*p.batch_stride_b + b_offset + y2_idx) / 4 + 8];
|
||||
|
||||
const FLOAT_TYPE sx = fma(FLOAT_TYPE(by10.x), q4_0, fma(FLOAT_TYPE(by10.y), q4_1, fma(FLOAT_TYPE(by10.z), q4_2, FLOAT_TYPE(by10.w) * q4_3)));
|
||||
const FLOAT_TYPE sy = fma(FLOAT_TYPE(by132.x), q4_4, fma(FLOAT_TYPE(by132.y), q4_5, fma(FLOAT_TYPE(by132.z), q4_6, FLOAT_TYPE(by132.w) * q4_7)));
|
||||
const FLOAT_TYPE sz = fma(FLOAT_TYPE(by20.x), q4_8, fma(FLOAT_TYPE(by20.y), q4_9, fma(FLOAT_TYPE(by20.z), q4_10, FLOAT_TYPE(by20.w) * q4_11)));
|
||||
|
@ -112,28 +110,12 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
|
|||
fma(FLOAT_TYPE(by10.y), sc2, fma(FLOAT_TYPE(by132.y), sc3, fma(FLOAT_TYPE(by20.y), sc6, fma(FLOAT_TYPE(by232.y), sc7,
|
||||
fma(FLOAT_TYPE(by10.z), sc2, fma(FLOAT_TYPE(by132.z), sc3, fma(FLOAT_TYPE(by20.z), sc6, fma(FLOAT_TYPE(by232.z), sc7,
|
||||
fma(FLOAT_TYPE(by10.w), sc2, fma(FLOAT_TYPE(by132.w), sc3, fma(FLOAT_TYPE(by20.w), sc6, FLOAT_TYPE(by232.w) * sc7)))))))))))))));
|
||||
temp[n] = fma(dall, fma(sx, sc0, fma(sy, sc1, fma(sz, sc4, sw * sc5))), fma(-dmin, smin, temp[n]));
|
||||
temp[j][n] = fma(dall, fma(sx, sc0, fma(sy, sc1, fma(sz, sc4, sw * sc5))), fma(-dmin, smin, temp[j][n]));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// sum up partial sums and write back result
|
||||
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
|
||||
tmpsh[n][tid] = temp[n];
|
||||
}
|
||||
barrier();
|
||||
[[unroll]] for (uint s = BLOCK_SIZE/2; s > 0; s >>= 1) {
|
||||
if (tid < s) {
|
||||
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
|
||||
tmpsh[n][tid] += tmpsh[n][tid + s];
|
||||
}
|
||||
}
|
||||
barrier();
|
||||
}
|
||||
if (tid == 0) {
|
||||
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
|
||||
data_d[d_offset + first_row + n] = D_TYPE(tmpsh[n][0]);
|
||||
}
|
||||
}
|
||||
reduce_result(temp, d_offset, first_row, num_rows, tid);
|
||||
}
|
||||
|
||||
void main() {
|
||||
|
|
|
@ -6,11 +6,6 @@
|
|||
|
||||
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
|
||||
|
||||
layout (constant_id = 0) const uint BLOCK_SIZE = 32;
|
||||
layout (constant_id = 1) const uint NUM_ROWS = 1;
|
||||
|
||||
shared FLOAT_TYPE tmpsh[NUM_ROWS][BLOCK_SIZE];
|
||||
|
||||
void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
|
||||
uint a_offset, b_offset, d_offset;
|
||||
get_offsets(a_offset, b_offset, d_offset);
|
||||
|
@ -33,25 +28,18 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
|
|||
const uint q_offset = 32*v_im + l0;
|
||||
const uint y_offset = 64*v_im + l0;
|
||||
|
||||
FLOAT_TYPE temp[NUM_ROWS];
|
||||
FLOAT_TYPE temp[NUM_COLS][NUM_ROWS];
|
||||
|
||||
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
|
||||
[[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) {
|
||||
temp[i] = FLOAT_TYPE(0);
|
||||
temp[j][i] = FLOAT_TYPE(0);
|
||||
}
|
||||
}
|
||||
|
||||
[[unroll]] for (uint i = ix; i < num_blocks_per_row; i += it_size) {
|
||||
const uint y1_idx = i * QUANT_K + y_offset;
|
||||
const uint y2_idx = y1_idx + 128;
|
||||
|
||||
B_TYPE_VEC2 by10 = data_b_v2[(b_offset + y1_idx) / 2];
|
||||
B_TYPE_VEC2 by116 = data_b_v2[(b_offset + y1_idx) / 2 + 8];
|
||||
B_TYPE_VEC2 by132 = data_b_v2[(b_offset + y1_idx) / 2 + 16];
|
||||
B_TYPE_VEC2 by148 = data_b_v2[(b_offset + y1_idx) / 2 + 24];
|
||||
B_TYPE_VEC2 by20 = data_b_v2[(b_offset + y2_idx) / 2];
|
||||
B_TYPE_VEC2 by216 = data_b_v2[(b_offset + y2_idx) / 2 + 8];
|
||||
B_TYPE_VEC2 by232 = data_b_v2[(b_offset + y2_idx) / 2 + 16];
|
||||
B_TYPE_VEC2 by248 = data_b_v2[(b_offset + y2_idx) / 2 + 24];
|
||||
|
||||
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
|
||||
const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row;
|
||||
f16vec2 d = data_a[ib0 + i].d;
|
||||
|
@ -116,6 +104,16 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
|
|||
const uint32_t q4_14 = qs64_80_hi4.z;
|
||||
const uint32_t q4_15 = qs64_80_hi4.w;
|
||||
|
||||
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
|
||||
B_TYPE_VEC2 by10 = data_b_v2[(j*p.batch_stride_b + b_offset + y1_idx) / 2];
|
||||
B_TYPE_VEC2 by116 = data_b_v2[(j*p.batch_stride_b + b_offset + y1_idx) / 2 + 8];
|
||||
B_TYPE_VEC2 by132 = data_b_v2[(j*p.batch_stride_b + b_offset + y1_idx) / 2 + 16];
|
||||
B_TYPE_VEC2 by148 = data_b_v2[(j*p.batch_stride_b + b_offset + y1_idx) / 2 + 24];
|
||||
B_TYPE_VEC2 by20 = data_b_v2[(j*p.batch_stride_b + b_offset + y2_idx) / 2];
|
||||
B_TYPE_VEC2 by216 = data_b_v2[(j*p.batch_stride_b + b_offset + y2_idx) / 2 + 8];
|
||||
B_TYPE_VEC2 by232 = data_b_v2[(j*p.batch_stride_b + b_offset + y2_idx) / 2 + 16];
|
||||
B_TYPE_VEC2 by248 = data_b_v2[(j*p.batch_stride_b + b_offset + y2_idx) / 2 + 24];
|
||||
|
||||
const FLOAT_TYPE sx =
|
||||
fma(FLOAT_TYPE(by10.x), q4_0,
|
||||
fma(FLOAT_TYPE(by10.y), q4_1,
|
||||
|
@ -141,28 +139,12 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
|
|||
fma(FLOAT_TYPE(by132.x) + FLOAT_TYPE(by132.y) + FLOAT_TYPE(by148.x) + FLOAT_TYPE(by148.y), sc3,
|
||||
fma(FLOAT_TYPE(by20.x) + FLOAT_TYPE(by20.y) + FLOAT_TYPE(by216.x) + FLOAT_TYPE(by216.y), sc6,
|
||||
(FLOAT_TYPE(by232.x) + FLOAT_TYPE(by232.y) + FLOAT_TYPE(by248.x) + FLOAT_TYPE(by248.y)) * sc7)));
|
||||
temp[n] = fma(dall, fma(sx, sc0, fma(sy, sc1, fma(sz, sc4, sw * sc5))), fma(-dmin, smin, temp[n]));
|
||||
temp[j][n] = fma(dall, fma(sx, sc0, fma(sy, sc1, fma(sz, sc4, sw * sc5))), fma(-dmin, smin, temp[j][n]));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// sum up partial sums and write back result
|
||||
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
|
||||
tmpsh[n][tid] = temp[n];
|
||||
}
|
||||
barrier();
|
||||
[[unroll]] for (uint s = BLOCK_SIZE/2; s > 0; s >>= 1) {
|
||||
if (tid < s) {
|
||||
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
|
||||
tmpsh[n][tid] += tmpsh[n][tid + s];
|
||||
}
|
||||
}
|
||||
barrier();
|
||||
}
|
||||
if (tid == 0) {
|
||||
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
|
||||
data_d[d_offset + first_row + n] = D_TYPE(tmpsh[n][0]);
|
||||
}
|
||||
}
|
||||
reduce_result(temp, d_offset, first_row, num_rows, tid);
|
||||
}
|
||||
|
||||
void main() {
|
||||
|
|
|
@ -6,11 +6,6 @@
|
|||
|
||||
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
|
||||
|
||||
layout (constant_id = 0) const uint BLOCK_SIZE = 32;
|
||||
layout (constant_id = 1) const uint NUM_ROWS = 1;
|
||||
|
||||
shared FLOAT_TYPE tmpsh[NUM_ROWS][BLOCK_SIZE];
|
||||
|
||||
void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
|
||||
uint a_offset, b_offset, d_offset;
|
||||
get_offsets(a_offset, b_offset, d_offset);
|
||||
|
@ -36,20 +31,17 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
|
|||
const uint s_offset = 8*v_im + is;
|
||||
const uint y_offset = 128*v_im + l0;
|
||||
|
||||
FLOAT_TYPE temp[NUM_ROWS];
|
||||
FLOAT_TYPE temp[NUM_COLS][NUM_ROWS];
|
||||
|
||||
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
|
||||
[[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) {
|
||||
temp[i] = FLOAT_TYPE(0);
|
||||
temp[j][i] = FLOAT_TYPE(0);
|
||||
}
|
||||
}
|
||||
|
||||
[[unroll]] for (uint i = ix; i < num_blocks_per_row; i += it_size) {
|
||||
const uint y_idx = i * QUANT_K + y_offset;
|
||||
|
||||
B_TYPE_VEC4 by0 = data_b_v4[(b_offset + y_idx) / 4];
|
||||
B_TYPE_VEC4 by32 = data_b_v4[(b_offset + y_idx) / 4 + 8];
|
||||
B_TYPE_VEC4 by64 = data_b_v4[(b_offset + y_idx) / 4 + 16];
|
||||
B_TYPE_VEC4 by96 = data_b_v4[(b_offset + y_idx) / 4 + 24];
|
||||
|
||||
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
|
||||
const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row;
|
||||
const FLOAT_TYPE d = FLOAT_TYPE(data_a[ib0 + i].d);
|
||||
|
@ -84,6 +76,12 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
|
|||
uvec4 q2 = uvec4(unpack8(q2_u32));
|
||||
uvec4 q3 = uvec4(unpack8(q3_u32));
|
||||
|
||||
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
|
||||
B_TYPE_VEC4 by0 = data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4];
|
||||
B_TYPE_VEC4 by32 = data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 8];
|
||||
B_TYPE_VEC4 by64 = data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 16];
|
||||
B_TYPE_VEC4 by96 = data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 24];
|
||||
|
||||
FLOAT_TYPE sum = FLOAT_TYPE(0.0);
|
||||
[[unroll]] for (int l = 0; l < 4; ++l) {
|
||||
sum = fma(FLOAT_TYPE(by0[l]) * scales[0], FLOAT_TYPE(int8_t(q0[l]) - 32),
|
||||
|
@ -91,28 +89,12 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
|
|||
fma(FLOAT_TYPE(by64[l]) * scales[2], FLOAT_TYPE(int8_t(q2[l]) - 32),
|
||||
fma(FLOAT_TYPE(by96[l]) * scales[3], FLOAT_TYPE(int8_t(q3[l]) - 32), sum))));
|
||||
}
|
||||
temp[n] += sum * d;
|
||||
temp[j][n] += sum * d;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// sum up partial sums and write back result
|
||||
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
|
||||
tmpsh[n][tid] = temp[n];
|
||||
}
|
||||
barrier();
|
||||
[[unroll]] for (uint s = BLOCK_SIZE/2; s > 0; s >>= 1) {
|
||||
if (tid < s) {
|
||||
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
|
||||
tmpsh[n][tid] += tmpsh[n][tid + s];
|
||||
}
|
||||
}
|
||||
barrier();
|
||||
}
|
||||
if (tid == 0) {
|
||||
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
|
||||
data_d[d_offset + first_row + n] = D_TYPE(tmpsh[n][0]);
|
||||
}
|
||||
}
|
||||
reduce_result(temp, d_offset, first_row, num_rows, tid);
|
||||
}
|
||||
|
||||
void main() {
|
||||
|
|
|
@ -24,5 +24,5 @@ void main() {
|
|||
|
||||
const bool is_src0 = i0 < p.ne00 && i1 < p.ne01 && i2 < p.ne02 && i3 < p.ne03;
|
||||
|
||||
data_d[p.d_offset + dst_idx] = D_TYPE(is_src0 ? data_a[src0_idx] : 0.0f);
|
||||
data_d[get_doffset() + dst_idx] = D_TYPE(is_src0 ? data_a[get_aoffset() + src0_idx] : 0.0f);
|
||||
}
|
||||
|
|
|
@ -22,5 +22,5 @@ void main() {
|
|||
return;
|
||||
}
|
||||
|
||||
data_d[p.d_offset + dst_idx(idx)] = D_TYPE(data_a[src0_idx_mod(idx)]);
|
||||
data_d[get_doffset() + dst_idx(idx)] = D_TYPE(data_a[get_aoffset() + src0_idx_mod(idx)]);
|
||||
}
|
||||
|
|
|
@ -18,7 +18,7 @@ void main() {
|
|||
continue;
|
||||
}
|
||||
|
||||
data_d[p.d_offset + idx] = D_TYPE(FLOAT_TYPE(data_a[idx]) * FLOAT_TYPE(p.param1));
|
||||
data_d[get_doffset() + idx] = D_TYPE(FLOAT_TYPE(data_a[get_aoffset() + idx]) * FLOAT_TYPE(p.param1));
|
||||
idx += num_threads;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -12,6 +12,6 @@ void main() {
|
|||
return;
|
||||
}
|
||||
|
||||
const FLOAT_TYPE val = FLOAT_TYPE(data_a[src0_idx(idx)]);
|
||||
data_d[p.d_offset + dst_idx(idx)] = D_TYPE(sin(val));
|
||||
const FLOAT_TYPE val = FLOAT_TYPE(data_a[get_aoffset() + src0_idx(idx)]);
|
||||
data_d[get_doffset() + dst_idx(idx)] = D_TYPE(sin(val));
|
||||
}
|
||||
|
|
|
@ -12,6 +12,6 @@ void main() {
|
|||
return;
|
||||
}
|
||||
|
||||
const FLOAT_TYPE val = FLOAT_TYPE(data_a[src0_idx(idx)]);
|
||||
data_d[p.d_offset + dst_idx(idx)] = D_TYPE(val * val);
|
||||
const FLOAT_TYPE val = FLOAT_TYPE(data_a[get_aoffset() + src0_idx(idx)]);
|
||||
data_d[get_doffset() + dst_idx(idx)] = D_TYPE(val * val);
|
||||
}
|
||||
|
|
|
@ -2,7 +2,7 @@
|
|||
|
||||
layout (push_constant) uniform parameter
|
||||
{
|
||||
uint ne; uint d_offset;
|
||||
uint ne; uint a_offset; uint d_offset;
|
||||
uint nb00; uint nb01; uint nb02; uint nb03;
|
||||
uint ne10; uint ne11; uint ne12; uint ne13;
|
||||
float sf0; float sf1; float sf2; float sf3;
|
||||
|
@ -32,5 +32,5 @@ void main() {
|
|||
const uint i02 = uint(i12 / p.sf2);
|
||||
const uint i03 = uint(i13 / p.sf3);
|
||||
|
||||
data_d[p.d_offset + idx] = D_TYPE(data_a[i03 * p.nb03 + i02 * p.nb02 + i01 * p.nb01 + i00 * p.nb00]);
|
||||
data_d[p.d_offset + idx] = D_TYPE(data_a[p.a_offset + i03 * p.nb03 + i02 * p.nb02 + i01 * p.nb01 + i00 * p.nb00]);
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue